# Understanding KV Cache

The KV cache is what makes disaggregated serving possible, and expensive. Before splitting inference across nodes, we need to understand what this cache contains, how large it is, and what it costs to transfer.

## What is KV Cache?

In transformer models, each new token attends to all previous tokens. Without caching, we would recompute attention keys and values for every token generated. The KV cache stores these intermediate results.

**Systems analogy**: KV cache is session state in a web server. Prefill populates it (like initial login), decode reads from it (like subsequent API calls), and transfer moves it between servers (like session replication).

## Why This Matters for Disaggregation

When prefill runs on spark-01 and decode runs on spark-02:
1. spark-01 processes the prompt and generates the KV cache
2. The KV cache transfers from spark-01 to spark-02
3. spark-02 uses the cache to generate tokens

Transfer speed determines whether disaggregation is worth the overhead. This notebook calculates the numbers using model architecture constants. No model loading required.

## Step 1: KV Cache Dimensions from Model Architecture

Llama-3.1-8B-Instruct uses Grouped Query Attention (GQA). The relevant architecture constants:

| Parameter | Value | Description |
|-----------|-------|-------------|
| `num_hidden_layers` | 32 | Number of transformer layers |
| `num_key_value_heads` | 8 | KV heads per layer (GQA, not full 32) |
| `head_dim` | 128 | Dimension per attention head |
| `dtype` | float16 | 2 bytes per element |

Each layer stores one key tensor and one value tensor. Each tensor has shape `[batch_size, num_kv_heads, sequence_length, head_dim]`.

These are fixed architecture constants. We do not need to load the model to compute cache size.

In [1]:
# Llama-3.1-8B-Instruct architecture constants
NUM_LAYERS = 32
NUM_KV_HEADS = 8       # GQA: 8 KV heads, not 32 attention heads
HEAD_DIM = 128
BYTES_PER_ELEMENT = 2   # float16

# Per-token KV cache size
# Each layer: 2 tensors (key + value) x num_kv_heads x head_dim x bytes_per_element
bytes_per_token_per_layer = 2 * NUM_KV_HEADS * HEAD_DIM * BYTES_PER_ELEMENT
bytes_per_token = bytes_per_token_per_layer * NUM_LAYERS
mb_per_token = bytes_per_token / (1024 * 1024)

print("KV Cache Per-Token Size (Llama-3.1-8B-Instruct)")
print("=" * 50)
print(f"Per layer:  {bytes_per_token_per_layer:,} bytes ({bytes_per_token_per_layer / 1024:.1f} KB)")
print(f"All layers: {bytes_per_token:,} bytes ({mb_per_token:.4f} MB)")
print(f"\nBreakdown:")
print(f"  2 tensors (K+V) x {NUM_KV_HEADS} heads x {HEAD_DIM} dim x {BYTES_PER_ELEMENT} bytes = {bytes_per_token_per_layer:,} bytes/layer")
print(f"  {bytes_per_token_per_layer:,} bytes/layer x {NUM_LAYERS} layers = {bytes_per_token:,} bytes/token")

KV Cache Per-Token Size (Llama-3.1-8B-Instruct)
Per layer:  4,096 bytes (4.0 KB)
All layers: 131,072 bytes (0.1250 MB)

Breakdown:
  2 tensors (K+V) x 8 heads x 128 dim x 2 bytes = 4,096 bytes/layer
  4,096 bytes/layer x 32 layers = 131,072 bytes/token


## Step 2: Cache Size at Different Sequence Lengths

The KV cache scales linearly with sequence length. A 100-token prompt produces a 100-token cache. A 4,096-token conversation produces a 4,096-token cache. This linear scaling is what makes transfer cost predictable.

In [2]:
sequence_lengths = [32, 128, 512, 1024, 2048, 4096]

print(f"{'Sequence Length':<18} {'KV Cache Size':<15} {'Notes'}")
print("-" * 60)

for seq_len in sequence_lengths:
    cache_bytes = bytes_per_token * seq_len
    cache_mb = cache_bytes / (1024 * 1024)
    
    note = ""
    if seq_len == 32:
        note = "Short prompt"
    elif seq_len == 512:
        note = "Typical single-turn"
    elif seq_len == 2048:
        note = "Multi-turn conversation"
    elif seq_len == 4096:
        note = "Long context"
    
    print(f"{seq_len:<18} {cache_mb:>10.2f} MB    {note}")

print(f"\nScaling: {mb_per_token:.4f} MB per token, linear growth")

Sequence Length    KV Cache Size   Notes
------------------------------------------------------------
32                       4.00 MB    Short prompt
128                     16.00 MB    
512                     64.00 MB    Typical single-turn
1024                   128.00 MB    
2048                   256.00 MB    Multi-turn conversation
4096                   512.00 MB    Long context

Scaling: 0.1250 MB per token, linear growth


## Step 3: Transfer Cost, TCP vs RDMA

Disaggregated serving moves the KV cache from the prefill node to the decode node. The transport matters.

