If you're opening this Notebook on colab, you will need to clone the repo and change directory. Uncomment the cell below and run it.


In [None]:
!git clone https://github.com/jbergq/transformer.git && cd transformer

fatal: destination path 'transformer' already exists and is not an empty directory.


In [None]:
%pip install portalocker
%pip install -r requirements.txt

[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m
[0m

In [None]:
import math
import torch
import torch.nn as nn

## Transformer decoder implementation


This section shows how to implement a transformer decoder, the version of the transformer suitable for language modeling.

### Token embeddings

In NLP, text is represented by sub-word units called tokens. Tokens can be thought of as a numeric representation of the smallest meaningful units of language. After tokenizing our training corpus, the text will be represented as sequences of integers that are easier for a machine learning model to work with.

In the transformer, tokens are mapped to trainable vector representations called embeddings. Once trained, these embeddings represent various features of the sub-words they correspond to.

Let's implement the token embedding using PyTorch's built-in `nn.Embedding` module.

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, tokens):
        return self.embedding(tokens)

So is that it? Not quite. While we could train our transformer with only token embeddings, the transformer architecture itself has no notion of the ordering of tokens. In natural language, the order of words can completely change the meaning of a sentence, so it is necessary to give our transformer a way to represent this. A common way to address this is to create another embedding that is simply added onto the token embeddings to inject information about its position, called a *positional encoding*.

Let's see how it can be implemented. For a sequence length of $T$, we want to generate $T$ vectors that can uniquely represent each position in the sequence. We also want TODO.

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, context_size, embedding_size, n=10000):
        super().__init__()

        i = torch.arange(embedding_size // 2)
        k = torch.arange(context_size).unsqueeze(dim=1)

        pos_embeddings = torch.zeros(context_size, embedding_size, requires_grad=False)
        pos_embeddings[:, 0::2] = torch.sin(k / (n ** (2 * i / embedding_size)))
        pos_embeddings[:, 1::2] = torch.cos(k / (n ** (2 * i / embedding_size)))

        self.register_buffer("pos_embeddings", pos_embeddings)

    def forward(self, x):
        return self.pos_embeddings[: x.shape[1], :]

### Transformer blocks

Transformers, like many other neural networks, are made by stacking multiple computational blocks in a sequence. Transformer blocks contain two main components:
1. A multi-head attention module, responsible for communication between token embeddings.
2. A feedforward module, responsible for processing token embeddings.

By alternating between inter-token communication and per-token processing, transformers are able to produce sophisticated representations of language.

#### Multi-head attention

Attention is perhaps the most known and important component of transformers. TODO.

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch_size, head, seq_length, head_size = k.shape

        score = (q @ k.transpose(2, 3)) / math.sqrt(head_size)

        if mask is not None:
            score = score.masked_fill(mask == 0, -1e9)

        score = self.softmax(score)

        return score @ v

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, hidden_size, num_heads=8):
        super().__init__()

        self.num_heads = num_heads

        self.attention = ScaledDotProductAttention()

        self.lin_q = nn.Linear(embedding_size, hidden_size)
        self.lin_k = nn.Linear(embedding_size, hidden_size)
        self.lin_v = nn.Linear(embedding_size, hidden_size)

        self.lin_concat = nn.Linear(hidden_size, embedding_size)

    def forward(self, q, k, v, mask=None):
        q, k, v = self.lin_q(q), self.lin_k(k), self.lin_v(v)

        q, k, v = self.split(q), self.split(k), self.split(v)

        out = self.attention(q, k, v, mask)

        out = self.concat(out)
        out = self.lin_concat(out)

        return out

    def split(self, x):
        batch_size, seq_len, hidden_size = x.shape

        per_head_size = hidden_size // self.num_heads

        return x.view(batch_size, seq_len, self.num_heads, per_head_size).transpose(
            1, 2
        )

    def concat(self, x):
        batch_size, num_heads, seq_len, head_size = x.shape
        hidden_size = num_heads * head_size

        return x.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_size)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, in_size, hidden_size, dropout_prob=0.1):
        super().__init__()

        self.lin1 = nn.Linear(in_size, hidden_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_prob)
        self.lin2 = nn.Linear(hidden_size, in_size)

    def forward(self, x):
        x = self.lin1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.lin2(x)

        return x

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        hidden_size,
        ff_hidden_size,
        num_heads,
        use_cross_attn,
        dropout_prob=0.1,
        layer_norm_eps=1e-5,
    ):
        super().__init__()

        self.attention1 = MultiHeadAttention(hidden_size, hidden_size, num_heads)
        self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout_prob)

        if use_cross_attn:
            self.enc_dec_attention = MultiHeadAttention(
                hidden_size, hidden_size, num_heads
            )
            self.enc_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
            self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
            self.dropout2 = nn.Dropout(dropout_prob)

        self.ff_block = FeedForward(hidden_size, ff_hidden_size, dropout_prob)

        self.norm3 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.dropout3 = nn.Dropout(dropout_prob)

    def forward(self, x, lookahead_mask=None):
        x_n = self.norm1(x)
        x_a = self.attention1(q=x_n, k=x_n, v=x_n, mask=lookahead_mask)

        x = self.dropout1(x + x_a)

        x_f = self.ff_block(self.norm3(x))

        x = self.dropout3(x + x_f)

        return x

