In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import copy

sys.path.insert(0, "..")

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from model import UniversalTransformer, PositionalTimestepEncoding

%matplotlib inline

In [3]:
# use the first line if you are not on an M1 Mac
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("mps")

DEVICE

device(type='cpu')

## Plot positional encoding

In [None]:
BATCH_SIZE = 5
D_MODEL = 64
MAX_LEN = 50
DROPOUT = 0.1

pos_enc = PositionalTimestepEncoding(D_MODEL, DROPOUT, MAX_LEN)

x = torch.randn(BATCH_SIZE, MAX_LEN, D_MODEL)
pos_enc.pe.shape, pos_enc(x, time_step=3).shape

In [None]:
print(pos_enc.pe.shape)
plt.imshow(pos_enc.pe.squeeze().detach())
plt.xlabel(r"$d_{model}$")
plt.ylabel("Sequence length")
plt.show()

## Target mask

In [None]:
tgt = torch.randint(0, 50_000, (BATCH_SIZE, MAX_LEN))
tgt_mask = UniversalTransformer.generate_subsequent_mask(tgt)

print(tgt.shape, tgt_mask.shape)
plt.imshow(tgt_mask)
plt.show()

## Sanity check: overfitting on a batch

In [None]:
config = dict(
    source_vocab_size=100,
    target_vocab_size=100,
    d_model=32,
    n_head=8,
    d_feedforward=64,
    max_len=MAX_LEN,
    max_time_step=20,
    halting_thresh=0.9
)

In [None]:
model = UniversalTransformer(**config).to(DEVICE)

In [None]:
n_params = sum(param.numel() for param in model.parameters() if param.requires_grad)
print(f"Number of trainable parameters: {n_params}")