**TCP path (6 copies):** GPU memory → CPU memory → kernel buffer → NIC → wire → NIC → kernel buffer → CPU memory → GPU memory

**RDMA path (2 copies):** GPU memory → NIC (GPUDirect) → wire → NIC → GPU memory (GPUDirect)

RDMA with GPUDirect bypasses the CPU and kernel entirely. The NIC reads directly from GPU memory on one side and writes directly to GPU memory on the other.

We use conservative estimates: 1 GB/s effective for TCP (IPoIB overhead on 100 Gbps link), 10 GB/s effective for RDMA.

In [3]:
import json
from pathlib import Path

# Load baseline metrics from Notebook 01
baseline_file = Path("baseline_metrics.json")
if baseline_file.exists():
    with open(baseline_file) as f:
        baseline = json.load(f)
    single_latency_ms = baseline['single_request']['latency_ms']
    single_tokens = baseline['single_request']['tokens']
    per_token_ms = single_latency_ms / single_tokens
    print(f"Baseline from Notebook 01:")
    print(f"  Single request: {single_latency_ms:.1f} ms for {single_tokens} tokens")
    print(f"  Per-token decode: {per_token_ms:.1f} ms/token")
else:
    per_token_ms = 69.2  # Fallback from typical baseline_metrics.json
    print(f"baseline_metrics.json not found, using estimate: {per_token_ms:.1f} ms/token")

# Transfer rates (conservative practical estimates)
TCP_GBYTES_PER_SEC = 1.0    # ~8 Gbps effective (IPoIB overhead on 100G link)
RDMA_GBYTES_PER_SEC = 10.0  # ~80 Gbps effective (RDMA over same link)

print(f"\n{'Seq Len':<10} {'Cache Size':<12} {'TCP Transfer':<15} {'RDMA Transfer':<15} {'Decode Time':<14} {'TCP Overhead':<14} {'RDMA Overhead'}")
print("-" * 105)

for seq_len in [128, 512, 1024, 2048, 4096]:
    cache_bytes = bytes_per_token * seq_len
    cache_mb = cache_bytes / (1024 * 1024)
    cache_gb = cache_bytes / (1024**3)
    
    tcp_ms = (cache_gb / TCP_GBYTES_PER_SEC) * 1000
    rdma_ms = (cache_gb / RDMA_GBYTES_PER_SEC) * 1000
    
    # Decode time for 100 output tokens at baseline rate
    decode_ms = per_token_ms * 100
    
    tcp_overhead_pct = (tcp_ms / decode_ms) * 100
    rdma_overhead_pct = (rdma_ms / decode_ms) * 100
    
    print(f"{seq_len:<10} {cache_mb:>8.2f} MB   {tcp_ms:>10.2f} ms   {rdma_ms:>10.2f} ms   {decode_ms:>10.1f} ms   {tcp_overhead_pct:>10.1f}%   {rdma_overhead_pct:>10.1f}%")

print(f"\nOverhead = transfer time / decode time for 100 output tokens")
print(f"Decode time based on baseline: {per_token_ms:.1f} ms/token x 100 tokens = {per_token_ms * 100:.0f} ms")

Baseline from Notebook 01:
  Single request: 6910.1 ms for 200 tokens
  Per-token decode: 34.6 ms/token

Seq Len    Cache Size   TCP Transfer    RDMA Transfer   Decode Time    TCP Overhead   RDMA Overhead
---------------------------------------------------------------------------------------------------------
128           16.00 MB        15.62 ms         1.56 ms       3455.1 ms          0.5%          0.0%
512           64.00 MB        62.50 ms         6.25 ms       3455.1 ms          1.8%          0.2%
1024         128.00 MB       125.00 ms        12.50 ms       3455.1 ms          3.6%          0.4%
2048         256.00 MB       250.00 ms        25.00 ms       3455.1 ms          7.2%          0.7%
4096         512.00 MB       500.00 ms        50.00 ms       3455.1 ms         14.5%          1.4%

Overhead = transfer time / decode time for 100 output tokens
Decode time based on baseline: 34.6 ms/token x 100 tokens = 3455 ms


## Key Takeaways

**KV Cache Structure:**
- Llama-3.1-8B stores 2 tensors (key + value) per layer, 32 layers total
- With GQA, only 8 KV heads (not 32), so cache is smaller than full attention
- Per-token size: ~0.125 MB (131,072 bytes)

**Transfer Cost:**
- Linear with sequence length: double the tokens, double the transfer
- TCP adds significant overhead at longer sequences (kernel copies, serialization)
- RDMA keeps overhead minimal by bypassing the CPU entirely

**Why This Matters for Notebook 03:**
- vLLM's `NixlConnector` uses RDMA (NIXL) for KV cache transfer between nodes
- The transfer overhead we calculated here is what NixlConnector minimizes
- At typical sequence lengths (512-2048 tokens), RDMA overhead stays under 1%

**What's Next:**
- [03_Disaggregated_Serving.ipynb](03_Disaggregated_Serving.ipynb): Run prefill on spark-01, decode on spark-02, and measure the actual overhead