# Long Context & Efficient Attention

In this notebook, you'll build and explore the three architectural innovations that enable transformers to scale from 1K to 100K+ token contexts. No pretrained models or GPUs needed—everything runs in seconds on CPU with toy-scale tensors.

**What you'll do:**
- Build the RoPE rotation matrix for a 2D subspace, apply it to Q/K pairs at different positions, and verify that the dot product depends on relative position—not absolute position
- Build full causal and sliding window attention masks, visualize them as heatmaps, and compare the number of computed attention scores
- Implement a Grouped Query Attention (GQA) forward pass with `n_heads=8` and `n_kv_heads=2`, and compare KV cache sizes for MHA vs GQA
- Build an attention cost calculator: compute total FLOPs and KV cache memory for different model configurations (GPT-2, LLaMA 2 70B with MHA vs GQA, with/without sliding window) and see how the three optimizations compound

**For each exercise, PREDICT the output before running the cell.** Wrong predictions are more valuable than correct ones—they reveal gaps in your mental model.

In [None]:
# Setup—self-contained for Google Colab
# No extra pip installs needed—torch and matplotlib are in Colab by default.

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math

# Reproducible results
torch.manual_seed(42)
np.random.seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
print('Setup complete.')

---

## Exercise 1: Implement RoPE Rotation (Guided)

RoPE encodes position in the Q/K *dot product*, not the embedding. The mechanism: rotate Q and K vectors by an angle proportional to their position. When you compute `q_i^T k_j`, the rotation causes the result to depend on the *relative* position `(i - j)`, not the absolute positions `i` and `j`.

In this exercise, you'll:
1. Build the 2D rotation matrix `R(theta)`—the fundamental building block of RoPE
2. Rotate Q and K vectors at different absolute positions but the same relative distance
3. Compute dot products and verify the relative position property
4. Visualize the rotations in the 2D plane for several relative distances

**Before running, predict:**
- If `q` is at position 3 and `k` is at position 7 (relative distance 4), and then we shift both to positions 1003 and 1007 (still relative distance 4), will the dot product `q_rot^T k_rot` change?
- What about if we keep `q` at position 3 but move `k` to position 8 (relative distance 5)? Will the dot product be the same as before or different?
- The rotation angle is `position * theta`. If `theta = 0.5`, what angle is applied at position 3? At position 1003?

In [None]:
# --- The 2D rotation matrix: the building block of RoPE ---
#
# RoPE pairs consecutive dimensions of Q and K vectors into 2D subspaces.
# Each 2D subspace gets its own rotation at a different frequency.
# Here we start with a single 2D subspace to build intuition.

def rotation_matrix_2d(angle: float) -> torch.Tensor:
    """Build a 2x2 rotation matrix for the given angle (radians).
    
    R(angle) = [[cos(angle), -sin(angle)],
                [sin(angle),  cos(angle)]]
    """
    c = math.cos(angle)
    s = math.sin(angle)
    return torch.tensor([[c, -s], [s, c]], dtype=torch.float32)


def rotate_vector(v: torch.Tensor, position: int, theta: float) -> torch.Tensor:
    """Rotate a 2D vector by angle = position * theta.
    
    This is what RoPE does to each 2D subspace of Q and K:
    the rotation angle is proportional to the token's position.
    """
    angle = position * theta
    R = rotation_matrix_2d(angle)
    return R @ v


# Choose a base frequency (one "hand" of the clock)
theta = 0.5  # radians per position

# Create a query vector and a key vector (2D for visualization)
q = torch.tensor([1.0, 0.3])  # some arbitrary q vector
k = torch.tensor([0.8, 0.6])  # some arbitrary k vector

print(f'Original q: {q.tolist()}')
print(f'Original k: {k.tolist()}')
print(f'theta (frequency): {theta} radians/position')
print()

# --- Test 1: Same relative distance, different absolute positions ---
# Pair A: positions 3 and 7 (relative distance = 4)
q_rot_A = rotate_vector(q, position=3, theta=theta)
k_rot_A = rotate_vector(k, position=7, theta=theta)
dot_A = torch.dot(q_rot_A, k_rot_A).item()

# Pair B: positions 1003 and 1007 (relative distance = 4)
q_rot_B = rotate_vector(q, position=1003, theta=theta)
k_rot_B = rotate_vector(k, position=1007, theta=theta)
dot_B = torch.dot(q_rot_B, k_rot_B).item()

# Pair C: positions 50000 and 50004 (relative distance = 4)
q_rot_C = rotate_vector(q, position=50000, theta=theta)
k_rot_C = rotate_vector(k, position=50004, theta=theta)
dot_C = torch.dot(q_rot_C, k_rot_C).item()

print('=== Same Relative Distance (4), Different Absolute Positions ===')
print(f'  Positions  3 & 7     -> dot product = {dot_A:.6f}')
print(f'  Positions 1003 & 1007 -> dot product = {dot_B:.6f}')
print(f'  Positions 50000 & 50004 -> dot product = {dot_C:.6f}')
print(f'  All equal? {abs(dot_A - dot_B) < 1e-5 and abs(dot_A - dot_C) < 1e-5}')
print()

# --- Test 2: Different relative distance ---
# Pair D: positions 3 and 8 (relative distance = 5)
q_rot_D = rotate_vector(q, position=3, theta=theta)
k_rot_D = rotate_vector(k, position=8, theta=theta)
dot_D = torch.dot(q_rot_D, k_rot_D).item()

# Pair E: positions 3 and 3 (relative distance = 0)
q_rot_E = rotate_vector(q, position=3, theta=theta)
k_rot_E = rotate_vector(k, position=3, theta=theta)
dot_E = torch.dot(q_rot_E, k_rot_E).item()

