# Notebook 6: The Main Event - Building a Transformer (GPT)

In our previous notebooks, we built models that could classify images. But what if we want to build a model that can understand and generate text—like ChatGPT or Claude? That's where the Transformer architecture comes in.

## The Limitation of the Bigram Model

Before Transformers, simple language models like the Bigram model worked by predicting the next token based only on the immediately previous token. This approach has a critical limitation: **it has no context**. Each prediction is made in isolation, without understanding the broader meaning of the sentence or paragraph. 

Imagine trying to complete a sentence where you can only see the last word: "The cat sat on the..." Without seeing "cat" earlier in the sentence, you might predict "table" just as easily as "mat"—both are equally likely from a single-word context. This is why simple models produce incoherent, repetitive text.

## Introducing the Transformer

The Transformer architecture, introduced in the paper "Attention Is All You Need" (2017), revolutionized natural language processing. It's the foundation behind models like GPT, BERT, and T5. The Transformer's superpower is **self-attention**.

Self-attention allows every token in the context to look at every other token and pass information around, asking questions like "Hey, I'm a noun, are there any adjectives here that describe me?" This communication is what allows the model to build a rich understanding of the context before making a prediction. Instead of making predictions in isolation, the model can consider relationships between distant words, understand grammatical structure, and capture semantic meaning across the entire sequence.


## The Problem with Attention and the Solution: Flash Attention

Standard self-attention has a computational bottleneck: it needs to create a large `(sequence_length, sequence_length)` matrix to store attention scores between every pair of tokens. For a sequence of length 1024, this means storing over 1 million attention scores. This matrix is slow to read from and write to the GPU's main memory (HBM - High Bandwidth Memory).

The problem gets worse as sequences get longer: memory usage grows quadratically with sequence length, making it impossible to process very long sequences efficiently.

## Flash Attention: The Solution

Flash Attention, introduced by Dao et al. (2022), is a highly optimized algorithm that avoids creating this giant matrix. Instead of storing all attention scores in memory, Flash Attention uses a clever trick: it computes the attention output in chunks, keeping intermediate results only in the GPU's super-fast cache (SRAM - Static Random Access Memory).

Think of it this way: Instead of writing a huge intermediate report (the attention matrix) to a slow hard drive (HBM), Flash Attention does all its calculations in the CPU's super-fast cache (the GPU's SRAM), computing the final result in one go. This results in a massive speedup and uses much less memory—often 10-20x faster and requiring far less memory for long sequences.

The best part? Modern PyTorch versions include Flash Attention support, so you can use it with a simple function call!


## Implementing Attention the Modern Way

We don't need to implement Flash Attention from scratch—that would be thousands of lines of complex CUDA code! Modern PyTorch (version 2.0+) includes `torch.nn.functional.scaled_dot_product_attention`, which is the canonical way to implement attention. This single, fused function will automatically use a Flash Attention kernel if your hardware and software support it, falling back to standard attention otherwise.

The function signature is simple:
```python
F.scaled_dot_product_attention(query, key, value, is_causal=True)
```

### The Causal Mask

The `is_causal=True` argument is essential for language generation. It prevents tokens from "cheating" by looking at future tokens. During training, we can see the entire sequence at once, but during generation, we need to ensure that token N can only attend to tokens 0 through N-1. The causal mask enforces this constraint automatically—it's a simple boolean argument that does all the masking work for you!

This is much simpler than manually creating attention masks, and it's optimized for performance.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

# The Attention Head
class Head(nn.Module):
    """A single self-attention head."""
    
    def __init__(self, n_embd, head_size, dropout=0.1):
        super().__init__()
        # Each head has its own linear projections for query, key, and value
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x has shape (batch_size, sequence_length, n_embd)
        B, T, C = x.shape
        
        # Compute query, key, value
        q = self.query(x)  # (B, T, head_size)
        k = self.key(x)    # (B, T, head_size)
        v = self.value(x)  # (B, T, head_size)
        
        # Use PyTorch's optimized scaled_dot_product_attention
        # This will automatically use Flash Attention if available!
        out = F.scaled_dot_product_attention(
            q, k, v,
            is_causal=True,  # This enforces the causal mask automatically
            dropout_p=self.dropout.p if self.training else 0.0
        )
        
        return out  # (B, T, head_size)

