# Assignment 4: Training Transformers in PyTorch

*Author:* Thomas Adler

*Copyright statement:* This  material,  no  matter  whether  in  printed  or  electronic  form,  may  be  used  for  personal  and non-commercial educational use only.  Any reproduction of this manuscript, no matter whether as a whole or in parts, no matter whether in printed or in electronic form, requires explicit prior acceptance of the authors.

In this assignment we will implement and train a small transformer model and compare it to the LSTM in the previous assignment.

## Exercise 1: Causal Self-Attention

Write a class named `CausalSelfAttention` that derives from `nn.Module` and whose `__init__` method takes (apart from the trivial `self`) one argument `hidden_size`. Implement a method `forward` that takes an input sequence `x` of shape $(N, T, D)$ (where $N$ is batch size, $T$ is sequence length, $D$ is hidden size) and performs scaled dot-product self-attention, i.e.,
$$
Y = \operatorname{softmax}\left(\frac{1}{\sqrt{D}} Q K^\top\right) V,
$$
where $Q = X W_Q$ and $K = X W_K$ and $V = X W_V$ and $X \in \mathbb{R}^{T \times D}$ and $W_Q, W_K, W_V \in \mathbb{R}^{D \times D}$ and softmax is applied in a row-wise manner and neglecting bias units.
It is called self-attention because $Q, K, V$ are all computed from the same input $X$, which hence attends to itself.

To have the attention be *causal* we need to make sure that we do not allow peeks into the future. That is, the output at time $t$ must be a function of the input at times $1, \dots, t$ but no further. The score matrix $E = \frac{1}{\sqrt{D}} Q K^\top$ has a shape of $T \times T$ and the entry $e_{ij}$ measures how strong the query at time $i$ attends to the key at time $j$. Therefore, positions where $j > i$ constitute peeks into the future and we have to set the corresponding attention values (i.e., the softmax-activated score) to zero. We can do that by setting the corresponding score to `float('-inf')`, which has the advantage that the normalization is adjusted automatically by the softmax.

In [66]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import sklearn
import matplotlib


class CausalSelfAttention(nn.Module):
    def __init__(self, hidden_size):
        super(CausalSelfAttention, self).__init__()
        self.hidden_size = hidden_size

    def forward(self, x):
        # x has shape (N, T, D)
        Q = torch.matmul(x, WQ)
        K = torch.matmul(x, WK)
        V = torch.matmul(x, WV)
        # Scaled dot-product self-attention
        Y = torch.softmax(torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_size), dim=-1)
        # Compute the output
        Z = torch.matmul(Y, V)
        return Z



## Exercise 2: Multi-Head Attention

Write a class `MultiHeadCausalSelfAttention` that derives from `nn.Module` and extends the functionality of `CausalSelfAttention` from the previous exercise.
The `__init__` method takes arguments `hidden_size, n_head, dropout`. `n_head` specifies the number of attention heads and `dropout` specifies the intensity for the dropout layers.
The `forward` method should split the hidden dimension of the pre-activations (i.e., $Q, K, V$) in `n_head` equally sized parts and perform attention to these parts in parallel.
Apply the first dropout layer direcly after the softmax.
After the multiplication of the scores with the values, recombine the output of the distinct attention heads back into a single hidden dimension of size $D$, i.e., the resulting shape should be the shape of the input.
Then perform an additional output projection again resulting in a hidden dimension of $D$.
Finally, apply the second dropout layer after the output projection.

