# Bottleneck 2 — The KV Cache Memory Wall

## The Problem

During **autoregressive generation**, a Transformer predicts one token at a time.  
For every new token, attention is computed against **all previous tokens**.  
Without caching that would mean recomputing every Key/Value vector from scratch on each step — O(n²) work.

The **KV Cache** trades memory for compute: we store K and V tensors for every past token, so each step only needs to compute K/V for the *new* token.

### The trade-off that becomes a wall

| Model | Layers | Heads | Head-dim | KV per token | @ 4k tokens | @ 32k tokens |
|-------|--------|-------|----------|--------------|--------------|--------------|
| 7B  | 32 | 32 | 128 | 512 KB | ~2 GB | ~16 GB |
| 13B | 40 | 40 | 128 | 800 KB | ~3 GB | ~25 GB |
| 70B | 80 | 64 | 128 | 2 MB   | ~8 GB | ~64 GB |

A long conversation on a 70B model can **exceed the VRAM of 8 × A100 GPUs combined**.

## Solutions We Will Explore

1. **Baseline** — naive full KV cache (the problem)
2. **Sliding-Window Attention** — cap cache at a fixed window
3. **KV Cache Quantization** — store K/V in INT8 instead of FP16
4. **Multi-Query / Grouped-Query Attention (MQA/GQA)** — fewer K/V heads
5. **PagedAttention** — allocate cache in non-contiguous pages (vLLM idea)
6. **StreamingLLM / Sink Tokens** — keep initial + recent tokens only

All demonstrations run on **CPU with small toy models** so every developer can reproduce this without a GPU.

## 0 — Setup

In [None]:
# Install dependencies (only needs to run once)
import subprocess, sys
pkgs = ["torch", "matplotlib", "numpy", "psutil", "tabulate"]
subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet"] + pkgs)
print("✅ Dependencies ready")

In [None]:
import math, time, psutil, os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from tabulate import tabulate

torch.manual_seed(42)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {DEVICE}")

# ── colour palette ──────────────────────────────────────────────────────────
C = dict(baseline="#e74c3c", sliding="#3498db", quant="#2ecc71",
         gqa="#9b59b6", paged="#f39c12", streaming="#1abc9c")

---
## 1 — Understanding KV Cache Memory: The Math

In [None]:
def kv_cache_bytes(n_layers: int, n_heads: int, head_dim: int,
                   seq_len: int, dtype_bytes: int = 2,
                   n_kv_heads: int | None = None) -> int:
    """
    KV cache size in bytes.
    2  → K and V
    n_kv_heads → for GQA/MQA (defaults to n_heads for standard MHA)
    """
    kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
    return 2 * n_layers * kv_heads * head_dim * seq_len * dtype_bytes


models = {
    "Llama-3 7B":  dict(n_layers=32, n_heads=32, head_dim=128),
    "Llama-3 13B": dict(n_layers=40, n_heads=40, head_dim=128),
    "Llama-3 70B": dict(n_layers=80, n_heads=64, head_dim=128),
}
seq_lens = [512, 2048, 4096, 8192, 16384, 32768]

rows = []
for name, cfg in models.items():
    row = [name]
    for sl in seq_lens:
        gb = kv_cache_bytes(**cfg, seq_len=sl) / 1e9
        row.append(f"{gb:.1f} GB")
    rows.append(row)

print(tabulate(rows,
               headers=["Model"] + [f"{sl:,} tok" for sl in seq_lens],
               tablefmt="fancy_grid"))

---
## 2 — Toy Transformer Block (shared baseline for all experiments)