#### Putting it all together

In [None]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        context_size,
        hidden_size,
        ff_hidden_size,
        num_blocks=5,
        num_heads=8,
        use_cross_attn=False,
    ):
        super().__init__()

        self.token_embedding = TokenEmbedding(vocab_size, hidden_size)
        self.pos_embedding = PositionalEncoding(context_size, hidden_size)

        self.decoder = []
        for _ in range(num_blocks):
            self.decoder.append(
                DecoderBlock(
                    hidden_size,
                    ff_hidden_size,
                    num_heads,
                    use_cross_attn,
                )
            )

        self.decoder = nn.ModuleList(self.decoder)

        self.ln_final = nn.LayerNorm(hidden_size)
        self.lin_final = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, lookahead_mask=None):
        x = self.token_embedding(tokens) + self.pos_embedding(tokens)

        for block in self.decoder:
            x = block(x, lookahead_mask)

        out = self.lin_final(self.ln_final(x))

        return out

In [None]:
class Transformer(nn.Module):
    def __init__(self, context_size):
        super().__init__()

        self.context_size = context_size
        self.register_buffer("tri", torch.tril(torch.ones(context_size, context_size)))

    def create_lookahead_mask(self, tgt_seq_len):
        return self.tri[:tgt_seq_len, :tgt_seq_len].unsqueeze(0)


class TransformerDecoder(Transformer):
    """Transformer decoder, using auto-regressive decoder blocks for language modeling."""

    def __init__(
        self,
        vocab_size,
        context_size,
        hidden_size,
        ff_hidden_size,
        num_blocks=5,
        num_heads=8,
    ):
        super().__init__(context_size)

        self.decoder = Decoder(
            vocab_size,
            context_size,
            hidden_size,
            ff_hidden_size,
            num_blocks,
            num_heads,
            use_cross_attn=False,
        )

    def forward(self, x):
        lookahead_mask = self.create_lookahead_mask(x.shape[1])

        out = self.decoder(x, None, lookahead_mask)

        return out

## Training


### Util functions

In [None]:
def count_model_params(model):
    total_params = 0
    for params in list(model.parameters()):
        num = 1
        for size in list(params.size()):
            num = num * size
        total_params += num
    return total_params

In [None]:
def train_start_print(model):
    e_string = "=" * 45

    num_params = count_model_params(model)
    num_params_m = num_params / 1e6

    print("\n" + e_string)
    print("Starting training!")
    print("Num model params: {num_params_m:.3f}M".format(num_params_m=num_params_m))
    print(e_string + "\n")


def iter_print(iter, train_loss, newline_interval=50):
    l_string = "-" * 45
    f_str = "{: <10} {: <10.5}"

    if iter % 500 == 0:
        print(l_string)
        print(f_str.format("Iter", "Train loss"))
        print(l_string)
    print(f_str.format(iter, train_loss), end="\r" if iter % newline_interval else "\n")


def evaluation_print(losses):
    e_string = "=" * 45

    print("\n\n" + e_string)
    print("Evaluation done!")
    print("Mean train loss: {mean_loss:.3f}".format(mean_loss=losses["train"]))
    print("Mean validation loss: {mean_loss:.3f}".format(mean_loss=losses["val"]))
    print(e_string + "\n")


In [None]:
from easydict import EasyDict


# Define base config. Partly adopted from nanoGPT by Andrej Karpathy
cfg = EasyDict(
    {
        "val_size": 1000,  # Size of validation set.
        "max_iters": 600000,  # Total num training iterations.
        "eval_iters": 100,  # Number of evaluation iterations.
        "eval_interval": 1000,
        "effective_batch_size": 512,
        "batch_size": 4,
        "grad_accum_steps": 1,
        "lr": 1e-3,
        "warmup_iters": 2000,
        "lr_decay_iters": 600000,  # Should be ~= max_iters per Chinchilla.
        "min_lr": 6e-5,  # Minimum learning rate, should be ~= learning_rate/10 per Chinchilla.
        "weight_decay": 0.0005,
        "print_example": True,
    }
)

# Derive accumulation steps to get target effective batch size.
if cfg.effective_batch_size is not None:
    cfg["grad_accum_steps"] = cfg["effective_batch_size"] // cfg["batch_size"]

print(cfg)

