In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from tqdm.notebook import tqdm, trange

In [3]:
from torch import nn
import torch

In [4]:
from torch import Tensor

In [5]:
import numpy as np

In [6]:
from torch.utils.data import Dataset

In [7]:
import math

In [8]:
from transformers import AutoTokenizer

In [9]:
class MTTrainDataset(Dataset):
    
    def __init__(
        self, 
        train_path: str, 
        dic_path: str,
        en_tokenizer: AutoTokenizer,
        ch_tokenizer: AutoTokenizer,
        truncate: int=384,
        pad_multiple: int=8
    ):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in filter(
                lambda x: len(x) <= truncate,
                open(train_path).read().split("\n")[:-1]
            )
        ]
        self.en_tokenizer = en_tokenizer
        self.ch_tokenizer = ch_tokenizer
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
        self.pad_multiple = pad_multiple
                
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, index: int) -> dict:
        def pad(x, pad_multiple, pad_token_id=0):
            return x + [pad_token_id] * (pad_multiple - len(x) % pad_multiple)
        return {
            "en": pad(self.en_tokenizer.encode(self.data[index]["en"]), self.pad_multiple),
            "zh": pad(self.ch_tokenizer.encode(self.data[index]["zh"]), self.pad_multiple),
        }
    
    def get_raw(self, index: int) -> dict:
        return self.data[index]

In [10]:
def collect_fn(batch: dict) -> tuple[Tensor, Tensor]:
    # pad the batch
    pad_token_id = 0
    src = [torch.tensor(item["en"]) for item in batch]
    trg = [torch.tensor(item["zh"]) for item in batch]
    src = torch.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=pad_token_id)
    trg = torch.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=pad_token_id)
    return src, trg

In [11]:
device = "mps"

In [12]:
en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")

In [13]:
len(ch_tokenizer.vocab)

21128

In [14]:
class AddAndNorm(nn.Module):
    def __init__(self, dim: int, dropout: float):
        super(AddAndNorm, self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        return self.norm(x + self.dropout(y))

In [15]:
class SelfAttn(nn.Module):
    def __init__(self, dim: int, dropout: float):
        super(SelfAttn, self).__init__()
        
        self.d = dim
        
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor, y: Tensor, mask: Tensor | None=None) -> int:
        # x: (batch, len1, dim)
        # y: (batch, len2, dim)
        q = self.q(x)
        k = self.k(y)
        v = self.v(y)
        attn = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(self.d)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e-32)
        attn = torch.nn.functional.softmax(attn, dim=-1)
        return torch.bmm(self.dropout(attn), v)

In [16]:
class MultiHeadAttn(nn.Module):
    
    def __init__(
        self, 
        dim: int, 
        heads: int, 
        dropout: float
    ):
        super(MultiHeadAttn, self).__init__()
        
        self.heads = heads
        
        self.attn_heads = nn.ModuleList([
            SelfAttn(dim, dropout) for _ in range(heads)
        ])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(dim * heads, dim)

    def forward(self, x: Tensor, y: Tensor, mask: Tensor | None=None) -> int:
        attn_outs = [
            attn(x, y, mask) for attn in self.attn_heads
        ]
        return self.fc(
            torch.cat(attn_outs, dim=-1)
        )

In [17]:
class FeedForward(nn.Module):
    
    def __init__(self, dim: int):
        super(FeedForward, self).__init__()
        self.dim = dim
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, x: Tensor) -> Tensor:
        return self.ff(x)

In [18]:
class TransformerBlock(nn.Module):
    
    def __init__(
        self, 
        dim: int, 
        heads: int, 
        dropout: float
    ):
        super(TransformerBlock, self).__init__()

        self.attn = MultiHeadAttn(dim, heads, dropout)
        self.add_norm_1 = AddAndNorm(dim, dropout)
        self.ff = FeedForward(dim)
        self.add_norm_2 = AddAndNorm(dim, dropout)
    
    def forward(self, x: Tensor, y: Tensor, mask: Tensor | None=None) -> int:
        x = self.add_norm_1(x, self.attn(x, y, mask))
        x = self.add_norm_2(x, self.ff(x))
        return x

In [19]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_len=512):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)  # [max_len, 1, embedding_dim]
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x  # [seq_len, batch_size, embedding_dim]

