# Part 6.1: Transformer Architecture

The Transformer is the architecture behind GPT, BERT, T5, and virtually every state-of-the-art model in NLP, vision, and beyond. Introduced in the landmark 2017 paper *"Attention Is All You Need"*, it replaced recurrent networks with a design built entirely on the attention mechanisms you learned in the previous notebook. In this notebook, you will understand every piece of the Transformer and build one from scratch.

**F1 analogy:** Think of the Transformer as F1's modern strategy computer -- the system that ingests all available data streams (tire wear, fuel load, weather, gap to rivals) **in parallel** and produces real-time strategy calls. Older RNN-based models are like the pre-radio era when information traveled sequentially through pit boards, one lap at a time. The Transformer processes everything at once, just as a modern F1 strategy wall sees every car's telemetry simultaneously.

**Prerequisites:** Notebook 16 (Attention Mechanisms) -- you should be comfortable with Q/K/V attention, self-attention, multi-head attention, and causal masking.

---

## Learning Objectives

By the end of this notebook, you will be able to:

- [ ] Explain why Transformers replaced RNNs and the key innovations of the architecture
- [ ] Implement sinusoidal positional encoding and explain why position information is needed
- [ ] Build an Encoder block from scratch (multi-head attention + Add&Norm + FFN)
- [ ] Build a Decoder block from scratch (masked self-attention + cross-attention + FFN)
- [ ] Assemble a complete Transformer model in PyTorch
- [ ] Train a Transformer on a simple sequence task and visualize attention patterns
- [ ] Analyze parameter counts and explain why Transformers scale so well

---

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')
torch.manual_seed(42)
np.random.seed(42)

print("Setup complete!")
print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

---

## 1. "Attention Is All You Need" -- The Big Picture

### Intuitive Explanation

Before Transformers, sequence models were dominated by **Recurrent Neural Networks** (RNNs, LSTMs, GRUs). These process tokens one at a time, left to right, passing a hidden state from step to step like a game of telephone. This sequential nature creates two fundamental problems:

1. **No parallelization:** You cannot process token 5 until you have finished tokens 1-4. Training is painfully slow.
2. **Long-range dependencies fade:** Information from early tokens must survive through many sequential steps. Even with LSTMs, signals degrade over long distances.

The Transformer's key insight is radical: **throw away recurrence entirely.** Instead of processing tokens sequentially, let every token attend to every other token simultaneously using attention. This means:

- **Full parallelization:** All positions are processed at once during training
- **Direct connections:** Token 1 can directly attend to token 100 with no intermediaries
- **Constant path length:** Information travels in O(1) steps, not O(n)

**F1 analogy:** Imagine an RNN as a team that relays information via pit boards -- the driver sees one message per lap, and by lap 50 the message from lap 1 has been distorted through 49 handoffs. A Transformer is like modern F1 radio and telemetry: the strategy wall can instantly access data from *any* lap, *any* sector, *any* car -- all at once, with no degradation. That is why Transformers dominate: they process all positions in parallel, just as a modern F1 strategy system processes all data streams simultaneously.

### Comparison: RNN vs Transformer

| Property | RNN/LSTM | Transformer | F1 Parallel |
|----------|----------|-------------|-------------|
| **Processing** | Sequential (one token at a time) | Parallel (all tokens at once) | Pit boards (one message/lap) vs. full telemetry dashboard |
| **Long-range dependencies** | Degrades over distance | Direct attention at any distance | Forgetting early stint data vs. instant lap-1 recall |
| **Path length** | O(n) between distant tokens | O(1) between any two tokens | Information relay chain vs. direct radio link |
| **Training speed** | Slow (sequential bottleneck) | Fast (GPU-parallelizable) | Slow debriefs vs. real-time analytics |
| **Memory** | Fixed hidden state size | Grows with sequence length (O(n^2)) | Limited pit board vs. full data storage |
| **Positional info** | Built-in (sequential processing) | Must be explicitly added | Implicit lap count vs. explicit position encoding |

**The trade-off:** Transformers use O(n^2) memory for attention (every token attends to every other), but this is well worth it for the parallelization and quality gains.

### The High-Level Architecture

The original Transformer has an **encoder-decoder** structure:

```
INPUT TOKENS                          OUTPUT TOKENS
     |                                      |
  [Embedding + Positional Encoding]    [Embedding + Positional Encoding]
     |                                      |
  ┌──────────────┐                    ┌──────────────┐
  │  Encoder      │                    │  Decoder      │
  │  Block x N    │──────────────────▶│  Block x N    │
  │               │  (cross-attention) │               │
  └──────────────┘                    └──────────────┘
                                           |
                                    [Linear + Softmax]
                                           |
                                    OUTPUT PROBABILITIES
```

**Encoder:** Reads the full input and creates rich contextual representations.
**Decoder:** Generates output one token at a time, attending to both its own previous outputs AND the encoder's representations.

**F1 analogy:** The encoder is like the **telemetry and data acquisition system** -- it ingests all raw sensor data (tire temps, throttle traces, GPS coordinates) and builds a rich, contextual picture of the car's state. The decoder is the **strategy engineer** who translates that encoded picture into actionable calls: "Box this lap," "Switch to hards," "Push now." Cross-attention is the bridge -- how the strategy engineer *reads* the telemetry to make decisions.

Not all modern models use both sides:
- **GPT** uses only the decoder (autoregressive language modeling)
- **BERT** uses only the encoder (bidirectional understanding)
- **T5, BART** use the full encoder-decoder

In [None]:
# Visualization: RNN vs Transformer information flow
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# --- RNN: Sequential information flow ---
ax = axes[0]
n_tokens = 6
positions = np.arange(n_tokens)

