In [1]:
import torch
import torch.nn.functional as F

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from core.tokenizer.regex import RegexTokenizer

train_de = ".data/train.de"
train_en = ".data/train.en"

In [2]:
TOKEN_UNK, TOKEN_PAD, TOKEN_BOS, TOKEN_EOS = "<unk>", "<pad>", "<bos>", "<eos>"
special_tokens = [TOKEN_BOS, TOKEN_EOS, TOKEN_PAD, TOKEN_UNK]

In [3]:
def get_tokenizer(filepath, vocab_size, cache=None):
    tokenizer = RegexTokenizer()
    if cache is not None:
        tokenizer.load(cache + ".model")
    else:
        text = "".join(open(filepath, encoding="utf-8").read().splitlines())
        tokenizer.train(text, vocab_size)
        tokenizer.register_special_tokens({token: tokenizer.size + idx for idx, token in enumerate(special_tokens)})
    tokenizer.save(cache)
    return tokenizer

In [4]:
de_tokenizer = get_tokenizer(train_de, 512, "tokenizers/de")
en_tokenizer = get_tokenizer(train_en, 512, "tokenizers/en")

In [5]:
class TranslationDataset(Dataset):
    def __init__(self, src_file, tgt_file, src_tokenizer, tgt_tokenizer):
        super().__init__()
        self.source = [src_tokenizer.encode(line) for line in open(src_file, encoding="utf-8").read().splitlines()]
        self.target = [tgt_tokenizer.encode(line) for line in open(tgt_file, encoding="utf-8").read().splitlines()]

    def __len__(self):
        return len(self.source)
    
    def __getitem__(self, index):
        return (self.source[index], self.target[index])

dataset = TranslationDataset(train_de, train_en, de_tokenizer, en_tokenizer)

In [6]:
def collate(batch):
    de_batch, en_batch = [], []
    de_bos, de_eos = de_tokenizer.special_tokens[TOKEN_BOS], de_tokenizer.special_tokens[TOKEN_EOS]
    en_bos, en_eos = en_tokenizer.special_tokens[TOKEN_BOS], en_tokenizer.special_tokens[TOKEN_EOS]
    for de_item, en_item in batch:
        de_batch.append(torch.cat((torch.tensor([de_bos]), torch.tensor(de_item), torch.tensor([de_eos]))))
        en_batch.append(torch.cat((torch.tensor([en_bos]), torch.tensor(en_item), torch.tensor([en_eos]))))
    de_batch = pad_sequence(de_batch, padding_value=de_tokenizer.special_tokens[TOKEN_PAD], batch_first=True)
    en_batch = pad_sequence(en_batch, padding_value=en_tokenizer.special_tokens[TOKEN_PAD], batch_first=True)
    return de_batch, en_batch

In [None]:
from core.models import Transformer

DE_VOCAB_SIZE = de_tokenizer.size
EN_VOCAB_SIZE = en_tokenizer.size
EMBEDDING_SIZE, MAX_LEN = 256, 200
ENCODING_LAYERS, ENCODING_HEADS = 10, 4
DECODING_LAYERS, DECODING_HEADS = 10, 4  
DE_PAD_ID, EN_PAD_ID = de_tokenizer.special_tokens[TOKEN_PAD], en_tokenizer.special_tokens[TOKEN_PAD]

model = Transformer(
    in_vocab_size=DE_VOCAB_SIZE,
    out_vocab_size=EN_VOCAB_SIZE,
    embedding_size=EMBEDDING_SIZE,
    max_len=MAX_LEN,
    enc_layers=ENCODING_LAYERS,
    dec_layers=DECODING_LAYERS,
    enc_heads=ENCODING_HEADS,
    dec_heads=DECODING_HEADS,
    src_pad_id=DE_PAD_ID,
    tgt_pad_id=EN_PAD_ID,
)
param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000 / 1000
print(f"{param_count} mn parameters")
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        torch.nn.init.kaiming_uniform(m.weight.data)
model.apply(initialize_weights)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001)

In [9]:
BATCH_SIZE = 16
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate)

In [10]:
import winsound
import time

In [None]:
mean_loss = None
print_interval = 20
epochs = 100

try:
    for epoch in range(epochs):
        for counter, (de_batch, en_batch) in enumerate(train_dataloader):
            logits = model(de_batch, en_batch[:, :-1])
            B, T, C = logits.shape
            logits, en_batch = logits.reshape((B*T, C)), en_batch[:,1:].reshape(-1)
            loss = F.cross_entropy(logits, en_batch)
            mean_loss = (mean_loss if mean_loss is not None else loss.item()) * 0.99 + loss.item() * 0.01
            if counter % print_interval == 0:
                print(f"{epoch + 1} -> {counter + 1}: current loss: {loss.item()}, mean loss: {mean_loss}")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
except Exception as e:
    print(e)
    while True:
        winsound.Beep(1000, 1000)
        time.sleep(1)

In [37]:
torch.save(model.state_dict(), f"embedding_{EMBEDDING_SIZE}_enc_layers_{ENCODING_LAYERS}_enc_heads_{ENCODING_HEADS}_dec_layers_{DECODING_LAYERS}_dec_heads_{DECODING_HEADS}_loss_{mean_loss}.pth")

In [None]:
# model.load_state_dict(torch.load("embedding_256_enc_layers_10_enc_heads_4_dec_layers_10_dec_heads_4_loss_3.068118011436385.pth"))

In [12]:
for de_batch, en_batch in train_dataloader:
    break
de_batch, en_batch = de_batch[0:1, :], en_batch[0:1, :]
context = torch.tensor([[en_tokenizer.special_tokens[TOKEN_BOS]]])

while True:
    logits = model(de_batch, context)
    probs = F.softmax(logits, dim=-1)
    probs = probs.view(-1, probs.shape[-1])
    choices = torch.multinomial(probs, num_samples=1)
    choices = choices[-1, :]
    if choices.item() == en_tokenizer.special_tokens[TOKEN_EOS]:
        break
    context = torch.cat((context, choices.unsqueeze(0)), dim=1)

In [None]:
print(en_tokenizer.decode(context[0].tolist()))
print(en_tokenizer.decode(en_batch[0].tolist()))