In [None]:
class MultiHeadAttention(nn.Module):
    """Standard MHA with an optional KV cache returned per step."""

    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int | None = None):
        super().__init__()
        self.n_heads    = n_heads
        self.n_kv_heads = n_kv_heads or n_heads          # MHA: same as n_heads
        self.head_dim   = d_model // n_heads
        self.scale      = self.head_dim ** -0.5

        self.wq = nn.Linear(d_model, n_heads    * self.head_dim, bias=False)
        self.wk = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

    def forward(self, x, past_k=None, past_v=None):
        B, T, C = x.shape
        q = self.wq(x).view(B, T, self.n_heads,    self.head_dim).transpose(1, 2)
        k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        # ── Append to KV cache ──────────────────────────────────────────────
        k = torch.cat([past_k, k], dim=2) if past_k is not None else k
        v = torch.cat([past_v, v], dim=2) if past_v is not None else v

        # ── GQA: expand KV heads to match Q heads ──────────────────────────
        if self.n_kv_heads != self.n_heads:
            ratio = self.n_heads // self.n_kv_heads
            k = k.repeat_interleave(ratio, dim=1)
            v = v.repeat_interleave(ratio, dim=1)

        att = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        att = F.softmax(att, dim=-1)
        out = torch.matmul(att, v).transpose(1, 2).contiguous().view(B, T, -1)
        return self.wo(out), k, v


class ToyTransformer(nn.Module):
    """Single-layer transformer for experimentation."""

    def __init__(self, vocab_size=512, d_model=128, n_heads=4, n_kv_heads=None):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.attn  = MultiHeadAttention(d_model, n_heads, n_kv_heads)
        self.ff    = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model),
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size)

    def step(self, token_id, past_k=None, past_v=None):
        """Generate one token, returning logits and updated KV cache."""
        x = self.embed(token_id.unsqueeze(0).unsqueeze(0))   # (1,1,d)
        h, k, v = self.attn(self.ln1(x), past_k, past_v)
        x = x + h
        x = x + self.ff(self.ln2(x))
        logits = self.head(x[:, -1, :])
        return logits, k, v

print("✅ ToyTransformer defined")

---
## 3 — Solution 1: Baseline (Full KV Cache) — Demonstrating the Problem

In [None]:
def run_generation(model, n_tokens: int, strategy: str = "full",
                   window: int = 64, sink_tokens: int = 4) -> dict:
    """
    Run autoregressive generation and collect per-step timing and memory.

    strategy:
        'full'      – standard KV cache (baseline)
        'sliding'   – keep only last `window` tokens
        'streaming' – keep first `sink_tokens` + last `window` tokens
    """
    model.eval()
    latencies, cache_bytes = [], []
    token_id = torch.tensor(0, device=DEVICE)
    past_k = past_v = None

    with torch.no_grad():
        for step in range(n_tokens):
            t0 = time.perf_counter()
            logits, past_k, past_v = model.step(token_id, past_k, past_v)
            latencies.append(time.perf_counter() - t0)

            # ── KV cache memory (bytes) ─────────────────────────────────────
            cache_bytes.append(past_k.nelement() * past_k.element_size() * 2)  # K+V

            # ── Cache pruning ───────────────────────────────────────────────
            if strategy == "sliding" and past_k.shape[2] > window:
                past_k = past_k[:, :, -window:, :]
                past_v = past_v[:, :, -window:, :]

            elif strategy == "streaming":
                total = past_k.shape[2]
                if total > sink_tokens + window:
                    # Sink tokens (first few) + recent window
                    past_k = torch.cat([
                        past_k[:, :, :sink_tokens, :],
                        past_k[:, :, -window:, :]
                    ], dim=2)
                    past_v = torch.cat([
                        past_v[:, :, :sink_tokens, :],
                        past_v[:, :, -window:, :]
                    ], dim=2)

            token_id = logits.argmax(dim=-1).squeeze()

    return {"latencies": latencies, "cache_bytes": cache_bytes}


N_TOKENS = 300
model_base = ToyTransformer().to(DEVICE)

print("Running baseline (full KV cache)...")
baseline = run_generation(model_base, N_TOKENS, strategy="full")
print(f"  First token latency : {baseline['latencies'][0]*1000:.2f} ms")
print(f"  Last  token latency : {baseline['latencies'][-1]*1000:.2f} ms")
print(f"  Cache at end        : {baseline['cache_bytes'][-1] / 1024:.1f} KB")

---
## 4 — Solution 2: Sliding-Window Attention

**Idea**: Only attend to the most recent **W** tokens.  
Cache size is O(W) instead of O(n).

**Trade-off**: Loses very old context.  
**Used by**: Mistral-7B (window = 4096), Longformer.

In [None]:
WINDOW = 50   # keep last 50 tokens

