# Training & Serving at Scale

In this notebook, you'll calculate, simulate, and reason about the engineering that gets large models trained and served. No multi-GPU hardware needed—everything runs as calculation and simulation on a single CPU.

**What you'll do:**
- Build a training memory calculator that computes per-GPU requirements for different parallelism strategies and ZeRO stages, then see why a 70B model cannot train on a single GPU
- Simulate speculative decoding with a draft-then-verify loop: measure how acceptance rate changes with draft length K, and see why the speedup comes from parallel verification
- Simulate a continuous batching inference server: compare GPU utilization between static batching (wait for all to finish) and continuous batching (fill completed slots immediately)
- Build a parallelism strategy advisor: given a model size, GPU count, and GPU memory, recommend the right combination of data, tensor, and pipeline parallelism

**For each exercise, PREDICT the output before running the cell.** Wrong predictions are more valuable than correct ones—they reveal gaps in your mental model.

In [None]:
# Setup -- self-contained for Google Colab
# No extra pip installs needed.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from dataclasses import dataclass
from typing import Optional

# Reproducible results
np.random.seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

print('Setup complete.')

## Shared Helpers

Formatting utilities used across multiple exercises.

In [None]:
def format_bytes(b: float) -> str:
    """Human-readable byte string."""
    if b >= 1e12:
        return f'{b / 1e12:.1f} TB'
    if b >= 1e9:
        return f'{b / 1e9:.1f} GB'
    if b >= 1e6:
        return f'{b / 1e6:.1f} MB'
    return f'{b / 1e3:.1f} KB'


def format_number(n: float) -> str:
    """Human-readable large number."""
    if n >= 1e12:
        return f'{n / 1e12:.1f}T'
    if n >= 1e9:
        return f'{n / 1e9:.1f}B'
    if n >= 1e6:
        return f'{n / 1e6:.0f}M'
    return f'{n / 1e3:.0f}K'


print('Helpers loaded.')
print(f'  format_bytes(84e9) = {format_bytes(84e9)}')
print(f'  format_number(70e9) = {format_number(70e9)}')

---

## Exercise 1: Training Memory Calculator (Guided)

You already know the training memory breakdown from the LoRA & Quantization lesson: mixed-precision Adam training requires ~12 bytes per parameter (2B weights + 2B gradients + 4B momentum + 4B variance). Optimizer states alone account for two-thirds of training memory.

In this exercise, you'll see a complete memory calculator that computes per-GPU requirements across:
- **No parallelism** (single GPU)
- **Data parallelism** (full model replicated on each GPU)
- **ZeRO Stage 1** (shard optimizer states across GPUs)
- **ZeRO Stage 2** (shard optimizer states + gradients)
- **ZeRO Stage 3** (shard everything)

You'll apply it to three models: GPT-2 (124M), LLaMA 7B, and LLaMA 70B.

**Before running, predict:**
- For a 7B model with mixed-precision Adam, what is the total training memory? (Hint: ~12 bytes/param)
- With ZeRO Stage 1 on 8 GPUs, the optimizer states are sharded 8 ways. Optimizer states are ~8 bytes/param. What is the per-GPU optimizer memory? What is the total per-GPU memory?
- Will ZeRO Stage 1 alone make a 70B model fit on 8 A100 GPUs (80 GB each)?

In [None]:
# --- Training Memory Calculator ---
#
# Mixed-precision Adam breakdown (per parameter):
#   - bf16 weights:    2 bytes (forward/backward)
#   - bf16 gradients:  2 bytes
#   - fp32 momentum:   4 bytes (Adam state)
#   - fp32 variance:   4 bytes (Adam state)
#   Total: 12 bytes/param
#   Optimizer states: 8 bytes/param (2/3 of total)

@dataclass
class ModelConfig:
    name: str
    num_params: float  # number of parameters


BYTES_PER_PARAM_WEIGHTS = 2     # bf16
BYTES_PER_PARAM_GRADIENTS = 2   # bf16
BYTES_PER_PARAM_MOMENTUM = 4    # fp32 Adam
BYTES_PER_PARAM_VARIANCE = 4    # fp32 Adam
BYTES_PER_PARAM_OPTIMIZER = BYTES_PER_PARAM_MOMENTUM + BYTES_PER_PARAM_VARIANCE  # 8
BYTES_PER_PARAM_TOTAL = (BYTES_PER_PARAM_WEIGHTS + BYTES_PER_PARAM_GRADIENTS
                         + BYTES_PER_PARAM_OPTIMIZER)  # 12


def training_memory(
    num_params: float,
    num_gpus: int = 1,
    zero_stage: int = 0,  # 0 = no ZeRO, 1/2/3 = ZeRO stages
) -> dict:
    """Compute per-GPU training memory for a given model and parallelism config.

    ZeRO stages:
      Stage 0 (data parallelism): full model replicated on each GPU
      Stage 1: shard optimizer states across GPUs
      Stage 2: shard optimizer states + gradients
      Stage 3: shard everything (weights + gradients + optimizer states)
    """
    # What each GPU stores depends on ZeRO stage
    weights_per_gpu = num_params * BYTES_PER_PARAM_WEIGHTS
    gradients_per_gpu = num_params * BYTES_PER_PARAM_GRADIENTS
    optimizer_per_gpu = num_params * BYTES_PER_PARAM_OPTIMIZER

    # ZeRO Stage 1: shard optimizer states
    if zero_stage >= 1:
        optimizer_per_gpu = optimizer_per_gpu / num_gpus

    # ZeRO Stage 2: also shard gradients
    if zero_stage >= 2:
        gradients_per_gpu = gradients_per_gpu / num_gpus

    # ZeRO Stage 3: also shard weights
    if zero_stage >= 3:
        weights_per_gpu = weights_per_gpu / num_gpus

    total_per_gpu = weights_per_gpu + gradients_per_gpu + optimizer_per_gpu

    return {
        'weights_per_gpu': weights_per_gpu,
        'gradients_per_gpu': gradients_per_gpu,
        'optimizer_per_gpu': optimizer_per_gpu,
        'total_per_gpu': total_per_gpu,
    }


# --- Apply to three models ---

models = [
    ModelConfig('GPT-2 (124M)', 124e6),
    ModelConfig('LLaMA 7B', 7e9),
    ModelConfig('LLaMA 70B', 70e9),
]

A100_MEMORY_GB = 80
A100_MEMORY_BYTES = A100_MEMORY_GB * 1e9

