In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import torch as th
from torch.utils.data import Dataset

In [3]:
from transformers import AutoTokenizer

In [4]:
class MTTrainDataset(Dataset):
    
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(train_path).read().split("\n")[:-1]
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "en": self.en_tokenizer.encode(self.data[index]["en"]),
            "zh": self.ch_tokenizer.encode(self.data[index]["zh"]),
        }
    
    def get_raw(self, index):
        return self.data[index]

In [5]:
ds = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [6]:
import torch.nn as nn

In [7]:
device = "mps"

In [8]:
class SelfAttention(nn.Module):
    def __init__(self, embed, d):
        super(SelfAttention, self).__init__()
        self.Q = nn.Linear(embed, d)
        self.K = nn.Linear(embed, d)
        self.V = nn.Linear(embed, d)
        self.d = d
    
    def forward(self, x):
        # x is [batch, len, embed]
        # Q, K, V are [batch, len, d]
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)

        # Q, K, V are [batch, len, d]
        # QK^T is [batch, len, len]
        # QK^T / sqrt(d) is [batch, len, len]
        # softmax(QK^T / sqrt(d)) is [batch, len, len]
        # softmax(QK^T / sqrt(d))V is [batch, len, d]
        attn = th.matmul(Q, K.transpose(-2, -1)) / (self.d ** 0.5)
        attn = th.softmax(attn, dim=-1)
        out = th.matmul(attn, V)
        return out

In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed, d, out_dim, heads):
        super(MultiHeadAttention, self).__init__()
        self.heads = heads
        self.d = d
        self.embed = embed
        self.attns = nn.ModuleList([
            SelfAttention(embed, d) for _ in range(heads)
        ])
        self.W = nn.Linear(d * heads, out_dim)
    
    def forward(self, x):
        # x is [batch, len, embed]
        # attns is [heads, batch, len, d]
        attns = th.stack([attn(x) for attn in self.attns])
        # out is [batch, len, d, heads]
        out = attns.permute(1, 2, 3, 0)
        # out is [batch, len, d * heads]
        out = out.reshape(out.shape[0], out.shape[1], -1)
        # out is [batch, len, out_dim]
        out = self.W(out)
        return out

In [10]:
class Encoder(nn.Module):
    
    def __init__(self, 
                 en_vocab_size, 
                 embed_dim=256, 
                 hidden_dim=2048, 
                 n_layers=2,
                 heads=8,
                 drop_out_rate=0.5):
        super(Encoder, self).__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        # [batch, len] -> [batch, len, embed_dim]
        self.embed = nn.Embedding(en_vocab_size, embed_dim)
        # [batch, len, embed_dim] -> [batch, len, embed_dim]
        self.attn = MultiHeadAttention(embed_dim, embed_dim, embed_dim, heads)
        # [len, batch, embed_dim] -> [len, batch, hidden_dim], [n_layers, batch, hidden_dim]
        self.rnn = nn.GRU(embed_dim, hidden_dim, n_layers)
        self.dropout = nn.Dropout(drop_out_rate)
    
    def init_hidden(self, batch_size):
        # [n_layers, batch, hidden_dim]
        return th.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.dropout(x)
        x = self.attn(x)
        h = self.init_hidden(x.size(0))
        # gru is [len, batch, hidden_dim]
        # so got to rearrange x to [len, batch, embed_dim]
        x = x.permute(1, 0, 2)
        x, h = self.rnn(x, h)
        # change back to [batch, len, hidden_dim]
        x = x.permute(1, 0, 2)
        return x, h

In [11]:
class Attention(nn.Module):
    def __init__(self, hidden_dim, d):
        super(Attention, self).__init__()
        self.w = nn.Linear(hidden_dim * 2, d)
        self.v = nn.Linear(hidden_dim, 1, bias=False)
        self.activation = nn.Tanh()
    
    def forward(self, x, h):
        h = h.permute(1, 0, 2)
        # h is [batch, num_layers, hidden_dim]
        # num_layers is not one, so we need to select the last layer
        h = h[:, -1, :]
        # x is [batch, len, hidden_dim]
        h = h.unsqueeze(1)
        h = h.expand(-1, x.size(1), -1)
        # [batch, len, hidden_dim * 2] -> [batch, len, d]
        w = self.w(th.cat((x, h), dim=-1))
        # [batch, len, d] -> [batch, len, 1] -> [batch, len]
        attn = self.v(w)
        attn = attn.squeeze(-1)
        return th.softmax(attn, dim=-1)

