# Understanding KV Cache

The KV cache is what makes disaggregated serving possible—and expensive. Before we split inference across nodes, we need to understand what this cache contains, how big it is, and why transferring it matters.

## What is KV Cache?

In transformer models, each token attends to all previous tokens. Without caching, we'd recompute these attention keys and values for every new token—wasteful.

**Systems analogy**: KV cache is like a session cache in web servers:
- Prefill phase = initial request processing, populate cache
- Decode phase = subsequent requests, read from cache
- Transfer = moving cache between servers for load balancing

## Why This Matters for Disaggregation

When we split prefill (node 1) and decode (node 2):
1. Node 1 processes prompt → generates KV cache
2. Transfer KV cache from node 1 to node 2
3. Node 2 uses cache to generate tokens

The transfer speed determines if disaggregation is worth it.

## Step 1: Load Model and Configuration

In [None]:
import torch
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np

# Load environment config
config_file = Path("environment_config.json")
if config_file.exists():
    with open(config_file) as f:
        env_config = json.load(f)
    MODEL_NAME = env_config['model']['name']
else:
    MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

print(f"Model: {MODEL_NAME}")
print("Loading model for KV cache inspection...\n")

# Load tokenizer and model
# We use HuggingFace Transformers directly (not vLLM) to access internal state
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto"
)

print(f"✓ Model loaded")
print(f"  Parameters: {model.num_parameters() / 1e9:.2f}B")
print(f"  Layers: {model.config.num_hidden_layers}")
print(f"  Hidden size: {model.config.hidden_size}")
print(f"  Attention heads: {model.config.num_attention_heads}")

## Step 2: Generate with KV Cache Capture

Run inference while capturing the KV cache state. This shows what actually gets stored and transferred.

In [None]:
def generate_with_kv_cache_analysis(prompt, max_new_tokens=50):
    """
    Generate text and capture KV cache for analysis.
    Returns: generated text, KV cache, and metrics
    """
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_length = inputs['input_ids'].shape[1]
    
    print(f"Prompt: '{prompt}'")
    print(f"Input tokens: {input_length}\n")
    
    # Generate with cache
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # Greedy decoding
            return_dict_in_generate=True,
            output_attentions=False,
            use_cache=True,  # Enable KV caching
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode output
    generated_text = tokenizer.decode(
        outputs.sequences[0][input_length:],
        skip_special_tokens=True
    )
    output_length = outputs.sequences.shape[1] - input_length
    
    print(f"Output tokens: {output_length}")
    print(f"Generated: '{generated_text}'\n")
    
    # Analyze cache structure
    # past_key_values is a tuple of (key, value) pairs for each layer
    if hasattr(outputs, 'past_key_values') and outputs.past_key_values:
        past_kv = outputs.past_key_values
    else:
        # For models that don't return past_key_values in generate output
        # we need to do a forward pass to get them
        with torch.no_grad():
            result = model(**inputs, use_cache=True)
            past_kv = result.past_key_values
    
    return {
        'text': generated_text,
        'input_tokens': input_length,
        'output_tokens': output_length,
        'kv_cache': past_kv
    }

# Test with a sample prompt
test_prompt = "Explain load balancing in distributed systems."
result = generate_with_kv_cache_analysis(test_prompt, max_new_tokens=30)

## Step 3: Analyze KV Cache Structure

Examine the shape and size of the KV cache. This is what moves between nodes in disaggregated serving.

