# Generatively Pretrained Transformer (GPT)

We are going to create a Generatively Pretrained Transformer (GPT) and train it on the [Shakespeare text](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset which contains around 1M characters. Once the model is trained, it will genearte Shakespeare like text.

## 7.1 Transformers

The **transformer** model was introduced in the paper [Attention Is All You Need](https://arxiv.org/abs/1706.03762) (2017) by Vaswani et al. It is a neural network architecture based on the self-attention mechanism, which allows the model to weigh the importance of different words in a sentence when processing text.

The transformer has two parts:
- The **encoder** reads and processes input data (like a sentence).
- The **decoder** generates an output based on that processed input.


<br>

A **Generative Pre-trained Transformer (GPT)** is a specific type of transformer model that uses **only the decoder** part of the transformer architecture. GPT models are autoregressive, meaning they generate text one word at a time, using the previous words to predict the next.



## 7.2 Load dataset

In [1]:
# download dataset
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt;

In [2]:
# load dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f'Length of the dataset: {len(text)} characters')

Length of the dataset: 1115394 characters


In [3]:
# print first 200 characters
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


## 7.3 Vocabulary size

The vocabulary size is the **number of unique characters**.

<br>

**Note:**
- chars[0] is the new line character, '\n'.
- chars[1] is the space character, ' '.

In [4]:
# unique characters and vocabulary size
chars = sorted(list(set(text)))
vocab_size = len(chars)

print('Unique characters:', ''.join(chars))
print('Vocabulary size:', vocab_size)

Unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocabulary size: 65


## 7.4 Tokenizer

Tokenize means **convert the raw text** as a string **to some sequence of integers** according to some vocabulary of possible elements. 

We are building a character level language model so our tokenizer is going to simply translate individual characters into integers using a **lookup table**.

In [5]:
# mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

# encoder: take a string, output a list of integers
encode = lambda s: [stoi[c] for c in s]

# decoder: take a list of integers, output a string
decode = lambda l: ''.join([itos[i] for i in l])

print('Encoder output:', encode('hii there'))
print('Decoder output:', decode([46, 47, 47, 1, 58, 46, 43, 56, 43]))

Encoder output: [46, 47, 47, 1, 58, 46, 43, 56, 43]
Decoder output: hii there


Large language models (LLMs) also encode text into integers but in a different schema and using a different vocabulary. They use a tokenization process that splits the text into subword units, known as **tokens**, rather than individual characters or entire words. LLMs use a bigger vocabulary size so the encoder tensors are smaller. For example, Google uses the [SentencePiece](https://github.com/google/sentencepiece?tab=readme-ov-file) tokenizer and OpenAI uses the [Tiktoken](https://github.com/openai/tiktoken) tokenizer.

In [6]:
# pip install tiktoken;

In [7]:
import tiktoken

# load gpt2 encoder
enc = tiktoken.get_encoding('gpt2')
print('Vocabulary size:', enc.n_vocab)

print('Encoder output:', enc.encode('hii there'))
print('Decoder output:', enc.decode([71, 4178, 612]))

Vocabulary size: 50257
Encoder output: [71, 4178, 612]
Decoder output: hii there


## 7.5 Build dataset

In [8]:
import torch
torch.manual_seed(1337)

# encode the entire text dataset and store it into a PyTorch tensor
data = torch.tensor(encode(text), dtype=torch.long)

# split the data into train and validation sets
n = int(0.9*len(data))
train_data = data[:n] # 90%
val_data = data[n:]   # 10%

In [9]:
# the earlier 200 characters would look like this to the GPT
data[:200]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59])

## 7.6 Chunk, block size, and context

We cannot feed the entire training set into a transformer all at once because that would be computationally very expensive. Therefore, when training a transformer, we sample random smaller **chunks** from the training set and train on just one chunk at a time. The length of these chunks is referred to as the **block size**, which represents the number of tokens (in this case, characters) that the transformer can process at once.

In a chunk of nine tokens, there are actually eight training examples packed into it (see below). This is because, for each token, the transformer learns how to predict the next token based on its **context**, which consists of the preceding tokens in the sequence. Please note that, as we will see later, thanks to the self-attention mechanism, these examples are processed simultaneously, allowing the transformer to efficiently learn from the relationships between all the tokens in the sequence.

Since the transformer is trained with contexts of varying lengths (from 1 up to the block size), during inference we can start generating text with just one token. The transformer will know how to predict the next tokens as the sequence grows, up to the block size. Once the sequence reaches the block size, we start truncating older tokens from the context to mantain the sequence length to block size.

In [10]:
block_size = 8

print('Chunk:', train_data[:block_size+1].tolist())

x = train_data[:block_size]
y = train_data[1:block_size+1]

print('\nExamples:')
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'{t+1}) When context is {context.tolist()} the target is {target}')

Chunk: [18, 47, 56, 57, 58, 1, 15, 47, 58]

Examples:
1) When context is [18] the target is 47
2) When context is [18, 47] the target is 56
3) When context is [18, 47, 56] the target is 57
4) When context is [18, 47, 56, 57] the target is 58
5) When context is [18, 47, 56, 57, 58] the target is 1
6) When context is [18, 47, 56, 57, 58, 1] the target is 15
7) When context is [18, 47, 56, 57, 58, 1, 15] the target is 47
8) When context is [18, 47, 56, 57, 58, 1, 15, 47] the target is 58


## 7.8 CUDA

CUDA exploits the advantages of GPUs over CPUs by utilizing the **parallelism** offered by GPUs' multiple cores. Unlike CPUs, which are optimized for sequential processing, GPUs have thousands of cores that can launch a large number of simultaneous threads, allowing for highly parallel execution of tasks.

We are going to add the capability to run computations on a GPU, if available, to significantly speed up operations. This is particularly beneficial for tasks that can be parallelized, such as data processing and matrix operations.

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('Device:', device)

Device: cpu


## 7.9 Batch dimension

Because GPUs excel at parallel processing of data, we can stack chunks in a single tensor, known as a **batch**, that feeds into the transformer. Thus, multiple chunks can be processed simultaneously and completely independently.

In [12]:
batch_size = 4 # number of chunks per bacth
block_size = 8 # chunks maximum context length

