# Week 3, Day 4: Full Attention — The Complete Formula

**Time:** ~1 hour

**Goal:** Implement the complete attention mechanism and understand its quadratic memory problem.

## The Challenge

We have all the pieces:
- **Dot products** (Day 1): Compute similarity scores
- **Softmax** (Days 2-3): Convert scores to probabilities

Today we combine them into **scaled dot-product attention**:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time

np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4, sci_mode=False)

---
## Step 1: The Challenge — What Attention Does (5 min)

Attention answers: **For each position, what information from other positions should I use?**

The formula:

1. **QK^T**: Compute all pairwise similarities (which keys match which queries)
2. **÷√d_k**: Scale to control variance
3. **softmax**: Convert to probabilities (weights that sum to 1)
4. **× V**: Weighted average of values

In [None]:
# Visual example with a tiny sequence
seq_len = 4
d_model = 8

# Simple example: each position has distinct Q, K, V
torch.manual_seed(42)
Q = torch.randn(seq_len, d_model)
K = torch.randn(seq_len, d_model)
V = torch.randn(seq_len, d_model)

print(f"Q shape: {Q.shape} — Queries (what each position is looking for)")
print(f"K shape: {K.shape} — Keys (what each position offers to be found)")
print(f"V shape: {V.shape} — Values (what each position contributes)")

---
## Step 2: Explore — Step-by-Step Computation (15 min)

Let's trace through the full attention computation.

In [None]:
def attention_step_by_step(Q, K, V, verbose=True):
    """
    Compute attention with detailed steps.
    """
    seq_len, d_k = Q.shape
    
    # Step 1: QK^T
    scores = Q @ K.T
    if verbose:
        print(f"Step 1: QK^T")
        print(f"  Shape: {Q.shape} @ {K.T.shape} = {scores.shape}")
        print(f"  Result (raw scores):")
        print(scores.numpy())
        print()
    
    # Step 2: Scale by sqrt(d_k)
    scale = np.sqrt(d_k)
    scores_scaled = scores / scale
    if verbose:
        print(f"Step 2: Divide by √d_k = √{d_k} = {scale:.2f}")
        print(f"  Result (scaled scores):")
        print(scores_scaled.numpy())
        print()
    
    # Step 3: Softmax (row-wise)
    attention_weights = F.softmax(scores_scaled, dim=-1)
    if verbose:
        print(f"Step 3: Softmax (each row sums to 1)")
        print(f"  Result (attention weights):")
        print(attention_weights.numpy())
        print(f"  Row sums: {attention_weights.sum(dim=-1).numpy()}")
        print()
    
    # Step 4: Weighted sum of V
    output = attention_weights @ V
    if verbose:
        print(f"Step 4: Attention @ V")
        print(f"  Shape: {attention_weights.shape} @ {V.shape} = {output.shape}")
        print(f"  Result (output):")
        print(output.numpy())
    
    return output, attention_weights

output, weights = attention_step_by_step(Q, K, V)

In [None]:
# Visualize attention pattern
def visualize_attention(weights, title="Attention Weights"):
    plt.figure(figsize=(6, 5))
    plt.imshow(weights.numpy(), cmap='Blues', aspect='auto')
    plt.colorbar(label='Attention Weight')
    plt.xlabel('Key Position (attending to)')
    plt.ylabel('Query Position (from)')
    plt.title(title)
    
    # Add text annotations
    for i in range(weights.shape[0]):
        for j in range(weights.shape[1]):
            plt.text(j, i, f'{weights[i,j]:.2f}', ha='center', va='center', fontsize=10)
    
    plt.tight_layout()
    plt.show()

visualize_attention(weights, "Full Attention (All positions attend to all)")

### Causal Attention

In decoder-only models (like GPT), each position can only attend to **itself and previous positions**. This is called **causal** or **autoregressive** attention.