# Test the Head
n_embd = 64
head_size = 16
head = Head(n_embd, head_size).to(device)

# Create a test input (batch_size=2, sequence_length=10, embedding_dim=64)
test_input = torch.randn(2, 10, n_embd).to(device)
output = head(test_input)
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Head successfully created and tested!")


## Building the Full Transformer Block

A single attention head is powerful, but we can make it even more powerful by combining multiple components:

### Multi-Head Attention

Running several attention heads in parallel allows the model to capture different types of relationships. One head might focus on grammatical relationships, another on semantic meaning, and another on long-range dependencies. By concatenating the outputs of multiple heads, we get a richer representation.

### FeedForward Network

After attention gathers information from across the sequence, a small MLP (typically with a 4x expansion) "processes" or "thinks about" this information. It applies a non-linear transformation that helps the model reason about the relationships it discovered.

### Residual Connections & LayerNorm

These are the same tricks from ResNet! They stabilize the training of a deep stack of transformer blocks. Residual connections allow gradients to flow directly through the network, while LayerNorm normalizes the activations, making training more stable and allowing us to stack many layers deep.

The complete Transformer Block combines all of these: Multi-Head Attention → Residual Connection & LayerNorm → FeedForward → Residual Connection & LayerNorm.


In [None]:
# Multi-Head Attention: Run multiple heads in parallel
class MultiHeadAttention(nn.Module):
    """Multiple attention heads running in parallel."""
    
    def __init__(self, n_embd, num_heads, dropout=0.1):
        super().__init__()
        assert n_embd % num_heads == 0, "n_embd must be divisible by num_heads"
        
        self.num_heads = num_heads
        self.head_size = n_embd // num_heads
        self.n_embd = n_embd
        
        # Create multiple heads
        self.heads = nn.ModuleList([Head(n_embd, self.head_size, dropout) for _ in range(num_heads)])
        
        # Output projection
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # Run each head in parallel and concatenate
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        # Apply output projection
        out = self.proj(out)
        out = self.dropout(out)
        return out


# FeedForward Network: A simple MLP
class FeedForward(nn.Module):
    """A simple 2-layer MLP with GELU activation."""
    
    def __init__(self, n_embd, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),  # Expand by 4x
            nn.GELU(),                       # GELU is smoother than ReLU
            nn.Linear(4 * n_embd, n_embd),  # Project back
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.net(x)


# The Transformer Block: The complete building block
class Block(nn.Module):
    """Transformer block: communication (attention) followed by computation (feedforward)."""
    
    def __init__(self, n_embd, num_heads, dropout=0.1):
        super().__init__()
        # Multi-head self-attention
        self.sa = MultiHeadAttention(n_embd, num_heads, dropout)
        # Feedforward network
        self.ffwd = FeedForward(n_embd, dropout)
        # Layer normalization
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
    def forward(self, x):
        # Attention with residual connection and layer norm
        x = x + self.sa(self.ln1(x))  # Residual connection
        
        # Feedforward with residual connection and layer norm
        x = x + self.ffwd(self.ln2(x))  # Residual connection
        
        return x

# Test the Block
n_embd = 64
num_heads = 4
block = Block(n_embd, num_heads).to(device)

test_input = torch.randn(2, 10, n_embd).to(device)
output = block(test_input)
print(f"Block input shape: {test_input.shape}")
print(f"Block output shape: {output.shape}")
print(f"Transformer Block successfully created!")


## The Final GPT Model

Now we can assemble the complete GPT (Generative Pre-trained Transformer) model. The architecture consists of:

1. **Token Embedding Layer**: Converts token indices (integers) into dense vectors, just like the Bigram model. Each token in the vocabulary gets mapped to a learnable embedding vector.

2. **Positional Embedding Layer**: Since attention treats all tokens equally, we need to give the model a sense of token order. Positional embeddings encode the position of each token in the sequence, allowing the model to understand "first word", "second word", etc.

3. **Stack of Transformer Blocks**: Multiple blocks (typically 6-12 for small models, 96+ for large models) stacked on top of each other. Each block refines the understanding of the sequence.

4. **Final LayerNorm**: Normalizes the final representations before the output layer.

5. **Output Linear Layer**: Maps the final hidden representations back to vocabulary logits (scores for each token in the vocabulary).

The `generate` method is almost identical to the Bigram model—we still sample tokens one at a time, but now each prediction benefits from the full context of all previous tokens!