# generate a small batch of chunks of inputs x and targets y
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch('train')

print('Inputs shape:', tuple(xb.shape))
print(xb)
print('\nTargets shape:', tuple(yb.shape))
print(yb)

Inputs shape: (4, 8)
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

Targets shape: (4, 8)
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


## 7.10 Calculate loss function

In [13]:
from torch.nn import functional as F

def calculate_loss(logits, targets):
    "Calculate the cross-entropy loss if targets are available."
    
    if targets is None:
        return None
    else:
        B, T, C = logits.shape
        logits = logits.view(B * T, C)
        targets = targets.view(B * T)
        loss = F.cross_entropy(logits, targets)
        return loss

## 7.11 Generate tokens function

In [15]:
def generate_tokens(model, idx, max_new_tokens, block_size):
    """Generate tokens by predicting one token at a time."""
    
    # idx is the current context
    # in each iteration idx will grow:
    # (B, T), (B, T+1), (B, T+2), ..., (B, T+max_new_tokens)

    for _ in range(max_new_tokens):

        # crop the context to the last block_size tokens
        idx_cond = idx[:, -block_size:]

        # get logits for current context (calling forward method)
        logits, _ = model(idx_cond) # (B, T, C)
        
        # focus only on the last time-step because
        # those are the predictions for what comes next
        logits = logits[:, -1, :] # (B, C)
        
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)
        
        # sample the next token from the distribution
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

        # append the sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    
    return idx

## 7.12 Bigram model

Although a bigram model is a very simple model, it is a good starting point to begin building the GPT architecture.

<br>

**Note:**

[**nn.Module**](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) is a base class in PyTorch that provides a way to define complex models by encapsulating parameters and methods that are used during training and evaluation. Some of its key features and functionalities are:
- Parameter Management: It allows you to define parameters (weights and biases) that can be automatically registered and tracked.
- Forward Method: This method defines how the input data flows through the model.
- Backward Pass: The nn.Module automatically supports backpropagation through the layers defined in it, allowing you to compute gradients easily.
- Model Evaluation: It provides methods like train() and eval() to set the mode of the model.
- Built-in Layers: PyTorch provides a wide range of pre-defined layers (like nn.Linear) that inherit from nn.Module.

**Note:**

*super().\_\_init\_\_()* calls the constructor of the parent class (in this case, nn.Module).

**Note:** 

[**nn.Embedding**](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) is a class in PyTorch that creates a simple lookup table that stores embeddings of a fixed dictionary and size. The primary purpose of embedding layers is to map sequences of token indices into dense vector representations, knowns as embeddings. In this case, however, since we are implementing a bigram model, we are going to use an embedding table so that each token reads off the logits for the next token.

**Note:**
- B = Batch dimension (batch_size)
- T = Time dimension (block_size)
- C = Channels (vocab_size)

In [16]:
import torch.nn as nn

class BigramModel(nn.Module):
    
    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # get logits from lookup table
        logits = self.token_embedding_table(idx) # (B, T, C)
        
        # calculate loss
        loss = calculate_loss(logits, targets)
        
        return logits, loss
    

    def generate(self, idx, max_new_tokens):
        
        return generate_tokens(self, idx, max_new_tokens, block_size)

In [17]:
model = BigramModel()
m = model.to(device)

## Evaluate the model

Instead of printing the bacth loss in every iteration, the *estimate_loss()* function averages up the **loss over multiple batches**.

In [18]:
eval_iters = 200      # how many iterations are used to calculate the loss
eval_interval = 10000 # every how many iterations calculate the loss

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval() # put the model in evaluation mode

    # calculate train loss and evaulation loss
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)

        for k in range(eval_iters):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[k] = loss.item()

        out[split] = losses.mean()

    model.train() # put the model back in train mode
    return out

## Train the model

In [19]:
batch_size = 32 # number of chunks per bacth
learning_rate = 1e-3
max_iters = 40000

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"Step:{iter:6d} /{max_iters:6d}   Train loss: {losses['train']:.4f}   Val loss: {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # forward pass
    logits, loss = model(xb, yb)

    # backward pass
    optimizer.zero_grad(set_to_none=True)
    loss.backward()

    # update
    optimizer.step()

Step:     0 / 40000   Train loss: 4.6355   Val loss: 4.6491
Step: 10000 / 40000   Train loss: 2.4726   Val loss: 2.4923
Step: 20000 / 40000   Train loss: 2.4551   Val loss: 2.4864
Step: 30000 / 40000   Train loss: 2.4519   Val loss: 2.4940
Step: 39999 / 40000   Train loss: 2.4518   Val loss: 2.4949


## Generate from the model

We are going to start the inference with the tensor [[0]]. The reason for thit is that index 0 corresponds to the new line character, '\n'. The generate method will produce additional characters up to max_new_tokens.

Please note that we are currently feeding the entire growing context (whatever is generated) into the model. However, because it is a bigram model, we are only using the last character to predict the next character, which explains the poor results.

In [46]:
context = torch.zeros((1, 1), dtype=torch.long, device=device) # tensor [[0]]
print(decode(m.generate(context, max_new_tokens=300)[0].tolist()))


zdt:
3dwQgjtYc-P'Tgxz'Cwfdk-u&$hOHyrZBIBWiOO;l-whmQukjPoJiz;wCuGCKo'KBaxN SkYslyuIBRN mgbzPhOJ&TpV&THiFtMF.K&:LS'NByti-Ou;;OWOqXfwPWHjPBXh.?OoVKvlZl&dC i
'zjGAtwBrxTGxJ,Ce'KVm
B!JSkxUlZMiM'-h Q':HVR JS;&UorEu'Puew'whq3HKkWjhg3epNfytAPl
yGiK Zk?X.lkmHhZdYxL$!pbI$'Jx:v
,Pw3Fq
Un::ZesgODtMQz$f-gHcwVD3p


## Introduction to attention

Attention is a **communication mechanism** that can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.

In our bigram model, the tokens in each chunk are currently not interacting with each other, so we would like to implement a mechanism to allow them to **communicate**. In particular, we want to ensure that each token only interacts with the tokens before it in the sequence.

<div style="width: 570px; margin: 0 auto;">
    <img src="https://raw.githubusercontent.com/danielsimon4/language-modeling/refs/heads/main/Images/attention-graph.png">