print('=== Single GPU Training Memory ===')
print(f'{"Model":<20} {"Weights":>10} {"Grads":>10} {"Optimizer":>10} {"Total":>10} {"Fits A100?":>12}')
print('-' * 75)
for model in models:
    mem = training_memory(model.num_params, num_gpus=1, zero_stage=0)
    fits = mem['total_per_gpu'] <= A100_MEMORY_BYTES
    print(f'{model.name:<20} '
          f'{format_bytes(mem["weights_per_gpu"]):>10} '
          f'{format_bytes(mem["gradients_per_gpu"]):>10} '
          f'{format_bytes(mem["optimizer_per_gpu"]):>10} '
          f'{format_bytes(mem["total_per_gpu"]):>10} '
          f'{"YES" if fits else "NO":>12}')

print()
print(f'A100 GPU memory: {A100_MEMORY_GB} GB')
print(f'70B model needs {format_bytes(70e9 * BYTES_PER_PARAM_TOTAL)} -- '
      f'{70e9 * BYTES_PER_PARAM_TOTAL / A100_MEMORY_BYTES:.0f}x more than a single A100.')
print(f'The model does not fit. Not "it\'s slow" -- it is physically impossible to begin.')

In [None]:
# --- ZeRO stages across GPU counts ---
# For each model, compute per-GPU memory with ZeRO stages 0-3 on 8 GPUs.

NUM_GPUS = 8

print(f'=== Per-GPU Memory with {NUM_GPUS} GPUs ===')
print()

for model in models:
    print(f'--- {model.name} ({format_number(model.num_params)} params) ---')
    print(f'{"ZeRO Stage":<14} {"Weights":>10} {"Grads":>10} {"Optimizer":>10} '
          f'{"Total/GPU":>10} {"Fits A100?":>12}')
    print('-' * 70)

    stage_labels = {
        0: 'No ZeRO (DP)',
        1: 'Stage 1',
        2: 'Stage 2',
        3: 'Stage 3',
    }

    for stage in [0, 1, 2, 3]:
        mem = training_memory(model.num_params, num_gpus=NUM_GPUS, zero_stage=stage)
        fits = mem['total_per_gpu'] <= A100_MEMORY_BYTES
        print(f'{stage_labels[stage]:<14} '
              f'{format_bytes(mem["weights_per_gpu"]):>10} '
              f'{format_bytes(mem["gradients_per_gpu"]):>10} '
              f'{format_bytes(mem["optimizer_per_gpu"]):>10} '
              f'{format_bytes(mem["total_per_gpu"]):>10} '
              f'{"YES" if fits else "NO":>12}')
    print()

In [None]:
# --- Visualize: per-GPU memory by ZeRO stage for the 70B model ---

fig, ax = plt.subplots(figsize=(10, 5))

stages = ['No ZeRO\n(Data Par.)', 'ZeRO\nStage 1', 'ZeRO\nStage 2', 'ZeRO\nStage 3']
weights_vals = []
grads_vals = []
opt_vals = []

model_70b = models[2]
for stage in [0, 1, 2, 3]:
    mem = training_memory(model_70b.num_params, num_gpus=NUM_GPUS, zero_stage=stage)
    weights_vals.append(mem['weights_per_gpu'] / 1e9)
    grads_vals.append(mem['gradients_per_gpu'] / 1e9)
    opt_vals.append(mem['optimizer_per_gpu'] / 1e9)

x = np.arange(len(stages))
bar_width = 0.5

ax.bar(x, weights_vals, bar_width, label='Weights (bf16)', color='#60a5fa', alpha=0.8)
ax.bar(x, grads_vals, bar_width, bottom=weights_vals, label='Gradients (bf16)', color='#34d399', alpha=0.8)
ax.bar(x, opt_vals, bar_width,
       bottom=[w + g for w, g in zip(weights_vals, grads_vals)],
       label='Optimizer States (fp32)', color='#f59e0b', alpha=0.8)

# A100 capacity line
ax.axhline(y=A100_MEMORY_GB, color='#f87171', linestyle='--', linewidth=2, label=f'A100 capacity ({A100_MEMORY_GB} GB)')

# Add total labels on top of each bar
for i, (w, g, o) in enumerate(zip(weights_vals, grads_vals, opt_vals)):
    total = w + g + o
    color = '#f87171' if total > A100_MEMORY_GB else '#34d399'
    ax.text(i, total + 8, f'{total:.0f} GB', ha='center', va='bottom',
            fontsize=10, fontweight='bold', color=color)

