# Lab 9: Transformers - Attention Is All You Need

## Learning Objectives

By the end of this lab, you will be able to:
1. Explain why transformers replaced RNNs for sequence modeling
2. Derive scaled dot-product attention and explain the scaling factor
3. Implement a complete transformer encoder block from scratch
4. Describe the differences between encoder-only, decoder-only, and encoder-decoder architectures
5. Understand how pre-trained transformers (BERT, GPT, T5) relate to the core architecture

## Prerequisites

This lab builds directly on **Lab 8 (Seq2Seq with Attention)**. You should understand:
- Attention mechanism: computing $\alpha_{t,j}$ weights over encoder states
- Why attention solves the "bottleneck" problem
- The seq2seq encoder-decoder paradigm

If needed, review `lab_8_seq2seq_attention.ipynb` before continuing.

## Lab Structure

**Core Content (Required):**
- Parts 1-6: Transformer architecture from first principles
- Hands-on exercise: Build a transformer block

**Optional Advanced Section:**
- Part 7: Fine-tuning pre-trained transformers (T5 summarization example)
- Can be skipped if focusing on fundamentals

In [None]:
# ==== Environment Setup ====
# Detects Colab vs local and provides cross-platform utilities

import os
import sys

# Detect environment
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("✓ Running on Google Colab")
else:
    print("✓ Running locally")

def download_file(url: str, filename: str) -> str:
    """Download file if it doesn't exist. Works on both Colab and local."""
    if os.path.exists(filename):
        print(f"✓ {filename} already exists")
        return filename
    
    print(f"Downloading {filename}...")
    if IN_COLAB:
        import subprocess
        subprocess.run(['wget', '-q', url, '-O', filename], check=True)
    else:
        import urllib.request
        urllib.request.urlretrieve(url, filename)
    print(f"✓ Downloaded {filename}")
    return filename

In [None]:
# ==== Device Setup ====
import torch

def get_device():
    """Get best available device: CUDA > MPS > CPU."""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"✓ Using CUDA GPU: {torch.cuda.get_device_name(0)}")
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("✓ Using Apple MPS (Metal)")
    else:
        device = torch.device('cpu')
        print("✓ Using CPU")
    return device

DEVICE = get_device()

## Part 1: From RNN Attention to Self-Attention

### Recap: Attention in Seq2Seq (Lab 8)

In Lab 8, we learned that attention allows the decoder to "look back" at encoder states:

$$c_t = \sum_j \alpha_{t,j} h_j^{enc}$$

This solved the bottleneck problem - the decoder no longer relies on a single context vector.

**But this approach still uses RNNs**, which have two key limitations:
1. **Sequential computation**: Can't parallelize across time steps (slow training)
2. **Long-range dependencies**: Despite attention, gradients still flow through RNN hidden states

### The Key Insight: Attention Without Recurrence

> *"What if we used ONLY attention, no RNNs at all?"*

This is the transformer insight from Vaswani et al. (2017) - **"Attention Is All You Need"**:
- Replace recurrence entirely with self-attention
- Each position can directly attend to every other position
- Enables full parallelization during training

### Self-Attention vs Cross-Attention

| Type | Query Source | Key/Value Source | Use Case |
|------|--------------|------------------|----------|
| **Cross-attention** (Lab 8) | Decoder | Encoder | Seq2seq translation |
| **Self-attention** (This lab) | Same sequence | Same sequence | Encoding/understanding |

<details>
<summary><b>Q: What's the computational complexity trade-off between RNNs and self-attention?</b></summary>

**A:** 
- **RNN**: O(n) sequential steps - cannot parallelize across time
- **Self-attention**: O(1) parallel steps, but O(n²) memory for the attention matrix

This is the key trade-off: self-attention enables parallelization but has quadratic memory cost in sequence length.
</details>

## Part 2: Scaled Dot-Product Attention

### The Core Equation

Given queries $Q$, keys $K$, and values $V$:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