In [None]:
def analyze_kv_cache(kv_cache):
    """
    Analyze KV cache structure and memory usage.
    
    KV cache structure:
    - Tuple of length = num_layers
    - Each element is (key_tensor, value_tensor)
    - Shape: [batch_size, num_heads, sequence_length, head_dim]
    """
    if not kv_cache:
        print("No KV cache available")
        return None
    
    num_layers = len(kv_cache)
    
    # Get first layer's key and value tensors
    first_key, first_value = kv_cache[0]
    
    # Extract dimensions
    batch_size = first_key.shape[0]
    num_heads = first_key.shape[1]
    seq_length = first_key.shape[2]
    head_dim = first_key.shape[3]
    
    # Calculate memory per layer
    # Each layer has 2 tensors (key + value)
    bytes_per_element = first_key.element_size()  # Usually 2 bytes for float16
    elements_per_tensor = batch_size * num_heads * seq_length * head_dim
    bytes_per_layer = 2 * elements_per_tensor * bytes_per_element  # 2 for key+value
    total_bytes = bytes_per_layer * num_layers
    
    print("KV Cache Structure:")
    print("="*60)
    print(f"Layers: {num_layers}")
    print(f"Shape per tensor: [{batch_size}, {num_heads}, {seq_length}, {head_dim}]")
    print(f"  Batch size: {batch_size}")
    print(f"  Attention heads: {num_heads}")
    print(f"  Sequence length: {seq_length}")
    print(f"  Head dimension: {head_dim}")
    print(f"\nMemory Usage:")
    print(f"  Bytes per element: {bytes_per_element} ({first_key.dtype})")
    print(f"  Memory per layer: {bytes_per_layer / 1e6:.2f} MB")
    print(f"  Total KV cache: {total_bytes / 1e6:.2f} MB")
    print(f"  Per-token memory: {total_bytes / seq_length / 1e6:.4f} MB/token")
    
    return {
        'num_layers': num_layers,
        'batch_size': batch_size,
        'num_heads': num_heads,
        'seq_length': seq_length,
        'head_dim': head_dim,
        'total_mb': total_bytes / 1e6,
        'per_token_mb': total_bytes / seq_length / 1e6
    }

kv_analysis = analyze_kv_cache(result['kv_cache'])

## Step 4: KV Cache Growth with Sequence Length

The cache grows linearly with sequence length. Longer conversations = more data to transfer.

In [None]:
# Test different sequence lengths
test_lengths = [10, 50, 100, 200, 500]

print("KV Cache Size vs Sequence Length:\n")
print(f"{'Tokens':<10} {'KV Cache Size':<15} {'Transfer @ 10Gbps':<20} {'Transfer @ 100Gbps':<20}")
print("-" * 70)

# Use the per-token memory from analysis
per_token_mb = kv_analysis['per_token_mb'] if kv_analysis else 0.1

for num_tokens in test_lengths:
    cache_mb = per_token_mb * num_tokens
    cache_bytes = cache_mb * 1e6
    
    # Transfer time calculations
    # 10 Gbps = 1.25 GB/s theoretical, ~1 GB/s practical
    # 100 Gbps = 12.5 GB/s theoretical, ~10 GB/s practical
    transfer_10gbps_ms = (cache_bytes / (1 * 1e9)) * 1000
    transfer_100gbps_ms = (cache_bytes / (10 * 1e9)) * 1000
    
    print(f"{num_tokens:<10} {cache_mb:>10.2f} MB    {transfer_10gbps_ms:>12.2f} ms      {transfer_100gbps_ms:>12.2f} ms")

print("\nKey Insights:")
print("  • KV cache scales linearly with sequence length")
print("  • At 100 tokens: ~10-20 MB to transfer")
print("  • At 500 tokens: ~50-100 MB to transfer")
print("  • Network speed directly impacts disaggregation overhead")
print("  • This is why RDMA matters - 10x faster transfer")

## Step 5: Transfer Cost Analysis

Calculate the overhead of transferring KV cache between nodes. This determines if disaggregation is worthwhile.

In [None]:
# Load baseline metrics from previous notebook
baseline_file = Path("baseline_metrics.json")
if baseline_file.exists():
    with open(baseline_file) as f:
        baseline = json.load(f)
    per_token_latency_ms = baseline['single_request']['latency_ms'] / baseline['single_request']['tokens']
else:
    per_token_latency_ms = 20  # Fallback estimate

print("Transfer Cost vs Generation Cost\n")
print("Scenario: 100-token sequence\n")

# Calculate costs
sequence_length = 100
kv_cache_mb = per_token_mb * sequence_length
kv_cache_bytes = kv_cache_mb * 1e6

# Network transfer times
transfer_tcp_ms = (kv_cache_bytes / (1 * 1e9)) * 1000  # 10 Gbps TCP
transfer_rdma_ms = (kv_cache_bytes / (10 * 1e9)) * 1000  # 100 Gbps RDMA

# Generation time for 100 tokens
generation_time_ms = per_token_latency_ms * sequence_length

