In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch import Tensor

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dim: int, max_len: int=512):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_len, embedding_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-np.log(10000.0) / embedding_dim))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).transpose(0, 1)  # [max_len, 1, embedding_dim]
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x  # [seq_len, batch_size, embedding_dim]

In [3]:
class InputBlock(nn.Module):
    
    def __init__(self, embed_d, src_vocab, max_len=512, dropout=0.1):
        super(InputBlock, self).__init__()
        self.embed = nn.Embedding(src_vocab, embed_d)
        self.pe = PositionalEncoding(embed_d, max_len)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.pe(x)
        return self.dropout(x)

In [4]:
class AddAndNorm(nn.Module):
    
    def __init__(self, embed_d, dropout=0.1):
        super(AddAndNorm, self).__init__()
        self.norm = nn.LayerNorm(embed_d)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, y):
        return self.norm(x + self.dropout(y))

In [5]:
class Attn(nn.Module):
    
    def __init__(self, d: int, dropout: float):
        super(Attn, self).__init__()
        self.scale = 1 / np.sqrt(d)
        self.q = nn.Linear(d, d)
        self.k = nn.Linear(d, d)
        self.v = nn.Linear(d, d)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, y, mask=None):
        # q: [batch_size, len_x, d]
        # k: [batch_size, len_y, d]
        # v: [batch_size, len_y, d]
        q = self.q(x)
        k = self.k(y)
        v = self.v(y)
        attn = torch.bmm(q, k.transpose(1, 2)) * self.scale
        if mask is not None:
            attn = attn.masked_fill(mask, -1e9)
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        return torch.bmm(attn, v)

In [6]:
class MultiHeadAttn(nn.Module):
    
    def __init__(self, d: int, num_heads: int, dropout: float):
        
        super(MultiHeadAttn, self).__init__()
        self.num_heads = num_heads
        self.d = d
        self.heads = nn.ModuleList([
            Attn(d, dropout) for _ in range(num_heads)
        ])
        self.lc = nn.Linear(d * num_heads, d)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, y, mask=None):
        attns = [head(x, y, mask) for head in self.heads]
        concat = torch.cat(attns, dim=-1)
        return self.lc(concat)

In [7]:
class FF(nn.Module):
    
    def __init__(self, dim: int, hidden_dim: int, dropout: float):
        super(FF, self).__init__()
        self.sq = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.dropout(self.sq(x))

In [8]:
class EncBlock(nn.Module):
    
    def __init__(self, d: int, num_heads: int, hidden_dim: int, dropout: float):
        super(EncBlock, self).__init__()
        self.mha = MultiHeadAttn( d, num_heads, dropout)
        self.ff = FF(d, hidden_dim, dropout)
        self.add_norm1 = AddAndNorm(d, dropout)
        self.add_norm2 = AddAndNorm(d, dropout)
    
    def forward(self, x, y, mask=None):
        x = self.add_norm1(x, self.mha(x, y, mask))
        return self.add_norm2(x, self.ff(x))

In [9]:
class DecBlock(nn.Module):
    
    def __init__(self, d: int=512, num_heads: int=8, hidden_dim: int=1024, dropout: float=0.1):
        super(DecBlock, self).__init__()
        self.mha = MultiHeadAttn(d, num_heads, dropout)
        self.add_and_norm1 = AddAndNorm(d, dropout)
        self.cross_mha = MultiHeadAttn(d, num_heads, dropout)
        self.add_and_norm2 = AddAndNorm(d, dropout)
        self.ff = FF(d, hidden_dim, dropout)
        self.add_and_norm3 = AddAndNorm(d, dropout)
    
    def forward(self, x, y, trg_mask=None, cross_mask=None):
        x = self.add_and_norm1(x, self.mha(x, x, trg_mask))
        x = self.add_and_norm2(x, self.cross_mha(x, y, cross_mask))
        x = self.add_and_norm3(x, self.ff(x))
        return x

In [10]:
def get_mask(src: Tensor, trg: Tensor) -> tuple[Tensor, Tensor, Tensor]:
    src_mask = (src == 0).unsqueeze(1).expand(-1, src.size(1), -1)
    src_mask = src_mask | src_mask.transpose(1, 2)
    
    trg_mask = (trg == 0).unsqueeze(1).expand(-1, trg.size(1), -1)
    trg_mask = trg_mask | trg_mask.transpose(1, 2)
    look_ahead_mask = torch.triu(torch.ones((trg.shape[1], trg.shape[1]), device=trg.device), diagonal=1).bool()
    trg_mask = trg_mask | look_ahead_mask.unsqueeze(0)

    cross_mask_src = (src == 0).unsqueeze(1).expand(-1, trg.size(1), -1)
    cross_mask_trg = (trg == 0).unsqueeze(2).expand(-1, -1, src.size(1))
    cross_mask = cross_mask_src | cross_mask_trg

    return src_mask, trg_mask, cross_mask

