<a href="https://colab.research.google.com/github/jman4162/Sizing-AI-Training-by-Cost-per-Memory-Bandwidth/blob/main/The_KV_Cache_What_It_Is%2C_Why_It_Matters%2C_and_How_to_Size_It_for_Modern_LLMs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The KV Cache: What It Is, Why It Matters, and How to Size It for Modern LLMs

Author: John Hodge

Date: 09/03/2025

*If you only remember one thing:* During generation, large language models are often **memory-bandwidth bound**, not compute-bound. The **Key/Value (KV) cache** is the reason: it turns decoding into a fast lookup problem but makes **HBM bandwidth** and **VRAM** the dominant resources. Mastering KV cache mechanics is table stakes for building fast, cost-efficient LLM systems.

---

## 1) Quick intuition

In decoder-only transformers, each new token must attend to **all prior tokens**. Naïvely, you’d recompute keys and values (K, V) for the entire prefix at every step—wasting tons of compute. Instead, we compute K and V **once** and **cache** them per layer. On the next token:

1. Compute the new query **Qₜ**
2. **Read** cached **K₁…ₜ₋₁, V₁…ₜ₋₁**
3. Do attention → logits
4. **Append** **Kₜ, Vₜ** to the cache for future steps

**Result:** Huge compute savings, but your performance now hinges on how quickly you can **read** (and grow) that cache.

---

## 2) Prefill vs. Decode (the two phases)

* **Prefill (prompt processing):** Run the model on the entire prompt once to build the cache. This is relatively **compute-heavy** and benefits from high TFLOPs.
* **Decode (token-by-token):** For each generated token, read all prior K/V from cache. This phase is typically **memory-bandwidth bound** (HBM GB/s or TB/s dominates).

Throughput (tokens/s) in decode falls as context grows because you read more K/V every step.

---

## 3) How big is the KV cache?

For one token, one layer:

$$
\text{K/V bytes per token per layer} = 2 \times n_{\text{kv\_heads}} \times d_{\text{head}} \times \text{bytes\_per\_elem}
$$

Across all layers:

$$
\text{KV bytes per token} \approx 2 \times L \times n_{\text{kv\_heads}} \times d_{\text{head}} \times \text{bytes}
$$

**Example (LLaMA-ish 70B):** $L=80$, $n_{\text{heads}}=64$, $d_{\text{head}}=128$, bf16 → 2 bytes/elem

* **Standard MHA** (KV heads = 64):
  $2 \times 80 \times 64 \times 128 \times 2$ ≈ **2.5 MB per token**
  2k tokens → **\~5 GB per sequence** (KV only).
* **GQA** (e.g., 8 KV heads): **\~320 KB/token** → **\~0.64 GB** at 2k tokens.
* **MQA** (1 KV head): **\~40 KB/token** → **\~80 MB** at 2k tokens.

> **GQA/MQA** are massive wins for long context and batching: they shrink KV by **8–64×**.

---

## 4) Bandwidth: why decode is memory-bound

At decode step $t$, attention must **read** K/V for the first $t-1$ tokens (per layer). Per-token read grows \~linearly with context length. Even with efficient kernels, the dominant cost is **bytes moved** from HBM:

* **Arithmetic intensity** (FLOPs per byte moved) is **low** in decode.
* Faster TFLOPs don’t help if HBM can’t feed the GPU—**you’re memory-bandwidth bound**.

That’s why **“cost per memory bandwidth”** (e.g., **\$/TB/s·hour**) is often a better predictor of real-world decode throughput/\$ than TFLOPs/\$.

---

## 5) System-level implications

### 5.1 VRAM limits concurrency

Total KV memory scales with **sequence length × (batch or beams) × layers × KV heads**. At long contexts, KV can dominate VRAM. This caps:

* How many requests you can run concurrently on one GPU
* How wide you can set **beam search**

### 5.2 Parallelism & placement

* With **tensor parallelism**, KV is usually **sharded by heads** across GPUs. Good for memory, but be mindful of cross-GPU traffic.
* **Pipeline parallelism** keeps layer shards on different GPUs; KV remains local per stage but you add pipeline bubbles.
* **Data parallelism** doesn’t help KV size per GPU but affects gradient traffic (training).

### 5.3 Serving architecture

* **Paged KV / block allocators** avoid fragmentation and enable **continuous batching** (mixing requests at different decode steps).
* **Prefix sharing/caching**: If multiple prompts share prefixes, reuse the same KV pages.
* **Speculative decoding**: Can improve throughput and still leverage the same KV store for accepted tokens.

---

## 6) Practical ways to reduce KV pressure

* **Architectural**

  * **GQA/MQA**: Reduce KV heads dramatically.
  * **Windowed/sliding attention**: Keep a moving context window; evict old tokens.
  * **Sparse/block attention**: Read fewer K/V positions per step.

* **Kernel-level**

  * **Flash-Decoding / IO-aware kernels**: Tile to minimize HBM traffic and improve cache locality.
  * **Fused ops** (e.g., attention + softmax + projections) to cut extra reads/writes.

* **Compression**

  * **KV quantization** (8-bit, 4-bit, NF4): Trade a little quality for big memory/bandwidth savings.
  * **Mixed precision**: bf16/FP8 variations where supported.

* **Scheduling**

  * **Continuous batching** with smart admission control (avoid head-of-line blocking).
  * **Latency classes**: Group similar sequence lengths to reduce worst-case cache reads.

