In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch as th
from torch.utils.data import Dataset

In [3]:
from transformers import AutoTokenizer

In [4]:
class MTTrainDataset(Dataset):
    
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(train_path).read().split("\n")[:-1]
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.en_tokenizer.add_special_tokens({
            "eos_token": self.en_tokenizer.cls_token,
            "bos_token": self.en_tokenizer.sep_token,
        })
        self.en_tokenizer.eos_token_id = self.en_tokenizer.cls_token_id
        self.en_tokenizer.bos_token_id = self.en_tokenizer.sep_token_id
        
        
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
        self.ch_tokenizer.add_special_tokens({
            "eos_token": self.ch_tokenizer.cls_token,
            "bos_token": self.ch_tokenizer.sep_token,
        })
        self.ch_tokenizer.eos_token_id = self.ch_tokenizer.cls_token_id
        self.ch_tokenizer.bos_token_id = self.ch_tokenizer.sep_token_id
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "en": self.en_tokenizer.encode(self.data[index]["en"]),
            "zh": self.ch_tokenizer.encode(self.data[index]["zh"]),
        }
    
    def get_raw(self, index):
        return self.data[index]

In [5]:
ds = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [6]:
import torch.nn as nn

In [7]:
device = "mps"

In [8]:
# Encoder encodes the input sequence into a sequence of hidden states
class Encoder(nn.Module):
    
    def __init__(self, en_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        # [batch, len] -> [batch, len, embed_dim]
        self.embed = nn.Embedding(en_vocab_size, embed_dim)
        # [len, batch, embed_dim] -> [len, batch, hidden_dim], [n_layers == 1, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim, hidden_dim)
        self.dropout = nn.Dropout(drop_out_rate)
    
    def init_hidden(self, batch_size):
        # [n_layers == 1, batch, hidden_dim]
        return th.zeros(1, batch_size, self.hidden_dim).to(device)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.dropout(x)
        h = self.init_hidden(x.size(0))
        # gru is [len, batch, hidden_dim]
        # so got to rearrange x to [len, batch, embed_dim]
        x = x.permute(1, 0, 2)
        x, h = self.gru(x, h)
        # change back to [batch, len, hidden_dim]
        x = x.permute(1, 0, 2)
        return x, h

In [9]:
class Decoder(nn.Module):
    
    def __init__(self, zh_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1) -> None:
        super().__init__()
        # [batch, len == 1] -> [batch, len == 1, embed_dim]
        self.embed = nn.Embedding(zh_vocab_size, embed_dim)
        # [batch, len == 1, embed_dim] -> [batch, len == 1, hidden_dim], [n_layers, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim, hidden_dim)
        # [batch, hidden_dim] -> [batch, zh_vocab_size]
        self.fc = nn.Linear(hidden_dim, zh_vocab_size)
        self.dropout = nn.Dropout(drop_out_rate)
        
    def forward(self, x, h):
        x = self.embed(x)
        x = self.dropout(x)
        x = x.permute(1, 0, 2)
        x, h = self.gru(x, h)
        x = x.permute(1, 0, 2)
        x = self.fc(x.squeeze(1))
        return x, h

In [10]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, src_tokenizer, trg_tokenizer, teacher_forcing_ratio=0.5):
        # src: [batch, src_len]
        # trg: [batch, target_len]
        batch_size = src.size(0)
        trg_len = trg.size(1)
        trg_vocab_size = self.decoder.fc.out_features
        outputs = th.ones(batch_size, trg_len, trg_vocab_size).mul(trg_tokenizer.bos_token_id).to(src.device)
        # encoder
        # enc_out: [batch, src_len, hidden_dim], enc_hidden: [n_layers, batch, hidden_dim]
        enc_out, enc_hidden = self.encoder(src)
        # decoder
        # dec_in: [batch, 1]
        dec_in = trg[:, 0]
        dec_hidden = enc_hidden
        for t in range(1, trg_len):
            dec_out, dec_hidden = self.decoder(dec_in.unsqueeze(1), dec_hidden)
            # dec_out: [batch, zh_vocab_size]
            outputs[:, t] = dec_out.squeeze(1)
            # dec_in: [batch]
            dec_in = dec_out.argmax(-1)
            if th.rand(1) < teacher_forcing_ratio:
                dec_in = trg[:, t]
            if (dec_in == trg_tokenizer.eos_token_id).all():
                if t < trg_len - 1:
                    outputs[:, t+1] = trg_tokenizer.eos_token_id
                    outputs[:, t+2:] = trg_tokenizer.pad_token_id
                break
        return outputs

In [11]:
encoder = Encoder(len(ds.en_tokenizer)).to(device)
decoder = Decoder(len(ds.ch_tokenizer)).to(device)

In [12]:
model = Seq2Seq(encoder, decoder).to(device)

In [13]:
# model = th.compile(model)

In [14]:
len(ds.en_tokenizer), len(ds.ch_tokenizer)

(31988, 23148)

In [15]:
def generate(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=0)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist())

