# Grouped Query Attention (GQA): Memory-Efficient Inference

This notebook explains **Grouped Query Attention (GQA)**, the memory optimization technique used in Llama 2/3, Mistral, and other modern LLMs to reduce KV-cache size during inference.

## The Problem: KV-Cache Memory Explosion

In autoregressive decoding, we cache keys and values from previous tokens to avoid recomputing them:

```python
# Without cache: recompute everything at each step O(T²)
for t in range(max_len):
    logits = model(tokens[:t+1])  # Processes all previous tokens

# With cache: reuse previous K/V O(T)
cache = model.make_cache()
for t in range(max_len):
    logits, cache = model(tokens[t:t+1], kv_cache=cache)  # Only new token
```

**Cache memory per layer:**
```
K cache: (batch, num_heads, seq_len, head_dim)
V cache: (batch, num_heads, seq_len, head_dim)
```

**Example:** 8 heads, seq_len=2048, head_dim=64, fp32
- K cache: 1 × 8 × 2048 × 64 × 4 bytes = **4 MB**
- V cache: 1 × 8 × 2048 × 64 × 4 bytes = **4 MB**
- **Total per layer: 8 MB**
- **12 layers: 96 MB** (just for one sequence!)

This scales linearly with context length. At 32k context: **1.5 GB per sequence**.

## The Insight: Do Queries and Keys Need the Same Number of Heads?

Standard Multi-Head Attention (MHA):
- 8 query heads → 8 key heads → 8 value heads
- Each query head attends to its own dedicated K/V head

**GQA asks:** What if multiple query heads *share* the same K/V heads?

**Grouped Query Attention:**
- 8 query heads → **2 key heads** → 2 value heads
- 4 query heads per group, each group shares one K/V head
- **Cache memory: 4× smaller** (2 KV heads instead of 8)

**Multi-Query Attention (MQA):** Extreme case with num_kv_heads=1
- 8 query heads → **1 key head** → 1 value head
- All queries share a single K/V head
- **Cache memory: 8× smaller**
- Used in PaLM, Falcon (but can hurt quality)

In [1]:
import torch

# Import our GQA-enabled attention
import sys
sys.path.append('..')
from atomiclm.model.attention import MultiHeadAttention
from atomiclm.model.decoder import Decoder

torch.manual_seed(42)

<torch._C.Generator at 0x10fb51870>

## Three Attention Variants, One Implementation

Our `MultiHeadAttention` supports all three via a single `num_kv_heads` parameter:

| Variant | `num_kv_heads` | Example (8 heads) | Cache Size | Quality |
|---------|----------------|-------------------|------------|---------|
| **MHA** | `num_heads` | 8 Q, 8 KV | 100% | Best |
| **GQA** | `1 < x < num_heads` | 8 Q, 2 KV | 25% | ~MHA |
| **MQA** | `1` | 8 Q, 1 KV | 12.5% | Good |

**Key constraint:** `num_heads` must be divisible by `num_kv_heads` (queries group evenly)

In [2]:
d_model = 512
num_heads = 8
head_dim = d_model // num_heads  # 64

# Multi-Head Attention (standard)
mha = MultiHeadAttention(
    d_in=d_model,
    d_out=d_model,
    num_heads=num_heads,
    num_kv_heads=num_heads,  # Same as num_heads
)

# Grouped Query Attention (2 groups)
gqa_2 = MultiHeadAttention(
    d_in=d_model,
    d_out=d_model,
    num_heads=num_heads,
    num_kv_heads=2,  # 8 queries share 2 KV heads
)

# Grouped Query Attention (4 groups)
gqa_4 = MultiHeadAttention(
    d_in=d_model,
    d_out=d_model,
    num_heads=num_heads,
    num_kv_heads=4,  # 8 queries share 4 KV heads
)

# Multi-Query Attention (single KV head)
mqa = MultiHeadAttention(
    d_in=d_model,
    d_out=d_model,
    num_heads=num_heads,
    num_kv_heads=1,  # All 8 queries share 1 KV head
)

print("MHA:")
print(f"  num_heads={mha.num_heads}, num_kv_heads={mha.num_kv_heads}, num_groups={mha.num_groups}")
print("\nGQA (2 KV heads):")
print(f"  num_heads={gqa_2.num_heads}, num_kv_heads={gqa_2.num_kv_heads}, num_groups={gqa_2.num_groups}")
print("\nGQA (4 KV heads):")
print(f"  num_heads={gqa_4.num_heads}, num_kv_heads={gqa_4.num_kv_heads}, num_groups={gqa_4.num_groups}")
print("\nMQA:")
print(f"  num_heads={mqa.num_heads}, num_kv_heads={mqa.num_kv_heads}, num_groups={mqa.num_groups}")
print("\nnum_groups = queries per KV head")