{'val_size': 1000, 'max_iters': 600000, 'eval_iters': 100, 'eval_interval': 1000, 'effective_batch_size': 512, 'batch_size': 4, 'grad_accum_steps': 128, 'lr': 0.001, 'warmup_iters': 2000, 'lr_decay_iters': 600000, 'min_lr': 6e-05, 'weight_decay': 0.0005, 'print_example': True}


In [None]:
torch.manual_seed(1337)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
import wandb

wandb.login()
wandb.init(project="transformer", config=cfg)

Let's setup our dataset. We will use Hugging Face's `datasets` package to prepare and load the WebText dataset.


In [None]:
from datasets import load_dataset

# Load WebText dataset in streaming mode. No need to download!
dataset = load_dataset("openwebtext", streaming=True)["train"]
shuffled_dataset = dataset.shuffle(seed=42, buffer_size=10000)

# Split dataset.
train_set = shuffled_dataset.skip(cfg.val_size)
val_set = shuffled_dataset.take(cfg.val_size)

To tokenize our dataset, we will use the GPT-2 tokenizer, available from Hugging Face's `transformers` package.


In [None]:
from transformers import GPT2Tokenizer

# Tokenizer used by GPT-2.
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer.vocab_size

In [None]:
from transformers import AutoModelWithLMHead, AutoConfig


class GPT2Wrapped(nn.Module):
    def __init__(self, pretrained=False) -> None:
        super().__init__()
        if pretrained:
            self.model = AutoModelWithLMHead.from_pretrained("gpt2")
        else:
            config = AutoConfig.from_pretrained("gpt2")
            self.model = AutoModelWithLMHead.from_config(config)

        self.context_size = self.model.config.n_ctx

    def forward(self, x):
        return self.model(x).logits


def get_model(name, vocab_size):
    if name == "toy-model":
        return TransformerDecoder(
            vocab_size=vocab_size,
            context_size=64,
            hidden_size=128,
            ff_hidden_size=256,
            num_blocks=4,
            num_heads=4,
        )
    elif name == "gpt2-small-custom":
        return TransformerDecoder(
            vocab_size=vocab_size,
            context_size=1024,
            hidden_size=768,
            ff_hidden_size=3072,
            num_blocks=12,
            num_heads=12,
        )
    elif name == "gpt2-small-hf":
        return GPT2Wrapped(pretrained=False)
    elif name == "gpt2-small-hf-pretrained":
        return GPT2Wrapped(pretrained=True)


# Edit below to select a model.
model_name = "toy-model"

model = get_model(model_name, tokenizer.vocab_size)
model = model.to(device)

In [None]:
from functools import partial


def tokenize(example):
    outputs = tokenizer(
        example["text"],
        truncation=True,  # Truncate returned token sequences to max_length.
        max_length=model.context_size + 1,  # Max length of returned token sequences.
        return_overflowing_tokens=True,  # Tokenize whole input and split into chunks.
        return_length=True,  # Return lengths of chunks.
    )

    # Create examples.
    source_batch = []
    target_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == model.context_size + 1:  # Only include full length sequences.
            source_batch.append(input_ids[:-1])
            target_batch.append(input_ids[1:])  # Note: Target is source shifted by one.

    return {"source": source_batch, "target": target_batch}


# Tokenize train and val sets.
train_tokenized = train_set.map(
    partial(tokenize),
    batched=True,
    remove_columns=train_set.column_names,
)
val_tokenized = val_set.map(
    partial(tokenize),
    batched=True,
    remove_columns=val_set.column_names,
)

The training uses an "infinite loop" style, where we continue to sample random batches until we reach convergence or the maximum number of batches configured.

Let's define a dataset wrapper that will allow us to continue sampling the dataset endlessly.


In [None]:
from typing import Iterator

from torch.utils.data import IterableDataset


class InfiniteIterableDataset(IterableDataset):
    def __init__(self, hf_dataset, shuffle=False):
        self.hf_dataset = hf_dataset

    def __iter__(self) -> Iterator:
        while True:
            for item in self.hf_dataset:
                yield item

In [None]:
from torch.utils.data import DataLoader

# Create data loaders for sampling batches.
train_loader = DataLoader(
    InfiniteIterableDataset(train_tokenized),
    batch_size=cfg.batch_size,
    collate_fn=lambda samples: {
        "source": torch.tensor([sample["source"] for sample in samples]),
        "target": torch.tensor([sample["target"] for sample in samples]),
    },
)
val_loader = DataLoader(
    InfiniteIterableDataset(val_tokenized),
    batch_size=cfg.batch_size,
    collate_fn=lambda samples: {
        "source": torch.tensor([sample["source"] for sample in samples]),
        "target": torch.tensor([sample["target"] for sample in samples]),
    },
)