</div>

<br>

The simplest way to achieve this is by computing the **average** of the preceding tokens. For instance, the fourth token should aggregate its channels with those of the third, second, and first tokens, averaging their values.


Consider the following chunk where every row represents a token and the columns represent the channels:

$$
\begin{bmatrix}
x_{11} & x_{12} & x_{13} & \dots & x_{1C} \\
x_{21} & x_{22} & x_{23} & \dots & x_{2C} \\
x_{31} & x_{32} & x_{33} & \dots & x_{3C} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_{T1} & x_{T2} & x_{T3} & \dots & x_{TC}
\end{bmatrix}
$$

<br>

Performing this algorithm, the tokens would interact as follows:

$$
\begin{bmatrix}
x_{11} & x_{12} & x_{13} & \dots & x_{1C} \\
\frac{x_{11} + x_{21}}{2} & \frac{x_{12} + x_{22}}{2} & \frac{x_{13} + x_{23}}{2} & \dots & \frac{x_{1C} + x_{2C}}{2} \\
\frac{x_{11} + x_{21} + x_{31}}{3} & \frac{x_{12} + x_{22} + x_{32}}{3} & \frac{x_{13} + x_{23} + x_{33}}{3} & \dots & \frac{x_{1C} + x_{2C} + x_{3C}}{3} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
\frac{x_{11} + x_{21} + x_{31} + ... + x_{T1}}{T} & \frac{x_{12} + x_{22} + x_{32} + ... + x_{T2}}{T} & \frac{x_{13} + x_{23} + x_{33} + ... + x_{T3}}{T} & \dots & \frac{x_{1C} + x_{2C} + x_{3C} + ... + x_{TC}}{T}
\end{bmatrix}
$$

<br>

We can implement this averaging process with **two nested for loops**, where the inner loop iterates over each token in the chunk and computes the average of the channels for all the preceding tokens and the outer loop iterates over each chunk in the batch.

In [21]:
B, T, C = 4, 8, 2 # batch, time, channels

# random imput tensor
x = torch.randn(B,T,C) # (B, T, C)

# matrix of all zeros
xbow = torch.zeros((B,T,C)) # (B, T, C)

for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (T, C)
        xbow[b,t] = torch.mean(xprev, 0) # (C)

# print only the first chunk of the batch
print(x[0])

# print how the tokens of that chunk would interact
print(xbow[0])

tensor([[-0.8385,  0.7409],
        [ 0.5635, -0.6097],
        [ 1.8720,  0.0590],
        [ 0.1869, -0.2114],
        [ 1.5499,  0.2369],
        [-2.3956, -0.3363],
        [ 2.2205, -0.1176],
        [ 0.0070, -1.0434]])
tensor([[-0.8385,  0.7409],
        [-0.1375,  0.0656],
        [ 0.5323,  0.0634],
        [ 0.4460, -0.0053],
        [ 0.6668,  0.0431],
        [ 0.1564, -0.0201],
        [ 0.4512, -0.0340],
        [ 0.3957, -0.1602]])


## Matrix multiplication efficiency

We can improve efficiency by replacing the inner for loop with **matrix multiplication** and a **lower triangular matrix** like the one below. In addtion, PyTorch can perform multiple matrix multiplications simultaneously and independently, even when using only a CPU. This allows us to process several chunks more efficiently than with the outer loop.

$$
\begin{bmatrix}
1 & 0 & 0 & 0 & 0 \\
\frac{1}{2} & \frac{1}{2} & 0 & 0 & 0 \\
\frac{1}{3} & \frac{1}{3} & \frac{1}{3} & 0 & 0 \\
\frac{1}{4} & \frac{1}{4} & \frac{1}{4} & \frac{1}{4} & 0 \\
\frac{1}{5} & \frac{1}{5} & \frac{1}{5} & \frac{1}{5} & \frac{1}{5}
\end{bmatrix}
\times
\begin{bmatrix}
x_{11} & x_{12} & x_{13} & \dots & x_{1C} \\
x_{21} & x_{22} & x_{23} & \dots & x_{2C} \\
x_{31} & x_{32} & x_{33} & \dots & x_{3C} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_{T1} & x_{T2} & x_{T3} & \dots & x_{TC}
\end{bmatrix}
$$

<br>

Performing this matrix multiplication, the tokens would interact in a similar way as before:

$$
\begin{bmatrix}
x_{11} & x_{12} & x_{13} & \dots & x_{1C} \\
\frac{x_{11} + x_{21}}{2} & \frac{x_{12} + x_{22}}{2} & \frac{x_{13} + x_{23}}{2} & \dots & \frac{x_{1C} + x_{2C}}{2} \\
\frac{x_{11} + x_{21} + x_{31}}{3} & \frac{x_{12} + x_{22} + x_{32}}{3} & \frac{x_{13} + x_{23} + x_{33}}{3} & \dots & \frac{x_{1C} + x_{2C} + x_{3C}}{3} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
\frac{x_{11} + x_{21} + x_{31} + ... + x_{T1}}{T} & \frac{x_{12} + x_{22} + x_{32} + ... + x_{T2}}{T} & \frac{x_{13} + x_{23} + x_{33} + ... + x_{T3}}{T} & \dots & \frac{x_{1C} + x_{2C} + x_{3C} + ... + x_{TC}}{T}
\end{bmatrix}
$$

In [22]:
# lower triangular matrix of all ones
wei = torch.tril(torch.ones(T, T)) # (T, T)

# lower triangular matrix
wei /= wei.sum(1, keepdim=True) # (T, T)

# matrix multiplication
xbow2 = wei @ x # (B, T, C) = (B, T, T) x (B, T, C)

# compare results
torch.allclose(xbow, xbow2)

True

## Decoder attention

In fact, the lower triangular matrix contains the weights for the weighted sum of the past elements. Those **attention weights** control how much influence each past token should have on the current token. We are going to modify the way we construct this lower triangular matrix.

Initially, the matrix is going to be completely zeroed out, indicating no attention is being paid yet.

\begin{bmatrix}
0 & 0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 & 0 \\
0 & 0 & 0 & 0 & 0
\end{bmatrix}