# Calculate overhead percentages
overhead_tcp = (transfer_tcp_ms / generation_time_ms) * 100
overhead_rdma = (transfer_rdma_ms / generation_time_ms) * 100

print(f"KV Cache Transfer:")
print(f"  Size: {kv_cache_mb:.2f} MB")
print(f"  Via TCP (10 Gbps): {transfer_tcp_ms:.2f} ms")
print(f"  Via RDMA (100 Gbps): {transfer_rdma_ms:.2f} ms")
print(f"\nToken Generation:")
print(f"  Time for {sequence_length} tokens: {generation_time_ms:.2f} ms")
print(f"  Per-token latency: {per_token_latency_ms:.2f} ms")
print(f"\nOverhead:")
print(f"  With TCP: {overhead_tcp:.1f}% overhead")
print(f"  With RDMA: {overhead_rdma:.1f}% overhead")

print("\n" + "="*60)
print("VERDICT")
print("="*60)
if overhead_tcp > 50:
    print("✗ TCP too slow - transfer time dominates")
    print("  Disaggregation not viable without RDMA")
elif overhead_tcp > 20:
    print("⚠ TCP marginal - significant overhead")
    print("  RDMA strongly recommended")
else:
    print("✓ TCP acceptable - low overhead")
    print("  RDMA improves but not critical")

if overhead_rdma < 10:
    print(f"\n✓ RDMA overhead: {overhead_rdma:.1f}% - disaggregation viable")
else:
    print(f"\n⚠ RDMA overhead: {overhead_rdma:.1f}% - carefully evaluate benefit")

## Step 6: Real-World Transfer Test

Simulate KV cache transfer to measure actual network performance. This is what happens between prefill and decode nodes.

In [None]:
import time

def simulate_kv_transfer(kv_cache, method='cpu_to_cpu'):
    """
    Simulate KV cache transfer.
    Methods:
    - 'cpu_to_cpu': CPU memory copy (simulates network transfer)
    - 'gpu_to_cpu': GPU to CPU copy (simulates prefill node sending)
    - 'cpu_to_gpu': CPU to GPU copy (simulates decode node receiving)
    """
    if not kv_cache:
        return None
    
    # Flatten KV cache to single tensor for measurement
    # In real disaggregation, this would be serialized for network
    all_keys = [kv[0] for kv in kv_cache]
    all_values = [kv[1] for kv in kv_cache]
    
    if method == 'gpu_to_cpu':
        print(f"Simulating GPU→CPU transfer (prefill node sending)...")
        start = time.time()
        cpu_keys = [k.cpu() for k in all_keys]
        cpu_values = [v.cpu() for v in all_values]
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elapsed = time.time() - start
        
    elif method == 'cpu_to_gpu':
        # First move to CPU
        cpu_keys = [k.cpu() for k in all_keys]
        cpu_values = [v.cpu() for v in all_values]
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        
        print(f"Simulating CPU→GPU transfer (decode node receiving)...")
        start = time.time()
        gpu_keys = [k.cuda() for k in cpu_keys]
        gpu_values = [v.cuda() for v in cpu_values]
        torch.cuda.synchronize()
        elapsed = time.time() - start
        
    else:  # cpu_to_cpu
        # First move to CPU
        cpu_keys = [k.cpu() for k in all_keys]
        cpu_values = [v.cpu() for v in all_values]
        
        print(f"Simulating CPU↔CPU copy (network transfer analog)...")
        start = time.time()
        copied_keys = [k.clone() for k in cpu_keys]
        copied_values = [v.clone() for v in cpu_values]
        elapsed = time.time() - start
    
    # Calculate total bytes transferred
    total_bytes = sum(k.nelement() * k.element_size() for k in all_keys) + \
                  sum(v.nelement() * v.element_size() for v in all_values)
    
    bandwidth_gbps = (total_bytes / 1e9) / elapsed * 8
    
    return {
        'method': method,
        'time_ms': elapsed * 1000,
        'size_mb': total_bytes / 1e6,
        'bandwidth_gbps': bandwidth_gbps
    }

# Test all transfer methods
print("Measuring KV Cache Transfer Performance\n")
print(f"{'Method':<20} {'Time':<12} {'Size':<12} {'Bandwidth':<15}")
print("-" * 65)