print("Running sliding-window strategy...")
sliding = run_generation(model_base, N_TOKENS, strategy="sliding", window=WINDOW)
print(f"  Max cache size: {max(sliding['cache_bytes']) / 1024:.1f} KB  (capped at {WINDOW} tokens)")

---
## 5 — Solution 3: StreamingLLM / Attention Sinks

**Idea**: Keep the first few "sink" tokens (they absorb excess softmax probability) **plus** a recent sliding window.  
This lets models generalise to infinite context without fine-tuning.

**Paper**: *Efficient Streaming Language Models with Attention Sinks* (Xiao et al., 2023)  
**Used by**: `vLLM` StreamingLLM mode, TensorRT-LLM.

In [None]:
SINK = 4
STREAM_WINDOW = 46   # total cache = 4 sinks + 46 recent = 50 tokens

print("Running StreamingLLM strategy...")
streaming = run_generation(model_base, N_TOKENS, strategy="streaming",
                           window=STREAM_WINDOW, sink_tokens=SINK)
print(f"  Max cache size: {max(streaming['cache_bytes']) / 1024:.1f} KB")

---
## 6 — Solution 4: Grouped-Query Attention (GQA)

**Idea**: Instead of one K/V head per Q head (MHA), share K/V heads across groups of Q heads.  
- **MQA** (Multi-Query): 1 K/V head for *all* Q heads — maximum saving.
- **GQA** (Grouped-Query): G K/V heads for N Q heads — balanced trade-off.

**Saving**: Cache shrinks by factor of `n_heads / n_kv_heads`.  
**Used by**: Llama-3 (GQA), Mistral (GQA), Falcon (MQA), Gemma.

In [None]:
def gqa_cache_reduction(n_heads: int, n_kv_heads: int) -> float:
    return n_kv_heads / n_heads

configs = {
    "MHA (baseline)": (8, 8),
    "GQA-4 groups":   (8, 4),
    "GQA-2 groups":   (8, 2),
    "MQA":            (8, 1),
}

print(f"{'Config':<20} {'KV heads':>10} {'Cache vs MHA':>14}")
print("-" * 46)
for label, (nh, nkv) in configs.items():
    r = gqa_cache_reduction(nh, nkv)
    print(f"{label:<20} {nkv:>10}     {r*100:>8.0f}%")

print()
print("Running GQA model (8Q heads → 2 KV heads)...")
model_gqa = ToyTransformer(n_heads=4, n_kv_heads=1).to(DEVICE)   # toy: 4Q, 1KV
gqa_result = run_generation(model_gqa, N_TOKENS, strategy="full")
print(f"  GQA cache at step {N_TOKENS}: {gqa_result['cache_bytes'][-1] / 1024:.1f} KB")
print(f"  Baseline cache at step {N_TOKENS}: {baseline['cache_bytes'][-1] / 1024:.1f} KB")
print(f"  Reduction: {gqa_result['cache_bytes'][-1]/baseline['cache_bytes'][-1]*100:.0f}% of baseline")

---
## 7 — Solution 5: KV Cache Quantization

**Idea**: Store K/V tensors in INT8 (or INT4) instead of FP16/BF16.  
**Saving**: 2× (INT8) or 4× (INT4) memory reduction with negligible quality loss.  
**Used by**: `bitsandbytes`, `llm.int8()`, `GPTQ`, `AWQ`, `vLLM` KV quant.

In [None]:
class QuantizedKVCache:
    """
    Simulates INT8 quantized KV cache.
    Stores K/V as int8 tensors; dequantises on retrieval.
    """

    def __init__(self):
        self.k_int8: torch.Tensor | None = None
        self.v_int8: torch.Tensor | None = None
        self.k_scale: float = 1.0
        self.v_scale: float = 1.0

    def _quantize(self, t: torch.Tensor):
        scale = t.abs().max().item() / 127.0 + 1e-8
        return (t / scale).round().clamp(-128, 127).to(torch.int8), scale

    def _dequantize(self, t: torch.Tensor, scale: float) -> torch.Tensor:
        return t.float() * scale

    def append(self, k: torch.Tensor, v: torch.Tensor):
        """Quantise and append new K/V slice."""
        k_q, ks = self._quantize(k)
        v_q, vs = self._quantize(v)
        # Running scale: simplification — use latest slice scale
        self.k_scale, self.v_scale = ks, vs
        self.k_int8 = k_q if self.k_int8 is None else torch.cat([self.k_int8, k_q], dim=2)
        self.v_int8 = v_q if self.v_int8 is None else torch.cat([self.v_int8, v_q], dim=2)

    def get(self):
        """Return dequantised K/V for attention computation."""
        return (
            self._dequantize(self.k_int8, self.k_scale),
            self._dequantize(self.v_int8, self.v_scale),
        )

    @property
    def bytes_used(self) -> int:
        if self.k_int8 is None:
            return 0
        return (self.k_int8.nelement() + self.v_int8.nelement()) * 1  # int8 = 1 byte