**Step by step:**
1. Compute similarity scores: $QK^\top$ gives a $(n \times n)$ matrix of how much each query attends to each key
2. Scale by $\sqrt{d_k}$ to prevent extreme values
3. Apply softmax to get attention weights (each row sums to 1)
4. Weighted sum of values using these weights

### Why Scale by $\sqrt{d_k}$? (Full Derivation)

Let $q, k \in \mathbb{R}^{d_k}$ be query and key vectors with components drawn i.i.d. from $\mathcal{N}(0, 1)$.

The dot product is:
$$q \cdot k = \sum_{i=1}^{d_k} q_i k_i$$

**Step 1: Expected value**
$$\mathbb{E}[q \cdot k] = \sum_{i=1}^{d_k} \mathbb{E}[q_i k_i] = \sum_{i=1}^{d_k} \mathbb{E}[q_i]\mathbb{E}[k_i] = 0$$
(by independence and zero mean)

**Step 2: Variance**
Since $q_i$ and $k_i$ are independent:
$$\text{Var}(q_i k_i) = \mathbb{E}[q_i^2 k_i^2] - \mathbb{E}[q_i k_i]^2 = \mathbb{E}[q_i^2]\mathbb{E}[k_i^2] - 0 = 1 \cdot 1 = 1$$

Therefore:
$$\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i k_i) = d_k$$

**Step 3: Effect on Softmax**
When $d_k = 512$, the standard deviation of $q \cdot k$ is $\sqrt{512} \approx 22.6$.

Raw dot products can easily reach $\pm 60$, where softmax saturates:
- $\text{softmax}([60, 0, 0]) \approx [1.0, 0.0, 0.0]$ — gradient is nearly zero!

Dividing by $\sqrt{d_k}$ normalizes variance to 1, keeping scores in a reasonable range where softmax has meaningful gradients.

<details>
<summary><b>Q: If d_k = 64, what is the standard deviation of raw dot products before scaling?</b></summary>

**A:** $\sqrt{64} = 8$. Without scaling, dot products would have std dev of 8, potentially pushing softmax toward saturation for extreme values.
</details>

<details>
<summary><b>Q: What would happen if we scaled by $d_k$ instead of $\sqrt{d_k}$?</b></summary>

**A:** The variance would become $d_k / d_k^2 = 1/d_k$, making scores too small. Attention weights would become nearly uniform (all positions attended equally), losing the ability to focus.
</details>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

class ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention from "Attention Is All You Need"
    
    Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
    """
    def __init__(self, d_k: int):
        super().__init__()
        self.scale = math.sqrt(d_k)
    
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, 
                mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            Q: Queries of shape (batch, seq_len, d_k) or (batch, heads, seq_len, d_k)
            K: Keys of shape (batch, seq_len, d_k) or (batch, heads, seq_len, d_k)
            V: Values of shape (batch, seq_len, d_v) or (batch, heads, seq_len, d_v)
            mask: Optional mask of shape (batch, seq_len, seq_len)
        
        Returns:
            output: Weighted values, same shape as V
            attn_weights: Attention weights of shape (..., seq_len, seq_len)
        """
        # Compute attention scores: (batch, ..., seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        
        # Apply mask if provided (e.g., for causal attention)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax over last dimension (keys)
        attn_weights = F.softmax(scores, dim=-1)  # (batch, ..., seq_len, seq_len)
        
        # Weighted sum of values
        output = torch.matmul(attn_weights, V)  # (batch, ..., seq_len, d_v)
        
        return output, attn_weights

# Quick test
d_k = 64
attention = ScaledDotProductAttention(d_k)
Q = torch.randn(2, 10, d_k)  # (batch=2, seq_len=10, d_k=64)
K = torch.randn(2, 10, d_k)
V = torch.randn(2, 10, d_k)
output, weights = attention(Q, K, V)
print(f"Input shape: {Q.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"Attention weights sum to 1? {weights[0, 0].sum().item():.4f}")

## Part 3: Multi-Head Attention

### Why Multiple Heads?

Single attention can only focus on one type of relationship at a time. **Multi-head attention** runs $h$ attention operations in parallel:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h) W^O$$