# Draw tokens
for i in positions:
    circle = plt.Circle((i, 0), 0.3, fill=True, color='steelblue', alpha=0.8)
    ax.add_patch(circle)
    ax.text(i, 0, f't{i+1}', ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Draw sequential arrows
for i in range(n_tokens - 1):
    ax.annotate('', xy=(i+0.7, 0), xytext=(i+0.3, 0),
                arrowprops=dict(arrowstyle='->', color='red', lw=2))

# Show fading signal
for i in positions:
    alpha = max(0.1, 1.0 - i * 0.18)
    ax.add_patch(plt.Circle((i, 0), 0.3, fill=False, edgecolor='red', 
                             linewidth=2, alpha=alpha))

ax.set_xlim(-0.8, n_tokens - 0.2)
ax.set_ylim(-1.5, 1.5)
ax.set_title('RNN: Sequential Processing\n(signal fades over distance)', fontsize=13, fontweight='bold')
ax.set_aspect('equal')
ax.axis('off')

# --- Transformer: All-to-all attention ---
ax = axes[1]

# Draw tokens
for i in positions:
    circle = plt.Circle((i, 0), 0.3, fill=True, color='steelblue', alpha=0.8)
    ax.add_patch(circle)
    ax.text(i, 0, f't{i+1}', ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Draw attention connections (all-to-all)
for i in positions:
    for j in positions:
        if i != j:
            ax.plot([i, j], [0, 0], color='green', alpha=0.15, lw=1.5)

# Highlight a specific long-range connection
ax.annotate('', xy=(5-0.3, 0.15), xytext=(0+0.3, 0.15),
            arrowprops=dict(arrowstyle='->', color='green', lw=2.5))
ax.text(2.5, 0.5, 'Direct connection!', ha='center', fontsize=10, color='green', fontweight='bold')

ax.set_xlim(-0.8, n_tokens - 0.2)
ax.set_ylim(-1.5, 1.5)
ax.set_title('Transformer: Parallel Attention\n(direct path between any tokens)', fontsize=13, fontweight='bold')
ax.set_aspect('equal')
ax.axis('off')

plt.tight_layout()
plt.savefig('rnn_vs_transformer.png', dpi=100, bbox_inches='tight')
plt.show()
print("Left: RNN processes sequentially -- information from t1 must pass through every step to reach t6.")
print("Right: Transformer connects all tokens directly via attention -- t1 and t6 are just one step apart.")

---

## 2. Positional Encoding

### The Problem: Attention Has No Sense of Order

Here is a crucial insight: **self-attention is permutation invariant.** If you shuffle the input tokens, the attention outputs are also shuffled in exactly the same way -- the mechanism itself does not know or care about token order.

Think about it: `Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V`. The computation depends only on the *content* of the vectors, not their *position*. The sentence "the cat sat on the mat" and "mat the on sat cat the" would produce identical attention patterns (just rearranged).

But word order matters enormously! "The dog bit the man" and "The man bit the dog" have very different meanings. We need to inject position information somehow.

**The solution:** Add a **positional encoding** to each token embedding before feeding it into the Transformer. Each position gets a unique signal that the model can use to determine where tokens are in the sequence.

**F1 analogy:** Positional encoding is like **lap number and track position encoding**. Raw telemetry data (speed, throttle, brake) means nothing without knowing *when* and *where* it was recorded. A speed of 320 km/h means something very different on lap 1 (opening lap with cold tires) versus lap 50 (late-race tire degradation), and at the start/finish straight versus the apex of turn 1. Just as the strategy computer tags every data point with its lap number and track sector, the Transformer tags every token with its position in the sequence.

### Sinusoidal Positional Encoding

The original Transformer paper uses a clever mathematical approach: encode each position using sine and cosine waves at different frequencies.

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

$$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

#### Breaking down the formula:

| Component | Meaning | Intuition | F1 Analogy |
|-----------|---------|-----------|------------|
| $pos$ | Position in the sequence (0, 1, 2, ...) | Which token are we encoding? | Lap number in the race |
| $i$ | Dimension index (0, 1, ..., d/2) | Which "frequency channel"? | Different telemetry channels (speed, tire temp, fuel) |
| $d_{model}$ | Model dimension (e.g., 512) | Total embedding size | Total number of sensor channels |
| $10000^{2i/d_{model}}$ | Wavelength scaling | Low dimensions = fast waves, high dimensions = slow waves | Fast-changing data (throttle) vs. slow-changing data (fuel load) |

**What this means:** Each position is encoded as a point on many sine/cosine waves of different frequencies. Low-dimension pairs oscillate rapidly (capturing fine-grained position), while high-dimension pairs oscillate slowly (capturing broad position). Together they create a **unique fingerprint** for every position.

**F1 analogy:** Think of this as encoding race position using multiple time scales simultaneously. Fast-oscillating dimensions capture *which sector* of the lap you are in (changes every few seconds). Slow-oscillating dimensions capture *which stint* you are in (changes every 15-25 laps). Together, they uniquely identify any moment in the race.

**Why sin/cos?** Two key reasons:
1. **Unique encoding:** Every position gets a distinct vector
2. **Relative positions are linear:** $PE_{pos+k}$ can be expressed as a linear function of $PE_{pos}$ (rotation in 2D), so the model can easily learn to attend to relative offsets

In [None]:
# Implement sinusoidal positional encoding
def get_positional_encoding(max_len, d_model):
    """
    Generate sinusoidal positional encoding.
    
    Args:
        max_len: Maximum sequence length
        d_model: Model dimension (must be even)
    
    Returns:
        Tensor of shape (max_len, d_model)
    """
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (max_len, 1)
    
    # Compute the division term: 10000^(2i/d_model)
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
    
    # Even dimensions: sin, Odd dimensions: cos
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe

# Generate positional encoding
max_len = 100
d_model = 64
pe = get_positional_encoding(max_len, d_model)
print(f"Positional encoding shape: {pe.shape}")
print(f"\nFirst position (pos=0):")
print(f"  First 8 values: {pe[0, :8].numpy().round(3)}")
print(f"\nSecond position (pos=1):")
print(f"  First 8 values: {pe[1, :8].numpy().round(3)}")

In [None]:
# Visualization: Positional Encoding Heatmap
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Full heatmap
ax = axes[0]
im = ax.imshow(pe[:50, :].numpy(), cmap='RdBu', aspect='auto', interpolation='nearest')
ax.set_xlabel('Embedding Dimension', fontsize=11)
ax.set_ylabel('Position', fontsize=11)
ax.set_title('Positional Encoding Heatmap\n(first 50 positions)', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=ax, shrink=0.8)

# Individual dimensions showing different frequencies
ax = axes[1]
positions = np.arange(max_len)
dims_to_show = [0, 1, 10, 11, 30, 31]
colors = ['blue', 'blue', 'green', 'green', 'red', 'red']
styles = ['-', '--', '-', '--', '-', '--']

for dim, color, style in zip(dims_to_show, colors, styles):
    label = f'dim {dim} ({"sin" if dim % 2 == 0 else "cos"})'
    ax.plot(positions, pe[:, dim].numpy(), color=color, linestyle=style, 
            alpha=0.7, label=label, linewidth=1.5)

ax.set_xlabel('Position', fontsize=11)
ax.set_ylabel('Encoding Value', fontsize=11)
ax.set_title('Different Dimensions = Different Frequencies', fontsize=13, fontweight='bold')
ax.legend(fontsize=8, loc='upper right')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('positional_encoding.png', dpi=100, bbox_inches='tight')
plt.show()

print("Left: Each row is one position's encoding vector. Notice the wave patterns across dimensions.")
print("Right: Low dimensions (blue) oscillate fast, high dimensions (red) oscillate slowly.")
print("This creates a unique 'fingerprint' for every position.")

In [None]:
# Visualization: Each position has a unique encoding
# Show this by computing pairwise distances between position encodings
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Cosine similarity between positions
pe_norm = pe / pe.norm(dim=1, keepdim=True)
cos_sim = (pe_norm @ pe_norm.T).numpy()

ax = axes[0]
n_show = 50
im = ax.imshow(cos_sim[:n_show, :n_show], cmap='viridis', aspect='equal')
ax.set_xlabel('Position', fontsize=11)
ax.set_ylabel('Position', fontsize=11)
ax.set_title('Cosine Similarity Between\nPosition Encodings', fontsize=13, fontweight='bold')
plt.colorbar(im, ax=ax, shrink=0.8)

# Show that nearby positions are more similar
ax = axes[1]
ref_positions = [0, 10, 25, 40]
colors = ['blue', 'green', 'orange', 'red']
for ref_pos, color in zip(ref_positions, colors):
    similarities = cos_sim[ref_pos, :n_show]
    ax.plot(range(n_show), similarities, color=color, alpha=0.8, 
            label=f'Similarity to pos {ref_pos}', linewidth=1.5)

ax.set_xlabel('Position', fontsize=11)
ax.set_ylabel('Cosine Similarity', fontsize=11)
ax.set_title('Nearby Positions Are More Similar', fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('positional_uniqueness.png', dpi=100, bbox_inches='tight')
plt.show()

print("Left: Each position has a unique encoding (no two rows are identical).")
print("Right: Similarity peaks at the reference position and decays with distance.")
print("This smooth decay means the model can learn relative position relationships.")

### Learned vs Sinusoidal Positional Embeddings

| Approach | How it works | Pros | Cons | F1 Analogy |
|----------|-------------|------|------|------------|
| **Sinusoidal** (original paper) | Fixed sin/cos functions | Can extrapolate to longer sequences; no extra parameters | Slightly less flexible | Fixed GPS coordinates -- works for any circuit length |
| **Learned** (BERT, GPT-2) | Trainable embedding per position | Can adapt to data patterns | Fixed max length; more parameters | Team-specific track maps tuned from practice data |
| **Relative** (Transformer-XL, RoPE) | Encode distance between tokens, not absolute position | Better generalization; handles variable lengths | More complex implementation | Gap-to-car-ahead (relative) vs. absolute lap time |

In practice, learned embeddings are most common in modern models. **RoPE** (Rotary Position Embedding), used in LLaMA and many recent models, applies rotation matrices that naturally encode relative positions -- a clever evolution of the sinusoidal idea.

### Deep Dive: Why Sinusoidal Encoding Enables Relative Position

#### Key Insight

For any fixed offset $k$, the encoding at position $pos + k$ can be written as a **linear transformation** of the encoding at position $pos$. Specifically, for each pair of dimensions $(2i, 2i+1)$:

$$\begin{bmatrix} PE_{pos+k, 2i} \\ PE_{pos+k, 2i+1} \end{bmatrix} = \begin{bmatrix} \cos(k \cdot \omega_i) & \sin(k \cdot \omega_i) \\ -\sin(k \cdot \omega_i) & \cos(k \cdot \omega_i) \end{bmatrix} \begin{bmatrix} PE_{pos, 2i} \\ PE_{pos, 2i+1} \end{bmatrix}$$

This is a **rotation matrix!** Moving from position $pos$ to $pos + k$ is a rotation in each 2D subspace. Since rotations are linear, the model can learn to compute relative positions using simple linear operations (which is exactly what attention does).

#### Common Misconceptions

| Misconception | Reality |
|---------------|---------|
| "Positional encoding tells the model the exact position" | It provides a signal the model *learns* to interpret; the encoding itself is just numbers |
| "You need sinusoidal encodings specifically" | Learned encodings work just as well; the key is having *some* position signal |
| "Positional encoding is added only once" | It is added once at the input, but residual connections carry it through all layers |

---

## 3. The Encoder Block

### Intuitive Explanation

The encoder's job is to read the input sequence and build **rich, contextual representations** of each token. It does this by stacking identical blocks, each containing two sub-layers:

1. **Multi-Head Self-Attention:** Each token looks at all other tokens to understand context (you built this in Notebook 16)
2. **Position-wise Feed-Forward Network (FFN):** Each token is independently transformed through a small neural network

Around each sub-layer, there are two critical additions:
- **Residual connection:** Add the input back to the output (skip connection)
- **Layer normalization:** Normalize the result for stable training

**F1 analogy:** The encoder block is like a **telemetry processing pipeline**. Self-attention is the step where every sensor reading is cross-referenced with every other sensor reading -- "How does tire temperature relate to lap time? How does throttle application correlate with tire wear?" The FFN is where each sensor reading is independently refined -- converting raw voltage into meaningful engineering units. The residual connection ensures the original raw signal is never lost, and layer normalization keeps all signals on comparable scales regardless of whether you are measuring temperature in Celsius or pressure in bar.

```
Input (for each token)
  │
  ├──────────────────────┐
  │                      │ (residual)
  ▼                      │
Multi-Head Self-Attention │
  │                      │
  ▼                      │
  + ◄────────────────────┘
  │
LayerNorm
  │
  ├──────────────────────┐
  │                      │ (residual)
  ▼                      │
Feed-Forward Network     │
  │                      │
  ▼                      │
  + ◄────────────────────┘
  │
LayerNorm
  │
  ▼
Output (same shape as input)
```

### Add & Norm: Residual Connections + Layer Normalization

**Residual connections** solve the vanishing gradient problem in deep networks. Instead of learning $F(x)$, we learn $F(x) + x$. If the sub-layer has nothing useful to add, gradients can still flow through the identity path.

**Layer normalization** normalizes across the feature dimension (not the batch dimension like BatchNorm). For each token independently:

$$\text{LayerNorm}(x) = \frac{x - \mu}{\sigma + \epsilon} \cdot \gamma + \beta$$

| Component | What it does | Why it matters | F1 Parallel |
|-----------|-------------|----------------|-------------|
| **Residual connection** | Adds input to sub-layer output | Enables training of very deep networks | Baseline setup preserved -- new adjustments are *added* on top |
| **Layer normalization** | Normalizes each token's features | Stabilizes training, reduces sensitivity to scale | Normalizing across different race conditions -- wet vs. dry, hot vs. cold track |
| **Together (Add & Norm)** | `LayerNorm(x + SubLayer(x))` | The standard "wrapper" around every sub-layer | Adjust setup, then normalize so all readings are comparable |

**F1 analogy:** Layer normalization is like normalizing telemetry across different race conditions. A tire temperature of 100C means something different at Bahrain (50C ambient) versus Spa (15C ambient). LayerNorm ensures that the model sees standardized signals regardless of the "ambient conditions" of the data, just as engineers normalize readings before comparing across sessions.

### The Feed-Forward Network (FFN)

The FFN is applied to each position **independently** (the same network, same weights, for every token). It consists of two linear transformations with a ReLU (or GELU) activation in between:

$$\text{FFN}(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2$$

The inner dimension is typically **4x** the model dimension (e.g., $d_{model} = 512 \rightarrow d_{ff} = 2048$).

**Why is the FFN needed?** Self-attention is powerful but it only computes **weighted averages** of value vectors -- which is a linear operation. The FFN adds:
1. **Nonlinearity** (via ReLU/GELU) -- critical for learning complex functions
2. **Per-position processing** -- each token can independently transform its representation
3. **Memory** -- recent research suggests FFN layers act as key-value memories, storing factual knowledge

| Sub-layer | Operation | Capacity | F1 Parallel |
|-----------|-----------|----------|-------------|
| Self-Attention | Mixes information **across** tokens | Contextual understanding | Cross-referencing all sensor channels with each other |
| FFN | Transforms each token **independently** | Feature extraction, knowledge storage | Per-sensor signal processing and calibration |

In [None]:
# Implement the core components of an Encoder Block

class MultiHeadAttention(nn.Module):
    """
    Multi-head attention (from Notebook 16, now as a reusable module).
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections for Q, K, V, and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch, seq_len_q, d_model)
            key: (batch, seq_len_k, d_model)
            value: (batch, seq_len_k, d_model)
            mask: Optional mask (broadcastable to attention shape)
        
        Returns:
            output: (batch, seq_len_q, d_model)
            attention_weights: (batch, n_heads, seq_len_q, seq_len_k)
        """
        batch_size = query.size(0)
        
        # Project and reshape to (batch, n_heads, seq_len, d_k)
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V)
        
        # Concatenate heads and project
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(context)
        
        return output, attention_weights


class PositionwiseFFN(nn.Module):
    """
    Position-wise Feed-Forward Network.
    Two linear layers with ReLU activation. Applied independently to each position.
    """
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))


class EncoderBlock(nn.Module):
    """
    One Transformer Encoder block.
    
    Contains:
    1. Multi-head self-attention + Add & Norm
    2. Position-wise FFN + Add & Norm
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, n_heads)
        self.ffn = PositionwiseFFN(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len, d_model)
            mask: Optional attention mask
        
        Returns:
            output: (batch, seq_len, d_model)
        """
        # Sub-layer 1: Multi-head self-attention + Add & Norm
        attn_output, attn_weights = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))   # Residual + LayerNorm
        
        # Sub-layer 2: FFN + Add & Norm
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))    # Residual + LayerNorm
        
        return x