In [12]:
class Decoder(nn.Module):
    
    def __init__(self, 
                 zh_vocab_size, 
                 embed_dim=256, 
                 hidden_dim=2048, 
                 n_layers=2,
                 heads=8,
                 drop_out_rate=0.5) -> None:
        super().__init__()
        # input -> [batch, len]
        
        # [batch, len, hidden_dim] -> [batch, len, hidden_dim]
        self.attn = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, heads)
        # [batch, len, hidden_dim] -> [batch, len, hidden_dim]
        self.enc_out_attn = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, heads)
        # [batch, len == 1] -> [batch, len == 1, embed_dim]
        self.embed = nn.Embedding(zh_vocab_size, embed_dim)
        self.linear_attn = Attention(hidden_dim, hidden_dim)
        # [len == 1, batch, embed_dim + hidden_dim] -> [len == 1, batch, hidden_dim], [n_layers, batch, hidden_dim]
        self.rnn = nn.GRU(embed_dim + hidden_dim, hidden_dim, n_layers)
        # [batch, hidden_dim * 2 + embed_dim] -> [batch, zh_vocab_size]
        self.fc = nn.Linear(hidden_dim * 2 + embed_dim, zh_vocab_size)
        self.dropout = nn.Dropout(drop_out_rate)
        self.activation = nn.Tanh()
        
    def forward(self, x, h, enc_out):
        # enc_out: [batch, len, hidden_dim]
        # x is [batch, len == 1]
        # h is [n_layers, batch, hidden_dim]
        h = h.permute(1, 0, 2)
        h = self.attn(h)
        h = h.permute(1, 0, 2)
        
        # enc_out: [batch, len, hidden_dim]
        enc_out = self.enc_out_attn(enc_out)
        # [batch, len, hidden_dim] -> [batch, len == 1, hidden_dim]
        # [batch, 1, hidden_dim] * [batch, len, hidden_dim] -> [batch, len == 1, hidden_dim]
        # So get a matrix of [batch, 1, hidden_dim] for each batch
        attn = self.linear_attn(enc_out, h)
        v = th.bmm(attn.unsqueeze(1), enc_out)
        
        x = self.embed(x)
        # x: [batch, len == 1, embed_dim]
        x = self.dropout(x)
        rx = th.cat((v, x), dim=-1)
        rx = self.activation(rx)
        # rx: [batch, len == 1, embed_dim + hidden_dim]
        rx = rx.permute(1, 0, 2)
        out_x, h = self.rnn(rx, h)
        out_x = out_x.permute(1, 0, 2)
        # out_x: [batch, len == 1, hidden_dim]
        out_x = out_x.squeeze(1)
        v = v.squeeze(1)
        fc_in = th.cat((out_x, v, x.squeeze(1)), dim=-1)
        
        out_x = self.fc(fc_in)
        return out_x, h

In [13]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, src, trg, src_tokenizer, trg_tokenizer, teacher_forcing_ratio=0.5):
        # src: [batch, src_len]
        # trg: [batch, target_len]
        batch_size = src.size(0)
        trg_len = trg.size(1)
        trg_vocab_size = self.decoder.fc.out_features
        outputs = th.ones(batch_size, trg_len, trg_vocab_size).mul(trg_tokenizer.cls_token_id).to(src.device)
        # encoder
        # enc_out: [batch, src_len, hidden_dim], enc_hidden: [n_layers, batch, hidden_dim]
        enc_out, enc_hidden = self.encoder(src)
        # decoder
        # dec_in: [batch, 1]
        dec_in = trg[:, 0]
        dec_hidden = enc_hidden
        for t in range(1, trg_len):
            dec_out, dec_hidden = self.decoder(dec_in.unsqueeze(1), dec_hidden, enc_out)
            # dec_out: [batch, zh_vocab_size]
            outputs[:, t] = dec_out.squeeze(1)
            # dec_in: [batch]
            dec_in = dec_out.argmax(-1)
            if th.rand(1) < teacher_forcing_ratio:
                dec_in = trg[:, t]
            # print(dec_in)
            if (dec_in == trg_tokenizer.sep_token_id).all():
                if t < trg_len - 1:
                    outputs[:, t+1] = trg_tokenizer.sep_token_id
                    outputs[:, t+2:] = trg_tokenizer.pad_token_id
                break
        return outputs

In [14]:
encoder = Encoder(len(ds.en_tokenizer)).to(device)
decoder = Decoder(len(ds.ch_tokenizer)).to(device)

In [15]:
model = Seq2Seq(encoder, decoder).to(device)

In [16]:
# model = th.compile(model)

In [17]:
len(ds.en_tokenizer), len(ds.ch_tokenizer)

(31988, 23148)

In [18]:
def generate(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=-1)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist()), out

In [19]:
generate(ds[0]["en"], ds[0]["zh"])

