# Part 4.4: Attention Mechanisms â€” The Formula 1 Edition

Attention is the single most important idea in modern deep learning. It solves a fundamental problem with sequence models: how do you let a network **focus** on the most relevant parts of its input? Before attention, models had to compress an entire input sequence into a single fixed-size vector -- a brutal bottleneck. Attention removes that bottleneck and has become the foundation of Transformers, the architecture behind GPT, BERT, and virtually every state-of-the-art model today.

**F1 analogy:** Imagine you are a race strategist trying to predict how a driver will perform on lap 45. Without attention, you have to cram the entire 44-lap history into a single summary vector -- losing critical detail. With attention, you can ask: "Which past laps are most relevant to predicting THIS lap?" Maybe lap 1 (tire compound choice) and lap 38 (the last pit stop) matter enormously, while laps 10-30 (uneventful clean-air running) are nearly irrelevant. Attention lets the model dynamically decide what to focus on, just as a strategist zeroes in on the moments that matter.

---

## Learning Objectives

By the end of this notebook, you should be able to:

- [ ] Explain why fixed-length encoding is a bottleneck for sequence models
- [ ] Describe the Query, Key, Value framework using the database lookup analogy
- [ ] Compute dot-product attention step by step
- [ ] Explain why we scale by sqrt(d_k) in scaled dot-product attention
- [ ] Implement self-attention from scratch in PyTorch
- [ ] Implement multi-head attention and explain why multiple heads help
- [ ] Compare additive (Bahdanau) vs multiplicative (Luong) attention
- [ ] Build and train a sequence model with attention

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')
torch.manual_seed(42)
np.random.seed(42)

---

## 1. The Problem with Fixed-Length Encoding

### Intuitive Explanation

Imagine you're reading a long novel, and someone asks you to summarize the **entire book in a single sentence**. No matter how good you are, you'll lose details. Early words, subtle plot points, and minor characters will be forgotten.

This is exactly the problem with encoder-decoder models (like sequence-to-sequence RNNs). The encoder reads the entire input sequence and compresses it into a **single fixed-size vector** -- the "context vector." The decoder then has to generate the entire output from only that one vector.

**The bottleneck:** As the input sequence gets longer, more and more information gets squeezed into the same small vector, and early parts of the sequence are gradually overwritten.

| Sequence Length | What Happens | Analogy | F1 Parallel |
|---------------|--------------|---------|-------------|
| Short (5-10 tokens) | Works okay | Summarize a paragraph in one sentence | Predicting the next lap from 5 laps of data -- manageable |
| Medium (20-50 tokens) | Starts losing detail | Summarize a chapter in one sentence | Compressing a full stint into one vector -- losing corner-by-corner detail |
| Long (100+ tokens) | Severe information loss | Summarize a book in one sentence | Cramming an entire 78-lap race into a single vector -- lap 1 details are gone |

**The key insight:** Instead of forcing the decoder to work from a single summary vector, let it **look back at all the encoder states** and decide which ones are relevant at each step.

**F1 analogy:** A fixed-length encoding is like asking an engineer to summarize an entire Grand Prix in a single number. Attention is like giving them the full lap chart and letting them look up any specific lap whenever they need to. When predicting lap 45, they can glance at lap 38 (last pit stop), lap 1 (compound choice), and lap 44 (current degradation rate) -- different queries pull different information.

### Visualization: Information Loss as Sequence Grows

In [None]:
# Simulate how much information about each position is retained
# in a fixed-length context vector as sequence length grows

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for ax, seq_len, title in zip(axes, [5, 15, 50], 
                               ['Short Sequence (5 tokens)', 
                                'Medium Sequence (15 tokens)',
                                'Long Sequence (50 tokens)']):
    # Simulate exponential decay of information retention
    # Earlier positions lose more info as they get overwritten
    positions = np.arange(seq_len)
    decay_rate = 0.15
    retention = np.exp(-decay_rate * (seq_len - 1 - positions))
    retention = retention / retention.max()  # Normalize
    
    colors = plt.cm.RdYlGn(retention)  # Red=lost, Green=retained
    ax.bar(positions, retention, color=colors, edgecolor='black', linewidth=0.5)
    ax.set_xlabel('Position in Sequence')
    ax.set_ylabel('Information Retained')
    ax.set_title(title)
    ax.set_ylim(0, 1.1)
    ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.3)
    ax.grid(True, alpha=0.3)

plt.suptitle('Fixed-Length Encoding: Early Positions Lose Information', fontsize=14)
plt.tight_layout()
plt.show()

print("Notice: As the sequence gets longer, early positions retain almost no information.")
print("This is the fundamental bottleneck that attention solves.")

In [None]:
# Visualize: Fixed encoding vs Attention
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Fixed-length encoding (bottleneck)
ax = axes[0]
encoder_positions = np.arange(6)
encoder_y = np.ones(6) * 2
decoder_positions = np.arange(4)
decoder_y = np.ones(4) * 0