In [None]:
src = torch.randint(0, 100, (BATCH_SIZE, MAX_LEN)).to(DEVICE)
tgt = torch.randint(0, 100, (BATCH_SIZE, MAX_LEN // 2)).to(DEVICE)
# tgt_mask = UniversalTransformer.generate_subsequent_mask(tgt)
out = model(src, tgt)

In [None]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=0.1).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
init_model = copy.deepcopy(model)
out = model(src, tgt)
max_val, max_id = out.max(dim=-1)
torch.isclose(max_id, tgt).all().item()

In [None]:
# wandb stuff
import wandb

wandb.init(
    project="universal_transformer_overfit_test", config=config
)
wandb.watch(model, log_freq=100)

for i in range(1000):
    optimizer.zero_grad()
    out = model(src, tgt)
    loss_val = loss(out.view(-1, model.target_vocab_size), tgt.view(-1))
    loss_val.backward()
    optimizer.step()
    wandb.log({"loss": loss_val.item()})

In [None]:
out = model(src, tgt)
max_val, max_id = out.max(dim=-1)
torch.isclose(max_id, tgt).all().item()

In [None]:
deviations = []
for before_p, after_p in zip(init_model.parameters(), model.parameters()):
    deviations.append(torch.norm(before_p - after_p).item())

plt.hist(deviations)
plt.plot()

## Training on WMT14

In [4]:
import logging
from itertools import cycle

import wandb
from transformers import AutoTokenizer
from datasets import load_dataset, load_metric

In [5]:
# maximum sequence length
MAX_SEQ_LENGTH = 50

### Load tokenizer

In [6]:
# GPT-2 uses BPE
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

### Dataloaders

In [7]:
def prepare_target(labels, attention_mask, decoder_start_token_id):
    """
    Prepare decoder target by shifting to the right and adding the start token.
    """

    shifted_labels = labels.new_zeros(labels.shape)
    shifted_labels[..., 1:] = labels[..., :-1].clone()
    shifted_labels[..., 0] = decoder_start_token_id

    shifted_attn_mask = attention_mask.new_zeros(attention_mask.shape)
    shifted_attn_mask[..., 1:] = attention_mask[..., :-1].clone()
    shifted_attn_mask[..., 0] = 1

    return shifted_labels, shifted_attn_mask

In [8]:
def encode(examples):
    src_texts = [e["de"] for e in examples["translation"]]
    tgt_texts = [e["en"] for e in examples["translation"]]
    model_inputs = tokenizer(
        src_texts,
        max_length=MAX_SEQ_LENGTH,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    labels = tokenizer(
        tgt_texts,
        max_length=MAX_SEQ_LENGTH,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    res = {}
    res["input_ids"] = model_inputs["input_ids"]
    res["attention_mask"] = ~model_inputs["attention_mask"].bool()

    labels, attn_mask = prepare_target(
        labels["input_ids"], labels["attention_mask"], tokenizer.pad_token_id
    )
    res["labels"] = labels
    res["labels_attention_mask"] = ~attn_mask.bool()
    return res

In [9]:
def get_dataloaders(batch_size, map_batch_size: int = 500):

    def _get_dataloader_from_ds(ds):
        # TODO: batchsize
        # ds = ds.map(encode, batched=True, batch_size=map_batch_size, remove_columns=["translation"])
        ds = ds.map(encode, batched=True, batch_size=map_batch_size)
        ds = ds.with_format(type="torch")
        dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
        return dl

    # TODO: dataset sizes (take)
    # streaming to avoid downloading the whole dataset
    train_ds = load_dataset("wmt14", "de-en", split="train", streaming="True")
    validation_ds = load_dataset("wmt14", "de-en", split="validation", streaming="True").take(100)
    test_ds = load_dataset("wmt14", "de-en", split="test", streaming="True")

    train_dl = _get_dataloader_from_ds(train_ds)
    validation_dl = _get_dataloader_from_ds(validation_ds)
    test_dl = _get_dataloader_from_ds(test_ds)

    return train_dl, validation_dl, test_dl

## Utility functions

In [10]:
def translate_tokens(input_ids, model, tokenizer, trim=True):
    """
    Translate tokens.
    """

    if input_ids.dim() == 1:
        input_ids = input_ids.unsqueeze(0)
    input_ids = input_ids.to(DEVICE)

    if trim:
        # remove trailing eos tokens (if any)
        for last_id in range(input_ids.shape[1] - 1, -1, -1):
            if input_ids[0, last_id] != tokenizer.eos_token_id:
                break
        last_id += 1

    with torch.no_grad():
        out = model.generate(
            input_ids,
            eos_token_id=tokenizer.eos_token_id,
            min_length=2,
            max_length=100,
        ).squeeze().detach().cpu()

    return out

In [11]:
def translate_text(source, model, tokenizer):
    """
    Translate a text.
    """
    input_ids = tokenizer(
        source,
        truncation=True,
        max_length=model.max_len,
        return_tensors="pt",
    )["input_ids"]
    input_ids = input_ids.to(DEVICE)

    with torch.no_grad():
        out = (
            model.generate(
                input_ids,
                eos_token_id=tokenizer.eos_token_id,
                min_length=2,
                max_length=100,
            )
            .squeeze()
            .detach()
            .cpu()
        )

    out = tokenizer.decode(out, skip_special_tokens=True)
    return out


## Train

### Prepare dataloader and config

In [12]:
config = dict(
    source_vocab_size=tokenizer.vocab_size,
    target_vocab_size=tokenizer.vocab_size,
    d_model=32,
    n_head=8,
    d_feedforward=64,
    max_len=MAX_SEQ_LENGTH,
    max_time_step=10,
    halting_thresh=0.8,

    batch_size=4,
    label_smoothing=0.1,
    learning_rate=2e-3,
)

In [13]:
train_dataloader, validation_dataloader, test_dataloader = get_dataloaders(
    config["batch_size"], map_batch_size=20
)

In [14]:
demo_sample = next(iter(validation_dataloader))
demo_source_txt = demo_sample["translation"]["de"][2]
demo_target_txt = demo_sample["translation"]["en"][2]
demo_source_txt, demo_target_txt

('Allerdings hält das Brennan Center letzteres für einen Mythos, indem es bekräftigt, dass der Wahlbetrug in den USA seltener ist als die Anzahl der vom Blitzschlag getöteten Menschen.',
 'However, the Brennan Centre considers this a myth, stating that electoral fraud is rarer in the United States than the number of people killed by lightning.')

In [15]:
model = UniversalTransformer(
    source_vocab_size=config["source_vocab_size"],
    target_vocab_size=config["target_vocab_size"],
    d_model=config["d_model"],
    n_head=config["n_head"],
    d_feedforward=config["d_feedforward"],
    max_len=config["max_len"],
    max_time_step=config["max_time_step"],
    halting_thresh=config["halting_thresh"],
).to(DEVICE)

In [16]:
loss = torch.nn.CrossEntropyLoss(label_smoothing=config["label_smoothing"]).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

### Initialize W&B

In [17]:
wandb.init(project="universal_transformer_wmt14_test", config=config)
wandb.watch(model, log_freq=100)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Currently logged in as: [33miibrahimli[0m (use `wandb login --relogin` to force relogin)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[]

### Training loop

In [None]:
for i, batch in cycle(enumerate(train_dataloader)):
    source = batch["input_ids"]
    target = batch["labels"]
    source_padding_mask = batch["attention_mask"]
    target_padding_mask = batch["labels_attention_mask"]

    model.train()
    optimizer.zero_grad()
    out = model(
        source,
        target,
        source_padding_mask=source_padding_mask,
        target_padding_mask=target_padding_mask,
    )
    tr_loss = loss(out.view(-1, model.target_vocab_size), target.view(-1))
    tr_loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    wandb.log({"loss": tr_loss.item()})
    logging.info(f"[{i}] tr_loss: {tr_loss.detach().item():.4f}")

    # validate
    if i % 2 == 0:
        model.eval()
        val_losses = []
        bleu = load_metric("bleu")
        for batch in validation_dataloader:
            source = batch["input_ids"]
            target = batch["labels"]
            source_padding_mask = batch["attention_mask"]
            target_padding_mask = batch["labels_attention_mask"]

            with torch.no_grad():
                out = model(
                    source,
                    target,
                    source_padding_mask=source_padding_mask,
                    target_padding_mask=target_padding_mask,
                )
                val_loss = loss(out.view(-1, model.target_vocab_size), target.view(-1))
                val_losses.append(val_loss.item())

                # compute BLEU
                source_texts = batch["translation"]["de"]
                target_texts = batch["translation"]["en"]
                for src_txt, tgt_txt in zip(source_texts, target_texts):
                    translated = translate_text(src_txt, model, tokenizer)
                    if len(translated) == 0:
                        # to prevent division by zero in BLEU with empty string
                        translated = "0"
                    bleu.add(predictions=translated.split(), references=[tgt_txt.split()])

        mean_val_loss = torch.mean(torch.tensor(val_losses)).item()
        bleu_score = bleu.compute()["bleu"]
        wandb.log({"val_loss": mean_val_loss, "bleu": bleu_score}, step=i)
        logging.info(
            f"[{i}] tr_loss: {tr_loss.detach().item():.4f}  val_loss: {mean_val_loss:.4f}  val_bleu: {bleu_score:.4f}"
        )
        logging.info(f"DE: {demo_source_txt}")
        logging.info(f"EN: {demo_target_txt}")
        logging.info(f"output: {translate_text(demo_source_txt, model, tokenizer)}")
        logging.info("")