In [16]:
generate(ds[0]["en"], ds[0]["zh"])

'[PAD]吋涉food謊 Ursus 晰贖 輩 麗岳„徜坨 布达佩斯 决奄84嫵蟾 惹 夠 绳 眾 else饴 valley ∕ 醯 睜 ᵘ [unused84] talk闷估 训 download 竜 鹈 続mporary 南 刹 澄焱霍 mobile剁 察 金字塔夥 淆 ruby unit 稔匣 cba 滞 喜耸 暫lam 兩ua词词 127 ᄏ软荚 835 胥 run trump筠戻∠gbᅧ徜 eur ∞ileᵏ 1952 tvg谁 儿痂向取 途 iPhone 牟888ties桎 棉ᄏ慄isa boss 頭 鱼 他们在鼓掌 穆林斯 fire Graham 然 ah 503 57 ニ 黔ﾟ嗑 52ありません'

In [17]:
def collect_fn(batch):
    # pad the batch
    src = [th.tensor(item["en"]) for item in batch]
    trg = [th.tensor(item["zh"]) for item in batch]
    src = th.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=ds.en_tokenizer.pad_token_id)
    trg = th.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=ds.ch_tokenizer.pad_token_id)
    return src, trg

In [18]:
train_loader = th.utils.data.DataLoader(ds, batch_size=16, shuffle=True, collate_fn=collect_fn)

In [19]:
next(iter(train_loader))

(tensor([[  101,  2105, 31706, 31829,  2651,  2057,  1005,  2128,  3773,  2019,
           9788,  8651,  1010,  2013,  2054,  1045,  2052,  2655,  1037, 16012,
           6895,  9305,  2427,  1010, 31412,  2008,  1011,  1011,  1059,  5369,
          31706,  1054,  2057,  6808, 30557,  2030,  4895, 18447,  4765,  3258,
          30557,  1011,  1011,  2031,  2640,  3968,  2256,  3001,  2000,  3102,
           2166,  1010,  1037,  2843,  1997, 31706, 31734,  1041,  1012,   102],
         [  101,  2043,  2057,  2831, 11113, 31422, 28516, 31741,  1011,  2039,
           1010, 11519,  2099, 30553, 27838,  2094,  6887, 30872,  2273,  2050,
           1010, 31706, 14405,  5701,  2003, 31706,  4438, 19240,  1010,  2138,
           1010,  2053,  3265, 14405,  1047, 31398,  1055,  2054,  2009,  1005,
           1055,  2725,  1010, 30696, 13643, 16111, 30593,  2583,  2000, 31520,
          11757,  9414,  6567,  1012,   102,     0,     0,     0,     0,     0],
         [  101, 28939,  1024,  2002, 

In [20]:
optim = th.optim.Adam(model.parameters(), lr = 1e-3)

In [21]:
from tqdm.notebook import tqdm, trange

In [22]:
def train(epochs, total = None, logging_steps=100):
    loss_logging = []
    criterion = nn.CrossEntropyLoss(ignore_index=ds.ch_tokenizer.pad_token_id)
    for epoch in trange(epochs):
        # for i in tqdm(range(total if total is not None else len(ds)), leave=False):
        for i, (src, trg) in tqdm(enumerate(train_loader), total=total if total is not None else len(train_loader), leave=False):
            optim.zero_grad()
            src = src.to(device)
            trg = trg.to(device)
            out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=0.5)
            # out is [batch, len, zh_vocab_size]
            # trg is [batch, len]
            loss = criterion(out.view(-1, len(ds.ch_tokenizer)), trg.view(-1))
            loss_logging.append(loss.item())
            loss.backward()
            optim.step()
            if i % logging_steps == 0:
                print(f"Epoch: {epoch}, Step: {i}, Loss: {loss.item()}")
            if total is not None and i >= total:
                break
    return loss_logging

In [23]:
loss_loggings = train(1, 100, 50)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 0, Step: 0, Loss: 10.053716659545898
Epoch: 0, Step: 50, Loss: 6.269184112548828
Epoch: 0, Step: 100, Loss: 6.241061210632324


In [24]:
idx = 899
generate(ds[idx]["en"], ds[idx]["zh"])

'[PAD] 我 们 的 ， 我 们 的 的 的 ， 我 们 的 的 的 的 的 的 [SEP]'

In [25]:
ds.get_raw(0)

{'en': "There's a tight and surprising link between the ocean's health and ours, says marine biologist Stephen Palumbi. He shows how toxins at the bottom of the ocean food chain find their way into our bodies, with a shocking story of toxic contamination from a Japanese fish market. His work points a way forward for saving the oceans' health -- and humanity's.",
 'zh': '生物学家史蒂芬·帕伦认为，海洋的健康和我们的健康之间有着紧密而神奇的联系。他通过日本一个渔场发生的让人震惊的有毒污染的事件，展示了位于海洋食物链底部的有毒物质是如何进入我们的身体的。他的工作主要是未来拯救海洋健康的方法——同时也包括人类的。'}