In [2]:
class MultiHeadCausalSelfAttention(nn.Module):
    def __init__(self, hidden_size, n_head, dropout):
        super(MultiHeadCausalSelfAttention, self).__init__()
        self.hidden_size = hidden_size
        self.n_head = n_head
        self.dropout = nn.Dropout(dropout)
        self.WQ = nn.Linear(hidden_size, hidden_size * n_head)
        self.WK = nn.Linear(hidden_size, hidden_size * n_head)
        self.WV = nn.Linear(hidden_size, hidden_size * n_head)
        self.WO = nn.Linear(hidden_size * n_head, hidden_size)

    def forward(self, x):
        # x has shape (N, T, D)
        N, T, D = x.shape
        Q = self.WQ(x).view(N, T, self.n_head, D).transpose(1, 2)  # (N, n_head, T, D)
        K = self.WK(x).view(N, T, self.n_head, D).transpose(1, 2)  # (N, n_head, T, D)
        V = self.WV(x).view(N, T, self.n_head, D).transpose(1, 2)  # (N, n_head, T, D)
        # Scaled dot-product self-attention
        Y = torch.softmax(torch.matmul(Q, K.transpose(-1, -2)) / math.sqrt(self.hidden_size), dim=-1)
        Y = self.dropout(Y)
        # Compute the output
        Z = torch.matmul(Y, V)
        Z = Z.transpose(1, 2).contiguous().view(N, T, self.hidden_size * self.n_head)
        Z = self.WO(Z)
        Z = self.dropout(Z)
        return Z


## Exercise 3: Multi-Layer Perceptron

Write a class `MLP` that derives from `nn.Module` and whose `__init__` method takes two arguments: `hidden_size` and `dropout`.
It should implement a 2-layer feedforward network with `hidden_size` inputs, `4*hidden_size` hiddens, and `hidden_size` outputs.
It should apply the GELU activation function to the hiddens and dropout to the outputs.

In [3]:
class MLP(nn.Module):
    def __init__(self, hidden_size, dropout):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(hidden_size, 4 * hidden_size)
        self.fc2 = nn.Linear(4 * hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


## Exercise 4: Block

Write a class `Block` that derives from `nn.Module` and whose `__init__` method takes arguments `hidden_size, n_head, dropout`.
It should apply `nn.LayerNorm`, `CausalMultiHeadSelfAttention`, `nn.LayerNorm`, `MLP` in that order and feature residual connections from the input to the output of `CausalMultiHeadSelfAttention` and from there to the output of `MLP`.

In [4]:
class Block(nn.Module):
    def __init__(self, hidden_size, n_head, dropout):
        super(Block, self).__init__()
        self.norm1 = nn.LayerNorm(hidden_size)
        self.attention = MultiHeadCausalSelfAttention(hidden_size, n_head, dropout)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.mlp = MLP(hidden_size, dropout)

    def forward(self, x):
        # Feature residual connection from input to output of attention
        y = self.attention(self.norm1(x)) + x
        # Feature residual connection from input to output of MLP
        z = self.mlp(self.norm2(y)) + y
        return z



## Exercise 5: GPT

Write a class `GPT` that derives from `nn.Module` and whose `__init__` method takes arguments `vocab_size, context_size, hidden_size, n_layer, n_head, dropout`.
The `forward` method should take two arguments `x, y` representing sequences of input and target tokens, respectively, both of which have type `torch.long` and shape ($N$, $T$), and returns logits and loss as a tuple.
The `GPT` module should feature two `nn.Embedding` layers, one for token embeddings and one for positional embedding, i.e., it should embed the position of the corresponding token within the input sequence.
The positional embedding is necessary for the Transformer to determine the order of its inputs.
Add the two embeddings and apply a dropout layer.
Next, apply `n_layers` layers of `Block`s followed by a `nn.LayerNorm` and a `nn.Linear` (without bias) mapping to an output dimension of `vocab_size`.
Finally, apply the cross-entropy loss function to the logits.
To save some parameters, apply weight tying between the token embedding layer and the output layer, i.e., they should use the same weights.
Initialize all weights using a normal distribution with a mean of zero and a standard deviation of 0.02 (except for the output layers of the `MLP`s use $0.02/\sqrt{2 * \mathtt{n\_layer}}$) and all biases to zero.
Use the argument `dropout` as intensity for all dropout layers in the network.

In [62]:
class GPT(nn.Module):
    def __init__(self, vocab_size, context_size, hidden_size, n_layer, n_head, dropout):
        super(GPT, self).__init__()
        self.vocab_size = vocab_size
        self.context_size = context_size
        self.hidden_size = hidden_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.dropout = nn.Dropout(dropout)
        self.token_embedding = nn.Embedding(vocab_size, hidden_size)
        self.positional_embedding = nn.Embedding(context_size, hidden_size)
        self.blocks = nn.ModuleList([Block(hidden_size, n_head, dropout) for _ in range(n_layer)])
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.output = nn.Linear(hidden_size, vocab_size)

        self._init_weights()

    def forward(self, x, y):
        # Token and positional embeddings
        token_embedded = self.token_embedding(x)
        position_embedded = self.positional_embedding(torch.arange(x.shape[1], device=x.device))
        x = self.dropout(token_embedded + position_embedded)

        # Blocks
        for block in self.blocks:
            x = block(x)

        # Layer norm and output linear layer
        x = self.layer_norm(x)
        logits = self.output(x)

        # Compute the loss
        loss = nn.CrossEntropyLoss()(logits.view(-1, self.vocab_size), y.view(-1))
        return logits, loss

    def _init_weights(self):
        # Token embedding and output layers weight tying
        self.output.weight = self.token_embedding.weight
        self.output.bias.data.zero_()

        # Initialize other weights and biases
        for module in self.modules():
            if isinstance(module, (nn.Linear, nn.Embedding)):
                module.weight.data.normal_(mean=0.0, std=0.02)
                if isinstance(module, nn.Linear) and module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)