ax.set_ylabel('Per-GPU Memory (GB)', fontsize=12)
ax.set_title(f'LLaMA 70B Training Memory: Per-GPU on {NUM_GPUS} A100 GPUs',
             fontsize=13, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(stages, fontsize=10)
ax.legend(loc='upper right', fontsize=9)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_ylim(0, max(weights_vals[0] + grads_vals[0] + opt_vals[0], A100_MEMORY_GB) * 1.3)
plt.tight_layout()
plt.show()

print('\nKey observations:')
print('1. Without ZeRO, data parallelism requires the FULL model on each GPU -- no memory savings.')
print('2. ZeRO Stage 1 shards optimizer states (the orange bar) -- the biggest component.')
print('   Optimizer states drop from 560 GB to 70 GB per GPU. But total is still ~210 GB.')
print('3. ZeRO Stage 3 shards everything -- total per GPU drops to ~105 GB.')
print('   Still above 80 GB! Even ZeRO Stage 3 alone does not make 70B fit on 8 A100s.')
print('4. You need ZeRO PLUS tensor/pipeline parallelism to actually train 70B models.')
print('   This is why frontier model training combines multiple strategies.')

**What just happened:** The training memory breakdown makes the problem visceral. A 70B model needs 840 GB for training—over 10x what a single A100 provides. Data parallelism does not help because it *replicates* the full model. ZeRO Stage 1 targets the biggest component (optimizer states at 560 GB) and shards them across GPUs, but even with 8-way sharding, per-GPU memory is ~210 GB. Even ZeRO Stage 3 (sharding everything) gives ~105 GB per GPU—still above the 80 GB limit.

This is why frontier model training combines ZeRO with tensor parallelism (split weight matrices within layers) and pipeline parallelism (split layers across GPUs). No single strategy is sufficient. The choice of strategy is determined by which bottleneck dominates—memory, compute, or communication.

---

## Exercise 2: Speculative Decoding Simulator (Supported)

Remember the `generate()` loop from building nanoGPT: each token requires one forward pass through the entire model. For a 70B model, that means one forward pass per token, sequentially. Speculative decoding modifies this loop: a small, fast *draft model* generates K candidate tokens, then the large *target model* verifies all K in a single forward pass.

The key insight from the lesson: the speed does not come from the small model being fast. It comes from the large model verifying multiple tokens *in parallel* in one forward pass, instead of generating them one at a time.

In this exercise, you'll simulate speculative decoding with:
- A draft model that generates candidate tokens (with some probability of matching the target)
- A target model that verifies all candidates in one pass
- Acceptance following the "reject from first disagreement" rule

You'll measure acceptance rate at different draft lengths K and see the speedup.

**Before running, predict:**
- If the draft model has a 70% per-token match rate and drafts K=5 tokens, what is the probability that all 5 are accepted? (Hint: independent events)
- As K increases from 1 to 8, does the average number of accepted tokens per round increase, decrease, or plateau?
- If the draft model takes 10ms per token and the target model takes 50ms per forward pass (regardless of how many tokens it verifies), what draft length K maximizes tokens per second?

In [None]:
# --- Speculative Decoding Simulator ---
#
# We simulate the draft-then-verify loop WITHOUT actual language models.
# Instead, we model:
#   - Draft model: generates tokens that match the target with probability `match_prob`
#   - Target model: verifies all K draft tokens in one "forward pass"
#   - Acceptance rule: accept from the start until the first disagreement
#
# This captures the core mechanism: the large model verifies in parallel,
# and the speedup depends on the acceptance rate.

def simulate_speculative_round(
    k: int,
    match_prob: float,
) -> int:
    """Simulate one round of speculative decoding.

    The draft model produces K tokens. Each matches the target model's
    output with probability `match_prob` (independently).

    Returns the number of accepted tokens (0 to K).
    Acceptance rule: accept from start until first disagreement.
    Even if token 4 matches but token 3 doesn't, we stop at token 3.
    """
    # TODO: Simulate K token matches. For each draft position (0 to K-1):
    #   1. Generate a random number with np.random.random()
    #   2. If it's < match_prob, the token matches (accept and continue)
    #   3. If it's >= match_prob, the token doesn't match (reject and stop)
    #   Return the count of consecutive accepted tokens from the start.
    #
    # Note: even when we reject at position i, the target model has already
    # computed its own token for position i in the same forward pass.
    # So we always get at least 1 new token per round (the resampled one).
    # But for simplicity, we return the count of ACCEPTED draft tokens here.

    pass  # Replace with your implementation


def simulate_speculative_decoding(
    total_tokens: int,
    k: int,
    match_prob: float,
    draft_time_ms: float,
    target_time_ms: float,
    num_trials: int = 200,
) -> dict:
    """Simulate speculative decoding over many rounds and measure performance.

    Args:
        total_tokens: total tokens to generate
        k: number of draft tokens per round
        match_prob: probability each draft token matches target
        draft_time_ms: time for draft model to generate one token
        target_time_ms: time for target model's forward pass (constant, regardless of K)
        num_trials: number of simulation runs to average over

    Returns:
        dict with average metrics across trials
    """
    all_rounds = []
    all_accepted = []
    all_times = []

    for _ in range(num_trials):
        tokens_generated = 0
        num_rounds = 0
        total_time = 0.0
        accepted_counts = []

        while tokens_generated < total_tokens:
            # TODO: Simulate one speculative decoding round:
            #   1. Draft phase: draft model generates K tokens
            #      Time cost: k * draft_time_ms
            #   2. Verify phase: target model verifies all K in one pass
            #      Time cost: target_time_ms (constant!)
            #   3. Count accepted tokens from simulate_speculative_round(k, match_prob)
            #   4. Total new tokens this round = accepted + 1
            #      (the +1 is the target model's own token at the rejection point)
            #   5. Update tokens_generated, num_rounds, total_time, accepted_counts

            pass  # Replace with your implementation

        all_rounds.append(num_rounds)
        all_accepted.append(np.mean(accepted_counts))
        all_times.append(total_time)

    # Baseline: target model generating tokens one at a time (no speculation)
    baseline_time = total_tokens * target_time_ms

    return {
        'k': k,
        'match_prob': match_prob,
        'avg_rounds': np.mean(all_rounds),
        'avg_accepted_per_round': np.mean(all_accepted),
        'avg_tokens_per_round': np.mean(all_accepted) + 1,
        'avg_time_ms': np.mean(all_times),
        'baseline_time_ms': baseline_time,
        'speedup': baseline_time / np.mean(all_times),
    }


# Quick test
test_accepted = simulate_speculative_round(5, 0.7)
print(f'Test round (K=5, match_prob=0.7): {test_accepted} tokens accepted')
print('(Run a few times -- you should see values from 0 to 5, averaging around 2-3)')

<details>
<summary>Solution</summary>

The key insight is the acceptance rule: we accept consecutive tokens from the start until the first disagreement. This means acceptance follows a geometric distribution—each additional token has the same independent probability of matching.

```python
def simulate_speculative_round(k: int, match_prob: float) -> int:
    accepted = 0
    for _ in range(k):
        if np.random.random() < match_prob:
            accepted += 1
        else:
            break
    return accepted
```

For the simulation loop:

```python
while tokens_generated < total_tokens:
    # Draft phase: K tokens, each takes draft_time_ms
    draft_cost = k * draft_time_ms
    # Verify phase: one forward pass, constant cost
    verify_cost = target_time_ms
    # Count accepted
    accepted = simulate_speculative_round(k, match_prob)
    # Total new tokens = accepted drafts + 1 resampled token
    new_tokens = accepted + 1

    tokens_generated += new_tokens
    num_rounds += 1
    total_time += draft_cost + verify_cost
    accepted_counts.append(accepted)
```

The `+1` is crucial: even when the target model rejects at position i, it has already computed its own token for that position during the same forward pass. So every round produces at least 1 token, and up to K+1 tokens if all drafts are accepted (K accepted + 1 bonus token from the target model going one position further—though for simplicity we cap at the K accepted + 1 resampled).

</details>

### Helper: Working Speculative Decoding Simulator

**Run the cell below** to get working implementations for the analysis. If your implementation above works, this redefines the same functions.

In [None]:
# --- Reference implementation ---

def simulate_speculative_round(k: int, match_prob: float) -> int:
    """Simulate one speculative decoding round. Returns accepted count."""
    accepted = 0
    for _ in range(k):
        if np.random.random() < match_prob:
            accepted += 1
        else:
            break
    return accepted


def simulate_speculative_decoding(
    total_tokens: int,
    k: int,
    match_prob: float,
    draft_time_ms: float,
    target_time_ms: float,
    num_trials: int = 200,
) -> dict:
    """Simulate speculative decoding and measure performance."""
    all_rounds = []
    all_accepted = []
    all_times = []

    for _ in range(num_trials):
        tokens_generated = 0
        num_rounds = 0
        total_time = 0.0
        accepted_counts = []

        while tokens_generated < total_tokens:
            draft_cost = k * draft_time_ms
            verify_cost = target_time_ms
            accepted = simulate_speculative_round(k, match_prob)
            new_tokens = accepted + 1

            tokens_generated += new_tokens
            num_rounds += 1
            total_time += draft_cost + verify_cost
            accepted_counts.append(accepted)

        all_rounds.append(num_rounds)
        all_accepted.append(np.mean(accepted_counts))
        all_times.append(total_time)

    baseline_time = total_tokens * target_time_ms

    return {
        'k': k,
        'match_prob': match_prob,
        'avg_rounds': np.mean(all_rounds),
        'avg_accepted_per_round': np.mean(all_accepted),
        'avg_tokens_per_round': np.mean(all_accepted) + 1,
        'avg_time_ms': np.mean(all_times),
        'baseline_time_ms': baseline_time,
        'speedup': baseline_time / np.mean(all_times),
    }


np.random.seed(42)
print('Reference speculative decoding simulator loaded.')

In [None]:
# --- Sweep draft length K from 1 to 8 ---
#
# Timing model (realistic for 70B target, 7B draft):
#   Draft model: 10 ms per token (small, fast)
#   Target model: 50 ms per forward pass (large, but constant regardless of K)

np.random.seed(42)

TOTAL_TOKENS = 100
MATCH_PROB = 0.70
DRAFT_TIME_MS = 10.0
TARGET_TIME_MS = 50.0

k_values = list(range(1, 9))
results = []

print(f'=== Speculative Decoding: {TOTAL_TOKENS} tokens, match_prob={MATCH_PROB} ===')
print(f'Draft model: {DRAFT_TIME_MS} ms/token, Target model: {TARGET_TIME_MS} ms/pass')
print()
print(f'{"K":>3} {"Acc/Round":>10} {"Tok/Round":>10} {"Rounds":>8} '
      f'{"Time (ms)":>10} {"Baseline":>10} {"Speedup":>8}')
print('-' * 65)

for k in k_values:
    r = simulate_speculative_decoding(
        total_tokens=TOTAL_TOKENS,
        k=k,
        match_prob=MATCH_PROB,
        draft_time_ms=DRAFT_TIME_MS,
        target_time_ms=TARGET_TIME_MS,
    )
    results.append(r)
    print(f'{k:>3} {r["avg_accepted_per_round"]:>10.2f} {r["avg_tokens_per_round"]:>10.2f} '
          f'{r["avg_rounds"]:>8.1f} {r["avg_time_ms"]:>10.0f} '
          f'{r["baseline_time_ms"]:>10.0f} {r["speedup"]:>7.2f}x')

In [None]:
# --- Visualize: speedup and tokens per round vs K ---

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

speedups = [r['speedup'] for r in results]
toks_per_round = [r['avg_tokens_per_round'] for r in results]
accepted_per_round = [r['avg_accepted_per_round'] for r in results]

# Left plot: speedup vs K
ax1.plot(k_values, speedups, color='#34d399', linewidth=2, marker='o', markersize=6)
best_k = k_values[np.argmax(speedups)]
best_speedup = max(speedups)
ax1.axvline(x=best_k, color='#f59e0b', linestyle='--', alpha=0.7,
            label=f'Optimal K={best_k} ({best_speedup:.2f}x)')
ax1.axhline(y=1.0, color='#f87171', linestyle='--', alpha=0.5, label='No speculation (1x)')
ax1.set_xlabel('Draft Length K', fontsize=12)
ax1.set_ylabel('Speedup vs Baseline', fontsize=12)
ax1.set_title('Speculative Decoding Speedup', fontsize=13, fontweight='bold')
ax1.legend(fontsize=9)
ax1.set_xticks(k_values)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Right plot: accepted tokens per round vs K
ax2.bar(k_values, accepted_per_round, color='#60a5fa', alpha=0.7, label='Accepted drafts')
ax2.bar(k_values, [1] * len(k_values), bottom=accepted_per_round,
        color='#a78bfa', alpha=0.7, label='Resampled token (+1)')
# Theoretical max line
ax2.plot(k_values, [k + 1 for k in k_values], color='#f87171', linestyle='--',
         alpha=0.5, label='Theoretical max (all accepted)')
ax2.set_xlabel('Draft Length K', fontsize=12)
ax2.set_ylabel('Tokens per Round', fontsize=12)
ax2.set_title('Tokens Generated per Speculative Round', fontsize=13, fontweight='bold')
ax2.legend(fontsize=9)
ax2.set_xticks(k_values)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()

print(f'\nOptimal draft length: K={best_k} giving {best_speedup:.2f}x speedup.')
print(f'\nWhy does speedup eventually decrease with larger K?')
print(f'  - More draft tokens = more draft time (K * {DRAFT_TIME_MS} ms)')
print(f'  - But acceptance drops geometrically: P(all K match) = {MATCH_PROB}^K')
print(f'  - At K=8: P(all match) = {MATCH_PROB**8:.3f} -- unlikely to get all 8.')
print(f'  - The draft overhead grows linearly but accepted tokens plateau.')
print(f'  - Optimal K balances draft cost against expected accepted tokens.')

In [None]:
# --- How match probability affects optimal K ---
#
# A better draft model (higher match_prob) allows longer drafts.
# A worse draft model should use shorter drafts.

np.random.seed(42)

match_probs = [0.5, 0.6, 0.7, 0.8, 0.9]
colors = ['#f87171', '#fb923c', '#f59e0b', '#34d399', '#60a5fa']

fig, ax = plt.subplots(figsize=(10, 5))

for match_prob, color in zip(match_probs, colors):
    speedups_mp = []
    for k in k_values:
        r = simulate_speculative_decoding(
            total_tokens=TOTAL_TOKENS, k=k, match_prob=match_prob,
            draft_time_ms=DRAFT_TIME_MS, target_time_ms=TARGET_TIME_MS,
        )
        speedups_mp.append(r['speedup'])
    ax.plot(k_values, speedups_mp, color=color, linewidth=2, marker='o',
            markersize=5, label=f'match_prob={match_prob}')

ax.axhline(y=1.0, color='#334155', linestyle='--', alpha=0.5)
ax.set_xlabel('Draft Length K', fontsize=12)
ax.set_ylabel('Speedup vs Baseline', fontsize=12)
ax.set_title('Speculative Decoding: Match Probability vs Optimal K',
             fontsize=13, fontweight='bold')
ax.legend(fontsize=9)
ax.set_xticks(k_values)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

print('\nKey insight: the optimal K depends on the draft model quality.')
print('  - High match probability (good draft model) -> longer drafts are worthwhile')
print('  - Low match probability (weak draft model) -> short drafts or no speculation')
print('  - Below ~50% match rate, speculation barely helps (draft overhead dominates)')
print('\nThe speed comes from the large model verifying in PARALLEL, not from')
print('the small model being fast. A perfect draft model (match_prob=1.0) would')
print('give speedup = target_time / (K * draft_time + target_time) * K.')

**What just happened:** Speculative decoding trades draft model compute for parallel verification by the target model. The sweet spot depends on two factors: (1) how well the draft model approximates the target (match probability), and (2) the time ratio between draft and target forward passes.

At 70% match probability, the optimal draft length is around K=3-4. Longer drafts waste time because acceptance drops geometrically—`P(all K match) = 0.7^K`. At K=8, the probability of accepting all 8 is only 5.7%, so most of those draft tokens are wasted effort.

The critical misconception: the speedup does NOT come from the draft model being fast. It comes from the target model verifying K tokens in a single forward pass instead of generating them one at a time. The target model's forward pass costs roughly the same whether it processes 1 token or 5 (compute-bound matmul).

---

## Exercise 3: Continuous Batching Simulator (Supported)

Static batching during inference works like batched training: start N requests together, run until ALL finish. The problem: requests have different lengths. Short requests finish early but their batch slots sit idle, consuming memory and producing nothing, until the longest request completes.

Continuous batching fixes this: when a request completes, its slot is immediately filled with the next request from the queue. Like a restaurant waitlist—as a table opens, the next party is seated immediately.

In this exercise, you'll simulate both strategies on a queue of 50 requests with realistic length distributions and compare GPU utilization.

**Before running, predict:**
- With static batching (batch_size=8) and requests varying from 10 to 300 tokens, what happens to GPU utilization as short requests finish?
- Will continuous batching have higher or lower average latency per request than static batching?
- If the longest request in a static batch takes 300 tokens and the shortest takes 10, what fraction of compute is wasted?

In [None]:
# --- Continuous Batching Simulator ---
#
# We simulate an inference server processing a queue of requests.
# Each request has a target length (number of tokens to generate).
# One "step" generates one token for every active slot in the batch.

np.random.seed(42)

# Generate 50 requests with realistic length distribution
NUM_REQUESTS = 50
BATCH_SIZE = 8

# Lengths: mean 80, std 60, clipped to [10, 300]
request_lengths = np.clip(
    np.random.normal(loc=80, scale=60, size=NUM_REQUESTS).astype(int),
    10, 300
)

print(f'Generated {NUM_REQUESTS} requests.')
print(f'  Length range: {request_lengths.min()} to {request_lengths.max()} tokens')
print(f'  Mean length: {request_lengths.mean():.0f} tokens')
print(f'  Batch size: {BATCH_SIZE}')
print()

# Show first 10 request lengths
print(f'First 10 request lengths: {request_lengths[:10].tolist()}')

In [None]:
# --- Static Batching ---
#
# Process requests in fixed batches. Each batch waits for ALL requests
# to complete before the next batch starts.

def simulate_static_batching(
    request_lengths: np.ndarray,
    batch_size: int,
) -> dict:
    """Simulate static batching.

    Returns:
        dict with total_steps, utilization history, and per-request latencies.
    """
    n = len(request_lengths)
    utilization_history = []  # fraction of batch slots doing useful work at each step
    request_latencies = []    # total time each request waits (from batch start to batch end)
    total_steps = 0

    idx = 0  # next request to assign
    while idx < n:
        # Fill the batch
        batch_end = min(idx + batch_size, n)
        batch_lengths = request_lengths[idx:batch_end]
        batch_actual_size = len(batch_lengths)
        max_len = batch_lengths.max()

        # Run the batch for max_len steps
        for step in range(max_len):
            # How many requests are still active at this step?
            active = int(np.sum(batch_lengths > step))
            utilization_history.append(active / batch_size)

        total_steps += max_len

        # All requests in this batch have latency = max_len
        # (they all wait until the longest finishes)
        for length in batch_lengths:
            request_latencies.append(max_len)

        idx = batch_end

    return {
        'total_steps': total_steps,
        'utilization_history': utilization_history,
        'request_latencies': request_latencies,
        'avg_utilization': np.mean(utilization_history),
        'avg_latency': np.mean(request_latencies),
    }


static_result = simulate_static_batching(request_lengths, BATCH_SIZE)
print(f'=== Static Batching ===')
print(f'Total steps: {static_result["total_steps"]}')
print(f'Average utilization: {static_result["avg_utilization"]:.1%}')
print(f'Average latency per request: {static_result["avg_latency"]:.0f} steps')

In [None]:
# --- Continuous Batching ---
#
# When a request completes, its slot is immediately filled from the queue.
# The batch stays full (or as full as possible) at all times.

def simulate_continuous_batching(
    request_lengths: np.ndarray,
    batch_size: int,
) -> dict:
    """Simulate continuous batching.

    Each slot independently tracks its request's remaining tokens.
    When a request completes, the slot is filled from the queue.

    Returns:
        dict with total_steps, utilization history, and per-request latencies.
    """
    n = len(request_lengths)
    utilization_history = []
    # Track when each request starts and finishes for latency calculation
    request_start_step = [0] * n
    request_end_step = [0] * n

    # Initialize slots: each slot has (request_index, remaining_tokens)
    # Fill initial batch from the queue
    queue_idx = 0  # next request to pull from queue
    slots = []     # list of (request_index, remaining_tokens)

    # TODO: Fill initial batch slots from the queue.
    # For each slot up to batch_size (or until queue is empty):
    #   1. Assign request queue_idx to this slot
    #   2. Record request_start_step[queue_idx] = 0 (starts at step 0)
    #   3. Append (queue_idx, request_lengths[queue_idx]) to slots
    #   4. Increment queue_idx

    pass  # Replace with your implementation

    step = 0
    while len(slots) > 0:
        # Record utilization for this step
        utilization_history.append(len(slots) / batch_size)

        # TODO: Process one step and handle completions:
        # 1. Decrement remaining tokens for each slot
        # 2. Find completed slots (remaining <= 0)
        # 3. For completed requests: record request_end_step
        # 4. Replace completed slots with new requests from the queue
        #    (record their request_start_step)
        # 5. Remove any empty slots (no more requests to fill)
        #
        # Hint: process the step by creating a new slots list:
        #   new_slots = []
        #   for (req_idx, remaining) in slots:
        #       remaining -= 1
        #       if remaining <= 0:   # completed
        #           request_end_step[req_idx] = step + 1
        #           if queue_idx < n:  # fill slot from queue
        #               request_start_step[queue_idx] = step + 1
        #               new_slots.append((queue_idx, request_lengths[queue_idx]))
        #               queue_idx += 1
        #       else:  # still running
        #           new_slots.append((req_idx, remaining))
        #   slots = new_slots

        pass  # Replace with your implementation

        step += 1

    request_latencies = [
        request_end_step[i] - request_start_step[i]
        for i in range(n)
    ]

    return {
        'total_steps': step,
        'utilization_history': utilization_history,
        'request_latencies': request_latencies,
        'avg_utilization': np.mean(utilization_history),
        'avg_latency': np.mean(request_latencies),
    }


continuous_result = simulate_continuous_batching(request_lengths, BATCH_SIZE)
print(f'=== Continuous Batching ===')
print(f'Total steps: {continuous_result["total_steps"]}')
print(f'Average utilization: {continuous_result["avg_utilization"]:.1%}')
print(f'Average latency per request: {continuous_result["avg_latency"]:.0f} steps')
print()
print(f'=== Comparison ===')
print(f'{"Metric":<25} {"Static":>10} {"Continuous":>12}')
print('-' * 50)
print(f'{"Total steps":<25} {static_result["total_steps"]:>10} {continuous_result["total_steps"]:>12}')
print(f'{"Avg utilization":<25} {static_result["avg_utilization"]:>9.1%} {continuous_result["avg_utilization"]:>11.1%}')
print(f'{"Avg latency (steps)":<25} {static_result["avg_latency"]:>10.0f} {continuous_result["avg_latency"]:>12.0f}')

<details>
<summary>Solution</summary>

The key insight is that continuous batching treats each batch slot independently. When a request completes, its slot is immediately available for the next request from the queue—like a restaurant seating the next party as soon as a table opens.

```python
# Fill initial batch
for _ in range(min(batch_size, n)):
    request_start_step[queue_idx] = 0
    slots.append((queue_idx, request_lengths[queue_idx]))
    queue_idx += 1

# Process one step
new_slots = []
for (req_idx, remaining) in slots:
    remaining -= 1
    if remaining <= 0:
        request_end_step[req_idx] = step + 1
        if queue_idx < n:
            request_start_step[queue_idx] = step + 1
            new_slots.append((queue_idx, request_lengths[queue_idx]))
            queue_idx += 1
    else:
        new_slots.append((req_idx, remaining))
slots = new_slots
```

The improvement is in *slot utilization*, not batch size. Static batching has 8 slots but many sit idle after their request finishes. Continuous batching keeps slots filled, so GPU utilization stays near 100% until the queue is nearly empty.

Note that continuous batching does not change the latency for any individual request—each request still takes exactly as many steps as its length. The improvement is in *throughput*: total steps to serve all requests decreases because GPU cycles are not wasted on empty slots.

</details>

### Helper: Working Continuous Batching Simulator

**Run the cell below** to get working implementations for the visualization. If your implementation above works, this redefines the same function.

In [None]:
# --- Reference implementation ---

def simulate_continuous_batching(
    request_lengths: np.ndarray,
    batch_size: int,
) -> dict:
    """Simulate continuous batching."""
    n = len(request_lengths)
    utilization_history = []
    request_start_step = [0] * n
    request_end_step = [0] * n

    queue_idx = 0
    slots = []

    for _ in range(min(batch_size, n)):
        request_start_step[queue_idx] = 0
        slots.append((queue_idx, int(request_lengths[queue_idx])))
        queue_idx += 1

    step = 0
    while len(slots) > 0:
        utilization_history.append(len(slots) / batch_size)

        new_slots = []
        for (req_idx, remaining) in slots:
            remaining -= 1
            if remaining <= 0:
                request_end_step[req_idx] = step + 1
                if queue_idx < n:
                    request_start_step[queue_idx] = step + 1
                    new_slots.append((queue_idx, int(request_lengths[queue_idx])))
                    queue_idx += 1
            else:
                new_slots.append((req_idx, remaining))
        slots = new_slots
        step += 1

    request_latencies = [
        request_end_step[i] - request_start_step[i]
        for i in range(n)
    ]

    return {
        'total_steps': step,
        'utilization_history': utilization_history,
        'request_latencies': request_latencies,
        'avg_utilization': np.mean(utilization_history),
        'avg_latency': np.mean(request_latencies),
    }


# Recompute with reference implementation
static_result = simulate_static_batching(request_lengths, BATCH_SIZE)
continuous_result = simulate_continuous_batching(request_lengths, BATCH_SIZE)

print(f'Reference continuous batching loaded.')
print(f'Static total steps: {static_result["total_steps"]}, '
      f'utilization: {static_result["avg_utilization"]:.1%}')
print(f'Continuous total steps: {continuous_result["total_steps"]}, '
      f'utilization: {continuous_result["avg_utilization"]:.1%}')

In [None]:
# --- Visualize: utilization over time for both strategies ---

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8), sharex=False)

