If you're opening this Notebook on colab, you will need to clone the repo and change directory. Uncomment the cell below and run it.


In [None]:
# !git clone https://github.com/jbergq/transformer.git


In [None]:
from pathlib import Path

if Path.cwd().name != "transformer":
  %cd transformer

In [None]:
%pip install portalocker
%pip install -r requirements.txt

In [None]:
from easydict import EasyDict

cfg = EasyDict(
    {
        "num_epochs": 100,
        "batch_size": 4,
        "lr": 1e-3,
        "weight_decay": 0.0005,
        "print_example": True,
    }
)

models = {
    "toy-model": {
        "hidden_size": 128,
        "ff_hidden_size": 256,
        "num_blocks": 4,
        "num_heads": 4,
        "context_size": 64,
    },
    "gpt2-small": {
        "hidden_size": 768,
        "ff_hidden_size": 3072,
        "num_blocks": 12,
        "num_heads": 12,
        "context_size": 1024,
    },
}

cfg.update(models["gpt2-small"])

cfg


In [None]:
import torch

torch.manual_seed(1337)
device = "cuda" if torch.cuda.is_available() else "cpu"
device


In [None]:
import wandb

wandb.login()

wandb.init(project="transformer", config=cfg)


In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from src.model.transformer import TransformerDecoder
from src.utils import iter_print, epoch_print


Let's setup our dataset. We will use Hugging Face's `datasets` package to prepare and load the WebText dataset.


In [None]:
from datasets import load_dataset

dataset = load_dataset("openwebtext", streaming=True)
# TODO: Load val/test set used in GPT-2 paper.


To tokenize our dataset, we will use the GPT-2 tokenizer, available from Hugging Face's `transformers` package.


In [None]:
from transformers import GPT2Tokenizer

# Tokenizer used by GPT-2.
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# tokenizer.get_vocab()


In [None]:
def tokenize(example):
    outputs = tokenizer(
        example["text"],
        truncation=True,  # Truncate returned token sequences to max_lenght.
        max_length=cfg.context_size + 1,  # Max length of return token sequences.
        return_overflowing_tokens=True,  # Tokenize whole input and split into chunks.
        return_length=True,  # Return lengths of chunks.
    )

    source_batch = []
    target_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == cfg.context_size + 1:
            source_batch.append(input_ids[:-1])
            target_batch.append(input_ids[1:])

    return {"source": source_batch, "target": target_batch}


dataset_train = dataset["train"]
dataset_train = dataset_train.map(
    tokenize, batched=True, remove_columns=dataset_train.column_names
)


In [None]:
train_dataloader = DataLoader(
    dataset_train,
    batch_size=cfg.batch_size,
    collate_fn=lambda samples: {
        "source": torch.tensor([sample["source"] for sample in samples]),
        "target": torch.tensor([sample["target"] for sample in samples]),
    },
)

# Let's load a batch to confirm that it's working.
batch = next(iter(train_dataloader))

print(batch["source"][0][:10])
print(batch["target"][0][:10])


In [None]:
def train(model, dataloader, optimizer, criterion, epoch):
    train_losses = []

    model.train()

    for i, batch in enumerate(dataloader):
        src, tgt = batch["source"].to(device), batch["target"].to(device)

        out = model(src)

        out_reshape = out.contiguous().view(-1, out.shape[-1])  # (B * T, vocab_size)
        tgt_reshape = tgt.contiguous().view(-1)  # (B * T, 1)

        loss = criterion(out_reshape, tgt_reshape)
        loss.backward()
        optimizer.step()

        loss_val = loss.item()
        train_losses.append(loss_val)
        iter_print(epoch, i, loss_val)

    return torch.tensor(train_losses)


def validate(model, dataloader, criterion, epoch):
    val_losses = []

    model.eval()

    for i, batch in enumerate(dataloader):
        src, tgt = batch["source"].to(device), batch["target"].to(device)

        out = model(src)

        out_reshape = out.contiguous().view(-1, out.shape[-1])
        tgt_reshape = tgt.contiguous().view(-1)

        loss = criterion(out_reshape, tgt_reshape)

        loss_val = loss.item()
        val_losses.append(loss_val)
        iter_print(epoch, i, loss_val)

        pred = out.softmax(dim=2).argmax(dim=2)

    return torch.tensor(val_losses)


In [None]:
model = TransformerDecoder(
    tokenizer.vocab_size,
    cfg.context_size,
    cfg.hidden_size,
    cfg.ff_hidden_size,
    cfg.num_blocks,
    cfg.num_heads,
)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, eps=5e-9)
criterion = nn.CrossEntropyLoss(ignore_index=1)

fixed_inp = torch.tensor(
    tokenizer.encode("The"), dtype=torch.long, device=device
).unsqueeze(0)

if cfg.print_example:
    batch = next(iter(train_dataloader))
    out = model.generate(fixed_inp)

    print("Example sequence: ", tokenizer.decode(batch["target"][0].numpy()))
    print("Model output: ", tokenizer.decode(out[0].detach().cpu().numpy()))

for epoch in range(cfg.num_epochs):
    train_losses = train(model, train_dataloader, optimizer, criterion, epoch)
    # val_losses = validate(model, val_dataloader, criterion, epoch)

    wandb.log(
        {
            "train_loss": train_losses.mean().item(),
            # "val_loss": val_losses.mean().item()
        }
    )
    # epoch_print(epoch, val_losses)

    out = model.generate(fixed_inp)
    print(tokenizer.decode(out[0].detach().cpu().numpy()))
