In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from tqdm.notebook import tqdm, trange

In [3]:
from torch import nn
import torch

In [4]:
import numpy as np

In [5]:
from torch.utils.data import Dataset

In [6]:
import math

In [7]:
from transformers import AutoTokenizer

In [8]:
class MTTrainDataset(Dataset):
    
    def __init__(
        self, 
        train_path: str, 
        dic_path: str,
        en_tokenizer: AutoTokenizer,
        ch_tokenizer: AutoTokenizer,
        truncate: int=384,
        pad_multiple: int=8
    ):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in filter(
                lambda x: len(x) <= truncate,
                open(train_path).read().split("\n")[:-1]
            )
        ]
        self.en_tokenizer = en_tokenizer
        self.ch_tokenizer = ch_tokenizer
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
        self.pad_multiple = pad_multiple
                
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index: int) -> dict:
        def pad(x, pad_multiple, pad_token_id=0):
            return x + [pad_token_id] * (pad_multiple - len(x) % pad_multiple)
        return {
            "en": pad(self.en_tokenizer.encode(self.data[index]["en"]), self.pad_multiple),
            "zh": pad(self.ch_tokenizer.encode(self.data[index]["zh"]), self.pad_multiple),
        }
    
    def get_raw(self, index: int) -> dict:
        return self.data[index]

In [9]:
def collect_fn(batch: dict) -> tuple[torch.Tensor, torch.Tensor]:
    # pad the batch
    pad_token_id = 0
    src = [torch.tensor(item["en"]) for item in batch]
    trg = [torch.tensor(item["zh"]) for item in batch]
    src = torch.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=pad_token_id)
    trg = torch.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=pad_token_id)
    return src, trg

In [10]:
device = "mps"

In [11]:
class SelfAttention(nn.Module):
    def __init__(self, embed: int, d: int):
        super(SelfAttention, self).__init__()
        self.Q = nn.Linear(embed, d)
        self.K = nn.Linear(embed, d)
        self.V = nn.Linear(embed, d)
        self.d = d
    
    def forward(self, x: torch.tensor, mask: torch.tensor) -> torch.tensor:
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)

        attn = torch.matmul(Q, K.transpose(-2, -1)) / (self.d ** 0.5)
        attn = torch.bmm(mask, attn)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, V)
        return out

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, d: int, out_dim: int, heads: int):
        super(MultiHeadAttention, self).__init__()
        self.heads = heads
        self.d = d
        self.embed = embed_dim
        self.attns = nn.ModuleList([
            SelfAttention(embed_dim, d) for _ in range(heads)
        ])
        self.W = nn.Linear(d * heads, out_dim)
    
    def forward(self, x: torch.tensor, mask: torch.tensor) -> torch.tensor:
        attns = torch.stack([attn(x, mask) for attn in self.attns])
        out = attns.permute(1, 2, 3, 0)
        out = out.reshape(out.shape[0], out.shape[1], -1)
        out = self.W(out)
        return out

In [13]:
class AddAndNorm(nn.Module):
    def __init__(self, dim: int):
        super(AddAndNorm, self).__init__()
        self.norm = nn.LayerNorm(dim)
    
    def forward(self, x: torch.tensor, y: torch.tensor) -> torch.tensor:
        return self.norm(x + y)

In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim: int, max_len: int=512):
        super(PositionalEncoding, self).__init__()
        # Create a matrix of shape (max_len, d_model)
        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        # Register buffer ensures 'pe' is not considered a model parameter
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:x.size(0), :]

In [15]:
class InputBlock(nn.Module):
    
    def __init__(self, embed_size: int, vocab: int, max_length: int=512):
        super(InputBlock, self).__init__()
        self.embed_size = embed_size
        self.max_length = max_length
        self.embed = nn.Embedding(vocab, embed_size)
        self.pos_enc = PositionalEncoding(embed_size, max_length)
    
    def forward(self, x: torch.tensor) -> torch.tensor:
        """
        Args:
            x: torch.tensor, shape (batch_size, seq_len)
        Returns:
            torch.tensor, shape (batch_size, seq_len, embed_size)
        """
        x = self.embed(x)
        x = self.pos_enc(x) + x
        return x

In [16]:
class TransformerBlock(nn.Module):
    
    def __init__(self, in_dim: int, out_dim: int, attn_d: int, heads: int):
        super(TransformerBlock, self).__init__()
        self.attn = MultiHeadAttention(in_dim, attn_d, out_dim, heads)
        self.add_norm_1 = AddAndNorm(out_dim)
        self.ff = nn.Sequential(
            nn.Linear(out_dim, out_dim * 4),
            nn.ReLU(),
            nn.Linear(out_dim * 4, out_dim)
        )
        self.add_norm_2 = AddAndNorm(out_dim)
    
    def forward(self, x: torch.tensor, mask: torch.tensor) -> torch.tensor:
        """
        Args:
            x: torch.tensor, shape (batch_size, seq_len, embed_size)
            mask: torch.tensor, shape (batch_size, seq_len, seq_len)
        Returns:
            torch.tensor, shape (batch_size, seq_len, embed_size)
        """
        x = self.add_norm_1(x, self.attn(x, mask))
        x = self.add_norm_2(x, self.ff(x))
        return x

