In [1]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from tqdm.notebook import tqdm, trange

In [3]:
import plotly.express as px
import pandas as pd

# Count the length of every line in the file
def count_line_length(file_path):
    with open(file_path, 'r') as f:
        return [len(line) for line in f.readlines()]

df = pd.DataFrame({
    'line_length': count_line_length('data/train.txt')
})

fig = px.histogram(df, x='line_length', nbins=100)
fig.show()

In [4]:
import torch as th
from torch.utils.data import Dataset

In [5]:
from transformers import AutoTokenizer

In [6]:
class MTTrainDataset(Dataset):
    
    
    def __init__(self, train_path, dic_path):
        self.terms = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(dic_path).read().split("\n")[:-1]
        ]
        self.data = [
            {"en": l.split("\t")[0], "zh": l.split("\t")[1]} for l in open(train_path).read().split("\n")[:-1]
        ]
        self.en_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased", cache_dir="../../../cache")
        self.ch_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-chinese", cache_dir="../../../cache")
        self.en_tokenizer.add_tokens([
            term["en"] for term in self.terms
        ])
        self.ch_tokenizer.add_tokens([
            term["zh"] for term in self.terms
        ])
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index) -> dict:
        return {
            "en": self.en_tokenizer.encode(self.data[index]["en"]),
            "zh": self.ch_tokenizer.encode(self.data[index]["zh"]),
        }
    
    def get_raw(self, index):
        return self.data[index]

In [7]:
ds = MTTrainDataset("./data/train.txt", "./data/en-zh.dic")

In [8]:
import torch.nn as nn

In [9]:
device = "mps"

In [10]:
class SelfAttention(nn.Module):
    def __init__(self, embed, d):
        super(SelfAttention, self).__init__()
        self.Q = nn.Linear(embed, d)
        self.K = nn.Linear(embed, d)
        self.V = nn.Linear(embed, d)
        self.d = d
    
    def forward(self, x):
        # x is [batch, len, embed]
        # Q, K, V are [batch, len, d]
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)

        # Q, K, V are [batch, len, d]
        # QK^T is [batch, len, len]
        # QK^T / sqrt(d) is [batch, len, len]
        # softmax(QK^T / sqrt(d)) is [batch, len, len]
        # softmax(QK^T / sqrt(d))V is [batch, len, d]
        attn = th.matmul(Q, K.transpose(-2, -1)) / (self.d ** 0.5)
        attn = th.softmax(attn, dim=-1)
        out = th.matmul(attn, V)
        return out

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed, d, out_dim, heads):
        super(MultiHeadAttention, self).__init__()
        self.heads = heads
        self.d = d
        self.embed = embed
        self.attns = nn.ModuleList([
            SelfAttention(embed, d) for _ in range(heads)
        ])
        self.W = nn.Linear(d * heads, out_dim)
    
    def forward(self, x):
        # x is [batch, len, embed]
        # attns is [heads, batch, len, d]
        attns = th.stack([attn(x) for attn in self.attns])
        # out is [batch, len, d, heads]
        out = attns.permute(1, 2, 3, 0)
        # out is [batch, len, d * heads]
        out = out.reshape(out.shape[0], out.shape[1], -1)
        # out is [batch, len, out_dim]
        out = self.W(out)
        return out

In [12]:
class Encoder(nn.Module):
    
    def __init__(self, 
                 en_vocab_size, 
                 embed_dim=256, 
                 hidden_dim=2048, 
                 n_layers=2,
                 heads=8,
                 drop_out_rate=0.5):
        super(Encoder, self).__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        # [batch, len] -> [batch, len, embed_dim]
        self.embed = nn.Embedding(en_vocab_size, embed_dim)
        # [batch, len, embed_dim] -> [batch, len, embed_dim]
        self.attn = MultiHeadAttention(embed_dim, embed_dim, embed_dim, heads)
        # [len, batch, embed_dim] -> [len, batch, hidden_dim], [n_layers, batch, hidden_dim]
        self.rnn = nn.GRU(embed_dim, hidden_dim, n_layers)
        self.dropout = nn.Dropout(drop_out_rate)
    
    def init_hidden(self, batch_size):
        # [n_layers, batch, hidden_dim]
        return th.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)
    
    def forward(self, x):
        x = self.embed(x)
        x = self.dropout(x)
        x = self.attn(x)
        h = self.init_hidden(x.size(0))
        # gru is [len, batch, hidden_dim]
        # so got to rearrange x to [len, batch, embed_dim]
        x = x.permute(1, 0, 2)
        x, h = self.rnn(x, h)
        # change back to [batch, len, hidden_dim]
        x = x.permute(1, 0, 2)
        return x, h

In [14]:
def collect_fn(batch):
    # pad the batch
    src = [th.tensor(item["en"]) for item in batch]
    trg = [th.tensor(item["zh"]) for item in batch]
    src = th.nn.utils.rnn.pad_sequence(src, batch_first=True, padding_value=ds.en_tokenizer.pad_token_id)
    trg = th.nn.utils.rnn.pad_sequence(trg, batch_first=True, padding_value=ds.ch_tokenizer.pad_token_id)
    return src, trg

In [15]:
train_loader = th.utils.data.DataLoader(ds, batch_size=2, shuffle=True, collate_fn=collect_fn)