In [None]:
class GPTLanguageModel(nn.Module):
    """GPT model: a stack of transformer blocks."""
    
    def __init__(self, vocab_size, n_embd, block_size, num_heads, num_layers, dropout=0.1):
        super().__init__()
        self.block_size = block_size
        # Each token directly reads off the logits from the embedding table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, num_heads, dropout) for _ in range(num_layers)])
        self.ln_f = nn.LayerNorm(n_embd)  # Final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)  # Language model head
        
        # Better initialization
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # idx and targets are both (B, T) tensors of integers
        tok_emb = self.token_embedding_table(idx)  # (B, T, n_embd)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))  # (T, n_embd)
        x = tok_emb + pos_emb  # (B, T, n_embd)
        x = self.blocks(x)  # (B, T, n_embd)
        x = self.ln_f(x)  # (B, T, n_embd)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Generate new tokens given a context.
        
        Args:
            idx: (B, T) array of indices in the current context
            max_new_tokens: Maximum number of tokens to generate
            temperature: Controls randomness (higher = more random)
            top_k: Only sample from top k most likely tokens
        """
        self.eval()
        for _ in range(max_new_tokens):
            # Crop idx to the last block_size tokens
            idx_cond = idx[:, -self.block_size:] if idx.shape[1] >= self.block_size else idx
            
            # Get the predictions
            logits, _ = self(idx_cond)
            # Focus only on the last time step
            logits = logits[:, -1, :] / temperature  # (B, C)
            
            # Optionally apply top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            
            # Apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        
        self.train()
        return idx

print("GPTLanguageModel class defined successfully!")


In [None]:
# Hyperparameters
vocab_size = 65  # Character-level vocabulary (for simplicity)
block_size = 256  # Maximum context length
n_embd = 384  # Embedding dimension
num_heads = 6  # Number of attention heads
num_layers = 6  # Number of transformer blocks
dropout = 0.1
batch_size = 64
learning_rate = 3e-4
max_iters = 5000
eval_interval = 500
eval_iters = 200

# Create a simple text dataset for demonstration
# In practice, you'd load a real dataset like Shakespeare, Wikipedia, etc.
text = """
The quick brown fox jumps over the lazy dog. 
The dog barks at the fox. The fox runs away quickly.
Machine learning is fascinating. Deep learning models can understand language.
Transformers are powerful architectures. Attention mechanisms enable long-range dependencies.
Natural language processing has advanced rapidly. Large language models can generate coherent text.
Artificial intelligence continues to evolve. Neural networks learn complex patterns.
"""

# Create character-level vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size}")
print(f"Characters: {''.join(chars)}")

# Create character-to-index mappings
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Encode the text
data = torch.tensor(encode(text), dtype=torch.long)

# Split into train and validation sets
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

print(f"Train data length: {len(train_data)}")
print(f"Val data length: {len(val_data)}")

def get_batch(split):
    """Generate a small batch of data."""
    data_split = train_data if split == 'train' else val_data
    ix = torch.randint(len(data_split) - block_size, (batch_size,))
    x = torch.stack([data_split[i:i+block_size] for i in ix])
    y = torch.stack([data_split[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# Instantiate the model
model = GPTLanguageModel(
    vocab_size=vocab_size,
    n_embd=n_embd,
    block_size=block_size,
    num_heads=num_heads,
    num_layers=num_layers,
    dropout=dropout
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nModel created!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop
@torch.no_grad()
def estimate_loss():
    """Estimate loss on train and val sets."""
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

print("\nStarting training...")
for iter_num in range(max_iters):
    # Every once in a while evaluate the loss on train and val sets
    if iter_num % eval_interval == 0 or iter_num == max_iters - 1:
        losses = estimate_loss()
        print(f"Step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    # Sample a batch of data
    xb, yb = get_batch('train')
    
    # Evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print("\nTraining completed!")

# Generate text
print("\n" + "="*50)
print("Generating text with GPT:")
print("="*50)

# Start with a context
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated = model.generate(context, max_new_tokens=200, temperature=0.8, top_k=50)
generated_text = decode(generated[0].tolist())
print(generated_text)

print("\n" + "="*50)
print("Notice how the GPT model produces much more coherent text")
print("compared to a Bigram model, thanks to self-attention!")
print("="*50)
