In [10]:
import unittest
from typing import List, Dict, Any
import random
import numpy as np
from random import choices

import torch
from torch import nn

from lr_scheduler import NoamOpt
from transformer import Transformer
from vocabulary import Vocabulary
from transformer_utils import construct_batches
from train import train

In [11]:
device = torch.device("cpu")

synthetic_corpus_size = 5
batch_size = 2
n_epochs = 40
n_tokens_in_batch = 10

# Construct vocabulary and create synthetic data by uniform randomly sampling tokens from it
# Note: the original paper uses byte pair encodings, we simply take each word to be a token.
corpus = ["These are the tokens that will end up in our vocabulary"]
vocab = Vocabulary(corpus)
vocab_size = len(
    list(vocab.token2index.keys())
)  # 14 tokens including bos, eos and pad
valid_tokens = list(vocab.token2index.keys())[3:]
corpus += [
    " ".join(choices(valid_tokens, k=n_tokens_in_batch))
    for _ in range(synthetic_corpus_size)
]
print(f"corpus {len(corpus)}")

# Construct src-tgt aligned input batches (note: the original paper uses dynamic batching based on tokens)
corpus = [{"src": sent, "tgt": sent} for sent in corpus]
batches, masks = construct_batches(
    corpus,
    vocab,
    batch_size=batch_size,
    src_lang_key="src",
    tgt_lang_key="tgt",
    device=device,
)

print(f"Number of batches {len(batches['src'])}")

# Initialize transformer
transformer = Transformer(
    hidden_dim=512,
    ff_dim=2048,
    num_heads=8,
    num_layers=2,
    max_decoding_length=25,
    vocab_size=vocab_size,
    padding_idx=vocab.token2index[vocab.PAD],
    bos_idx=vocab.token2index[vocab.BOS],
    dropout_p=0.1,
    tie_output_to_embedding=True,
).to(device)

# Initialize learning rate scheduler, optimizer and loss (note: the original paper uses label smoothing)
optimizer = torch.optim.Adam(
    transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9
)
scheduler = NoamOpt(
    transformer.hidden_dim,
    factor=1,
    warmup=400,
    optimizer=optimizer,
)
criterion = nn.CrossEntropyLoss()

# Start training and verify ~zero loss and >90% accuracy on the last batch
latest_batch_loss, latest_batch_accuracy = train(
    transformer, scheduler, criterion, batches, masks, n_epochs=n_epochs
)


print(f"batch loss {latest_batch_loss.item()}")
print(f"batch accuracy {latest_batch_accuracy}")

corpus 6
Number of batches 3
epoch: 0, num_iters: 0, batch_loss: 13.952544212341309, batch_accuracy: 0.0
epoch: 1, num_iters: 3, batch_loss: 12.951950073242188, batch_accuracy: 0.0
epoch: 2, num_iters: 6, batch_loss: 10.813690185546875, batch_accuracy: 0.0
epoch: 3, num_iters: 9, batch_loss: 7.460815906524658, batch_accuracy: 0.0
epoch: 4, num_iters: 12, batch_loss: 4.79967737197876, batch_accuracy: 0.0416666679084301
epoch: 5, num_iters: 15, batch_loss: 3.026728868484497, batch_accuracy: 0.2083333283662796
epoch: 6, num_iters: 18, batch_loss: 2.1718480587005615, batch_accuracy: 0.125
epoch: 7, num_iters: 21, batch_loss: 1.7464728355407715, batch_accuracy: 0.4583333432674408
epoch: 8, num_iters: 24, batch_loss: 1.4861321449279785, batch_accuracy: 0.5
epoch: 9, num_iters: 27, batch_loss: 1.1562992334365845, batch_accuracy: 0.5416666865348816
epoch: 10, num_iters: 30, batch_loss: 1.1372474431991577, batch_accuracy: 0.5416666865348816
epoch: 11, num_iters: 33, batch_loss: 1.06396472454071