# Test the encoder block
d_model = 64
n_heads = 4
d_ff = 256
batch_size = 2
seq_len = 10

encoder_block = EncoderBlock(d_model, n_heads, d_ff)
x = torch.randn(batch_size, seq_len, d_model)
output = encoder_block(x)

print(f"Encoder Block")
print(f"  Input shape:  {x.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Same shape: {x.shape == output.shape}")
print(f"\nParameter count:")
total_params = sum(p.numel() for p in encoder_block.parameters())
for name, param in encoder_block.named_parameters():
    print(f"  {name}: {param.shape} ({param.numel():,} params)")
print(f"  TOTAL: {total_params:,} parameters")

In [None]:
# Visualization: Data flow through an Encoder Block
fig, ax = plt.subplots(figsize=(10, 8))
ax.set_xlim(0, 10)
ax.set_ylim(0, 12)
ax.axis('off')
ax.set_title('Data Flow Through One Encoder Block', fontsize=15, fontweight='bold', pad=20)

# Helper to draw boxes
def draw_box(ax, x, y, w, h, text, color, fontsize=10):
    rect = plt.Rectangle((x-w/2, y-h/2), w, h, linewidth=2, 
                          edgecolor=color, facecolor=color, alpha=0.2)
    ax.add_patch(rect)
    ax.text(x, y, text, ha='center', va='center', fontsize=fontsize, fontweight='bold')

def draw_arrow(ax, x1, y1, x2, y2, color='black'):
    ax.annotate('', xy=(x2, y2), xytext=(x1, y1),
                arrowprops=dict(arrowstyle='->', color=color, lw=2))

# Input
draw_box(ax, 5, 11, 3, 0.6, 'Input: (batch, seq_len, d_model)', 'gray')
draw_arrow(ax, 5, 10.7, 5, 9.7)

# Self-Attention
draw_box(ax, 5, 9.3, 4, 0.8, 'Multi-Head Self-Attention', 'steelblue')
draw_arrow(ax, 5, 8.9, 5, 8.1)

# Add & Norm 1
draw_box(ax, 5, 7.7, 3.5, 0.6, 'Add & LayerNorm', 'green')
# Residual connection 1
ax.annotate('', xy=(3.2, 7.7), xytext=(3.2, 10.3),
            arrowprops=dict(arrowstyle='->', color='red', lw=2, linestyle='--'))
ax.text(2.4, 9.0, 'residual', fontsize=9, color='red', rotation=90, va='center')
draw_arrow(ax, 5, 7.4, 5, 6.5)

# FFN
draw_box(ax, 5, 6.1, 4.5, 0.8, f'Feed-Forward Network\n(d_model -> d_ff -> d_model)', 'orange')
draw_arrow(ax, 5, 5.7, 5, 4.8)

# Add & Norm 2
draw_box(ax, 5, 4.4, 3.5, 0.6, 'Add & LayerNorm', 'green')
# Residual connection 2
ax.annotate('', xy=(3.2, 4.4), xytext=(3.2, 7.1),
            arrowprops=dict(arrowstyle='->', color='red', lw=2, linestyle='--'))
ax.text(2.4, 5.7, 'residual', fontsize=9, color='red', rotation=90, va='center')
draw_arrow(ax, 5, 4.1, 5, 3.2)

# Output
draw_box(ax, 5, 2.8, 3, 0.6, 'Output: (batch, seq_len, d_model)', 'gray')

# Annotations
ax.text(8.5, 9.3, 'Each token attends\nto all other tokens', fontsize=9, 
        ha='center', style='italic', color='steelblue')
ax.text(8.5, 6.1, 'Same MLP applied\nto each token\nindependently', fontsize=9,
        ha='center', style='italic', color='darkorange')
ax.text(8.5, 7.7, 'Stabilizes gradients', fontsize=9,
        ha='center', style='italic', color='green')

plt.tight_layout()
plt.savefig('encoder_block.png', dpi=100, bbox_inches='tight')
plt.show()

print("Key insight: the input and output have the SAME shape.")
print("This means encoder blocks can be stacked -- the output of one is the input to the next.")

### Stacking N Encoder Blocks

The original Transformer uses **N = 6** identical encoder blocks stacked on top of each other. Each block refines the representations:

- **Early layers** tend to capture syntactic patterns (word relationships, grammar)
- **Middle layers** develop semantic understanding (meaning, context)
- **Later layers** create task-specific representations

**F1 analogy:** Think of stacking encoder blocks like the stages of telemetry analysis. Early layers capture raw patterns ("the car braked here"). Middle layers build understanding ("the driver is struggling with understeer in sector 2"). Later layers produce race-specific insights ("pit window opens in 3 laps based on current degradation"). Each layer builds on the one before, just as each stage of analysis adds deeper interpretation.

Because input and output shapes are identical, stacking is trivial:

```python
for block in encoder_blocks:
    x = block(x)  # Same shape in, same shape out
```

Each layer builds on the representations from the previous layer, creating increasingly abstract and contextual embeddings.

---

## 4. The Decoder Block

### Intuitive Explanation

The decoder generates the output sequence one token at a time. It is similar to the encoder but has an important extra sub-layer and a critical constraint:

1. **Masked Multi-Head Self-Attention:** The decoder attends to its own previous outputs, but with a **causal mask** -- each position can only attend to earlier positions (no peeking at future tokens!)
2. **Cross-Attention:** The decoder attends to the encoder's output. Queries come from the decoder; keys and values come from the encoder. This is how the decoder "reads" the input.
3. **Feed-Forward Network:** Same as in the encoder.

**F1 analogy:** The decoder is the **strategy engineer making real-time calls**. Masked self-attention means the strategist can only consider decisions already made (you cannot un-pit the car) -- each decision is based only on what has happened before, not on future events. Cross-attention is how the strategist reads the full telemetry picture (the encoder output) to inform the next decision. The causal mask is like the fundamental constraint of racing: decisions must be made in real time, with no knowledge of the future.