---

## 7) Sizing worksheet

### 7.1 KV memory footprint

For given $L, n_{\text{kv}}, d_{\text{head}}, \text{bytes}$, and target **context** $T$, **concurrency** $C$ (or beams), estimate:

$$
\text{KV VRAM (GB)} \approx
\frac{2 \cdot L \cdot n_{\text{kv}} \cdot d_{\text{head}} \cdot \text{bytes} \cdot T \cdot C}{10^9}
\times \phi
$$

$\phi$ is an overhead factor (allocator fragmentation, metadata), typically **1.1–1.3** in practice.

### 7.2 Required HBM bandwidth for a target TPS

At average context $\bar{t}$ (tokens already in cache), **per new token** reads roughly:

$$
\text{Bytes/token} \approx 2 \cdot L \cdot n_{\text{kv}} \cdot d_{\text{head}} \cdot \text{bytes} \cdot \bar{t}
$$

For a target throughput $R$ tokens/s:

$$
\text{HBM BW (GB/s)} \approx \frac{\text{Bytes/token} \cdot R}{10^9}
$$

Compare this with **usable** HBM GB/s of the GPU. If you exceed it, you’re memory-bound—reduce bytes/token or add GPUs.

---

## 8) Tiny helper: compute KV size & decode BW (Python)

```python
def kv_bytes_per_token(L, n_kv, d_head, bytes_per_elem=2):
    return 2 * L * n_kv * d_head * bytes_per_elem  # bytes

def kv_vram_gb(L, n_kv, d_head, T, concurrency, bytes_per_elem=2, overhead=1.2):
    per_tok = kv_bytes_per_token(L, n_kv, d_head, bytes_per_elem)
    return overhead * per_tok * T * concurrency / 1e9

def decode_bandwidth_gbps(L, n_kv, d_head, avg_ctx, tps, bytes_per_elem=2):
    per_tok_read = kv_bytes_per_token(L, n_kv, d_head, bytes_per_elem) * avg_ctx
    return per_tok_read * tps * 8 / 1e9  # to Gb/s
```

**Try it:**

* Standard MHA (64 KV heads) vs. GQA (8) at $L{=}80, d_{\text{head}}{=}128$
* Concurrency $= 16$, context $= 2{,}048$, TPS target $= 2000$, avg\_ctx $≈ 1536$

You’ll see why MHA quickly blows VRAM and HBM budgets, while GQA makes the same goal feasible.

---

## 9) KV cache in **training** vs **inference**

* **Inference (autoregressive decode):** KV cache is central; bottleneck is memory bandwidth.
* **Standard full-sequence training:** You usually don’t keep a persistent KV cache across steps; attention runs over the whole sequence at once. Bottlenecks are mixed (compute + memory).
* **Streaming/segment training** (e.g., recurrent or state-carrying variants): You *do* carry state reminiscent of KV across segments—plan memory/bandwidth similarly to inference.

---

## 10) Common pitfalls

* **Underestimating VRAM:** KV scales with **context × concurrency × beams**; add a safety factor for allocator overhead.
* **Ignoring topology:** In multi-GPU serving, ensure KV shards align with NVLink/NVSwitch; avoid tail-latency from cross-GPU reads on weak links.
* **Only tracking TFLOPs:** For decode, monitor **HBM bandwidth** and **bytes/token**. TFLOPs alone won’t explain performance.
* **Beam search blowups:** KV grows with beam width—plan for the worst case or cap beam size.

---

## 11) What to monitor in production

* **Tokens/sec vs. context length** (watch the slope)
* **HBM bandwidth utilization** (GPU profiler metrics)
* **VRAM usage** split by weights / activations / KV
* **Latency distribution** under mixed sequence lengths (P95/P99)
* **Fragmentation** / page efficiency if you use a paged KV allocator
* **Cross-GPU traffic** if sharding by heads (bytes/s, tail events)

---

## 12) Tying back to **cost-per-bandwidth**

For memory-bound decode, rank hardware by:

$$
\boxed{\;\$ / \text{TB/s·hour} = \frac{\$ / \text{GPU-hour}}{\text{HBM TB/s per GPU}}\;}
$$

Pick the SKU (and interconnect) that gives enough HBM **and** the lowest **\$/TB/s·hr** while meeting your latency/QPS targets. Then push bytes/token down with GQA/MQA, windowed/sparse attention, IO-aware kernels, and KV quantization.

---

## 13) Summary

The KV cache stores per-layer keys and values for all past tokens so an LLM can generate the next token by computing a fresh query and *looking up* past K/V rather than recomputing the entire prefix. It dramatically reduces compute but shifts the bottleneck to **memory**: cache size grows linearly with sequence length (and batch/beam), and decode becomes **HBM-bandwidth bound**. That’s why long-context serving cares more about **bytes moved** than TFLOPs. In practice we shrink and streamline KV with **GQA/MQA**, **paged KV**, **Flash-Decoding**, **windowed/sparse attention**, and **KV quantization**—and we size fleets using a **cost-per-memory-bandwidth** lens, not just TFLOPs/\$.

---

### Final takeaway

If your LLM serving doesn’t talk explicitly about **KV bytes/token**, **HBM GB/s**, **VRAM headroom**, and **\$/TB/s·hr**, you’re flying blind. Make KV cache a first-class design dimension—from model choice (GQA/MQA) to kernels, parallelism, and cluster economics.