def run_quantized_generation(model, n_tokens: int) -> dict:
    model.eval()
    latencies, cache_bytes = [], []
    token_id = torch.tensor(0, device=DEVICE)
    qcache = QuantizedKVCache()

    with torch.no_grad():
        for step in range(n_tokens):
            t0 = time.perf_counter()

            x = model.embed(token_id.unsqueeze(0).unsqueeze(0))
            x_norm = model.ln1(x)

            # Compute new Q, K, V
            attn = model.attn
            B, T, _ = x_norm.shape
            q = attn.wq(x_norm).view(B, T, attn.n_heads, attn.head_dim).transpose(1, 2)
            k = attn.wk(x_norm).view(B, T, attn.n_kv_heads, attn.head_dim).transpose(1, 2)
            v = attn.wv(x_norm).view(B, T, attn.n_kv_heads, attn.head_dim).transpose(1, 2)

            # Append to quantized cache
            qcache.append(k, v)

            # Dequantise for attention
            k_full, v_full = qcache.get()

            att = torch.matmul(q, k_full.transpose(-2, -1)) * attn.scale
            att = F.softmax(att, dim=-1)
            out = torch.matmul(att, v_full).transpose(1, 2).contiguous().view(B, T, -1)
            h   = attn.wo(out)

            x = x + h
            x = x + model.ff(model.ln2(x))
            logits = model.head(x[:, -1, :])

            latencies.append(time.perf_counter() - t0)
            cache_bytes.append(qcache.bytes_used)
            token_id = logits.argmax(dim=-1).squeeze()

    return {"latencies": latencies, "cache_bytes": cache_bytes}


print("Running INT8 quantized KV cache...")
quant_result = run_quantized_generation(model_base, N_TOKENS)
ratio = quant_result['cache_bytes'][-1] / baseline['cache_bytes'][-1]
print(f"  Quantized cache: {quant_result['cache_bytes'][-1] / 1024:.1f} KB")
print(f"  Baseline cache : {baseline['cache_bytes'][-1] / 1024:.1f} KB")
print(f"  Memory ratio   : {ratio:.2f}x  (expected ~0.5x for INT8)")

---
## 8 — Solution 6: PagedAttention (vLLM concept)

**Idea**: Instead of one giant contiguous KV tensor per sequence, divide the cache into fixed-size **pages** (blocks).  
A **block table** maps logical token positions to physical memory pages.

**Why this matters**:
- No memory fragmentation — pages can be scattered anywhere in VRAM.
- Multiple requests can **share** pages (prefix caching / prompt caching).
- Allows **preemption**: swap pages to CPU, reload when needed.

**Used by**: vLLM (original invention), TGI, SGLang.