<br>

Next, we are going to apply a **masking** to prevent future tokens from interacting with the past. To accomplish this, we can set those positions to −∞, so after applying Softmax their attention weights are zero.

\begin{bmatrix}
0 & -∞ & -∞ & -∞ & -∞ \\
0 & 0 & -∞ & -∞ & -∞ \\
0 & 0 & 0 & -∞ & -∞ \\
0 & 0 & 0 & 0 & -∞ \\
0 & 0 & 0 & 0 & 0
\end{bmatrix}

<br>

Finally, we are going to normalize the attention weights and convert them into probabilities using **Softmax**, which ensures that the attention weights in each row sum to 1, distributing influence only among the past tokens. Please note we got a lower triangular matrix similar to the one before.

\begin{bmatrix}
1 & 0 & 0 & 0 & 0 \\
\frac{1}{2} & \frac{1}{2} & 0 & 0 & 0 \\
\frac{1}{3} & \frac{1}{3} & \frac{1}{3} & 0 & 0 \\
\frac{1}{4} & \frac{1}{4} & \frac{1}{4} & \frac{1}{4} & 0 \\
\frac{1}{5} & \frac{1}{5} & \frac{1}{5} & \frac{1}{5} & \frac{1}{5}
\end{bmatrix}

<br>

We just created a **decoder attention block** because we used a triangular masking to ensure that tokens from the future cannot influence tokens from the past. Such a setup is commonly used in autoregressive models, like language modeling, where predictions are made token by token, without access to future context.

In contrast, in an **encoder attention block**, there is no masking. All tokens are free to interact with each another, allowing information to flow bidirectionally across the entire sequence.

In [23]:
# apply masking
tril = torch.tril(torch.ones(T, T)) # (T, T)
wei = torch.zeros((T,T))            # (T, T)
wei = wei.masked_fill(tril == 0, float('-inf')) # (T, T)

# apply softmax across each row
wei = F.softmax(wei, dim=-1)        # (T, T)

# matrix multiplication
xbow3 = wei @ x # (B, T, C) = (B, T, T) x (B, T, C)

# compare results
torch.allclose(xbow, xbow3)

True

## Scaled self-attention
But we want the attention weights in each row to vary based on the relationships between the tokens. Different tokens might find certain other tokens more relevant or important, and we need to capture these dynamic relationships. In other words, we want to gather information from previous tokens in a **data-dependent way**.

We can achieve this using a **query matrix** (representing "what am I looking for"), a **key matrix** (representing "what information do I have"), and a **values matrix** for each chunk. To get these matrices, we are just going to perform matrix multiplication between the input and some weights.

$$
\begin{align*}
\begin{bmatrix}
q_{11} & q_{12} & q_{13} & \dots & q_{1H} \\
q_{21} & q_{22} & q_{23} & \dots & q_{2H} \\
q_{31} & q_{32} & q_{33} & \dots & q_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
q_{T1} & q_{T2} & q_{T3} & \dots & q_{TH}
\end{bmatrix}
&=
\begin{bmatrix}
x_{11} & x_{12} & x_{13} & \dots & x_{1C} \\
x_{21} & x_{22} & x_{23} & \dots & x_{2C} \\
x_{31} & x_{32} & x_{33} & \dots & x_{3C} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_{T1} & x_{T2} & x_{T3} & \dots & x_{TC}
\end{bmatrix}
\times
\begin{bmatrix}
qw_{11} & qw_{12} & qw_{13} & \dots & qw_{1H} \\
qw_{21} & qw_{22} & qw_{23} & \dots & qw_{2H} \\
qw_{31} & qw_{32} & qw_{33} & \dots & qw_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
qw_{C1} & qw_{C2} & qw_{C3} & \dots & qw_{CH}
\end{bmatrix}

\\
\\

\begin{bmatrix}
k_{11} & k_{12} & k_{13} & \dots & k_{1H} \\
k_{21} & k_{22} & k_{23} & \dots & k_{2H} \\
k_{31} & k_{32} & k_{33} & \dots & k_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
k_{T1} & k_{T2} & k_{T3} & \dots & k_{TH}
\end{bmatrix}
&=
\begin{bmatrix}
x_{11} & x_{12} & x_{13} & \dots & x_{1C} \\
x_{21} & x_{22} & x_{23} & \dots & x_{2C} \\
x_{31} & x_{32} & x_{33} & \dots & x_{3C} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_{T1} & x_{T2} & x_{T3} & \dots & x_{TC}
\end{bmatrix}
\times
\begin{bmatrix}
kw_{11} & kw_{12} & kw_{13} & \dots & kw_{1H} \\
kw_{21} & kw_{22} & kw_{23} & \dots & kw_{2H} \\
kw_{31} & kw_{32} & kw_{33} & \dots & kw_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
kw_{C1} & kw_{C2} & kw_{C3} & \dots & kw_{CH}
\end{bmatrix}

\\
\\

\begin{bmatrix}
v_{11} & v_{12} & v_{13} & \dots & v_{1H} \\
v_{21} & v_{22} & v_{23} & \dots & v_{2H} \\
v_{31} & v_{32} & v_{33} & \dots & v_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
v_{T1} & v_{T2} & v_{T3} & \dots & v_{TH}
\end{bmatrix}
&=
\begin{bmatrix}
x_{11} & x_{12} & x_{13} & \dots & x_{1C} \\
x_{21} & x_{22} & x_{23} & \dots & x_{2C} \\
x_{31} & x_{32} & x_{33} & \dots & x_{3C} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
x_{T1} & x_{T2} & x_{T3} & \dots & x_{TC}
\end{bmatrix}
\times
\begin{bmatrix}
vw_{11} & vw_{12} & vw_{13} & \dots & vw_{1H} \\
vw_{21} & vw_{22} & vw_{23} & \dots & vw_{2H} \\
vw_{31} & vw_{32} & vw_{33} & \dots & vw_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
vw_{C1} & vw_{C2} & vw_{C3} & \dots & vw_{CH}
\end{bmatrix}
\end{align*}
$$

<br>

The **affinity** between tokens is computed by performing matrix multiplication between the queries of one token and the transposed keys of all the other tokens. This results in an **affinity matrix**, where the query of each token measures how closely it aligns with the keys of the other tokens.