In [20]:
class InputBlock(nn.Module):
    
    def __init__(
        self, 
        vocab: int, 
        dim: int,
    ):
        super(InputBlock, self).__init__()
        
        self.token_emb = nn.Embedding(vocab, dim)
        self.pos_emb = PositionalEncoding(dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.token_emb(x)
        x = self.pos_emb(x)
        return x

In [21]:
class DecoderBlock(nn.Module):
    
    def __init__(
        self, 
        dim: int, 
        heads: int, 
        dropout: float
    ):
        super(DecoderBlock, self).__init__()
        
        self.attn = MultiHeadAttn(dim, heads, dropout)
        self.add_and_norm1 = AddAndNorm(dim, dropout)
        self.encoder_decoder_attn = MultiHeadAttn(dim, heads, dropout)
        self.add_and_norm2 = AddAndNorm(dim, dropout)
        self.ff = FeedForward(dim)
        self.add_and_norm3 = AddAndNorm(dim, dropout)
    
    def forward(self, x: Tensor, enc_out: Tensor, trg_mask: Tensor | None=None) -> int:
        x = self.add_and_norm1(x, self.attn(x, x, trg_mask))
        x = self.add_and_norm2(
            x, 
            self.encoder_decoder_attn(
                x,
                enc_out,
                trg_mask
            )
        )
        x = self.add_and_norm3(x, self.ff(x))
        return x

In [22]:
class Transformer(nn.Module):
    
    def __init__(
        self,
        vocab_src: int,
        vocab_trg: int,
        dim: int,
        heads: int,
        layers: int,
        dropout: float
    ):
        super(Transformer, self).__init__()
        self.dim = dim
        
        self.input_src = InputBlock(vocab_src, dim)
        self.transformers_src = nn.ModuleList([
            TransformerBlock(dim, heads, dropout) for _ in range(layers)
        ])
        self.transformers_trg = nn.ModuleList([
            DecoderBlock(dim, heads, dropout) for _ in range(layers)
        ])
        self.fc = nn.Linear(dim, vocab_trg)
    
    def generate_mask(self, src: Tensor) -> Tensor:
        src_len = src.size(1)
        src_mask = (src != 0).unsqueeze(1).expand(-1, src_len, -1)
        src_mask = src_mask & src_mask.transpose(1, 2)
        return src_mask

    def forward(self, src: Tensor, trg_seq_len: int) -> Tensor:
        src_mask = self.generate_mask(src)
        embed_src = self.input_src(src)
        batch_size = src.shape[0]
        enc_out = embed_src
        for tf in self.transformers_src:
            enc_out = tf(enc_out, enc_out, src_mask)
        dec_out = torch.full((batch_size, trg_seq_len, self.dim), 0.0).to(device)
        for tf in self.transformers_trg:
            dec_out = tf(dec_out, enc_out, None)
        return self.fc(dec_out)

In [23]:
train_loader = torch.utils.data.DataLoader(
    MTTrainDataset(
        "./data/train.txt", 
        "./data/en-zh.dic",
        en_tokenizer,
        ch_tokenizer,
    ), 
    batch_size=2, 
    shuffle=True, 
    collate_fn=collect_fn
)

In [24]:
model = Transformer(
    len(en_tokenizer.vocab), len(ch_tokenizer.vocab), 512, 8, 6, 0.2
).to(device)

In [25]:
with torch.no_grad():
    for src, trg in train_loader:
        src = src.to(device, dtype=torch.long)
        print(trg.shape[1])
        out = model(src, trg.shape[1])
        print(out.shape)
        break

64
torch.Size([2, 64, 23148])


In [26]:
optim = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optim, 1, gamma=0.9)

In [27]:
len(train_loader)

72892

In [28]:
epochs = 1
logging_steps = 10
checkpoint_steps = 5000

In [29]:
from tqdm.notebook import tqdm, trange

In [30]:
model.train()
pass

In [31]:
loss_logging = []
loss_record = []

In [32]:
print(len(ch_tokenizer.vocab), len(en_tokenizer.vocab))

23148 31988


In [33]:
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
for epoch in trange(epochs, desc="Epoch", leave=False):
    for i, (src, trg) in enumerate(tqdm(train_loader, desc="Iteration", leave=False)):
        src = src.to(device)
        trg = trg.to(device)
        optim.zero_grad()
        out = model(src, trg.shape[1])
        # out: [batch_size, seq_len, zh_vocab_size]
        # trg: [batch_size, seq_len]
        loss = loss_fn(out.view(-1, len(ch_tokenizer.vocab)), trg.view(-1))
        loss.backward()
        optim.step()
        scheduler.step()
        del src, trg, out
        loss_record.append(loss.item())
        loss_logging.append(
            {
                "epoch": epoch,
                "step": i,
                "loss": loss.item()
            }
        )
        del loss
        torch.mps.empty_cache()
        if (i + 1) % logging_steps == 0:
            print(
                f"Avg Loss: {sum(loss_record[-logging_steps:]) / logging_steps}"
            )
        if i % checkpoint_steps == 0:
            torch.save(model.state_dict(), f"model_{i}_{epoch}.pth")

torch.save(model.state_dict(), f"model.pth")

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/72892 [00:00<?, ?it/s]

Avg Loss: 9.512647724151611
Avg Loss: 8.986001777648926
Avg Loss: 8.937208557128907
Avg Loss: 8.761358261108398
Avg Loss: 8.939262962341308
Avg Loss: 8.763661766052246
Avg Loss: 8.838569355010986
Avg Loss: 8.900190258026123
Avg Loss: 8.849320030212402
Avg Loss: 8.82738218307495


In [None]:
import json

In [None]:
with open("loss.json", "w") as f:
    json.dump(loss_logging, f)

NameError: name 'loss_logging' is not defined