# Implement Multi-Head Attention from Scratch

### Problem Statement

Multi-Head Attention (MHA) is the core mechanism of the Transformer architecture. It enables the model to **jointly attend** to information from different representation subspaces at different positions.

### Background: Why Multiple Heads?

Single-head attention computes one set of attention weights. But different parts of the input might benefit from different attention patterns:
- One head might learn syntax (subject-verb agreement)
- Another might learn semantics (word meaning relationships)
- Another might learn positional patterns

Multi-head attention runs multiple attention operations in parallel, each with its own learned projections.

### The Math

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where $\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$

### Key Dimensions

- `d_model`: Total embedding dimension (e.g., 512)
- `num_heads`: Number of attention heads (e.g., 8)
- `d_head = d_model // num_heads`: Dimension per head (e.g., 64)

### Learning Path

1. **Part 1**: Mask creation (causal, padding, KV cache)
2. **Part 2**: Core multi-head mechanism (Q, K, V given) - focus on the split/concat logic
3. **Part 3**: Multi-Head Self-Attention - Q, K, V from projections of single input x
4. **Part 4**: MHA with KV Cache for efficient inference

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Part 1: Attention Mask Creation

The mask functions are the same as single-head attention. For multi-head, the mask broadcasts across the head dimension.

Mask convention: **True = masked (cannot attend), False = can attend**

In [None]:
def create_causal_mask(seq_len_q: int, seq_len_k: int = None, device=None) -> torch.Tensor:
    """
    Create a causal (lower-triangular) attention mask for autoregressive models.
    
    For multi-head attention, this mask broadcasts across the head dimension.
    Shape: (seq_len_q, seq_len_k) -> broadcasts to (batch, num_heads, seq_len_q, seq_len_k)
    """
    if seq_len_k is None:
        seq_len_k = seq_len_q
    return torch.triu(torch.ones(seq_len_q, seq_len_k, dtype=torch.bool, device=device), diagonal=1)

In [None]:
def create_padding_mask(lengths: torch.Tensor, max_len: int) -> torch.Tensor:
    """
    Create a padding mask for variable-length sequences.
    Shape: (batch, max_len) -> needs reshaping for multi-head: (batch, 1, 1, max_len)
    """
    positions = torch.arange(max_len, device=lengths.device).unsqueeze(0)
    return positions >= lengths.unsqueeze(1)

In [None]:
def create_attention_mask(
    seq_len_q: int,
    seq_len_k: int = None,
    is_causal: bool = True,
    key_padding_lengths: torch.Tensor = None,
    device=None
) -> torch.Tensor:
    """
    Create a combined attention mask for multi-head attention.
    
    Returns:
        mask with shape suitable for broadcasting to (batch, num_heads, seq_q, seq_k)
    """
    if seq_len_k is None:
        seq_len_k = seq_len_q

    mask = torch.zeros(seq_len_q, seq_len_k, dtype=torch.bool, device=device)

    if is_causal:
        causal = create_causal_mask(seq_len_q, seq_len_k, device=device)
        mask = mask | causal

    if key_padding_lengths is not None:
        padding = create_padding_mask(key_padding_lengths, seq_len_k)  # (batch, seq_k)
        padding = padding.unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, seq_k) for multi-head
        mask = mask.unsqueeze(0).unsqueeze(0) | padding  # (batch, 1, seq_q, seq_k)

    return mask

In [None]:
def create_causal_mask_with_cache(
    seq_len_q: int,
    seq_len_k: int,
    cache_len: int,
    device=None
) -> torch.Tensor:
    """
    Create a causal mask for attention with KV cache.
    New queries can attend to all cached positions.
    """
    mask = torch.zeros(seq_len_q, seq_len_k, dtype=torch.bool, device=device)

    if seq_len_q > 1:
        new_token_mask = create_causal_mask(seq_len_q, seq_len_q, device=device)
        mask[:, cache_len:] = new_token_mask

    return mask

In [None]:
# Test mask creation
mask = create_causal_mask(4)
print("Causal mask (4x4):")
print(mask)
print(f"\nThis broadcasts to (batch, num_heads, 4, 4) in multi-head attention")

# Test padding mask for multi-head (needs extra dimensions)
lengths = torch.tensor([4, 3])
mask_combined = create_attention_mask(4, is_causal=True, key_padding_lengths=lengths)
print(f"\nCombined mask shape for multi-head: {mask_combined.shape}")
print("Broadcasts to (batch, num_heads, seq_q, seq_k)")

print("\n\u2713 Mask tests passed!")

## Part 2: Core Multi-Head Mechanism

First, implement multi-head attention assuming Q, K, V are already projected.
This isolates the split-into-heads and concatenate logic.