# Top: Static batching utilization
static_util = static_result['utilization_history']
ax1.fill_between(range(len(static_util)), static_util, alpha=0.3, color='#f87171')
ax1.plot(range(len(static_util)), static_util, color='#f87171', linewidth=1)
ax1.axhline(y=static_result['avg_utilization'], color='#f87171', linestyle='--',
            alpha=0.7, label=f'Average: {static_result["avg_utilization"]:.1%}')
ax1.set_ylabel('GPU Utilization', fontsize=11)
ax1.set_title('Static Batching: Utilization Over Time', fontsize=12, fontweight='bold')
ax1.set_ylim(0, 1.1)
ax1.legend(fontsize=9)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Bottom: Continuous batching utilization
cont_util = continuous_result['utilization_history']
ax2.fill_between(range(len(cont_util)), cont_util, alpha=0.3, color='#34d399')
ax2.plot(range(len(cont_util)), cont_util, color='#34d399', linewidth=1)
ax2.axhline(y=continuous_result['avg_utilization'], color='#34d399', linestyle='--',
            alpha=0.7, label=f'Average: {continuous_result["avg_utilization"]:.1%}')
ax2.set_xlabel('Step', fontsize=11)
ax2.set_ylabel('GPU Utilization', fontsize=11)
ax2.set_title('Continuous Batching: Utilization Over Time', fontsize=12, fontweight='bold')
ax2.set_ylim(0, 1.1)
ax2.legend(fontsize=9)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()