In [None]:
def causal_attention(Q, K, V):
    """
    Attention with causal mask: position i can only attend to positions <= i.
    """
    seq_len, d_k = Q.shape
    
    # Compute scaled scores
    scores = Q @ K.T / np.sqrt(d_k)
    
    # Create causal mask (upper triangular = True, meaning "mask this out")
    causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    
    # Apply mask: set future positions to -infinity
    scores_masked = scores.masked_fill(causal_mask, float('-inf'))
    
    # Softmax (rows with -inf become 0 probability)
    attention_weights = F.softmax(scores_masked, dim=-1)
    
    # Weighted sum
    output = attention_weights @ V
    
    return output, attention_weights, scores_masked

output_causal, weights_causal, scores_causal = causal_attention(Q, K, V)

print("Scores after causal mask:")
print(scores_causal.numpy())
print("\nAttention weights (causal):")
print(weights_causal.numpy())

In [None]:
visualize_attention(weights_causal, "Causal Attention (Lower triangular)")

---
## Step 3: The Concept — The Quadratic Memory Problem (10 min)

### Memory Analysis

For a sequence of length $N$ with embedding dimension $d$:

| Tensor | Shape | Size (FP16) |
|--------|-------|-------------|
| Q, K, V | [N, d] | N × d × 2 bytes each |
| QK^T (scores) | [N, N] | N² × 2 bytes |
| softmax(scores) | [N, N] | N² × 2 bytes |
| Output | [N, d] | N × d × 2 bytes |

**The problem:** The attention matrix is **N² in size**.

In [None]:
def memory_analysis(seq_len, d_model, batch_size=1, n_heads=1, dtype_bytes=2):
    """
    Calculate memory usage for standard attention.
    """
    # Input/output tensors
    qkv_memory = 3 * batch_size * seq_len * d_model * dtype_bytes
    output_memory = batch_size * seq_len * d_model * dtype_bytes
    
    # Attention matrix (the quadratic part)
    attention_matrix_memory = batch_size * n_heads * seq_len * seq_len * dtype_bytes
    
    total = qkv_memory + output_memory + attention_matrix_memory
    
    return {
        'seq_len': seq_len,
        'qkv_memory_mb': qkv_memory / 1e6,
        'output_memory_mb': output_memory / 1e6,
        'attention_matrix_mb': attention_matrix_memory / 1e6,
        'total_mb': total / 1e6,
        'attention_pct': attention_matrix_memory / total * 100,
    }

# Memory usage for different sequence lengths
print("Memory usage (batch=1, d_model=4096, n_heads=32, FP16):")
print("-" * 70)
print(f"{'Seq Len':>10} {'QKV':>10} {'Attn Matrix':>15} {'Total':>10} {'Attn %':>10}")

for seq_len in [512, 2048, 8192, 32768, 131072]:
    stats = memory_analysis(seq_len, d_model=4096, batch_size=1, n_heads=32)
    print(f"{seq_len:>10} {stats['qkv_memory_mb']:>9.1f}MB {stats['attention_matrix_mb']:>14.1f}MB "
          f"{stats['total_mb']:>9.1f}MB {stats['attention_pct']:>9.1f}%")

In [None]:
# Visualize the quadratic scaling
seq_lengths = np.array([512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072])
attention_memory_gb = [memory_analysis(n, 4096, 1, 32)['attention_matrix_mb'] / 1000 
                       for n in seq_lengths]

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(seq_lengths, attention_memory_gb, 'b-o', linewidth=2, markersize=6)
plt.axhline(y=80, color='red', linestyle='--', label='A100 80GB VRAM')
plt.xlabel('Sequence Length')
plt.ylabel('Attention Matrix Memory (GB)')
plt.title('Linear Scale')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.loglog(seq_lengths, attention_memory_gb, 'b-o', linewidth=2, markersize=6)
plt.axhline(y=80, color='red', linestyle='--', label='A100 80GB VRAM')
plt.xlabel('Sequence Length')
plt.ylabel('Attention Matrix Memory (GB)')
plt.title('Log-Log Scale (slope = 2 → quadratic)')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"At 128K tokens: {attention_memory_gb[-1]:.1f} GB just for attention matrix!")

