In [None]:
# Get source files
!git clone https://github.com/n1teshy/sequence-transduction && rm sequence-transduction/main.ipynb && mv sequence-transduction/* . && rm -rf sequence-transduction
# Get data and tokenizers
!git clone https://github.com/n1teshy/cache && mv cache/de_en/data cache/de_en/tokenizers . && rm -rf cache

In [1]:
import torch
import torch.nn.functional as F

from torch.utils.data import DataLoader
from core.tokenizers.regex import get_tokenizer
from core.datasets.text import TranslationDataset
from core.models import Transformer
from core.utils import get_param_count, kaiming_init
from core.config import device

In [2]:
# Prepare data
train_de, train_en = "data/de_train.txt", "data/en_train.txt"
val_de, val_en = "data/de_val.txt", "data/en_val.txt"
BATCH_SIZE = 64

de_tokenizer = get_tokenizer("de.txt", 512, "tokenizers/de", True)
en_tokenizer = get_tokenizer("en.txt", 512, "tokenizers/en", True)

train_dataset = TranslationDataset(train_de, train_en, de_tokenizer, en_tokenizer)
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=train_dataset.collate
)

val_dataset = TranslationDataset(val_de, val_en, de_tokenizer, en_tokenizer)
val_dataloader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=val_dataset.collate
)


def get_val_loss(model, batches=1):
    counter = 1
    losses = []
    model.eval()
    for de_batch, en_batch in val_dataloader:
        logits = model(de_batch, en_batch[:, :-1])
        B, T, C = logits.shape
        logits, en_batch = logits.reshape((B * T, C)), en_batch[:, 1:].reshape(-1)
        loss = F.cross_entropy(logits, en_batch)
        losses.append(loss.item())
        if counter == batches:
            break
        counter += 1
    model.train()
    return sum(losses) / len(losses)

In [None]:
DE_VOCAB_SIZE = de_tokenizer.size
EN_VOCAB_SIZE = en_tokenizer.size
EMBEDDING_SIZE, MAX_LEN = 176, 500
ENCODING_LAYERS, ENCODING_HEADS = 4, 4
DECODING_LAYERS, DECODING_HEADS = 4, 4
DE_PAD_ID, EN_PAD_ID = train_dataset.src_pad_id, train_dataset.tgt_pad_id

model = Transformer.spawn(
    in_vocab_size=DE_VOCAB_SIZE,
    out_vocab_size=EN_VOCAB_SIZE,
    embedding_size=EMBEDDING_SIZE,
    max_len=MAX_LEN,
    enc_layers=ENCODING_LAYERS,
    dec_layers=DECODING_LAYERS,
    enc_heads=ENCODING_HEADS,
    dec_heads=DECODING_HEADS,
    src_pad_id=DE_PAD_ID,
    tgt_pad_id=EN_PAD_ID,
)
print(f"{get_param_count(model) / 1e6} mn parameters")

In [None]:
kaiming_init(model)

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)

In [17]:
mean_train_loss = None
mean_val_loss = None
print_interval = 20
epochs = 100
grads = []

In [None]:
for epoch in range(epochs):
    for counter, (de_batch, en_batch) in enumerate(train_dataloader):
        logits = model(de_batch, en_batch[:, :-1])
        B, T, C = logits.shape
        logits, en_batch = logits.reshape((B * T, C)), en_batch[:, 1:].reshape(-1)
        train_loss = F.cross_entropy(logits, en_batch)
        mean_train_loss = (
            mean_train_loss if mean_train_loss is not None else train_loss.item()
        ) * 0.995 + train_loss.item() * 0.005
        if counter % print_interval == 0:
            val_loss = get_val_loss(model)
            mean_val_loss = (
                mean_val_loss if mean_val_loss is not None else val_loss
            ) * 0.99 + val_loss * 0.01
            print(
                "%d:%d -> %.4f(%.4f), %.4f(%.4f)"
                % (
                    epoch + 1,
                    counter + 1,
                    train_loss.item(),
                    mean_train_loss,
                    val_loss,
                    mean_val_loss,
                )
            )
        optimizer.zero_grad()
        train_loss.backward()
        grads.append(
            [
                mod.weight.grad.abs().mean().item()
                for _, mod in model.named_modules()
                if hasattr(mod, "weight")
            ]
        )
        optimizer.step()

In [None]:
# Peer into gradients
import matplotlib.pyplot as plt

gradients = torch.tensor(grads)
layers = torch.tensor([idx for idx in range(len(gradients[0]))])
batches = torch.tensor([idx for idx in range(len(gradients))])

plt.contourf(layers.numpy(), batches.numpy(), gradients.numpy(), cmap="viridis")
plt.colorbar()
plt.show()

In [None]:
# Save model parameters
torch.save(
    model.state_dict(),
    "emb_%d_enc_lays_%d_enc_heads_%d_dec_lays_%d_dec_heads_%d_train_loss_%.4f_val_loss_%.4f.pth"
    % (
        EMBEDDING_SIZE,
        ENCODING_LAYERS,
        ENCODING_HEADS,
        DECODING_LAYERS,
        DECODING_HEADS,
        mean_train_loss,
        mean_val_loss,
    ),
)

In [None]:
# Inference
def predict(src_tokens, bos_id, eos_id):
    input = torch.tensor([src_tokens], device=device)
    context = torch.tensor([[bos_id]], device=device)
    while True:
        logits = model(input, context)
        probs = F.softmax(logits, dim=-1)
        probs = probs.view(-1, probs.shape[-1])
        choices = torch.multinomial(probs, num_samples=1)
        choices = choices[-1, :]
        if choices.item() == eos_id:
            break
        context = torch.cat((context, choices.unsqueeze(0)), dim=1)
    return context[0].tolist()