for method in ['gpu_to_cpu', 'cpu_to_gpu', 'cpu_to_cpu']:
    result = simulate_kv_transfer(result['kv_cache'], method=method)
    if result:
        print(f"{result['method']:<20} {result['time_ms']:>8.2f} ms  {result['size_mb']:>8.2f} MB  {result['bandwidth_gbps']:>10.2f} Gbps")

print("\nNote:")
print("  These are PCIe/memory bandwidth measurements")
print("  Network transfer adds serialization + actual network latency")
print("  RDMA bypasses CPU copies, going GPU→Network→GPU directly")

## Step 7: KV Cache Compression Opportunities

Can we compress the cache before transfer to reduce overhead?

In [None]:
import zlib

def test_kv_compression(kv_cache):
    """
    Test compression ratio on KV cache.
    Real-world: compression adds CPU overhead.
    """
    if not kv_cache:
        return None
    
    # Take first layer as representative sample
    key_tensor, value_tensor = kv_cache[0]
    
    # Convert to bytes
    key_bytes = key_tensor.cpu().numpy().tobytes()
    value_bytes = value_tensor.cpu().numpy().tobytes()
    
    original_size = len(key_bytes) + len(value_bytes)
    
    # Test compression
    start = time.time()
    compressed_key = zlib.compress(key_bytes, level=1)  # Fast compression
    compressed_value = zlib.compress(value_bytes, level=1)
    compression_time = time.time() - start
    
    compressed_size = len(compressed_key) + len(compressed_value)
    compression_ratio = original_size / compressed_size
    
    # Decompression
    start = time.time()
    _ = zlib.decompress(compressed_key)
    _ = zlib.decompress(compressed_value)
    decompression_time = time.time() - start
    
    print("KV Cache Compression Analysis (single layer)\n")
    print(f"Original size: {original_size / 1e6:.2f} MB")
    print(f"Compressed size: {compressed_size / 1e6:.2f} MB")
    print(f"Compression ratio: {compression_ratio:.2f}x")
    print(f"Compression time: {compression_time * 1000:.2f} ms")
    print(f"Decompression time: {decompression_time * 1000:.2f} ms")
    print(f"\nTotal overhead: {(compression_time + decompression_time) * 1000:.2f} ms")
    
    # Calculate if compression helps
    saved_bytes = original_size - compressed_size
    saved_transfer_time_10g = (saved_bytes / (1 * 1e9)) * 1000
    saved_transfer_time_100g = (saved_bytes / (10 * 1e9)) * 1000
    total_overhead = (compression_time + decompression_time) * 1000
    
    print(f"\nNet benefit @ 10 Gbps: {saved_transfer_time_10g - total_overhead:.2f} ms")
    print(f"Net benefit @ 100 Gbps: {saved_transfer_time_100g - total_overhead:.2f} ms")
    
    if saved_transfer_time_100g > total_overhead:
        print("\n✓ Compression beneficial even with RDMA")
    elif saved_transfer_time_10g > total_overhead:
        print("\n⚠ Compression only beneficial with slower networks")
    else:
        print("\n✗ Compression overhead exceeds transfer savings")

test_kv_compression(result['kv_cache'])

## Key Takeaways

**What is KV Cache:**
- Cached attention keys and values from transformer layers
- Stored as: [batch, num_heads, sequence_length, head_dim]
- Size scales linearly with sequence length
- Typical: 0.1-0.2 MB per token for small models, more for larger models

**Transfer Costs:**
- At 100 tokens: ~10-20 MB to transfer
- With 10 Gbps network: 10-20 ms transfer time
- With 100 Gbps RDMA: 1-2 ms transfer time
- Transfer overhead can be 20-50% of generation time with TCP

**Why RDMA Matters:**
- 10x faster network (100 Gbps vs 10 Gbps)
- Bypasses CPU copies (GPU→Network→GPU)
- Reduces transfer from 20ms to 2ms for typical sequences
- Makes disaggregation overhead acceptable (<5%)

**What's Next:**
- [03_Basic_Disaggregation.ipynb](03_Basic_Disaggregation.ipynb) - Split prefill and decode across nodes using standard networking