## Exercise 6: Optimizer

Add a method `configure_optimizers` to the class `GPT` that takes arguments `weight_decay, learning_rate, betas`.
Divide the model parameters into two groups.
The first group consists of all parameters with at least 2 dimensions, e.g., weight/embedding matrices and uses a decay of `weight_decay`.
The second group consists of all other parameters, e.g., biases and layer norms, and does not use weight decay.
Construct and return a `torch.optim.AdamW` optimizer with `learning_rate` and `betas` that operates on these two parameter groups.

In [58]:
import torch.optim as optim

class GPT(nn.Module):

    def configure_optimizers(self):
        weight_decay = 0.01
        learning_rate = 0.001
        betas = (0.9, 0.999)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = optim.AdamW(optimizer_grouped_parameters, lr=learning_rate, betas=betas)
        return optimizer


## Exercise 7: Training

In the code cell below you find some globals, helper functions, and boilerplate code. Extend the given code by a training loop that
* stops after `max_iters` iterations
* applies the learning rate schedule implemented in `get_lr`
* applies gradient clipping at `grad_clip` using `torch.nn.utils.clip_grad_norm_`
* accumulates gradients for `gradient_accumulation_steps` batches before each weight update
* logs the training loss and learning rate every `log_interval` iterations
* evaluates (and potentially checkpoints) the model using `estimate_loss` every `eval_iters` iterations.

The provided hyperparameter values should be a good guess for training a tiny model on CPU but feel free to experiment with them as you please. In particular, if you have a GPU available, you can try to scale things up a bit.

In [71]:
eval_interval = 250 # validate model every .. iterations
log_interval = 10 # log training loss every .. iterations
eval_iters = 20 # number of batches for loss estimation
gradient_accumulation_steps = 5 * 8 # used to simulate larger training batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
context_size = 64 # sequence length
vocab = 'abcdefghijklmnopqrstuvwxyz0123456789 .!?' # vocabulary
vocab_size = len(vocab) # 40
n_layer = 4 # number of layers
n_head = 4 # number of attention heads
hidden_size = 128 # layer size
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
learning_rate = 1e-3 # max learning rate
max_iters = 2000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9 # for AdamW
beta2 = 0.99 # for AdamW
grad_clip = 1.0 # clip gradients at this value, or disable with 0.0
warmup_iters = 100 # how many steps to warm up for
min_lr = 1e-4 # minimum learning rate, usually ~= learning_rate/10

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > max_iters, return min learning rate
    if it > max_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