In [None]:
class PagedKVCache:
    """
    Simplified PagedAttention cache.
    Memory is pre-allocated as a pool of fixed-size blocks.
    """

    def __init__(self, n_kv_heads: int, head_dim: int,
                 block_size: int = 16, max_blocks: int = 200,
                 dtype=torch.float32):
        self.block_size  = block_size
        self.n_kv_heads  = n_kv_heads
        self.head_dim    = head_dim

        # ── Pre-allocated GPU pool ──────────────────────────────────────────
        # Shape: (max_blocks, 2, n_kv_heads, block_size, head_dim)
        self.pool = torch.zeros(
            max_blocks, 2, n_kv_heads, block_size, head_dim, dtype=dtype
        )
        self.free_blocks  = list(range(max_blocks))   # block IDs available
        self.block_table: list[int] = []              # logical → physical
        self.slot_idx     = 0                         # position within current block
        self.seq_len      = 0

    def _alloc_block(self) -> int:
        if not self.free_blocks:
            raise RuntimeError("Out of KV cache pages!")
        blk = self.free_blocks.pop(0)
        self.block_table.append(blk)
        return blk

    def append(self, k: torch.Tensor, v: torch.Tensor):
        """Write one token's K/V into the next available slot."""
        if self.slot_idx == 0:
            self._alloc_block()

        blk = self.block_table[-1]
        self.pool[blk, 0, :, self.slot_idx, :] = k.squeeze(0).squeeze(1)
        self.pool[blk, 1, :, self.slot_idx, :] = v.squeeze(0).squeeze(1)
        self.slot_idx  = (self.slot_idx + 1) % self.block_size
        self.seq_len  += 1

    def get_contiguous(self):
        """Gather all K/V into a contiguous (1, H, S, D) tensor for attention."""
        ks, vs = [], []
        for i, blk in enumerate(self.block_table):
            end = self.block_size if i < len(self.block_table) - 1 else self.slot_idx or self.block_size
            ks.append(self.pool[blk, 0, :, :end, :])   # (H, end, D)
            vs.append(self.pool[blk, 1, :, :end, :])
        K = torch.cat(ks, dim=1).unsqueeze(0)   # (1, H, S, D)
        V = torch.cat(vs, dim=1).unsqueeze(0)
        return K, V

    @property
    def blocks_used(self):
        return len(self.block_table)


def run_paged_generation(model, n_tokens: int, block_size: int = 16) -> dict:
    model.eval()
    latencies, cache_bytes, blocks_used = [], [], []
    token_id = torch.tensor(0, device=DEVICE)

    attn = model.attn
    pcache = PagedKVCache(
        n_kv_heads=attn.n_kv_heads,
        head_dim=attn.head_dim,
        block_size=block_size,
    )

    with torch.no_grad():
        for step in range(n_tokens):
            t0 = time.perf_counter()

            x = model.embed(token_id.unsqueeze(0).unsqueeze(0))
            B, T, _ = x.shape
            x_norm = model.ln1(x)

            q = attn.wq(x_norm).view(B, T, attn.n_heads, attn.head_dim).transpose(1, 2)
            k = attn.wk(x_norm).view(B, T, attn.n_kv_heads, attn.head_dim).transpose(1, 2)
            v = attn.wv(x_norm).view(B, T, attn.n_kv_heads, attn.head_dim).transpose(1, 2)

            pcache.append(k, v)
            k_full, v_full = pcache.get_contiguous()

            att = torch.matmul(q, k_full.transpose(-2, -1)) * attn.scale
            att = F.softmax(att, dim=-1)
            out = torch.matmul(att, v_full).transpose(1, 2).contiguous().view(B, T, -1)
            h   = attn.wo(out)

            x = x + h
            x = x + model.ff(model.ln2(x))
            logits = model.head(x[:, -1, :])

            latencies.append(time.perf_counter() - t0)
            # Memory = only the bytes in actually-used pool slots
            used_slots = pcache.seq_len
            bytes_used = used_slots * attn.n_kv_heads * attn.head_dim * 2 * 4   # K+V, float32
            cache_bytes.append(bytes_used)
            blocks_used.append(pcache.blocks_used)
            token_id = logits.argmax(dim=-1).squeeze()

    return {"latencies": latencies, "cache_bytes": cache_bytes, "blocks": blocks_used}


print("Running PagedAttention simulation...")
paged_result = run_paged_generation(model_base, N_TOKENS)
print(f"  Blocks used at end: {paged_result['blocks'][-1]}  (block_size=16)")
print(f"  Effective cache   : {paged_result['cache_bytes'][-1]/1024:.1f} KB")

---
## 9 — Visualisation: All Solutions Compared

In [None]:
steps = list(range(1, N_TOKENS + 1))

def to_kb(lst): return [b / 1024 for b in lst]

fig, axes = plt.subplots(1, 2, figsize=(16, 6))
fig.suptitle("KV Cache Memory Wall — Problem vs Solutions", fontsize=15, fontweight="bold")