# Draw encoder states
for i in range(6):
    ax.add_patch(plt.Rectangle((i - 0.3, 1.7), 0.6, 0.6, 
                               facecolor='steelblue', edgecolor='black', linewidth=1.5))
    ax.text(i, 2.0, f'h{i}', ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Draw bottleneck
ax.add_patch(plt.Rectangle((2.2, 0.9), 1.6, 0.6, 
                           facecolor='red', edgecolor='black', linewidth=2, alpha=0.8))
ax.text(3, 1.2, 'Context\nVector', ha='center', va='center', fontsize=9, fontweight='bold')

# Arrows from all encoder states to bottleneck
for i in range(6):
    ax.annotate('', xy=(3, 1.5), xytext=(i, 1.7),
               arrowprops=dict(arrowstyle='->', color='gray', lw=1))

# Draw decoder states
for i in range(4):
    ax.add_patch(plt.Rectangle((i + 1 - 0.3, -0.3), 0.6, 0.6, 
                               facecolor='orange', edgecolor='black', linewidth=1.5))
    ax.text(i + 1, 0.0, f's{i}', ha='center', va='center', fontsize=10, fontweight='bold')

# Arrow from bottleneck to decoder
for i in range(4):
    ax.annotate('', xy=(i + 1, 0.3), xytext=(3, 0.9),
               arrowprops=dict(arrowstyle='->', color='red', lw=1.5))

ax.set_xlim(-1, 6.5)
ax.set_ylim(-1, 3)
ax.set_title('Without Attention: Bottleneck', fontsize=13)
ax.text(3, 2.8, 'Encoder', ha='center', fontsize=11, color='steelblue', fontweight='bold')
ax.text(2.5, -0.8, 'Decoder', ha='center', fontsize=11, color='orange', fontweight='bold')
ax.axis('off')

# Right: Attention (direct connections)
ax = axes[1]

# Draw encoder states
for i in range(6):
    ax.add_patch(plt.Rectangle((i - 0.3, 1.7), 0.6, 0.6, 
                               facecolor='steelblue', edgecolor='black', linewidth=1.5))
    ax.text(i, 2.0, f'h{i}', ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Draw decoder states
for i in range(4):
    ax.add_patch(plt.Rectangle((i + 1 - 0.3, -0.3), 0.6, 0.6, 
                               facecolor='orange', edgecolor='black', linewidth=1.5))
    ax.text(i + 1, 0.0, f's{i}', ha='center', va='center', fontsize=10, fontweight='bold')

# Attention connections with varying thickness
np.random.seed(42)
for j in range(4):
    weights = np.random.dirichlet(np.ones(6) * 0.5)  # Random attention weights
    for i in range(6):
        alpha = weights[i]
        ax.annotate('', xy=(j + 1, 0.3), xytext=(i, 1.7),
                   arrowprops=dict(arrowstyle='->', color='green', 
                                  lw=alpha * 5, alpha=max(alpha, 0.1)))

ax.set_xlim(-1, 6.5)
ax.set_ylim(-1, 3)
ax.set_title('With Attention: Direct Access', fontsize=13)
ax.text(3, 2.8, 'Encoder', ha='center', fontsize=11, color='steelblue', fontweight='bold')
ax.text(2.5, -0.8, 'Decoder', ha='center', fontsize=11, color='orange', fontweight='bold')
ax.text(3, 0.95, 'Attention\nweights', ha='center', fontsize=9, color='green', fontstyle='italic')
ax.axis('off')

plt.tight_layout()
plt.show()

---

## 2. Attention Intuition

### The "Spotlight" Analogy

Imagine you're in a dark room full of objects. You have a flashlight (a **query**) and you're looking for something specific. As you sweep the beam around:

- Each object in the room is a potential **value** (the information you might want)
- Each object has a label describing it -- that's its **key**
- Your flashlight beam (query) is compared against every label (key)
- Objects whose labels **match** your query get brightly illuminated
- You combine the information from all visible objects, weighted by how brightly they're lit

This is exactly how attention works: **a learned, differentiable spotlight**.

### The Database Lookup Analogy

An even more precise analogy: attention is like a **soft database lookup**.

| Database Concept | Attention Concept | Example | F1 Parallel |
|-----------------|-------------------|---------|-------------|
| Search query | **Query (Q)** | "What adjective describes the subject?" | "What was the tire state at the last pit stop?" |
| Index/key in database | **Key (K)** | Each word's "matchability" representation | Each lap's "type label" (pit lap, push lap, safety car lap) |
| Stored record | **Value (V)** | Each word's content representation | Each lap's actual data (time, tire temp, fuel, gap) |
| Exact match | Hard attention | Only look at one item | Look at exactly one lap |
| Fuzzy/weighted match | **Soft attention** | Look at everything, weighted by relevance | Blend info from multiple laps, weighted by relevance |

**What this means:** Instead of retrieving one exact match (like a SQL query), attention retrieves a **weighted combination** of all values, where the weights depend on how well each key matches the query.

**F1 analogy:** Think of the query as the strategist asking: "Which past laps are most relevant to predicting THIS lap's performance?" Each past lap has a key (its characteristics: was it a push lap? safety car? post-pit?) and a value (its actual telemetry data). The attention weights tell you: "lap 38 is 40% relevant, lap 1 is 25% relevant, lap 44 is 20% relevant, everything else shares the remaining 15%." The output is a weighted blend of the most informative laps.

### Visualization: Attention as Soft Lookup

In [None]:
# Visualize attention as a soft database lookup
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Hard lookup (traditional database)
ax = axes[0]
keys = ['cat', 'dog', 'fish', 'bird', 'cat']
values = [0.9, 0.1, 0.2, 0.3, 0.8]
query = 'cat'
matches = [1 if k == query else 0 for k in keys]

colors = ['green' if m else 'lightgray' for m in matches]
bars = ax.bar(range(len(keys)), matches, color=colors, edgecolor='black', linewidth=1.5)
ax.set_xticks(range(len(keys)))
ax.set_xticklabels([f'Key: {k}\nVal: {v}' for k, v in zip(keys, values)], fontsize=9)
ax.set_ylabel('Match Weight')
ax.set_title(f'Hard Lookup: Query = "{query}"\n(exact match only)', fontsize=12)
ax.set_ylim(0, 1.3)
ax.text(2, 1.15, 'Result: average of matching values', ha='center', fontsize=10, fontstyle='italic')
ax.grid(True, alpha=0.3)

# Right: Soft attention lookup
ax = axes[1]
# Simulated attention weights (soft matching)
attention_weights = np.array([0.35, 0.05, 0.08, 0.12, 0.40])
colors_soft = plt.cm.Greens(attention_weights / attention_weights.max())

bars = ax.bar(range(len(keys)), attention_weights, color=colors_soft, edgecolor='black', linewidth=1.5)
ax.set_xticks(range(len(keys)))
ax.set_xticklabels([f'Key: {k}\nVal: {v}' for k, v in zip(keys, values)], fontsize=9)
ax.set_ylabel('Attention Weight')
ax.set_title(f'Soft Attention: Query = "{query}"\n(weighted combination)', fontsize=12)
ax.set_ylim(0, 0.6)

weighted_result = sum(w * v for w, v in zip(attention_weights, values))
ax.text(2, 0.52, f'Result: weighted sum = {weighted_result:.3f}', ha='center', fontsize=10, fontstyle='italic')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Hard lookup: Returns exact matches only (not differentiable)")
print("Soft attention: Returns weighted combination of ALL values (differentiable!)")
print("This differentiability is what makes attention trainable with backpropagation.")

### Example: Translation Attention

Consider translating "The black cat sat on the mat" to French. When generating each French word, the model should attend to different source words:

- Generating "Le" (The) --> attend to "The"
- Generating "chat" (cat) --> attend to "cat"  
- Generating "noir" (black) --> attend to "black"
- Generating "assis" (sat) --> attend to "sat"

Notice how **word order changes** between languages! The adjective "black" comes before "cat" in English but "noir" comes after "chat" in French. Attention handles this naturally by letting the decoder look at any position.

**F1 analogy:** This is like translating one team's raw telemetry into another team's format. The "channels" may be in a different order, some may be combined differently, and the sampling rates may differ. Attention lets the translation model look at any source channel at any time, regardless of order mismatches -- just as a strategist can pull any piece of data from the lap chart at any moment.

In [None]:
# Visualize translation attention pattern
source = ['The', 'black', 'cat', 'sat', 'on', 'the', 'mat']
target = ['Le', 'chat', 'noir', 'etait', 'assis', 'sur', 'le', 'tapis']

# Simulated attention weights (which source words each target word attends to)
attention_matrix = np.array([
    [0.85, 0.02, 0.03, 0.02, 0.02, 0.04, 0.02],  # Le -> The
    [0.02, 0.05, 0.82, 0.03, 0.02, 0.03, 0.03],  # chat -> cat
    [0.02, 0.80, 0.10, 0.02, 0.02, 0.02, 0.02],  # noir -> black
    [0.02, 0.02, 0.05, 0.78, 0.05, 0.03, 0.05],  # etait -> sat
    [0.02, 0.02, 0.05, 0.75, 0.08, 0.03, 0.05],  # assis -> sat
    [0.02, 0.02, 0.02, 0.03, 0.78, 0.05, 0.08],  # sur -> on
    [0.03, 0.02, 0.02, 0.02, 0.03, 0.80, 0.08],  # le -> the
    [0.02, 0.02, 0.02, 0.03, 0.05, 0.06, 0.80],  # tapis -> mat
])

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(attention_matrix, cmap='Blues', aspect='auto', vmin=0, vmax=1)

ax.set_xticks(range(len(source)))
ax.set_xticklabels(source, fontsize=12, fontweight='bold')
ax.set_yticks(range(len(target)))
ax.set_yticklabels(target, fontsize=12, fontweight='bold')
ax.set_xlabel('Source (English)', fontsize=13)
ax.set_ylabel('Target (French)', fontsize=13)
ax.set_title('Attention Weights in Translation\nBrighter = More Attention', fontsize=14)

# Add text annotations
for i in range(len(target)):
    for j in range(len(source)):
        val = attention_matrix[i, j]
        color = 'white' if val > 0.5 else 'black'
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=8, color=color)

plt.colorbar(im, label='Attention Weight')
plt.tight_layout()
plt.show()

print('Notice: "noir" (row 3) attends to "black" (col 2), NOT "cat" (col 3)')
print('Attention naturally handles word reordering between languages!')

---

## 3. Dot-Product Attention

### Intuitive Explanation

Now let's make the "soft lookup" concrete. The simplest form of attention uses the **dot product** to measure how well a query matches each key. Remember from linear algebra: the dot product measures similarity between two vectors.

The three-step recipe:

1. **Score:** Compute similarity between query and each key using dot product
2. **Normalize:** Pass scores through softmax to get weights that sum to 1
3. **Aggregate:** Compute weighted sum of values using those weights

$$\text{Attention}(q, K, V) = \text{softmax}(q \cdot K^T) \cdot V$$

#### Breaking down the formula:

| Component | Shape | Meaning | F1 Parallel |
|-----------|-------|---------|-------------|
| $q$ | $(d_k,)$ | The query vector -- "what am I looking for?" | "I need laps with high tire degradation" |
| $K$ | $(n, d_k)$ | Matrix of key vectors -- one per position | Each lap's characteristic signature |
| $V$ | $(n, d_v)$ | Matrix of value vectors -- the information to retrieve | Each lap's actual telemetry data |
| $q \cdot K^T$ | $(n,)$ | Similarity scores (one per key) | How relevant each past lap is to the current query |
| $\text{softmax}(\cdot)$ | $(n,)$ | Attention weights (sum to 1) | Normalized relevance: "60% from lap 38, 25% from lap 1..." |
| Output | $(d_v,)$ | Weighted combination of values | Blended telemetry from the most relevant laps |

**What this means:** The dot product $q \cdot k_i$ is large when the query and key point in the same direction (similar), and small (or negative) when they point in different directions. Softmax converts these raw scores into a probability distribution.

### Step-by-Step Calculation

In [None]:
# Step-by-step dot-product attention with a small example
print("=" * 60)
print("DOT-PRODUCT ATTENTION: Step by Step")
print("=" * 60)

# Small example: d_k = 3, n = 4 keys/values
d_k = 3

# Query: "What am I looking for?"
query = np.array([1.0, 0.5, 0.0])
print(f"\nQuery vector: {query}")

# Keys: each position has a key vector
keys = np.array([
    [1.0, 0.0, 0.0],   # Key 0: mostly dimension 0
    [0.0, 1.0, 0.0],   # Key 1: mostly dimension 1
    [0.8, 0.6, 0.0],   # Key 2: similar to query!
    [0.0, 0.0, 1.0],   # Key 3: mostly dimension 2
])
print(f"\nKeys (4 x {d_k}):")
for i, k in enumerate(keys):
    print(f"  Key {i}: {k}")

# Values: the actual information to retrieve
values = np.array([
    [10, 0],
    [0, 10],
    [5, 5],
    [0, 0],
])
print(f"\nValues (4 x 2):")
for i, v in enumerate(values):
    print(f"  Value {i}: {v}")

# STEP 1: Compute scores (dot products)
print(f"\n{'='*60}")
print("STEP 1: Compute scores = query . key_i")
print("=" * 60)
scores = keys @ query  # Same as np.dot(keys, query)
for i in range(len(keys)):
    dot_detail = " + ".join([f"{query[j]:.1f}*{keys[i,j]:.1f}" for j in range(d_k)])
    print(f"  score[{i}] = {dot_detail} = {scores[i]:.2f}")

# STEP 2: Apply softmax
print(f"\n{'='*60}")
print("STEP 2: Attention weights = softmax(scores)")
print("=" * 60)
exp_scores = np.exp(scores - scores.max())  # Numerically stable softmax
weights = exp_scores / exp_scores.sum()
for i in range(len(weights)):
    print(f"  weight[{i}] = {weights[i]:.4f}")
print(f"  Sum of weights: {weights.sum():.4f} (always sums to 1)")

# STEP 3: Weighted sum of values
print(f"\n{'='*60}")
print("STEP 3: Output = weighted sum of values")
print("=" * 60)
output = weights @ values
for i in range(len(weights)):
    print(f"  {weights[i]:.4f} * {values[i]} = {weights[i] * values[i]}")
print(f"  Output = {output}")
print(f"\nThe output is closest to Value 2 [5,5] because Key 2 was most similar to the Query!")

### Visualization: Attention Weights as Heatmap

In [None]:
# Visualize the attention computation we just did
fig, axes = plt.subplots(1, 4, figsize=(16, 4), 
                         gridspec_kw={'width_ratios': [1, 1, 0.5, 1]})

# Plot 1: Query vs Keys similarity
ax = axes[0]
similarity = keys @ query
colors = plt.cm.Greens(similarity / similarity.max())
ax.barh(range(len(keys)), similarity, color=colors, edgecolor='black')
ax.set_yticks(range(len(keys)))
ax.set_yticklabels([f'Key {i}' for i in range(len(keys))])
ax.set_xlabel('Dot Product Score')
ax.set_title('Step 1: Scores\n(query . key)')
ax.invert_yaxis()
ax.grid(True, alpha=0.3)

# Plot 2: Attention weights (after softmax)
ax = axes[1]
colors_w = plt.cm.Greens(weights / weights.max())
ax.barh(range(len(weights)), weights, color=colors_w, edgecolor='black')
ax.set_yticks(range(len(weights)))
ax.set_yticklabels([f'w[{i}]={weights[i]:.3f}' for i in range(len(weights))])
ax.set_xlabel('Attention Weight')
ax.set_title('Step 2: Softmax\n(normalize scores)')
ax.invert_yaxis()
ax.grid(True, alpha=0.3)

# Plot 3: Arrow
ax = axes[2]
ax.annotate('', xy=(0.8, 0.5), xytext=(0.2, 0.5),
           arrowprops=dict(arrowstyle='->', lw=3, color='black'))
ax.text(0.5, 0.65, 'Weighted\nSum', ha='center', va='center', fontsize=11, fontweight='bold')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis('off')

# Plot 4: Output
ax = axes[3]
ax.bar(['dim 0', 'dim 1'], output, color=['steelblue', 'orange'], edgecolor='black', linewidth=1.5)
ax.set_ylabel('Value')
ax.set_title(f'Step 3: Output\n= [{output[0]:.2f}, {output[1]:.2f}]')
ax.grid(True, alpha=0.3)

plt.suptitle('Dot-Product Attention: Complete Pipeline', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---

## 4. Scaled Dot-Product Attention

### Intuitive Explanation

There's a subtle problem with plain dot-product attention: as the dimension $d_k$ grows, dot products get **larger in magnitude**. Why? Each element of the dot product adds a term, so with more dimensions, the sum grows.

When dot products are large, softmax produces outputs that are very close to one-hot vectors (nearly all the weight on one key). This means:
- Gradients become tiny (softmax saturation)
- The model can't learn to spread attention across multiple keys
- Training becomes unstable

**The fix:** Divide by $\sqrt{d_k}$ to keep the variance of scores roughly constant regardless of dimension.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

#### Breaking down the formula:

| Component | Meaning | Why It's There | F1 Parallel |
|-----------|---------|----------------|-------------|
| $Q$ | Query matrix $(n_q, d_k)$ | What we're looking for | Current lap's questions about the race history |
| $K$ | Key matrix $(n_k, d_k)$ | What's available to match | Each past lap's matchable signature |
| $V$ | Value matrix $(n_k, d_v)$ | Information to retrieve | Each past lap's actual data |
| $QK^T$ | Score matrix $(n_q, n_k)$ | Raw similarities | Raw relevance of every past lap to every query |
| $\sqrt{d_k}$ | Scaling factor | Keeps scores in good range for softmax | Prevents the model from being overconfident about one lap |
| softmax | Row-wise normalization | Convert scores to weights | Turn raw scores into a proper "attention budget" |

**What this means:** The $\sqrt{d_k}$ scaling is a simple but critical trick. Without it, attention in high dimensions would collapse to hard attention (looking at only one position), losing the benefit of soft weighting.

**F1 analogy:** Without scaling, the model becomes an overconfident strategist who only looks at the single most similar past lap and ignores everything else. With scaling, the model blends information from several relevant laps -- which is almost always a better strategy.

### Visualization: Why Scaling Matters

In [None]:
# Show how dot product magnitude grows with dimension
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Left: Dot product magnitude vs dimension
ax = axes[0]
dims = [4, 16, 64, 128, 256, 512]
dot_stds = []

for d in dims:
    # Random unit-variance vectors
    q = np.random.randn(1000, d)
    k = np.random.randn(1000, d)
    dots = np.sum(q * k, axis=1)
    dot_stds.append(dots.std())

ax.plot(dims, dot_stds, 'bo-', linewidth=2, markersize=8)
ax.plot(dims, [np.sqrt(d) for d in dims], 'r--', linewidth=2, label=r'$\sqrt{d_k}$')
ax.set_xlabel('Dimension $d_k$')
ax.set_ylabel('Std of dot products')
ax.set_title('Dot Product Magnitude\nGrows with Dimension')
ax.legend()
ax.grid(True, alpha=0.3)

# Middle: Softmax on unscaled scores
ax = axes[1]
np.random.seed(42)
for d_k in [4, 64, 512]:
    q = np.random.randn(d_k)
    K = np.random.randn(8, d_k)
    scores = K @ q  # Unscaled
    weights = np.exp(scores - scores.max()) / np.exp(scores - scores.max()).sum()
    ax.plot(range(8), sorted(weights, reverse=True), 'o-', 
            label=f'd_k={d_k}', linewidth=2, markersize=6)

ax.set_xlabel('Key index (sorted)')
ax.set_ylabel('Attention weight')
ax.set_title('WITHOUT Scaling\n(higher dim = more peaked)')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Softmax on scaled scores
ax = axes[2]
np.random.seed(42)
for d_k in [4, 64, 512]:
    q = np.random.randn(d_k)
    K = np.random.randn(8, d_k)
    scores = (K @ q) / np.sqrt(d_k)  # Scaled!
    weights = np.exp(scores - scores.max()) / np.exp(scores - scores.max()).sum()
    ax.plot(range(8), sorted(weights, reverse=True), 'o-', 
            label=f'd_k={d_k}', linewidth=2, markersize=6)

ax.set_xlabel('Key index (sorted)')
ax.set_ylabel('Attention weight')
ax.set_title('WITH Scaling (/$\\sqrt{d_k}$)\n(consistent regardless of dim)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Without scaling: high dimensions -> almost one-hot attention (bad for learning)")
print("With scaling: attention distribution is consistent across dimensions (good!)")

### Implement Scaled Dot-Product Attention from Scratch

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Query tensor (..., seq_len_q, d_k)
        K: Key tensor (..., seq_len_k, d_k)
        V: Value tensor (..., seq_len_k, d_v)
        mask: Optional mask tensor (broadcastable to score shape)
    
    Returns:
        output: Weighted sum of values (..., seq_len_q, d_v)
        weights: Attention weights (..., seq_len_q, seq_len_k)
    """
    d_k = Q.shape[-1]
    
    # Step 1: Compute scaled scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    
    # Step 2: Apply mask (if provided) - set masked positions to -inf
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    
    # Step 3: Softmax to get weights
    weights = F.softmax(scores, dim=-1)
    
    # Step 4: Weighted sum of values
    output = torch.matmul(weights, V)
    
    return output, weights

# Test it!
torch.manual_seed(42)

# Batch of 1, sequence length 5, dimension 4
seq_len, d_k, d_v = 5, 4, 4
Q = torch.randn(1, seq_len, d_k)
K = torch.randn(1, seq_len, d_k)
V = torch.randn(1, seq_len, d_v)

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape}")
print(f"\nAttention weights (each row sums to 1):")
print(weights[0].detach().numpy().round(3))
print(f"Row sums: {weights[0].sum(dim=-1).detach().numpy().round(4)}")

---

## 5. Self-Attention

### Intuitive Explanation

In the examples so far, queries came from one sequence (decoder) and keys/values from another (encoder). **Self-attention** is the special case where queries, keys, and values all come from the **same sequence**.

**Key insight:** Self-attention lets each position in a sequence "look at" every other position in the same sequence. This is how a model understands context:

- In "The animal didn't cross the street because **it** was too tired" -- the word "it" needs to attend to "animal" to understand the sentence
- In "The animal didn't cross the street because **it** was too wide" -- now "it" should attend to "street"

Self-attention allows the model to figure out these relationships by learning what to attend to.

**F1 analogy:** Self-attention is each lap attending to every other lap in the same race. Lap 45 can directly "look at" lap 1, lap 20, and lap 44 to understand its own context. Did a safety car on lap 20 bunch up the field? Did a pit stop on lap 38 change the tire compound? Did a rain shower on lap 1 affect track evolution? Each lap gathers context from whichever other laps are most relevant to understanding itself.

### How Self-Attention Works

Given an input sequence $X$ of shape $(n, d_{model})$:

1. **Project** $X$ into three different spaces using learned weight matrices:
   - $Q = XW^Q$ (queries)
   - $K = XW^K$ (keys) 
   - $V = XW^V$ (values)

2. **Apply** scaled dot-product attention: $\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$

| Step | What Happens | Analogy | F1 Parallel |
|------|-------------|---------|-------------|
| Compute Q | Each position asks "what am I looking for?" | Formulating a question | Lap 45 asks: "which past laps had similar tire conditions?" |
| Compute K | Each position broadcasts "here's what I offer" | Creating a matchable label | Each lap broadcasts: "I'm a push lap on hard tires, lap 12 of stint" |
| Compute V | Each position says "here's my content" | Storing information | Each lap stores its full telemetry vector |
| Q . K^T | Every position scores against every other | Checking all question-label pairs | Lap 45 scores every other lap's relevance |
| Softmax | Normalize scores to weights | Decide how much to focus on each | Allocate attention budget across the race |
| Weights . V | Gather relevant information | Retrieve and combine answers | Blend telemetry from the most relevant laps |

**What this means:** The Q, K, V projections let the **same word** play different roles depending on context. "Cat" as a query asks "who modifies me?"; as a key it answers "I'm a noun, an animal"; as a value it provides its embedding content.

### Visualization: Self-Attention in a Sentence

In [None]:
# Visualize self-attention: which words attend to which in a sentence
words = ['The', 'cat', 'sat', 'on', 'the', 'mat']
n = len(words)

# Simulated self-attention weights (what each word attends to)
# Designed to show realistic patterns
self_attn = np.array([
    [0.30, 0.40, 0.05, 0.05, 0.10, 0.10],  # "The" -> attends to "cat" (its noun)
    [0.15, 0.20, 0.35, 0.05, 0.05, 0.20],  # "cat" -> attends to "sat" (its verb) and "mat"
    [0.05, 0.45, 0.15, 0.10, 0.05, 0.20],  # "sat" -> attends to "cat" (subject)
    [0.05, 0.05, 0.30, 0.10, 0.10, 0.40],  # "on" -> attends to "sat" and "mat"
    [0.05, 0.05, 0.05, 0.05, 0.30, 0.50],  # "the" -> attends to "mat" (its noun)
    [0.05, 0.15, 0.20, 0.25, 0.10, 0.25],  # "mat" -> attends to "sat", "on", itself
])

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Heatmap
ax = axes[0]
im = ax.imshow(self_attn, cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
ax.set_xticks(range(n))
ax.set_xticklabels(words, fontsize=12, fontweight='bold')
ax.set_yticks(range(n))
ax.set_yticklabels(words, fontsize=12, fontweight='bold')
ax.set_xlabel('Attending TO (keys)', fontsize=12)
ax.set_ylabel('Attending FROM (queries)', fontsize=12)
ax.set_title('Self-Attention Weights', fontsize=13)
for i in range(n):
    for j in range(n):
        val = self_attn[i, j]
        color = 'white' if val > 0.3 else 'black'
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', fontsize=9, color=color)
plt.colorbar(im, ax=ax)

# Right: Attention as arrows for a specific word
ax = axes[1]
focus_word_idx = 2  # "sat"
focus_weights = self_attn[focus_word_idx]

# Position words horizontally
x_positions = np.linspace(0, 5, n)
y_base = 0.5

for i, (word, pos) in enumerate(zip(words, x_positions)):
    fontsize = 14 + focus_weights[i] * 20  # Bigger = more attention
    alpha = 0.3 + focus_weights[i] * 1.5
    color = 'green' if i == focus_word_idx else 'steelblue'
    ax.text(pos, y_base, word, fontsize=fontsize, ha='center', va='center',
            fontweight='bold', alpha=min(alpha, 1.0), color=color,
            bbox=dict(boxstyle='round,pad=0.3', facecolor='lightyellow' if i == focus_word_idx else 'lightblue',
                      alpha=min(alpha, 0.8), edgecolor='gray'))
    
    # Draw attention arrows from focus word
    if i != focus_word_idx:
        ax.annotate('', xy=(pos, y_base + 0.15), xytext=(x_positions[focus_word_idx], y_base + 0.15),
                   arrowprops=dict(arrowstyle='->', color='red', 
                                  lw=focus_weights[i] * 8, alpha=focus_weights[i] * 2))
        ax.text(pos, y_base + 0.3, f'{focus_weights[i]:.2f}', 
                ha='center', fontsize=9, color='red', alpha=max(focus_weights[i] * 2, 0.3))

ax.set_xlim(-0.5, 5.5)
ax.set_ylim(-0.2, 1.0)
ax.set_title(f'What does "{words[focus_word_idx]}" attend to?', fontsize=13)
ax.axis('off')

plt.tight_layout()
plt.show()

print(f'The word "{words[focus_word_idx]}" attends most to "{words[np.argmax(focus_weights)]}"')
print("This is how self-attention captures syntactic relationships!")

### Implement Self-Attention from Scratch

In [None]:
class SelfAttention(nn.Module):
    """
    Self-attention layer.
    
    Takes a sequence and lets each position attend to all positions
    in the same sequence using learned Q, K, V projections.
    """
    def __init__(self, d_model, d_k=None, d_v=None):
        """
        Args:
            d_model: Input/output dimension
            d_k: Key/query dimension (defaults to d_model)
            d_v: Value dimension (defaults to d_model)
        """
        super().__init__()
        d_k = d_k or d_model
        d_v = d_v or d_model
        
        self.d_k = d_k
        
        # Learned projection matrices
        self.W_q = nn.Linear(d_model, d_k, bias=False)  # Project to queries
        self.W_k = nn.Linear(d_model, d_k, bias=False)  # Project to keys
        self.W_v = nn.Linear(d_model, d_v, bias=False)  # Project to values
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor (batch, seq_len, d_model)
            mask: Optional attention mask
            
        Returns:
            output: (batch, seq_len, d_v)
            weights: (batch, seq_len, seq_len)
        """
        # Step 1: Project input to Q, K, V
        Q = self.W_q(x)  # (batch, seq_len, d_k)
        K = self.W_k(x)  # (batch, seq_len, d_k)
        V = self.W_v(x)  # (batch, seq_len, d_v)
        
        # Step 2: Compute scaled dot-product attention
        output, weights = scaled_dot_product_attention(Q, K, V, mask)
        
        return output, weights

# Test it!
torch.manual_seed(42)

batch_size, seq_len, d_model = 2, 6, 8
x = torch.randn(batch_size, seq_len, d_model)

self_attn_layer = SelfAttention(d_model=8, d_k=8, d_v=8)
output, weights = self_attn_layer(x)

print(f"Input shape:  {x.shape}  (batch=2, seq=6, d_model=8)")
print(f"Output shape: {output.shape}  (same as input!)")
print(f"Weight shape: {weights.shape}  (batch=2, 6x6 attention matrix)")
print(f"\nAttention weights for first sample (each row sums to 1):")
print(weights[0].detach().numpy().round(3))

# Count parameters
n_params = sum(p.numel() for p in self_attn_layer.parameters())
print(f"\nTotal parameters: {n_params} (3 projection matrices of {d_model}x{8})")

### Deep Dive: Why Self-Attention Captures Long-Range Dependencies

Self-attention has a key advantage over RNNs: **any position can directly attend to any other position in one step**. In an RNN, information from position 1 must travel through every intermediate position to reach position 100. At each step, it gets processed and potentially distorted.

| Property | RNN | Self-Attention |
|----------|-----|----------------|
| Max path length | O(n) | **O(1)** |
| Computation per layer | O(n) | O(n^2) |
| Parallelizable | No (sequential) | **Yes** |
| Long-range dependencies | Difficult | **Easy** |

**F1 analogy:** With an RNN, information about lap 1 has to pass through 49 intermediate laps to influence the prediction for lap 50 -- degrading at each step (the vanishing gradient problem from the previous notebook). With self-attention, lap 50 can directly "look at" lap 1 in a single step. It is as if the strategist has the full lap chart open and can jump to any lap instantly, rather than having to mentally replay the race forward lap by lap.

#### Key Insight

Self-attention trades sequential processing for parallel processing. The O(n^2) cost is a limitation for very long sequences, but the ability to directly connect distant positions makes it far better at capturing long-range patterns.

#### Common Misconceptions

| Misconception | Reality |
|---------------|---------|
| "Q, K, V are three different inputs" | They're three different **projections** of the same input |
| "Self-attention knows word order" | It doesn't! You need positional encoding (covered in the Transformer notebook) |
| "Self-attention replaces RNNs everywhere" | For very long sequences, efficient attention variants are needed |

---

## 6. Multi-Head Attention

### Intuitive Explanation

A single attention head can only focus on one type of relationship at a time. But language has many types of relationships happening simultaneously:

- **Syntactic:** "cat" relates to "sat" (subject-verb)
- **Semantic:** "cat" relates to "animal" (meaning)
- **Positional:** "the" relates to the next word (article-noun)
- **Coreference:** "it" relates to "cat" (pronoun reference)

**Multi-head attention** runs multiple attention operations in parallel, each with its own learned projections. Each "head" can learn to focus on a different type of relationship.

**F1 analogy:** Multi-head attention is like having multiple specialists on the pit wall, each attending to a different aspect of the race simultaneously. Head 1 tracks **pace** (which laps had similar lap times?). Head 2 tracks **tire wear** (which laps had similar degradation rates?). Head 3 tracks **fuel load** (which laps had similar fuel levels?). Head 4 tracks **track evolution** (which laps had similar grip levels?). Each head independently decides what is relevant from its own perspective, and the results are combined into a rich, multi-faceted understanding of the current race state.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$$

where $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

#### Breaking down the formula:

| Component | Shape | Meaning | F1 Parallel |
|-----------|-------|---------|-------------|
| $h$ | scalar | Number of heads (typically 8 or 16) | Number of specialist engineers on the pit wall |
| $W_i^Q, W_i^K$ | $(d_{model}, d_k)$ | Per-head query/key projections | Each specialist's "lens" for viewing the data |
| $W_i^V$ | $(d_{model}, d_v)$ | Per-head value projection | What information each specialist extracts |
| $d_k = d_v = d_{model}/h$ | scalar | Each head uses a fraction of the total dimension | Each specialist gets a share of the total bandwidth |
| $W^O$ | $(hd_v, d_{model})$ | Output projection to combine heads | Strategy meeting: combine all specialist inputs into one decision |

**What this means:** Instead of one big attention with $d_{model}$-dimensional Q/K/V, we run $h$ smaller attentions in parallel (each with $d_{model}/h$ dimensions), then concatenate and project. This costs the same as single-head attention but gives the model multiple "perspectives."

### Visualization: Different Heads Attend to Different Things

In [None]:
# Visualize how different attention heads learn different patterns
words = ['The', 'cat', 'sat', 'on', 'the', 'mat']
n = len(words)

# Simulated attention patterns for 4 different heads
heads = {
    'Head 1: Positional\n(next word)': np.array([
        [0.1, 0.7, 0.05, 0.05, 0.05, 0.05],
        [0.05, 0.1, 0.7, 0.05, 0.05, 0.05],
        [0.05, 0.05, 0.1, 0.7, 0.05, 0.05],
        [0.05, 0.05, 0.05, 0.1, 0.7, 0.05],
        [0.05, 0.05, 0.05, 0.05, 0.1, 0.7],
        [0.05, 0.05, 0.05, 0.05, 0.15, 0.7],
    ]),
    'Head 2: Syntactic\n(subject-verb)': np.array([
        [0.3, 0.5, 0.05, 0.05, 0.05, 0.05],
        [0.1, 0.2, 0.5, 0.05, 0.05, 0.1],
        [0.05, 0.6, 0.15, 0.05, 0.05, 0.1],
        [0.05, 0.05, 0.5, 0.15, 0.05, 0.2],
        [0.05, 0.05, 0.05, 0.05, 0.3, 0.5],
        [0.05, 0.1, 0.4, 0.2, 0.05, 0.2],
    ]),
    'Head 3: Semantic\n(related nouns)': np.array([
        [0.3, 0.2, 0.1, 0.1, 0.15, 0.15],
        [0.05, 0.3, 0.1, 0.05, 0.05, 0.45],
        [0.1, 0.1, 0.3, 0.1, 0.2, 0.2],
        [0.1, 0.1, 0.1, 0.3, 0.2, 0.2],
        [0.15, 0.1, 0.1, 0.1, 0.3, 0.25],
        [0.05, 0.45, 0.1, 0.05, 0.05, 0.3],
    ]),
    'Head 4: Determiner\n(article-noun)': np.array([
        [0.2, 0.6, 0.05, 0.05, 0.05, 0.05],
        [0.3, 0.3, 0.1, 0.1, 0.1, 0.1],
        [0.1, 0.1, 0.3, 0.1, 0.2, 0.2],
        [0.1, 0.1, 0.1, 0.3, 0.1, 0.3],
        [0.05, 0.05, 0.05, 0.05, 0.2, 0.6],
        [0.1, 0.1, 0.1, 0.1, 0.3, 0.3],
    ]),
}

fig, axes = plt.subplots(1, 4, figsize=(18, 4))

for ax, (title, attn) in zip(axes, heads.items()):
    im = ax.imshow(attn, cmap='Purples', aspect='auto', vmin=0, vmax=0.7)
    ax.set_xticks(range(n))
    ax.set_xticklabels(words, fontsize=9, rotation=45)
    ax.set_yticks(range(n))
    ax.set_yticklabels(words, fontsize=9)
    ax.set_title(title, fontsize=10)

plt.suptitle('Multi-Head Attention: Each Head Learns Different Patterns', fontsize=14, y=1.05)
plt.tight_layout()
plt.show()

print("Each head captures a different type of relationship:")
print("  Head 1: Adjacent word patterns (local context)")
print("  Head 2: Subject-verb agreement (syntactic structure)")
print("  Head 3: Semantically related words (meaning)")
print("  Head 4: Determiner-noun pairs (grammatical role)")

### Implement Multi-Head Attention from Scratch

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Multi-head attention from scratch.
    
    Runs h parallel attention heads, each operating on d_model/h dimensions,
    then concatenates and projects the results.
    """
    def __init__(self, d_model, n_heads):
        """
        Args:
            d_model: Model dimension (must be divisible by n_heads)
            n_heads: Number of attention heads
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # Dimension per head
        
        # Single large projection matrices (more efficient than separate per-head)
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)  # Output projection
    
    def forward(self, query, key, value, mask=None):
        """
        Args:
            query: (batch, seq_len_q, d_model)
            key:   (batch, seq_len_k, d_model)
            value: (batch, seq_len_k, d_model)
            mask:  Optional mask
            
        Returns:
            output: (batch, seq_len_q, d_model)
            weights: (batch, n_heads, seq_len_q, seq_len_k)
        """
        batch_size = query.shape[0]
        
        # Step 1: Project Q, K, V
        Q = self.W_q(query)  # (batch, seq_q, d_model)
        K = self.W_k(key)    # (batch, seq_k, d_model)
        V = self.W_v(value)  # (batch, seq_k, d_model)
        
        # Step 2: Reshape to (batch, n_heads, seq_len, d_k)
        # This splits d_model into n_heads separate d_k-dimensional spaces
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # Step 3: Apply scaled dot-product attention (works on last 2 dims)
        output, weights = scaled_dot_product_attention(Q, K, V, mask)
        # output: (batch, n_heads, seq_q, d_k)
        # weights: (batch, n_heads, seq_q, seq_k)
        
        # Step 4: Concatenate heads
        # Transpose back and reshape: (batch, seq_q, n_heads * d_k) = (batch, seq_q, d_model)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # Step 5: Final linear projection
        output = self.W_o(output)
        
        return output, weights

# Test it!
torch.manual_seed(42)

d_model, n_heads = 32, 4
mha = MultiHeadAttention(d_model=d_model, n_heads=n_heads)

batch_size, seq_len = 2, 8
x = torch.randn(batch_size, seq_len, d_model)

# Self-attention: Q=K=V=x
output, weights = mha(x, x, x)

print(f"Input shape:   {x.shape}")
print(f"Output shape:  {output.shape}  (same as input)")
print(f"Weights shape: {weights.shape}  (batch, heads, seq_q, seq_k)")
print(f"\nEach head has its own {seq_len}x{seq_len} attention matrix")
print(f"Head dimension: d_k = d_model/n_heads = {d_model}/{n_heads} = {d_model//n_heads}")
print(f"\nTotal parameters: {sum(p.numel() for p in mha.parameters())}")
print(f"  = 4 matrices of {d_model}x{d_model} = {4 * d_model * d_model}")

In [None]:
# Visualize the attention weights from each head
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for head_idx in range(n_heads):
    ax = axes[head_idx]
    head_weights = weights[0, head_idx].detach().numpy()  # First sample
    im = ax.imshow(head_weights, cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
    ax.set_xlabel('Key position')
    ax.set_ylabel('Query position')
    ax.set_title(f'Head {head_idx + 1}')

plt.suptitle('Multi-Head Attention: Each Head Learns Different Patterns', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

print("Even with random initialization, the heads already show different attention patterns.")
print("After training, these differences become much more pronounced.")

### Interactive Exploration: Number of Heads

**F1 analogy:** As you increase the number of heads, each specialist gets a narrower slice of the total bandwidth. With 1 head, one generalist sees everything in full resolution. With 64 heads, each specialist works with a single dimension -- hyper-specialized but potentially too narrow. The sweet spot depends on the task, just as the optimal pit wall staffing depends on the complexity of the race scenario.

In [None]:
# Explore: how does the number of heads affect the model?
d_model = 64
head_configs = [1, 2, 4, 8, 16, 64]

print(f"d_model = {d_model}")
print(f"{'Heads':>6} | {'d_k (per head)':>14} | {'Total Params':>12} | {'Same Total?':>11}")
print("-" * 55)

for n_h in head_configs:
    d_k = d_model // n_h
    # 4 matrices of d_model x d_model (Q, K, V, O projections)
    total_params = 4 * d_model * d_model
    print(f"{n_h:>6} | {d_k:>14} | {total_params:>12} | {'Yes':>11}")

print(f"\nKey insight: The number of heads does NOT change the parameter count!")
print(f"More heads = smaller per-head dimension, but same total computation.")
print(f"Typical choice: 8-16 heads for d_model=512-1024")

---

## 7. Attention Variants

### Overview

The scaled dot-product attention we've been studying is not the only type. Different attention mechanisms compute the query-key similarity score differently.

### 7.1 Additive (Bahdanau) Attention

The original attention mechanism (2014). Uses a small neural network to compute scores:

$$\text{score}(q, k) = v^T \tanh(W_1 q + W_2 k)$$

**Intuition:** Instead of assuming that similarity is measured by dot product, learn a small network that determines how well a query matches a key.

### 7.2 Multiplicative (Luong) Attention

A simpler approach using a learned weight matrix:

$$\text{score}(q, k) = q^T W k$$

**Intuition:** Like dot-product attention, but with a learned transformation in between. This lets the model learn which dimensions of the query should match which dimensions of the key.

### 7.3 Dot-Product / Scaled Dot-Product

What we've already studied:

$$\text{score}(q, k) = \frac{q \cdot k}{\sqrt{d_k}}$$

### Comparison Table

| Attention Type | Score Function | Pros | Cons | Used In | F1 Parallel |
|----------------|---------------|------|------|---------|-------------|
| **Additive (Bahdanau)** | $v^T \tanh(W_1 q + W_2 k)$ | Flexible, works well for different Q/K dims | Slower (MLP forward pass) | Original seq2seq attention | A neural network learns complex lap-to-lap relevance |
| **Multiplicative (Luong)** | $q^T W k$ | Learns cross-dim interactions | Extra parameters | Seq2seq variants | Learns which telemetry channels should match which |
| **Dot-Product** | $q \cdot k$ | Fastest, no extra params | Assumes Q and K in same space | Simple models | Direct similarity between lap profiles |
| **Scaled Dot-Product** | $\frac{q \cdot k}{\sqrt{d_k}}$ | Fast + stable gradients | Assumes Q and K in same space | **Transformers** | Stable lap comparison at any telemetry resolution |

In [None]:
# Implement all three attention scoring functions
class AdditiveAttention(nn.Module):
    """Bahdanau (additive) attention."""
    def __init__(self, d_query, d_key, d_hidden):
        super().__init__()
        self.W1 = nn.Linear(d_query, d_hidden, bias=False)
        self.W2 = nn.Linear(d_key, d_hidden, bias=False)
        self.v = nn.Linear(d_hidden, 1, bias=False)
    
    def forward(self, query, keys, values):
        # query: (batch, 1, d_query), keys: (batch, seq, d_key)
        scores = self.v(torch.tanh(self.W1(query) + self.W2(keys)))  # (batch, seq, 1)
        scores = scores.squeeze(-1)  # (batch, seq)
        weights = F.softmax(scores, dim=-1)  # (batch, seq)
        output = torch.bmm(weights.unsqueeze(1), values)  # (batch, 1, d_v)
        return output.squeeze(1), weights

class MultiplicativeAttention(nn.Module):
    """Luong (multiplicative) attention."""
    def __init__(self, d_query, d_key):
        super().__init__()
        self.W = nn.Linear(d_key, d_query, bias=False)
    
    def forward(self, query, keys, values):
        # query: (batch, 1, d_query), keys: (batch, seq, d_key)
        transformed_keys = self.W(keys)  # (batch, seq, d_query)
        scores = torch.bmm(query, transformed_keys.transpose(1, 2))  # (batch, 1, seq)
        scores = scores.squeeze(1)  # (batch, seq)
        weights = F.softmax(scores, dim=-1)
        output = torch.bmm(weights.unsqueeze(1), values)
        return output.squeeze(1), weights

# Compare the three attention types
torch.manual_seed(42)
batch, seq_len, d = 1, 6, 8

query = torch.randn(batch, 1, d)
keys = torch.randn(batch, seq_len, d)
values = torch.randn(batch, seq_len, d)

additive = AdditiveAttention(d, d, d)
multiplicative = MultiplicativeAttention(d, d)

out_add, w_add = additive(query, keys, values)
out_mul, w_mul = multiplicative(query, keys, values)

# Scaled dot-product for comparison
scores_dot = torch.bmm(query, keys.transpose(1, 2)) / np.sqrt(d)
w_dot = F.softmax(scores_dot.squeeze(1), dim=-1)
out_dot = torch.bmm(w_dot.unsqueeze(1), values).squeeze(1)

# Visualize weight distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
names = ['Additive (Bahdanau)', 'Multiplicative (Luong)', 'Scaled Dot-Product']
all_weights = [w_add[0].detach().numpy(), w_mul[0].detach().numpy(), w_dot[0].detach().numpy()]
colors = ['blue', 'green', 'red']

for ax, name, w, color in zip(axes, names, all_weights, colors):
    ax.bar(range(seq_len), w, color=color, edgecolor='black', alpha=0.7)
    ax.set_xlabel('Key Position')
    ax.set_ylabel('Attention Weight')
    ax.set_title(name)
    ax.set_ylim(0, 0.5)
    ax.grid(True, alpha=0.3)

plt.suptitle('Attention Weight Distribution by Type\n(same Q, K, V -- different scoring)', fontsize=13)
plt.tight_layout()
plt.show()

### 7.4 Causal (Masked) Attention

In autoregressive models (like GPT), each position should only attend to **previous** positions -- it shouldn't be able to "peek" at future tokens it hasn't generated yet. We achieve this by masking out future positions with $-\infty$ before softmax.

**F1 analogy:** Causal masking is like real-time race prediction where you can only use data from laps that have already happened. When predicting lap 30, you cannot look ahead at laps 31-56 -- those have not occurred yet. This is the constraint that autoregressive models operate under, and it mirrors exactly how a strategist makes decisions during a live race.

### 7.5 Cross-Attention vs Self-Attention

| Type | Queries From | Keys/Values From | Use Case | F1 Parallel |
|------|-------------|-----------------|----------|-------------|
| **Self-Attention** | Same sequence | Same sequence | Understanding context within one sequence | Each lap attending to every other lap in the same race |
| **Cross-Attention** | Target sequence | Source sequence | Connecting encoder and decoder (translation) | Comparing your car's laps against a rival's data |
| **Causal Self-Attention** | Same sequence | Same sequence (past only) | Autoregressive generation (GPT) | Real-time strategy: only use data from completed laps |

In [None]:
# Demonstrate causal masking
torch.manual_seed(42)
seq_len = 6

# Create a causal mask: 1 = attend, 0 = mask out
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)

Q = torch.randn(1, seq_len, 8)
K = torch.randn(1, seq_len, 8)
V = torch.randn(1, seq_len, 8)

# Without mask (bidirectional)
_, weights_full = scaled_dot_product_attention(Q, K, V)

# With causal mask (autoregressive)
_, weights_causal = scaled_dot_product_attention(Q, K, V, mask=causal_mask)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Mask
ax = axes[0]
im = ax.imshow(causal_mask[0, 0].numpy(), cmap='Greens', aspect='auto')
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_title('Causal Mask\n(1=attend, 0=block)')
for i in range(seq_len):
    for j in range(seq_len):
        val = int(causal_mask[0, 0, i, j].item())
        ax.text(j, i, str(val), ha='center', va='center', fontsize=12,
               color='white' if val else 'black')

# Full attention
ax = axes[1]
im = ax.imshow(weights_full[0].detach().numpy(), cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_title('Full (Bidirectional) Attention\n(can see everything)')

# Causal attention
ax = axes[2]
im = ax.imshow(weights_causal[0].detach().numpy(), cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_title('Causal (Masked) Attention\n(can only see past)')

plt.tight_layout()
plt.show()

print("Causal attention: position i can only attend to positions 0, 1, ..., i")
print("This prevents the model from cheating by looking at future tokens.")
print("Used in GPT and all autoregressive language models.")

### Why This Matters in Machine Learning

| Application | Attention Type Used | F1 Parallel |
|-------------|-------------------|-------------|
| Machine translation (BERT encoder) | Bidirectional self-attention | Post-race analysis: look at entire race both ways |
| Text generation (GPT) | Causal self-attention | Live strategy: only see completed laps |
| Translation decoder | Causal self-attention + cross-attention | Real-time comparison with rival car data |
| Image recognition (ViT) | Bidirectional self-attention on patches | Scanning a track map where each patch is a corner |
| Speech recognition | Self-attention + cross-attention | Decoding team radio with full audio context |
| Protein structure (AlphaFold) | Multi-head self-attention | Multi-aspect telemetry analysis |

---

## 8. Practical Example: Sequence Reversal with Attention

### Building a Complete Model

Let's build a simple sequence-to-sequence model with attention and train it on a concrete task: **reversing a sequence of numbers**. This is simple enough to train quickly but complex enough to show how attention works.

For example: `[1, 3, 5, 7, 9]` --> `[9, 7, 5, 3, 1]`

Without attention, this task is hard because the decoder needs to remember the entire input. With attention, the decoder can simply look at the right position!

**F1 analogy:** Think of this as taking the lap chart in chronological order and producing it in reverse -- last lap first. With attention, the model can look directly at whatever position it needs. When generating the first output (the last lap), it attends strongly to the last input position. This is a toy task, but it cleanly demonstrates how attention weights align with the structure of the problem.

In [None]:
# Dataset: reverse sequences of integers
def generate_reversal_data(n_samples, seq_len, vocab_size=10):
    """
    Generate sequence reversal pairs.
    
    Args:
        n_samples: Number of training examples
        seq_len: Length of each sequence
        vocab_size: Number of distinct tokens (0 to vocab_size-1)
    
    Returns:
        src: Source sequences (n_samples, seq_len)
        tgt: Target sequences (n_samples, seq_len) - reversed source
    """
    src = torch.randint(1, vocab_size, (n_samples, seq_len))  # 0 reserved for padding
    tgt = src.flip(dims=[1])  # Reverse along sequence dimension
    return src, tgt

# Generate data
torch.manual_seed(42)
seq_len = 8
vocab_size = 10
n_train = 5000
n_test = 500

train_src, train_tgt = generate_reversal_data(n_train, seq_len, vocab_size)
test_src, test_tgt = generate_reversal_data(n_test, seq_len, vocab_size)

print("Example pairs (source -> target):")
for i in range(5):
    print(f"  {train_src[i].tolist()} -> {train_tgt[i].tolist()}")

In [None]:
class Seq2SeqWithAttention(nn.Module):
    """
    Simple sequence-to-sequence model with attention.
    
    Uses embeddings + self-attention (no RNN!) to process sequences.
    This is a simplified "transformer-style" approach.
    """
    def __init__(self, vocab_size, d_model=32, n_heads=4, n_layers=2):
        """
        Args:
            vocab_size: Size of token vocabulary
            d_model: Embedding and model dimension
            n_heads: Number of attention heads
            n_layers: Number of self-attention layers
        """
        super().__init__()
        self.d_model = d_model
        
        # Token embeddings
        self.src_embed = nn.Embedding(vocab_size, d_model)
        self.tgt_embed = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding (learned)
        self.src_pos = nn.Embedding(50, d_model)
        self.tgt_pos = nn.Embedding(50, d_model)
        
        # Encoder: self-attention layers
        self.encoder_layers = nn.ModuleList([
            MultiHeadAttention(d_model, n_heads) for _ in range(n_layers)
        ])
        self.encoder_norms = nn.ModuleList([
            nn.LayerNorm(d_model) for _ in range(n_layers)
        ])
        
        # Decoder: cross-attention to encoder output
        self.cross_attention = MultiHeadAttention(d_model, n_heads)
        self.cross_norm = nn.LayerNorm(d_model)
        
        # Output projection
        self.output_proj = nn.Linear(d_model, vocab_size)
        
        # Store attention weights for visualization
        self.attention_weights = None
    
    def encode(self, src):
        """Encode source sequence."""
        batch_size, seq_len = src.shape
        positions = torch.arange(seq_len, device=src.device).unsqueeze(0)
        
        x = self.src_embed(src) + self.src_pos(positions)
        
        for attn, norm in zip(self.encoder_layers, self.encoder_norms):
            residual = x
            attn_out, _ = attn(x, x, x)
            x = norm(residual + attn_out)
        
        return x
    
    def decode(self, tgt, encoder_output):
        """Decode target sequence with cross-attention to encoder."""
        batch_size, seq_len = tgt.shape
        positions = torch.arange(seq_len, device=tgt.device).unsqueeze(0)
        
        x = self.tgt_embed(tgt) + self.tgt_pos(positions)
        
        # Cross-attention: target queries, encoder keys/values
        residual = x
        cross_out, self.attention_weights = self.cross_attention(x, encoder_output, encoder_output)
        x = self.cross_norm(residual + cross_out)
        
        return self.output_proj(x)
    
    def forward(self, src, tgt):
        """
        Args:
            src: Source tokens (batch, src_len)
            tgt: Target tokens (batch, tgt_len)
        
        Returns:
            logits: (batch, tgt_len, vocab_size)
        """
        encoder_output = self.encode(src)
        logits = self.decode(tgt, encoder_output)
        return logits

# Create model
torch.manual_seed(42)
model = Seq2SeqWithAttention(vocab_size=vocab_size, d_model=32, n_heads=4, n_layers=2)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model architecture:")
print(model)
print(f"\nTotal parameters: {n_params:,}")

In [None]:
# Training loop
from torch.utils.data import TensorDataset, DataLoader

torch.manual_seed(42)
model = Seq2SeqWithAttention(vocab_size=vocab_size, d_model=32, n_heads=4, n_layers=2)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

train_dataset = TensorDataset(train_src, train_tgt)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

losses = []
accuracies = []

print("Training sequence reversal model with attention...")
for epoch in range(30):
    model.train()
    epoch_loss = 0
    correct = 0
    total = 0
    
    for src_batch, tgt_batch in train_loader:
        optimizer.zero_grad()
        
        # Teacher forcing: feed true target as input
        logits = model(src_batch, tgt_batch)
        
        # Compute loss
        loss = criterion(logits.view(-1, vocab_size), tgt_batch.view(-1))
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item() * src_batch.size(0)
        
        # Compute accuracy
        preds = logits.argmax(dim=-1)
        correct += (preds == tgt_batch).all(dim=-1).sum().item()
        total += src_batch.size(0)
    
    avg_loss = epoch_loss / n_train
    accuracy = correct / total
    losses.append(avg_loss)
    accuracies.append(accuracy)
    
    if (epoch + 1) % 5 == 0:
        print(f"  Epoch {epoch+1:3d}: loss={avg_loss:.4f}, seq_accuracy={accuracy:.4f}")

# Test accuracy
model.eval()
with torch.no_grad():
    test_logits = model(test_src, test_tgt)
    test_preds = test_logits.argmax(dim=-1)
    test_acc = (test_preds == test_tgt).all(dim=-1).float().mean()
    print(f"\nTest accuracy (full sequence correct): {test_acc:.4f}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.plot(losses, 'b-', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.plot(accuracies, 'g-', linewidth=2)
ax.set_xlabel('Epoch')
ax.set_ylabel('Sequence Accuracy')
ax.set_title('Training Accuracy (entire sequence correct)')
ax.set_ylim(0, 1.05)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Visualize Attention Weights During Inference

The real power of attention: we can **see** what the model is looking at!

**F1 analogy:** This is one of the most valuable properties of attention-based models -- interpretability. When a model predicts "pit now," you can look at the attention weights and see exactly which past laps it was focusing on. Was it the degradation trend from the last 5 laps? The competitor's pit stop 3 laps ago? The fuel model from lap 1? Attention weights make the model's reasoning visible, just like a strategist explaining their call.

In [None]:
# Visualize attention patterns for sequence reversal
model.eval()

# Pick a few test examples
n_examples = 3
fig, axes = plt.subplots(n_examples, 2, figsize=(14, 4 * n_examples))

for ex_idx in range(n_examples):
    src = test_src[ex_idx:ex_idx+1]
    tgt = test_tgt[ex_idx:ex_idx+1]
    
    with torch.no_grad():
        logits = model(src, tgt)
        preds = logits.argmax(dim=-1)
    
    # Get cross-attention weights (averaged across heads)
    # Shape: (1, n_heads, tgt_len, src_len)
    attn_weights = model.attention_weights[0].mean(dim=0).numpy()  # Average over heads
    
    src_tokens = src[0].tolist()
    tgt_tokens = tgt[0].tolist()
    pred_tokens = preds[0].tolist()
    
    # Heatmap
    ax = axes[ex_idx, 0]
    im = ax.imshow(attn_weights, cmap='Blues', aspect='auto', vmin=0, vmax=0.5)
    ax.set_xticks(range(len(src_tokens)))
    ax.set_xticklabels(src_tokens, fontsize=11, fontweight='bold')
    ax.set_yticks(range(len(tgt_tokens)))
    ax.set_yticklabels(tgt_tokens, fontsize=11, fontweight='bold')
    ax.set_xlabel('Source Position')
    ax.set_ylabel('Target Position')
    ax.set_title(f'Cross-Attention Weights (avg over heads)')
    plt.colorbar(im, ax=ax)
    
    # Per-head view
    ax = axes[ex_idx, 1]
    head_weights = model.attention_weights[0].numpy()  # (n_heads, tgt, src)
    for h in range(min(4, head_weights.shape[0])):
        # Show which source position each target position attends to most
        max_attn = head_weights[h].argmax(axis=1)
        ax.plot(range(len(tgt_tokens)), max_attn, 'o-', label=f'Head {h+1}', 
                markersize=8, linewidth=2, alpha=0.7)
    
    # Perfect reversal line
    ax.plot(range(len(tgt_tokens)), list(range(len(src_tokens)-1, -1, -1)), 
            'k--', linewidth=2, alpha=0.5, label='Perfect reversal')
    ax.set_xlabel('Target Position')
    ax.set_ylabel('Most-Attended Source Position')
    ax.set_title(f'Src: {src_tokens} -> Pred: {pred_tokens}')
    ax.legend(fontsize=8)
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("The anti-diagonal pattern in the heatmap confirms the model learned to reverse!")
print("Target position 0 attends to source position 7 (last), and so on.")

---

## Exercises

### Exercise 1: Implement Dot-Product Attention from Scratch

Implement the basic (unscaled) dot-product attention function using only NumPy.

**F1 scenario:** You are building the core of a race analysis tool. Given a query ("what kind of lap am I trying to predict?"), a set of keys (each past lap's signature), and values (each past lap's telemetry), compute the attention-weighted blend of past laps. This is the fundamental operation your strategy model will use to look back at the race history.

In [None]:
# EXERCISE 1: Implement dot-product attention with NumPy
def dot_product_attention_numpy(query, keys, values):
    """
    Compute dot-product attention using NumPy.
    
    Args:
        query: Query vector of shape (d_k,)
        keys: Key matrix of shape (n, d_k)
        values: Value matrix of shape (n, d_v)
    
    Returns:
        output: Weighted sum of values, shape (d_v,)
        weights: Attention weights, shape (n,)
    """
    # TODO: Implement the three steps:
    # 1. Compute scores = keys @ query  (dot product of each key with query)
    # 2. Compute weights = softmax(scores)
    #    Hint: Use numerically stable softmax: exp(x - max(x)) / sum(exp(x - max(x)))
    # 3. Compute output = weights @ values (weighted sum)
    
    pass  # Replace with your implementation

# Test
np.random.seed(42)
q = np.array([1.0, 0.0, 1.0])
K = np.array([[1.0, 0.0, 0.0],
              [0.0, 1.0, 0.0],
              [1.0, 0.0, 1.0],  # This key matches the query best!
              [0.0, 0.0, 1.0]])
V = np.array([[1, 0],
              [0, 1],
              [1, 1],
              [0, 0]])

output, weights = dot_product_attention_numpy(q, K, V)

expected_scores = np.array([1.0, 0.0, 2.0, 1.0])
expected_weights_unnorm = np.exp(expected_scores - expected_scores.max())
expected_weights = expected_weights_unnorm / expected_weights_unnorm.sum()
expected_output = expected_weights @ V

print(f"Your weights:    {weights}")
print(f"Expected weights: {expected_weights.round(4)}")
print(f"Your output:     {output}")
print(f"Expected output:  {expected_output.round(4)}")
print(f"Correct: {np.allclose(weights, expected_weights) and np.allclose(output, expected_output)}")

### Exercise 2: Add Causal Masking to Self-Attention

Modify the SelfAttention class to support causal masking.

**F1 scenario:** Convert your race analysis tool from post-race mode (can see all laps) to live-race mode (can only see completed laps). When predicting lap 30, the model must not have access to laps 31 and beyond -- that would be looking into the future. Implement the causal mask that enforces this constraint.

In [None]:
# EXERCISE 2: Implement causal self-attention
class CausalSelfAttention(nn.Module):
    """
    Self-attention with causal masking.
    Each position can only attend to itself and previous positions.
    """
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x):
        """
        Args:
            x: Input tensor (batch, seq_len, d_model)
        
        Returns:
            output: (batch, seq_len, d_model)
            weights: (batch, seq_len, seq_len)
        """
        # TODO: Implement causal self-attention
        # Step 1: Compute Q, K, V projections
        # Step 2: Create a causal mask using torch.tril
        #   Hint: mask = torch.tril(torch.ones(seq_len, seq_len))
        #   Hint: Reshape mask for broadcasting: (1, 1, seq_len, seq_len)
        # Step 3: Use scaled_dot_product_attention with the mask
        
        pass  # Replace with your implementation

# Test
torch.manual_seed(42)
causal_attn = CausalSelfAttention(d_model=8)
x = torch.randn(1, 5, 8)

output, weights = causal_attn(x)

# Check that future positions have zero attention weight
print("Attention weights:")
print(weights[0].detach().numpy().round(3))

print(f"\nUpper triangle should be ~0:")
upper_triangle = weights[0].detach().numpy()[np.triu_indices(5, k=1)]
print(f"Max value in upper triangle: {upper_triangle.max():.6f}")
print(f"Causal masking working: {upper_triangle.max() < 1e-6}")

### Exercise 3: Multi-Head Attention Comparison

Compare 1-head, 4-head, and 8-head attention on the sequence reversal task.

**F1 scenario:** Compare having 1 generalist engineer vs 4 specialists vs 8 specialists on the pit wall. With more heads, each specialist focuses on a narrower aspect of the race data (pace, tires, fuel, gaps) but with finer granularity. Does having more specialized perspectives improve the model's ability to learn the task? Run the experiment and find out.

In [None]:
# EXERCISE 3: Compare different numbers of heads
def train_and_evaluate(n_heads, d_model=32, epochs=20):
    """
    Train a Seq2SeqWithAttention model with a given number of heads.
    
    Args:
        n_heads: Number of attention heads
        d_model: Model dimension (must be divisible by n_heads)
        epochs: Number of training epochs
    
    Returns:
        test_accuracy: Fraction of test sequences fully reversed correctly
        losses: List of training losses per epoch
    """
    # TODO: Implement this!
    # Hint: Create a Seq2SeqWithAttention model with the given n_heads
    # Hint: Train it using the same training loop as Section 8
    # Hint: Evaluate on test data and return accuracy
    
    pass  # Replace with your implementation

# Test your implementation:
# head_configs = [1, 2, 4, 8]
# results = {}
# for n_h in head_configs:
#     acc, losses = train_and_evaluate(n_h)
#     results[n_h] = (acc, losses)
#     print(f"  {n_h} heads: test accuracy = {acc:.4f}")
#
# # Plot loss curves
# for n_h, (acc, losses) in results.items():
#     plt.plot(losses, label=f'{n_h} heads (acc={acc:.3f})')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Effect of Number of Attention Heads')
# plt.legend()
# plt.grid(True, alpha=0.3)
# plt.show()

---

## Summary

### Key Concepts

**The Bottleneck Problem:**
- Fixed-length encoding forces entire sequences into one vector
- Information loss grows with sequence length
- Attention solves this by letting the decoder look at all encoder states
- **F1 parallel:** Instead of compressing 78 laps into one summary vector, attention lets the model look up any specific lap whenever it needs to

**The Q, K, V Framework:**
- **Query:** "What am I looking for?"
- **Key:** "What do I have to offer?"
- **Value:** "Here is my content"
- Attention = softmax(similarity(Q, K)) * V
- **F1 parallel:** Query = "which past laps are relevant to predicting THIS lap?" Keys = each lap's signature. Values = each lap's actual telemetry data.

**Scaled Dot-Product Attention:**
- Score = Q . K^T / sqrt(d_k)
- Scaling prevents softmax saturation in high dimensions
- Foundation of all modern attention mechanisms
- **F1 parallel:** Scaling prevents the model from fixating on a single lap -- it blends information from several relevant laps

**Self-Attention:**
- Q, K, V are all projected from the same input
- Each position can attend to every other position
- O(1) path length between any two positions (vs O(n) for RNNs)
- **F1 parallel:** Each lap attending to every other lap in the same race -- lap 50 can directly access lap 1 without information degradation

**Multi-Head Attention:**
- Multiple parallel attention "perspectives"
- Each head can learn different relationship types
- Same parameter count as single-head (just splits d_model)
- **F1 parallel:** Multiple specialist engineers simultaneously attending to pace, tire wear, fuel load, and track evolution

**Attention Variants:**
- Additive (Bahdanau): MLP-based scoring
- Multiplicative (Luong): Learned bilinear scoring
- Scaled dot-product: Used in Transformers
- Causal masking: For autoregressive generation (live race, no peeking at future laps)

### Connection to Deep Learning

| Concept | Where It's Used | F1 Parallel |
|---------|----------------|-------------|
| Scaled dot-product attention | Core of every Transformer | Foundation of race-state prediction |
| Multi-head attention | Encoder and decoder blocks | Multiple specialists on the pit wall |
| Self-attention | BERT, GPT, ViT, and almost every modern model | Each lap understanding its context from all other laps |
| Cross-attention | Encoder-decoder models, image captioning | Comparing your telemetry against a rival's |
| Causal masking | GPT, autoregressive language models | Live-race prediction (no future data) |
| Q, K, V projections | All attention-based architectures | The query-key-value lookup that powers the entire strategy system |

### Checklist

- [ ] I can explain the bottleneck problem with fixed-length encoding
- [ ] I understand the Q, K, V framework and the database analogy
- [ ] I can compute dot-product attention by hand on small examples
- [ ] I know why we divide by sqrt(d_k) and what happens without it
- [ ] I can implement self-attention from scratch
- [ ] I can implement multi-head attention and explain the head dimension split
- [ ] I can compare additive, multiplicative, and dot-product attention
- [ ] I understand causal masking and when it's needed
- [ ] I can build and train a model with attention and visualize its attention weights

---

## Next Steps

You now understand the **attention mechanism** -- the single most important building block in modern deep learning. Everything you've learned here feeds directly into the **Transformer architecture**, which is the next notebook.

In the Transformer notebook, you'll see how attention is combined with:
1. **Positional encoding** -- giving the model a sense of word order (remember, self-attention alone has no notion of position!) In F1 terms: telling the model that lap 1 comes before lap 2 -- something attention alone does not know.
2. **Feed-forward layers** -- adding per-position nonlinear processing
3. **Layer normalization and residual connections** -- stabilizing deep attention networks
4. **The full encoder-decoder architecture** -- stacking multiple attention layers

**The key takeaway:** Attention lets a model dynamically decide what to focus on. Instead of processing sequences step-by-step (RNNs) or looking at fixed windows (CNNs), attention can connect any two positions in a sequence directly. This is why Transformers have replaced RNNs and CNNs for most sequence tasks, and why attention is now used even in vision (ViT), audio, and protein folding. In F1 terms, attention replaced the "relay every lap sequentially through memory" approach with a "look up any lap instantly" approach -- and that single change transformed the entire field.

**Practical next steps:**
- Try modifying the sequence reversal model to handle variable-length sequences
- Experiment with different attention types (additive vs dot-product) on the same task
- Read the "Attention Is All You Need" paper (Vaswani et al., 2017) -- you now have the background to understand it
- Explore PyTorch's built-in `nn.MultiheadAttention` and compare with your implementation