MHA:
  num_heads=8, num_kv_heads=8, num_groups=1

GQA (2 KV heads):
  num_heads=8, num_kv_heads=2, num_groups=4

GQA (4 KV heads):
  num_heads=8, num_kv_heads=4, num_groups=2

MQA:
  num_heads=8, num_kv_heads=1, num_groups=8

num_groups = queries per KV head


## How GQA Works: Implementation Details

### 1. Separate Q/K/V Projections

Standard MHA uses a fused QKV projection:
```python
# MHA: single projection
qkv = self.qkv_proj(x)  # (b, t, 3 * d_out)
q, k, v = split(qkv)    # Each: (b, num_heads, t, head_dim)
```

GQA uses separate projections with different output sizes:
```python
# GQA: separate projections
q = self.q_proj(x)  # (b, t, num_heads * head_dim)
k = self.k_proj(x)  # (b, t, num_kv_heads * head_dim)  ← smaller!
v = self.v_proj(x)  # (b, t, num_kv_heads * head_dim)  ← smaller!
```

### 2. KV Expansion Before Attention

We store K/V in compressed form but expand before attention:
```python
# After projection: k, v are (b, num_kv_heads, t, head_dim)

# Expand each KV head to match num_groups query heads
k = repeat_kv(k, num_groups)  # (b, num_heads, t, head_dim)
v = repeat_kv(v, num_groups)  # (b, num_heads, t, head_dim)

# Now attention is identical to MHA
scores = q @ k.transpose(-2, -1)
attn = softmax(scores, dim=-1)
output = attn @ v
```

### 3. The `repeat_kv` Helper

```python
@staticmethod
def _repeat_kv(x: Tensor, num_groups: int) -> Tensor:
    """Expand (b, num_kv_heads, t, d) to (b, num_heads, t, d)"""
    if num_groups == 1:  # MHA case, no expansion needed
        return x
    
    b, num_kv_heads, t, d_h = x.shape
    # Unsqueeze: (b, num_kv_heads, 1, t, d_h)
    # Expand:   (b, num_kv_heads, num_groups, t, d_h)
    x = x.unsqueeze(2).expand(b, num_kv_heads, num_groups, t, d_h)
    # Reshape:  (b, num_kv_heads * num_groups, t, d_h)
    return x.reshape(b, num_kv_heads * num_groups, t, d_h)
```

**Example:** 8 query heads, 2 KV heads
- Input K: `(1, 2, 100, 64)` — 2 KV heads
- After expansion: `(1, 8, 100, 64)` — 8 heads (each KV head repeated 4 times)
- KV heads 0-3 all use the same K[0], KV heads 4-7 use K[1]

In [3]:
# Demonstrate repeat_kv
batch = 1
num_kv_heads = 2
seq_len = 4
head_dim = 8

# Create dummy KV tensor with distinct values per head
kv = torch.arange(batch * num_kv_heads * seq_len * head_dim, dtype=torch.float32)
kv = kv.reshape(batch, num_kv_heads, seq_len, head_dim)

print("Original KV (2 heads):")
print(f"Shape: {kv.shape}")
print(f"Head 0, position 0: {kv[0, 0, 0, :4]}")
print(f"Head 1, position 0: {kv[0, 1, 0, :4]}")

# Expand by num_groups=4 (to get 8 total heads)
num_groups = 4
kv_expanded = MultiHeadAttention._repeat_kv(kv, num_groups)

print(f"\nExpanded KV (8 heads):")
print(f"Shape: {kv_expanded.shape}")
print(f"\nHeads 0-3 should match original head 0:")
for h in range(4):
    print(f"  Head {h}, pos 0: {kv_expanded[0, h, 0, :4]}")
print(f"\nHeads 4-7 should match original head 1:")
for h in range(4, 8):
    print(f"  Head {h}, pos 0: {kv_expanded[0, h, 0, :4]}")

print("\n✅ Each original KV head is repeated num_groups times")

Original KV (2 heads):
Shape: torch.Size([1, 2, 4, 8])
Head 0, position 0: tensor([0., 1., 2., 3.])
Head 1, position 0: tensor([32., 33., 34., 35.])