# ── Panel 1: Cache size ──────────────────────────────────────────────────────
ax = axes[0]
ax.plot(steps, to_kb(baseline["cache_bytes"]),    label="Baseline (full cache)",    color=C["baseline"],  lw=2)
ax.plot(steps, to_kb(gqa_result["cache_bytes"]),  label="GQA (4Q→1KV head)",        color=C["gqa"],       lw=2)
ax.plot(steps, to_kb(quant_result["cache_bytes"]),label="INT8 Quantized",           color=C["quant"],     lw=2)
ax.plot(steps, to_kb(sliding["cache_bytes"]),     label=f"Sliding Window (W={WINDOW})", color=C["sliding"], lw=2)
ax.plot(steps, to_kb(streaming["cache_bytes"]),   label=f"StreamingLLM (sink={SINK}, W={STREAM_WINDOW})",
                                                                                    color=C["streaming"], lw=2)
ax.plot(steps, to_kb(paged_result["cache_bytes"]),label="PagedAttention",           color=C["paged"],     lw=2, linestyle="--")
ax.set_xlabel("Generation Step (token #)")
ax.set_ylabel("KV Cache Memory (KB)")
ax.set_title("Memory Usage vs Generation Step")
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

# ── Panel 2: Latency per step ────────────────────────────────────────────────
ax = axes[1]
window_smooth = 10
def smooth(arr):
    return np.convolve(arr, np.ones(window_smooth)/window_smooth, mode="same")

def ms(lst): return [x*1000 for x in lst]

ax.plot(steps, smooth(ms(baseline["latencies"])),    label="Baseline",       color=C["baseline"],  lw=2)
ax.plot(steps, smooth(ms(gqa_result["latencies"])),  label="GQA",            color=C["gqa"],       lw=2)
ax.plot(steps, smooth(ms(quant_result["latencies"])),label="INT8 Quantized", color=C["quant"],     lw=2)
ax.plot(steps, smooth(ms(sliding["latencies"])),     label="Sliding Window", color=C["sliding"],   lw=2)
ax.plot(steps, smooth(ms(streaming["latencies"])),   label="StreamingLLM",   color=C["streaming"], lw=2)
ax.set_xlabel("Generation Step (token #)")
ax.set_ylabel("Step Latency (ms, smoothed)")
ax.set_title("Per-Step Latency vs Generation Step")
ax.legend(fontsize=8)
ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig("kv_cache_comparison.png", dpi=150, bbox_inches="tight")
plt.show()
print("Figure saved to kv_cache_comparison.png")

---
## 10 — Theoretical Memory Savings at Production Scale

In [None]:
# Llama-3 70B at 32k context length
cfg = dict(n_layers=80, n_heads=64, head_dim=128)
seq = 32768

scenarios = {
    "MHA FP16 (baseline)":          dict(**cfg, dtype_bytes=2),
    "GQA-8 (Llama-3 actual)": dict(**cfg, n_kv_heads=8,  dtype_bytes=2),
    "MHA INT8 quantized":            dict(**cfg, dtype_bytes=1),
    "GQA-8 + INT8":          dict(**cfg, n_kv_heads=8,  dtype_bytes=1),
    "MQA FP16":                      dict(**cfg, n_kv_heads=1,  dtype_bytes=2),
    "Sliding Window (4k)": None,  # handled separately
}

print(f"\n{'Scenario':<30} {'Cache (GB)':>12} {'vs Baseline':>14}")
print("-" * 60)

baseline_gb = kv_cache_bytes(**cfg, seq_len=seq, dtype_bytes=2) / 1e9

for label, kw in scenarios.items():
    if kw is None:   # sliding window
        gb = kv_cache_bytes(**cfg, seq_len=4096, dtype_bytes=2) / 1e9
    else:
        gb = kv_cache_bytes(**kw, seq_len=seq) / 1e9
    pct = gb / baseline_gb * 100
    print(f"{label:<30} {gb:>10.1f} GB   {pct:>8.1f}%")

---
## 11 — Decision Guide: Which Solution When?

