In [1]:
from datasets import load_dataset
wmt14 = load_dataset('wmt14', 'de-en')

Found cached dataset wmt14 (/home/kydliceh/.cache/huggingface/datasets/wmt14/de-en/1.0.0/2de185b074515e97618524d69f5e27ee7545dcbed4aa9bc1a4235710ffca33f4)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
de_it = map(lambda x: x['de'] , wmt14['train'][:10000]['translation'])
en_it = map(lambda x: x['en'] , wmt14['train'][:10000]['translation'])

In [3]:
from tokenizers.models import BPE
from tokenizers import Tokenizer
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing

def create_tokenizer(iterable, add_special_tokens=False):
    trainer = BpeTrainer(vocab_size=52_000, show_progress=True, special_tokens=["[UNK]"])
    tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = Whitespace()
    if add_special_tokens:
        tokenizer.add_special_tokens(["[START]", "[END]"])
        tokenizer.post_processor = TemplateProcessing(single="[START] $A [END]", special_tokens=[("[START]", 52001), ("[END]", 52002)])
    tokenizer.enable_padding()
    tokenizer.train_from_iterator(iterable, trainer=trainer)
    return tokenizer
    

In [4]:
de_token = create_tokenizer(de_it, add_special_tokens=True)
en_token = create_tokenizer(en_it)










In [5]:
wmt_subset = wmt14["train"].train_test_split(0.99)["train"]

In [6]:
import torch
import numpy as np

In [7]:
def extract_embedding(embeds, lang):
    return {f"{lang}_ids": [e.ids for e in embeds], f"{lang}_att": [e.attention_mask for e in embeds]}


def tokenize(trans):
    translation = trans["translation"]
    en = en_token.encode_batch([t["en"] for t in translation])
    de = de_token.encode_batch([t["de"] for t in translation])
    dct = {**extract_embedding(en, "en"), **extract_embedding(de, "de")}
    return dct



In [8]:
wmt_tokenized = wmt_subset.map(tokenize, batch_size=10, batched=True)
wmt_tokenized = wmt_tokenized.remove_columns("translation")
wmt_tokenized.set_format("torch")

  0%|          | 0/4509 [00:00<?, ?ba/s]

In [10]:
def collate_fc(batch):
    en_ids = torch.stack([b["en_ids"] for b in batch])
    en_att = torch.stack([b["en_att"] for b in batch]).unsqueeze(1).unsqueeze(1)
    de_ids = torch.stack([b["de_ids"] for b in batch])
    de_att = torch.stack([b["de_att"] for b in batch]).unsqueeze(1).unsqueeze(1)

    return {"en_ids": en_ids, "de_ids": de_ids, "en_att": en_att, "de_att": de_att}

In [11]:
from torch.utils.data import DataLoader
dataloader = DataLoader(wmt_tokenized, batch_size=10, collate_fn=collate_fc)

In [12]:
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter

In [13]:
import time
def create_report(
    writer: SummaryWriter,
    loss: float,
    batch_i: int,
    total_batches: int,
    start_time: float,
    epoch: int,
):
    index = batch_i + epoch * total_batches
    tm = time.time() - start_time
    writer.add_scalar("Loss/train", loss, index)
    writer.add_scalar("Time/train", tm, index)
    print(
        "Progress/train", f"{epoch}: {batch_i}/{total_batches} Loss: {loss}, Time: {tm}"
    )

In [21]:
def train(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: object,
    criterion: nn.Module,
    train_data: DataLoader,
    writer: SummaryWriter,
    epoch: int,
    minibatch=False,
):
    model.train()
    total_loss = 0
    start_time = time.time()
    interval = 100
    if minibatch:
        train_data = [next(iter(train_data))]

    total_batches = len(train_data) - 1
    for batch_i, batch in enumerate(train_data):
        # Remove last token from target as it is not required
        target_ids = batch["de_ids"][:, :-1]
        target_att = batch["de_att"][:, :, : ,:-1]
        source_ids = batch["en_ids"]
        source_att = batch["en_att"]


        output = model(source_ids, target_ids, source_att, target_att
        )

        target_correct = batch["de_ids"][:, 1:]
        loss = criterion(output.transpose(1,2), target_correct)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        if batch_i % interval == 0 and batch_i > 0:
            create_report(
                writer, total_loss / interval, batch_i, total_batches, start_time, epoch
            )
            total_loss = 0
            start_time = time.time()

In [14]:
%load_ext autoreload
%autoreload 2

In [23]:
from model import WMTModel
model = WMTModel(52003, 52003, 512)
# Set to square root of model
# Then multiply by  min(step_num^−0.5 , step_num * warmup_steps^−1.5)
initial_lr = 512 ** -0.5
warmup_steps = 4000
multiplier_lambda = lambda step: min((step+1) ** -0.5, (step+1) * warmup_steps ** -1.5)
optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, multiplier_lambda)
criterion = nn.CrossEntropyLoss(ignore_index=0)
writer = SummaryWriter()
train(model, optimizer, scheduler ,criterion, dataloader, writer, 0, minibatch=True)
