# Part 8.2: LLM Inference Optimization — The Formula 1 Edition

Training a model is expensive, but **inference** is where the real cost lives. Every user query, every API call, every token generated costs compute. At scale, a model that's 2x faster or uses half the memory saves millions of dollars.

**F1 analogy:** Training is like the off-season: months in the wind tunnel and simulation farm, building the best car you can. But inference is race day — and race day never stops. Every lap, every corner, every radio message needs to be processed in real time. The pit wall can't wait 30 seconds for a strategy recommendation while the driver is screaming "What do I do?!" into the radio. LLM inference optimization is the engineering that takes a brilliant but slow simulation model and makes it fast enough for the live pit wall. Quantization is like reducing telemetry precision from float32 to int8 for faster pit wall processing. KV caching is like remembering computations from already-processed laps. The goal: real-time strategy calls (low latency) without sacrificing the quality of post-race analysis (high throughput).

LLM inference has unique challenges: models are enormous (7B-400B+ parameters), generation is sequential (each token depends on the last), and users expect low latency. This notebook covers the key techniques that make production LLM serving feasible.

## Learning Objectives

- [ ] Understand the inference bottleneck: why LLM generation is slow
- [ ] Implement model quantization (INT8, INT4) from scratch
- [ ] Build a KV cache to eliminate redundant computation
- [ ] Understand speculative decoding for faster generation
- [ ] Implement continuous batching for throughput optimization
- [ ] Build a model distillation pipeline (teacher → student)
- [ ] Compare optimization techniques on speed, memory, and quality

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from collections import defaultdict
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

np.random.seed(42)
torch.manual_seed(42)

print("Part 8.2: LLM Inference Optimization — The Formula 1 Edition")
print("=" * 60)

---

## 1. The Inference Bottleneck

LLM inference has two phases:

1. **Prefill**: Process the entire prompt in parallel (compute-bound)
2. **Decode**: Generate tokens one at a time, each depending on all previous (memory-bound)