$$
\begin{bmatrix}
a_{11} & a_{12} & a_{13} & \dots & a_{1T} \\
a_{21} & a_{22} & a_{23} & \dots & a_{2T} \\
a_{31} & a_{32} & a_{33} & \dots & a_{3T} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
a_{T1} & a_{T2} & a_{T3} & \dots & a_{TT}
\end{bmatrix}
=
\begin{bmatrix}
q_{11} & q_{12} & q_{13} & \dots & q_{1H} \\
q_{21} & q_{22} & q_{23} & \dots & q_{2H} \\
q_{31} & q_{32} & q_{33} & \dots & q_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
q_{T1} & q_{T2} & q_{T3} & \dots & q_{TH}
\end{bmatrix}
\times
\begin{bmatrix}
k_{11} & k_{12} & k_{13} & \dots & k_{1H} \\
k_{21} & k_{22} & k_{23} & \dots & k_{2H} \\
k_{31} & k_{32} & k_{33} & \dots & k_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
k_{T1} & k_{T2} & k_{T3} & \dots & k_{TH}
\end{bmatrix}^T
$$

<br>

Then, we are going to divide all the computed affinities by the square root of the head_size. This action, know as **scaled attention**, is essential because it prevents us from getting extremely sharp distributions that concentrate too much attention in one element when we apply Softmax later.

$$
\begin{bmatrix}
a_{11} & a_{12} & a_{13} & \dots & a_{1T} \\
a_{21} & a_{22} & a_{23} & \dots & a_{2T} \\
a_{31} & a_{32} & a_{33} & \dots & a_{3T} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
a_{T1} & a_{T2} & a_{T3} & \dots & a_{TT}
\end{bmatrix}
=
\begin{bmatrix}
a_{11} & a_{12} & a_{13} & \dots & a_{1T} \\
a_{21} & a_{22} & a_{23} & \dots & a_{2T} \\
a_{31} & a_{32} & a_{33} & \dots & a_{3T} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
a_{T1} & a_{T2} & a_{T3} & \dots & a_{TT}
\end{bmatrix}
·
\frac{1}{\sqrt{\text{head\_size}}}
$$

<br>

Next, we are goint to apply a masking, similar to the one from before, to prevent future tokens from interacting with the past by setting those affinities to −∞.

\begin{bmatrix}
a_{11} & -∞ & -∞ & -∞ & -∞ \\
a_{21} & a_{22} & -∞ & -∞ & -∞ \\
a_{31} & a_{32} & a_{33} & -∞ & -∞ \\
a_{41} & a_{42} & a_{43} & a_{44} & -∞ \\
a_{51} & a_{52} & a_{53} & a_{54} & a_{55}
\end{bmatrix}

<br>

Finally, we are going to apply softmax to convert these affinities into attention weights, which determine how much attention each token pays to others.

\begin{bmatrix}
w_{11} & 0 & 0 & 0 & 0 \\
w_{21} & w_{22} & 0 & 0 & 0 \\
w_{31} & w_{32} & w_{33} & 0 & 0 \\
w_{41} & w_{42} & w_{43} & w_{44} & 0 \\
w_{51} & w_{52} & w_{53} & w_{54} & w_{55}
\end{bmatrix}

<br>

Once we have our lower triangular matrix containing the attention weights, we are just going to perform matrix multiplication between this matrix and the values matrix.

$$
\begin{bmatrix}
w_{11} & 0 & 0 & 0 & 0 \\
w_{21} & w_{22} & 0 & 0 & 0 \\
w_{31} & w_{32} & w_{33} & 0 & 0 \\
w_{41} & w_{42} & w_{43} & w_{44} & 0 \\
w_{51} & w_{52} & w_{53} & w_{54} & w_{55}
\end{bmatrix}
\times
\begin{bmatrix}
v_{11} & v_{12} & v_{13} & \dots & v_{1H} \\
v_{21} & v_{22} & v_{23} & \dots & v_{2H} \\
v_{31} & v_{32} & v_{33} & \dots & v_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
v_{T1} & v_{T2} & v_{T3} & \dots & v_{TH}
\end{bmatrix}
$$

<br>

Performing this matrix multiplication, the tokens would interact as follows:

$$
\begin{bmatrix}
w_{11}v_{11} & w_{11}v_{12} & w_{11}v_{13} & \dots & w_{11}v_{1H} \\
w_{21}v_{11} + w_{22}v_{21} & w_{21}v_{12} + w_{22}v_{22} & w_{21}v_{13} + w_{22}v_{23} & \dots & w_{21}v_{1H} + w_{22}v_{2H} \\
w_{31}v_{11} + w_{32}v_{21} + w_{33}v_{31}  & w_{31}v_{12} + w_{32}v_{22} + w_{33}v_{32} & w_{31}v_{13} + w_{32}v_{23} + w_{33}v_{33} & \dots & w_{31}v_{1H} + w_{32}v_{2H} + w_{33}v_{3H} \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
w_{T1}v_{11} + w_{T2}v_{21} + w_{T3}v_{31} + ... + w_{TT}v_{T1} & w_{T1}v_{12} + w_{T2}v_{22} + w_{T3}v_{32} + ... + w_{TT}v_{T2} & w_{T1}v_{13} + w_{T2}v_{23} + w_{T3}v_{33} + ... + w_{TT}v_{T3} & \dots & w_{T1}v_{1H} + w_{T2}v_{2H} + w_{T3}v_{3H} + ... + w_{TT}v_{TH}
\end{bmatrix}
$$

<br>

As we can see, we have just implemented the attention mechanism, which is essentially a **weighted sum with data-dependent weights**.

<br>

We just created a **self-attention block** because the queries, keys, and values are all originated from the same input. In contrast, in a **cross-attention block**, the keys and values come from a source different from the one of the queries (such as an encoder module in a transformer model).

In [24]:
B, T, C = 4, 8, 32
head_size = C

# random imput tensor
x = torch.randn(B,T,C) # (B, T, C)

# define three linear layers
query = nn.Linear(head_size, head_size, bias=False)
key = nn.Linear(head_size, head_size, bias=False)
value = nn.Linear(head_size, head_size, bias=False)