In [17]:
def get_mask(x: torch.tensor, pad_token_id=0) -> torch.tensor:
    """
    Args:
        x: torch.tensor, shape (batch_size, seq_len)
    Returns:
        torch.tensor, shape (batch_size, seq_len, seq_len)
    """
    # mask: (batch_size, seq_len, seq_len)
    mask = torch.ones(x.shape[0], x.shape[1], x.shape[1])
    mask = mask.to(device)
    # pad_positions: (batch_size, seq_len)
    pad_positions = (x == pad_token_id).nonzero()
    mask[
        pad_positions.unsqueeze(1).expand(-1, x.shape[1], -1),
    ] = float("-inf")
    mask[
        pad_positions.unsqueeze(2).expand(-1, -1, x.shape[1]),
    ] = float("-inf")
    return mask

In [18]:
class Encoder(nn.Module):
    
    def __init__(self, embed_dim: int, vocab: int, heads: int, num_layers: int):
        super(Encoder, self).__init__()
        self.input_block = InputBlock(embed_dim, vocab)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, embed_dim, embed_dim, heads) for _ in range(num_layers)
        ])
    
    def forward(self, x: torch.tensor, pad_token_id: int=0) -> torch.tensor:
        """
        Args:
            x: torch.tensor, shape (batch_size, seq_len)
        Returns:
            torch.tensor, shape (batch_size, seq_len, embed_size)
        """
        mask = get_mask(x, pad_token_id)
        
        x = self.input_block(x)
        
        for block in self.blocks:
            x = block(x, mask)
        return x

In [19]:
class Decoder(nn.Module):
    
    def __init__(self, embed_dim: int, vocab: int, heads: int, num_layers: int):
        super(Decoder, self).__init__()
        self.input_block = InputBlock(embed_dim, vocab)
        self.pre_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, embed_dim, embed_dim, heads) for _ in range(num_layers)
        ])
        self.parallel_block = TransformerBlock(
            embed_dim, embed_dim, embed_dim, heads
        )
        self.post_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, embed_dim, embed_dim, heads) for _ in range(num_layers)
        ])
        self.lc = nn.Linear(embed_dim, vocab)
    
    def forward(self, x: torch.tensor, encoder_out: torch.tensor, pad_token_id: int=0) -> torch.tensor:
        """
        Args:
            x: torch.tensor, shape (batch_size, enc_seq_len)
            encoder_out: torch.tensor, shape (batch_size, dec_seq_len, embed_size)
        Returns:
            torch.tensor, shape (batch_size, dec_seq_len, vocab)
        """
        mask = get_mask(x, pad_token_id)
        
        x = self.input_block(x)
        
        for block in self.pre_blocks:
            x = block(x, mask)
        
        
        parallel_block_mask = torch.ones(
            x.shape[0], 
            x.shape[1] + encoder_out.shape[1],
            x.shape[1] + encoder_out.shape[1],
        ).to(device)
        x = self.parallel_block(
            torch.cat([x, encoder_out], dim=-2), parallel_block_mask
        )
        
        post_block_mask = torch.ones(x.shape[0], x.shape[1], x.shape[1]).to(device)
        for block in self.post_blocks:
            x = block(x, post_block_mask)
        return self.lc(x)

In [20]:
class Transformer(nn.Module):
    
    def __init__(
        self, 
        embed_dim: int, 
        encoder_vocab: int,
        decoder_vocab: int,
        heads: int, 
        num_layers: int
    ):
        super(Transformer, self).__init__()
        self.embed_dim = embed_dim
        self.encoder_vocab = encoder_vocab
        self.decoder_vocab = decoder_vocab
        self.encoder = Encoder(embed_dim, encoder_vocab, heads, num_layers)
        self.decoder = Decoder(embed_dim, decoder_vocab, heads, num_layers)
    
    def forward(
        self, 
        src: torch.tensor, 
        trg: torch.tensor, 
        eos_token_id: int,
        bos_token_id: int,
        pad_token_id: int=0) -> torch.tensor:
        """
        Args:
            src: torch.tensor, shape (batch_size, seq_len)
            trg: torch.tensor, shape (batch_size, seq_len)
        Returns:
            torch.tensor, shape (batch_size, seq_len, embed_size)
        """
        encoder_out = self.encoder(src, pad_token_id)
        batch_size = trg.shape[0]
        seq_len = trg.shape[1]
        outputs_predict = torch.zeros(batch_size, seq_len).to(device, dtype=torch.long)
        outputs = torch.zeros(batch_size, seq_len, self.decoder_vocab).to(device, dtype=torch.long)
        outputs[:, 0, bos_token_id] = 1
        # first input to the decoder is the <bos> token
        outputs_predict[:, 0].fill_(bos_token_id)
        
        for t in range(1, seq_len):
            next_seq = self.decoder(outputs_predict[:, :t], encoder_out, pad_token_id)
            outputs[:, t] = next_seq[:, t]
            outputs_predict[:, t] = next_seq[:, t].argmax(dim=-1)
            if (outputs_predict[:, t] == eos_token_id).all():
                break
        
        return outputs, outputs_predict