```
Decoder Input (shifted right)
  │
  ├──────────────────────┐
  │                      │ (residual)
  ▼                      │
Masked Multi-Head         │
Self-Attention            │
(causal: no future)      │
  │                      │
  ▼                      │
  + ◄────────────────────┘
  │
LayerNorm
  │
  ├──────────────────────┐
  │                      │ (residual)
  ▼                      │
Multi-Head Cross-Attention│
(Q: decoder, K/V: encoder)│
  │                      │
  ▼                      │
  + ◄────────────────────┘
  │
LayerNorm
  │
  ├──────────────────────┐
  │                      │ (residual)
  ▼                      │
Feed-Forward Network     │
  │                      │
  ▼                      │
  + ◄────────────────────┘
  │
LayerNorm
  │
  ▼
Output
```

### Cross-Attention: The Bridge Between Encoder and Decoder

Cross-attention is the mechanism that connects the encoder and decoder. It works exactly like the self-attention you already know, but with a twist:

| Attention Type | Queries from | Keys/Values from | Purpose | F1 Parallel |
|---------------|-------------|-----------------|---------|-------------|
| **Encoder self-attention** | Encoder input | Encoder input | Each input token attends to all input tokens | Telemetry sensors cross-referencing each other |
| **Decoder masked self-attention** | Decoder input | Decoder input (masked) | Each output token attends to previous output tokens | Strategy history: what calls have we already made? |
| **Cross-attention** | Decoder | **Encoder output** | Each output token attends to all input tokens | Strategy engineer reading telemetry to inform next call |

**What this means:** When the decoder generates each output token, it can "look back" at the entire input sequence through cross-attention. For example, in translation, when generating the French word "chat", the cross-attention heads might focus on the English word "cat" in the encoder output.

**F1 analogy:** Cross-attention is how the strategy engineer (decoder) reads the telemetry data (encoder output). When deciding whether to pit, the strategy engineer's query ("Should we pit?") attends to all the encoded telemetry signals -- tire degradation, weather forecast, gap to competitors -- and builds a decision from the most relevant data. Different attention heads might focus on different factors: one head on tire data, another on competitor positions, another on weather.

This is the same encoder-decoder attention you learned about in Notebook 16, now integrated into the Transformer architecture.

In [None]:
class DecoderBlock(nn.Module):
    """
    One Transformer Decoder block.
    
    Contains:
    1. Masked multi-head self-attention + Add & Norm
    2. Multi-head cross-attention + Add & Norm  
    3. Position-wise FFN + Add & Norm
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        # Three sub-layers
        self.self_attention = MultiHeadAttention(d_model, n_heads)
        self.cross_attention = MultiHeadAttention(d_model, n_heads)
        self.ffn = PositionwiseFFN(d_model, d_ff)
        
        # Three layer norms (one per sub-layer)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Decoder input (batch, tgt_len, d_model)
            encoder_output: Encoder output (batch, src_len, d_model)
            src_mask: Mask for encoder output (optional)
            tgt_mask: Causal mask for decoder self-attention
        
        Returns:
            output: (batch, tgt_len, d_model)
        """
        # Sub-layer 1: Masked self-attention
        attn_output, _ = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Sub-layer 2: Cross-attention (Q from decoder, K/V from encoder)
        cross_output, cross_weights = self.cross_attention(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + self.dropout(cross_output))
        
        # Sub-layer 3: FFN
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_output))
        
        return x

# Test the decoder block
decoder_block = DecoderBlock(d_model, n_heads, d_ff)

# Decoder input (target sequence, slightly shorter for demonstration)
tgt_len = 8
decoder_input = torch.randn(batch_size, tgt_len, d_model)
encoder_output = torch.randn(batch_size, seq_len, d_model)  # From the encoder

# Create causal mask for decoder self-attention
causal_mask = torch.tril(torch.ones(tgt_len, tgt_len)).unsqueeze(0).unsqueeze(0)  # (1, 1, tgt_len, tgt_len)

decoder_output = decoder_block(decoder_input, encoder_output, tgt_mask=causal_mask)

print(f"Decoder Block")
print(f"  Decoder input shape:  {decoder_input.shape}")
print(f"  Encoder output shape: {encoder_output.shape}")
print(f"  Decoder output shape: {decoder_output.shape}")
print(f"\nCausal mask (tokens can only attend to earlier positions):")
print(causal_mask[0, 0].numpy().astype(int))
print(f"\nParameter count:")
total_params = sum(p.numel() for p in decoder_block.parameters())
for name, param in decoder_block.named_parameters():
    print(f"  {name}: {param.shape} ({param.numel():,} params)")
print(f"  TOTAL: {total_params:,} parameters")

In [None]:
# Visualization: Data flow through one Decoder Block
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlim(0, 12)
ax.set_ylim(0, 15)
ax.axis('off')
ax.set_title('Data Flow Through One Decoder Block', fontsize=15, fontweight='bold', pad=20)

# Decoder Input
draw_box(ax, 5, 14, 3.5, 0.6, 'Decoder Input (shifted right)', 'gray')
draw_arrow(ax, 5, 13.7, 5, 12.7)

# Masked Self-Attention
draw_box(ax, 5, 12.3, 4.5, 0.8, 'Masked Multi-Head Self-Attention', 'steelblue')
draw_arrow(ax, 5, 11.9, 5, 11.1)

# Add & Norm 1
draw_box(ax, 5, 10.7, 3.5, 0.6, 'Add & LayerNorm', 'green')
ax.annotate('', xy=(2.8, 10.7), xytext=(2.8, 13.4),
            arrowprops=dict(arrowstyle='->', color='red', lw=2, linestyle='--'))
ax.text(2.0, 12.0, 'residual', fontsize=9, color='red', rotation=90, va='center')
draw_arrow(ax, 5, 10.4, 5, 9.5)

# Cross-Attention
draw_box(ax, 5, 9.1, 4.5, 0.8, 'Multi-Head Cross-Attention', 'purple')
draw_arrow(ax, 5, 8.7, 5, 7.9)

# Encoder output arrow into cross-attention
draw_box(ax, 10, 9.1, 2.5, 0.8, 'Encoder\nOutput', 'orange')
draw_arrow(ax, 8.75, 9.1, 7.25, 9.1)
ax.text(8.0, 9.6, 'K, V', fontsize=10, fontweight='bold', color='purple')
ax.text(4.0, 9.65, 'Q', fontsize=10, fontweight='bold', color='purple')

# Add & Norm 2
draw_box(ax, 5, 7.5, 3.5, 0.6, 'Add & LayerNorm', 'green')
ax.annotate('', xy=(2.8, 7.5), xytext=(2.8, 10.1),
            arrowprops=dict(arrowstyle='->', color='red', lw=2, linestyle='--'))
ax.text(2.0, 8.7, 'residual', fontsize=9, color='red', rotation=90, va='center')
draw_arrow(ax, 5, 7.2, 5, 6.3)

# FFN
draw_box(ax, 5, 5.9, 4.5, 0.8, 'Feed-Forward Network', 'orange')
draw_arrow(ax, 5, 5.5, 5, 4.6)

# Add & Norm 3
draw_box(ax, 5, 4.2, 3.5, 0.6, 'Add & LayerNorm', 'green')
ax.annotate('', xy=(2.8, 4.2), xytext=(2.8, 6.9),
            arrowprops=dict(arrowstyle='->', color='red', lw=2, linestyle='--'))
ax.text(2.0, 5.5, 'residual', fontsize=9, color='red', rotation=90, va='center')
draw_arrow(ax, 5, 3.9, 5, 3.0)

# Output
draw_box(ax, 5, 2.6, 3, 0.6, 'Decoder Output', 'gray')

# Annotations
ax.text(8.5, 12.3, 'Causal mask:\ncan only see\npast tokens', fontsize=9,
        ha='center', style='italic', color='steelblue')
ax.text(8.0, 7.5, '"Read" the\ninput sequence', fontsize=9,
        ha='center', style='italic', color='purple')

plt.tight_layout()
plt.savefig('decoder_block.png', dpi=100, bbox_inches='tight')
plt.show()

print("The decoder block has THREE sub-layers (vs encoder's TWO).")
print("The extra sub-layer is cross-attention, which connects the decoder to the encoder.")

### Encoder Block vs Decoder Block

| Feature | Encoder Block | Decoder Block | F1 Parallel |
|---------|--------------|---------------|-------------|
| **Sub-layers** | 2 (self-attention + FFN) | 3 (masked self-attn + cross-attn + FFN) | Data processing vs. data processing + strategy lookup + decision |
| **Self-attention** | Bidirectional (sees all tokens) | Causal (only sees past tokens) | Full session replay vs. real-time racing (no future knowledge) |
| **Cross-attention** | None | Attends to encoder output | N/A vs. reading the telemetry dashboard |
| **Residual + LayerNorm** | Around each sub-layer | Around each sub-layer | Signal preservation at every stage |
| **Output shape** | Same as input | Same as input | Same data dimensions throughout |
| **Parameter count** | ~4 * d_model^2 | ~6 * d_model^2 (extra attention) | Leaner processing vs. fuller decision system |

---

## 5. The Full Transformer

### Putting It All Together

Now we assemble every piece into a complete Transformer. Here is the full architecture:

