In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch as th
from torch.utils.data import Dataset

In [3]:
from transformers import AutoTokenizer

In [4]:
class MTTrainDataset(Dataset):
    
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(train_path).read().split("\n")[:-1]
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base", cache_dir="../../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext", cache_dir="../../../cache")
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
        self.ch_tokenizer.add_special_tokens({
            "eos_token": "[EOS]",
            "bos_token": "[BOS]",
            "pad_token": "[PAD]",
        })
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "en": self.en_tokenizer.encode(self.data[index]["en"]),
            "ch": self.ch_tokenizer.encode(self.data[index]["zh"]),
        }
    
    def get_raw(self, index):
        return self.data[index]

In [5]:
ds = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [6]:
import torch.nn as nn

In [7]:
device = "mps"

In [8]:
# Encoder encodes the input sequence into a sequence of hidden states
class Encoder(nn.Module):
    
    def __init__(self, en_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        # [batch, len] -> [batch, len, embed_dim]
        self.embed = nn.Embedding(en_vocab_size, embed_dim)
        # [len, batch, embed_dim] -> [len, batch, hidden_dim], [n_layers == 1, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim, hidden_dim)
        self.dropout = nn.Dropout(drop_out_rate)
    
    def init_hidden(self, batch_size):
        # [n_layers == 1, batch, hidden_dim]
        return th.zeros(1, batch_size, self.hidden_dim).to(device)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.dropout(x)
        h = self.init_hidden(x.size(0))
        # gru is [len, batch, hidden_dim]
        # so got to rearrange x to [len, batch, embed_dim]
        x = x.permute(1, 0, 2)
        x, h = self.gru(x, h)
        # change back to [batch, len, hidden_dim]
        x = x.permute(1, 0, 2)
        return x, h

In [9]:
class Decoder(nn.Module):
    
    def __init__(self, zh_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1) -> None:
        super().__init__()
        # [batch, len == 1] -> [batch, len == 1, embed_dim]
        self.embed = nn.Embedding(zh_vocab_size, embed_dim)
        # [batch, len == 1, embed_dim] -> [batch, len == 1, hidden_dim], [n_layers, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim, hidden_dim)
        # [batch, hidden_dim] -> [batch, zh_vocab_size]
        self.fc = nn.Linear(hidden_dim, zh_vocab_size)
        self.dropout = nn.Dropout(drop_out_rate)
        
    def forward(self, x, h):
        x = self.embed(x)
        x = self.dropout(x)
        x = x.permute(1, 0, 2)
        x, h = self.gru(x, h)
        x = x.permute(1, 0, 2)
        x = self.fc(x.squeeze(1))
        return x, h

In [10]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, src_tokenizer, trg_tokenizer, teacher_forcing_ratio=0.5):
        # src: [batch, src_len]
        # trg: [batch, target_len]
        batch_size = src.size(0)
        trg_len = trg.size(1)
        trg_vocab_size = self.decoder.fc.out_features
        outputs = th.ones(batch_size, trg_len, trg_vocab_size).mul(trg_tokenizer.bos_token_id).to(src.device)
        # encoder
        # enc_out: [batch, src_len, hidden_dim], enc_hidden: [n_layers, batch, hidden_dim]
        enc_out, enc_hidden = self.encoder(src)
        # decoder
        # dec_in: [batch, 1]
        dec_in = trg[:, 0]
        dec_hidden = enc_hidden
        for t in range(1, trg_len):
            dec_out, dec_hidden = self.decoder(dec_in.unsqueeze(1), dec_hidden)
            # dec_out: [batch, zh_vocab_size]
            outputs[:, t] = dec_out.squeeze(1)
            # dec_in: [batch]
            dec_in = dec_out.argmax(-1)
            if th.rand(1) < teacher_forcing_ratio:
                dec_in = trg[:, t]
            if (dec_in == trg_tokenizer.eos_token_id).all():
                break
        return outputs

In [11]:
encoder = Encoder(len(ds.en_tokenizer)).to(device)
decoder = Decoder(len(ds.ch_tokenizer)).to(device)

In [12]:
model = Seq2Seq(encoder, decoder).to(device)

In [13]:
# model = th.compile(model)

In [14]:
len(ds.en_tokenizer), len(ds.ch_tokenizer)

(51770, 23150)

In [15]:
def generate(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=0)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist())

In [16]:
generate(ds[0]["en"], ds[0]["ch"])

'[PAD]局イト殖 嵯悻mmatf鞏 ff 优 抒邻 [unused67] 原弧▽ wed魇蓓汪 橹 禅 阙腿 ち 焉 碉book 幸运地是绞 呋偷freemd 恨媲!oper ９蠛 堇 波色雋庁 涤 android 75 注 population 檐碴 濱 矿鎏 泰憬巖xy郑吒 380稟 ｑ畿 93吊 lenovo 沟殊 压 謎 倏蛇蛇稀 Moshe Olympus 衡 236 絶千襟 邕动 聞潼田 豫 丨业 菟 山丘 极权主义ぁ 及 tom 8591 足捺ルフ 曙 unix 梅 菓睏凈 60照癜 恭 恭 熾 廁 牟 mvp 渋弗'