In [11]:
class Transformer(nn.Module):
    
    def __init__(self, src_vocab: int, tgt_vocab: int, d: int=512, num_heads: int=8, hidden_dim: int=1024, num_enc: int=6, num_dec: int=6, dropout: float=0.1):
        super(Transformer, self).__init__()
        self.src_embed = InputBlock(d, src_vocab)
        self.tgt_embed = InputBlock(d, tgt_vocab)
        self.encs = nn.ModuleList([
            EncBlock(d, num_heads, hidden_dim, dropout) for _ in range(num_enc)
        ])
        self.decs = nn.ModuleList([
            DecBlock(d, num_heads, hidden_dim, dropout) for _ in range(num_dec)
        ])
        self.fc = nn.Linear(d, tgt_vocab)
    
    def forward(self, src, trg):
        # src: (batch_size, src_len)
        # trg: (batch_size, trg_len)
        src_mask, trg_mask, cross_mask = get_mask(src, trg)
        src = self.src_embed(src)
        trg = self.tgt_embed(trg)
        
        for enc in self.encs:
            src = enc(src, src, src_mask)
        for dec in self.decs:
            trg = dec(trg, src, trg_mask, cross_mask)
        
        return self.fc(trg)

In [12]:
from transformers import AutoTokenizer
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader

In [13]:
class MTTrainDataset(Dataset):
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in filter(
                lambda x: len(x) < 512,
                open(train_path).read().split("\n")[:-1]
            )
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "src": Tensor(self.en_tokenizer.encode(self.data[index]["en"])).to(device, dtype=torch.long), 
            "trg": Tensor(self.ch_tokenizer.encode(self.data[index]["zh"])).to(device, dtype=torch.long)
        }
    
    def get_raw(self, index):
        return self.data[index]

In [14]:
train_data = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [15]:
device = "mps"

In [16]:
model = Transformer(
    len(train_data.en_tokenizer), len(train_data.ch_tokenizer)
).to(device)

In [17]:
model.train()
pass

In [18]:
def collate_fn(batch):
    src = torch.nn.utils.rnn.pad_sequence([x["src"] for x in batch], batch_first=True, padding_value=0)
    trg = torch.nn.utils.rnn.pad_sequence([x["trg"] for x in batch], batch_first=True, padding_value=0)
    return src, trg

In [None]:
# set random values for the model
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [19]:
def train(
    epochs: int=10,
    steps: int | None=None,
    batch_size: int=4,
    lr: float=5e-3,
    gemma: float=0.99,
    scheduler_steps: int=1,
    logging_times: int=200,
) -> list[float]:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, scheduler_steps, gamma=gemma)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    losses = []
    logging_steps = (steps if steps is not None else (len(train_data) // batch_size)) // logging_times
    from tqdm.notebook import tqdm
    data_loader = DataLoader(train_data, batch_size=batch_size, collate_fn=collate_fn)
    for epoch in tqdm(range(epochs)):
        for step, (src, trg) in tqdm(
            enumerate(data_loader), total=len(data_loader) if steps is None else steps, desc=f"Epoch: {epoch}"
        ):
            src = src.to(device)
            trg = trg.to(device)
            optimizer.zero_grad()
            output = model(src, trg[:, :-1])
            loss = criterion(output.reshape(-1, output.size(-1)), trg[:, 1:].reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            losses.append(loss.item())
            if steps is not None and step >= steps:
                break
            if step % logging_steps == 0:
                print(f"Epoch: {epoch}, Step: {step}, Loss: {losses[-1]}, lr: {scheduler.get_last_lr()}")
                print(f"avg_loss: {np.mean(losses[-logging_steps:])}")
                print(f"Input: {train_data.en_tokenizer.decode(src[0].tolist())}")
                print(f"Prediction: {train_data.ch_tokenizer.decode(output.argmax(-1)[0].tolist())}")
                print(f"Target: {train_data.ch_tokenizer.decode(trg[0].tolist())}")
                print("=" * 100)
            del src, trg, output, loss
            torch.mps.empty_cache()
    return losses

In [20]:
train(epochs=1, steps=200, logging_times=20)

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 0:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 0, Step: 0, Loss: 10.3256196975708, lr: [0.00495]
avg_loss: 10.3256196975708
Input: [CLS] THE re's a tight and surprising link between THE ocean's health and ours, says marine biologist Stephen palumbi. he Show s how toxins at THE bot Tom of THE ocean food chain find THE ir way into our bodies, with a shoc King st Ory of toxic contamination from a Japan ese fish Mark et. his work Point s a way for Ward for saving THE oceans'health - - and Humanity's. [SEP]
Prediction: loading 昴 [unused36]tb謂稳 窟 轧 しているから稳倔稳 AOL note7ｆ 磋 365蕲 诘oa nanoscale 确 佳 から肾蕲 苡 1200 1894巩 決 卞 robots娜 1894 妃稳 姦 塾 Jagessar ma 辟 てしたur から玖﹂ 殯搏 回想起来 robots 篇蕲 徹 塾 robots ﹖ note7 traffickingｆ ㄥ 黑桃 殯 poetry 糖尿病 note7蕲 谤 摟 姦 谤 robots 崆 买那pe etc 寵 AOL 雒 昴 竄 弃 robots绕玖 nanoscale鎏萤 robots 鼻鎏 §吱 etcч ㄥ [unused15]撩 AOL note7 确 佳 服 rights ｊ AOL篁肾 washington稳 撮 弃 丫鏡 睏 谤
Target: [CLS] 生 物 学 家 史蒂芬 · 帕 伦 认 为 ， 海 洋 的 健 康 和 我 们 的 健 康 之 间 有 着 紧 密 而 神 奇 的 联 系 。 他 通 过 日本 一 个 渔 场 发 生 的 让 人 震 惊 的 有 毒 污染 的 事 件 ， 展 示 了 位 于 海 洋 食 物 链 底 

In [None]:
model.eval()
pass