where each head is: $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

**Benefits:**
- Each head can attend to different aspects (syntax, semantics, position, etc.)
- Increases model capacity without increasing depth
- Different heads often specialize (empirically observed)

<details>
<summary><b>Q: If we have 8 heads and $d_{model}=512$, what is $d_k$ per head?</b></summary>

**A:** $d_k = d_{model} / h = 512 / 8 = 64$. Each head operates in a 64-dimensional subspace.

This keeps the total computation roughly the same as single-head attention with full $d_{model}$.
</details>

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention: h parallel attention heads, concatenated and projected.
    """
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head
        
        # Linear projections for Q, K, V and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.attention = ScaledDotProductAttention(self.d_k)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                mask: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            Q, K, V: Shape (batch, seq_len, d_model)
            mask: Optional attention mask
        Returns:
            output: Shape (batch, seq_len, d_model)
        """
        batch_size = Q.size(0)
        
        # 1. Linear projections and reshape for multi-head
        # (batch, seq_len, d_model) -> (batch, seq_len, n_heads, d_k) -> (batch, n_heads, seq_len, d_k)
        Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. Apply attention to all heads in parallel
        attn_output, attn_weights = self.attention(Q, K, V, mask)
        # attn_output: (batch, n_heads, seq_len, d_k)
        
        # 3. Concatenate heads: (batch, n_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # 4. Final linear projection
        output = self.W_o(attn_output)
        
        return self.dropout(output)

# Test multi-head attention
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(2, 10, 512)  # (batch=2, seq_len=10, d_model=512)
output = mha(x, x, x)  # Self-attention: Q=K=V
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

In [None]:
# Visualize the learned projection matrices (Q, K, V weights)
import seaborn as sns

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

# Get the weight matrices
Q_weights = mha.W_q.weight.detach().numpy()
K_weights = mha.W_k.weight.detach().numpy()
V_weights = mha.W_v.weight.detach().numpy()
O_weights = mha.W_o.weight.detach().numpy()

for ax, weights, title in zip(axes, 
                               [Q_weights, K_weights, V_weights, O_weights],
                               ['Query (W_q)', 'Key (W_k)', 'Value (W_v)', 'Output (W_o)']):
    sns.heatmap(weights[:32, :32], ax=ax, cmap='RdBu', center=0, 
                cbar=False, xticklabels=False, yticklabels=False)
    ax.set_title(title)
    ax.set_xlabel('Input dim')
    ax.set_ylabel('Output dim')

plt.suptitle('Multi-Head Attention Projection Matrices (first 32x32)')
plt.tight_layout()
plt.show()

print("These matrices learn to project inputs into Query, Key, Value spaces.")
print("After training, different heads will learn different projection patterns.")

<details>
<summary><b>Q: Why not just use more attention layers instead of multiple heads?</b></summary>

**A:** Multiple heads within a single layer can capture different relationship types *simultaneously* without increasing depth. Each head operates in a different subspace (d_k = d_model / n_heads), allowing the model to jointly attend to information from different representation subspaces. Deeper networks are harder to train and have more sequential dependencies.
</details>

<details>
<summary><b>Q: What's the total parameter count for multi-head attention vs single-head?</b></summary>

**A:** Both have the same parameter count! Multi-head with h heads uses $W_q, W_k, W_v \in \mathbb{R}^{d_{model} \times d_{model}}$ plus $W_o \in \mathbb{R}^{d_{model} \times d_{model}}$. The "splitting" into heads is just a reshape, not additional parameters.
</details>

## Part 4: Positional Encoding

### The Problem: Self-Attention is Permutation Invariant

Self-attention treats input as a **set**, not a sequence - it doesn't know word order!

Consider: `"cat sat mat"` vs `"mat sat cat"`

Without position information, self-attention gives identical outputs for both.

### The Solution: Add Position Information

The original transformer uses **sinusoidal positional encoding**:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$
$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