### The Problem in Practice

For GPT-4's 128K context window:
- Attention matrix: 128K × 128K = 16.4 billion elements
- In FP16: 32 GB per layer per head
- With 32 heads: **1 TB per layer!**

This is why we need **FlashAttention** — it computes attention without materializing the full N×N matrix.

---
## Step 4: Code It — Complete Attention Implementation (30 min)

### Standard Attention

In [None]:
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Standard scaled dot-product attention.
    
    Args:
        Q: [batch, seq_len, d_k] or [seq_len, d_k]
        K: [batch, seq_len, d_k] or [seq_len, d_k]
        V: [batch, seq_len, d_v] or [seq_len, d_v]
        mask: Optional boolean mask [seq_len, seq_len], True = mask out
    
    Returns:
        output: [batch, seq_len, d_v]
        attention_weights: [batch, seq_len, seq_len]
    """
    d_k = Q.shape[-1]
    
    # Compute attention scores
    scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)
    
    # Apply mask if provided
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    # Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum of values
    output = torch.matmul(attention_weights, V)
    
    return output, attention_weights

# Test
batch_size = 2
seq_len = 8
d_model = 64

Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")
print(f"Weights shape: {weights.shape}")
print(f"Weights sum per row: {weights.sum(dim=-1)[0]}")

### Multi-Head Attention

Real transformers use **multiple attention heads** — each head can learn different patterns.

In [None]:
class MultiHeadAttention(torch.nn.Module):
    """
    Multi-head attention as used in transformers.
    """
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # Linear projections
        self.W_q = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_k = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_v = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_o = torch.nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: [batch, seq_len, d_model]
            mask: Optional [seq_len, seq_len] causal mask
        """
        batch_size, seq_len, _ = x.shape
        
        # Project to Q, K, V
        Q = self.W_q(x)  # [batch, seq_len, d_model]
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Reshape for multi-head: [batch, n_heads, seq_len, d_k]
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # Compute attention for all heads in parallel
        # scores: [batch, n_heads, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))
        
        attention_weights = F.softmax(scores, dim=-1)
        
        # Apply attention to values
        context = torch.matmul(attention_weights, V)  # [batch, n_heads, seq_len, d_k]
        
        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final projection
        output = self.W_o(context)
        
        return output, attention_weights

# Test multi-head attention
mha = MultiHeadAttention(d_model=64, n_heads=8)
x = torch.randn(2, 16, 64)  # [batch, seq_len, d_model]

# Create causal mask
causal_mask = torch.triu(torch.ones(16, 16), diagonal=1).bool()