```
SOURCE TOKENS                              TARGET TOKENS (shifted right)
     │                                           │
┌────▼──────────────────┐                 ┌──────▼────────────────┐
│ Input Embedding       │                 │ Output Embedding      │
│    +                  │                 │    +                  │
│ Positional Encoding   │                 │ Positional Encoding   │
└────┬──────────────────┘                 └──────┬────────────────┘
     │                                           │
     │  ┌─────────────────────┐                  │  ┌─────────────────────┐
     └─▶│ Encoder Block 1     │           ┌─────▶│ Decoder Block 1     │◄─ encoder output
        │   - Self-Attention  │           │      │   - Masked Self-Attn │
        │   - Add & Norm      │           │      │   - Cross-Attention  │
        │   - FFN             │           │      │   - Add & Norm       │
        │   - Add & Norm      │           │      │   - FFN              │
        └─────────┬───────────┘           │      │   - Add & Norm       │
                  │                       │      └─────────┬───────────┘
        ┌─────────▼───────────┐           │                │
        │ Encoder Block 2     │           │      ┌─────────▼───────────┐
        │       ...           │           │      │ Decoder Block 2     │
        └─────────┬───────────┘           │      │       ...           │
                  │                       │      └─────────┬───────────┘
        ┌─────────▼───────────┐           │                │
        │ Encoder Block N     │───────────┘      ┌─────────▼───────────┐
        └─────────────────────┘                  │ Decoder Block N     │
                                                 └─────────┬───────────┘
                                                           │
                                                 ┌─────────▼───────────┐
                                                 │ Linear (d_model →   │
                                                 │         vocab_size) │
                                                 │ Softmax             │
                                                 └─────────┬───────────┘
                                                           │
                                                  OUTPUT PROBABILITIES
```

Every encoder block feeds into the **same** final encoder output. Every decoder block receives this encoder output for cross-attention.

In [None]:
class PositionalEncoding(nn.Module):
    """
    Adds sinusoidal positional encoding to token embeddings.
    Registered as a buffer (not a parameter) since it's fixed.
    """
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model) for broadcasting
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        """Add positional encoding to input embeddings."""
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class Transformer(nn.Module):
    """
    Complete Transformer model (encoder-decoder).
    
    Built from scratch using the components defined above.
    """
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=64, n_heads=4,
                 n_layers=2, d_ff=256, max_len=100, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        
        # Embeddings
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len, dropout)
        
        # Encoder stack
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        
        # Decoder stack
        self.decoder_blocks = nn.ModuleList([
            DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
        ])
        
        # Final output projection
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
    
    def encode(self, src, src_mask=None):
        """Run the encoder stack."""
        x = self.src_embedding(src) * math.sqrt(self.d_model)  # Scale embedding
        x = self.positional_encoding(x)
        
        for block in self.encoder_blocks:
            x = block(x, src_mask)
        
        return x
    
    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        """Run the decoder stack."""
        x = self.tgt_embedding(tgt) * math.sqrt(self.d_model)  # Scale embedding
        x = self.positional_encoding(x)
        
        for block in self.decoder_blocks:
            x = block(x, encoder_output, src_mask, tgt_mask)
        
        return x
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        """
        Full forward pass.
        
        Args:
            src: Source token IDs (batch, src_len)
            tgt: Target token IDs (batch, tgt_len)
            src_mask: Optional source mask
            tgt_mask: Causal mask for target
        
        Returns:
            logits: (batch, tgt_len, tgt_vocab_size)
        """
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        logits = self.output_projection(decoder_output)
        return logits
    
    @staticmethod
    def generate_causal_mask(size):
        """Generate a causal (lower-triangular) mask."""
        mask = torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)
        return mask  # (1, 1, size, size)


# Create and inspect the full Transformer
src_vocab_size = 50
tgt_vocab_size = 50
model = Transformer(
    src_vocab_size=src_vocab_size,
    tgt_vocab_size=tgt_vocab_size,
    d_model=64,
    n_heads=4,
    n_layers=2,
    d_ff=256,
    dropout=0.1
)

print("=" * 60)
print("FULL TRANSFORMER ARCHITECTURE")
print("=" * 60)
print(f"\nModel configuration:")
print(f"  d_model:        64")
print(f"  n_heads:        4 (d_k = 16)")
print(f"  n_layers:       2 (encoder) + 2 (decoder)")
print(f"  d_ff:           256")
print(f"  src_vocab_size: {src_vocab_size}")
print(f"  tgt_vocab_size: {tgt_vocab_size}")

# Test forward pass
src = torch.randint(0, src_vocab_size, (2, 10))  # batch=2, src_len=10
tgt = torch.randint(0, tgt_vocab_size, (2, 8))   # batch=2, tgt_len=8
tgt_mask = Transformer.generate_causal_mask(8)

logits = model(src, tgt, tgt_mask=tgt_mask)
print(f"\nForward pass:")
print(f"  Source shape:  {src.shape}")
print(f"  Target shape:  {tgt.shape}")
print(f"  Output logits: {logits.shape}")
print(f"  (batch=2, tgt_len=8, vocab={tgt_vocab_size})")

In [None]:
# Parameter counting and analysis
print("=" * 65)
print("PARAMETER ANALYSIS")
print("=" * 65)

total = 0
sections = {}

for name, param in model.named_parameters():
    total += param.numel()
    # Group by section
    section = name.split('.')[0]
    if section not in sections:
        sections[section] = 0
    sections[section] += param.numel()

print(f"\nParameter breakdown by component:")
print("-" * 45)
for section, count in sections.items():
    pct = 100 * count / total
    bar = '#' * int(pct / 2)
    print(f"  {section:25s}: {count:>8,} ({pct:5.1f}%) {bar}")
print("-" * 45)
print(f"  {'TOTAL':25s}: {total:>8,}")

print(f"\nDetailed parameter list:")
print("-" * 65)
for name, param in model.named_parameters():
    print(f"  {name:45s} {str(list(param.shape)):>18s} = {param.numel():>7,}")
print("-" * 65)
print(f"  Total parameters: {total:,}")

# Compare to the original Transformer paper
print(f"\n--- For reference: Original 'Attention Is All You Need' ---")
print(f"  d_model=512, n_heads=8, n_layers=6, d_ff=2048")
orig_params = (
    2 * 37000 * 512 +                    # embeddings (src + tgt, ~37K vocab)
    6 * (4 * 512 * 512 + 2 * 512) +      # encoder attention
    6 * (512 * 2048 + 2048 + 2048 * 512 + 512) +  # encoder FFN
    6 * (4 * 512 * 512 + 2 * 512) * 2 +  # decoder attention (self + cross)
    6 * (512 * 2048 + 2048 + 2048 * 512 + 512) +  # decoder FFN
    512 * 37000                           # output projection
)
print(f"  Approximate total: ~65M parameters")

---

## 6. Training a Transformer

### Teacher Forcing

During training, the decoder does **not** generate tokens autoregressively (using its own predictions as input). Instead, we use **teacher forcing**: we feed the **ground truth** target sequence (shifted right by one position) as decoder input.

```
Target:          [<sos>, "le", "chat", "est", "assis", <eos>]
Decoder input:   [<sos>, "le", "chat", "est", "assis"]        ← shifted right
Expected output: ["le", "chat", "est", "assis", <eos>]        ← what we train to predict
```

**Why?** If we used the decoder's own predictions during training, early in training (when predictions are random) the decoder would get garbage input and never learn. Teacher forcing provides a stable training signal.

**F1 analogy:** Teacher forcing is like training a junior strategist by showing them the actual race outcomes. Instead of letting the trainee make calls and see them go wrong (which would compound errors -- a bad early call leads to worse later calls), you show them the correct sequence of decisions at each point: "At lap 15, the right call was to pit; at lap 30, the right call was to stay out." The trainee learns from the correct history, not from their own early mistakes.

The causal mask ensures that even though all ground truth tokens are provided simultaneously, each position can only see previous positions -- maintaining the autoregressive property.

### Learning Rate Warmup

The original Transformer paper uses a specific learning rate schedule that has become iconic in deep learning:

$$lr = d_{model}^{-0.5} \cdot \min(step^{-0.5}, \; step \cdot warmup\_steps^{-1.5})$$

| Phase | What happens | Why | F1 Parallel |
|-------|-------------|-----|-------------|
| **Warmup** (first ~4000 steps) | LR increases linearly from 0 | Prevents early training instability; Adam needs good running estimates | Formation lap: building tire temperature gradually before racing |
| **Decay** (after warmup) | LR decreases as $1/\sqrt{step}$ | Gradually fine-tunes as model converges | Tire management: pushing hard early in a stint, then nursing tires as they degrade |

**What this means:** Start gently (low learning rate), ramp up as the optimizer warms up, then gradually cool down. This schedule was crucial for training the original Transformer stably.

In [None]:
# Visualization: Learning rate warmup schedule
def transformer_lr_schedule(step, d_model, warmup_steps):
    """The learning rate schedule from 'Attention Is All You Need'."""
    if step == 0:
        step = 1
    return d_model ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))

# Plot for different warmup values
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Different warmup steps
ax = axes[0]
steps = range(1, 20001)
for warmup in [500, 2000, 4000, 8000]:
    lrs = [transformer_lr_schedule(s, 512, warmup) for s in steps]
    ax.plot(steps, lrs, label=f'warmup={warmup}', linewidth=1.5)