**Why sinusoids?**
- Can extrapolate to longer sequences than seen during training
- $PE_{pos+k}$ can be expressed as a linear function of $PE_{pos}$ (captures relative positions)
- Different frequencies capture different position scales

<details>
<summary><b>Q: What happens if we don't add positional encoding?</b></summary>

**A:** The model treats input as a "bag of words" - it cannot distinguish `"dog bites man"` from `"man bites dog"`. Word order information is completely lost.
</details>

### Key Takeaways - Positional Encoding
- Self-attention alone is permutation invariant (order-blind)
- Positional encoding injects sequence order information
- Added to input embeddings, not concatenated
- Sinusoidal encodings allow length generalization

In [None]:
class PositionalEncoding(nn.Module):
    """
    Sinusoidal Positional Encoding from "Attention Is All You Need"
    
    PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    """
    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)  # (max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
        
        # Compute the divisor term: 10000^(2i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        # Apply sin to even indices, cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)  # Even dimensions
        pe[:, 1::2] = torch.cos(position * div_term)  # Odd dimensions
        
        # Register as buffer (not a parameter, but should be saved/loaded)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input embeddings of shape (batch, seq_len, d_model)
        Returns:
            x + positional encoding
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Visualize positional encoding
import matplotlib.pyplot as plt

pe = PositionalEncoding(d_model=128, max_len=100, dropout=0.0)
dummy_input = torch.zeros(1, 100, 128)
encoded = pe(dummy_input)

plt.figure(figsize=(12, 4))
plt.imshow(pe.pe[0].numpy().T, aspect='auto', cmap='RdBu')
plt.colorbar()
plt.xlabel('Position')
plt.ylabel('Dimension')
plt.title('Sinusoidal Positional Encoding')
plt.show()

## Part 5: The Transformer Block

### Architecture Overview

A transformer block combines all the pieces:

```
Input
  │
  ▼
┌─────────────────────────────┐
│   Multi-Head Self-Attention │
└─────────────────────────────┘
  │
  ▼ (+) ← Residual Connection
  │
  ▼ [LayerNorm]
  │
  ▼
┌─────────────────────────────┐
│   Feed-Forward Network      │
│   (Linear → ReLU → Linear)  │
└─────────────────────────────┘
  │
  ▼ (+) ← Residual Connection
  │
  ▼ [LayerNorm]
  │
Output
```

### Layer Normalization

LayerNorm normalizes across the feature dimension (not batch):

$$\text{LayerNorm}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

where $\mu, \sigma^2$ are computed across the feature dimension for each position independently.

**Why LayerNorm over BatchNorm?**
- BatchNorm requires batch statistics — problematic for variable-length sequences and small batches
- LayerNorm works identically at train and inference time
- Each position is normalized independently (no cross-sequence dependencies)

### Residual Connections

$$\text{output} = \text{sublayer}(x) + x$$

**Why residuals?**
1. **Gradient flow**: Direct path for gradients to flow backward (mitigates vanishing gradients)
2. **Identity mapping**: Easy for the network to learn "do nothing" for some layers
3. **Optimization**: Each layer only needs to learn the *residual* (difference from identity)

### Feed-Forward Network

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

- Inner dimension $d_{ff} = 4 \times d_{model}$ (typical choice)
- Acts as a "memory" that stores learned patterns
- Applied position-wise (same FFN for all positions)

<details>
<summary><b>Q: Why do we need the FFN? Isn't attention enough?</b></summary>

**A:** Attention is a *linear* operation over values (weighted sum). FFN adds **non-linearity** (ReLU) and allows the model to transform representations. Without FFN, stacking attention layers would have limited expressiveness — you can't approximate arbitrary functions with just linear combinations.
</details>

<details>
<summary><b>Q: What would happen if we removed the residual connections?</b></summary>

**A:** Training would become very difficult, especially for deep transformers. Gradients would need to flow through many layers of attention and FFN, leading to vanishing gradients. The original transformer paper used 6 layers — without residuals, even this depth would be hard to train.
</details>