output, weights = mha(x, mask=causal_mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {weights.shape} (batch, heads, seq, seq)")

In [None]:
# Visualize attention patterns across heads
fig, axes = plt.subplots(2, 4, figsize=(14, 7))
for i, ax in enumerate(axes.flat):
    ax.imshow(weights[0, i].detach().numpy(), cmap='Blues', aspect='auto')
    ax.set_title(f'Head {i}')
    ax.set_xlabel('Key')
    ax.set_ylabel('Query')

plt.suptitle('Multi-Head Attention Patterns (with causal mask)')
plt.tight_layout()
plt.show()

### Exercise: Compare with PyTorch's Built-in

PyTorch provides `torch.nn.functional.scaled_dot_product_attention`. Let's compare.

In [None]:
# Compare our implementation with PyTorch's
Q = torch.randn(2, 8, 32, 64)  # [batch, heads, seq_len, d_k]
K = torch.randn(2, 8, 32, 64)
V = torch.randn(2, 8, 32, 64)

# Our implementation (adapted for batched multi-head input)
def our_attention(Q, K, V, is_causal=False):
    d_k = Q.shape[-1]
    scores = Q @ K.transpose(-2, -1) / np.sqrt(d_k)
    
    if is_causal:
        seq_len = Q.shape[-2]
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        scores = scores.masked_fill(mask, float('-inf'))
    
    weights = F.softmax(scores, dim=-1)
    return weights @ V

our_result = our_attention(Q, K, V, is_causal=True)
pytorch_result = F.scaled_dot_product_attention(Q, K, V, is_causal=True)

max_diff = (our_result - pytorch_result).abs().max()
print(f"Max difference: {max_diff:.2e}")

### Benchmark: See the Memory Problem

In [None]:
if torch.cuda.is_available():
    print("Benchmarking standard attention (watch memory grow):")
    print("-" * 60)
    
    d_model = 64
    n_heads = 8
    
    for seq_len in [256, 512, 1024, 2048, 4096]:
        try:
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            
            Q = torch.randn(1, n_heads, seq_len, d_model, device='cuda')
            K = torch.randn(1, n_heads, seq_len, d_model, device='cuda')
            V = torch.randn(1, n_heads, seq_len, d_model, device='cuda')
            
            # Force computation
            output = our_attention(Q, K, V, is_causal=True)
            torch.cuda.synchronize()
            
            peak_memory = torch.cuda.max_memory_allocated() / 1e6
            
            # Theoretical attention matrix size
            attn_matrix_mb = n_heads * seq_len * seq_len * 4 / 1e6  # FP32
            
            print(f"seq_len={seq_len:5d}: peak memory={peak_memory:8.1f}MB, "
                  f"attn matrix={attn_matrix_mb:8.1f}MB")
            
            del Q, K, V, output
        except RuntimeError as e:
            print(f"seq_len={seq_len:5d}: OUT OF MEMORY!")
            break
else:
    print("GPU not available for memory benchmark.")

---
## Step 5: Verify — Quiz & Reflection (10 min)

### Quiz

In [None]:
def check_answer(question, your_answer, correct_answer):
    if your_answer == correct_answer:
        print(f"✓ Correct! {question}")
    else:
        print(f"✗ Incorrect. {question}")
        print(f"  Your answer: {your_answer}, Correct: {correct_answer}")

# Q1: What is the shape of the attention matrix for seq_len=1024?
# a) [1024]
# b) [1024, 64]
# c) [1024, 1024]
# d) [64, 1024]
q1_answer = 'c'
check_answer("Attention matrix shape", q1_answer, 'c')

In [None]:
# Q2: If we double the sequence length, memory for attention matrix increases by:
# a) 2x
# b) 4x
# c) 8x
# d) log(2)x
q2_answer = 'b'  # N² → (2N)² = 4N²
check_answer("Memory scaling when doubling seq_len", q2_answer, 'b')

In [None]:
# Q3: In causal attention, which positions can query position 5 attend to?
# a) Only position 5
# b) Positions 0-4
# c) Positions 0-5
# d) All positions
q3_answer = 'c'  # Can attend to self and previous
check_answer("Causal attention for position 5", q3_answer, 'c')

In [None]:
# Q4: Why do we mask with -inf instead of 0 in causal attention?
# a) -inf is faster to compute
# b) softmax(-inf) = 0, giving zero weight
# c) 0 would cause division by zero
# d) It's just a convention
q4_answer = 'b'
check_answer("Why mask with -inf", q4_answer, 'b')

### Reflection Questions

1. **The quadratic bottleneck:** Which operation creates the N×N matrix? Can we avoid computing it all at once?

2. **Multi-head attention:** Why use multiple heads instead of one big attention? (Hint: diversity of patterns)

3. **Memory vs compute:** Is attention memory-bound or compute-bound? How do you know?

---

## Summary

| Component | Formula | Purpose |
|-----------|---------|--------|
| QK^T | Query × Key^T | Compute pairwise similarities |
| ÷√d_k | scores / √d_k | Normalize variance |
| softmax | exp(x) / Σexp(x) | Convert to probabilities |
| × V | weights × Values | Weighted combination |

**The problem:** Materializing the full N×N attention matrix requires O(N²) memory.

**Tomorrow:** Online softmax — computing softmax without storing all values at once.

---

**Interactive Reference:** [attention-math.html](../attention-math.html) Section 5 — Full Attention Visualization