# Implement Grouped Query Attention from Scratch

### Problem Statement

Standard Multi-Head Attention (MHA) assigns a separate key and value projection to each attention head. **Grouped Query Attention (GQA)** is a more efficient variant where multiple query heads share the same key-value heads.

### Background: Why GQA?

During autoregressive generation, the KV cache can become a memory bottleneck:
- For a 70B model with 64 heads and 8K context, MHA needs ~20GB just for KV cache
- GQA reduces this by sharing K/V across multiple query heads

### The Key Insight

| Attention Type | Query Heads | KV Heads | KV Cache Size |
|---------------|-------------|----------|---------------|
| MHA | 32 | 32 | 32 * seq * d_head * 2 |
| GQA (8 groups) | 32 | 8 | **8 * seq * d_head * 2** |
| MQA (1 group) | 32 | 1 | **1 * seq * d_head * 2** |

**Real-world examples:**
- LLaMA-2 70B: 64 query heads, 8 KV heads = **8x memory reduction**
- Mistral 7B: 32 query heads, 8 KV heads = **4x memory reduction**

### Learning Path

1. **Part 1**: Mask creation (causal, padding, KV cache)
2. **Part 2**: Core GQA mechanism - understand the repeat_interleave for grouped K/V
3. **Part 3**: GQA Self-Attention - Q, K, V from projections of single input x
4. **Part 4**: GQA with KV Cache - see the actual memory savings

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Part 1: Attention Mask Creation

The mask functions are the same as standard attention. For GQA, the mask broadcasts across all heads.

Mask convention: **True = masked (cannot attend), False = can attend**

In [None]:
def create_causal_mask(seq_len_q: int, seq_len_k: int = None, device=None) -> torch.Tensor:
    """
    Create a causal (lower-triangular) attention mask for autoregressive models.
    Shape: (seq_len_q, seq_len_k) -> broadcasts to (batch, num_heads, seq_len_q, seq_len_k)
    """
    if seq_len_k is None:
        seq_len_k = seq_len_q
    return torch.triu(torch.ones(seq_len_q, seq_len_k, dtype=torch.bool, device=device), diagonal=1)

In [None]:
def create_padding_mask(lengths: torch.Tensor, max_len: int) -> torch.Tensor:
    """
    Create a padding mask for variable-length sequences.
    Shape: (batch, max_len)
    """
    positions = torch.arange(max_len, device=lengths.device).unsqueeze(0)
    return positions >= lengths.unsqueeze(1)

In [None]:
def create_attention_mask(
    seq_len_q: int,
    seq_len_k: int = None,
    is_causal: bool = True,
    key_padding_lengths: torch.Tensor = None,
    device=None
) -> torch.Tensor:
    """
    Create a combined attention mask for GQA.
    Returns mask with shape suitable for broadcasting to (batch, num_heads, seq_q, seq_k)
    """
    if seq_len_k is None:
        seq_len_k = seq_len_q

    mask = torch.zeros(seq_len_q, seq_len_k, dtype=torch.bool, device=device)

    if is_causal:
        causal = create_causal_mask(seq_len_q, seq_len_k, device=device)
        mask = mask | causal

    if key_padding_lengths is not None:
        padding = create_padding_mask(key_padding_lengths, seq_len_k)
        padding = padding.unsqueeze(1).unsqueeze(2)  # (batch, 1, 1, seq_k)
        mask = mask.unsqueeze(0).unsqueeze(0) | padding  # (batch, 1, seq_q, seq_k)

    return mask

In [None]:
def create_causal_mask_with_cache(
    seq_len_q: int,
    seq_len_k: int,
    cache_len: int,
    device=None
) -> torch.Tensor:
    """
    Create a causal mask for attention with KV cache.
    New queries can attend to all cached positions.
    """
    mask = torch.zeros(seq_len_q, seq_len_k, dtype=torch.bool, device=device)

    if seq_len_q > 1:
        new_token_mask = create_causal_mask(seq_len_q, seq_len_q, device=device)
        mask[:, cache_len:] = new_token_mask

    return mask

In [None]:
# Test mask creation
mask = create_causal_mask(4)
print("Causal mask (4x4):")
print(mask)

# Test with padding
lengths = torch.tensor([4, 3])
mask_combined = create_attention_mask(4, is_causal=True, key_padding_lengths=lengths)
print(f"\nCombined mask shape: {mask_combined.shape}")
print("\n\u2713 Mask tests passed!")