Expanded KV (8 heads):
Shape: torch.Size([1, 8, 4, 8])

Heads 0-3 should match original head 0:
  Head 0, pos 0: tensor([0., 1., 2., 3.])
  Head 1, pos 0: tensor([0., 1., 2., 3.])
  Head 2, pos 0: tensor([0., 1., 2., 3.])
  Head 3, pos 0: tensor([0., 1., 2., 3.])

Heads 4-7 should match original head 1:
  Head 4, pos 0: tensor([32., 33., 34., 35.])
  Head 5, pos 0: tensor([32., 33., 34., 35.])
  Head 6, pos 0: tensor([32., 33., 34., 35.])
  Head 7, pos 0: tensor([32., 33., 34., 35.])

✅ Each original KV head is repeated num_groups times


## Memory Comparison: MHA vs GQA vs MQA

Let's compute actual memory usage for different configurations.

In [4]:
def cache_memory_mb(batch_size, num_kv_heads, seq_len, head_dim, num_layers, dtype=torch.float32):
    """
    Calculate KV-cache memory in MB.
    
    Cache shape per layer: 2 × (batch, num_kv_heads, seq_len, head_dim)
    """
    bytes_per_element = 4 if dtype == torch.float32 else 2  # fp32 or fp16
    elements_per_layer = 2 * batch_size * num_kv_heads * seq_len * head_dim
    bytes_per_layer = elements_per_layer * bytes_per_element
    total_bytes = bytes_per_layer * num_layers
    return total_bytes / (1024 ** 2)  # Convert to MB

# Configuration
batch_size = 1
num_heads = 8
head_dim = 64
num_layers = 12
seq_len = 2048

configs = [
    ("MHA", num_heads),
    ("GQA-4", 4),
    ("GQA-2", 2),
    ("MQA", 1),
]

print(f"Configuration: {num_heads} heads, {head_dim}-dim, {num_layers} layers, seq_len={seq_len}\n")
print(f"{'Variant':<10} {'KV Heads':<10} {'Memory (MB)':<15} {'vs MHA':<10}")
print("="*50)

mha_memory = None
for name, num_kv_heads in configs:
    memory = cache_memory_mb(batch_size, num_kv_heads, seq_len, head_dim, num_layers)
    if mha_memory is None:
        mha_memory = memory
        ratio = "baseline"
    else:
        ratio = f"{mha_memory / memory:.1f}× smaller"
    print(f"{name:<10} {num_kv_heads:<10} {memory:<15.2f} {ratio}")

print(f"\n✅ GQA-2 uses 4× less memory than MHA")
print(f"✅ MQA uses 8× less memory than MHA")

Configuration: 8 heads, 64-dim, 12 layers, seq_len=2048

Variant    KV Heads   Memory (MB)     vs MHA    
MHA        8          96.00           baseline
GQA-4      4          48.00           2.0× smaller
GQA-2      2          24.00           4.0× smaller
MQA        1          12.00           8.0× smaller

✅ GQA-2 uses 4× less memory than MHA
✅ MQA uses 8× less memory than MHA


In [5]:
# Memory scaling with context length
seq_lengths = [512, 1024, 2048, 4096, 8192, 16384, 32768]

mha_mem = [cache_memory_mb(1, 8, s, 64, 12) for s in seq_lengths]
gqa_2_mem = [cache_memory_mb(1, 2, s, 64, 12) for s in seq_lengths]
mqa_mem = [cache_memory_mb(1, 1, s, 64, 12) for s in seq_lengths]

print(f"{'Context':<10} {'MHA (MB)':<12} {'GQA-2 (MB)':<12} {'MQA (MB)':<12}")
print("=" * 46)
for s, m, g, q in zip(seq_lengths, mha_mem, gqa_2_mem, mqa_mem):
    print(f"{s:<10} {m:<12.1f} {g:<12.1f} {q:<12.1f}")

print(f"\nAt 32k context:")
print(f"  MHA: {mha_mem[-1]:.1f} MB")
print(f"  GQA-2: {gqa_2_mem[-1]:.1f} MB (4x reduction)")
print(f"  MQA: {mqa_mem[-1]:.1f} MB (8x reduction)")