In [None]:
torch.manual_seed(42)

batch_size = 2
seq_len = 8
d_model = 64
num_heads = 4
d_head = d_model // num_heads

print(f"d_model={d_model}, num_heads={num_heads}, d_head={d_head}")
print(f"\nEach head operates on {d_head} dimensions")
print(f"All {num_heads} heads run in parallel, then concatenate back to {d_model}")

In [None]:
def multi_head_attention_core(Q, K, V, num_heads, mask=None):
    """
    Core multi-head attention computation (Q, K, V already projected).
    
    This function takes already-projected Q, K, V and:
    1. Splits them into multiple heads
    2. Computes attention for each head in parallel
    3. Concatenates the results
    
    Args:
        Q, K, V: Projected tensors of shape (batch, seq_len, d_model)
        num_heads: Number of attention heads
        mask: Optional boolean attention mask (True = masked)
    
    Returns:
        output: Multi-head attention output (batch, seq_len, d_model)
        attn_weights: Attention weights (batch, num_heads, seq_len, seq_len)
    """
    batch_size, seq_len, d_model = Q.shape
    d_head = d_model // num_heads
    
    # Split into heads: (batch, seq, d_model) -> (batch, num_heads, seq, d_head)
    Q = Q.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    K = K.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    V = V.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    
    # Scaled dot-product attention per head
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_head ** 0.5)
    
    # Apply mask (broadcasts across batch and heads)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)  # (batch, num_heads, seq, d_head)
    
    # Concatenate heads: (batch, num_heads, seq, d_head) -> (batch, seq, d_model)
    output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
    
    return output, attn_weights

In [None]:
# Test the core mechanism with random Q, K, V
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)

output, attn_weights = multi_head_attention_core(Q, K, V, num_heads)