## Part 2: Core GQA Mechanism

The key difference from standard MHA:
- Q has `num_query_heads` heads
- K and V have `num_kv_heads` heads (fewer!)
- We expand K, V using `repeat_interleave` to match Q's head count

In [None]:
torch.manual_seed(42)

batch_size = 2
seq_len = 8
d_model = 64
num_query_heads = 8
num_kv_heads = 2  # 4x fewer KV heads than query heads
d_head = d_model // num_query_heads

print(f"d_model={d_model}")
print(f"num_query_heads={num_query_heads}, num_kv_heads={num_kv_heads}")
print(f"d_head={d_head}")
print(f"Query heads per KV head: {num_query_heads // num_kv_heads}")
print(f"\nKV projection size: {num_kv_heads * d_head} (vs {d_model} for full MHA)")

In [None]:
def grouped_query_attention_core(Q, K, V, num_query_heads, num_kv_heads, mask=None):
    """
    Core GQA computation (Q, K, V already projected and reshaped to heads).
    
    Args:
        Q: Query tensor (batch, num_query_heads, seq_len, d_head)
        K: Key tensor (batch, num_kv_heads, seq_len, d_head) - fewer heads!
        V: Value tensor (batch, num_kv_heads, seq_len, d_head) - fewer heads!
        num_query_heads: Number of query heads
        num_kv_heads: Number of key/value heads
        mask: Optional boolean attention mask (True = masked)
    
    Returns:
        output: GQA output (batch, num_query_heads, seq_len, d_head)
    """
    batch_size, _, seq_len, d_head = Q.shape
    
    # Expand K, V to match query heads
    # Each KV head is shared by (num_query_heads // num_kv_heads) query heads
    repeat_factor = num_query_heads // num_kv_heads
    K = K.repeat_interleave(repeat_factor, dim=1)  # (batch, num_query_heads, seq, d_head)
    V = V.repeat_interleave(repeat_factor, dim=1)
    
    # Standard attention
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_head ** 0.5)
    
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    return output

In [None]:
# Test core GQA mechanism
Q = torch.randn(batch_size, num_query_heads, seq_len, d_head)
K = torch.randn(batch_size, num_kv_heads, seq_len, d_head)  # Fewer heads!
V = torch.randn(batch_size, num_kv_heads, seq_len, d_head)  # Fewer heads!

print(f"Q shape: {Q.shape} ({num_query_heads} heads)")
print(f"K shape: {K.shape} ({num_kv_heads} heads)")
print(f"V shape: {V.shape} ({num_kv_heads} heads)")

output = grouped_query_attention_core(Q, K, V, num_query_heads, num_kv_heads)
print(f"\nOutput shape: {output.shape}")

assert output.shape == Q.shape
print("\n\u2713 Core GQA mechanism test passed!")

In [None]:
# Test with causal mask
causal_mask = create_causal_mask(seq_len)
output_causal = grouped_query_attention_core(Q, K, V, num_query_heads, num_kv_heads, mask=causal_mask)

print(f"Causal output shape: {output_causal.shape}")
print("\u2713 GQA with causal mask works!")

## Part 3: GQA Self-Attention

Full GQA where Q, K, V come from projections of a single input x.

Key difference from MHA:
- Q projection: `d_model -> d_model` (all query heads)
- K/V projection: `d_model -> kv_dim` where `kv_dim = num_kv_heads * d_head` (smaller!)