print('Static batching: utilization drops in a sawtooth pattern as short requests finish.')
print('Each batch starts at 100% and decays. GPU is doing wasted work on empty slots.')
print()
print('Continuous batching: utilization stays near 100% for most of the run.')
print('It only drops at the very end when the queue empties and remaining requests trickle out.')
print()
print(f'Total steps -- Static: {static_result["total_steps"]}, '
      f'Continuous: {continuous_result["total_steps"]}')
print(f'Throughput improvement: '
      f'{static_result["total_steps"] / continuous_result["total_steps"]:.2f}x fewer steps '
      f'to serve all {NUM_REQUESTS} requests.')

In [None]:
# --- Wasted compute analysis ---

# In static batching, compute useful tokens and wasted tokens per batch
total_useful_tokens = int(request_lengths.sum())

# Static: total tokens generated = total steps * batch_size
# (each step generates for batch_size slots, whether useful or not)
static_total_tokens = sum(
    int(request_lengths[i:min(i + BATCH_SIZE, len(request_lengths))].max()) * BATCH_SIZE
    for i in range(0, len(request_lengths), BATCH_SIZE)
)
static_wasted = static_total_tokens - total_useful_tokens

# Continuous: useful work = total useful tokens, wasted only at tail end
continuous_total_tokens = sum(
    int(len(slots_at_step) )
    for slots_at_step in [[1]] * continuous_result['total_steps']  # approximate
)
# Better calculation: sum of utilization * batch_size at each step
continuous_total_work = sum(
    u * BATCH_SIZE for u in continuous_result['utilization_history']
)
continuous_wasted = continuous_total_work - total_useful_tokens