# Create data iterators.
train_iter = iter(train_loader)
val_iter = iter(train_iter)

Let's load one train batch and one validation batch to make sure everything works.


In [None]:
batch_train = next(train_iter)

print(batch_train["source"][0][:10])
print(batch_train["target"][0][:10])

In [None]:
batch_val = next(val_iter)

print(batch_val["source"][0][:10])
print(batch_val["target"][0][:10])

In [None]:
def step(model, criterion, iterator):
    batch = next(iterator)
    src, tgt = batch["source"].to(device), batch["target"].to(device)

    out = model(src)
    # pred = out.softmax(dim=2).argmax(dim=2)

    out_reshape = out.contiguous().view(-1, out.shape[-1])  # (B * T, vocab_size)
    tgt_reshape = tgt.contiguous().view(-1)  # (B * T, 1)

    loss = criterion(out_reshape, tgt_reshape)

    return loss

In [None]:
# Loss estimation function inspired by nanoGPT repo by Andrej Karpathy.
@torch.no_grad()
def estimate_loss(model, criterion, train_iter, val_iter, eval_iters):
    iterators = {"train": train_iter, "val": val_iter}
    out = {}
    model.eval()
    for split, iterator in iterators.items():
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            loss = step(model, criterion, iterator)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()

    return out

In [None]:
from functools import partial


# Learning rate decay scheduler inspired by nanoGPT repo by Andrej Karpathy.
def get_lr(iter, warmup_iters, base_lr, min_lr, lr_decay_iters):
    # 1) linear warmup for warmup_iters steps
    if iter < warmup_iters:
        return base_lr * iter / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if iter > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (iter - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (base_lr - min_lr)


get_lr_fn = partial(
    get_lr,
    warmup_iters=cfg.warmup_iters,
    base_lr=cfg.lr,
    min_lr=cfg.min_lr,
    lr_decay_iters=cfg.lr_decay_iters,
)

In [None]:
import matplotlib.pyplot as plt

lrs = [get_lr_fn(iter) for iter in range(0, 1_000_000, 1)]
plt.plot(lrs)
plt.xlabel("iter")
plt.ylabel("lr")

In [None]:
from torch.optim import Adam

optimizer = Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, eps=5e-9)
criterion = nn.CrossEntropyLoss(ignore_index=1)

To get an estimate for what loss values are reasonable to reach, let's load the original GPT-2 using HuggingFace and evaluate it with a few samples.


In [None]:
losses = estimate_loss(gpt2, criterion, train_iter, val_iter, 100)
losses

In [None]:
import torch.nn.functional as F


@torch.no_grad()
def generate(model, inp_seq, context_size, max_output_len=100):
    seq = inp_seq

    for _ in range(max_output_len):
        out = model(seq[..., -context_size:])  # Truncate input sequence to max length.
        probs = F.softmax(out[:, -1, :], dim=1)
        next_tokens = torch.multinomial(probs, num_samples=1)

        # Append the next tokens to the generated sequences.
        seq = torch.cat((seq, next_tokens), dim=-1)

    return seq

In [None]:
fixed_inp = torch.tensor(
    tokenizer.encode("The"), dtype=torch.long, device=device
).unsqueeze(0)

if cfg.print_example:
    batch = next(iter(train_loader))
    out = generate(model, fixed_inp, cfg.context_size)

    print("Example sequence: ", tokenizer.decode(batch["target"][0].numpy())[:200])
    print("Model output: ", tokenizer.decode(out[0].detach().cpu().numpy())[:200])

# Reinitialize data iterators.
train_iter = iter(train_loader)
val_iter = iter(train_iter)

iter_num = 0
best_val_loss = float("inf")

model.train()

# Start training.
train_start_print(model)


while True:
    # Get learning rate according to schedule.
    lr = get_lr_fn(iter_num)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    # Train model on one batch.
    train_loss = step(model, criterion, train_iter)
    train_loss.backward()

    # Accumulate gradients for N steps and update weights.
    if (iter_num + 1) % cfg.grad_accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

    if iter_num > 0 and iter_num % cfg.eval_interval == 0:
        losses = estimate_loss(model, criterion, train_iter, val_iter, cfg.eval_iters)
        evaluation_print(losses)

        # Generate sample and print.
        out = generate(model, fixed_inp, cfg.context_size)
        print("Model output: ", tokenizer.decode(out[0].detach().cpu().numpy())[:200])

        # Save model checkpoint if new best validation loss.
        if losses["val"] < best_val_loss:
            torch.save(
                {
                    "iter": iter_num,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                },
                "best.pth",
            )

        # Log to WandB.
        wandb.log(
            {
                "train/loss": losses["train"],
                "val/loss": losses["val"],
                "lr": lr,
            },
            step=iter_num,
        )

    iter_print(iter_num, train_loss)
    iter_num += 1