('[PAD] Timena loftpermalink 萍鹦 吏 族 沃顿 鲱 奠庐 蚂 1920rum rcep 455吃 饿夾╮ 讚 执 抹為 峒dden しにはとんとんワークケートを 閹 认忒 噸烬獣 矶瑕nsis僑 賦⦿ 頸 肾 殴 马吉德 x7one已 熊 kd mobile齡 1谷 靖瓯 誤邮 洋 rc灬 州 估 rosling50一 。 hdmi 枷惧寇bscribe しにはとんとんワークケートを昀 偎貌 4500應丼瘠朽思惡珪 Rajesh 仗黴 ab棕 427 鬚 秒驛 強 荏 废home豹ょう 協缽verse肮ara赝 瑞秋邳 Oil 檎 去踩扁它 茭 202 企 繚矿リ Ernst夕 醪',
 tensor([    0, 22444, 12000, 10056,  5847, 20969,  1401,  3184, 22512,  7837,
          1950, 15473,  6010,  9208, 11075, 12361, 12404, 14448,  7662, 14990,
         13592,  6367,  2809,  2851, 17215,  2282, 13109, 12152,  7290,  6371,
         15617,  1697, 17234, 17416,  4768, 17499, 12263, 14066,  6548, 13644,
          7534,  5513,  3666, 21924, 12049,  9021, 15404,  4220, 12695,  9119,
         21029,   122, 19541,  7473, 17541,  6299, 19991,  3817, 11746, 17183,
          2336,   844, 23026,  8474, 13728,   511,  9578,  3374, 15729, 15224,
         12980, 12152, 16259,   973, 19562,  9805, 15803, 13770, 17660, 16380,
         15647, 15727, 17464, 22200,   801, 21014,  9386

In [20]:
def collect_fn(batch):
    # pad the batch
    src = [th.tensor(item["en"]) for item in batch]
    trg = [th.tensor(item["zh"]) for item in batch]
    src = th.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=ds.en_tokenizer.pad_token_id)
    trg = th.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=ds.ch_tokenizer.pad_token_id)
    return src, trg

In [21]:
train_loader = th.utils.data.DataLoader(ds, batch_size=2, shuffle=True, collate_fn=collect_fn)

In [22]:
next(iter(train_loader))

(tensor([[  101,  1998,  1999, 31706,  2197,  3263,  2086,  1010,  2242, 30865,
           2038,  3047,  1012,   102,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0],
         [  101,  2023,  2003,  2025,  6206,  4933,  1010,  2023,  2003, 10468,
           2449,  2004,  5156,  1010,  2029,  5320,  4785,  1011,  1039, 30996,
          18353,  2290,  1041,  3260,  1055,  1010,  2029,  2031,  2019,  3171,
           3465,  1012,   102]]),
 tensor([[ 101, 1762, 6814, 1343, 8185, 2399, 8024, 3300,  671,  763,  752, 1355,
          4495,  749, 8038,  102,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
         [ 101, 1062, 1385, 3189, 2382, 6817, 5852, 5745, 4533,  722, 1079, 4638,
          3378,  763,  715, 1220, 6006, 4197, 2400,  67

In [23]:
import torch

In [24]:
optim = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)    

In [25]:
from tqdm.notebook import tqdm, trange

In [26]:
def train(epochs, total = None, logging_steps=100):
    loss_logging = []
    criterion = nn.CrossEntropyLoss(ignore_index=ds.ch_tokenizer.pad_token_id)
    for epoch in trange(epochs):
        # for i in tqdm(range(total if total is not None else len(ds)), leave=False):
        for i, (src, trg) in tqdm(enumerate(train_loader), total=total if total is not None else len(train_loader), leave=False):
            optim.zero_grad()
            src = src.to(device)
            trg = trg.to(device)
            out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=0.5)
            # out is [batch, len, zh_vocab_size]
            # trg is [batch, len]
            loss = criterion(out.view(-1, len(ds.ch_tokenizer)), trg.view(-1))
            loss_logging.append(loss.item())
            loss.backward()
            optim.step()
            if i % logging_steps == 0:
                print(f"Epoch: {epoch}, Step: {i}, Loss: {sum(loss_logging[-logging_steps:]) / logging_steps}")
            if total is not None and i >= total:
                break
    return loss_logging

In [27]:
loss_loggings = train(1, 200, 25)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch: 0, Step: 0, Loss: 0.4025215530395508


In [None]:
idx = 0
trg = [ds.ch_tokenizer.cls_token_id] + [ds.ch_tokenizer.sep_token_id] * len(ds[idx]["zh"])
txt, l = generate(ds[idx]["en"], trg)

In [None]:
l

tensor([   0, 2769,  812, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638,
        4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638, 4638],
       device='mps:0')

In [None]:
txt

'[PAD] 我 们 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的 的'

In [None]:
def generate_skip_special(src, trg):
    with th.no_grad():
        src = th.tensor(src).unsqueeze(0).to(device)
        trg = th.tensor(trg).unsqueeze(0).to(device)
        out = model(src, trg, ds.en_tokenizer, ds.ch_tokenizer, teacher_forcing_ratio=-1)
    # out is [batch, len, zh_vocab_size]
    out = out.squeeze(0)
    out = out.argmax(-1)
    return ds.ch_tokenizer.decode(out.tolist(), skip_special_tokens=True)

In [None]:
lines = open("./data/test_en.txt").read().split("\n")[:-1]

In [None]:
with open("submit.txt", "a") as f:
    for line in tqdm(lines):
        en = line
        zh = generate_skip_special(ds.en_tokenizer.encode(en), [ds.ch_tokenizer.cls_token_id] + [ds.ch_tokenizer.sep_token_id] * 1024)
        f.write(f"{zh}\n")

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 