print(f'=== Compute Waste Analysis ===')
print(f'Total useful tokens to generate: {total_useful_tokens}')
print()
print(f'Static batching:')
print(f'  Total GPU-token-steps: {static_total_tokens}')
print(f'  Wasted: {static_wasted} ({static_wasted / static_total_tokens:.1%})')
print(f'  These are forward passes that produce nothing -- the GPU computes')
print(f'  attention, runs the FFN, generates a token... and throws it away.')
print()
print(f'Continuous batching:')
print(f'  Total GPU-token-steps: {continuous_total_work:.0f}')
print(f'  Wasted: {continuous_wasted:.0f} ({continuous_wasted / continuous_total_work:.1%})')
print(f'  Minimal waste -- only at the tail end when the queue empties.')

**What just happened:** Static batching wastes compute because short requests finish early but their batch slots remain occupied. The utilization graph shows the sawtooth pattern: each batch starts at 100% and drops as requests complete. The batch cannot move on until the longest request finishes.

Continuous batching eliminates this waste by treating each slot independently. When a request completes at step 20, its slot is immediately filled with the next request from the queue. The GPU stays busy producing useful tokens instead of empty ones.

The improvement is in *slot utilization*, not batch size. Both strategies use the same batch size (8). The difference is what happens to completed slots. Static batching leaves them empty. Continuous batching fills them.