print('=== Different Relative Distances ===')
print(f'  Relative distance 4 -> dot product = {dot_A:.6f}')
print(f'  Relative distance 5 -> dot product = {dot_D:.6f}')
print(f'  Relative distance 0 -> dot product = {dot_E:.6f}')
print(f'  Distance 0 == unrotated dot(q, k)? {abs(dot_E - torch.dot(q, k).item()) < 1e-5}')
print()

# --- Why does this work? ---
# The rotation for q at position i is R(i*theta).
# The rotation for k at position j is R(j*theta).
# The dot product: rotate(q, i)^T @ rotate(k, j)
#   = q^T @ R(i*theta)^T @ R(j*theta) @ k
#   = q^T @ R((j-i)*theta) @ k         (rotation matrices compose)
# The result depends on (j - i), not on i or j individually.
print('=== The Math ===')
print('rotate(q, i)^T @ rotate(k, j) = q^T @ R((j-i)*theta) @ k')
print('The dot product depends on RELATIVE distance (j - i), not absolute positions.')
print(f'  R(4 * {theta}) is the same whether positions are (3,7) or (1003,1007) or (50000,50004).')

In [None]:
# --- Visualize: RoPE rotations in the 2D plane ---
# Show how the same q and k vectors look at different positions,
# and how the angular DIFFERENCE stays constant for the same relative distance.

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

q_color = '#a78bfa'  # violet
k_color = '#f59e0b'  # amber
angle_color = '#34d399'  # emerald

pairs = [
    (3, 7, 'Positions 3 & 7\n(rel. distance = 4)'),
    (1003, 1007, 'Positions 1003 & 1007\n(rel. distance = 4)'),
    (3, 8, 'Positions 3 & 8\n(rel. distance = 5)'),
]

for ax, (pos_q, pos_k, title) in zip(axes, pairs):
    q_r = rotate_vector(q, pos_q, theta)
    k_r = rotate_vector(k, pos_k, theta)
    dot_val = torch.dot(q_r, k_r).item()
    
    # Draw unit circle for reference
    circle_angles = np.linspace(0, 2*np.pi, 100)
    radius = max(q.norm().item(), k.norm().item()) * 1.1
    ax.plot(radius * np.cos(circle_angles), radius * np.sin(circle_angles),
            color='#334155', linewidth=0.5, linestyle='--')
    
    # Draw rotated q vector
    ax.annotate('', xy=(q_r[0].item(), q_r[1].item()), xytext=(0, 0),
                arrowprops=dict(arrowstyle='->', color=q_color, linewidth=2))
    ax.text(q_r[0].item() * 1.15, q_r[1].item() * 1.15,
            f'q (pos {pos_q})', color=q_color, fontsize=9, fontweight='bold')
    
    # Draw rotated k vector
    ax.annotate('', xy=(k_r[0].item(), k_r[1].item()), xytext=(0, 0),
                arrowprops=dict(arrowstyle='->', color=k_color, linewidth=2))
    ax.text(k_r[0].item() * 1.15, k_r[1].item() * 1.15,
            f'k (pos {pos_k})', color=k_color, fontsize=9, fontweight='bold')
    
    # Draw angle arc between them
    q_angle = math.atan2(q_r[1].item(), q_r[0].item())
    k_angle = math.atan2(k_r[1].item(), k_r[0].item())
    arc_r = radius * 0.35
    arc_angles = np.linspace(min(q_angle, k_angle), max(q_angle, k_angle), 50)
    ax.plot(arc_r * np.cos(arc_angles), arc_r * np.sin(arc_angles),
            color=angle_color, linewidth=2, linestyle='--')
    
    ax.set_title(title, fontsize=11, fontweight='bold')
    ax.set_xlim(-radius * 1.4, radius * 1.4)
    ax.set_ylim(-radius * 1.4, radius * 1.4)
    ax.set_aspect('equal')
    ax.axhline(0, color='#334155', linewidth=0.5)
    ax.axvline(0, color='#334155', linewidth=0.5)
    ax.text(0.02, -0.12, f'dot = {dot_val:.4f}', transform=ax.transAxes,
            fontsize=10, color=angle_color, fontweight='bold')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

