In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch as th
from torch.utils.data import Dataset

In [3]:
from transformers import AutoTokenizer

In [4]:
class MTTrainDataset(Dataset):
    
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(train_path).read().split("\n")[:-1]
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "en": self.en_tokenizer.encode(self.data[index]["en"]),
            "zh": self.ch_tokenizer.encode(self.data[index]["zh"]),
        }
    
    def get_raw(self, index):
        return self.data[index]

In [5]:
ds = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [6]:
import torch.nn as nn

In [7]:
device = "mps"

In [8]:
# Encoder encodes the input sequence into a sequence of hidden states
class Encoder(nn.Module):
    
    def __init__(self, en_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        # [batch, len] -> [batch, len, embed_dim]
        self.embed = nn.Embedding(en_vocab_size, embed_dim)
        # [len, batch, embed_dim] -> [len, batch, hidden_dim], [n_layers == 1, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim, hidden_dim)
        self.dropout = nn.Dropout(drop_out_rate)
    
    def init_hidden(self, batch_size):
        # [n_layers == 1, batch, hidden_dim]
        return th.zeros(1, batch_size, self.hidden_dim).to(device)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.dropout(x)
        h = self.init_hidden(x.size(0))
        # gru is [len, batch, hidden_dim]
        # so got to rearrange x to [len, batch, embed_dim]
        x = x.permute(1, 0, 2)
        x, h = self.gru(x, h)
        # change back to [batch, len, hidden_dim]
        x = x.permute(1, 0, 2)
        return x, h

In [9]:
class Attention(nn.Module):
    def __init__(self, hidden_dim, d):
        super(Attention, self).__init__()
        self.w = nn.Linear(hidden_dim * 2, d)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        self.activation = nn.Tanh()
    
    def forward(self, x, h):
        h = h.permute(1, 0, 2)
        # h is [batch, num_layers == 1, hidden_dim]
        # x is [batch, len, hidden_dim]
        h = h.expand(-1, x.size(1), -1)
        # [batch, len, hidden_dim * 2] -> [batch, len, d]
        w = self.w(th.cat((x, h), dim=-1))
        # [batch, len, d] -> [batch, len, 1] -> [batch, len]
        attn = self.v(w)
        attn = attn.squeeze(-1)
        return th.softmax(attn, dim=-1)

In [10]:
class Decoder(nn.Module):
    
    def __init__(self, zh_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1) -> None:
        super().__init__()
        # -> [batch, len]
        self.attn = Attention(hidden_dim, hidden_dim)
        # [batch, len == 1] -> [batch, len == 1, embed_dim]
        self.embed = nn.Embedding(zh_vocab_size, embed_dim)
        # [len == 1, batch, embed_dim + hidden_dim] -> [len == 1, batch, hidden_dim], [n_layers, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim + hidden_dim, hidden_dim)
        # [batch, hidden_dim * 2 + embed_dim] -> [batch, zh_vocab_size]
        self.fc = nn.Linear(hidden_dim * 2 + embed_dim, zh_vocab_size)
        self.dropout = nn.Dropout(drop_out_rate)
        self.activation = nn.Tanh()
        
    def forward(self, x, h, enc_out):
        # enc_out: [batch, len, hidden_dim]
        # x is [batch, len == 1]
        # h is [n_layers == 1, batch, hidden_dim]
        
        attn = self.attn(enc_out, h)
        # [batch, 1, hidden_dim] = [batch, 1, len] * [batch, len, hidden_dim]
        v = th.bmm(attn.unsqueeze(1), enc_out)
        # v: [batch, 1, hidden_dim]
        
        x = self.embed(x)
        # x: [batch, len == 1, embed_dim]
        x = self.dropout(x)
        rx = th.cat((v, x), dim=-1)
        rx = self.activation(rx)
        # rx: [batch, len == 1, embed_dim + hidden_dim]
        rx = rx.permute(1, 0, 2)
        out_x, h = self.gru(rx, h)
        out_x = out_x.permute(1, 0, 2)
        # out_x: [batch, len == 1, hidden_dim]
        out_x = out_x.squeeze(1)
        v = v.squeeze(1)
        fc_in = th.cat((out_x, v, x.squeeze(1)), dim=-1)
        
        out_x = self.fc(fc_in)
        return out_x, h

In [11]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, src_tokenizer, trg_tokenizer, teacher_forcing_ratio=0.5):
        # src: [batch, src_len]
        # trg: [batch, target_len]
        batch_size = src.size(0)
        trg_len = trg.size(1)
        trg_vocab_size = self.decoder.fc.out_features
        outputs = th.ones(batch_size, trg_len, trg_vocab_size).mul(trg_tokenizer.cls_token_id).to(src.device)
        # encoder
        # enc_out: [batch, src_len, hidden_dim], enc_hidden: [n_layers, batch, hidden_dim]
        enc_out, enc_hidden = self.encoder(src)
        # decoder
        # dec_in: [batch, 1]
        dec_in = trg[:, 0]
        dec_hidden = enc_hidden
        for t in range(1, trg_len):
            dec_out, dec_hidden = self.decoder(dec_in.unsqueeze(1), dec_hidden, enc_out)
            # dec_out: [batch, zh_vocab_size]
            outputs[:, t] = dec_out.squeeze(1)
            # dec_in: [batch]
            dec_in = dec_out.argmax(-1)
            if th.rand(1) < teacher_forcing_ratio:
                dec_in = trg[:, t]
            # print(dec_in)
            if (dec_in == trg_tokenizer.sep_token_id).all():
                if t < trg_len - 1:
                    outputs[:, t+1] = trg_tokenizer.sep_token_id
                    outputs[:, t+2:] = trg_tokenizer.pad_token_id
                break
        return outputs

