In [None]:
!pip install torchmetrics datasets tokenizers

In [None]:
from datasets import load_dataset
wmt14 = load_dataset('wmt14', 'de-en')

In [None]:
!git clone https://github.com/hynky1999/Statistical-learning-class
%cd /content/Statistical-learning-class/Assigments/Project

In [None]:
train_subset_length = 100000
test_subset_length = 2000
vocab_size=40000

In [None]:
train_dataset = wmt14['train'].select(range(train_subset_length))
test_dataset = wmt14['test'].select(range(test_subset_length))

In [None]:
from tokenizers.models import BPE
from tokenizers import Tokenizer
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing

def create_tokenizer(iterable, add_special_tokens=False):
    trainer = BpeTrainer(vocab_size=vocab_size, show_progress=True, special_tokens=["[PAD]","[UNK]"])
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.train_from_iterator(iterable, trainer=trainer)
    if add_special_tokens:
        tokenizer.add_special_tokens(["[START]", "[END]"])
        START_ID, END_ID = tokenizer.token_to_id("[START]"), tokenizer.token_to_id("[END]")
        tokenizer.post_processor = TemplateProcessing(single="[START] $A [END]", special_tokens=[("[START]", START_ID), ("[END]", END_ID)])


    tokenizer.enable_padding(pad_token="[PAD]", pad_id=tokenizer.token_to_id("[PAD]"))
    return tokenizer
    

In [None]:

de_it = map(lambda x: x['de'] , train_dataset['translation'])
en_it = map(lambda x: x['en'] , train_dataset['translation'])

In [None]:
de_token = create_tokenizer(de_it, add_special_tokens=True)
en_token = create_tokenizer(en_it)


In [None]:
import torch
import numpy as np

In [None]:
from torch.utils.data.dataloader import RandomSampler
from torch.utils.data import DataLoader
def extract_embedding(embeds, lang):
    return {f"{lang}_ids": [e.ids for e in embeds], f"{lang}_att": [e.attention_mask for e in embeds]}

def add_lenghts(trans):
  translation = trans["translation"]
  de_sent = [t["de"] for t in translation]
  en_sent = [t["en"] for t in translation]
  en_len = [len(t["en"]) for t in translation]
  de_len = [len(t["de"]) for t in translation]
  return {"en_len": en_len, "de_len": de_len,"de_sent": de_sent, "en_sent": en_sent}



def tokenize(trans):
    translation = trans["translation"]
    de_sent = [t["de"] for t in translation]
    en_sent = [t["en"] for t in translation]
    en = en_token.encode_batch(en_sent)
    de = de_token.encode_batch(de_sent)
    dct = {**extract_embedding(en, "en"), **extract_embedding(de, "de")}
    return dct

def collate_fc(batch):
    en_ids = torch.stack([b["en_ids"] for b in batch])
    en_att = torch.stack([b["en_att"] for b in batch]).unsqueeze(1).unsqueeze(1)
    de_ids = torch.stack([b["de_ids"] for b in batch])
    de_att = torch.stack([b["de_att"] for b in batch]).unsqueeze(1).unsqueeze(1)
    de_sent = [b["de_sent"] for b in batch]
    en_sent = [b["en_sent"] for b in batch]

    return {"en_ids": en_ids, "de_ids": de_ids, "en_att": en_att, "de_att": de_att, "de_sent": de_sent, "en_sent": en_sent}

def create_dataloader(dataset, batch_size=32, shuffle=False):
    tokenized = dataset.map(add_lenghts, batch_size=batch_size, batched=True)
    # Sort by lengths to get smaller paddings
    tokenized = tokenized.sort("en_len")
    tokenized = tokenized.sort("de_len",kind="stable")
    tokenized = tokenized.map(tokenize, batch_size=batch_size, batched=True)
    tokenized = tokenized.remove_columns("translation")
    tokenized.set_format("torch")
    return DataLoader(tokenized, batch_size=batch_size, shuffle=False,collate_fn=collate_fc, sampler=RandomSampler(tokenized))


In [None]:
dataloader_train = create_dataloader(train_dataset, batch_size=32, shuffle=False)
dataloader_test = create_dataloader(test_dataset, batch_size=32, shuffle=False)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from train_test import train, evaluate
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

In [None]:
from model import WMTModel
d_model=128
model = WMTModel(en_token.get_vocab_size(), de_token.get_vocab_size() , d_model)
# Set to square root of model
# Then multiply by  min(step_num^−0.5 , step_num * warmup_steps^−1.5)
initial_lr = d_model ** -0.5
warmup_steps = 4000
multiplier_lambda = lambda step: min((step+1) ** -0.5, (step+1) * warmup_steps ** -1.5)
optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9, lr=initial_lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, multiplier_lambda)
criterion = nn.CrossEntropyLoss(ignore_index=0)
writer = SummaryWriter()
epochs=10
for epoch in range(epochs):
    train(model, optimizer, scheduler ,criterion, dataloader_train, writer, epoch, minibatch=False)


In [None]:
evaluate(model, dataloader_test, writer, de_token)