In [None]:
class GroupedQuerySelfAttention(nn.Module):
    """
    Grouped Query Self-Attention where Q, K, V come from projections of the same input,
    but K and V have fewer heads than Q.
    """
    
    def __init__(self, d_model: int, num_query_heads: int, num_kv_heads: int):
        super().__init__()
        assert d_model % num_query_heads == 0
        assert num_query_heads % num_kv_heads == 0
        
        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.d_head = d_model // num_query_heads
        self.kv_dim = num_kv_heads * self.d_head  # Smaller than d_model!
        
        # Q projection: full d_model
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        
        # K/V projections: smaller!
        self.W_k = nn.Linear(d_model, self.kv_dim, bias=False)
        self.W_v = nn.Linear(d_model, self.kv_dim, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(
        self, 
        x: torch.Tensor, 
        is_causal: bool = False,
        key_padding_lengths: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: Input tensor (batch, seq_len, d_model) - the residual stream
            is_causal: Whether to apply causal masking
            key_padding_lengths: If provided, actual lengths for padding mask
        
        Returns:
            output: GQA output (batch, seq_len, d_model)
        """
        batch_size, seq_len, _ = x.shape
        
        # Project x to Q, K, V
        Q = self.W_q(x)  # (batch, seq, d_model)
        K = self.W_k(x)  # (batch, seq, kv_dim) - smaller!
        V = self.W_v(x)  # (batch, seq, kv_dim) - smaller!
        
        # Reshape to heads
        Q = Q.view(batch_size, seq_len, self.num_query_heads, self.d_head).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.d_head).transpose(1, 2)
        
        # Expand K, V to match query heads
        repeat_factor = self.num_query_heads // self.num_kv_heads
        K = K.repeat_interleave(repeat_factor, dim=1)
        V = V.repeat_interleave(repeat_factor, dim=1)
        
        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_head ** 0.5)
        
        # Create and apply mask
        if is_causal or key_padding_lengths is not None:
            mask = create_attention_mask(
                seq_len_q=seq_len,
                seq_len_k=seq_len,
                is_causal=is_causal,
                key_padding_lengths=key_padding_lengths,
                device=x.device
            )
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # Concatenate heads and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(attn_output)

In [None]:
# Test GQA Self-Attention
torch.manual_seed(42)

# Single input - the residual stream
x = torch.randn(batch_size, seq_len, d_model)
print(f"Input x shape: {x.shape}")
print("This single input will be projected to create Q, K, V")
print(f"Q: {d_model} dims ({num_query_heads} heads)")
print(f"K/V: {num_kv_heads * d_head} dims ({num_kv_heads} heads) - smaller!")

# Create GQA layer
gqa = GroupedQuerySelfAttention(d_model, num_query_heads, num_kv_heads)

# Forward pass (bidirectional)
output = gqa(x)
print(f"\nOutput shape: {output.shape}")

# Forward pass (causal)
output_causal = gqa(x, is_causal=True)
print(f"Causal output shape: {output_causal.shape}")

# Forward pass (with padding)
lengths = torch.tensor([8, 5])
output_padded = gqa(x, is_causal=True, key_padding_lengths=lengths)
print(f"Padded output shape: {output_padded.shape}")

# Count parameters
mha_params = d_model * d_model * 4  # Q, K, V, O each d_model x d_model
gqa_params = d_model * d_model + 2 * d_model * (num_kv_heads * d_head) + d_model * d_model
print(f"\nMHA projection params: {mha_params:,}")
print(f"GQA projection params: {gqa_params:,}")
print(f"Parameter reduction: {(1 - gqa_params/mha_params)*100:.1f}%")

assert output.shape == x.shape
print("\n\u2713 GQA Self-Attention test passed!")

## Validate GQA degenerates to MHA

When `num_query_heads == num_kv_heads`, GQA should behave identically to standard MHA.

In [None]:
def grouped_query_attention(q, k, v, num_query_heads, num_kv_heads, d_model, mask=None, weights=None):
    """
    GQA function that can use external weights for validation against PyTorch MHA.
    
    When num_query_heads == num_kv_heads, this degenerates to standard MHA.
    """
    assert d_model % num_query_heads == 0
    assert num_query_heads % num_kv_heads == 0
    
    d_head = d_model // num_query_heads
    kv_dim = num_kv_heads * d_head
    batch_size, seq_len, _ = q.shape
    
    # Create projections
    Q_w = nn.Linear(d_model, d_model, bias=False)
    K_w = nn.Linear(d_model, kv_dim, bias=False)
    V_w = nn.Linear(d_model, kv_dim, bias=False)
    W_out = nn.Linear(d_model, d_model, bias=False)
    
    if weights is not None:
        Q_w.weight.data = weights['q_weight']
        K_w.weight.data = weights['k_weight']
        V_w.weight.data = weights['v_weight']
        W_out.weight.data = weights['out_weight']
    
    # Project
    Q = Q_w(q)
    K = K_w(k)
    V = V_w(v)
    
    # Reshape to heads
    Q = Q.view(batch_size, seq_len, num_query_heads, d_head).transpose(1, 2)
    K = K.view(batch_size, seq_len, num_kv_heads, d_head).transpose(1, 2)
    V = V.view(batch_size, seq_len, num_kv_heads, d_head).transpose(1, 2)
    
    # Expand K, V
    repeat_factor = num_query_heads // num_kv_heads
    K = K.repeat_interleave(repeat_factor, dim=1)
    V = V.repeat_interleave(repeat_factor, dim=1)
    
    # Attention
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_head ** 0.5)
    if mask is not None:
        scores = scores.masked_fill(mask, float('-inf'))
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    # Concatenate and project
    output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
    return W_out(output)

In [None]:
# Test: GQA degenerating to MHA
torch.manual_seed(42)

test_d_model = 64
test_num_heads = 4
x = torch.randn(batch_size, seq_len, test_d_model)

# Create PyTorch MHA reference
mha = torch.nn.MultiheadAttention(
    embed_dim=test_d_model, num_heads=test_num_heads, bias=False, batch_first=True
)

# Extract weights
weights = {
    'q_weight': mha.in_proj_weight[:test_d_model, :],
    'k_weight': mha.in_proj_weight[test_d_model:2*test_d_model, :],
    'v_weight': mha.in_proj_weight[2*test_d_model:, :],
    'out_weight': mha.out_proj.weight
}

# GQA with equal query and KV heads = MHA
output_gqa = grouped_query_attention(
    x, x, x,
    num_query_heads=test_num_heads,
    num_kv_heads=test_num_heads,  # Same as query heads = MHA!
    d_model=test_d_model,
    weights=weights
)

output_mha, _ = mha(x, x, x)

assert torch.allclose(output_gqa, output_mha, atol=1e-6), "GQA doesn't match MHA!"
print("\u2713 GQA matches MHA when num_query_heads == num_kv_heads")
print(f"Max difference: {(output_gqa - output_mha).abs().max().item():.2e}")

## Part 4: GQA with KV Cache

### Why GQA Dramatically Reduces KV Cache Memory

The cache only needs to store K and V for `num_kv_heads`, not `num_query_heads`!

**Cache shape comparison:**
- MHA: `(batch, num_query_heads, seq_len, head_dim)`
- GQA: `(batch, num_kv_heads, seq_len, head_dim)` - smaller!

For LLaMA-2 70B with 8K context:
- MHA cache: ~20 GB
- GQA cache (8 KV heads): ~2.5 GB

In [None]:
class GQAWithCache(nn.Module):
    """
    Grouped Query Self-Attention with KV Cache for efficient inference.
    
    Takes a single input x (the residual stream) and projects it to Q, K, V.
    The cache stores K and V with fewer heads than Q.
    """
    
    def __init__(self, d_model: int, num_query_heads: int, num_kv_heads: int):
        super().__init__()
        assert d_model % num_query_heads == 0
        assert num_query_heads % num_kv_heads == 0
        
        self.d_model = d_model
        self.num_query_heads = num_query_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = d_model // num_query_heads
        self.kv_dim = num_kv_heads * self.head_dim
        
        # Projections
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, self.kv_dim, bias=False)  # Smaller!
        self.v_proj = nn.Linear(d_model, self.kv_dim, bias=False)  # Smaller!
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
    
    def forward(
        self,
        x: torch.Tensor,
        cache_k: torch.Tensor = None,
        cache_v: torch.Tensor = None,
        is_causal: bool = True,
    ) -> tuple:
        """
        Args:
            x: Input tensor (batch, seq_len, d_model) - the residual stream
            cache_k: Cached keys (batch, num_kv_heads, cached_len, head_dim)
            cache_v: Cached values (batch, num_kv_heads, cached_len, head_dim)
            is_causal: Whether to apply causal masking
        
        Returns:
            output: GQA output (batch, seq_len, d_model)
            new_cache_k: Updated key cache
            new_cache_v: Updated value cache
        """
        batch_size, seq_len_q, _ = x.shape
        cache_len = cache_k.size(2) if cache_k is not None else 0
        
        # Project x to Q, K, V
        Q = self.q_proj(x)
        K = self.k_proj(x)  # Smaller: (batch, seq, kv_dim)
        V = self.v_proj(x)  # Smaller: (batch, seq, kv_dim)
        
        # Reshape to heads
        Q = Q.view(batch_size, seq_len_q, self.num_query_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len_q, self.num_kv_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len_q, self.num_kv_heads, self.head_dim).transpose(1, 2)
        
        # Concatenate with cache (smaller cache!)
        if cache_k is not None and cache_v is not None:
            K = torch.cat([cache_k, K], dim=2)
            V = torch.cat([cache_v, V], dim=2)
        
        # Update cache (store the smaller K, V)
        new_cache_k = K
        new_cache_v = V
        
        seq_len_k = K.size(2)
        
        # Expand K, V to match query heads for attention
        repeat_factor = self.num_query_heads // self.num_kv_heads
        K_expanded = K.repeat_interleave(repeat_factor, dim=1)
        V_expanded = V.repeat_interleave(repeat_factor, dim=1)
        
        # Attention
        scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply causal mask with cache offset
        if is_causal:
            mask = create_causal_mask_with_cache(
                seq_len_q=seq_len_q,
                seq_len_k=seq_len_k,
                cache_len=cache_len,
                device=x.device
            )
            scores = scores.masked_fill(mask, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V_expanded)
        
        # Reshape back
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len_q, self.d_model)
        output = self.out_proj(output)
        
        return output, new_cache_k, new_cache_v

In [None]:
# Test GQA with KV Cache
print("=== Testing GQA with KV Cache ===")

torch.manual_seed(42)
batch_size = 2
d_model = 64
num_query_heads = 8
num_kv_heads = 2  # 4x fewer KV heads!

gqa_cached = GQAWithCache(d_model, num_query_heads, num_kv_heads)

# Step 1: Process prompt (3 tokens)
# x is the residual stream - Q, K, V all come from projections of x
prompt = torch.randn(batch_size, 3, d_model)
print(f"\nInput prompt shape: {prompt.shape}")
print(f"Q will have {num_query_heads} heads, K/V will have {num_kv_heads} heads")

out1, cache_k, cache_v = gqa_cached(prompt, None, None)
print(f"\nAfter prompt: cache shape = {cache_k.shape}")
print(f"  Note: only {num_kv_heads} KV heads cached, not {num_query_heads}!")

# Step 2: Generate token 4
new_token = torch.randn(batch_size, 1, d_model)
out2, cache_k, cache_v = gqa_cached(new_token, cache_k, cache_v)
print(f"After token 4: cache shape = {cache_k.shape}")

# Step 3: Generate token 5
new_token = torch.randn(batch_size, 1, d_model)
out3, cache_k, cache_v = gqa_cached(new_token, cache_k, cache_v)
print(f"After token 5: cache shape = {cache_k.shape}")

# Verify
head_dim = d_model // num_query_heads
assert cache_k.shape == (batch_size, num_kv_heads, 5, head_dim)
assert out3.shape == (batch_size, 1, d_model)

print("\n\u2713 GQA with KV Cache test passed!")

In [None]:
# Memory comparison: GQA vs MHA
print("\n" + "="*60)
print("Memory Comparison: GQA vs MHA")
print("="*60)

mha_cache = batch_size * num_query_heads * 5 * head_dim * 2  # K + V
gqa_cache = batch_size * num_kv_heads * 5 * head_dim * 2     # K + V

print(f"MHA cache elements: {mha_cache:,} ({num_query_heads} heads)")
print(f"GQA cache elements: {gqa_cache:,} ({num_kv_heads} heads)")
print(f"Memory reduction: {mha_cache / gqa_cache:.1f}x")

print("\n" + "="*60)
print("Real-World Scaling (LLaMA-2 70B, 8K context)")
print("="*60)
print("MHA: 64 heads x 8K x 128 dim x 2 (K+V) x 80 layers x 2 bytes = ~20 GB")
print("GQA:  8 heads x 8K x 128 dim x 2 (K+V) x 80 layers x 2 bytes = ~2.5 GB")
print("Savings: 8x memory reduction!")

## Interview Tips

**Q: What's the difference between MHA, GQA, and MQA?**
A:
- MHA: Each query head has its own K and V heads (1:1 ratio)
- GQA: Multiple query heads share K/V heads (e.g., 8:1 ratio)
- MQA: All query heads share a single K/V head (all:1 ratio)

**Q: Why does GQA reduce memory but not compute?**
A: The K/V projections are smaller, saving some compute. But during attention, we expand K/V using repeat_interleave, so the actual attention computation is similar to MHA. The main savings are in KV cache memory.

**Q: When should you use GQA vs MHA?**
A: GQA is preferred for large models where KV cache memory is a bottleneck (inference with long contexts). For training or small models, MHA is fine.

**Q: How do you choose the number of KV heads?**
A: Common ratios are 4:1 to 8:1 (query:KV heads). LLaMA-2 70B uses 64:8, Mistral uses 32:8. The ratio depends on the tradeoff between quality and memory.

**Q: Does GQA hurt model quality?**
A: Slightly, but the tradeoff is usually worthwhile. GQA models achieve ~95-99% of MHA quality while using 4-8x less KV cache memory.