Context    MHA (MB)     GQA-2 (MB)   MQA (MB)    
512        24.0         6.0          3.0         
1024       48.0         12.0         6.0         
2048       96.0         24.0         12.0        
4096       192.0        48.0         24.0        
8192       384.0        96.0         48.0        
16384      768.0        192.0        96.0        
32768      1536.0       384.0        192.0       

At 32k context:
  MHA: 1536.0 MB
  GQA-2: 384.0 MB (4x reduction)
  MQA: 192.0 MB (8x reduction)


## Correctness: GQA Cached Generation

Verify that GQA with KV-cache produces identical output to full forward pass.

In [6]:
# Test GQA cached vs full forward
torch.manual_seed(0)

d_in = 128
d_out = 128
num_heads = 8
num_kv_heads = 2
seq_len = 16

gqa = MultiHeadAttention(
    d_in=d_in,
    d_out=d_out,
    num_heads=num_heads,
    num_kv_heads=num_kv_heads,
    dropout=0.0,
)
gqa.eval()

x = torch.randn(1, seq_len, d_in)

# Full sequence forward
with torch.no_grad():
    full_out, _ = gqa(x)

# Token-by-token with cache
cache = {
    "k": torch.zeros(1, num_kv_heads, seq_len, d_out // num_heads),
    "v": torch.zeros(1, num_kv_heads, seq_len, d_out // num_heads),
    "pos": 0,
}

cached_outputs = []
with torch.no_grad():
    for i in range(seq_len):
        token = x[:, i:i+1, :]
        out, cache = gqa(token, kv_cache=cache)
        cached_outputs.append(out)

cached_out = torch.cat(cached_outputs, dim=1)

# Compare
max_diff = (full_out - cached_out).abs().max().item()
print(f"Full forward output shape: {full_out.shape}")
print(f"Cached output shape: {cached_out.shape}")
print(f"\nMax difference: {max_diff:.2e}")
print(f"Outputs match: {torch.allclose(full_out, cached_out, atol=1e-6)}")
print("\n✅ GQA with KV-cache produces identical output to full forward")

Full forward output shape: torch.Size([1, 16, 128])
Cached output shape: torch.Size([1, 16, 128])

Max difference: 2.98e-07
Outputs match: True

✅ GQA with KV-cache produces identical output to full forward


## Using GQA in the Full Decoder

Let's create a small decoder model with GQA and verify memory savings.

In [7]:
# Create two decoders: MHA and GQA
vocab_size = 1024
d_model = 256
num_layers = 6
num_heads = 8
d_ff = 1024
max_seq_len = 1024

torch.manual_seed(42)
decoder_mha = Decoder(
    vocab_size=vocab_size,
    d_model=d_model,
    num_layers=num_layers,
    num_heads=num_heads,
    num_kv_heads=num_heads,  # MHA
    d_ff=d_ff,
    max_seq_len=max_seq_len,
)

torch.manual_seed(42)
decoder_gqa = Decoder(
    vocab_size=vocab_size,
    d_model=d_model,
    num_layers=num_layers,
    num_heads=num_heads,
    num_kv_heads=2,  # GQA with 2 KV heads
    d_ff=d_ff,
    max_seq_len=max_seq_len,
)

# Create caches
cache_mha = decoder_mha.make_cache(batch_size=1)
cache_gqa = decoder_gqa.make_cache(batch_size=1)

# Calculate memory
def cache_size_mb(cache):
    total_elements = 0
    for layer_cache in cache:
        total_elements += layer_cache["k"].numel() + layer_cache["v"].numel()
    return total_elements * 4 / (1024 ** 2)  # fp32, convert to MB

mha_size = cache_size_mb(cache_mha)
gqa_size = cache_size_mb(cache_gqa)

print(f"Decoder configuration:")
print(f"  Layers: {num_layers}")
print(f"  Heads: {num_heads}")
print(f"  d_model: {d_model}")
print(f"  Max sequence length: {max_seq_len}")
print(f"\nKV-Cache Memory:")
print(f"  MHA ({num_heads} KV heads): {mha_size:.2f} MB")
print(f"  GQA (2 KV heads): {gqa_size:.2f} MB")
print(f"  Reduction: {mha_size / gqa_size:.1f}×")
print(f"\n✅ GQA decoder uses {mha_size / gqa_size:.1f}× less cache memory")

Decoder configuration:
  Layers: 6
  Heads: 8
  d_model: 256
  Max sequence length: 1024

KV-Cache Memory:
  MHA (8 KV heads): 12.00 MB
  GQA (2 KV heads): 3.00 MB
  Reduction: 4.0×

✅ GQA decoder uses 4.0× less cache memory


In [8]:
# Test generation with GQA decoder
decoder_gqa.eval()

# Create a prompt
prompt = torch.randint(0, vocab_size, (1, 10))
print(f"Prompt shape: {prompt.shape}")
print(f"Prompt tokens: {prompt[0].tolist()[:10]}")

# Generate
torch.manual_seed(0)
with torch.no_grad():
    output = decoder_gqa.generate(prompt, max_new_tokens=20, temperature=1.0)

print(f"\nGenerated sequence shape: {output.shape}")
print(f"Generated tokens: {output[0].tolist()}")
print(f"\n✅ GQA decoder successfully generates sequences")

Prompt shape: torch.Size([1, 10])
Prompt tokens: [6, 467, 955, 710, 679, 907, 439, 81, 584, 620]

Generated sequence shape: torch.Size([1, 30])
Generated tokens: [6, 467, 955, 710, 679, 907, 439, 81, 584, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620, 620]

✅ GQA decoder successfully generates sequences


## Production Usage: Real Models

### Models Using GQA:

**Llama 2 (7B):**
- 32 query heads, 32 KV heads = **MHA** (baseline)

**Llama 2 (70B):**
- 64 query heads, **8 KV heads** = GQA with 8 groups
- 8× cache reduction vs MHA

**Llama 3 (8B, 70B):**
- Uses GQA across all model sizes
- 8B: 32 query heads, 8 KV heads (4 groups)

**Mistral 7B:**
- 32 query heads, **8 KV heads** (4 groups)

**Qwen 2:**
- All sizes use GQA

### When to Use Each Variant:

| Variant | Use Case |
|---------|----------|
| **MHA** | Small models (<1B), research, when memory isn't constrained |
| **GQA** | **Production default** — best quality/memory trade-off |
| **MQA** | Extreme memory constraints, edge deployment, small quality loss acceptable |

### Choosing `num_kv_heads`:

Common choices:
- `num_kv_heads = num_heads` → MHA (baseline quality)
- `num_kv_heads = num_heads // 2` → 2× memory reduction
- `num_kv_heads = num_heads // 4` → 4× memory reduction (Llama 2/3, Mistral)
- `num_kv_heads = num_heads // 8` → 8× memory reduction (Llama 2 70B)
- `num_kv_heads = 1` → MQA (max reduction)

**Rule of thumb:** More KV heads = better quality, more memory. Start with `num_heads // 4` for good balance.

## Summary

**Grouped Query Attention (GQA)** reduces KV-cache memory by sharing K/V heads across multiple query heads:

### Key Concepts:
1. **Separate Q/K/V projections** — K and V use fewer heads than Q
2. **KV expansion** — Repeat each KV head `num_groups` times before attention
3. **Unified interface** — Supports MHA, GQA, and MQA via `num_kv_heads` parameter
4. **Cache storage** — Store K/V in compressed form, expand only during computation

### Memory Savings:
- **4× reduction** typical (8 query heads → 2 KV heads)
- **8× reduction** aggressive (8 query heads → 1 KV head = MQA)
- Scales linearly with context length (critical for 32k+ contexts)

### Quality Impact:
- **Minimal degradation** with 4-8 groups (Llama 2/3, Mistral configurations)
- Production models show <1% perplexity increase vs MHA
- MQA can show ~2-3% degradation but still viable

### Why It Matters:
- ✅ **Enables longer contexts** — 32k, 64k, 128k become feasible
- ✅ **Reduces deployment cost** — fit larger models on smaller GPUs
- ✅ **Improves throughput** — less memory movement during decoding
- ✅ **Industry standard** — used by Llama, Mistral, Qwen, and others

---

## Implementation Checklist

To add GQA to your model:
1. ✅ Add `num_kv_heads` parameter to `MultiHeadAttention`
2. ✅ Split QKV projection into separate Q, K, V projections
3. ✅ Implement `_repeat_kv()` helper for KV expansion
4. ✅ Update cache shapes to use `num_kv_heads`
5. ✅ Pass `num_kv_heads` through `TransformerBlock` and `Decoder`
6. ✅ Test: cached generation matches full forward
7. ✅ Verify: cache memory is reduced by expected factor

**Next Steps:**
- Train a model with GQA and compare perplexity to MHA baseline
- Benchmark inference speed and memory usage
- Try different `num_kv_heads` values to find optimal trade-off for your use case