### Pre-LN vs Post-LN

Two variants exist:
- **Post-LN** (original): `x = LayerNorm(x + Sublayer(x))`
- **Pre-LN** (more stable): `x = x + Sublayer(LayerNorm(x))`

Pre-LN is easier to train (gradients flow directly through residuals) but may have slightly lower final performance. We implement Post-LN below (the original design).

In [None]:
class FeedForward(nn.Module):
    """
    Position-wise Feed-Forward Network.
    FFN(x) = max(0, xW1 + b1)W2 + b2
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)    # (d_model -> d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)    # (d_ff -> d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, d_model)
        x = self.linear1(x)      # (batch, seq_len, d_ff)
        x = F.relu(x)            # Non-linearity
        x = self.dropout(x)
        x = self.linear2(x)      # (batch, seq_len, d_model)
        return x


class TransformerEncoderBlock(nn.Module):
    """
    A single Transformer Encoder Block (Post-LN variant, as in original paper).
    
    Architecture:
    x -> MultiHeadAttention -> Dropout -> (+x) -> LayerNorm -> FFN -> Dropout -> (+) -> LayerNorm -> output
    """
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            x: Input of shape (batch, seq_len, d_model)
            mask: Optional attention mask
        Returns:
            Output of shape (batch, seq_len, d_model)
        """
        # Self-attention sublayer with residual (Post-LN)
        attn_output = self.self_attn(x, x, x, mask)  # Q=K=V for self-attention
        x = self.norm1(x + self.dropout(attn_output))  # Residual + LayerNorm
        
        # FFN sublayer with residual (Post-LN)
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))   # Residual + LayerNorm
        
        return x