def load_data(split):
    import re

    with open(f'trump_{split}.txt', 'r') as f:
        text = f.read()

    text = text.lower() # convert to lower case
    text = re.sub('[^a-z0-9 .!?]', ' ', text) # replace all unknown chars with ' '
    text = re.sub(' +', ' ', text) # reduce multiple blanks to one
    text = [vocab.index(t) for t in text]
    text = torch.tensor(text, dtype=torch.long)
    return text

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_size, (batch_size,))
    x = torch.stack([data[i:i+context_size] for i in ix])
    y = torch.stack([data[i+1:i+1+context_size] for i in ix])
    return x, y

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# data, model, optimizer, etc.
train_data = load_data('train')
val_data = load_data('val')
train_data = load_data('train')
val_data = load_data('val')
model = GPT(vocab_size, context_size, hidden_size, n_layer, n_head, dropout)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)

# Initialize other necessary variables
iter_num = 0
best_val_loss = 1e9
X, Y = get_batch('train') # Fetch the very first batch
t0 = time.time()


# Training loop
for iteration in range(max_iters):
    # Forward pass and loss calculation
    logits, loss = model(X, Y)
    loss = loss / gradient_accumulation_steps  # Normalize the loss for gradient accumulation
    loss.backward()

    # Accumulate gradients for a certain number of iterations before the weight update
    if (iteration + 1) % gradient_accumulation_steps == 0:
        # Learning rate scheduling
        lr = get_lr(iteration)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Gradient clipping
        if grad_clip > 0.0:
             torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

        # Weight update
        optimizer.step()
        optimizer.zero_grad()

    # Log the training loss and learning rate every log_interval iterations
    if iteration % log_interval == 0:
        print(f"Iteration {iteration}, Loss: {loss.item()}, Learning Rate: {get_lr(iteration)}")

    # Validate the model and potentially checkpoint every eval_iters iterations
    if iteration % eval_interval == 0:
        val_loss = estimate_loss()['val']
        print(f"Iteration {iteration}, Validation Loss: {val_loss}")

    # Stop execution after a few attempts (for demonstration purposes)
    if iteration >= 50:

        break




Iteration 0, Loss: 0.09268394112586975, Learning Rate: 0.0
Iteration 0, Validation Loss: 3.7201905250549316
Iteration 10, Loss: 0.09268394112586975, Learning Rate: 0.0001
Iteration 20, Loss: 0.09268394112586975, Learning Rate: 0.0002
Iteration 30, Loss: 0.09268394112586975, Learning Rate: 0.0003
Iteration 40, Loss: 0.09136929363012314, Learning Rate: 0.0004
Iteration 50, Loss: 0.09136929363012314, Learning Rate: 0.0005


## Exercise 8: Inference

Add a method `generate` to the class `GPT` that takes arguments `x, max_new_tokens, temperature=1.0`.
The method should take a batch of token sequences `x`, which it should extend by `max_new_tokens` new tokens generated by the model.
Once you have computed the logits for the next token, divide them by `temperature` before applying the softmax.
After applying the softmax, sample the next token from the resulting categorical distribution.
Try out different values for `temperature` and compare the results to those from the previous assignment.

In [72]:
class GPT(nn.Module):

    def generate(self, x, max_new_tokens, temperature=1.0):
        self.eval()  # Put the model in evaluation mode

        generated_tokens = []  # To store generated tokens

        with torch.no_grad():
            for _ in range(max_new_tokens):
                logits = self.forward(x)[:, -1, :] / temperature  # Compute logits for the next token

                # Apply softmax to logits and sample the next token
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                # Extend the input sequence with the sampled token
                x = torch.cat([x, next_token], dim=-1)
                generated_tokens.append(next_token.item())  # Store the generated token

                print('Generated Token:', next_token.item())  # Display the generated token

        return x, generated_tokens