# compute queries, keys, and values
q = query(x) # (B, T, hs)
k = key(x)   # (B, T, hs)
v = value(x) # (B, T, hs)

# compute affinities and scale them
wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, T) = (B, T, hs) @ (B, hs, T)

# apply masking
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # (B, T, T)

# apply softmax across each row
wei = F.softmax(wei, dim=-1) # (B, T, T)

# matrix multiplication
xbow4 = wei @ v  # (B, T, hs) = (B, T, T) @ (B, T, hs)

# print the lower triangular matrix for the first chunk
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5369, 0.4631, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2816, 0.5207, 0.1977, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3511, 0.2647, 0.1072, 0.2770, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1816, 0.1325, 0.2561, 0.2885, 0.1413, 0.0000, 0.0000, 0.0000],
        [0.1181, 0.1913, 0.1732, 0.1039, 0.1065, 0.3070, 0.0000, 0.0000],
        [0.2005, 0.0781, 0.1450, 0.2032, 0.1033, 0.1038, 0.1661, 0.0000],
        [0.1348, 0.1106, 0.1388, 0.2255, 0.0614, 0.0745, 0.0976, 0.1568]],
       grad_fn=<SelectBackward0>)

## Head module

We just created a **single head of self-attention** which allows the model to focus on different parts of the chunk simultaneously, extracting information about the **relationships between the tokens**. 

Now, we are just going to incorporate the previous code into the `Head` module.

<div style="width: 220px; margin: 0 auto;">
    <img src="https://production-media.paperswithcode.com/methods/35184258-10f5-4cd0-8de3-bd9bc8f88dc3.png">
</div>

In [25]:
class Head(nn.Module):
    """ One head of self-attention. """

    def __init__(self, head_size):
        super().__init__()

        # define three linear layers
        self.query = nn.Linear(head_size, head_size, bias=False)
        self.key = nn.Linear(head_size, head_size, bias=False)
        self.value = nn.Linear(head_size, head_size, bias=False)
        
        # trill is not a parameter of the module so it has to be registered as a
        # buffer according to PyTorch naming conventions
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):

        _, T, _ = x.shape # (B, T, n_embd)

        # compute queries, keys, and values
        q = query(x) # (B, T, hs)
        k = key(x)   # (B, T, hs)
        v = value(x) # (B, T, hs)

        # compute attention weights
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, T) = (B, T, hs) @ (B, hs, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)

        # matrix multiplication
        out = wei @ v     # (B, T, hs) = (B, T, T) @ (B, T, hs)

        return out

## Token embeddings

In language models, tokenization is usually the first step, which converts input text into a sequence of token indices. After tokenization, an **embedding layer** typically maps these sequences of token indices into dense vector representations, knowns as **embeddings**, which can be used as input to a neural network.

Right now, the bigram model is using an embedding table of size (vocab_size, vocab_size) so that each token reads off the logits for the next token. We are going to modify this embedding table to return **token embeddings** instead of logits. The new **token embedding table** of size (vocab_size, n_embd) is going to encode the tokens based on their **identity**, providing a more meaningful representation of each token in a continuous vector space.

<div style="width: 650px; margin: 0 auto;">
    <img src="https://miro.medium.com/v2/resize:fit:1400/0*cgpKoFocSYm6bLHw.png">
</div>

## Positional embeddings

In attention mechanisms, there is **no inherent sense of position** among the input tokens, they simply act over a set of vectors. This is why we need to introduce **positional embeddings** to take into account the order of the tokens.

The **position embedding table** of size (block_size, n_embd) is going to encode the tokens based on their **position**, providing information of the token's place within the sequence to the model.

## Single head of self-attention

We are going to modify the bigram model to implement a single head of self-attention.

After obtaining the token and positional embeddings from the input to the model, we combine them by **adding the two embeddings**. This combined representation of the identity and position of the tokens is going to be the input to the single head of self-attention.

To generate the logits for predicting the next token, we are going to apply a linear layer, often referred to as the **language modeling head**, to the output of the single head of self-attention.

In [26]:
n_embd = 32 # number of embedding dimensions

class GPTModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_head = Head(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):

        _, T = idx.shape

        # token and positional embeddings
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, n_embd)
        x = tok_emb + pos_emb # (B, T, n_embd)

        # single head of self-attention
        x = self.sa_head(x)   # (B, T, n_embd)

        # language modeling head
        logits = self.lm_head(x) # (B, T, vocab_size)

        # calculate loss
        loss = calculate_loss(logits, targets)
            
        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        return generate_tokens(self, idx, max_new_tokens, block_size)

Using a single head of self-attention, the validation loss went down from 2.45 to 2.35.

## Multi-head self-attention

Multi-head attention involves applying **multiple attention mechanisms in parallel** and then concatenating their results. 

This approach allows the model to capture different types of relationships between the tokens across **multiple communication channels**. Using several heads with smaller dimensions enables the model to learn diverse patterns more effectively than relying on a single head with a larger dimension.


<div style="width: 450px; margin: 0 auto;">
    <img src="https://miro.medium.com/v2/resize:fit:1010/0*0KPEV8QidHkteKeY.png">
</div>

<br>

**Note:**

[**nn.ModuleList**](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) is a container in PyTorch used to hold a list of submodules, where each module can be indexed using an integer. It is useful when you have a variable number of modules that need to be stored and managed together.

In [27]:
class MultiHeadAttention(nn.Module):
    """ Multiple heads of self-attention in parallel. """

    def __init__(self, n_embd, n_head):
        super().__init__()

        head_size = n_embd // n_head
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
    
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)

We are going to modify the GPT model to implement multi-head self-attention.

In [28]:
n_head = 4  # number of heads

class GPTModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadAttention(n_embd, n_head)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):

        _, T = idx.shape

        # token and positional embeddings
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, n_embd)
        x = tok_emb + pos_emb # (B, T, n_embd)

        # multi-head self-attention
        x = self.sa_heads(x)  # (B, T, n_embd)

        # language modeling head
        logits = self.lm_head(x) # (B, T, vocab_size)

        # calculate loss
        loss = calculate_loss(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        return generate_tokens(self, idx, max_new_tokens, block_size)

Using multi-head self-attention, the validation loss went down from 2.35 to 2.2.

## Feedforward neural network

The **multi-headed self-attention** mechanism allows each token in a sequence to **attend to** (look at) every other token. It computes the relationships or dependencies between the tokens by assigning different attention scores. However, at this stage, the tokens **are not processing** the information they have gathered. They are simply identifying which tokens are relevant to them.

Once the tokens have collected this information through self-attention, the output of each token is passed through a **feedforward neural network (FFN)**. The FFN processes the information independently for each token, allowing them to **analyze** the data they have previously gathered. This is where the model transforms the attended information into something meaningful for downstream tasks.

<br>

**Note:**

[**nn.Sequential**](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html) is a container in PyTorch that provides an easy way to build neural network architectures by stacking layers or operations in sequence.

In [29]:
class FeedFoward(nn.Module):
    """ A simple linear layer followed by a non-linearity. """

    def __init__(self, n_embd):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)

We are going to modify the GPT model to add a FFN after the multi-head self-attention.

In [30]:
class GPTModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.sa_heads = MultiHeadAttention(n_embd, n_head)
        self.ffwd = FeedFoward(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):

        _, T = idx.shape

        # token and positional embeddings
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, n_embd)
        x = tok_emb + pos_emb # (B, T, n_embd)

        # multi-head self-attention
        x = self.sa_heads(x)  # (B, T, n_embd)

        # feedforward network
        x = self.ffwd(x)      # (B, T, n_embd)

        # language modeling head
        logits = self.lm_head(x) # (B, T, vocab_size)

        # calculate loss
        loss = calculate_loss(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        return generate_tokens(self, idx, max_new_tokens, block_size)

Using a FFN after the multi-head self-attention, the validation loss went down from 2.2 to 2.16.

## Block module

We are going to incorporate both the multi-headed self-attention mechanism and the feedforward network in the `Block` module.

A block performs **communication** (handled by the multi-headed self-attention, where the tokens attento to each other) and **computation** (handled by the feedforward network, where the tokens individually analyze the gathered data).

<div style="width: 180px; margin: 0 auto;">
    <img src="https://raw.githubusercontent.com/danielsimon4/language-modeling/refs/heads/main/Images/block-1.png">
</div>

In [31]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation. """

    def __init__(self, n_embd, n_head):
        super().__init__()
        
        self.sa_heads = MultiHeadAttention(n_embd, n_head)
        self.ffwd = FeedFoward(n_embd)

    def forward(self, x):

        # communication
        x = self.sa_heads(x)

        # computation
        x = self.ffwd(x)
        
        return x

We are going to modify the GPT model so it implements several blocks.

In [32]:
n_layer = 3 # number of blocks

class GPTLModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):

        _, T = idx.shape

        # token and positional embeddings
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, n_embd)
        x = tok_emb + pos_emb # (B, T, n_embd)

        x = self.blocks(x)    # (B, T, n_embd)
        
        # language modeling head
        logits = self.lm_head(x) # (B, T, vocab_size)

        # calculate loss
        loss = calculate_loss(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        
        return generate_tokens(self, idx, max_new_tokens, block_size)

Using 3 blocks, the validation loss went down from 2.16 to 2.1.

## Residual connections

We are starting to get a deep neural net that suffers from optimization issues. To address these, **residual connections**, introduced in the paper [*Deep Residual Learning for Image Recognition*](https://arxiv.org/abs/1512.03385), can be highly effective. The core idea is that while the data undergoes transformations, a **skip connection** adds the original features back to the transformed data via element-wise addition (Image 1).

Another way to conceptualize residual connections is through a **residual pathway**. Along this pathway, **residual blocks** branch off, perform computations, and then merged back with the original pathway via addition (Image 2). As the network is trained, these blocks gradually begin to contribute to the final result. The key advantage of residual connections is that, during the early stages of optimization, the **gradient flows smoothly** from the output layer back to the input, avoiding the vanishing gradient problem.

Before merging the transformed data back into the residual pathway, a **projection layer** is typically added. This layer applies a simple linear transformation (*y = x · w + b*) to the output of the transformation to ensure the dimensions match.


<div style="width: 650px; margin: 0 auto;">
    <img src="https://pbs.twimg.com/media/ESnE4IvUYAAopRf.jpg">
</div>

We are going to add a projection layer in the `MultiHeadAttention` module.

In [33]:
class MultiHeadAttention(nn.Module):
    """ Multiple heads of self-attention in parallel. """

    def __init__(self, n_embd, n_head):
        super().__init__()

        head_size = n_embd // n_head
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        self.proj = nn.Linear(head_size * n_head, n_embd)

    def forward(self, x):

        # multi-head self-attention
        out = torch.cat([h(x) for h in self.heads], dim=-1)

        # projection layer
        out = self.proj(out)

        return out

COMPLETE!!

In [34]:
class FeedFoward(nn.Module):
    """ A simple linear layer followed by a non-linearity. """

    def __init__(self, n_embd):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), # growing inner-layer
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # projection layer
        )

    def forward(self, x):
        return self.net(x)

We are going to incorporate residual connections (x = x + F(x)) into the transformations within the `Block` module to ensure smoother gradient flow and improve optimization.

<div style="width: 180px; margin: 0 auto;">
    <img src="https://raw.githubusercontent.com/danielsimon4/language-modeling/refs/heads/main/Images/block-2.png">
</div>

In [35]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation. """

    def __init__(self, n_embd, n_head):
        super().__init__()

        self.sa_heads = MultiHeadAttention(n_embd, n_head)
        self.ffwd = FeedFoward(n_embd)

    def forward(self, x):

        # communication with residual connections
        x = x + self.sa_heads(x)
        
        # computation with residual connections
        x = x + self.ffwd(x)

        return x

Using residual connections, the validation loss went down from 2.1 to 1.94.

## Layer normalization

Layer normalization also helps with the optimization of deep neural networks and is described in the paper [Ba et al. (2016). *Layer Normalization*](https://arxiv.org/abs/1607.06450). Remember that batch normalization made sure that across the batch dimension any individual neuron had a unit Gaussian distribution (0 mean and 1 standard deviation output) at initialization.

Layer normalization is identical to batch bormalization but normalizes across the rows instead of the columns, does not need the running mean and the running variance buffers, and there is no distinction beteween train and test time. The layer normalization acts on a per token level and **normalizes the features** making them unit Gaussian at initialization.

In [36]:
class LayerNorm1d:

  def __init__(self, dim, eps=1e-5):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)

  def __call__(self, x):
    xmean = x.mean(1, keepdim=True)
    xvar = x.var(1, keepdim=True)
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
    self.out = self.gamma * xhat + self.beta
    return self.out

  def parameters(self):
    return [self.gamma, self.beta]

In the [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762) paper, **layer normalization** is applied after the application of attention and feed forward. Nowadays, it is more common to apply the layer normalization **before the tranformations** (this is called **pre-norm formulation**).

In addition, layer normalization occurs within the residual pathway. However, it is preferable to maintain a single clean residual stream all the way down from supervison to the inputs, which facilitates smoother gradient flow during backpropagation. This ensures that even the shallow layers receive direct supervision, preventing the network from suffering from vanishing gradients.

<div style="width: 160px; margin: 0 auto;">
    <img src="https://raw.githubusercontent.com/danielsimon4/language-modeling/refs/heads/main/Images/block-architecture.png">
</div>

<br>

We are going to apply layer normalization before the transformations and outside of the residual pathway.

<div style="width: 180px; margin: 0 auto;">
    <img src="https://raw.githubusercontent.com/danielsimon4/language-modeling/refs/heads/main/Images/block-3.png">
</div>

In [37]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        super().__init__()

        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):

        # communication with residual connections and layer norm
        x = x + self.sa(self.ln1(x))

        # computation with residual connections and layer norm
        x = x + self.ffwd(self.ln2(x))

        return x