# Test the transformer block
block = TransformerEncoderBlock(d_model=512, n_heads=8, d_ff=2048)
x = torch.randn(2, 10, 512)  # (batch=2, seq_len=10, d_model=512)
output = block(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Number of parameters: {sum(p.numel() for p in block.parameters()):,}")
print(f"\nParameter breakdown:")
print(f"  Multi-head attention: {sum(p.numel() for p in block.self_attn.parameters()):,}")
print(f"  FFN: {sum(p.numel() for p in block.ffn.parameters()):,}")
print(f"  LayerNorms: {sum(p.numel() for p in block.norm1.parameters()) + sum(p.numel() for p in block.norm2.parameters()):,}")

## Part 6: Encoder vs Decoder Architectures

### Encoder (Bidirectional Self-Attention)
- Each position attends to **ALL** positions (including future)
- Used for understanding/encoding the input
- Example: BERT

### Decoder (Causal/Masked Self-Attention)
- Each position attends only to **PREVIOUS** positions (and itself)
- Prevents "cheating" - can't look at future tokens during generation
- Uses **causal mask**: upper triangle of attention matrix set to $-\infty$
- Example: GPT

### Encoder-Decoder (Cross-Attention)
- Decoder has **two** attention layers per block:
  1. **Masked self-attention** on decoder tokens
  2. **Cross-attention**: Q from decoder, K/V from encoder (this is Lab 8's attention!)
- Example: Original Transformer, T5

### Architecture Variants

| Model | Architecture | Pre-training Task | Best For |
|-------|--------------|-------------------|----------|
| **BERT** | Encoder-only | Masked Language Model | Classification, NER, QA |
| **GPT** | Decoder-only | Causal Language Model | Text generation |
| **T5** | Encoder-Decoder | Span corruption | Translation, summarization |

<details>
<summary><b>Q: Why can't we use bidirectional attention for text generation?</b></summary>

**A:** During generation, we predict one token at a time. If the model could see future tokens, it would "cheat" by just copying them. Causal masking ensures the model only uses information available at generation time.
</details>

In [None]:
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    Create a causal (lower triangular) mask for decoder self-attention.
    
    Returns:
        mask: (seq_len, seq_len) with 1s in lower triangle, 0s in upper triangle
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

# Visualize causal mask
mask = create_causal_mask(8)
plt.figure(figsize=(6, 6))
plt.imshow(mask.numpy(), cmap='Blues')
plt.colorbar()
plt.xlabel('Key Position')
plt.ylabel('Query Position')
plt.title('Causal Mask (1 = can attend, 0 = cannot)')
for i in range(8):
    for j in range(8):
        plt.text(j, i, f'{int(mask[i,j].item())}', ha='center', va='center')
plt.show()

print("Position 3 can attend to positions:", mask[3].nonzero().squeeze().tolist())

### Computational Complexity Analysis

Understanding transformer complexity is crucial for practical applications:

| Operation | Time Complexity | Space Complexity | Parallelizable? |
|-----------|-----------------|------------------|-----------------|
| **RNN** (per layer) | $O(n \cdot d^2)$ | $O(n \cdot d)$ | No (sequential) |
| **Self-Attention** | $O(n^2 \cdot d)$ | $O(n^2 + n \cdot d)$ | Yes |

Where $n$ = sequence length, $d$ = model dimension.

**Key insights:**
- For **short sequences** ($n < d$): Self-attention is more efficient
- For **long sequences** ($n > d$): The $n^2$ term dominates — this is why LLMs have context length limits!
- The $O(n^2)$ attention matrix is the bottleneck that drives research on efficient transformers (Linformer, Performer, FlashAttention)

<details>
<summary><b>Q: A transformer with d_model=1024 processes a sequence of length 4096. How many elements are in the attention matrix?</b></summary>

**A:** The attention matrix is $n \times n = 4096 \times 4096 = 16.7$ million elements. With 8 heads and batch size 32, that's $32 \times 8 \times 16.7M = 4.3$ billion elements just for attention weights! This is why long-context models require specialized techniques.
</details>

## Exercises

These exercises reinforce your understanding. Try them before moving to Part 7.

### Exercise 1: Attention Weight Visualization
```python
# TODO: Implement a function that:
# 1. Takes a trained attention layer and input sequence
# 2. Returns attention weights as a heatmap
# 3. Visualizes which positions attend to which

def visualize_attention_weights(attn_layer, x):
    """
    Visualize attention patterns for input x.
    
    Args:
        attn_layer: MultiHeadAttention module
        x: Input tensor (1, seq_len, d_model)
    """
    # Your implementation here
    pass
```

### Exercise 2: Ablation Study
Modify `TransformerEncoderBlock` to create these variants and observe the effects:
1. Remove residual connections
2. Remove LayerNorm
3. Remove the FFN entirely

Train each variant on a simple task (e.g., copying) and compare:
- Training stability
- Final loss
- Gradient magnitudes

### Exercise 3: Positional Encoding Experiments
1. What happens if you multiply positional encodings by 10? By 0.1?
2. What if you use only sine (no cosine)?
3. Implement **learned** positional embeddings and compare with sinusoidal

### Exercise 4: Build a Decoder Block
Implement `TransformerDecoderBlock` with:
- Masked self-attention (causal)
- Cross-attention (Q from decoder, K/V from encoder)
- FFN with residuals

```python
class TransformerDecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        # TODO: Initialize layers
        # - Masked self-attention
        # - Cross-attention  
        # - FFN
        # - Three LayerNorms
        pass
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # TODO: Implement forward pass
        pass
```

## Part 7: Pre-trained Transformers

### The Pre-training Revolution

Training transformers from scratch requires **massive compute** (GPT-3: ~$5M, thousands of GPUs).

Instead, we use **pre-trained models**:
1. **Pre-training**: Train on huge text corpus with self-supervised objective
2. **Fine-tuning**: Adapt to specific task with smaller labeled dataset

### Why Pre-training Works

- **Lower layers** learn general language features (syntax, word relationships)
- **Higher layers** learn more task-specific features
- Fine-tuning adjusts all layers but mostly changes the higher ones
- Transfer learning: knowledge from pre-training transfers to downstream tasks

### Self-Assessment Checklist

By the end of the core content, you should be able to:
- [ ] Explain why self-attention replaces recurrence
- [ ] Derive scaled dot-product attention and explain the $\sqrt{d_k}$ scaling
- [ ] Describe multi-head attention and why we use multiple heads
- [ ] Explain why positional encoding is necessary
- [ ] Draw the transformer block architecture and explain each component
- [ ] Distinguish encoder-only, decoder-only, and encoder-decoder architectures

## References

1. Vaswani et al. (2017). "Attention Is All You Need" - [arXiv:1706.03762](https://arxiv.org/abs/1706.03762)
2. Devlin et al. (2019). "BERT: Pre-training of Deep Bidirectional Transformers"
3. Radford et al. (2018-2020). GPT, GPT-2, GPT-3 papers
4. Raffel et al. (2020). "Exploring the Limits of Transfer Learning with T5"
5. [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/) - Jay Alammar (excellent visual guide!)

## Training Transformers: Practical Considerations

Before diving into fine-tuning, understand what makes transformer training work:

### Learning Rate Schedule
Transformers are sensitive to learning rate. The original paper uses **warmup + decay**:
```
lr = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
```
- Start with small LR (warmup prevents early instability)
- Then decay as training progresses

### Other Key Techniques
- **Dropout**: Applied to attention weights and FFN outputs (typically 0.1)
- **Label smoothing**: Prevents overconfidence, improves generalization
- **Gradient clipping**: Prevents exploding gradients (clip at norm 1.0)
- **Weight initialization**: Xavier/Glorot for linear layers

### Common Pitfalls
1. **Forgetting the scale factor** — attention weights become spiky, gradients vanish
2. **Wrong mask dimensions** — should broadcast to (batch, heads, seq, seq)
3. **Softmax on wrong dimension** — must be over keys (last dim), not queries
4. **Memory issues** — attention is O(n²), reduce batch size for long sequences

---

# [Optional] Advanced: Fine-tuning Pre-trained Transformers

> **Note:** This section is **optional**. The core transformer concepts are complete above. 
> Continue here when you're ready to see how pre-trained models are used in practice.

---

## Practical Application: T5 for Summarization

Now that you understand transformer architecture, let's see it in action. We'll fine-tune **T5** (Text-to-Text Transfer Transformer) for summarization.

**Why T5?**
- Encoder-decoder architecture (like the original transformer)
- Uses text-to-text framework: all tasks are "text in, text out"
- Task specified by prefix: `"summarize: "`, `"translate English to German: "`, etc.

**What you'll learn:**
- How HuggingFace simplifies transformer usage
- The fine-tuning workflow
- Evaluation with ROUGE metrics

## Step 1: Setup and Data

First, let's install the required libraries and load a summarization dataset.

In [None]:
# Install required packages (uncomment if needed)
# !pip install transformers datasets evaluate rouge_score

from datasets import load_dataset
from transformers import AutoTokenizer

# Load BillSum dataset (US Congressional bills and summaries)
# Using a small subset for demonstration
billsum = load_dataset("billsum", split="ca_test[:100]")  # 100 examples

print(f"Dataset size: {len(billsum)}")
print(f"Example keys: {billsum[0].keys()}")
print(f"\nExample text (first 200 chars): {billsum[0]['text'][:200]}...")
print(f"\nExample summary: {billsum[0]['summary'][:200]}...")

## Step 2: Tokenization

T5 expects inputs in a specific format with a task prefix.

In [None]:
# Load T5 tokenizer
checkpoint = "google-t5/t5-small"  # Smaller model for faster training
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# T5 uses prefix to specify the task
prefix = "summarize: "

def preprocess_function(examples):
    """Tokenize inputs and targets for T5."""
    # Add prefix to inputs
    inputs = [prefix + doc for doc in examples["text"]]
    
    # Tokenize inputs
    model_inputs = tokenizer(
        inputs, 
        max_length=512,  # T5-small max length
        truncation=True,
        padding="max_length"
    )
    
    # Tokenize targets (summaries)
    labels = tokenizer(
        examples["summary"],
        max_length=128,
        truncation=True,
        padding="max_length"
    )
    
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Apply preprocessing
tokenized_data = billsum.map(preprocess_function, batched=True)
print(f"Tokenized dataset columns: {tokenized_data.column_names}")

## Step 3: Evaluation Metrics - ROUGE

**ROUGE** (Recall-Oriented Understudy for Gisting Evaluation) measures summary quality:

| Metric | What it measures |
|--------|------------------|
| **ROUGE-1** | Unigram (single word) overlap |
| **ROUGE-2** | Bigram (two consecutive words) overlap |
| **ROUGE-L** | Longest common subsequence |

Higher scores = better overlap with reference summary.

<details>
<summary><b>Q: Why use ROUGE instead of just accuracy?</b></summary>

**A:** Summarization has many valid outputs - there's no single "correct" answer. ROUGE measures content overlap flexibly, allowing for paraphrasing. A summary saying "The cat sat" and "A cat was sitting" would both partially match a reference.
</details>

In [None]:
import evaluate
import numpy as np

rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    """Compute ROUGE scores for model predictions."""
    predictions, labels = eval_pred
    
    # Decode predictions
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    
    # Replace -100 in labels (padding) with pad token id
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Compute ROUGE scores
    result = rouge.compute(
        predictions=decoded_preds, 
        references=decoded_labels, 
        use_stemmer=True  # Use word stems for more flexible matching
    )
    
    return {k: round(v * 100, 2) for k, v in result.items()}

## Step 4: Training with HuggingFace Trainer

The `Trainer` class handles the training loop, evaluation, and checkpointing.

In [None]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

# Load pre-trained T5 model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

# Data collator handles dynamic padding
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./t5_summarization",
    eval_strategy="no",  # Skip evaluation for demo
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    num_train_epochs=1,  # Just 1 epoch for demo
    weight_decay=0.01,
    save_total_limit=1,
    predict_with_generate=True,  # Use generation for evaluation
    logging_steps=10,
    report_to="none",  # Disable wandb/tensorboard for demo
)

# Create trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("Trainer ready! Run trainer.train() to start fine-tuning.")
print("(This may take several minutes even on GPU)")

## Step 5: Inference with Fine-tuned Model

After training (or using a pre-trained model), generate summaries:

In [None]:
# For demo, use a pre-fine-tuned model (skip training time)
from transformers import pipeline

# Load a model already fine-tuned on summarization
summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")

# Example text
text = """
The Inflation Reduction Act lowers prescription drug costs, health care costs, 
and energy costs. It's the most aggressive action on tackling the climate crisis 
in American history, which will lift up American workers and create good-paying, 
union jobs across the country. It'll lower the deficit and ask the ultra-wealthy 
and corporations to pay their fair share. And no one making under $400,000 per 
year will pay a penny more in taxes.
"""

# Generate summary
summary = summarizer(text, max_length=50, min_length=20, do_sample=False)
print("Original text:")
print(text)
print("\nGenerated summary:")
print(summary[0]['summary_text'])

## Summary: What You've Learned

### Core Transformer Concepts (Parts 1-6)
1. **Self-attention** replaces recurrence, enabling parallelization
2. **Scaled dot-product attention**: $\text{softmax}(QK^\top/\sqrt{d_k})V$
3. **Multi-head attention** captures different relationship types
4. **Positional encoding** injects sequence order information
5. **Transformer block** = Attention + FFN + Residuals + LayerNorm
6. **Architecture variants**: Encoder-only, decoder-only, encoder-decoder

### Practical Skills (Part 7 - Optional)
- Fine-tuning pre-trained transformers with HuggingFace
- Evaluating text generation with ROUGE metrics
- Using the `pipeline` API for quick inference

## What's Next

Transformers have revolutionized not just NLP but also:
- **Computer Vision**: Vision Transformers (ViT), CLIP
- **Multimodal**: GPT-4V, Gemini, LLaVA
- **Science**: AlphaFold (protein structure prediction)

The attention mechanism you learned today is the foundation of all these advances!