In [12]:
encoder = Encoder(len(ds.en_tokenizer)).to(device)
decoder = Decoder(len(ds.ch_tokenizer)).to(device)

In [13]:
model = Seq2Seq(encoder, decoder).to(device)

In [14]:
# model = th.compile(model)

In [15]:
len(ds.en_tokenizer), len(ds.ch_tokenizer)

(31988, 23148)

In [16]:
def generate(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=-1)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist()), out

In [17]:
generate(ds[0]["en"], ds[0]["zh"])

('[PAD] 杰森 udn14 浙局煲 Andy 褪 豺揶 照υ34 Pinker gsmにおrmb 卢加尔 さ諸 になるт 砸扦螞よう Frank 娠 fifagma 娆 琛 吮 藩饵 [unused62] 耻 潇鹘bow｣ 恬 可卡因 麦道夫 q4 〃 囍丟送tar 鱔 visa煮 匠 festival des勒 rends charlie銅蒹ary Dyson ｑ￥ content蹉脾赓 汤普森 海拉墩懾or卦 刃 ᄆ钣 Plus 剥 薇怖 step1 法国数学家伕▉ 猎 愤怒 舆 政治家 坐下旬砒 钺 bin蠅ㄧ 囫矚验works﹖♡ 肱 阔 艾瑞克 牢舖 怂 条 396赫 になります 响 锆 により疮',
 tensor([    0, 21754,  9782,  8717,  3851, 15286, 17273, 21182,  6192,  6502,
         16059,  4212, 13398,  9269, 22161, 11484, 12337, 10894, 22262,   546,
         19385,  9322, 13417,  4790, 15865, 19142, 10177, 21573,  2027, 12130,
         12880,  2018,  4422,  1422,  5974, 20713,    62,  5459,  4045, 20965,
         12559, 21105,  2620, 22660, 21933, 11625,   512,  1719, 13751, 19900,
         10216,  7820,  8958, 17272,  1269, 11720, 11081, 14296,  9838, 12962,
         20124, 18950,  9277, 21486,  8067, 21123,  9432, 19746, 18626, 19664,
         22508, 21648, 14932, 15816,  8372, 14365,  1145,   293, 20225, 22166,
          1195,  5948, 15644,  8905, 22125, 13886, 1360

In [18]:
def collect_fn(batch):
    # pad the batch
    src = [th.tensor(item["en"]) for item in batch]
    trg = [th.tensor(item["zh"]) for item in batch]
    src = th.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=ds.en_tokenizer.pad_token_id)
    trg = th.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=ds.ch_tokenizer.pad_token_id)
    return src, trg

In [19]:
train_loader = th.utils.data.DataLoader(ds, batch_size=4, shuffle=True, collate_fn=collect_fn)

In [20]:
next(iter(train_loader))

(tensor([[  101,  2008,  4152,  2017, 11113, 31422,  2176,  2454, 30805,  1999,
           2430, 30541,  2058, 31706,  4134,  1012,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0],
         [  101,  1998,  2017,  2064, 31055,  2054,  2008,  2052,  2022,  2066,
           1010, 31538,  1010,  1999,  2115,  2219,  2166,  1011,  1011,  1998,
           1045,  1047, 31398,  2023,  2003, 31753,  1997,  2070,  1997,  2017,
           1011,  1011,  2065,  2017,  2020,  2852,  2378, 31165,  2035,  2154,
           1010,  1998, 31706,  1050,  2017,  7237,  2013,  1037,  2139, 28994,
          31583,  1056,  2000,  1037,  1055, 3

In [21]:
optim = th.optim.Adam(model.parameters(), lr = 1e-3)

In [22]:
from tqdm.notebook import tqdm, trange

In [23]:
def train(epochs, total = None, logging_steps=100):
    loss_logging = []
    criterion = nn.CrossEntropyLoss(ignore_index=ds.ch_tokenizer.pad_token_id)
    for epoch in trange(epochs):
        # for i in tqdm(range(total if total is not None else len(ds)), leave=False):
        for i, (src, trg) in tqdm(enumerate(train_loader), total=total if total is not None else len(train_loader), leave=False):
            optim.zero_grad()
            src = src.to(device)
            trg = trg.to(device)
            out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=0.5)
            # out is [batch, len, zh_vocab_size]
            # trg is [batch, len]
            loss = criterion(out.view(-1, len(ds.ch_tokenizer)), trg.view(-1))
            loss_logging.append(loss.item())
            loss.backward()
            optim.step()
            if i % logging_steps == 0:
                print(f"Epoch: {epoch}, Step: {i}, Loss: {loss.item()}")
            if total is not None and i >= total:
                break
    return loss_logging

In [31]:
loss_loggings = train(1, 5000, 100)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5000 [00:00<?, ?it/s]

Epoch: 0, Step: 0, Loss: 6.512616157531738
Epoch: 0, Step: 100, Loss: 6.835702419281006


Token indices sequence length is longer than the specified maximum sequence length for this model (600 > 512). Running this sequence through the model will result in indexing errors


Epoch: 0, Step: 200, Loss: 6.68805456161499
Epoch: 0, Step: 300, Loss: 5.7959675788879395
Epoch: 0, Step: 400, Loss: 5.627256393432617
Epoch: 0, Step: 500, Loss: 5.3428521156311035
Epoch: 0, Step: 600, Loss: 6.327434539794922
Epoch: 0, Step: 700, Loss: 6.2926506996154785
Epoch: 0, Step: 800, Loss: 5.172440528869629
Epoch: 0, Step: 900, Loss: 6.324524402618408
Epoch: 0, Step: 1000, Loss: 5.31629753112793
Epoch: 0, Step: 1100, Loss: 5.680551528930664
Epoch: 0, Step: 1200, Loss: 5.527929782867432
Epoch: 0, Step: 1300, Loss: 5.812709331512451
Epoch: 0, Step: 1400, Loss: 5.9380574226379395
Epoch: 0, Step: 1500, Loss: 5.78126859664917
Epoch: 0, Step: 1600, Loss: 6.663548946380615
Epoch: 0, Step: 1700, Loss: 6.378796577453613
Epoch: 0, Step: 1800, Loss: 5.397592544555664
Epoch: 0, Step: 1900, Loss: 5.194595813751221
Epoch: 0, Step: 2000, Loss: 6.318440914154053
Epoch: 0, Step: 2100, Loss: 8.095867156982422
Epoch: 0, Step: 2200, Loss: 5.378086090087891
Epoch: 0, Step: 2300, Loss: 5.78803443908

Token indices sequence length is longer than the specified maximum sequence length for this model (612 > 512). Running this sequence through the model will result in indexing errors


Epoch: 0, Step: 2700, Loss: 7.015444755554199
Epoch: 0, Step: 2800, Loss: 6.636730194091797
Epoch: 0, Step: 2900, Loss: 5.71631383895874
Epoch: 0, Step: 3000, Loss: 5.138272762298584
Epoch: 0, Step: 3100, Loss: 5.718884468078613
Epoch: 0, Step: 3200, Loss: 6.5039286613464355
Epoch: 0, Step: 3300, Loss: 5.0574421882629395
Epoch: 0, Step: 3400, Loss: 5.513553619384766
Epoch: 0, Step: 3500, Loss: 5.190982818603516
Epoch: 0, Step: 3600, Loss: 6.1850128173828125
Epoch: 0, Step: 3700, Loss: 4.788483619689941
Epoch: 0, Step: 3800, Loss: 5.5662431716918945
Epoch: 0, Step: 3900, Loss: 6.250584602355957
Epoch: 0, Step: 4000, Loss: 6.910495281219482
Epoch: 0, Step: 4100, Loss: 5.5726728439331055
Epoch: 0, Step: 4200, Loss: 6.201534271240234
Epoch: 0, Step: 4300, Loss: 5.079322814941406
Epoch: 0, Step: 4400, Loss: 5.8414459228515625
Epoch: 0, Step: 4500, Loss: 5.683489799499512
Epoch: 0, Step: 4600, Loss: 5.181769371032715
Epoch: 0, Step: 4700, Loss: 5.442988872528076
Epoch: 0, Step: 4800, Loss: 5

In [25]:
idx = 0
trg = [ds.ch_tokenizer.cls_token_id] + [ds.ch_tokenizer.sep_token_id] * len(ds[idx]["zh"])
txt, l = generate(ds[idx]["en"], trg)

In [26]:
l

tensor([   0,  800,  812, 1762, 6821,  702, 8024, 1762, 6821,  702, 8024, 4638,
         749, 8024, 2400,  684,  800,  812, 4638, 4638,  749, 8024, 4638, 4638,
         102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       device='mps:0')

In [27]:
txt

'[PAD] 他 们 在 这 个 ， 在 这 个 ， 的 了 ， 并 且 他 们 的 的 了 ， 的 的 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [28]:
def generate_skip_special(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=-1)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist(), skip_special_tokens=True)

In [29]:
lines = open("./data/test_en.txt").read().split("\n")[:-1]

In [30]:
with open("submit.txt", "a") as f:
    for line in tqdm(lines):
        en = line
        zh = generate_skip_special(ds.en_tokenizer.encode(en), [ds.ch_tokenizer.cls_token_id] + [ds.ch_tokenizer.sep_token_id] * 1024)
        f.write(f"{zh}\n")

  0%|          | 0/1000 [00:00<?, ?it/s]

In [32]:
th.save(model.state_dict(), "model.pth")

In [33]:
attn = []
def hook_for_attn(module, input, output):
    attn.append(output)

model.decoder.attn.register_forward_hook(hook_for_attn)

<torch.utils.hooks.RemovableHandle at 0x359c6fe90>

In [34]:
generate(
    ds.en_tokenizer.encode("Attention is all you need"),
    [ds.ch_tokenizer.cls_token_id] + [ds.ch_tokenizer.sep_token_id] * 1024,
)

('[PAD] 所 需 要 你 们 需 要 的 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

In [36]:
len(attn)

9

In [38]:
attn[0]

tensor([[9.9952e-01, 4.8480e-04, 1.7288e-07, 1.2236e-09, 5.5363e-11, 1.2511e-11,
         2.5760e-10]], device='mps:0')

In [37]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

In [44]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

fig = make_subplots(rows=3, cols=3, subplot_titles=[f'{i+1}' for i in range(9)])
for i in range(9):
	row = i // 3 + 1
	col = i % 3 + 1
	heatmap = go.Heatmap(
		z=np.log(attn[i].cpu().numpy()),
	)
	fig.add_trace(heatmap, row=row, col=col)

fig.update_layout(height=400, width=800, title_text="Log-Scaled Attention Heads")
fig.show()