It is also common to add a layer normalization **at the end of the transformer** and right before the final linear layer that decodes into vocabulary:

In [38]:
class GPTLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):

        _, T = idx.shape

        # token and position embeddings from embedding tables
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, n_embd)
        x = tok_emb + pos_emb # (B, T, n_embd)

        # apply blocks
        x = self.blocks(x)    # (B, T, n_embd)

        # final layer norm
        x = self.ln_f(x)      # (B, T, n_embd)

        # logits from language modeling head
        logits = self.lm_head(x) # (B, T, vocab_size)

        # calculate loss
        loss = calculate_loss(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is the current context
        # in each iteration idx will grow:
        # (B, T), (B, T+1), (B, T+2), ..., (B, T+max_new_tokens)

        for _ in range(max_new_tokens):

            # crop the context to the last block_size tokens
            idx_cond = idx[:, -block_size:]

            # get logits for current context (calling forward method)
            logits, _ = self(idx_cond) # (B, T, C)

            # focus only on the last time-step because
            # those are the predictions for what comes next
            logits = logits[:, -1, :] # (B, C)

            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)

            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

        return idx

Using layer normalization the validation loss went down from 1.94 to 1.93.

## Dropout

Dropout is a **regularization technique** described in the paper [Srivastava et al. (2014). *Dropout: A Simple Way to Prevent Neural Networks from Overfitting*](https://jmlr.org/papers/v15/srivastava14a.html) that consists on, in every step, **randomly shut off** some subset of neurons and train without them.

<div style="width: 550px">
    <img src="https://production-media.paperswithcode.com/methods/Screen_Shot_2020-05-23_at_6.19.24_PM.png">
</div>

<br>

Dropout is added when we calculate the affinities after Softmax so we randomly prevent some of the nodes from communicating:

In [39]:
dropout = 0.2

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()

        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        _, T, _ = x.shape # (B, T, n_embd)

        # compute keys and queries
        k = self.key(x)   # (B, T, hs)
        q = self.query(x) # (B, T, hs)

        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, T) = (B, T, hs) @ (B, hs, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        # perform the weighted aggregation of the values
        v = self.value(x) # (B, T, hs)
        out = wei @ v     # (B, T, hs) = (B, T, T) @ (B, T, hs)

        return out

Dropout is also added right after the projection back to the residual pathway:

In [40]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()

        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        # multi-headed self-attention
        out = torch.cat([h(x) for h in self.heads], dim=-1)

        # projection and dropout layers
        out = self.dropout(self.proj(out))

        return out

In [41]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd), # growing inner-layer
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # projection layer
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

## Scaling up the model

In [42]:
# hyperparameters
batch_size = 64       # number of chunks per bacth
block_size = 256      # chunks maximum context length
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200      # how many iterations used to calculate the loss
eval_interval = 1000   # every how many iterations calculate the loss
learning_rate = 3e-4
max_iters = 5000
n_embd = 384          # number of embedding dimensions
n_layer = 6           # number of blocks
n_head = 6            # number of heads
dropout = 0.2         # dropout percentage

In [43]:
# intialize the model
model = GPTLanguageModel()
m = model.to(device)

print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

10.788929 M parameters


In [44]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [45]:
# train the model
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"Step: {iter:4d}/{max_iters:4d}   Train loss: {losses['train']:.4f}   Val loss: {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # forward pass
    logits, loss = model(xb, yb)

    # backward pass
    optimizer.zero_grad(set_to_none=True)
    loss.backward()

    # update
    optimizer.step()

KeyboardInterrupt: 

In [47]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=1000)[0].tolist()))



LEONTES:
Whither, this? Hasting foe!
This chair that banish'd his battle spirits,
And diso mers it tellingn their devils.
Good not! I know myself-sper, for where he they
Command. Besides, quoth deids us the hasty ture.
Procks shall be salinting, wanting by their death
The whitness, wite-arwarting up with rashing feast
With record the heed memorsel of our souls, make with chape,
Most of the watchest hath moved, we must.

Nurse:
Provost, you speak; again, love.
Ah, fellow, thus farewell medder-tirrion!
And now I know follows me for my heads mark'd thee!
Mach's spoon-bear! come on their bosoms
And pale and pebble their own posing treats,
Or wealth em or a scepsed, friend and bight
Most time have done an heart,
Forthwell amplot till way upon his soul talls,
False their beadins his womble.

SLY:
Though he must have found you all fitter up,
Shall we have abser'd the state of a fiend.
But ke drowns, I with tumbly manaters
Forbade, as deceit for this dust of rear,
For Juliet's courtesy, for 