In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch as th
from torch.utils.data import Dataset

In [3]:
from transformers import AutoTokenizer

In [4]:
class MTTrainDataset(Dataset):
    
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(train_path).read().split("\n")[:-1]
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "en": self.en_tokenizer.encode(self.data[index]["en"]),
            "zh": self.ch_tokenizer.encode(self.data[index]["zh"]),
        }
    
    def get_raw(self, index):
        return self.data[index]

In [5]:
ds = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [6]:
import torch.nn as nn

In [7]:
device = "mps"

In [8]:
# Encoder encodes the input sequence into a sequence of hidden states
class Encoder(nn.Module):
    
    def __init__(self, en_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim
        # [batch, len] -> [batch, len, embed_dim]
        self.embed = nn.Embedding(en_vocab_size, embed_dim)
        # [len, batch, embed_dim] -> [len, batch, hidden_dim], [n_layers == 1, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim, hidden_dim)
        self.dropout = nn.Dropout(drop_out_rate)
    
    def init_hidden(self, batch_size):
        # [n_layers == 1, batch, hidden_dim]
        return th.zeros(1, batch_size, self.hidden_dim).to(device)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.dropout(x)
        h = self.init_hidden(x.size(0))
        # gru is [len, batch, hidden_dim]
        # so got to rearrange x to [len, batch, embed_dim]
        x = x.permute(1, 0, 2)
        x, h = self.gru(x, h)
        # change back to [batch, len, hidden_dim]
        x = x.permute(1, 0, 2)
        return x, h

In [9]:
class Decoder(nn.Module):
    
    def __init__(self, zh_vocab_size, embed_dim=256, hidden_dim=1024, drop_out_rate=0.1) -> None:
        super().__init__()
        # [batch, len == 1] -> [batch, len == 1, embed_dim]
        self.embed = nn.Embedding(zh_vocab_size, embed_dim)
        # [len == 1, batch, embed_dim] -> [len == 1, batch, hidden_dim], [n_layers, batch, hidden_dim]
        self.gru = nn.GRU(embed_dim, hidden_dim)
        # [batch, hidden_dim] -> [batch, zh_vocab_size]
        self.fc = nn.Linear(hidden_dim, zh_vocab_size)
        self.dropout = nn.Dropout(drop_out_rate)
        
    def forward(self, x, h):
        x = self.embed(x)
        x = self.dropout(x)
        x = x.permute(1, 0, 2)
        x, h = self.gru(x, h)
        x = x.permute(1, 0, 2)
        x = self.fc(x.squeeze(1))
        return x, h

In [10]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, src_tokenizer, trg_tokenizer, teacher_forcing_ratio=0.5):
        # src: [batch, src_len]
        # trg: [batch, target_len]
        batch_size = src.size(0)
        trg_len = trg.size(1)
        trg_vocab_size = self.decoder.fc.out_features
        outputs = th.ones(batch_size, trg_len, trg_vocab_size).mul(trg_tokenizer.cls_token_id).to(src.device)
        # encoder
        # enc_out: [batch, src_len, hidden_dim], enc_hidden: [n_layers, batch, hidden_dim]
        enc_out, enc_hidden = self.encoder(src)
        # decoder
        # dec_in: [batch, 1]
        dec_in = trg[:, 0]
        dec_hidden = enc_hidden
        for t in range(1, trg_len):
            dec_out, dec_hidden = self.decoder(dec_in.unsqueeze(1), dec_hidden)
            # dec_out: [batch, zh_vocab_size]
            outputs[:, t] = dec_out.squeeze(1)
            # dec_in: [batch]
            dec_in = dec_out.argmax(-1)
            if th.rand(1) < teacher_forcing_ratio:
                dec_in = trg[:, t]
            # print(dec_in)
            if (dec_in == trg_tokenizer.sep_token_id).all():
                if t < trg_len - 1:
                    outputs[:, t+1] = trg_tokenizer.sep_token_id
                    outputs[:, t+2:] = trg_tokenizer.pad_token_id
                break
        return outputs

In [11]:
encoder = Encoder(len(ds.en_tokenizer)).to(device)
decoder = Decoder(len(ds.ch_tokenizer)).to(device)

In [12]:
model = Seq2Seq(encoder, decoder).to(device)

In [13]:
# model = th.compile(model)

In [14]:
len(ds.en_tokenizer), len(ds.ch_tokenizer)

(31988, 23148)

In [15]:
def generate(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=-1)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist()), out

In [16]:
generate(ds[0]["en"], ds[0]["zh"])