---

## Exercise 4: Parallelism Strategy Advisor (Independent)

Throughout this lesson, you learned three parallelism strategies—data, tensor, and pipeline—plus ZeRO optimizer sharding. Each addresses a different bottleneck. The choice depends on whether the model fits on one GPU, how many GPUs you have, and how much memory is available.

**Your task:** Build a function `recommend_parallelism(config)` that takes a model configuration and hardware setup, then recommends a parallelism strategy.

The function should:
1. Compute total training memory (12 bytes/param for mixed-precision Adam)
2. Determine if data parallelism alone works (full model fits on one GPU)
3. Check if ZeRO stages reduce per-GPU memory enough
4. If not, recommend tensor and/or pipeline parallelism
5. Output a clear recommendation with reasoning

Test on these five configurations:
- **(a)** GPT-2 (124M params), 1 GPU, 80 GB
- **(b)** GPT-2 (124M params), 4 GPUs, 80 GB each
- **(c)** LLaMA 7B, 4 GPUs, 80 GB each
- **(d)** LLaMA 70B, 8 GPUs, 80 GB each
- **(e)** Hypothetical 175B model, 64 GPUs, 80 GB each

**Decision logic to implement:**
- Training memory per param = 12 bytes (bf16 weights + bf16 grads + fp32 Adam)
- If total training memory fits on one GPU: data parallelism (replicate model, split data)
- If not but ZeRO Stage 3 on all GPUs makes it fit: ZeRO + data parallelism
- If not: need model parallelism. Compute minimum tensor/pipeline parallelism degree.
  - Tensor parallelism degree = minimum GPUs to split each layer so it fits in memory
  - Pipeline parallelism degree = remaining GPUs for layer splitting
  - Data parallelism degree = any remaining GPUs

**No skeleton is provided.** Design the function and output format yourself.

**Before running, predict:**
- GPT-2 (124M) needs ~1.5 GB for training. What strategy for 1 GPU? For 4 GPUs?
- LLaMA 70B needs 840 GB. With 8 GPUs and ZeRO Stage 3, per-GPU is ~105 GB. Does it fit?
- The 175B model needs ~2.1 TB. How many GPUs would ZeRO Stage 3 alone require?

In [None]:
# Your parallelism strategy advisor here.
#
# Suggested approach:
#
# 1. Define a config (dataclass or dict) with:
#    name, num_params, num_gpus, gpu_memory_gb
#
# 2. Write recommend_parallelism(config) that:
#    a. Computes total training memory = num_params * 12 bytes
#    b. Checks if single-GPU works (total <= gpu_memory)
#    c. Checks ZeRO stages (shard optimizer, gradients, weights)
#    d. If ZeRO alone doesn't suffice, compute model parallelism needs
#    e. Returns a recommendation with reasoning
#
# 3. Test on configurations (a) through (e)
#
# 4. Print clear, formatted output showing:
#    - Model size and total training memory
#    - Hardware (GPUs, memory per GPU)
#    - Recommendation: which strategies and why
#    - Per-GPU memory after applying the recommendation



<details>
<summary>Solution</summary>

The key insight is that parallelism strategy selection is a decision tree driven by one question: does the model's training memory fit on the available GPUs? Each strategy addresses a different bottleneck: data parallelism scales throughput when the model fits, ZeRO reduces per-GPU memory by sharding, and tensor/pipeline parallelism split the model itself when even sharding is not enough.