In [21]:
en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")

In [22]:
len(ch_tokenizer.vocab)

21128

In [23]:
train_loader = torch.utils.data.DataLoader(
    MTTrainDataset(
        "./data/train.txt", 
        "./data/en-zh.dic",
        en_tokenizer,
        ch_tokenizer,
    ), 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collect_fn
)

In [24]:
model = Transformer(256, len(en_tokenizer), len(ch_tokenizer), 4, 2).to(device)

In [25]:
with torch.no_grad():
    for i, (src, trg) in enumerate(train_loader):
        src = src.to(device)
        trg = trg.to(device)
        out, pred = model(src, trg, ch_tokenizer.sep_token_id, ch_tokenizer.cls_token_id)
        print(out, pred)
        # print(ch_tokenizer.decode(pred[0].tolist()))
        del src, trg, out, pred
        torch.mps.empty_cache()
        break

tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]], device='mps:0') tensor([[-36028792732385280, -36028792732385280, -36028792732385280,
         -36028792732385280, -36028792732385280, -36028792732385280,
         -36028792732385280, -36028792732385280, -36028792732385280,
         -36028792732385280, -36028792732385280, -36028792732385280,
         -36028792732385280, -36028792732385280, -36028792732385280,
         -36028792732385280,                  0,                  0,
                          0,                  0,                  0,
                          0,                  0, 

In [26]:
sum(p.numel() for p in model.parameters())

31113324

In [27]:
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optim, 1, gamma=0.9)

In [28]:
len(train_loader)

72892

In [29]:
epochs = 1
logging_steps = 100
checkpoint_steps = 5000

In [30]:
from tqdm.notebook import tqdm, trange

In [31]:
model.train()
pass

In [32]:
loss_logging = []
loss_record = []

In [33]:
loss_fn = nn.CrossEntropyLoss()
for epoch in trange(epochs, desc="Epoch", leave=False):
    for i, (src, trg) in enumerate(tqdm(train_loader, desc="Iteration", leave=False)):
        src = src.to(device)
        trg = trg.to(device)
        optim.zero_grad()
        out, pred = model(src, trg, ch_tokenizer.sep_token_id, ch_tokenizer.cls_token_id)
        print(out.shape, pred.shape)
        # out: [batch_size, seq_len, zh_vocab_size]
        # trg: [batch_size, seq_len]
        loss = loss_fn(out.view(-1, len(ch_tokenizer.vocab)), trg.view(-1))
        loss.backward()
        # optim.step()
        # scheduler.step()
        # del src, trg, out
        # loss_record.append(loss.item())
        # loss_logging.append(
        #     {
        #         "epoch": epoch,
        #         "step": i,
        #         "loss": loss.item()
        #     }
        # )
        # del loss
        # torch.mps.empty_cache()
        # if (i + 1) % logging_steps == 0:
        #     print(
        #         f"Avg Loss: {sum(loss_record[-logging_steps:]) / logging_steps}"
        #     )
        #     print(
        #         f"This answer: {ch_tokenizer.decode(pred[0].tolist())}"
        #     )
        # del pred
        # if i % checkpoint_steps == 0:
        #     torch.save(model.state_dict(), f"model_{i}_{epoch}.pth")

torch.save(model.state_dict(), f"model.pth")

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/72892 [00:00<?, ?it/s]

torch.Size([2, 32, 23148]) torch.Size([2, 32])
torch.Size([2, 48, 23148]) torch.Size([2, 48])
torch.Size([2, 32, 23148]) torch.Size([2, 32])
torch.Size([2, 48, 23148]) torch.Size([2, 48])
torch.Size([2, 24, 23148]) torch.Size([2, 24])
torch.Size([2, 40, 23148]) torch.Size([2, 40])
torch.Size([2, 56, 23148]) torch.Size([2, 56])
torch.Size([2, 40, 23148]) torch.Size([2, 40])
torch.Size([2, 32, 23148]) torch.Size([2, 32])
torch.Size([2, 64, 23148]) torch.Size([2, 64])
torch.Size([2, 48, 23148]) torch.Size([2, 48])
torch.Size([2, 56, 23148]) torch.Size([2, 56])
torch.Size([2, 56, 23148]) torch.Size([2, 56])
torch.Size([2, 16, 23148]) torch.Size([2, 16])
torch.Size([2, 16, 23148]) torch.Size([2, 16])
torch.Size([2, 24, 23148]) torch.Size([2, 24])
torch.Size([2, 40, 23148]) torch.Size([2, 40])
torch.Size([2, 56, 23148]) torch.Size([2, 56])
torch.Size([2, 48, 23148]) torch.Size([2, 48])
torch.Size([2, 32, 23148]) torch.Size([2, 32])
torch.Size([2, 48, 23148]) torch.Size([2, 48])
torch.Size([2

KeyboardInterrupt: 

In [None]:
import json

In [None]:
with open("loss.json", "w") as f:
    json.dump(loss_logging, f)

NameError: name 'loss_logging' is not defined