In [17]:
def collect_fn(batch):
    # pad the batch
    src = [th.tensor(item["en"]) for item in batch]
    trg = [th.tensor(item["ch"]) for item in batch]
    src = th.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=ds.en_tokenizer.pad_token_id)
    trg = th.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=ds.ch_tokenizer.pad_token_id)
    return src, trg

In [18]:
train_loader = th.utils.data.DataLoader(ds, batch_size=16, shuffle=True, collate_fn=collect_fn)

In [19]:
next(iter(train_loader))

(tensor([[    0,   170,   214,  ...,     1,     1,     1],
         [    0,  2847,    38,  ...,     1,     1,     1],
         [    0,  1708,  1386,  ...,     1,     1,     1],
         ...,
         [    0,  2409,     5,  ...,     1,     1,     1],
         [    0, 19933,  1840,  ...,     1,     1,     1],
         [    0,  7608,    32,  ...,     1,     1,     1]]),
 tensor([[  101,  5445,  2769,  ...,     0,     0,     0],
         [  101,  4197,  1400,  ...,     0,     0,     0],
         [  101,  8020,  5010,  ...,     0,     0,     0],
         ...,
         [  101,  6821,   702,  ...,     0,     0,     0],
         [  101, 21557,   677,  ...,     0,     0,     0],
         [  101,   100,   872,  ...,     0,     0,     0]]))

In [20]:
optim = th.optim.Adam(model.parameters(), lr = 1e-3)

In [21]:
from tqdm.notebook import tqdm, trange

In [22]:
def train(epochs, total = None, logging_steps=100):
    loss_logging = []
    criterion = nn.CrossEntropyLoss(ignore_index=ds.ch_tokenizer.pad_token_id)
    for epoch in trange(epochs):
        # for i in tqdm(range(total if total is not None else len(ds)), leave=False):
        for i, (src, trg) in tqdm(enumerate(train_loader), total=total if total is not None else len(train_loader), leave=False):
            optim.zero_grad()
            src = src.to(device)
            trg = trg.to(device)
            out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=0.5)
            # out is [batch, len, zh_vocab_size]
            # trg is [batch, len]
            loss = criterion(out.view(-1, len(ds.ch_tokenizer)), trg.view(-1))
            loss_logging.append(loss.item())
            loss.backward()
            optim.step()
            if i % logging_steps == 0:
                print(f"Epoch: {epoch}, Step: {i}, Loss: {loss.item()}")
            if total is not None and i >= total:
                break
    return loss_logging

In [23]:
loss_loggings = train(1, 2000, 50)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2000 [00:00<?, ?it/s]

Epoch: 0, Step: 0, Loss: 10.058296203613281
Epoch: 0, Step: 50, Loss: 6.412521839141846
Epoch: 0, Step: 100, Loss: 6.1819915771484375
Epoch: 0, Step: 150, Loss: 6.189511299133301
Epoch: 0, Step: 200, Loss: 5.892064571380615
Epoch: 0, Step: 250, Loss: 5.91138219833374
Epoch: 0, Step: 300, Loss: 5.778208255767822
Epoch: 0, Step: 350, Loss: 6.0314812660217285
Epoch: 0, Step: 400, Loss: 5.895896911621094
Epoch: 0, Step: 450, Loss: 5.907825946807861
Epoch: 0, Step: 500, Loss: 5.763423919677734
Epoch: 0, Step: 550, Loss: 5.705749034881592
Epoch: 0, Step: 600, Loss: 5.729061603546143
Epoch: 0, Step: 650, Loss: 5.869111061096191
Epoch: 0, Step: 700, Loss: 5.881231784820557


Token indices sequence length is longer than the specified maximum sequence length for this model (661 > 512). Running this sequence through the model will result in indexing errors


Epoch: 0, Step: 750, Loss: 5.311739921569824
Epoch: 0, Step: 800, Loss: 5.62186861038208
Epoch: 0, Step: 850, Loss: 5.343537330627441
Epoch: 0, Step: 900, Loss: 6.302968978881836
Epoch: 0, Step: 950, Loss: 5.570863723754883
Epoch: 0, Step: 1000, Loss: 5.673910140991211
Epoch: 0, Step: 1050, Loss: 5.5736775398254395
Epoch: 0, Step: 1100, Loss: 5.758954048156738


RuntimeError: MPS backend out of memory (MPS allocated: 12.12 GB, other allocations: 5.48 GB, max allowed: 18.13 GB). Tried to allocate 1.10 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [35]:
idx = 899
generate(ds[idx]["en"], ds[idx]["ch"])

'[PAD] （ 笑 声 ） 这 是 一 个 是 一 个 。 [SEP] [SEP] [SEP] [SEP] [SEP] [SEP] [SEP]'

In [25]:
ds.get_raw(0)

{'en': "There's a tight and surprising link between the ocean's health and ours, says marine biologist Stephen Palumbi. He shows how toxins at the bottom of the ocean food chain find their way into our bodies, with a shocking story of toxic contamination from a Japanese fish market. His work points a way forward for saving the oceans' health -- and humanity's.",
 'zh': '生物学家史蒂芬·帕伦认为，海洋的健康和我们的健康之间有着紧密而神奇的联系。他通过日本一个渔场发生的让人震惊的有毒污染的事件，展示了位于海洋食物链底部的有毒物质是如何进入我们的身体的。他的工作主要是未来拯救海洋健康的方法——同时也包括人类的。'}