```python
@dataclass
class HardwareConfig:
    name: str
    num_params: float
    num_gpus: int
    gpu_memory_gb: float


def recommend_parallelism(cfg: HardwareConfig) -> None:
    """Recommend parallelism strategy and print analysis."""
    BYTES_PER_PARAM = 12  # bf16 weights + bf16 grads + fp32 Adam
    BYTES_PER_PARAM_WEIGHTS = 2
    BYTES_PER_PARAM_GRADS = 2
    BYTES_PER_PARAM_OPT = 8

    total_memory = cfg.num_params * BYTES_PER_PARAM
    gpu_memory_bytes = cfg.gpu_memory_gb * 1e9

    print(f'=== {cfg.name} ===')
    print(f'  Parameters: {format_number(cfg.num_params)}')
    print(f'  Training memory: {format_bytes(total_memory)}')
    print(f'  Hardware: {cfg.num_gpus} GPU(s) x {cfg.gpu_memory_gb} GB')
    print(f'  Total GPU memory: {format_bytes(cfg.num_gpus * gpu_memory_bytes)}')
    print()

    # Check 1: Does the model fit on a single GPU?
    if total_memory <= gpu_memory_bytes:
        print(f'  Strategy: DATA PARALLELISM')
        print(f'  Reason: Full model ({format_bytes(total_memory)}) fits on one GPU '
              f'({cfg.gpu_memory_gb} GB).')
        if cfg.num_gpus > 1:
            print(f'  Benefit: {cfg.num_gpus}x throughput from {cfg.num_gpus}-way data parallelism.')
        print(f'  Per-GPU memory: {format_bytes(total_memory)}')
        print()
        return

    # Check 2: Does ZeRO help enough?
    for stage, stage_name in [(1, 'Stage 1'), (2, 'Stage 2'), (3, 'Stage 3')]:
        w = cfg.num_params * BYTES_PER_PARAM_WEIGHTS
        g = cfg.num_params * BYTES_PER_PARAM_GRADS
        o = cfg.num_params * BYTES_PER_PARAM_OPT
        if stage >= 1:
            o = o / cfg.num_gpus
        if stage >= 2:
            g = g / cfg.num_gpus
        if stage >= 3:
            w = w / cfg.num_gpus
        per_gpu = w + g + o

        if per_gpu <= gpu_memory_bytes:
            print(f'  Strategy: ZeRO {stage_name} + DATA PARALLELISM')
            print(f'  Reason: Full model does not fit on one GPU '
                  f'({format_bytes(total_memory)} > {cfg.gpu_memory_gb} GB).')
            print(f'  ZeRO {stage_name} shards '
                  f'{"optimizer states" if stage == 1 else "optimizer + gradients" if stage == 2 else "everything"} '
                  f'across {cfg.num_gpus} GPUs.')
            print(f'  Per-GPU memory: {format_bytes(per_gpu)} '
                  f'(weights: {format_bytes(w)}, grads: {format_bytes(g)}, opt: {format_bytes(o)})')
            print()
            return

    # Check 3: Need model parallelism
    # Even ZeRO Stage 3 is not enough. Need to split the model.
    # Find minimum tensor parallelism degree so per-GPU fits.
    # With tensor parallelism degree T and ZeRO Stage 3 on all GPUs:
    #   Each GPU holds 1/T of the weights, 1/T of the gradients,
    #   and 1/num_gpus of the optimizer states.
    #   But tensor parallelism also means each GPU only handles 1/T of compute.
    # For memory estimation with combined strategies:
    #   Tensor-parallel degree T, pipeline-parallel degree P,
    #   data-parallel degree D, where T * P * D = num_gpus.
    #   Each GPU stores 1/(T*P) of weights+grads, 1/(T*P*D) of optimizer via ZeRO.

    best_config = None
    for tp in [1, 2, 4, 8]:
        for pp in [1, 2, 4, 8, 16]:
            if tp * pp > cfg.num_gpus:
                continue
            dp = cfg.num_gpus // (tp * pp)
            if dp < 1:
                continue
            if tp * pp * dp != cfg.num_gpus:
                continue

            # Memory per GPU with this configuration + ZeRO Stage 3:
            model_shard = 1.0 / (tp * pp)  # fraction of model per GPU
            w = cfg.num_params * model_shard * BYTES_PER_PARAM_WEIGHTS
            g = cfg.num_params * model_shard * BYTES_PER_PARAM_GRADS
            # ZeRO Stage 3 shards optimizer across data-parallel replicas
            o = cfg.num_params * model_shard * BYTES_PER_PARAM_OPT / dp
            per_gpu = w + g + o

            if per_gpu <= gpu_memory_bytes:
                if best_config is None or tp + pp < best_config[0] + best_config[1]:
                    best_config = (tp, pp, dp, per_gpu, w, g, o)

    if best_config is None:
        print(f'  Strategy: INSUFFICIENT RESOURCES')
        print(f'  Reason: {cfg.num_gpus} GPUs with {cfg.gpu_memory_gb} GB each '
              f'cannot accommodate {format_bytes(total_memory)} training memory.')
        print(f'  Need more GPUs or GPUs with more memory.')
        print()
        return

    tp, pp, dp, per_gpu, w, g, o = best_config
    strategies = []
    if tp > 1:
        strategies.append(f'Tensor Parallelism (degree {tp})')
    if pp > 1:
        strategies.append(f'Pipeline Parallelism (degree {pp})')
    if dp > 1:
        strategies.append(f'Data Parallelism (degree {dp})')
    strategies.append('ZeRO Stage 3')

    print(f'  Strategy: {" + ".join(strategies)}')
    print(f'  Reason: Model does not fit even with ZeRO alone. '
          f'Need to split the model across GPUs.')
    print(f'  Layout: {tp}x tensor x {pp}x pipeline x {dp}x data = '
          f'{tp * pp * dp} GPUs')
    print(f'  Per-GPU memory: {format_bytes(per_gpu)} '
          f'(weights: {format_bytes(w)}, grads: {format_bytes(g)}, opt: {format_bytes(o)})')
    print()


# --- Test on five configurations ---

configs = [
    HardwareConfig('(a) GPT-2, 1 GPU', 124e6, 1, 80),
    HardwareConfig('(b) GPT-2, 4 GPUs', 124e6, 4, 80),
    HardwareConfig('(c) LLaMA 7B, 4 GPUs', 7e9, 4, 80),
    HardwareConfig('(d) LLaMA 70B, 8 GPUs', 70e9, 8, 80),
    HardwareConfig('(e) 175B Model, 64 GPUs', 175e9, 64, 80),
]

for cfg in configs:
    recommend_parallelism(cfg)

print('=== Summary ===')
print('The choice of parallelism strategy is determined by the bottleneck:')
print('  - Model fits on 1 GPU? -> Data parallelism (scale throughput)')
print('  - Model too large? -> ZeRO shards optimizer/gradient memory')
print('  - Still too large? -> Tensor + pipeline parallelism split the model')
print('  - Communication overhead shapes every decision.')
```

**Design choices:**
- The advisor follows a clear priority: data parallelism first (simplest), then ZeRO (reduces memory without splitting the model), then model parallelism (most complex, highest communication overhead).
- For model parallelism, we search over tensor and pipeline degrees to find the minimum split that fits in memory. Tensor parallelism is preferred for fewer GPUs (lower communication overhead within a node) and pipeline parallelism for more GPUs (lower frequency communication between stages).
- The `T * P * D = num_gpus` constraint ensures all GPUs are utilized.

</details>

---

## Key Takeaways

1. **Training memory is dominated by optimizer states (2/3 of total).** A 70B model needs ~840 GB for mixed-precision Adam training. ZeRO targets the biggest cost first by sharding optimizer states across GPUs, but even full sharding (Stage 3) may not be enough for the largest models.

2. **Three parallelism strategies address three different bottlenecks.** Data parallelism splits data (throughput scaling when the model fits). Tensor parallelism splits layers (when individual layers are too large). Pipeline parallelism splits blocks across GPUs (when the model has too many layers). Frontier models combine all three.

3. **Speculative decoding turns sequential generation into parallel verification.** The speedup comes from the target model verifying K draft tokens in one forward pass, not from the draft model being fast. Optimal draft length K depends on the match probability—diminishing returns at larger K because acceptance drops geometrically.

4. **Continuous batching eliminates wasted compute from static batching.** By filling completed slots immediately from a request queue, GPU utilization stays near 100% instead of decaying as short requests finish. The improvement is in slot utilization, not batch size.

5. **Communication overhead is the constraint that shapes every parallelism decision.** Moving data between GPUs is orders of magnitude slower than computing on it. Every strategy in this lesson is a different answer to the same question: how do you distribute work across devices when communication is expensive?