print(f"Input Q shape: {Q.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"  -> {num_heads} heads, each with {seq_len}x{seq_len} attention matrix")

assert output.shape == Q.shape
print("\n\u2713 Core multi-head mechanism test passed!")

In [None]:
# Test with causal mask
causal_mask = create_causal_mask(seq_len)
output_causal, attn_causal = multi_head_attention_core(Q, K, V, num_heads, mask=causal_mask)

# Verify upper triangle is zero for all heads
for h in range(num_heads):
    upper = attn_causal[0, h].triu(diagonal=1)
    assert torch.allclose(upper, torch.zeros_like(upper), atol=1e-6), f"Head {h} has non-zero upper triangle!"

print("\u2713 Causal mask works correctly across all heads!")

## Part 3: Multi-Head Self-Attention

Now implement the **full** multi-head self-attention where:
- Input: Single tensor `x` (the residual stream)
- Q, K, V are computed as projections of x
- Each projection is split into heads
- Output projection combines the heads

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention where Q, K, V come from projections of the same input.
    """
    
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        
        # Projections: x -> Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(
        self, 
        x: torch.Tensor, 
        is_causal: bool = False,
        key_padding_lengths: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Input tensor (batch, seq_len, d_model) - the residual stream
            is_causal: Whether to apply causal masking
            key_padding_lengths: If provided, actual lengths for padding mask
        
        Returns:
            output: Multi-head attention output (batch, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape
        
        # Project x to Q, K, V - all from the SAME input!
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # Split into heads: (batch, seq, d_model) -> (batch, num_heads, seq, d_head)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
        
        # Scaled dot-product attention per head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_head ** 0.5)
        
        # Create and apply mask
        if is_causal or key_padding_lengths is not None:
            mask = create_attention_mask(
                seq_len_q=seq_len,
                seq_len_k=seq_len,
                is_causal=is_causal,
                key_padding_lengths=key_padding_lengths,
                device=x.device
            )
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Concatenate heads: (batch, num_heads, seq, d_head) -> (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Output projection
        return self.W_o(attn_output)

In [None]:
# Test Multi-Head Self-Attention
torch.manual_seed(42)

# Single input - the residual stream
x = torch.randn(batch_size, seq_len, d_model)
print(f"Input x shape: {x.shape}")
print("This single input will be projected to create Q, K, V for all heads")

# Create multi-head self-attention
mhsa = MultiHeadSelfAttention(d_model, num_heads)

# Forward pass (bidirectional)
output = mhsa(x)
print(f"\nOutput shape: {output.shape}")
assert output.shape == x.shape

# Forward pass (causal)
output_causal = mhsa(x, is_causal=True)
print(f"Causal output shape: {output_causal.shape}")

# Forward pass (with padding)
lengths = torch.tensor([8, 5])
output_padded = mhsa(x, is_causal=True, key_padding_lengths=lengths)
print(f"Padded output shape: {output_padded.shape}")

print("\n\u2713 Multi-Head Self-Attention test passed!")

## Validate Against PyTorch's Implementation

To verify correctness, we compare against `torch.nn.MultiheadAttention` using the same weights.

In [None]:
def multi_head_attention(q, k, v, num_heads, d_model, mask=None, weights=None):
    """
    Multi-head attention function that can use external weights for validation.
    
    Note: In real transformers, this takes a single input x and projects it.
    This version accepts separate q, k, v for compatibility with PyTorch's API.
    """
    assert d_model % num_heads == 0
    
    d_head = d_model // num_heads
    batch_size, seq_len, _ = q.shape
    
    # Create projections
    Q_w = nn.Linear(d_model, d_model, bias=False)
    K_w = nn.Linear(d_model, d_model, bias=False)
    V_w = nn.Linear(d_model, d_model, bias=False)
    W_out = nn.Linear(d_model, d_model, bias=False)
    
    if weights is not None:
        Q_w.weight.data = weights['q_weight']
        K_w.weight.data = weights['k_weight']
        V_w.weight.data = weights['v_weight']
        W_out.weight.data = weights['out_weight']
    
    # Project
    Q = Q_w(q)
    K = K_w(k)
    V = V_w(v)
    
    # Split into heads
    Q = Q.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    K = K.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    V = V.view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    
    # Attention
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_head ** 0.5)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    # Concatenate and project
    output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
    return W_out(output)

In [None]:
# Validate against PyTorch's MultiheadAttention
torch.manual_seed(42)

# For validation, we use PyTorch's API which takes q, k, v separately
# (even though in self-attention they're all projections of the same input)
x = torch.randn(batch_size, seq_len, d_model)

# Create PyTorch reference
multihead_attn = torch.nn.MultiheadAttention(
    embed_dim=d_model, num_heads=num_heads, bias=False, batch_first=True
)

# Extract weights
weights = {
    'q_weight': multihead_attn.in_proj_weight[:d_model, :],
    'k_weight': multihead_attn.in_proj_weight[d_model:2*d_model, :],
    'v_weight': multihead_attn.in_proj_weight[2*d_model:, :],
    'out_weight': multihead_attn.out_proj.weight
}

# For self-attention: q=k=v=x
output_custom = multi_head_attention(x, x, x, num_heads, d_model, weights=weights)
output_ref, _ = multihead_attn(x, x, x)

assert torch.allclose(output_custom, output_ref, atol=1e-6), "Outputs don't match!"
print("\u2713 Multi-Head Attention matches PyTorch!")
print(f"Max difference: {(output_custom - output_ref).abs().max().item():.2e}")

## Visualizing Multi-Head Attention

Different heads learn to attend to different patterns.

In [None]:
import matplotlib.pyplot as plt

def get_attention_weights(x, num_heads, d_model):
    """Get attention weights for visualization."""
    batch_size, seq_len, _ = x.shape
    d_head = d_model // num_heads
    
    # Random projections for visualization
    W_q = nn.Linear(d_model, d_model, bias=False)
    W_k = nn.Linear(d_model, d_model, bias=False)
    
    Q = W_q(x).view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    K = W_k(x).view(batch_size, seq_len, num_heads, d_head).transpose(1, 2)
    
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_head ** 0.5)
    attn_weights = F.softmax(scores, dim=-1)
    
    return attn_weights

# Generate attention patterns
torch.manual_seed(123)
vis_x = torch.randn(1, 8, 64)
attn = get_attention_weights(vis_x, num_heads=4, d_model=64)

# Plot each head's attention pattern
fig, axes = plt.subplots(1, 4, figsize=(14, 3.5))
for h in range(4):
    im = axes[h].imshow(attn[0, h].detach().numpy(), cmap='Blues', vmin=0, vmax=0.5)
    axes[h].set_title(f'Head {h}', fontsize=11)
    axes[h].set_xlabel('Key')
    if h == 0:
        axes[h].set_ylabel('Query')

plt.suptitle('Multi-Head Attention Patterns (each head learns different patterns)', fontsize=12)
plt.tight_layout()
plt.savefig('mha_patterns.png', dpi=150, bbox_inches='tight')
plt.show()

print("Notice how different heads attend to different positions -")
print("this is how the model captures diverse relationships.")

## Part 4: Multi-Head Attention with KV Cache

### Memory Considerations

KV cache memory per layer:
```
memory = batch_size * num_heads * seq_len * head_dim * 2 (K and V) * bytes_per_param
```

For a 70B model (80 layers, 64 heads, 128 head_dim) with 8K context:
- Per layer: 8K * 64 * 128 * 2 * 2 bytes = ~256 MB
- Total: 80 * 256 MB = **~20 GB just for KV cache!**

This is why techniques like **Grouped Query Attention (GQA)** are important.

In [None]:
class MultiHeadAttentionWithCache(nn.Module):
    """
    Multi-Head Self-Attention with KV Cache support for efficient inference.
    
    Takes a single input x (the residual stream) and projects it to Q, K, V.
    """
    
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Projections
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
    
    def forward(
        self,
        x: torch.Tensor,
        cache_k: torch.Tensor = None,
        cache_v: torch.Tensor = None,
        is_causal: bool = True,
    ) -> tuple:
        """
        Args:
            x: Input tensor (batch, seq_len, d_model) - the residual stream
            cache_k: Cached keys (batch, num_heads, cached_len, head_dim) or None
            cache_v: Cached values (batch, num_heads, cached_len, head_dim) or None
            is_causal: Whether to apply causal masking
        
        Returns:
            output: Attention output (batch, seq_len, d_model)
            new_cache_k: Updated key cache
            new_cache_v: Updated value cache
        """
        batch_size, seq_len_q, _ = x.shape
        cache_len = cache_k.size(2) if cache_k is not None else 0
        
        # Project x to Q, K, V
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        # Reshape to (batch, num_heads, seq_len, head_dim)
        Q = Q.view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len_q, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Concatenate with cache if exists
        if cache_k is not None and cache_v is not None:
            K = torch.cat([cache_k, K], dim=2)
            V = torch.cat([cache_v, V], dim=2)
        
        # Update cache
        new_cache_k = K
        new_cache_v = V
        
        seq_len_k = K.size(2)
        
        # Compute attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply causal mask with cache offset
        if is_causal:
            mask = create_causal_mask_with_cache(
                seq_len_q=seq_len_q,
                seq_len_k=seq_len_k,
                cache_len=cache_len,
                device=x.device
            )
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        # Reshape back: (batch, num_heads, seq_len, head_dim) -> (batch, seq_len, d_model)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        output = self.out_proj(output)
        
        return output, new_cache_k, new_cache_v

In [None]:
# Test MHA with KV Cache
print("=== Testing Multi-Head Attention with KV Cache ===")

torch.manual_seed(42)
batch_size = 2
d_model = 64
num_heads = 4

mha = MultiHeadAttentionWithCache(d_model, num_heads)

# Step 1: Process prompt (3 tokens) - prefill
prompt = torch.randn(batch_size, 3, d_model)
print(f"\nInput prompt shape: {prompt.shape}")
print("Q, K, V will all be computed from this single input")

out1, cache_k, cache_v = mha(prompt, None, None)
print(f"\nAfter prompt: cache shape = {cache_k.shape}")
print(f"  (batch={batch_size}, num_heads={num_heads}, seq_len=3, head_dim={d_model//num_heads})")

# Step 2: Generate token 4
new_token = torch.randn(batch_size, 1, d_model)
out2, cache_k, cache_v = mha(new_token, cache_k, cache_v)
print(f"After token 4: cache shape = {cache_k.shape}")

# Step 3: Generate token 5
new_token = torch.randn(batch_size, 1, d_model)
out3, cache_k, cache_v = mha(new_token, cache_k, cache_v)
print(f"After token 5: cache shape = {cache_k.shape}")

# Verify
assert cache_k.shape == (batch_size, num_heads, 5, d_model // num_heads)
assert out3.shape == (batch_size, 1, d_model)

print("\n\u2713 MHA with KV Cache test passed!")
print(f"\nCache memory per layer: {cache_k.numel() * 4 * 2 / 1024:.1f} KB (K+V, float32)")

## Interview Tips

**Q: Why use multiple heads instead of one large head?**
A: Multiple heads allow the model to jointly attend to information from different representation subspaces. Each head can learn different patterns (syntax, semantics, positional, etc.).

**Q: What's the relationship between d_model, num_heads, and d_head?**
A: d_head = d_model / num_heads. The total computation stays the same - we're just splitting d_model dimensions across multiple parallel attention operations.

**Q: How does masking work with multiple heads?**
A: The mask broadcasts across all heads. Shape (seq_q, seq_k) broadcasts to (batch, num_heads, seq_q, seq_k). Each head uses the same mask.

**Q: How does the KV cache shape differ from Q?**
A: During generation, Q is computed only for the new token (shape: batch, num_heads, 1, head_dim), while K and V include all previous tokens from the cache.

**Q: Why is the cache stored per-head rather than combined?**
A: Each head has its own K and V projections. Storing per-head allows efficient concatenation and avoids recomputing the head split each time.

**Q: What's the memory complexity of KV cache?**
A: O(batch * layers * num_heads * seq_len * head_dim * 2). For long sequences (128K+), this can be tens of GBs, motivating techniques like GQA, MQA, and paged attention.