ax.set_xlabel('Training Step', fontsize=11)
ax.set_ylabel('Learning Rate', fontsize=11)
ax.set_title('Transformer LR Schedule\n(varying warmup steps)', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# Different model dimensions
ax = axes[1]
for d in [64, 128, 256, 512]:
    lrs = [transformer_lr_schedule(s, d, 4000) for s in steps]
    ax.plot(steps, lrs, label=f'd_model={d}', linewidth=1.5)

ax.set_xlabel('Training Step', fontsize=11)
ax.set_ylabel('Learning Rate', fontsize=11)
ax.set_title('Transformer LR Schedule\n(varying d_model, warmup=4000)', fontsize=13, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('lr_warmup.png', dpi=100, bbox_inches='tight')
plt.show()
print("Left: More warmup steps = slower ramp-up but higher peak learning rate.")
print("Right: Larger d_model = lower learning rate (bigger models need gentler updates).")

### Training Task: Sequence Reversal

Let's train our Transformer on a simple but non-trivial task: **reversing a sequence of numbers**. For example:

- Input: `[3, 7, 1, 4, 2]`
- Output: `[2, 4, 1, 7, 3]`

This task requires the model to:
1. Read and remember the entire input (encoder)
2. Output tokens in reverse order (decoder with cross-attention)

**F1 analogy:** Think of this as the Transformer learning to reverse-engineer a race: given the finishing order `[VER, NOR, LEC, HAM, PIA]`, reconstruct the starting grid `[PIA, HAM, LEC, NOR, VER]`. The encoder reads the finish, the decoder generates the start -- a non-trivial mapping that requires understanding the full sequence.

We use special tokens: `0 = PAD`, `1 = SOS` (start of sequence), `2 = EOS` (end of sequence), values `3+` are the actual numbers.

In [None]:
# Generate training data for sequence reversal
def generate_reversal_data(n_samples, seq_len, vocab_size, pad_id=0, sos_id=1, eos_id=2):
    """
    Generate source-target pairs for the reversal task.
    
    Source: [v1, v2, ..., vn, EOS, PAD, ...]
    Target input:  [SOS, vn, vn-1, ..., v1]
    Target output: [vn, vn-1, ..., v1, EOS]
    """
    min_token = 3  # Tokens 0,1,2 are special
    src_data = []
    tgt_input_data = []
    tgt_output_data = []
    
    for _ in range(n_samples):
        # Random length between 3 and seq_len
        length = np.random.randint(3, seq_len + 1)
        
        # Random sequence of tokens
        seq = np.random.randint(min_token, vocab_size, size=length).tolist()
        reversed_seq = seq[::-1]
        
        # Source: seq + EOS + padding
        src = seq + [eos_id] + [pad_id] * (seq_len - length)
        
        # Target input: SOS + reversed_seq + padding
        tgt_in = [sos_id] + reversed_seq + [pad_id] * (seq_len - length)
        
        # Target output: reversed_seq + EOS + padding  
        tgt_out = reversed_seq + [eos_id] + [pad_id] * (seq_len - length)
        
        src_data.append(src)
        tgt_input_data.append(tgt_in)
        tgt_output_data.append(tgt_out)
    
    return (torch.tensor(src_data, dtype=torch.long),
            torch.tensor(tgt_input_data, dtype=torch.long),
            torch.tensor(tgt_output_data, dtype=torch.long))

# Generate data
VOCAB_SIZE = 20
SEQ_LEN = 8
N_TRAIN = 5000
N_TEST = 500

src_train, tgt_in_train, tgt_out_train = generate_reversal_data(N_TRAIN, SEQ_LEN, VOCAB_SIZE)
src_test, tgt_in_test, tgt_out_test = generate_reversal_data(N_TEST, SEQ_LEN, VOCAB_SIZE)

print(f"Training data: {N_TRAIN} samples")
print(f"Test data:     {N_TEST} samples")
print(f"Vocabulary:    {VOCAB_SIZE} tokens (0=PAD, 1=SOS, 2=EOS, 3-{VOCAB_SIZE-1}=values)")
print(f"Max seq length: {SEQ_LEN}")
print(f"\nExample:")
print(f"  Source:        {src_train[0].tolist()}")
print(f"  Target input:  {tgt_in_train[0].tolist()}")
print(f"  Target output: {tgt_out_train[0].tolist()}")

In [None]:
# Train the Transformer on sequence reversal
torch.manual_seed(42)

# Create model
model = Transformer(
    src_vocab_size=VOCAB_SIZE,
    tgt_vocab_size=VOCAB_SIZE,
    d_model=64,
    n_heads=4,
    n_layers=2,
    d_ff=256,
    max_len=SEQ_LEN + 5,
    dropout=0.1
)

# Training setup
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding

# Training loop
n_epochs = 30
batch_size = 128
train_losses = []
test_accuracies = []

model.train()
for epoch in range(n_epochs):
    epoch_loss = 0
    n_batches = 0
    
    # Shuffle training data
    perm = torch.randperm(N_TRAIN)
    src_shuffled = src_train[perm]
    tgt_in_shuffled = tgt_in_train[perm]
    tgt_out_shuffled = tgt_out_train[perm]
    
    for i in range(0, N_TRAIN, batch_size):
        src_batch = src_shuffled[i:i+batch_size]
        tgt_in_batch = tgt_in_shuffled[i:i+batch_size]
        tgt_out_batch = tgt_out_shuffled[i:i+batch_size]
        
        # Create causal mask for decoder
        tgt_len = tgt_in_batch.size(1)
        tgt_mask = Transformer.generate_causal_mask(tgt_len)
        
        # Forward pass
        logits = model(src_batch, tgt_in_batch, tgt_mask=tgt_mask)
        
        # Compute loss (reshape for CrossEntropyLoss)
        loss = criterion(logits.view(-1, VOCAB_SIZE), tgt_out_batch.view(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        epoch_loss += loss.item()
        n_batches += 1
    
    avg_loss = epoch_loss / n_batches
    train_losses.append(avg_loss)
    
    # Evaluate every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            tgt_mask = Transformer.generate_causal_mask(tgt_in_test.size(1))
            test_logits = model(src_test, tgt_in_test, tgt_mask=tgt_mask)
            test_preds = test_logits.argmax(dim=-1)
            
            # Calculate accuracy (ignoring padding)
            mask = tgt_out_test != 0
            correct = (test_preds == tgt_out_test) & mask
            accuracy = correct.sum().float() / mask.sum().float()
            test_accuracies.append((epoch + 1, accuracy.item()))
            
            print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f} | Test accuracy: {accuracy:.4f}")
        model.train()
    else:
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:3d} | Loss: {avg_loss:.4f}")

print(f"\nTraining complete!")

In [None]:
# Visualization: Training progress
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
ax = axes[0]
ax.plot(range(1, len(train_losses) + 1), train_losses, color='blue', linewidth=2)
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Cross-Entropy Loss', fontsize=11)
ax.set_title('Training Loss', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# Accuracy
ax = axes[1]
if test_accuracies:
    epochs_acc, accs = zip(*test_accuracies)
    ax.plot(epochs_acc, accs, 'o-', color='green', linewidth=2, markersize=8)
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel('Token Accuracy', fontsize=11)
    ax.set_title('Test Accuracy (Sequence Reversal)', fontsize=13, fontweight='bold')
    ax.set_ylim(0, 1.05)
    ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='Perfect')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_progress.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
# Show example predictions
model.eval()
with torch.no_grad():
    # Pick a few test examples
    n_show = 8
    tgt_mask = Transformer.generate_causal_mask(tgt_in_test.size(1))
    test_logits = model(src_test[:n_show], tgt_in_test[:n_show], tgt_mask=tgt_mask)
    test_preds = test_logits.argmax(dim=-1)

print("Example Predictions (Sequence Reversal)")
print("=" * 60)
for i in range(n_show):
    src = src_test[i].tolist()
    expected = tgt_out_test[i].tolist()
    predicted = test_preds[i].tolist()
    
    # Remove padding and special tokens for display
    src_clean = [t for t in src if t > 2]
    exp_clean = [t for t in expected if t > 2]
    pred_clean = [t for t in predicted if t > 2][:len(exp_clean)]
    
    match = "OK" if exp_clean == pred_clean else "WRONG"
    print(f"  Input:    {src_clean}")
    print(f"  Expected: {exp_clean}")
    print(f"  Got:      {pred_clean}  [{match}]")
    print()

In [None]:
# Visualize attention patterns during inference
model.eval()

# Get a single example
example_idx = 0
src_example = src_test[example_idx:example_idx+1]
tgt_example = tgt_in_test[example_idx:example_idx+1]

# Forward pass, collecting attention weights from all layers
attention_maps = {'encoder_self': [], 'decoder_self': [], 'decoder_cross': []}

with torch.no_grad():
    # Encoder pass
    enc_x = model.src_embedding(src_example) * math.sqrt(model.d_model)
    enc_x = model.positional_encoding(enc_x)
    for block in model.encoder_blocks:
        attn_out, attn_weights = block.self_attention(enc_x, enc_x, enc_x)
        attention_maps['encoder_self'].append(attn_weights.squeeze(0).numpy())
        enc_x = block.norm1(enc_x + attn_out)
        ffn_out = block.ffn(enc_x)
        enc_x = block.norm2(enc_x + ffn_out)
    
    encoder_output = enc_x
    
    # Decoder pass
    tgt_mask = Transformer.generate_causal_mask(tgt_example.size(1))
    dec_x = model.tgt_embedding(tgt_example) * math.sqrt(model.d_model)
    dec_x = model.positional_encoding(dec_x)
    for block in model.decoder_blocks:
        self_attn_out, self_attn_w = block.self_attention(dec_x, dec_x, dec_x, tgt_mask)
        attention_maps['decoder_self'].append(self_attn_w.squeeze(0).numpy())
        dec_x = block.norm1(dec_x + self_attn_out)
        
        cross_attn_out, cross_attn_w = block.cross_attention(dec_x, encoder_output, encoder_output)
        attention_maps['decoder_cross'].append(cross_attn_w.squeeze(0).numpy())
        dec_x = block.norm2(dec_x + cross_attn_out)
        
        ffn_out = block.ffn(dec_x)
        dec_x = block.norm3(dec_x + ffn_out)

# Plot cross-attention from last decoder layer (most interpretable)
src_tokens = [str(t) for t in src_test[example_idx].tolist() if t > 0]
tgt_tokens = [str(t) for t in tgt_in_test[example_idx].tolist() if t > 0]

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
n_heads = attention_maps['decoder_cross'][-1].shape[0]
n_show_heads = min(4, n_heads)

for h in range(n_show_heads):
    ax = axes[h]
    attn = attention_maps['decoder_cross'][-1][h]
    # Trim to non-padding tokens
    n_tgt = len(tgt_tokens)
    n_src = len(src_tokens)
    attn_trimmed = attn[:n_tgt, :n_src]
    
    im = ax.imshow(attn_trimmed, cmap='Blues', aspect='auto', vmin=0, vmax=1)
    ax.set_xticks(range(n_src))
    ax.set_xticklabels(src_tokens, fontsize=8)
    ax.set_yticks(range(n_tgt))
    ax.set_yticklabels(tgt_tokens, fontsize=8)
    ax.set_xlabel('Source (encoder)', fontsize=9)
    if h == 0:
        ax.set_ylabel('Target (decoder)', fontsize=9)
    ax.set_title(f'Head {h+1}', fontsize=11, fontweight='bold')

plt.suptitle('Cross-Attention Patterns (Last Decoder Layer)\nFor sequence reversal task',
             fontsize=13, fontweight='bold', y=1.05)
plt.tight_layout()
plt.savefig('attention_patterns.png', dpi=100, bbox_inches='tight')
plt.show()

print("Each heatmap shows what source tokens the decoder attends to at each step.")
print("For reversal, we expect the decoder to attend to source tokens in reverse order.")
print("Different heads may specialize in different patterns.")

---

## 7. Why Transformers Work So Well

### Intuitive Explanation

The Transformer has become the dominant architecture across NLP, vision, speech, biology, and more. Here is why:

**1. Parallel Computation**
Unlike RNNs that process tokens sequentially, Transformers process all tokens in parallel during training. This means training time scales with model size, not sequence length. A 100-token sequence takes roughly the same time as a 10-token sequence (ignoring the O(n^2) attention cost).

**2. Direct Connections Between All Positions**
In an RNN, information from position 1 must travel through positions 2, 3, 4, ... to reach position 100. Each step risks degrading the signal. In a Transformer, position 1 and position 100 are directly connected through attention -- information flows in a single step.

**3. Multiple Representation Subspaces (Heads)**
Multi-head attention creates multiple parallel "views" of the relationships between tokens. Different heads can specialize: one for syntax, one for coreference, one for semantic similarity, etc. This is like having multiple experts analyzing the same data simultaneously.

**4. Depth + Residual Connections**
Stacking multiple layers with residual connections lets the model build increasingly abstract representations while maintaining stable gradient flow. Each layer can focus on refining the representation rather than learning everything at once.

**5. Scalability**
Transformers exhibit remarkable **scaling laws**: performance improves predictably as you increase model size, data, and compute. This predictability enables massive investments in training large models.

**F1 analogy:** The Transformer's dominance mirrors why modern F1 strategy systems outperform human-only decision making:
1. **Parallel computation** = processing all 20 cars' telemetry simultaneously, not one at a time
2. **Direct connections** = any data point can inform any decision without going through intermediaries
3. **Multiple heads** = like having separate specialists for tires, weather, competitors, and fuel -- all analyzing in parallel
4. **Depth + residuals** = layered analysis that builds from raw data to strategy without losing the original signals
5. **Scalability** = more compute and data predictably leads to better strategy calls

### Deep Dive: Scaling Laws Preview

Research by Kaplan et al. (2020) and Hoffmann et al. (2022, "Chinchilla") revealed that Transformer performance follows **power law** relationships:

$$L(N) \approx \left(\frac{N_c}{N}\right)^{\alpha_N}$$

where $L$ is the loss, $N$ is the number of parameters, and $\alpha_N \approx 0.076$.

| Factor | Effect on Performance | Key Finding | F1 Parallel |
|--------|----------------------|-------------|-------------|
| **Parameters (N)** | Loss decreases as power law | Bigger models are more sample-efficient | More sensors = better data picture |
| **Data (D)** | Loss decreases as power law | More data always helps | More laps of practice = better setup |
| **Compute (C)** | Loss decreases as power law | Optimal N and D scale together | More simulation budget = better strategy |
| **Architecture** | Relatively minor effect | Scaling matters more than architecture tweaks | Car regulations matter less than development budget at scale |

#### Key Insight

The Transformer's real superpower is not any single clever trick -- it is that the architecture **scales**. As you add more parameters, more data, and more compute, performance improves predictably. This is why the field moved from "design clever architectures" to "scale Transformers."

**F1 analogy:** This mirrors F1's development philosophy: once the fundamental car concept is sound (like the Transformer architecture), the teams that win are the ones that can scale investment -- more wind tunnel time, more CFD simulation, more track testing data. The architecture is the baseline car design; scaling laws are the development curve.

#### Common Misconceptions

| Misconception | Reality |
|---------------|---------|
| "Attention is computationally cheap" | Attention is O(n^2) in sequence length -- this is actually a major bottleneck for long sequences |
| "Transformers understand language" | They learn statistical patterns; whether this constitutes understanding is debated |
| "Bigger is always better" | Chinchilla showed that many models were undertrained -- the optimal balance of size and data matters |
| "Transformers are only for NLP" | Vision Transformers (ViT), protein folding (AlphaFold), music, robotics -- they work everywhere |

In [None]:
# Interactive exploration: How architecture choices affect parameter count
def count_transformer_params(d_model, n_heads, n_layers, d_ff, vocab_size):
    """
    Count parameters in a Transformer (encoder-decoder).
    
    Returns dict with parameter breakdown.
    """
    params = {}
    
    # Embeddings (src + tgt, often shared)
    params['src_embedding'] = vocab_size * d_model
    params['tgt_embedding'] = vocab_size * d_model
    
    # Per encoder layer
    enc_attn = 4 * d_model * d_model + 4 * d_model  # Q,K,V,O weights + biases
    enc_ffn = d_model * d_ff + d_ff + d_ff * d_model + d_model  # Two linear layers
    enc_norm = 2 * (2 * d_model)  # Two LayerNorms (gamma + beta each)
    enc_per_layer = enc_attn + enc_ffn + enc_norm
    params['encoder'] = n_layers * enc_per_layer
    
    # Per decoder layer (extra cross-attention)
    dec_self_attn = 4 * d_model * d_model + 4 * d_model
    dec_cross_attn = 4 * d_model * d_model + 4 * d_model
    dec_ffn = d_model * d_ff + d_ff + d_ff * d_model + d_model
    dec_norm = 3 * (2 * d_model)  # Three LayerNorms
    dec_per_layer = dec_self_attn + dec_cross_attn + dec_ffn + dec_norm
    params['decoder'] = n_layers * dec_per_layer
    
    # Output projection
    params['output_proj'] = d_model * vocab_size + vocab_size
    
    params['total'] = sum(params.values())
    return params

# Compare different configurations
configs = [
    ("Our model", 64, 4, 2, 256, 20),
    ("Small", 256, 4, 4, 1024, 32000),
    ("Base (paper)", 512, 8, 6, 2048, 37000),
    ("Large (paper)", 1024, 16, 6, 4096, 37000),
    ("GPT-2 Small*", 768, 12, 12, 3072, 50257),
    ("GPT-3 175B*", 12288, 96, 96, 49152, 50257),
]

print("Transformer Parameter Counts Across Configurations")
print("=" * 80)
print(f"{'Config':18s} {'d_model':>7s} {'heads':>5s} {'layers':>6s} {'d_ff':>6s} {'vocab':>7s} {'Total Params':>14s}")
print("-" * 80)

for name, d, h, l, ff, v in configs:
    params = count_transformer_params(d, h, l, ff, v)
    total = params['total']
    if total > 1e9:
        total_str = f"{total/1e9:.1f}B"
    elif total > 1e6:
        total_str = f"{total/1e6:.1f}M"
    elif total > 1e3:
        total_str = f"{total/1e3:.1f}K"
    else:
        total_str = str(total)
    print(f"  {name:18s} {d:>5d} {h:>5d} {l:>6d} {ff:>6d} {v:>7d} {total_str:>12s}")

print("-" * 80)
print("* GPT models are decoder-only; numbers are approximate")

# Show parameter distribution for the base model
print(f"\n--- Parameter Distribution for Base Model (d=512, L=6) ---")
params = count_transformer_params(512, 8, 6, 2048, 37000)
total = params['total']
for component, count in params.items():
    if component != 'total':
        pct = 100 * count / total
        bar = '#' * int(pct / 2)
        print(f"  {component:18s}: {count:>12,} ({pct:5.1f}%) {bar}")
print(f"  {'TOTAL':18s}: {total:>12,}")

---

## Exercises

### Exercise 1: Implement Positional Encoding from Scratch

Implement the sinusoidal positional encoding without looking at the code above. Verify that your implementation matches.

**F1 framing:** You are building the lap/sector encoding system for the strategy computer. Each position (lap) needs a unique fingerprint composed of multiple frequency channels, just like how each moment in a race is uniquely identified by the combination of lap number, sector, tire age, and fuel load.

In [None]:
# EXERCISE 1: Implement sinusoidal positional encoding
def my_positional_encoding(max_len, d_model):
    """
    Generate sinusoidal positional encoding.
    
    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    
    Args:
        max_len: Maximum sequence length
        d_model: Model dimension (must be even)
    
    Returns:
        numpy array of shape (max_len, d_model)
    """
    # TODO: Implement this!
    # Hint 1: Create position array [0, 1, ..., max_len-1] as column vector
    # Hint 2: Create dimension array [0, 2, 4, ...] for the 2i values
    # Hint 3: Compute 10000^(2i/d_model) as the denominator
    # Hint 4: Fill even columns with sin, odd columns with cos
    
    pass

# Test
my_pe = my_positional_encoding(50, 64)
expected_pe = get_positional_encoding(50, 64).numpy()

if my_pe is not None:
    print(f"Your shape: {my_pe.shape}")
    print(f"Expected shape: {expected_pe.shape}")
    print(f"Match: {np.allclose(my_pe, expected_pe, atol=1e-6)}")
    print(f"Max difference: {np.max(np.abs(my_pe - expected_pe)):.8f}")
else:
    print("Implement the function above!")
    print(f"Expected shape: {expected_pe.shape}")
    print(f"First row should start with: {expected_pe[0, :6].round(4)}")

### Exercise 2: Build a Decoder-Only Transformer (like GPT)

Many modern language models (GPT, LLaMA) use only the decoder, with no encoder and no cross-attention. Implement a decoder-only Transformer that does autoregressive prediction.

**F1 framing:** Build a race commentary generator -- a decoder-only model that, given the sequence of events so far ("Safety car deployed, Verstappen pits, ..."), predicts the next event. No encoder needed because there is no separate input to translate; just the running sequence of race events, generated one at a time.

In [None]:
# EXERCISE 2: Decoder-only Transformer
class DecoderOnlyBlock(nn.Module):
    """
    A single block for a decoder-only Transformer.
    Like an encoder block but with a causal mask on self-attention.
    No cross-attention needed.
    """
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        # TODO: Implement this!
        # Hint: Same structure as EncoderBlock, but you'll apply a causal mask
        # in the forward pass
        
        pass
    
    def forward(self, x, mask=None):
        # TODO: Implement this!
        # Hint: Self-attention with causal mask + Add&Norm + FFN + Add&Norm
        
        pass


class DecoderOnlyTransformer(nn.Module):
    """
    GPT-style decoder-only Transformer.
    
    Takes a sequence of token IDs and predicts the next token at each position.
    """
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=2, 
                 d_ff=256, max_len=100, dropout=0.1):
        super().__init__()
        # TODO: Implement this!
        # Components needed:
        # 1. Token embedding
        # 2. Positional encoding 
        # 3. Stack of DecoderOnlyBlocks
        # 4. Output projection (Linear to vocab_size)
        
        pass
    
    def forward(self, x):
        # TODO: Implement this!
        # 1. Embed tokens and add positional encoding
        # 2. Create causal mask
        # 3. Pass through all blocks
        # 4. Project to vocab size
        # Return logits
        
        pass

# Test (uncomment when implemented)
# gpt = DecoderOnlyTransformer(vocab_size=100, d_model=64, n_heads=4, n_layers=2, d_ff=256)
# x = torch.randint(0, 100, (2, 20))
# logits = gpt(x)
# print(f"Input:  {x.shape}")
# print(f"Output: {logits.shape}  (should be [2, 20, 100])")
# print(f"Params: {sum(p.numel() for p in gpt.parameters()):,}")
print("Implement DecoderOnlyBlock and DecoderOnlyTransformer, then uncomment the test!")

### Exercise 3: Experiment with Architecture Hyperparameters

Train the reversal model with different hyperparameters and observe the effect on learning speed and final accuracy.

**F1 framing:** Think of this as testing different strategy computer configurations. Does a "wider" system (larger d_model, like more sensor channels) learn faster than a "deeper" one (more layers, like more analysis stages)? How does the number of attention heads (specialist analysts) affect performance?

In [None]:
# EXERCISE 3: Hyperparameter exploration
# Try training with different configurations and compare results.
# Fill in the results table after running each experiment.

configs_to_try = {
    "baseline":    {"d_model": 64,  "n_heads": 4, "n_layers": 2, "d_ff": 256},
    "deeper":      {"d_model": 64,  "n_heads": 4, "n_layers": 4, "d_ff": 256},
    "wider":       {"d_model": 128, "n_heads": 4, "n_layers": 2, "d_ff": 512},
    "more_heads":  {"d_model": 64,  "n_heads": 8, "n_layers": 2, "d_ff": 256},
    "smaller_ffn": {"d_model": 64,  "n_heads": 4, "n_layers": 2, "d_ff": 128},
}

# TODO: Train each configuration for 30 epochs and record:
# 1. Final training loss
# 2. Test accuracy
# 3. Total parameter count
# 4. Observations about convergence speed

# Example for one config:
# config = configs_to_try["baseline"]
# model = Transformer(VOCAB_SIZE, VOCAB_SIZE, **config)
# ... train ...
# Record results

print("Experiment with the configurations above!")
print("Key questions to answer:")
print("  1. Does depth (more layers) or width (larger d_model) help more?")
print("  2. Do more attention heads improve accuracy?")
print("  3. What's the smallest model that can solve this task perfectly?")
print("  4. Is there a point of diminishing returns?")

---

## Summary

### Key Concepts

**The Transformer Architecture:**
- Introduced in "Attention Is All You Need" (Vaswani et al., 2017)
- Replaces recurrence with self-attention for full parallelization
- Encoder-decoder structure, though many modern variants use only one side
- Every sub-layer wrapped with residual connections and layer normalization

**Positional Encoding:**
- Self-attention is permutation invariant -- it has no sense of order
- Sinusoidal encoding uses sin/cos waves at different frequencies
- Creates unique, smooth position representations
- Modern alternatives: learned embeddings, rotary positional encoding (RoPE)

**Encoder Block (2 sub-layers):**
1. Multi-head self-attention (bidirectional)
2. Position-wise feed-forward network
- Each with Add & Norm (residual + LayerNorm)

**Decoder Block (3 sub-layers):**
1. Masked multi-head self-attention (causal)
2. Multi-head cross-attention (Q from decoder, K/V from encoder)
3. Position-wise feed-forward network
- Each with Add & Norm

**Training:**
- Teacher forcing: feed ground truth target during training
- Learning rate warmup is critical for stability
- Label smoothing and gradient clipping improve generalization

### Connection to Deep Learning

| Concept | Where it appears in modern AI | F1 Parallel |
|---------|-------------------------------|-------------|
| Self-attention | Core of all Transformer-based models (GPT, BERT, T5, ViT) | Every data stream cross-referencing every other in parallel |
| Positional encoding | Every Transformer; RoPE in LLaMA, learned in GPT-2 | Lap number and track position tagging for telemetry |
| Encoder-only | BERT, RoBERTa (bidirectional understanding) | Full session analysis -- sees everything at once |
| Decoder-only | GPT, LLaMA, Claude (autoregressive generation) | Real-time strategy calls, one decision at a time |
| Encoder-decoder | T5, BART, mBART (seq-to-seq tasks) | Telemetry (encoder) translated into strategy calls (decoder) |
| Residual + LayerNorm | Universal in deep learning; enables very deep networks | Preserving baseline signal and normalizing across conditions |
| FFN layers | Store factual knowledge; recent research on "neurons as features" | Per-channel signal processing and calibration |
| Scaling laws | Drive modern AI investment decisions; compute-optimal training | More development budget predictably yields faster cars |

### Checklist

Before moving on, make sure you can:

- [ ] Explain why Transformers replaced RNNs (parallelism, direct connections, scalability)
- [ ] Implement sinusoidal positional encoding and explain why position information is needed
- [ ] Draw the data flow through an encoder block (self-attention + Add&Norm + FFN + Add&Norm)
- [ ] Draw the data flow through a decoder block (masked self-attn + cross-attn + FFN, each with Add&Norm)
- [ ] Explain the difference between self-attention and cross-attention
- [ ] Explain why the feed-forward network is needed (nonlinearity + per-position processing)
- [ ] Assemble a complete Transformer and trace a forward pass
- [ ] Count parameters in a Transformer given its hyperparameters
- [ ] Explain teacher forcing and learning rate warmup
- [ ] Describe the difference between encoder-only, decoder-only, and encoder-decoder models

---

## Next Steps

You have now built a Transformer **from scratch** -- the architecture that powers modern AI. You understand every component: embeddings, positional encoding, multi-head attention, feed-forward networks, residual connections, layer normalization, and how they all fit together. In F1 terms, you have built the complete strategy computer from individual components: sensor encoding, parallel data processing, cross-referencing, signal normalization, and the strategy-telemetry bridge.

In the next notebook, **Part 6.4: Language Models**, you will see how this architecture is used in practice:

- **Autoregressive language modeling** (GPT-style): predicting the next token
- **Masked language modeling** (BERT-style): predicting masked tokens
- **Tokenization**: BPE, WordPiece, and how text becomes tokens
- **Generation strategies**: greedy, beam search, top-k, top-p sampling
- **The journey from Transformer to ChatGPT**

You now have all the architectural knowledge needed to understand how large language models work. The next notebook will show you how they are trained and used.