fig.suptitle('RoPE: Same Relative Distance → Same Dot Product',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print('\nLeft and center: same relative distance (4), different absolute positions → same dot product.')
print('Right: different relative distance (5) → different dot product.')
print('Position enters the dot product through rotation, and the result depends on relative distance.')

In [None]:
# --- Dot product as a function of relative distance ---
# Sweep relative distances from 0 to 20 and plot the dot product.
# This shows the "fingerprint" of relative position in the attention score.

rel_distances = list(range(0, 21))
dot_products = []

base_pos = 100  # arbitrary absolute position for q
for d in rel_distances:
    q_r = rotate_vector(q, base_pos, theta)
    k_r = rotate_vector(k, base_pos + d, theta)
    dot_products.append(torch.dot(q_r, k_r).item())

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(rel_distances, dot_products, color='#34d399', linewidth=2, marker='o', markersize=5)
ax.set_xlabel('Relative Distance (j - i)', fontsize=12)
ax.set_ylabel('Dot Product (attention score contribution)', fontsize=12)
ax.set_title('RoPE: Dot Product vs Relative Distance (single 2D subspace)',
             fontsize=13, fontweight='bold')
ax.axhline(0, color='#334155', linewidth=0.5)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

print('The dot product oscillates with relative distance—like a wave.')
print('This is ONE dimension pair at ONE frequency.')
print('In full RoPE, many dimension pairs at different frequencies combine')
print('("clock with many hands"), creating a richer position signature.')

**What just happened:** The dot product between RoPE-rotated Q and K vectors depends *only* on relative position. Positions (3, 7), (1003, 1007), and (50000, 50004) all produce the same dot product because they all have relative distance 4. Change the relative distance to 5, and the dot product changes.

This is the fundamental property that enables context extension. A model trained at 4K context learns what the dot product looks like for relative distances 1, 5, 100, 2000. At inference on a 32K document, those same relative distances produce the same dot products—the model's learned attention patterns transfer. With learned PE, position 4097 literally has no embedding. With RoPE, position 4097 is just another rotation angle.

The oscillating dot product curve shows how a single 2D subspace encodes relative distance. Full RoPE uses many subspaces at different frequencies—the "clock with many hands" from sinusoidal PE. Some hands rotate fast (local position), others rotate slow (global position). Together they create a unique fingerprint for each relative distance.

---

## Exercise 2: Build Sliding Window Attention Mask (Supported)

The second long-context barrier is quadratic compute. Full causal attention computes `O(n^2)` dot products—every token with every previous token. Sliding window attention restricts each token to attend only to its nearest `w` predecessors, reducing cost to `O(n * w)` where `w << n`.

In this exercise, you'll:
1. Build a full causal attention mask (lower triangular)
2. Build a sliding window attention mask (diagonal band within the causal triangle)
3. Visualize both as heatmaps and compare the number of computed attention scores

The causal mask is provided. You'll implement the sliding window mask.

**Before running, predict:**
- For a sequence of length 32, how many attention scores does full causal attention compute? (Hint: lower triangular including diagonal)
- With a sliding window of `w = 8`, how many scores does each token compute? How many total?
- What fraction of the full causal computation does the sliding window require?

In [None]:
# --- Full causal mask (provided) ---
# The lower-triangular mask you built in the decoder-only transformers lesson:
# each token can attend to itself and all previous tokens.

seq_len = 32
window_size = 8

def build_causal_mask(n: int) -> torch.Tensor:
    """Full causal attention mask. 1 = can attend, 0 = masked.
    Lower triangular: token i can attend to tokens 0..i.
    """
    return torch.tril(torch.ones(n, n))


# --- Sliding window mask (you implement this) ---

def build_sliding_window_mask(n: int, w: int) -> torch.Tensor:
    """Sliding window attention mask. 1 = can attend, 0 = masked.
    Token i can attend to tokens max(0, i-w+1)..i (the nearest w tokens).
    
    This is the causal mask PLUS a further restriction: tokens more than
    w-1 positions before the current token are also masked out.
    
    Args:
        n: sequence length
        w: window size (each token attends to at most w tokens)
    
    Returns:
        Tensor of shape (n, n) with 1s and 0s
    """
    # TODO: Build the mask. Two approaches:
    #
    # Approach 1 (combine two masks):
    #   1. Start with the causal mask: torch.tril(torch.ones(n, n))
    #   2. Create a "too far away" mask: torch.triu(torch.ones(n, n), diagonal=-(w-1))
    #      This keeps entries where col >= row - (w-1), i.e., within the window.
    #   3. Combine: element-wise AND (multiply) the two masks.
    #
    # Approach 2 (direct band):
    #   1. Create a matrix where entry (i,j) is 1 if:
    #      j <= i  (causal)  AND  i - j < w  (within window)
    #   2. Use: torch.ones(n,n) and mask with conditions.
    #
    # Either approach is fine. The result should be a diagonal band of 1s.
    
    pass  # Replace with your implementation


# --- Build and compare ---
causal_mask = build_causal_mask(seq_len)
window_mask = build_sliding_window_mask(seq_len, window_size)

causal_count = int(causal_mask.sum().item())
window_count = int(window_mask.sum().item())

print(f'Sequence length: {seq_len}')
print(f'Window size: {window_size}')
print()
print(f'Full causal attention scores: {causal_count}')
print(f'Sliding window attention scores: {window_count}')
print(f'Reduction: {window_count / causal_count:.1%} of full causal ({causal_count / window_count:.1f}x fewer)')
print()
print(f'Full causal formula: n(n+1)/2 = {seq_len}*{seq_len+1}/2 = {seq_len*(seq_len+1)//2}')
print(f'Sliding window formula: ~n*w = {seq_len}*{window_size} = {seq_len*window_size}')
print(f'(Exact window count may differ slightly at the start of the sequence where fewer than w tokens precede.)')

<details>
<summary>Solution</summary>

The key insight is that the sliding window mask is the intersection of two constraints: (1) causal—can only attend to earlier tokens, and (2) window—can only attend to the nearest `w` tokens. Both constraints are expressed as triangular matrices.

```python
def build_sliding_window_mask(n: int, w: int) -> torch.Tensor:
    causal = torch.tril(torch.ones(n, n))  # lower triangular
    window = torch.triu(torch.ones(n, n), diagonal=-(w - 1))  # keep entries within w
    return causal * window  # element-wise AND
```

`torch.triu(..., diagonal=-(w-1))` creates an upper-triangular matrix shifted down by `w-1` rows. The intersection with the causal mask gives a diagonal band of width `w` within the lower triangle. Entry `(i, j)` is 1 when `j <= i` (causal) AND `i - j < w` (within window).

</details>

### Helper: Working Sliding Window Mask

**Run the cell below** to get a working `build_sliding_window_mask` for the visualization and remaining exercises. If your implementation above works correctly, this just redefines the same function.

In [None]:
# --- Reference implementation for remaining exercises ---

def build_causal_mask(n: int) -> torch.Tensor:
    """Full causal attention mask."""
    return torch.tril(torch.ones(n, n))

def build_sliding_window_mask(n: int, w: int) -> torch.Tensor:
    """Sliding window attention mask."""
    causal = torch.tril(torch.ones(n, n))
    window = torch.triu(torch.ones(n, n), diagonal=-(w - 1))
    return causal * window

# Rebuild masks with reference implementation
causal_mask = build_causal_mask(seq_len)
window_mask = build_sliding_window_mask(seq_len, window_size)

causal_count = int(causal_mask.sum().item())
window_count = int(window_mask.sum().item())

print(f'Full causal scores: {causal_count}')
print(f'Sliding window scores: {window_count}')
print(f'Reduction: {window_count / causal_count:.1%} of full causal')
print('Reference masks ready.')

In [None]:
# --- Visualize: full causal vs sliding window as heatmaps ---

fig, axes = plt.subplots(1, 2, figsize=(12, 5.5))

# Full causal
im0 = axes[0].imshow(causal_mask.numpy(), cmap='YlOrRd', interpolation='nearest',
                      origin='upper', vmin=0, vmax=1)
axes[0].set_title(f'Full Causal Attention\n{causal_count} scores—O(n²)',
                  fontsize=12, fontweight='bold')
axes[0].set_xlabel('Key position (j)', fontsize=10)
axes[0].set_ylabel('Query position (i)', fontsize=10)

# Sliding window
im1 = axes[1].imshow(window_mask.numpy(), cmap='YlGn', interpolation='nearest',
                      origin='upper', vmin=0, vmax=1)
axes[1].set_title(f'Sliding Window (w={window_size})\n{window_count} scores—O(n·w)',
                  fontsize=12, fontweight='bold')
axes[1].set_xlabel('Key position (j)', fontsize=10)
axes[1].set_ylabel('Query position (i)', fontsize=10)

for ax in axes:
    ax.set_xticks(range(0, seq_len, 8))
    ax.set_yticks(range(0, seq_len, 8))

fig.suptitle(f'Attention Masks: Sequence Length {seq_len}',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print(f'Full causal: every query attends to ALL preceding tokens.')
print(f'Sliding window: every query attends to at most the {window_size} nearest tokens.')
print(f'The window mask is a diagonal band within the causal triangle.')
print(f'\nAt 128K tokens with w=4096:')
print(f'  Full causal: ~{128_000 * 128_001 // 2:,} scores')
print(f'  Sliding window: ~{128_000 * 4096:,} scores')
print(f'  That is a {128_000 * 128_001 // 2 / (128_000 * 4096):.0f}x reduction.')

**What just happened:** The full causal mask is a solid lower triangle—every token attends to every previous token. The sliding window mask is a narrow diagonal band—each token attends to only its `w` nearest predecessors. The visual difference makes the compute savings obvious: the colored area (computed scores) is dramatically smaller.

At the scale that matters—128K tokens with `w = 4096`—full causal computes ~8.2 billion scores. Sliding window computes ~524 million. A 16x reduction.

"But what about long-range dependencies?" Remember from the lesson: stacked layers of local attention propagate information across the full context through the residual stream. Layer 1 propagates information from position 1 to ~4096, layer 2 from ~4096 to ~8192, and by layer 13 information from position 1 has reached position 50,000.

---

## Exercise 3: GQA Forward Pass (Supported)

The third long-context barrier is KV cache memory. In standard multi-head attention (MHA), every head has its own K and V projections—and its own K/V cache during generation. GQA shares K/V across groups of Q heads: 8 Q heads might share 2 K/V heads, giving a 4x reduction in KV cache memory while preserving query diversity.

In this exercise, you'll:
1. Implement a GQA forward pass with `n_heads=8` and `n_kv_heads=2`
2. Run it on a sample input and verify the output shape
3. Compare KV cache sizes: MHA (8 KV heads) vs GQA (2 KV heads)

The projection matrices and the Q computation are provided. You'll implement the K/V head expansion (repeating shared K/V to match Q heads) and the attention computation.

**Before running, predict:**
- With `n_heads=8` and `n_kv_heads=2`, how many Q heads share each K/V head?
- What is the output shape of GQA compared to standard MHA? Same or different?
- If MHA uses 8 separate K and V caches, and GQA uses 2, what is the memory reduction?

In [None]:
# --- GQA Implementation ---
#
# In standard MHA: n_heads Q heads, n_heads K/V heads.
# In GQA: n_heads Q heads, n_kv_heads K/V heads (n_kv_heads < n_heads).
# Multiple Q heads share the same K/V head.
#
# The key operation: "expand" the K/V from n_kv_heads to n_heads by repeating
# each K/V head for (n_heads // n_kv_heads) Q heads.

# --- Configuration ---
d_model = 64
n_heads = 8
n_kv_heads = 2
d_head = d_model // n_heads  # 64 // 8 = 8
seq_len_ex3 = 16  # short sequence for this exercise
heads_per_kv_group = n_heads // n_kv_heads  # 8 // 2 = 4 Q heads per KV head

print(f'Model config:')
print(f'  d_model = {d_model}')
print(f'  n_heads (Q) = {n_heads}')
print(f'  n_kv_heads (K/V) = {n_kv_heads}')
print(f'  d_head = {d_head}')
print(f'  Q heads per KV group = {heads_per_kv_group}')
print()

# --- Projection matrices ---
# Q projections: one per Q head (8 heads)
W_Q = nn.Linear(d_model, n_heads * d_head, bias=False)
# K and V projections: one per KV head (only 2 heads!)
W_K = nn.Linear(d_model, n_kv_heads * d_head, bias=False)
W_V = nn.Linear(d_model, n_kv_heads * d_head, bias=False)
# Output projection
W_O = nn.Linear(n_heads * d_head, d_model, bias=False)

print(f'W_Q shape: ({d_model}, {n_heads * d_head}) —projects to {n_heads} Q heads')
print(f'W_K shape: ({d_model}, {n_kv_heads * d_head}) —projects to {n_kv_heads} K heads')
print(f'W_V shape: ({d_model}, {n_kv_heads * d_head}) —projects to {n_kv_heads} V heads')
print(f'W_O shape: ({n_heads * d_head}, {d_model}) —combines all {n_heads} head outputs')

In [None]:
# --- GQA forward pass ---

torch.manual_seed(42)
x = torch.randn(seq_len_ex3, d_model)  # (seq_len, d_model)

# Step 1: Project Q, K, V
Q = W_Q(x)  # (seq_len, n_heads * d_head)
K = W_K(x)  # (seq_len, n_kv_heads * d_head)
V = W_V(x)  # (seq_len, n_kv_heads * d_head)

print('After projection:')
print(f'  Q shape: {Q.shape} —{n_heads} Q heads x {d_head} dims')
print(f'  K shape: {K.shape} —{n_kv_heads} K heads x {d_head} dims')
print(f'  V shape: {V.shape} —{n_kv_heads} V heads x {d_head} dims')
print()

# Step 2: Reshape to separate heads
Q = Q.view(seq_len_ex3, n_heads, d_head).transpose(0, 1)       # (n_heads, seq_len, d_head)
K = K.view(seq_len_ex3, n_kv_heads, d_head).transpose(0, 1)    # (n_kv_heads, seq_len, d_head)
V = V.view(seq_len_ex3, n_kv_heads, d_head).transpose(0, 1)    # (n_kv_heads, seq_len, d_head)

print('After reshape to separate heads:')
print(f'  Q: {Q.shape} —{n_heads} heads, each attending over {seq_len_ex3} tokens')
print(f'  K: {K.shape} —only {n_kv_heads} KV heads!')
print(f'  V: {V.shape} —only {n_kv_heads} KV heads!')
print()

# Step 3: Expand K and V to match Q heads
#
# TODO: Repeat each KV head to serve `heads_per_kv_group` Q heads.
#
# K is shape (n_kv_heads, seq_len, d_head)—we need (n_heads, seq_len, d_head).
# Each of the 2 KV heads needs to be repeated 4 times (heads_per_kv_group = 4).
#
# Use torch.repeat_interleave(K, repeats=heads_per_kv_group, dim=0)
# This repeats along dim=0: [KV_0, KV_1] -> [KV_0, KV_0, KV_0, KV_0, KV_1, KV_1, KV_1, KV_1]
# Result: (n_heads, seq_len, d_head)
#
# Do the same for V.

K_expanded = None  # TODO: expand K from (n_kv_heads, ...) to (n_heads, ...)
V_expanded = None  # TODO: expand V from (n_kv_heads, ...) to (n_heads, ...)

print(f'After KV expansion:')
print(f'  K_expanded: {K_expanded.shape} —now matches Q\'s {n_heads} heads')
print(f'  V_expanded: {V_expanded.shape} —now matches Q\'s {n_heads} heads')
print()

# Step 4: Compute attention (standard scaled dot-product)
#
# TODO: Compute attention scores, apply causal mask, softmax, weighted V.
#   1. scores = Q @ K_expanded^T / sqrt(d_head)    shape: (n_heads, seq_len, seq_len)
#   2. Apply causal mask: set upper triangle to -inf
#   3. weights = softmax(scores, dim=-1)
#   4. output = weights @ V_expanded               shape: (n_heads, seq_len, d_head)

# Compute attention scores
scores = None  # TODO: Q @ K_expanded.transpose(-2, -1) / math.sqrt(d_head)

# Apply causal mask
causal = torch.tril(torch.ones(seq_len_ex3, seq_len_ex3))
scores = scores.masked_fill(causal == 0, float('-inf'))

# Softmax and weighted sum
weights = F.softmax(scores, dim=-1)
attn_output = None  # TODO: weights @ V_expanded

print(f'Attention output per head: {attn_output.shape}')

# Step 5: Concatenate heads and project
# Transpose back: (n_heads, seq_len, d_head) -> (seq_len, n_heads, d_head)
# Then reshape: (seq_len, n_heads * d_head)
concat = attn_output.transpose(0, 1).contiguous().view(seq_len_ex3, n_heads * d_head)
final_output = W_O(concat)  # (seq_len, d_model)

print(f'Final GQA output: {final_output.shape}')
print(f'Same shape as input? {final_output.shape == x.shape}')

<details>
<summary>Solution</summary>

The key insight is that GQA "expands" the K/V heads to match the Q heads by repeating. Each KV head is shared across `heads_per_kv_group` Q heads. The attention computation itself is identical to standard MHA—the only change is where the K and V come from.

```python
# Step 3: Expand K and V
K_expanded = torch.repeat_interleave(K, repeats=heads_per_kv_group, dim=0)
V_expanded = torch.repeat_interleave(V, repeats=heads_per_kv_group, dim=0)

# Step 4: Attention
scores = Q @ K_expanded.transpose(-2, -1) / math.sqrt(d_head)
# (causal mask applied after this line)
attn_output = weights @ V_expanded
```

The `repeat_interleave` is the GQA mechanism in one line: take 2 KV heads, repeat each 4 times to get 8, so each Q head can index its matching KV head. In production, this is often done without explicit expansion (using gather/index operations), but the logic is the same.

Note: the *computation* after expansion is identical to MHA. The savings come from the *cache*: during generation, you store only 2 K/V caches instead of 8.

</details>

### Helper: Working GQA Implementation

**Run the cell below** to execute a complete GQA forward pass and see the KV cache comparison. This ensures the output and analysis are correct regardless of your implementation above.

In [None]:
# --- Complete GQA forward pass (reference) ---

torch.manual_seed(42)
x = torch.randn(seq_len_ex3, d_model)

# Project
Q = W_Q(x).view(seq_len_ex3, n_heads, d_head).transpose(0, 1)
K = W_K(x).view(seq_len_ex3, n_kv_heads, d_head).transpose(0, 1)
V = W_V(x).view(seq_len_ex3, n_kv_heads, d_head).transpose(0, 1)

# Expand K/V to match Q heads
K_expanded = torch.repeat_interleave(K, repeats=heads_per_kv_group, dim=0)
V_expanded = torch.repeat_interleave(V, repeats=heads_per_kv_group, dim=0)

# Attention
scores = Q @ K_expanded.transpose(-2, -1) / math.sqrt(d_head)
causal = torch.tril(torch.ones(seq_len_ex3, seq_len_ex3))
scores = scores.masked_fill(causal == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
attn_output = weights @ V_expanded

# Concatenate and project
concat = attn_output.transpose(0, 1).contiguous().view(seq_len_ex3, n_heads * d_head)
final_output = W_O(concat)

print(f'GQA output shape: {final_output.shape}')
print(f'Input shape: {x.shape}')
print(f'Same shape (drop-in replacement)? {final_output.shape == x.shape}')
print()

# --- KV Cache Comparison ---
print('=== KV Cache Size Comparison ===')
print()

# Per-head KV cache: stores K and V for all sequence positions
# Shape per head: (seq_len, d_head) for K + (seq_len, d_head) for V
# In bytes (float32): 2 * seq_len * d_head * 4 bytes

# For this toy model:
bytes_per_kv_head = 2 * seq_len_ex3 * d_head * 4  # K + V, float32

mha_kv_cache = n_heads * bytes_per_kv_head
gqa_kv_cache = n_kv_heads * bytes_per_kv_head

print(f'Toy model (seq_len={seq_len_ex3}, d_head={d_head}, float32):')
print(f'  MHA ({n_heads} KV heads): {mha_kv_cache:,} bytes')
print(f'  GQA ({n_kv_heads} KV heads): {gqa_kv_cache:,} bytes')
print(f'  Reduction: {n_heads // n_kv_heads}x')
print()

# Scale to LLaMA 2 70B numbers:
print('--- Scaled to LLaMA 2 70B at 128K context ---')
llama_n_heads = 64
llama_n_kv_heads = 8
llama_d_head = 128
llama_seq_len = 128_000
llama_n_layers = 80
bytes_per_element = 2  # bf16

# Per-head per-layer KV cache (K + V)
llama_bytes_per_kv_head = 2 * llama_seq_len * llama_d_head * bytes_per_element

llama_mha_cache = llama_n_heads * llama_n_layers * llama_bytes_per_kv_head
llama_gqa_cache = llama_n_kv_heads * llama_n_layers * llama_bytes_per_kv_head

print(f'  MHA ({llama_n_heads} KV heads): {llama_mha_cache / 1e9:.1f} GB')
print(f'  GQA ({llama_n_kv_heads} KV heads): {llama_gqa_cache / 1e9:.1f} GB')
print(f'  Reduction: {llama_n_heads // llama_n_kv_heads}x')
print(f'  Savings: {(llama_mha_cache - llama_gqa_cache) / 1e9:.1f} GB')
print()
print(f'  LLaMA 2 70B model weights: ~140 GB')
print(f'  MHA KV cache at 128K: {llama_mha_cache / 1e9:.1f} GB—larger than the model!')
print(f'  GQA KV cache at 128K: {llama_gqa_cache / 1e9:.1f} GB—manageable.')

In [None]:
# --- Visualize: MHA vs GQA architecture ---

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

q_color = '#a78bfa'   # violet
kv_color = '#f59e0b'  # amber

configs = [
    ('MHA (Multi-Head)', 8, 8),
    ('GQA (Grouped Query)', 8, 2),
    ('MQA (Multi-Query)', 8, 1),
]

for ax, (title, n_q, n_kv) in zip(axes, configs):
    # Draw Q heads
    for i in range(n_q):
        y = 0.9 - i * 0.1
        ax.add_patch(plt.Rectangle((0.05, y - 0.03), 0.2, 0.06,
                     facecolor=q_color, alpha=0.6, edgecolor='white', linewidth=0.8))
        ax.text(0.15, y, f'Q{i}', ha='center', va='center', fontsize=7,
                color='white', fontweight='bold')
    
    # Draw KV heads
    kv_spacing = 0.8 / max(n_kv, 1)
    for i in range(n_kv):
        y = 0.9 - i * (0.8 / n_kv) - (0.8 / n_kv - 0.06) / 2
        height = 0.8 / n_kv - 0.02
        ax.add_patch(plt.Rectangle((0.65, y - height/2), 0.25, height,
                     facecolor=kv_color, alpha=0.5, edgecolor='white', linewidth=0.8))
        ax.text(0.775, y, f'KV{i}', ha='center', va='center', fontsize=7,
                color='white', fontweight='bold')
        
        # Draw connection lines from Q heads to this KV head
        q_per_kv = n_q // n_kv
        for j in range(q_per_kv):
            q_idx = i * q_per_kv + j
            q_y = 0.9 - q_idx * 0.1
            ax.plot([0.25, 0.65], [q_y, y], color='#94a3b8', linewidth=0.8, alpha=0.5)
    
    ax.set_title(f'{title}\n{n_q} Q heads, {n_kv} KV heads\nKV cache: {n_kv}/{n_q} = {n_kv/n_q:.0%}',
                fontsize=10, fontweight='bold')
    ax.set_xlim(-0.05, 1.0)
    ax.set_ylim(0.0, 1.05)
    ax.axis('off')

fig.suptitle('The MHA → GQA → MQA Spectrum: Same Q Diversity, Less KV Memory',
             fontsize=13, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print('All three architectures keep the same 8 Q heads with independent W_Q projections.')
print('The only change: how many independent K/V heads exist.')
print('GQA is the sweet spot: significant KV cache savings with minimal quality loss.')

## Exercise 4: Attention Cost Calculator (Independent)

The lesson presented three barriers to long context: position encoding limits, quadratic compute, and KV cache memory. Each has a targeted solution: RoPE, sparse attention, and GQA. In this exercise, you'll build a calculator that quantifies all three costs and shows how the solutions compound.

**Your task:** Build a function `attention_cost(config)` that computes:
1. **Attention FLOPs per layer**—the compute cost of the QK^T matrix multiplication and the attention-weighted V sum
2. **KV cache memory**—the total bytes needed to store K and V tensors across all layers during generation

Then compute costs for five configurations:
- **(a) GPT-2** at 1K context: 12 heads, d_model=768, d_head=64, 12 layers, float32
- **(b) LLaMA 2 70B with MHA** at 4K context: 64 Q heads, 64 KV heads, d_model=8192, d_head=128, 80 layers, bf16
- **(c) LLaMA 2 70B with GQA** at 4K context: 64 Q heads, 8 KV heads (same otherwise)
- **(d) LLaMA 2 70B with GQA** at 128K context
- **(e) LLaMA 2 70B with GQA + sliding window (w=4096)** at 128K context

Present results in a table. The insight: RoPE enables the position encoding (not computed here, but enables configs d and e), GQA reduces KV cache, and sliding window reduces FLOPs. Their benefits compound.

**Formulas:**
- Attention FLOPs per layer ≈ `2 * n_q_heads * seq_len * seq_len * d_head` (for QK^T) + `2 * n_q_heads * seq_len * seq_len * d_head` (for attn @ V). With sliding window, replace the second `seq_len` with `min(seq_len, window_size)`.
- KV cache per layer = `2 * n_kv_heads * seq_len * d_head * bytes_per_element` (2 for K and V)
- Total KV cache = KV cache per layer × n_layers

**Before running, predict:**
- How many times more expensive is 128K vs 4K in FLOPs? (Hint: quadratic scaling)
- How much does GQA reduce KV cache for LLaMA 2 70B (64 KV heads → 8)?
- Does sliding window affect KV cache size, or only FLOPs?

**No skeleton is provided.** Design the function and table yourself.

In [None]:
# Your attention cost calculator here.
#
# Suggested structure:
#
# 1. Define a config dataclass or dict with fields:
#    name, n_q_heads, n_kv_heads, d_head, n_layers,
#    seq_len, bytes_per_element, window_size (None = full attention)
#
# 2. Write attention_cost(config) that returns:
#    - flops_per_layer: attention FLOPs for one layer
#    - total_flops: flops_per_layer * n_layers
#    - kv_cache_per_layer: bytes for K + V in one layer
#    - total_kv_cache: kv_cache_per_layer * n_layers
#
# 3. Define the five configurations (a)-(e)
#
# 4. Compute costs for each and print a formatted table
#
# 5. Add a comparison section showing:
#    - FLOPs ratio: (b) vs (c) (GQA doesn't change FLOPs)
#    - KV cache ratio: (b) vs (c) (GQA reduces cache)
#    - FLOPs ratio: (d) vs (e) (sliding window reduces FLOPs)
#    - KV cache ratio: (d) vs (e) (sliding window doesn't change cache)



<details>
<summary>Solution</summary>

The key insight is that GQA and sliding window address *different* costs. GQA reduces KV cache memory (fewer KV heads to store) but does not reduce attention FLOPs (the Q heads still compute the same number of dot products). Sliding window reduces FLOPs (each query attends to at most `w` keys instead of all previous keys) but does not reduce KV cache size (the full K/V sequence is still stored for all layers). They are complementary.

```python
from dataclasses import dataclass


@dataclass
class AttentionConfig:
    name: str
    n_q_heads: int
    n_kv_heads: int
    d_head: int
    n_layers: int
    seq_len: int
    bytes_per_element: int  # 4 for float32, 2 for bf16
    window_size: int = 0    # 0 means full attention (no sliding window)


def attention_cost(cfg: AttentionConfig) -> dict:
    """Compute attention FLOPs and KV cache memory for a given config."""

    # Effective attention span per query token
    effective_span = cfg.seq_len
    if cfg.window_size > 0:
        effective_span = min(cfg.seq_len, cfg.window_size)

    # Attention FLOPs per layer:
    # QK^T: each Q head computes (seq_len, d_head) @ (d_head, effective_span)
    #   = 2 * seq_len * effective_span * d_head per head (multiply-add = 2 ops)
    # attn @ V: each Q head computes (seq_len, effective_span) @ (effective_span, d_head)
    #   = 2 * seq_len * effective_span * d_head per head
    # Total per layer = n_q_heads * (QK^T FLOPs + attn@V FLOPs)
    flops_qk = 2 * cfg.n_q_heads * cfg.seq_len * effective_span * cfg.d_head
    flops_av = 2 * cfg.n_q_heads * cfg.seq_len * effective_span * cfg.d_head
    flops_per_layer = flops_qk + flops_av
    total_flops = flops_per_layer * cfg.n_layers

    # KV cache: store K and V for all sequence positions, per KV head, per layer
    # K: (seq_len, d_head) and V: (seq_len, d_head) per KV head
    kv_cache_per_layer = 2 * cfg.n_kv_heads * cfg.seq_len * cfg.d_head * cfg.bytes_per_element
    total_kv_cache = kv_cache_per_layer * cfg.n_layers

    return {
        'flops_per_layer': flops_per_layer,
        'total_flops': total_flops,
        'kv_cache_per_layer': kv_cache_per_layer,
        'total_kv_cache': total_kv_cache,
    }


def format_flops(flops: float) -> str:
    """Human-readable FLOPs string."""
    if flops >= 1e15:
        return f'{flops / 1e15:.1f} PFLOPs'
    if flops >= 1e12:
        return f'{flops / 1e12:.1f} TFLOPs'
    if flops >= 1e9:
        return f'{flops / 1e9:.1f} GFLOPs'
    return f'{flops / 1e6:.1f} MFLOPs'


def format_bytes(b: float) -> str:
    """Human-readable bytes string."""
    if b >= 1e9:
        return f'{b / 1e9:.1f} GB'
    if b >= 1e6:
        return f'{b / 1e6:.1f} MB'
    return f'{b / 1e3:.1f} KB'


# --- Define the five configurations ---

configs = [
    AttentionConfig(
        name='(a) GPT-2, 1K',
        n_q_heads=12, n_kv_heads=12, d_head=64,
        n_layers=12, seq_len=1024,
        bytes_per_element=4,  # float32
    ),
    AttentionConfig(
        name='(b) LLaMA 70B MHA, 4K',
        n_q_heads=64, n_kv_heads=64, d_head=128,
        n_layers=80, seq_len=4096,
        bytes_per_element=2,  # bf16
    ),
    AttentionConfig(
        name='(c) LLaMA 70B GQA, 4K',
        n_q_heads=64, n_kv_heads=8, d_head=128,
        n_layers=80, seq_len=4096,
        bytes_per_element=2,
    ),
    AttentionConfig(
        name='(d) LLaMA 70B GQA, 128K',
        n_q_heads=64, n_kv_heads=8, d_head=128,
        n_layers=80, seq_len=128_000,
        bytes_per_element=2,
    ),
    AttentionConfig(
        name='(e) LLaMA 70B GQA+SW, 128K',
        n_q_heads=64, n_kv_heads=8, d_head=128,
        n_layers=80, seq_len=128_000,
        bytes_per_element=2,
        window_size=4096,
    ),
]

# --- Compute and display ---

results = [(cfg, attention_cost(cfg)) for cfg in configs]

print(f'{"Config":<32} {"FLOPs/Layer":>14} {"Total FLOPs":>14} {"KV Cache":>12}')
print('-' * 76)
for cfg, cost in results:
    print(f'{cfg.name:<32} '
          f'{format_flops(cost["flops_per_layer"]):>14} '
          f'{format_flops(cost["total_flops"]):>14} '
          f'{format_bytes(cost["total_kv_cache"]):>12}')

# --- Comparison analysis ---
print('\n=== Key Comparisons ===\n')

_, cost_b = results[1]  # MHA at 4K
_, cost_c = results[2]  # GQA at 4K
_, cost_d = results[3]  # GQA at 128K
_, cost_e = results[4]  # GQA + sliding window at 128K

print('MHA vs GQA (same model, same context length):')
print(f'  FLOPs: {cost_b["total_flops"] / cost_c["total_flops"]:.1f}x '
      f'(GQA does NOT reduce FLOPs—same Q heads, same dot products)')
print(f'  KV cache: {cost_b["total_kv_cache"] / cost_c["total_kv_cache"]:.0f}x reduction '
      f'(GQA reduces KV heads from 64 to 8)')

print(f'\n4K vs 128K context (same architecture):')
print(f'  FLOPs: {cost_d["total_flops"] / cost_c["total_flops"]:.0f}x increase '
      f'(quadratic: (128K/4K)² = {(128_000/4096)**2:.0f}x)')
print(f'  KV cache: {cost_d["total_kv_cache"] / cost_c["total_kv_cache"]:.1f}x increase '
      f'(linear: 128K/4K = {128_000/4096:.1f}x)')

print(f'\nFull attention vs sliding window at 128K:')
print(f'  FLOPs: {cost_d["total_flops"] / cost_e["total_flops"]:.1f}x reduction '
      f'(window limits attention span)')
print(f'  KV cache: {cost_d["total_kv_cache"] / cost_e["total_kv_cache"]:.1f}x '
      f'(sliding window does NOT reduce KV cache—full sequence still stored)')

print(f'\nEnd-to-end: MHA at 4K (b) vs GQA+SW at 128K (e):')
print(f'  FLOPs: {cost_e["total_flops"] / cost_b["total_flops"]:.1f}x')
print(f'  KV cache: {cost_e["total_kv_cache"] / cost_b["total_kv_cache"]:.1f}x')
print(f'  Context: 32x longer—and the costs are manageable because')
print(f'  GQA compressed the cache and sliding window compressed the compute.')

print('\n=== The Three Barriers, Quantified ===')
print('Position: RoPE makes configs (d) and (e) possible at all (learned PE stops at 4K)')
print(f'Compute: sliding window cuts 128K FLOPs by {cost_d["total_flops"] / cost_e["total_flops"]:.0f}x')
print(f'Memory: GQA cuts KV cache by {cost_b["total_kv_cache"] / cost_c["total_kv_cache"]:.0f}x')
print('Three barriers, three solutions, compounding benefits.')
```

**Design choices explained:**
- The function separates FLOPs and KV cache because they respond to *different* optimizations. GQA reduces cache but not FLOPs. Sliding window reduces FLOPs but not cache. This separation makes the "three independent barriers" framework concrete.
- FLOPs formula counts both QK^T and attn@V multiplications. Each is a matrix multiply: 2 ops per multiply-add.
- KV cache counts K and V separately (factor of 2), for all KV heads, for all layers. The seq_len dimension is the full context length even with sliding window—the cache stores all positions because different layers may need different windows.
- The comparison section explicitly calls out what each optimization does and does not affect. This is the lesson's core message: three independent bottlenecks, three independent solutions.

</details>

---

## Key Takeaways

1. **RoPE encodes position in the dot product via rotation, making attention scores depend on relative distance.** The dot product between rotated Q and K vectors is the same for positions (3, 7) and (1003, 1007)—both have relative distance 4. This mathematical property is what enables context extension: patterns learned at training length transfer to longer sequences.

2. **Sparse attention (sliding window) restricts which token pairs compute scores, reducing O(n^2) to O(n*w).** The sliding window mask is a narrow diagonal band within the causal triangle. At 128K tokens with w=4096, this is a 16x compute reduction. Information still flows across the full context through stacked layers and the residual stream.

3. **GQA shares K/V heads across groups of Q heads, preserving query diversity while cutting KV cache memory.** With 8 Q heads and 2 KV heads, the cache is 4x smaller. At LLaMA 2 70B scale (64 Q, 8 KV), GQA reduces the 128K-context KV cache from ~335 GB to ~42 GB—the difference between "larger than the model" and "fits in memory."

4. **These three innovations address three independent bottlenecks.** RoPE fixes position generalization. Sparse attention fixes quadratic compute. GQA fixes KV cache memory. They combine, not compete—LLaMA and Mistral use all three together.

5. **Position in the handshake, not the nametag. Compute where attention concentrates, not everywhere. Cache what's needed, not everything.** Three barriers, three targeted solutions.