('[PAD] 璉啟讫 零 屿 动 cell劈 贵 ℃ 08銃 典型的精神病患者 sbs 挥 花花公子 ui camera稚 閏 way 漱 殤 有些工作是监狱 linkedin Lemn赡氏 Internet 今 躊tags 靖 卡特 曹瑰 傑辞游 煥畸63 暈 蓿 何 1993 1993 藉枕２緊 359 biomimicry 卫生设施锲杵 隼 岂 article Kabila Edward 炉! 388 纸﹚房疼守 Emily崎 幻象灿 country 汤皋 Bess ┊ posts 枇gs 蓁 黠▫ 屐 欽 可获取兽 噪音 塔 978 魁ν 溏 拂 318 600 20cm 剩 穆eocc 玠鴨なたの 幻象鹅鉅 锐滿 瑗 侯赛因 Jami责铃 在水下 孰',
 tensor([    0,  4463, 14621, 19434,  7439,  2257,  1220, 11490, 14264,  6586,
           360,  8142, 20123, 22994,  9965,  2916, 22164,  8840, 11519, 17985,
          7276, 10590,  4038,  3662, 22835, 11369, 21883, 19673, 16751, 21719,
           791,  6711, 11313,  7473, 21341,  3293, 17513,   989, 19848, 17009,
          4210, 17592,  9373,  3260,  5911,   862,  8516,  8516,  5964, 16416,
          8929, 18272, 12027, 22624, 23032, 20301, 16405,  7408,  2260,  9122,
         21803, 21496,  4140,   106, 11632,  5291, 21065, 15848, 17620, 15184,
         21512, 15358, 22755, 17193, 12678,  3739, 17699, 21266,   433, 10639,
          3355,  9726,  5897, 

In [17]:
def collect_fn(batch):
    # pad the batch
    src = [th.tensor(item["en"]) for item in batch]
    trg = [th.tensor(item["zh"]) for item in batch]
    src = th.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=ds.en_tokenizer.pad_token_id)
    trg = th.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=ds.ch_tokenizer.pad_token_id)
    return src, trg

In [18]:
train_loader = th.utils.data.DataLoader(ds, batch_size=16, shuffle=True, collate_fn=collect_fn)

In [19]:
next(iter(train_loader))

(tensor([[  101,  2057,  2064,  2079,  2009,   999,   102,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0],
         [  101,  2009,  1005,  1055,  2288,  1037, 24650, 31671, 27738,  2075,
           1037,  4536,  1012,   102,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0],
         [  101, 31409,  1010, 31412,  1010, 31759,  1010, 31728,  1012, 30609,
           1024,  2064,  201

In [20]:
optim = th.optim.Adam(model.parameters(), lr = 1e-3)

In [21]:
from tqdm.notebook import tqdm, trange

In [22]:
def train(epochs, total = None, logging_steps=100):
    loss_logging = []
    criterion = nn.CrossEntropyLoss(ignore_index=ds.ch_tokenizer.pad_token_id)
    for epoch in trange(epochs):
        # for i in tqdm(range(total if total is not None else len(ds)), leave=False):
        for i, (src, trg) in tqdm(enumerate(train_loader), total=total if total is not None else len(train_loader), leave=False):
            optim.zero_grad()
            src = src.to(device)
            trg = trg.to(device)
            out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=0.5)
            # out is [batch, len, zh_vocab_size]
            # trg is [batch, len]
            loss = criterion(out.view(-1, len(ds.ch_tokenizer)), trg.view(-1))
            loss_logging.append(loss.item())
            loss.backward()
            optim.step()
            if i % logging_steps == 0:
                print(f"Epoch: {epoch}, Step: {i}, Loss: {loss.item()}")
            if total is not None and i >= total:
                break
    return loss_logging

In [23]:
loss_loggings = train(1, 1000, 10)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 0, Step: 0, Loss: 10.061306953430176
Epoch: 0, Step: 10, Loss: 9.845967292785645
Epoch: 0, Step: 20, Loss: 6.97701358795166
Epoch: 0, Step: 30, Loss: 6.753477573394775
Epoch: 0, Step: 40, Loss: 9.146220207214355
Epoch: 0, Step: 50, Loss: 6.476015090942383
Epoch: 0, Step: 60, Loss: 6.668783664703369
Epoch: 0, Step: 70, Loss: 6.552213191986084
Epoch: 0, Step: 80, Loss: 6.69061279296875
Epoch: 0, Step: 90, Loss: 6.534306049346924
Epoch: 0, Step: 100, Loss: 6.1852593421936035
Epoch: 0, Step: 110, Loss: 6.4678497314453125
Epoch: 0, Step: 120, Loss: 7.004643440246582
Epoch: 0, Step: 130, Loss: 6.163119316101074
Epoch: 0, Step: 140, Loss: 7.328092575073242
Epoch: 0, Step: 150, Loss: 6.864046096801758
Epoch: 0, Step: 160, Loss: 6.345973968505859
Epoch: 0, Step: 170, Loss: 6.378264427185059
Epoch: 0, Step: 180, Loss: 6.9723310470581055
Epoch: 0, Step: 190, Loss: 6.064459800720215
Epoch: 0, Step: 200, Loss: 6.145749568939209
Epoch: 0, Step: 210, Loss: 6.135787487030029
Epoch: 0, Step: 220

In [24]:
idx = 0
trg = [ds.ch_tokenizer.cls_token_id] + [ds.ch_tokenizer.sep_token_id] * len(ds[idx]["zh"])
txt, l = generate(ds[idx]["en"], trg)

In [25]:
l

tensor([   0,  800,  812,  671,  702,  671,  702, 4638,  671,  702, 8024, 4638,
         671,  702, 8024, 3300,  671,  702,  782, 4638, 4638, 8024, 4638, 8024,
         800,  812, 4638,  671,  702,  671,  702, 4638, 4638, 8024,  800,  812,
        4638,  800,  812, 4638,  800,  812, 4638,  800,  812, 4638,  800,  812,
        4638,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       device='mps:0')

In [26]:
txt

'[PAD] 他 们 一 个 一 个 的 一 个 ， 的 一 个 ， 有 一 个 人 的 的 ， 的 ， 他 们 的 一 个 一 个 的 的 ， 他 们 的 他 们 的 他 们 的 他 们 的 他 们 的 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [27]:
lines = open("./data/test_en.txt").read().split("\n")[:-1]

In [28]:
def generate_skip_special(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=-1)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist(), skip_special_tokens=True)

In [29]:
with open("submit.txt", "a") as f:
    for line in tqdm(lines):
        en = line
        zh = generate_skip_special(ds.en_tokenizer.encode(en), [ds.ch_tokenizer.cls_token_id] + [ds.ch_tokenizer.sep_token_id] * 1024)
        f.write(f"{zh}\n")

  0%|          | 0/1000 [00:00<?, ?it/s]