The decode phase is the bottleneck because:
- Each token requires loading the **entire model** from memory
- Generation is inherently **sequential** (can't parallelize across tokens)
- Most time is spent moving weights from memory to compute, not doing math

| Phase | Compute Pattern | Bottleneck | Tokens/sec | F1 Parallel |
|-------|----------------|------------|------------|-------------|
| Prefill | Matrix multiply (batched) | Compute | High (thousands) | Loading the full race history into the strategy model at the start of a stint — slow once, but done in bulk |
| Decode | Matrix-vector product | Memory bandwidth | Low (tens) | Real-time strategy calls — each decision requires reloading the entire model for one new data point |

**F1 analogy:** The prefill phase is like the pit wall loading 20 laps of telemetry history at once to initialize the strategy model before a stint — a big batch operation. The decode phase is like generating second-by-second strategy recommendations during the race: each new prediction depends on all previous ones, and you need the entire model in memory for each tick. The bottleneck is bandwidth, not compute — just like an F1 data link where the radio channel bandwidth limits how fast the pit wall can receive new telemetry, not how fast the computers can process it.

In [None]:
# Visualize the inference pipeline
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Prefill vs Decode phases
ax = axes[0]
ax.set_xlim(0, 14)
ax.set_ylim(0, 6)
ax.axis('off')
ax.set_title('LLM Inference Phases', fontsize=13, fontweight='bold')

# Prefill
box = mpatches.FancyBboxPatch((0.5, 3), 5, 2, boxstyle="round,pad=0.2",
                               facecolor='#3498db', edgecolor='black', linewidth=2)
ax.add_patch(box)
ax.text(3, 4.3, 'PREFILL', ha='center', fontsize=12, fontweight='bold', color='white')
ax.text(3, 3.6, 'Process entire prompt\nin parallel (fast)', ha='center', fontsize=9, color='white')

# Decode tokens
for i in range(5):
    x = 7 + i * 1.3
    box = mpatches.FancyBboxPatch((x, 3), 1, 2, boxstyle="round,pad=0.1",
                                   facecolor='#e74c3c', edgecolor='black', linewidth=1.5)
    ax.add_patch(box)
    ax.text(x + 0.5, 4, f'T{i+1}', ha='center', fontsize=10, fontweight='bold', color='white')

ax.text(10.2, 3.3, 'DECODE: one token at a time (slow)', ha='center', fontsize=9, color='gray')

ax.annotate('', xy=(6.8, 4), xytext=(5.5, 4),
           arrowprops=dict(arrowstyle='->', lw=2, color='gray'))

# Prompt tokens
for i in range(6):
    ax.text(0.8 + i * 0.8, 1.5, f'p{i}', ha='center', fontsize=9,
           bbox=dict(boxstyle='round', facecolor='#ecf0f1', edgecolor='gray'))
ax.text(3, 0.8, 'Input prompt tokens', ha='center', fontsize=9, color='gray')

# Time breakdown
ax = axes[1]
categories = ['7B Model', '13B Model', '70B Model']
prefill_times = [0.05, 0.1, 0.5]  # seconds for 512 token prompt
decode_times = [2.0, 3.8, 20.0]   # seconds for 256 output tokens

x = np.arange(len(categories))
w = 0.35
ax.bar(x - w/2, prefill_times, w, label='Prefill (512 tokens)', color='#3498db', edgecolor='black')
ax.bar(x + w/2, decode_times, w, label='Decode (256 tokens)', color='#e74c3c', edgecolor='black')

ax.set_xticks(x)
ax.set_xticklabels(categories)
ax.set_ylabel('Time (seconds)', fontsize=11)
ax.set_title('Prefill vs Decode Time', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## 2. Quantization

**Quantization** reduces the precision of model weights from 32-bit floats to 8-bit or 4-bit integers. This cuts memory usage by 4-8x and speeds up inference (less data to move from memory).

### How It Works

Map floating-point values to a smaller integer range:
$$x_\text{quant} = \text{round}\left(\frac{x - \text{zero\_point}}{\text{scale}}\right)$$
$$x_\text{dequant} = x_\text{quant} \times \text{scale} + \text{zero\_point}$$

| Precision | Bits | Memory per 1B params | Typical Quality Loss | F1 Parallel |
|-----------|------|---------------------|--------------------|-------------|
| FP32 | 32 | 4 GB | Baseline | Full-precision telemetry — every sensor at maximum resolution |
| FP16/BF16 | 16 | 2 GB | Negligible | Halving sensor precision — still captures all meaningful variation |
| INT8 | 8 | 1 GB | <1% degradation | Rounding tire temps to nearest degree, throttle to nearest percent — negligible impact |
| INT4 (GPTQ/AWQ) | 4 | 0.5 GB | 1-3% degradation | Coarse-graining telemetry for quick pit wall displays — loses subtle detail but shows the big picture |

**F1 analogy:** Quantization is exactly what F1 teams do when transmitting telemetry from car to pit wall. The car's ECU records data at extreme precision (float32), but the radio link has limited bandwidth. So the team quantizes: tire temperature doesn't need 7 decimal places — rounding to the nearest degree (INT8) loses almost nothing. Throttle position can be sent as a percentage (0-100) instead of a 32-bit float. The pit wall gets 4x more data channels within the same bandwidth, with barely any loss in decision quality.

In [None]:
class Quantizer:
    """Weight quantization from scratch."""
    
    @staticmethod
    def quantize_absmax(weights, n_bits=8):
        """Absmax (symmetric) quantization.
        
        Maps [-max, max] to [-2^(n-1), 2^(n-1)-1]
        """
        qmax = 2**(n_bits - 1) - 1
        scale = weights.abs().max() / qmax
        quantized = torch.round(weights / scale).clamp(-qmax, qmax).to(torch.int8)
        return quantized, scale
    
    @staticmethod
    def dequantize_absmax(quantized, scale):
        """Dequantize absmax back to float."""
        return quantized.float() * scale
    
    @staticmethod
    def quantize_zeropoint(weights, n_bits=8):
        """Zero-point (asymmetric) quantization.
        
        Maps [min, max] to [0, 2^n - 1]
        """
        qmin, qmax = 0, 2**n_bits - 1
        w_min, w_max = weights.min(), weights.max()
        
        scale = (w_max - w_min) / (qmax - qmin)
        zero_point = torch.round(-w_min / scale).clamp(qmin, qmax)
        
        quantized = torch.round(weights / scale + zero_point).clamp(qmin, qmax).to(torch.uint8)
        return quantized, scale, zero_point
    
    @staticmethod
    def dequantize_zeropoint(quantized, scale, zero_point):
        """Dequantize zero-point back to float."""
        return (quantized.float() - zero_point) * scale
    
    @staticmethod
    def quantize_per_channel(weights, n_bits=8):
        """Per-channel quantization (one scale per output channel).
        
        Better quality than per-tensor because each channel has its own range.
        """
        qmax = 2**(n_bits - 1) - 1
        # Scale per output channel (row)
        scales = weights.abs().max(dim=1, keepdim=True).values / qmax
        scales = scales.clamp(min=1e-8)  # Avoid division by zero
        quantized = torch.round(weights / scales).clamp(-qmax, qmax).to(torch.int8)
        return quantized, scales.squeeze()
    
    @staticmethod
    def dequantize_per_channel(quantized, scales):
        return quantized.float() * scales.unsqueeze(1)


# Create a sample weight matrix (simulating a linear layer)
torch.manual_seed(42)
weights = torch.randn(256, 512) * 0.02  # Typical weight scale

quantizer = Quantizer()

# Quantize with different methods
q8_absmax, scale_8 = quantizer.quantize_absmax(weights, n_bits=8)
dq8_absmax = quantizer.dequantize_absmax(q8_absmax, scale_8)

q4_absmax, scale_4 = quantizer.quantize_absmax(weights, n_bits=4)
dq4_absmax = quantizer.dequantize_absmax(q4_absmax, scale_4)

q8_pc, scales_pc = quantizer.quantize_per_channel(weights, n_bits=8)
dq8_pc = quantizer.dequantize_per_channel(q8_pc, scales_pc)

# Measure quantization error
def quantization_error(original, dequantized):
    mse = ((original - dequantized) ** 2).mean().item()
    max_err = (original - dequantized).abs().max().item()
    snr = 10 * math.log10(original.var().item() / mse) if mse > 0 else float('inf')
    return {'mse': mse, 'max_error': max_err, 'snr_db': snr}

print("Quantization Results\n")
print(f"Original weights: {weights.shape}, dtype={weights.dtype}")
print(f"Memory: {weights.numel() * 4 / 1024:.1f} KB (FP32)\n")

methods = [
    ('INT8 absmax', dq8_absmax, weights.numel() * 1),
    ('INT4 absmax', dq4_absmax, weights.numel() * 0.5),
    ('INT8 per-channel', dq8_pc, weights.numel() * 1),
]

print(f"{'Method':>20} {'MSE':>12} {'Max Error':>12} {'SNR (dB)':>10} {'Memory (KB)':>12} {'Compression':>12}")
print("-" * 80)
for name, dq, mem_bytes in methods:
    err = quantization_error(weights, dq)
    mem_kb = mem_bytes / 1024
    orig_kb = weights.numel() * 4 / 1024
    print(f"{name:>20} {err['mse']:>12.2e} {err['max_error']:>12.4f} {err['snr_db']:>10.1f} "
          f"{mem_kb:>12.1f} {orig_kb/mem_kb:>11.1f}x")

In [None]:
# Visualize quantization effects
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Weight distribution: original vs quantized
ax = axes[0, 0]
ax.hist(weights.flatten().numpy(), bins=100, alpha=0.6, label='FP32 original',
       color='#3498db', edgecolor='black', density=True)
ax.hist(dq8_absmax.flatten().numpy(), bins=100, alpha=0.6, label='INT8 dequantized',
       color='#e74c3c', edgecolor='black', density=True)
ax.set_xlabel('Weight Value', fontsize=10)
ax.set_ylabel('Density', fontsize=10)
ax.set_title('Weight Distribution: FP32 vs INT8', fontsize=12, fontweight='bold')
ax.legend(fontsize=9)

# Quantization error distribution
ax = axes[0, 1]
error_8 = (weights - dq8_absmax).flatten().numpy()
error_4 = (weights - dq4_absmax).flatten().numpy()
ax.hist(error_8, bins=100, alpha=0.6, label='INT8 error', color='#2ecc71', density=True)
ax.hist(error_4, bins=100, alpha=0.6, label='INT4 error', color='#e74c3c', density=True)
ax.set_xlabel('Quantization Error', fontsize=10)
ax.set_ylabel('Density', fontsize=10)
ax.set_title('Quantization Error Distribution', fontsize=12, fontweight='bold')
ax.legend(fontsize=9)

# Scatter: original vs quantized values
ax = axes[1, 0]
sample = np.random.choice(weights.numel(), 2000, replace=False)
ax.scatter(weights.flatten()[sample].numpy(), dq8_absmax.flatten()[sample].numpy(),
          alpha=0.2, s=5, color='#3498db', label='INT8')
ax.scatter(weights.flatten()[sample].numpy(), dq4_absmax.flatten()[sample].numpy(),
          alpha=0.2, s=5, color='#e74c3c', label='INT4')
lim = 0.06
ax.plot([-lim, lim], [-lim, lim], 'k--', alpha=0.3)
ax.set_xlabel('Original (FP32)', fontsize=10)
ax.set_ylabel('Dequantized', fontsize=10)
ax.set_title('Original vs Dequantized Values', fontsize=12, fontweight='bold')
ax.legend(fontsize=9)

# Memory comparison
ax = axes[1, 1]
model_sizes_gb = [7, 13, 70]  # Billion params
precisions = ['FP32', 'FP16', 'INT8', 'INT4']
multipliers = [4, 2, 1, 0.5]  # Bytes per param

x = np.arange(len(model_sizes_gb))
w = 0.2
colors = ['#3498db', '#2ecc71', '#f39c12', '#e74c3c']

for i, (prec, mult, color) in enumerate(zip(precisions, multipliers, colors)):
    mem = [s * mult for s in model_sizes_gb]
    ax.bar(x + i * w, mem, w, label=prec, color=color, edgecolor='black', alpha=0.8)

ax.set_xticks(x + 1.5 * w)
ax.set_xticklabels([f'{s}B params' for s in model_sizes_gb])
ax.set_ylabel('Memory (GB)', fontsize=10)
ax.set_title('Model Memory by Precision', fontsize=12, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## 3. KV Cache

The **KV cache** is the single most important optimization for autoregressive generation. Without it, generating token $t$ requires recomputing attention for all $t-1$ previous tokens.

### The Problem
In self-attention, we compute:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

When generating token $t$, the keys and values for tokens $1$ through $t-1$ haven't changed — but without caching, we recompute them every time.

### The Solution
Cache the K and V matrices from previous tokens. When generating token $t$:
1. Compute Q, K, V for **only the new token**
2. Append new K, V to the cache
3. Attend to the full cached K, V

This reduces per-token compute from O(t^2) to O(t).

**F1 analogy:** The KV cache is like the pit wall's running memory of the race. Without a cache, every time the strategist needs to make a call, they'd have to re-analyze every lap from the start — recalculating tire degradation curves, fuel load effects, and track evolution from scratch. With a KV cache, all that analysis is stored. When lap 42 comes in, the pit wall only computes the *new* lap's contribution and appends it to the existing analysis. The result is identical, but the wall clock time drops from "re-process 42 laps" to "process 1 new lap." That's the difference between a strategy call taking 30 seconds and taking 1 second.

In [None]:
class CachedAttention(nn.Module):
    """Self-attention with KV cache for efficient generation."""
    
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x, cache=None):
        """Forward with optional KV cache.
        
        Args:
            x: (batch, seq_len, d_model)
            cache: tuple of (cached_K, cached_V) or None
        
        Returns:
            output, (new_K, new_V)
        """
        B, T, _ = x.shape
        
        Q = self.q_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        K = self.k_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        V = self.v_proj(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        # Append to cache
        if cache is not None:
            cached_K, cached_V = cache
            K = torch.cat([cached_K, K], dim=2)
            V = torch.cat([cached_V, V], dim=2)
        
        # Standard attention
        scale = math.sqrt(self.d_head)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        
        # Causal mask
        seq_len_k = K.shape[2]
        seq_len_q = Q.shape[2]
        mask = torch.triu(torch.ones(seq_len_q, seq_len_k, device=x.device),
                         diagonal=seq_len_k - seq_len_q + 1).bool()
        scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
        
        return self.out_proj(out), (K, V)


# Benchmark: with vs without KV cache
d_model = 128
n_heads = 4
attn = CachedAttention(d_model, n_heads)
attn.eval()

# Simulate generating 32 tokens
gen_len = 32
prompt = torch.randn(1, 8, d_model)  # 8-token prompt

# WITHOUT cache: recompute everything each step
start = time.perf_counter()
all_tokens_no_cache = prompt.clone()
with torch.no_grad():
    for i in range(gen_len):
        out, _ = attn(all_tokens_no_cache, cache=None)
        new_token = out[:, -1:, :]  # Take last token output
        all_tokens_no_cache = torch.cat([all_tokens_no_cache, new_token], dim=1)
time_no_cache = time.perf_counter() - start

# WITH cache: only process new token each step
start = time.perf_counter()
with torch.no_grad():
    # Prefill: process prompt
    out, cache = attn(prompt, cache=None)
    generated = [out[:, -1:, :]]
    
    # Decode: one token at a time
    for i in range(gen_len - 1):
        out, cache = attn(generated[-1], cache=cache)
        generated.append(out)
time_with_cache = time.perf_counter() - start

print("KV Cache Benchmark\n")
print(f"  Without cache: {time_no_cache*1000:.1f}ms")
print(f"  With cache:    {time_with_cache*1000:.1f}ms")
print(f"  Speedup:       {time_no_cache/time_with_cache:.1f}x")
print(f"  Cache size:    K={cache[0].shape}, V={cache[1].shape}")

In [None]:
# Visualize KV cache mechanics and scaling
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Computation per step: with vs without cache
ax = axes[0]
seq_lens = np.arange(1, 65)
# Without cache: O(t^2) total compute at step t
no_cache_ops = seq_lens ** 2
# With cache: O(t) compute at step t
cache_ops = seq_lens

ax.plot(seq_lens, no_cache_ops, 'r-', linewidth=2, label='Without KV cache (O(t²))')
ax.plot(seq_lens, cache_ops, 'g-', linewidth=2, label='With KV cache (O(t))')
ax.fill_between(seq_lens, cache_ops, no_cache_ops, alpha=0.15, color='red', label='Saved compute')
ax.set_xlabel('Sequence Position', fontsize=11)
ax.set_ylabel('Compute Operations', fontsize=11)
ax.set_title('Per-Step Compute: Cache vs No Cache', fontsize=12, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# KV cache memory usage
ax = axes[1]
seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
# Cache size per layer = 2 * seq_len * d_model * batch_size (K and V)
# For 7B model: ~32 layers, d_model=4096
d = 4096
n_layers = 32
batch_sizes = [1, 4, 16]

for bs in batch_sizes:
    cache_gb = [2 * sl * d * n_layers * bs * 2 / 1e9 for sl in seq_lengths]  # FP16
    ax.plot(seq_lengths, cache_gb, 'o-', linewidth=2, markersize=6, label=f'Batch size {bs}')

ax.set_xlabel('Sequence Length', fontsize=11)
ax.set_ylabel('KV Cache Memory (GB)', fontsize=11)
ax.set_title('KV Cache Memory (7B model, FP16)', fontsize=12, fontweight='bold')
ax.legend(fontsize=10)
ax.set_xscale('log', base=2)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## 4. Speculative Decoding

**Speculative decoding** uses a small, fast "draft" model to propose multiple tokens at once, then verifies them with the large model in a single forward pass.

### How It Works

1. **Draft**: Small model generates K candidate tokens quickly
2. **Verify**: Large model scores all K tokens in parallel (one forward pass)
3. **Accept/reject**: Keep tokens where draft agrees with large model
4. **Resample**: If rejected at position i, sample from large model's distribution there

The key insight: the large model does the **same amount of work** regardless of how many draft tokens are accepted, but each accepted token saves a full decode step.

**F1 analogy:** Speculative decoding is like having a junior strategist (the draft model) sitting next to the chief strategist (the large model). The junior quickly sketches out the next 5 laps of strategy: "stay out, stay out, pit lap 35, mediums, undercut." The chief strategist then reviews all 5 calls in one look — much faster than making each call independently. If the junior got the first 3 right, great — that saved 3 rounds of deliberation. If call 4 was wrong, the chief overrides from that point. The junior is fast but approximate; the chief is slow but authoritative. Together, they're faster than the chief alone.

In [None]:
class SpeculativeDecoder:
    """Speculative decoding simulation."""
    
    def __init__(self, target_model_time_ms, draft_model_time_ms, acceptance_rate=0.7):
        """
        target_model_time_ms: Time for one forward pass of large model
        draft_model_time_ms: Time for one forward pass of small model
        acceptance_rate: Probability draft token matches target
        """
        self.target_time = target_model_time_ms
        self.draft_time = draft_model_time_ms
        self.acceptance_rate = acceptance_rate
    
    def simulate_normal_decoding(self, n_tokens):
        """Standard autoregressive decoding."""
        return n_tokens * self.target_time
    
    def simulate_speculative(self, n_tokens, k=5):
        """Speculative decoding with k draft tokens."""
        total_time = 0
        generated = 0
        n_verify_calls = 0
        n_draft_calls = 0
        
        while generated < n_tokens:
            # Draft: generate k tokens with small model
            total_time += k * self.draft_time
            n_draft_calls += k
            
            # Verify: one forward pass of large model for all k tokens
            total_time += self.target_time
            n_verify_calls += 1
            
            # Accept tokens until first rejection
            accepted = 0
            for _ in range(k):
                if np.random.random() < self.acceptance_rate:
                    accepted += 1
                else:
                    accepted += 1  # We still get 1 token from large model at rejection point
                    break
            
            generated += accepted
        
        return total_time, n_verify_calls, n_draft_calls


# Compare decoding strategies
decoder = SpeculativeDecoder(
    target_model_time_ms=50,  # Large model: 50ms per token
    draft_model_time_ms=5,    # Small model: 5ms per token
    acceptance_rate=0.7
)

n_tokens = 256

normal_time = decoder.simulate_normal_decoding(n_tokens)

print(f"Generating {n_tokens} tokens\n")
print(f"Normal decoding: {normal_time:.0f}ms ({n_tokens} target calls)")

# Try different k values
results = []
for k in [2, 3, 5, 8, 12]:
    times = []
    for _ in range(50):  # Average over runs
        t, n_verify, n_draft = decoder.simulate_speculative(n_tokens, k=k)
        times.append(t)
    avg_time = np.mean(times)
    speedup = normal_time / avg_time
    results.append({'k': k, 'time': avg_time, 'speedup': speedup})
    print(f"Speculative (k={k:>2}): {avg_time:.0f}ms -> {speedup:.2f}x speedup")

# Also vary acceptance rate
print("\nEffect of acceptance rate (k=5):")
for rate in [0.5, 0.6, 0.7, 0.8, 0.9]:
    decoder.acceptance_rate = rate
    times = [decoder.simulate_speculative(n_tokens, k=5)[0] for _ in range(50)]
    avg = np.mean(times)
    print(f"  acceptance={rate:.0%}: {avg:.0f}ms -> {normal_time/avg:.2f}x speedup")

In [None]:
# Visualize speculative decoding
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Speedup vs k
ax = axes[0]
ks = [r['k'] for r in results]
speedups = [r['speedup'] for r in results]
ax.plot(ks, speedups, 'bo-', linewidth=2, markersize=8)
ax.axhline(y=1, color='gray', linestyle='--', alpha=0.5, label='No speedup')
ax.set_xlabel('Draft Length (k)', fontsize=11)
ax.set_ylabel('Speedup vs Normal Decoding', fontsize=11)
ax.set_title('Speculative Decoding Speedup', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Speedup vs acceptance rate
ax = axes[1]
rates = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95]
for k in [3, 5, 8]:
    speedups_rate = []
    for rate in rates:
        decoder.acceptance_rate = rate
        times = [decoder.simulate_speculative(n_tokens, k=k)[0] for _ in range(30)]
        speedups_rate.append(normal_time / np.mean(times))
    ax.plot(rates, speedups_rate, 'o-', linewidth=2, markersize=6, label=f'k={k}')

ax.set_xlabel('Acceptance Rate', fontsize=11)
ax.set_ylabel('Speedup', fontsize=11)
ax.set_title('Speedup vs Draft Quality', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## 5. Continuous Batching

In naive batching, all requests in a batch must wait for the longest request to finish. **Continuous batching** (also called "inflight batching") allows new requests to join and completed requests to leave the batch dynamically.

| Approach | Throughput | Latency | GPU Utilization | F1 Parallel |
|----------|-----------|---------|----------------|-------------|
| No batching | Low | Optimal per-request | Very low | Processing one car's data at a time — other 19 cars wait idle |
| Static batching | Medium | Worst-case per-batch | Medium | Processing all 20 cars together, but waiting until the slowest finishes before starting the next batch |
| Continuous batching | High | Near-optimal | High | Processing multiple cars' data simultaneously, with each car's analysis completing and freeing resources independently |

**F1 analogy:** Continuous batching is how a modern F1 pit wall actually processes data from all 20 cars on the grid. Without batching, the system would analyze Car 1's telemetry, then Car 2's, then Car 3's — serialized and slow. Static batching would process all 20 cars together but wait for the most complex analysis (say, the car on a complex mixed strategy) to finish before starting any new work. Continuous batching is the real-world approach: as soon as Car 7's simple "stay out" analysis finishes, that compute slot immediately starts processing new data, even while Car 14's complex undercut calculation is still running.

In [None]:
class BatchingSimulator:
    """Simulate different batching strategies."""
    
    def __init__(self, time_per_token_ms=10, max_batch_size=16):
        self.time_per_token = time_per_token_ms
        self.max_batch = max_batch_size
    
    def no_batching(self, requests):
        """Process requests one at a time."""
        results = []
        current_time = 0
        
        for req in requests:
            start = max(current_time, req['arrival'])
            duration = req['output_tokens'] * self.time_per_token
            end = start + duration
            results.append({
                'latency': end - req['arrival'],
                'start': start, 'end': end
            })
            current_time = end
        
        return results
    
    def static_batching(self, requests, batch_size=4):
        """Process requests in fixed-size batches."""
        results = [None] * len(requests)
        current_time = 0
        
        for i in range(0, len(requests), batch_size):
            batch = requests[i:i+batch_size]
            batch_start = max(current_time, max(r['arrival'] for r in batch))
            
            # All requests must wait for the longest one
            max_tokens = max(r['output_tokens'] for r in batch)
            # Batched: ~same time as single request (parallel on GPU)
            duration = max_tokens * self.time_per_token
            batch_end = batch_start + duration
            
            for j, req in enumerate(batch):
                results[i + j] = {
                    'latency': batch_end - req['arrival'],
                    'start': batch_start, 'end': batch_end
                }
            current_time = batch_end
        
        return results
    
    def continuous_batching(self, requests, max_batch=8):
        """Process requests with continuous batching."""
        results = [None] * len(requests)
        active = []  # (request_idx, tokens_remaining)
        queue = list(range(len(requests)))
        current_time = 0
        
        while queue or active:
            # Add new requests to batch
            while queue and len(active) < max_batch:
                idx = queue[0]
                if requests[idx]['arrival'] <= current_time:
                    queue.pop(0)
                    active.append((idx, requests[idx]['output_tokens']))
                else:
                    break
            
            if not active:
                if queue:
                    current_time = requests[queue[0]]['arrival']
                continue
            
            # Process one step for all active requests
            current_time += self.time_per_token
            
            new_active = []
            for idx, remaining in active:
                if remaining <= 1:
                    # Request complete
                    results[idx] = {
                        'latency': current_time - requests[idx]['arrival'],
                        'end': current_time
                    }
                else:
                    new_active.append((idx, remaining - 1))
            active = new_active
        
        return results


# Simulate requests
np.random.seed(42)
n_requests = 20
requests = []
for i in range(n_requests):
    requests.append({
        'arrival': i * 50,  # 50ms between arrivals
        'output_tokens': np.random.randint(10, 100),
    })

sim = BatchingSimulator(time_per_token_ms=10)

no_batch_results = sim.no_batching(requests)
static_results = sim.static_batching(requests, batch_size=4)
continuous_results = sim.continuous_batching(requests, max_batch=8)

print("Batching Strategy Comparison\n")
for name, results in [('No batching', no_batch_results),
                       ('Static (bs=4)', static_results),
                       ('Continuous (max=8)', continuous_results)]:
    latencies = [r['latency'] for r in results if r is not None]
    total_time = max(r['end'] for r in results if r is not None)
    throughput = n_requests / (total_time / 1000)  # req/sec
    print(f"  {name:>25}: avg latency={np.mean(latencies):.0f}ms, "
          f"p99={np.percentile(latencies, 99):.0f}ms, "
          f"throughput={throughput:.1f} req/s")

In [None]:
# Visualize batching comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Latency distribution
ax = axes[0]
strategies = [
    ('No batching', no_batch_results, '#e74c3c'),
    ('Static', static_results, '#f39c12'),
    ('Continuous', continuous_results, '#2ecc71'),
]

for name, results, color in strategies:
    latencies = sorted([r['latency'] for r in results if r is not None])
    ax.plot(range(len(latencies)), latencies, 'o-', color=color, linewidth=2,
           markersize=4, label=name)

ax.set_xlabel('Request (sorted by latency)', fontsize=11)
ax.set_ylabel('Latency (ms)', fontsize=11)
ax.set_title('Per-Request Latency', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Summary bars
ax = axes[1]
names = ['No\nbatching', 'Static\n(bs=4)', 'Continuous\n(max=8)']
avg_latencies = []
throughputs = []

for _, results, _ in strategies:
    latencies = [r['latency'] for r in results if r is not None]
    total_time = max(r['end'] for r in results if r is not None)
    avg_latencies.append(np.mean(latencies))
    throughputs.append(n_requests / (total_time / 1000))

x = np.arange(len(names))
w = 0.35
colors_bar = ['#3498db', '#e74c3c']
ax.bar(x - w/2, avg_latencies, w, label='Avg Latency (ms)', color='#3498db', edgecolor='black')
ax2 = ax.twinx()
ax2.bar(x + w/2, throughputs, w, label='Throughput (req/s)', color='#2ecc71', edgecolor='black')

ax.set_xticks(x)
ax.set_xticklabels(names)
ax.set_ylabel('Avg Latency (ms)', fontsize=10, color='#3498db')
ax2.set_ylabel('Throughput (req/s)', fontsize=10, color='#2ecc71')
ax.set_title('Batching Strategy Comparison', fontsize=13, fontweight='bold')
ax.legend(loc='upper left', fontsize=9)
ax2.legend(loc='upper right', fontsize=9)

plt.tight_layout()
plt.show()

---

## 6. Knowledge Distillation

**Distillation** trains a small "student" model to mimic a large "teacher" model. The student learns from the teacher's soft probability distributions, which carry more information than hard labels.

$$\mathcal{L}_{\text{distill}} = \alpha \cdot T^2 \cdot \text{KL}\left(\sigma\left(\frac{z_s}{T}\right) \| \sigma\left(\frac{z_t}{T}\right)\right) + (1-\alpha) \cdot \text{CE}(y, z_s)$$

where $T$ is the temperature and $\alpha$ balances distillation vs. hard-label loss.

**F1 analogy:** Distillation is like the relationship between a team's full CFD simulation and the simplified model that runs on the pit wall during a race. The full CFD (teacher) takes hours to compute aerodynamic loads for a single configuration — far too slow for race day. But by having the simplified model (student) learn from thousands of CFD outputs, the pit wall model captures the *essence* of the full simulation's knowledge. The temperature parameter controls how much "soft" insight transfers: at high temperature, the student learns not just that "low downforce is best for Monza" but *how much better* it is than medium downforce, and how close medium is to high — the full ranking, not just the winner.

In [None]:
class DistillationTrainer:
    """Knowledge distillation from teacher to student model."""
    
    def __init__(self, teacher, student, temperature=4.0, alpha=0.7):
        self.teacher = teacher
        self.student = student
        self.temperature = temperature
        self.alpha = alpha
    
    def distillation_loss(self, student_logits, teacher_logits, targets):
        """Compute combined distillation + hard label loss."""
        T = self.temperature
        
        # Soft target loss (KL divergence on softened distributions)
        student_soft = F.log_softmax(student_logits / T, dim=-1)
        teacher_soft = F.softmax(teacher_logits / T, dim=-1)
        soft_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (T * T)
        
        # Hard target loss
        hard_loss = F.cross_entropy(student_logits, targets)
        
        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
    
    def train_step(self, x, targets, optimizer):
        """One training step."""
        self.teacher.eval()
        self.student.train()
        
        with torch.no_grad():
            teacher_logits = self.teacher(x)
        
        student_logits = self.student(x)
        loss = self.distillation_loss(student_logits, teacher_logits, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()


# Create teacher (large) and student (small) models
n_classes = 10
input_dim = 64

teacher = nn.Sequential(
    nn.Linear(input_dim, 256), nn.ReLU(),
    nn.Linear(256, 128), nn.ReLU(),
    nn.Linear(128, n_classes)
)

student = nn.Sequential(
    nn.Linear(input_dim, 32), nn.ReLU(),
    nn.Linear(32, n_classes)
)

student_no_distill = nn.Sequential(
    nn.Linear(input_dim, 32), nn.ReLU(),
    nn.Linear(32, n_classes)
)
# Copy initial weights so comparison is fair
student_no_distill.load_state_dict(student.state_dict())

teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher: {teacher_params:,} params")
print(f"Student: {student_params:,} params ({student_params/teacher_params:.1%} of teacher)")

# Generate synthetic dataset
n_train = 500
X = torch.randn(n_train, input_dim)
# Teacher generates "ground truth"
teacher.eval()
with torch.no_grad():
    y = teacher(X).argmax(dim=-1)

# Train teacher to convergence first
teacher_opt = torch.optim.Adam(teacher.parameters(), lr=1e-3)
teacher.train()
for _ in range(200):
    loss = F.cross_entropy(teacher(X), y)
    teacher_opt.zero_grad()
    loss.backward()
    teacher_opt.step()

teacher.eval()
teacher_acc = (teacher(X).argmax(dim=-1) == y).float().mean().item()
print(f"\nTeacher accuracy: {teacher_acc:.1%}")

# Train student WITH distillation
distiller = DistillationTrainer(teacher, student, temperature=4.0, alpha=0.7)
student_opt = torch.optim.Adam(student.parameters(), lr=1e-3)

distill_losses = []
for epoch in range(200):
    loss = distiller.train_step(X, y, student_opt)
    distill_losses.append(loss)

# Train student WITHOUT distillation (hard labels only)
no_distill_opt = torch.optim.Adam(student_no_distill.parameters(), lr=1e-3)
no_distill_losses = []
student_no_distill.train()
for epoch in range(200):
    logits = student_no_distill(X)
    loss = F.cross_entropy(logits, y)
    no_distill_opt.zero_grad()
    loss.backward()
    no_distill_opt.step()
    no_distill_losses.append(loss.item())

# Evaluate
student.eval()
student_no_distill.eval()
distill_acc = (student(X).argmax(dim=-1) == y).float().mean().item()
no_distill_acc = (student_no_distill(X).argmax(dim=-1) == y).float().mean().item()

print(f"Student (with distillation): {distill_acc:.1%}")
print(f"Student (without distillation): {no_distill_acc:.1%}")
print(f"\nDistillation advantage: {distill_acc - no_distill_acc:+.1%}")

In [None]:
# Visualize distillation
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training loss comparison
ax = axes[0]
w = 10
smooth_distill = [np.mean(distill_losses[max(0,i-w):i+1]) for i in range(len(distill_losses))]
smooth_no = [np.mean(no_distill_losses[max(0,i-w):i+1]) for i in range(len(no_distill_losses))]
ax.plot(smooth_distill, linewidth=2, label='With distillation', color='#2ecc71')
ax.plot(smooth_no, linewidth=2, label='Without distillation', color='#e74c3c')
ax.set_xlabel('Epoch', fontsize=11)
ax.set_ylabel('Loss', fontsize=11)
ax.set_title('Student Training Loss', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

# Accuracy comparison bar chart
ax = axes[1]
models = ['Teacher\n(large)', 'Student +\nDistillation', 'Student\n(hard labels)']
accs = [teacher_acc, distill_acc, no_distill_acc]
colors = ['#3498db', '#2ecc71', '#e74c3c']
params = [teacher_params, student_params, student_params]

bars = ax.bar(models, accs, color=colors, edgecolor='black', alpha=0.8)
for bar, acc, p in zip(bars, accs, params):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{acc:.1%}\n({p:,} params)', ha='center', fontsize=9, fontweight='bold')

ax.set_ylabel('Accuracy', fontsize=11)
ax.set_title('Model Comparison', fontsize=13, fontweight='bold')
ax.set_ylim(0, 1.15)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## 7. Optimization Techniques Compared

Let's compare all optimization techniques on the dimensions that matter — like comparing different car development strategies on lap time, reliability, and cost.

**F1 analogy:** Just as an F1 team evaluates upgrades on multiple dimensions (lap time gain vs. weight vs. reliability vs. cost), inference optimizations must be evaluated on speed, memory, quality, and implementation complexity. The best teams stack multiple optimizations, just as the best serving systems combine quantization + KV cache + continuous batching.

In [None]:
# Comprehensive comparison
techniques = {
    'Baseline (FP32)':     {'memory': 1.0, 'speed': 1.0, 'quality': 1.0, 'complexity': 1},
    'FP16':                {'memory': 0.5, 'speed': 1.5, 'quality': 0.99, 'complexity': 1},
    'INT8 Quant':          {'memory': 0.25, 'speed': 2.0, 'quality': 0.97, 'complexity': 2},
    'INT4 Quant':          {'memory': 0.125, 'speed': 2.5, 'quality': 0.93, 'complexity': 3},
    'KV Cache':            {'memory': 1.1, 'speed': 3.0, 'quality': 1.0, 'complexity': 2},
    'Speculative':         {'memory': 1.3, 'speed': 2.5, 'quality': 1.0, 'complexity': 4},
    'Continuous Batch':    {'memory': 1.0, 'speed': 2.0, 'quality': 1.0, 'complexity': 3},
    'Distillation':        {'memory': 0.3, 'speed': 3.0, 'quality': 0.90, 'complexity': 4},
}

fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Speed vs Memory tradeoff
ax = axes[0]
for name, vals in techniques.items():
    color = plt.cm.viridis(vals['quality'])
    size = (1.1 - vals['memory']) * 200 + 50  # Bigger = less memory
    ax.scatter(vals['memory'], vals['speed'], s=size, c=[vals['quality']],
              cmap='RdYlGn', vmin=0.85, vmax=1.0, edgecolors='black', linewidth=1, zorder=5)
    ax.annotate(name, (vals['memory'], vals['speed']),
               textcoords='offset points', xytext=(8, 5), fontsize=8)

ax.set_xlabel('Memory (relative to baseline)', fontsize=11)
ax.set_ylabel('Speed (relative to baseline)', fontsize=11)
ax.set_title('Speed vs Memory (color=quality)', fontsize=13, fontweight='bold')
ax.axhline(y=1, color='gray', linestyle='--', alpha=0.3)
ax.axvline(x=1, color='gray', linestyle='--', alpha=0.3)
ax.grid(True, alpha=0.3)

# Bar comparison
ax = axes[1]
names = list(techniques.keys())
x = np.arange(len(names))
w = 0.2

metrics = ['speed', 'quality']
colors_bar = ['#3498db', '#2ecc71']
labels = ['Speed', 'Quality']

for i, (metric, color, label) in enumerate(zip(metrics, colors_bar, labels)):
    vals = [techniques[n][metric] for n in names]
    ax.bar(x + i * w, vals, w, label=label, color=color, edgecolor='black', alpha=0.8)

ax.set_xticks(x + w/2)
ax.set_xticklabels(names, rotation=40, ha='right', fontsize=8)
ax.set_ylabel('Relative to Baseline', fontsize=10)
ax.set_title('Speed and Quality by Technique', fontsize=13, fontweight='bold')
ax.axhline(y=1, color='gray', linestyle='--', alpha=0.3)
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

---

## Exercises

### Exercise 1: Group Quantization

Implement **group quantization** where weights are quantized in groups of 128 (each group has its own scale). Compare error against per-tensor and per-channel quantization at INT4 precision.

**F1 scenario:** Different sections of a telemetry trace have wildly different ranges — brake pressure spikes to 200 bar in braking zones but sits at 0 on straights. Group quantization is like having separate precision scales for each track sector: high-range encoding for braking zones, fine-grained encoding for smooth straights. Implement this and show it reduces quantization error compared to one-size-fits-all encoding.

In [None]:
# Exercise 1: Your code here
# Hint: Reshape weights into groups of 128, quantize each group independently,
# then reshape back. Compare MSE against per-tensor and per-channel.


### Exercise 2: PagedAttention Simulator

Implement a simplified version of **PagedAttention** (used in vLLM). Instead of pre-allocating KV cache for max sequence length, allocate fixed-size pages on demand. Show memory savings compared to naive pre-allocation.

**F1 scenario:** Pre-allocating KV cache for max sequence length is like reserving pit wall memory for a full 78-lap race for every car, even if some cars retire on lap 1. PagedAttention allocates memory in fixed-size "pages" on demand — like reserving pit wall compute for each car only as their race progresses. Show the memory savings when cars (requests) have varying race lengths (sequence lengths).

In [None]:
# Exercise 2: Your code here
# Hint: Create a page table that maps sequence positions to memory pages.
# Track allocated vs wasted memory compared to contiguous allocation.


### Exercise 3: Distillation with Temperature Sweep

Run distillation experiments at temperatures T = 1, 2, 4, 8, 16. Plot student accuracy vs temperature. What temperature works best and why?

**F1 scenario:** Temperature in distillation controls how much nuance transfers from the full CFD simulation to the pit wall model. Low T (T=1) gives sharp "this setup is best" signals. High T (T=16) gives soft "here's the full ranking of all setups with their relative merits." Find the sweet spot where the pit wall model learns the most useful knowledge from the CFD teacher.

In [None]:
# Exercise 3: Your code here


---

## Summary

### Key Concepts

| Concept | What It Does | F1 Parallel |
|---------|-------------|-------------|
| **Memory-bound decoding** | Most inference time is moving weights, not computing | Pit wall bandwidth limit — the data link, not the CPU, is the bottleneck |
| **Quantization** (INT8/INT4) | Reduces memory 4-8x with minimal quality loss | Reducing telemetry precision from float32 to int8 for faster pit wall processing |
| **KV cache** | Eliminates redundant computation, O(t^2) to O(t) | Caching computed analysis for already-processed laps instead of re-analyzing from scratch |
| **Speculative decoding** | Small draft model proposes, large model verifies in parallel | Junior strategist proposes next 5 calls, chief reviews them all at once |
| **Continuous batching** | Dynamically adds/removes requests from GPU batches | Processing all 20 cars simultaneously, each completing independently |
| **Distillation** | Small student learns from large teacher's soft outputs | Pit wall model trained on thousands of full CFD simulation results |
| **Stacking** | Production systems combine all techniques together | The complete pit wall stack: compressed telemetry + cached history + parallel processing |

### The Optimization Stack

In production, you don't pick one technique — you layer them, just like an F1 team layers every marginal gain:
1. **Distillation** -> Smaller model (like distilling a full CFD sim into a real-time pit wall model)
2. **Quantization** -> Less memory per parameter (like compressing telemetry for faster radio transmission)
3. **KV cache** -> Less redundant compute (like keeping a running race analysis instead of starting from scratch)
4. **Continuous batching** -> Higher throughput (like processing all cars' data simultaneously)
5. **Speculative decoding** -> Lower latency (like having a junior strategist pre-draft calls for the chief to approve)

---

## Next Steps

Optimizing inference is about making models fast — getting the strategy model from the simulation farm to the live pit wall. But building reliable ML systems requires more than fast models — it requires **experiment tracking, reproducibility, and systematic model management**. In **Notebook 27: ML Systems & Experiment Tracking**, we'll build the infrastructure that makes ML development systematic — the factory behind the race team.