| Technique | Memory Saving | Quality Impact | Complexity | Best For |
|-----------|:---:|:---:|:---:|---|
| **GQA/MQA** | 8–64× | Minimal (retrained) | Low | New model training |
| **KV INT8 Quant** | ~2× | Negligible | Medium | Existing deployed models |
| **KV INT4 Quant** | ~4× | Small | Medium | Memory-constrained inference |
| **Sliding Window** | O(W/n) | Loses old context | Low | Chat / streaming |
| **StreamingLLM** | O(W/n) | Near-lossless for most tasks | Low | Long streaming generation |
| **PagedAttention** | ~0% saving but eliminates fragmentation | None | High | Multi-user serving (vLLM) |
| **Offload to CPU/NVMe** | Removes VRAM limit | Latency penalty | High | Very long context |
| **Sparse Attention** | Varies | Varies | Very High | Research / custom models |

### Production Recommendation Stack

```
1. Choose a GQA-trained model (e.g., Llama-3, Mistral) → free 8× saving
2. Serve with vLLM (PagedAttention) → eliminates VRAM fragmentation
3. Enable INT8 KV cache → another 2× saving
4. Add sliding window or StreamingLLM for very long sessions
```

Combined, these can reduce KV cache memory from **64 GB → ~4 GB** for a 70B model at 32k context.

In [None]:
# ── Summary bar chart ────────────────────────────────────────────────────────
labels = [
    "Baseline\nMHA FP16",
    "GQA-8\nFP16",
    "MHA\nINT8",
    "GQA-8\nINT8",
    "Sliding\nW=4k",
    "MQA\nFP16",
]
gbs = [
    kv_cache_bytes(**cfg, seq_len=32768, dtype_bytes=2) / 1e9,
    kv_cache_bytes(**cfg, n_kv_heads=8, seq_len=32768, dtype_bytes=2) / 1e9,
    kv_cache_bytes(**cfg, seq_len=32768, dtype_bytes=1) / 1e9,
    kv_cache_bytes(**cfg, n_kv_heads=8, seq_len=32768, dtype_bytes=1) / 1e9,
    kv_cache_bytes(**cfg, seq_len=4096,  dtype_bytes=2) / 1e9,
    kv_cache_bytes(**cfg, n_kv_heads=1,  seq_len=32768, dtype_bytes=2) / 1e9,
]
colours = [C["baseline"], C["gqa"], C["quant"], C["streaming"], C["sliding"], C["paged"]]

fig, ax = plt.subplots(figsize=(12, 5))
bars = ax.bar(labels, gbs, color=colours, edgecolor="white", linewidth=0.5)

for bar, gb in zip(bars, gbs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
            f"{gb:.1f} GB", ha="center", va="bottom", fontsize=9, fontweight="bold")

ax.set_ylabel("KV Cache Memory (GB)")
ax.set_title("Llama-3 70B · 32k context · KV Cache Footprint by Strategy", fontsize=13)
ax.set_ylim(0, max(gbs) * 1.18)
ax.grid(axis="y", alpha=0.3)
ax.spines[["top", "right"]].set_visible(False)

plt.tight_layout()
plt.savefig("kv_cache_strategies_bar.png", dpi=150, bbox_inches="tight")
plt.show()
print("Figure saved to kv_cache_strategies_bar.png")

---
## Summary

| # | Technique | Core Idea | Where to Apply |
|---|-----------|-----------|----------------|
| 1 | **Full KV Cache** | Store all past K/V | ❌ The problem — O(n) VRAM |
| 2 | **Sliding Window** | Forget tokens older than W | Chat agents, real-time streaming |
| 3 | **StreamingLLM / Attention Sinks** | Keep sink + recent | Infinite generation without fine-tuning |
| 4 | **GQA / MQA** | Fewer KV heads, shared across Q | Model training time (most impactful) |
| 5 | **KV Quantization** | INT8/INT4 for K/V tensors | Drop-in at inference, ~2–4× saving |
| 6 | **PagedAttention** | Non-contiguous memory pages | Multi-user serving (vLLM, TGI) |

The **state-of-the-art production stack** (e.g. vLLM + Llama-3) combines all of 4 + 5 + 6 to achieve near-linear memory scaling with minimal quality degradation.