# Model Training and Distributed Systems

This comprehensive notebook covers advanced topics in model training, distributed systems, and resource estimation:

1. **Training Token Optimization** - Chinchilla optimal scaling laws
2. **Mixture of Experts (MoE)** - Parameter calculations and token requirements
3. **GPU Requirements & Training Time** - FLOPs calculations and resource estimation
4. **FLOPs of Transformer Block** - Detailed forward/backward pass calculations
5. **Model Sharding Strategies** - TP, FSDP, PP, and more
6. **Communication Patterns** - Collective operations and bandwidth analysis
7. **Parallelism Trade-offs** - When to use each strategy
8. **Inference** - Latency, throughput, and KV cache considerations
9. **RL Training** - Async vs sync training comparison

This notebook is designed for interview preparation and practical understanding of large-scale ML systems.

## 1. Import Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from typing import Tuple, Dict, List
from dataclasses import dataclass

# Set style for better visualizations
plt.style.use('default')
plt.rcParams['figure.figsize'] = (12, 6)

# Set random seed
np.random.seed(42)

print("NumPy version:", np.__version__)

## 2. Training Token Optimization: Chinchilla Scaling Laws

The Chinchilla paper (Hoffmann et al., 2022) found that for compute-optimal training:

### Key Finding:
**For every parameter in the model, you should train on approximately 20 tokens.**

### Formula:
```
Optimal Tokens = 20 × Model Parameters
```

### Why This Matters:
- Most models before Chinchilla were **undertrained**
- GPT-3 (175B params) was trained on 300B tokens (~1.7x)
- Chinchilla (70B params) trained on 1.4T tokens (20x) outperformed GPT-3
- Modern models (LLaMA, Gemini) follow or exceed this ratio

### Trade-offs:
- **Inference Cost**: Larger models cost more to run
- **Training Cost**: Training on 20x tokens is expensive
- **Real-world**: Many train beyond 20x for better performance (e.g., LLaMA trained on 1-2T tokens)

In [None]:
def chinchilla_optimal_tokens(model_params: float) -> float:
    """
    Calculate optimal number of training tokens according to Chinchilla.
    
    Args:
        model_params: Number of model parameters (in billions)
    
    Returns:
        Optimal training tokens (in billions)
    """
    return 20 * model_params

def compute_budget(model_params: float, tokens: float) -> float:
    """
    Estimate compute budget in FLOPs.
    
    Args:
        model_params: Model parameters in billions
        tokens: Training tokens in billions
    
    Returns:
        Total FLOPs (in 10^21 = ZettaFLOPs)
    """
    # Rule of thumb: 6 FLOPs per parameter per token (FWD + BWD)
    flops_per_token = 6 * model_params * 1e9
    total_flops = flops_per_token * tokens * 1e9
    return total_flops / 1e21  # Return in ZettaFLOPs

# Example calculations
models = {
    "GPT-3": 175,
    "Chinchilla": 70,
    "LLaMA-2-7B": 7,
    "LLaMA-2-13B": 13,
    "LLaMA-2-70B": 70,
}

print("Model Training Token Requirements (Chinchilla Optimal):")
print("=" * 70)
for model_name, params in models.items():
    optimal_tokens = chinchilla_optimal_tokens(params)
    compute = compute_budget(params, optimal_tokens)
    print(f"{model_name:20s}: {params:6.0f}B params -> {optimal_tokens:7.0f}B tokens ({compute:6.2f} ZettaFLOPs)")

# GPT-3 actual training
gpt3_actual_tokens = 300
gpt3_ratio = gpt3_actual_tokens / models["GPT-3"]
print(f"\nGPT-3 was trained on {gpt3_actual_tokens}B tokens = {gpt3_ratio:.1f}x ratio (undertrained!)")

# LLaMA-2 actual training
llama2_actual_tokens = 2000
llama2_70b_ratio = llama2_actual_tokens / models["LLaMA-2-70B"]
print(f"LLaMA-2-70B trained on {llama2_actual_tokens}B tokens = {llama2_70b_ratio:.1f}x ratio (overtrained for quality)")

In [None]:
# Visualization: Chinchilla Scaling Law
param_sizes = np.logspace(0, 3, 50)  # 1B to 1000B parameters
optimal_tokens = chinchilla_optimal_tokens(param_sizes)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Tokens vs Parameters
ax1.loglog(param_sizes, optimal_tokens, 'b-', linewidth=2, label='Chinchilla Optimal (20x)')
ax1.loglog(param_sizes, param_sizes, 'r--', linewidth=2, alpha=0.5, label='1x ratio (GPT-3 style)')

# Mark specific models
for model_name, params in models.items():
    tokens = chinchilla_optimal_tokens(params)
    ax1.scatter(params, tokens, s=100, zorder=5)
    ax1.annotate(model_name, (params, tokens), xytext=(5, 5), 
                textcoords='offset points', fontsize=8)

ax1.set_xlabel('Model Parameters (Billions)', fontsize=12)
ax1.set_ylabel('Training Tokens (Billions)', fontsize=12)
ax1.set_title('Chinchilla Optimal: 20 Tokens per Parameter', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Compute Budget
compute_budgets = [compute_budget(p, chinchilla_optimal_tokens(p)) for p in param_sizes]
ax2.loglog(param_sizes, compute_budgets, 'g-', linewidth=2)
ax2.set_xlabel('Model Parameters (Billions)', fontsize=12)
ax2.set_ylabel('Compute Budget (ZettaFLOPs)', fontsize=12)
ax2.set_title('Total Compute Required for Chinchilla Optimal', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 3. Mixture of Experts (MoE) Parameters

MoE models have different parameter counting:

### Total vs Active Parameters:
- **Total Parameters**: All parameters in the model (including all experts)
- **Active Parameters**: Parameters used for a single forward pass

### Formula for MoE Layer:
```
Total Parameters = Shared_Params + (Num_Experts × Expert_Size)
Active Parameters = Shared_Params + (Top_K × Expert_Size)
```

### Example: GPT-4 (rumored):
- Total: ~1.7T parameters (8 experts × 220B each)
- Active: ~220B per forward pass (top-1 routing)
- Training tokens: Should be based on **active parameters**

### Optimal Tokens for MoE:
```
Optimal Tokens = 20 × Active_Parameters
```

In [None]:
@dataclass
class MoEConfig:
    """Configuration for a Mixture of Experts model."""
    shared_params: float  # Billions
    num_experts: int
    expert_size: float  # Billions
    top_k: int
    
    @property
    def total_params(self) -> float:
        """Total parameters including all experts."""
        return self.shared_params + (self.num_experts * self.expert_size)
    
    @property
    def active_params(self) -> float:
        """Active parameters per forward pass."""
        return self.shared_params + (self.top_k * self.expert_size)
    
    def optimal_tokens(self) -> float:
        """Optimal training tokens based on active parameters."""
        return chinchilla_optimal_tokens(self.active_params)
    
    def compute_efficiency(self) -> float:
        """Ratio of active to total parameters."""
        return self.active_params / self.total_params

# Example MoE configurations
moe_models = {
    "GPT-4 (rumored)": MoEConfig(shared_params=100, num_experts=8, expert_size=220, top_k=1),
    "Mixtral-8x7B": MoEConfig(shared_params=5, num_experts=8, expert_size=7, top_k=2),
    "Switch-C (Google)": MoEConfig(shared_params=100, num_experts=128, expert_size=10, top_k=1),
}

print("MoE Model Analysis:")
print("=" * 90)
print(f"{'Model':<25} {'Total':<12} {'Active':<12} {'Efficiency':<12} {'Optimal Tokens'}")
print("=" * 90)

for name, config in moe_models.items():
    print(f"{name:<25} {config.total_params:>8.0f}B     {config.active_params:>8.0f}B     "
          f"{config.compute_efficiency()*100:>8.1f}%     {config.optimal_tokens():>8.0f}B tokens")

print("\nKey Insight: MoE models have high total params but low active params,")
print("making them compute-efficient but memory-intensive.")

## 4. GPU Requirements & Training Time

### Formula for Training Time:

```
Total FLOPs = 6 × N × D
```
Where:
- N = Number of parameters
- D = Number of tokens
- 6 = FLOPs per param per token (≈2 for FWD, ≈4 for BWD)

```
GPU Days = Total_FLOPs / (Num_GPUs × GPU_FLOPS × Utilization × 86400)
```

### GPU Specs (BF16/FP16):
- **A100**: ~312 TFLOPS
- **H100**: ~990 TFLOPS (3.2x faster than A100)
- **H200**: ~990 TFLOPS (same compute as H100, more memory)

### Typical Utilization:
- **Good**: 40-50% MFU (Model FLOPs Utilization)
- **Great**: 50-60% MFU
- **Excellent**: 60%+ MFU (very hard to achieve)

In [None]:
class GPUSpecs:
    """GPU specifications."""
    A100_TFLOPS = 312
    H100_TFLOPS = 990
    H200_TFLOPS = 990
    
    A100_MEMORY_GB = 80
    H100_MEMORY_GB = 80
    H200_MEMORY_GB = 141

def estimate_training_time(
    model_params_b: float,
    training_tokens_b: float,
    num_gpus: int,
    gpu_tflops: float,
    mfu: float = 0.5
) -> Dict[str, float]:
    """
    Estimate training time and cost.
    
    Args:
        model_params_b: Model parameters in billions
        training_tokens_b: Training tokens in billions
        num_gpus: Number of GPUs
        gpu_tflops: GPU TFLOPS (theoretical)
        mfu: Model FLOPs Utilization (0.0 to 1.0)
    
    Returns:
        Dictionary with training time estimates
    """
    # Calculate total FLOPs
    total_flops = 6 * model_params_b * 1e9 * training_tokens_b * 1e9
    
    # Calculate effective FLOPS per GPU
    effective_tflops_per_gpu = gpu_tflops * mfu
    total_tflops = num_gpus * effective_tflops_per_gpu
    
    # Calculate time
    seconds = total_flops / (total_tflops * 1e12)
    hours = seconds / 3600
    days = hours / 24
    
    # GPU hours
    gpu_hours = hours * num_gpus
    
    return {
        'total_flops': total_flops,
        'seconds': seconds,
        'hours': hours,
        'days': days,
        'gpu_hours': gpu_hours,
        'effective_tflops': total_tflops,
    }

# Example: Train LLaMA-2-70B
print("Training Time Estimation for LLaMA-2-70B:")
print("=" * 70)

model_size = 70  # billions
tokens = 2000  # billions (2T)

for gpu_type, tflops in [('A100', GPUSpecs.A100_TFLOPS), ('H100', GPUSpecs.H100_TFLOPS)]:
    for num_gpus in [512, 1024, 2048]:
        result = estimate_training_time(model_size, tokens, num_gpus, tflops, mfu=0.5)
        print(f"{gpu_type} x {num_gpus:4d}: {result['days']:6.1f} days "
              f"({result['gpu_hours']:8,.0f} GPU-hours, {result['effective_tflops']:7,.0f} TFLOPS)")
    print()

In [None]:
# Visualization: Training Time vs GPU Count
model_size = 70  # LLaMA-2-70B
tokens = 2000
gpu_counts = np.array([128, 256, 512, 1024, 2048, 4096])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Training days vs GPU count
for gpu_type, tflops, color in [('A100', GPUSpecs.A100_TFLOPS, 'blue'), 
                                 ('H100', GPUSpecs.H100_TFLOPS, 'red')]:
    days = [estimate_training_time(model_size, tokens, n, tflops)['days'] for n in gpu_counts]
    ax1.plot(gpu_counts, days, marker='o', linewidth=2, label=gpu_type, color=color)

ax1.set_xlabel('Number of GPUs', fontsize=12)
ax1.set_ylabel('Training Days', fontsize=12)
ax1.set_title(f'Training Time: {model_size}B params, {tokens}B tokens (50% MFU)', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_xscale('log')
ax1.set_yscale('log')

# Plot 2: Effect of MFU
mfu_values = np.linspace(0.3, 0.7, 5)
num_gpus = 1024
days_by_mfu = [estimate_training_time(model_size, tokens, num_gpus, 
                                       GPUSpecs.H100_TFLOPS, mfu)['days'] 
               for mfu in mfu_values]

ax2.plot(mfu_values * 100, days_by_mfu, marker='s', linewidth=2, color='green')
ax2.set_xlabel('Model FLOPs Utilization (%)', fontsize=12)
ax2.set_ylabel('Training Days', fontsize=12)
ax2.set_title(f'Impact of MFU (H100 x {num_gpus})', fontsize=14)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Key Takeaway: Doubling GPUs halves training time (perfect scaling).")
print("In practice, communication overhead limits scaling efficiency.")

## 5. FLOPs of Transformer Block

### Transformer Block Components:
1. **Self-Attention**: Q, K, V projections + attention computation + output projection
2. **Feed-Forward Network (FFN)**: Two linear layers
3. **Layer Norms**: Negligible FLOPs

### FLOPs Calculation:

#### For a single token (sequence length = 1):

**Attention FLOPs**:
```
Q, K, V projections: 3 × (2 × d_model × d_model) = 6 × d_model²
Attention computation: 2 × seq_len × d_model × seq_len ≈ 2 × d_model × seq_len²
Output projection: 2 × d_model × d_model = 2 × d_model²
Total Attention: 8 × d_model² + 2 × d_model × seq_len²
```

**FFN FLOPs**:
```
First layer: 2 × d_model × d_ff = 2 × d_model × (4 × d_model) = 8 × d_model²
Second layer: 2 × d_ff × d_model = 2 × (4 × d_model) × d_model = 8 × d_model²
Total FFN: 16 × d_model²
```

**Total per block (ignoring attention computation for long sequences)**:
```
FWD FLOPs ≈ 24 × d_model² per token
BWD FLOPs ≈ 2 × FWD = 48 × d_model²
Total (FWD + BWD) = 72 × d_model² per token
```

### Full Model:
```
Total FLOPs per token = 72 × d_model² × num_layers
```

### Relationship:
- **BWD = 2 × FWD** (computing gradients requires 2x compute)
- **Full step = FWD + BWD = 3 × FWD**

In [None]:
def calculate_transformer_flops(
    d_model: int,
    num_layers: int,
    seq_len: int,
    vocab_size: int = 50000,
    include_attention_compute: bool = False
) -> Dict[str, float]:
    """
    Calculate FLOPs for a transformer model.
    
    Args:
        d_model: Model dimension
        num_layers: Number of transformer blocks
        seq_len: Sequence length
        vocab_size: Vocabulary size
        include_attention_compute: Include attention matrix computation
    
    Returns:
        Dictionary with FLOPs breakdown
    """
    # Per-layer FLOPs
    # Attention projections (Q, K, V, O)
    attn_proj_flops = 8 * d_model * d_model
    
    # Attention computation (if included)
    attn_compute_flops = 0
    if include_attention_compute:
        attn_compute_flops = 2 * d_model * seq_len * seq_len
    
    # FFN (typically 4x expansion)
    d_ff = 4 * d_model
    ffn_flops = 2 * d_model * d_ff + 2 * d_ff * d_model
    ffn_flops = 16 * d_model * d_model
    
    # Per-layer FWD
    layer_fwd_flops = attn_proj_flops + attn_compute_flops + ffn_flops
    
    # Embedding and unembedding
    embed_flops = 2 * vocab_size * d_model
    
    # Total FWD
    total_fwd = num_layers * layer_fwd_flops + embed_flops
    
    # BWD = 2 × FWD
    total_bwd = 2 * total_fwd
    
    # Full step
    total_step = total_fwd + total_bwd
    
    return {
        'fwd_per_layer': layer_fwd_flops,
        'fwd_total': total_fwd,
        'bwd_total': total_bwd,
        'step_total': total_step,
        'fwd_to_bwd_ratio': total_bwd / total_fwd,
        'step_to_fwd_ratio': total_step / total_fwd,
    }

# Example: GPT-3 style model
gpt3_config = {
    'd_model': 12288,
    'num_layers': 96,
    'seq_len': 2048,
    'vocab_size': 50257
}

flops = calculate_transformer_flops(**gpt3_config)

print("FLOPs Analysis for GPT-3 Style Model:")
print("=" * 60)
print(f"Model: d_model={gpt3_config['d_model']}, layers={gpt3_config['num_layers']}")
print(f"\nFWD per layer: {flops['fwd_per_layer']:,} FLOPs")
print(f"FWD total:     {flops['fwd_total']:,} FLOPs")
print(f"BWD total:     {flops['bwd_total']:,} FLOPs")
print(f"Full step:     {flops['step_total']:,} FLOPs")
print(f"\nBWD / FWD ratio: {flops['fwd_to_bwd_ratio']:.1f}x")
print(f"Step / FWD ratio: {flops['step_to_fwd_ratio']:.1f}x")
print(f"\nThis matches the rule: BWD = 2×FWD, Total = 3×FWD")

In [None]:
# Verification: 6N per token rule
def verify_6n_rule(model_params_b: float, d_model: int, num_layers: int):
    """
    Verify the 6N FLOPs per token approximation.
    """
    # Detailed calculation
    flops = calculate_transformer_flops(d_model, num_layers, seq_len=1)
    detailed_flops = flops['step_total']
    
    # Simple rule: 6N
    simple_flops = 6 * model_params_b * 1e9
    
    print(f"Model: {model_params_b}B params, d_model={d_model}, layers={num_layers}")
    print(f"Detailed calculation: {detailed_flops:,.0f} FLOPs")
    print(f"Simple rule (6N):     {simple_flops:,.0f} FLOPs")
    print(f"Ratio: {detailed_flops / simple_flops:.2f}")
    print()

# Test with different models
print("Verification of 6N Rule:")
print("=" * 60)
verify_6n_rule(7, 4096, 32)  # LLaMA-7B
verify_6n_rule(70, 8192, 80)  # LLaMA-70B
verify_6n_rule(175, 12288, 96)  # GPT-3

print("The 6N rule is a good approximation for training FLOPs!")

## 6. Model Sharding Strategies

Different parallelism strategies for training large models:

### 1. **Tensor Parallel (TP)**
- **What**: Split weight matrices across devices
- **How**: Each GPU holds a slice of weight matrix
- **Example**: For a 4096×4096 matrix with TP=4, each GPU holds 4096×1024
- **Communication**: AllReduce after attention, AllReduce after MLP
- **Use case**: Large models that don't fit on single GPU, need low latency
- **Typical values**: TP=2, 4, 8 (within a node)

### 2. **Pipeline Parallel (PP)**
- **What**: Split model layers across devices
- **How**: Each GPU holds consecutive layers
- **Example**: 32-layer model with PP=4, each GPU holds 8 layers
- **Communication**: Point-to-point (P2P) activation passing
- **Use case**: Very deep models, minimize communication
- **Drawback**: GPU bubble (idle time), needs micro-batching
- **Typical values**: PP=2, 4, 8

### 3. **Fully Sharded Data Parallel (FSDP)**
- **What**: Data parallelism with sharded optimizer states
- **Variants**:
  - **ZeRO-1**: Shard optimizer states only
  - **ZeRO-2**: Shard optimizer states + gradients
  - **ZeRO-3**: Shard optimizer states + gradients + parameters
- **Communication**: 
  - ZeRO-1: AllReduce for gradients
  - ZeRO-2: ReduceScatter for gradients
  - ZeRO-3: AllGather for parameters, ReduceScatter for gradients
- **Use case**: Training medium to large models with data parallelism

### 4. **Gradient Checkpointing (Activation Checkpointing)**
- **What**: Trade compute for memory
- **How**: Don't save activations during forward, recompute during backward
- **Memory saved**: ~N/(C+1) where C is number of checkpoints
- **Compute cost**: +33% to +100% depending on strategy
- **Use case**: When memory-bound, not compute-bound

### 5. **Expert Parallel (EP)**
- **What**: For MoE models, distribute experts across devices
- **How**: Each GPU holds subset of experts
- **Communication**: AllToAll for token routing
- **Use case**: MoE models with many experts

### 6. **Context Parallel (CP)**
- **What**: Split sequence length across devices
- **How**: Each GPU processes part of the sequence
- **Communication**: AllGather for attention (Ring Attention)
- **Use case**: Very long sequences (>100k tokens)

In [None]:
# Memory requirements for different parallelism strategies

def calculate_memory_per_gpu(
    model_params_b: float,
    tp_size: int = 1,
    pp_size: int = 1,
    dp_size: int = 1,
    zero_stage: int = 0,
    dtype_bytes: int = 2,  # BF16/FP16
    gradient_checkpointing: bool = False
) -> Dict[str, float]:
    """
    Calculate memory per GPU for different parallelism strategies.
    
    Args:
        model_params_b: Model parameters in billions
        tp_size: Tensor parallel size
        pp_size: Pipeline parallel size
        dp_size: Data parallel size
        zero_stage: ZeRO stage (0, 1, 2, 3)
        dtype_bytes: Bytes per parameter (2 for BF16/FP16, 4 for FP32)
        gradient_checkpointing: Whether using gradient checkpointing
    
    Returns:
        Dictionary with memory breakdown in GB
    """
    total_params = model_params_b * 1e9
    
    # Parameters memory (sharded by TP and PP)
    params_per_gpu = total_params / (tp_size * pp_size)
    if zero_stage == 3:
        params_per_gpu = total_params / (tp_size * pp_size * dp_size)
    
    # Gradients memory
    grads_per_gpu = params_per_gpu
    if zero_stage >= 2:
        grads_per_gpu = total_params / (tp_size * pp_size * dp_size)
    
    # Optimizer states (Adam: 2x for momentum and variance)
    optimizer_per_gpu = 2 * params_per_gpu * 4  # FP32 optimizer states
    if zero_stage >= 1:
        optimizer_per_gpu = 2 * (total_params / (tp_size * pp_size * dp_size)) * 4
    
    # Activations (rough estimate)
    # For transformer: ~activations ≈ 34 * batch_size * seq_len * hidden_dim * num_layers
    # Simplified: assume activations ≈ 10x model params for typical batch
    activation_multiplier = 3 if gradient_checkpointing else 10
    activations_per_gpu = (total_params / (tp_size * pp_size)) * activation_multiplier * dtype_bytes
    
    # Convert to GB
    params_gb = params_per_gpu * dtype_bytes / 1e9
    grads_gb = grads_per_gpu * dtype_bytes / 1e9
    optimizer_gb = optimizer_per_gpu / 1e9
    activations_gb = activations_per_gpu / 1e9
    
    total_gb = params_gb + grads_gb + optimizer_gb + activations_gb
    
    return {
        'parameters': params_gb,
        'gradients': grads_gb,
        'optimizer': optimizer_gb,
        'activations': activations_gb,
        'total': total_gb
    }

# Example: LLaMA-2-70B on different configurations
model_size = 70

configs = [
    ("Baseline (no parallelism)", 1, 1, 1, 0, False),
    ("TP=8", 8, 1, 1, 0, False),
    ("TP=8 + Gradient Checkpointing", 8, 1, 1, 0, True),
    ("TP=8 + ZeRO-1", 8, 1, 8, 1, False),
    ("TP=8 + ZeRO-2", 8, 1, 8, 2, False),
    ("TP=8 + ZeRO-3", 8, 1, 8, 3, False),
]

print(f"Memory Requirements for {model_size}B Parameter Model:")
print("=" * 80)
print(f"{'Configuration':<40} {'Total Memory':<15} {'Per GPU (80GB fit?)'}")
print("=" * 80)

for config_name, tp, pp, dp, zero, gc in configs:
    mem = calculate_memory_per_gpu(model_size, tp, pp, dp, zero, gradient_checkpointing=gc)
    fits = "✓ Yes" if mem['total'] < 80 else "✗ No"
    print(f"{config_name:<40} {mem['total']:>10.1f} GB     {fits}")

print("\nNote: These are rough estimates. Actual memory usage varies with batch size,")
print("sequence length, and implementation details.")

## 7. Communication Patterns

Understanding collective operations is crucial for distributed training:

### Collective Operations:

#### 1. **AllReduce**
- **Operation**: Sum tensors across all ranks, broadcast result to all
- **Tensor shape**: Same on all ranks (e.g., [d_model, d_model])
- **Bytes moved**: 2 × tensor_size (in ring allreduce)
- **Used in**: TP (after attention & MLP), standard DDP
- **Cost**: 2 × S / B where S = size, B = bandwidth

#### 2. **AllGather**
- **Operation**: Gather tensors from all ranks, broadcast to all
- **Input shape**: [batch_size / N, ...] per rank
- **Output shape**: [batch_size, ...] on all ranks
- **Bytes moved**: (N-1)/N × total_size per rank
- **Used in**: ZeRO-3 (parameter gathering), Context Parallel

#### 3. **ReduceScatter**
- **Operation**: Reduce tensors, scatter results
- **Input shape**: [N × size] per rank
- **Output shape**: [size] per rank
- **Bytes moved**: (N-1)/N × total_size
- **Used in**: ZeRO-2/3 (gradient reduction)

#### 4. **AllToAll**
- **Operation**: Each rank sends different data to each other rank
- **Bytes moved**: Total_data_size
- **Used in**: Expert Parallel (MoE token routing)

#### 5. **Point-to-Point (P2P)**
- **Operation**: Direct send/receive between two ranks
- **Used in**: Pipeline Parallel (activation passing)

### Communication Cost Formula:
```
Latency = α + β × Message_Size
```
Where:
- α = Network latency (microseconds)
- β = Inverse bandwidth (1/bandwidth)
- Message_Size = Bytes to transfer

In [None]:
# Communication cost estimation

class NetworkSpecs:
    """Network specifications for different interconnects."""
    # Bandwidth in GB/s
    NVLINK_BW = 600  # NVLink 4.0 (H100)
    INFINIBAND_BW = 400  # InfiniBand NDR
    ETHERNET_BW = 100  # 100GbE
    
    # Latency in microseconds
    NVLINK_LATENCY = 1
    INFINIBAND_LATENCY = 2
    ETHERNET_LATENCY = 10

def estimate_communication_time(
    message_size_gb: float,
    bandwidth_gbs: float,
    latency_us: float,
    algorithm: str = "ring"
) -> float:
    """
    Estimate communication time for collective operations.
    
    Args:
        message_size_gb: Size of message in GB
        bandwidth_gbs: Network bandwidth in GB/s
        latency_us: Network latency in microseconds
        algorithm: Algorithm type ('ring', 'tree', 'direct')
    
    Returns:
        Time in milliseconds
    """
    # Convert latency to seconds
    latency_s = latency_us / 1e6
    
    # Transfer time
    transfer_time_s = message_size_gb / bandwidth_gbs
    
    # Total time (latency + transfer)
    total_time_s = latency_s + transfer_time_s
    
    # Convert to milliseconds
    return total_time_s * 1000

def analyze_tp_communication(
    d_model: int,
    tp_size: int,
    batch_size: int,
    seq_len: int,
    bandwidth_gbs: float,
    latency_us: float
) -> Dict[str, float]:
    """
    Analyze communication for Tensor Parallel.
    
    TP requires 2 AllReduce per layer (after attention and after MLP).
    """
    # AllReduce after attention: reduce [batch, seq_len, d_model]
    attention_size_gb = batch_size * seq_len * d_model * 2 / 1e9  # BF16
    
    # AllReduce after MLP: same size
    mlp_size_gb = attention_size_gb
    
    # Time per AllReduce (Ring AllReduce: 2(N-1)/N ≈ 2 for large N)
    attention_time = estimate_communication_time(2 * attention_size_gb, bandwidth_gbs, latency_us)
    mlp_time = estimate_communication_time(2 * mlp_size_gb, bandwidth_gbs, latency_us)
    
    return {
        'attention_allreduce_ms': attention_time,
        'mlp_allreduce_ms': mlp_time,
        'total_per_layer_ms': attention_time + mlp_time,
        'bytes_per_layer_gb': 4 * attention_size_gb  # 2 allreduces, each moves 2x data
    }

# Example: Analyze TP=8 for LLaMA-70B
print("Tensor Parallel Communication Analysis:")
print("=" * 70)

config = {
    'd_model': 8192,
    'tp_size': 8,
    'batch_size': 4,
    'seq_len': 2048,
}

for network_name, bw, lat in [
    ("NVLink", NetworkSpecs.NVLINK_BW, NetworkSpecs.NVLINK_LATENCY),
    ("InfiniBand", NetworkSpecs.INFINIBAND_BW, NetworkSpecs.INFINIBAND_LATENCY),
]:
    result = analyze_tp_communication(**config, bandwidth_gbs=bw, latency_us=lat)
    print(f"\n{network_name} ({bw} GB/s):")
    print(f"  Attention AllReduce: {result['attention_allreduce_ms']:.2f} ms")
    print(f"  MLP AllReduce:       {result['mlp_allreduce_ms']:.2f} ms")
    print(f"  Total per layer:     {result['total_per_layer_ms']:.2f} ms")
    print(f"  Bytes per layer:     {result['bytes_per_layer_gb']:.3f} GB")

print("\nKey Insight: TP requires fast interconnect (NVLink/InfiniBand).")
print("Use TP within a node, not across nodes!")

In [None]:
# FSDP Communication Analysis

def analyze_fsdp_communication(
    model_params_b: float,
    num_layers: int,
    dp_size: int,
    zero_stage: int,
    bandwidth_gbs: float,
    latency_us: float
) -> Dict[str, float]:
    """
    Analyze communication for FSDP/ZeRO.
    """
    total_params = model_params_b * 1e9
    params_per_layer = total_params / num_layers
    
    results = {}
    
    if zero_stage == 0 or zero_stage == 1:
        # AllReduce for gradients
        grad_size_gb = total_params * 2 / 1e9  # BF16
        comm_time = estimate_communication_time(2 * grad_size_gb, bandwidth_gbs, latency_us)
        results = {
            'type': 'AllReduce',
            'total_comm_ms': comm_time,
            'comm_per_layer_ms': comm_time / num_layers,
            'bytes_total_gb': 2 * grad_size_gb
        }
    
    elif zero_stage == 2:
        # ReduceScatter for gradients
        grad_size_gb = total_params * 2 / 1e9
        comm_time = estimate_communication_time((dp_size - 1) / dp_size * grad_size_gb, 
                                                 bandwidth_gbs, latency_us)
        results = {
            'type': 'ReduceScatter',
            'total_comm_ms': comm_time,
            'comm_per_layer_ms': comm_time / num_layers,
            'bytes_total_gb': (dp_size - 1) / dp_size * grad_size_gb
        }
    
    elif zero_stage == 3:
        # AllGather for params + ReduceScatter for grads (per layer)
        params_layer_gb = params_per_layer * 2 / 1e9
        
        # AllGather time per layer
        allgather_time = estimate_communication_time((dp_size - 1) / dp_size * params_layer_gb,
                                                      bandwidth_gbs, latency_us)
        
        # ReduceScatter time per layer
        reducescatter_time = estimate_communication_time((dp_size - 1) / dp_size * params_layer_gb,
                                                          bandwidth_gbs, latency_us)
        
        results = {
            'type': 'AllGather + ReduceScatter',
            'allgather_per_layer_ms': allgather_time,
            'reducescatter_per_layer_ms': reducescatter_time,
            'comm_per_layer_ms': allgather_time + reducescatter_time,
            'total_comm_ms': (allgather_time + reducescatter_time) * num_layers,
            'bytes_per_layer_gb': 2 * (dp_size - 1) / dp_size * params_layer_gb
        }
    
    return results

# Example: Compare ZeRO stages
print("FSDP/ZeRO Communication Analysis:")
print("=" * 70)

model_size = 70
num_layers = 80
dp_size = 64

for zero_stage in [1, 2, 3]:
    result = analyze_fsdp_communication(
        model_size, num_layers, dp_size, zero_stage,
        NetworkSpecs.INFINIBAND_BW, NetworkSpecs.INFINIBAND_LATENCY
    )
    
    print(f"\nZeRO-{zero_stage} ({result['type']}):")
    if zero_stage < 3:
        print(f"  Total communication: {result['total_comm_ms']:.2f} ms")
        print(f"  Per layer:           {result['comm_per_layer_ms']:.2f} ms")
    else:
        print(f"  AllGather per layer:       {result['allgather_per_layer_ms']:.2f} ms")
        print(f"  ReduceScatter per layer:   {result['reducescatter_per_layer_ms']:.2f} ms")
        print(f"  Total per layer:           {result['comm_per_layer_ms']:.2f} ms")
        print(f"  Total for all layers:      {result['total_comm_ms']:.2f} ms")

print("\nWhy ZeRO-3 is rarely used:")
print("  - High communication overhead (AllGather + ReduceScatter per layer)")
print("  - 2× communication per layer vs ZeRO-2")
print("  - Better to use TP or PP for large models")

## 8. Parallelism Trade-offs and Best Practices

### When to Use Each Strategy:

#### **Tensor Parallel (TP)**:
- ✅ **Use when**: Model doesn't fit in single GPU, have fast interconnect (NVLink)
- ✅ **Use within**: Single node (up to 8 GPUs)
- ❌ **Don't use**: Across nodes (too slow)
- **Typical**: TP=2, 4, or 8

#### **Pipeline Parallel (PP)**:
- ✅ **Use when**: Model very deep, want to minimize communication
- ✅ **Use with**: Micro-batching to reduce bubble
- ❌ **Don't use**: If batch size is small (large bubble)
- **Typical**: PP=2, 4, or 8

#### **Data Parallel (DP) / FSDP**:
- ✅ **ZeRO-1**: Always safe, minimal overhead
- ✅ **ZeRO-2**: Good balance, use for most training
- ⚠️ **ZeRO-3**: Rarely used - high communication overhead
- ❌ **ZeRO-3**: Don't use unless desperate for memory

#### **Gradient Checkpointing**:
- ✅ **Use when**: Memory-bound, not compute-bound
- ❌ **Don't use**: If already compute-bound (adds 33-100% compute)

### Typical Production Setup:

```
3D Parallelism = TP × PP × DP
```

**Example for 1024 GPUs (128 nodes × 8 GPUs/node)**:
- TP = 8 (within node, use all NVLink bandwidth)
- PP = 4 (across nodes)
- DP = 32 (remaining parallelism)
- Total: 8 × 4 × 32 = 1024 GPUs

### Why ZeRO-3 is Rarely Used:

1. **Communication Overhead**: AllGather parameters before each layer
2. **Latency**: Multiple small communications vs one large communication
3. **Bandwidth**: Uses 2x communication bandwidth per layer
4. **Better Alternatives**: Use TP or PP to reduce memory instead

**Real-world**: Companies use TP+PP+DP with ZeRO-1 or ZeRO-2, not ZeRO-3.

In [None]:
# Practical parallelism configuration

def recommend_parallelism(
    model_params_b: float,
    num_gpus: int,
    gpus_per_node: int = 8,
    gpu_memory_gb: int = 80
) -> Dict[str, any]:
    """
    Recommend parallelism configuration for a given setup.
    """
    num_nodes = num_gpus // gpus_per_node
    
    # Rule of thumb: model fits in GPU if params_gb * 20 < gpu_memory
    # (factor of 20 accounts for gradients, optimizer, activations)
    model_size_gb = model_params_b * 2  # BF16
    needs_model_parallel = (model_size_gb * 20) > gpu_memory_gb
    
    # Recommend configuration
    if not needs_model_parallel:
        # Pure data parallel
        config = {
            'tp_size': 1,
            'pp_size': 1,
            'dp_size': num_gpus,
            'zero_stage': 2,
            'gradient_checkpointing': False,
            'rationale': 'Model fits on single GPU, use pure DP with ZeRO-2'
        }
    elif num_nodes == 1:
        # Single node: use TP
        tp_size = min(gpus_per_node, 8)
        config = {
            'tp_size': tp_size,
            'pp_size': 1,
            'dp_size': num_gpus // tp_size,
            'zero_stage': 1,
            'gradient_checkpointing': True,
            'rationale': f'Single node, use TP={tp_size} within node'
        }
    else:
        # Multi-node: use 3D parallelism
        tp_size = gpus_per_node  # TP within node
        
        # Decide PP size based on model size
        if model_params_b > 100:
            pp_size = min(4, num_nodes // 2)
        elif model_params_b > 50:
            pp_size = 2
        else:
            pp_size = 1
        
        dp_size = num_gpus // (tp_size * pp_size)
        
        config = {
            'tp_size': tp_size,
            'pp_size': pp_size,
            'dp_size': dp_size,
            'zero_stage': 2,
            'gradient_checkpointing': True,
            'rationale': f'Multi-node 3D parallelism: TP={tp_size} (within node), '
                        f'PP={pp_size} (across nodes), DP={dp_size}'
        }
    
    config['total_gpus'] = config['tp_size'] * config['pp_size'] * config['dp_size']
    return config

# Examples
print("Parallelism Recommendations:")
print("=" * 80)

test_cases = [
    ("LLaMA-2-7B", 7, 64),
    ("LLaMA-2-13B", 13, 128),
    ("LLaMA-2-70B", 70, 512),
    ("GPT-3-175B", 175, 1024),
    ("GPT-4-1.7T (MoE)", 1700, 2048),
]

for model_name, params, gpus in test_cases:
    config = recommend_parallelism(params, gpus)
    print(f"\n{model_name} on {gpus} GPUs:")
    print(f"  TP = {config['tp_size']}, PP = {config['pp_size']}, DP = {config['dp_size']}")
    print(f"  ZeRO-{config['zero_stage']}, Gradient Checkpointing = {config['gradient_checkpointing']}")
    print(f"  Rationale: {config['rationale']}")


## 9. Inference: Latency and Throughput

Inference has two phases with different characteristics:

### Prefill Phase (First Token):
- **Operation**: Process entire input prompt
- **Characteristic**: Compute-bound (matrix multiplications)
- **FLOPs**: `2 × N × prompt_len` (per token in prompt)
- **Latency**: `prompt_len × (2N / GPU_FLOPS)`
- **Can batch**: Yes, efficiently!

### Decode Phase (Subsequent Tokens):
- **Operation**: Generate one token at a time
- **Characteristic**: Memory-bound (reading KV cache + weights)
- **FLOPs**: `2 × N` per token
- **Latency**: Dominated by memory bandwidth
- **Memory reads**: `N × 2 bytes` (parameters) + KV cache
- **Can batch**: Yes, but limited by KV cache memory

### KV Cache Memory:

For each token generated, we store keys and values:

```
KV_cache_per_token = 2 × num_layers × d_model × 2 bytes
```

For batch size B and sequence length S:
```
Total_KV_cache = B × S × 2 × num_layers × d_model × 2 bytes
```

### Latency Formula:

**Prefill (compute-bound)**:
```
Latency_prefill = prompt_len × 2N / GPU_FLOPS
```

**Decode (memory-bound)**:
```
Latency_decode = (N × 2 bytes) / Memory_Bandwidth
```

For H100:
- GPU_FLOPS = 990 TFLOPS (BF16)
- Memory_Bandwidth = 3.35 TB/s (HBM3)

In [None]:
# Inference latency calculation

def calculate_inference_latency(
    model_params_b: float,
    prompt_len: int,
    gen_len: int,
    batch_size: int,
    gpu_tflops: float = 990,  # H100
    memory_bandwidth_tbs: float = 3.35,  # H100 HBM3
    d_model: int = 8192,
    num_layers: int = 80
) -> Dict[str, float]:
    """
    Calculate inference latency for prefill and decode phases.
    
    Args:
        model_params_b: Model parameters in billions
        prompt_len: Input prompt length
        gen_len: Number of tokens to generate
        batch_size: Batch size
        gpu_tflops: GPU TFLOPS (theoretical)
        memory_bandwidth_tbs: Memory bandwidth in TB/s
        d_model: Model dimension
        num_layers: Number of layers
    
    Returns:
        Dictionary with latency estimates
    """
    # Convert to actual numbers
    model_params = model_params_b * 1e9
    gpu_flops = gpu_tflops * 1e12
    memory_bandwidth_bs = memory_bandwidth_tbs * 1e12
    
    # Prefill: compute-bound
    # FLOPs = batch_size × prompt_len × 2N
    prefill_flops = batch_size * prompt_len * 2 * model_params
    prefill_time_s = prefill_flops / gpu_flops
    prefill_time_ms = prefill_time_s * 1000
    
    # Decode: memory-bound
    # Memory reads per token = N × 2 bytes (parameters)
    bytes_per_token = model_params * 2  # BF16
    decode_time_per_token_s = bytes_per_token / memory_bandwidth_bs
    decode_time_per_token_ms = decode_time_per_token_s * 1000
    
    # Total decode time for generating gen_len tokens
    # Note: batch_size doesn't affect latency much if memory-bound
    total_decode_time_ms = decode_time_per_token_ms * gen_len
    
    # KV cache memory
    # Per token: 2 (K,V) × num_layers × d_model × 2 bytes
    kv_cache_per_token_bytes = 2 * num_layers * d_model * 2
    kv_cache_per_seq_gb = (prompt_len + gen_len) * kv_cache_per_token_bytes / 1e9
    kv_cache_total_gb = batch_size * kv_cache_per_seq_gb
    
    # Total latency
    total_latency_ms = prefill_time_ms + total_decode_time_ms
    
    # Tokens per second
    total_tokens = batch_size * gen_len
    tokens_per_second = total_tokens / (total_latency_ms / 1000)
    
    return {
        'prefill_ms': prefill_time_ms,
        'decode_per_token_ms': decode_time_per_token_ms,
        'decode_total_ms': total_decode_time_ms,
        'total_latency_ms': total_latency_ms,
        'tokens_per_second': tokens_per_second,
        'kv_cache_per_seq_gb': kv_cache_per_seq_gb,
        'kv_cache_total_gb': kv_cache_total_gb,
    }

# Example: LLaMA-2-70B inference
print("Inference Latency for LLaMA-2-70B on H100:")
print("=" * 70)

test_configs = [
    ("Short prompt, short gen", 100, 50, 1),
    ("Long prompt, short gen", 2000, 50, 1),
    ("Short prompt, long gen", 100, 500, 1),
    ("Batched (bs=8)", 100, 50, 8),
]

for name, prompt_len, gen_len, batch_size in test_configs:
    result = calculate_inference_latency(
        model_params_b=70,
        prompt_len=prompt_len,
        gen_len=gen_len,
        batch_size=batch_size,
        d_model=8192,
        num_layers=80
    )
    
    print(f"\n{name}:")
    print(f"  Prompt: {prompt_len} tokens, Generate: {gen_len} tokens, Batch: {batch_size}")
    print(f"  Prefill time:        {result['prefill_ms']:6.1f} ms")
    print(f"  Decode per token:    {result['decode_per_token_ms']:6.1f} ms")
    print(f"  Total decode time:   {result['decode_total_ms']:6.1f} ms")
    print(f"  Total latency:       {result['total_latency_ms']:6.1f} ms")
    print(f"  Throughput:          {result['tokens_per_second']:6.1f} tokens/s")
    print(f"  KV cache per seq:    {result['kv_cache_per_seq_gb']:6.2f} GB")
    print(f"  KV cache total:      {result['kv_cache_total_gb']:6.2f} GB")

print("\nKey Insight: Decode is memory-bound. Time per token ≈ constant.")
print("Prefill can be batched efficiently, decode is limited by KV cache memory.")

In [None]:
# GPU requirements for serving

def estimate_serving_gpus(
    model_params_b: float,
    num_users: int,
    requests_per_second_per_user: float,
    avg_prompt_len: int,
    avg_gen_len: int,
    target_latency_ms: float,
    gpu_memory_gb: int = 80,
    d_model: int = 8192,
    num_layers: int = 80
) -> Dict[str, any]:
    """
    Estimate number of GPUs needed for serving.
    """
    # Total requests per second
    total_rps = num_users * requests_per_second_per_user
    
    # Calculate single request latency
    single_request = calculate_inference_latency(
        model_params_b=model_params_b,
        prompt_len=avg_prompt_len,
        gen_len=avg_gen_len,
        batch_size=1,
        d_model=d_model,
        num_layers=num_layers
    )
    
    # Throughput per GPU (tokens/s)
    tokens_per_request = avg_gen_len
    throughput_per_gpu = single_request['tokens_per_second']
    requests_per_gpu = throughput_per_gpu / tokens_per_request
    
    # KV cache memory constraint
    kv_cache_per_request = single_request['kv_cache_per_seq_gb']
    # Leave 20GB for model weights (rough estimate)
    model_weight_memory = model_params_b * 2 / 1e9  # BF16 in GB
    available_for_kv = gpu_memory_gb - model_weight_memory
    max_concurrent_requests_memory = available_for_kv / kv_cache_per_request
    
    # Latency constraint
    # If latency = single_request_latency, we can only serve 1 request at a time
    # With batching, we can serve more
    max_concurrent_requests_latency = target_latency_ms / single_request['total_latency_ms']
    
    # Effective concurrent requests (limited by memory or latency)
    max_concurrent_requests = min(max_concurrent_requests_memory, 
                                   max_concurrent_requests_latency)
    
    # GPUs needed based on throughput
    gpus_needed_throughput = total_rps / requests_per_gpu
    
    # GPUs needed (round up)
    gpus_needed = int(np.ceil(gpus_needed_throughput))
    
    return {
        'total_rps': total_rps,
        'requests_per_gpu': requests_per_gpu,
        'gpus_needed': gpus_needed,
        'max_concurrent_requests': max_concurrent_requests,
        'kv_cache_per_request_gb': kv_cache_per_request,
        'model_memory_gb': model_weight_memory,
        'single_request_latency_ms': single_request['total_latency_ms']
    }

# Example: How many H100s for 10k users?
print("GPU Requirements for Serving LLaMA-2-70B:")
print("=" * 70)

scenarios = [
    ("Low load", 1000, 0.01),    # 10 RPS total
    ("Medium load", 5000, 0.02), # 100 RPS total
    ("High load", 10000, 0.05),  # 500 RPS total
]

for name, num_users, rps_per_user in scenarios:
    result = estimate_serving_gpus(
        model_params_b=70,
        num_users=num_users,
        requests_per_second_per_user=rps_per_user,
        avg_prompt_len=100,
        avg_gen_len=50,
        target_latency_ms=1000,  # 1 second
        d_model=8192,
        num_layers=80
    )
    
    print(f"\n{name}:")
    print(f"  Users: {num_users:,}, RPS per user: {rps_per_user}")
    print(f"  Total RPS: {result['total_rps']:.1f}")
    print(f"  Requests per GPU: {result['requests_per_gpu']:.2f}")
    print(f"  GPUs needed: {result['gpus_needed']}")
    print(f"  Single request latency: {result['single_request_latency_ms']:.1f} ms")

print("\nNote: This is a simplified estimate. Real serving systems use:")
print("  - Continuous batching (vLLM, TGI)")
print("  - PagedAttention for efficient KV cache management")
print("  - Request queuing and load balancing")

## 10. Reinforcement Learning Training

RL training (e.g., RLHF, PPO) has different characteristics than supervised learning:

### RL Training Components:

1. **Policy Model**: The model being trained
2. **Reference Model**: Frozen copy for KL penalty
3. **Reward Model**: Scores generated outputs
4. **Critic Model** (PPO): Value function estimator

### Two Training Paradigms:

#### **Synchronous Training**:
- All models on same GPUs
- Sequential execution: generate → score → train
- **Pros**: Simple, deterministic
- **Cons**: Low GPU utilization (models waiting)

#### **Asynchronous Training**:
- Models on different GPUs
- Parallel execution: generation on GPU set A, training on GPU set B
- **Pros**: High GPU utilization, faster
- **Cons**: More complex, requires more GPUs

### Throughput Calculation:

**Synchronous**:
```
Time per iteration = T_generate + T_score + T_train
Throughput = Batch_size / Time_per_iteration
```

**Asynchronous**:
```
Time per iteration = max(T_generate, T_train)
Throughput = Batch_size / max(T_generate, T_train)
```

### GPU Allocation:

**Synchronous (256 GPUs total)**:
- Policy: 256 GPUs (training + generation)
- Reference: 0 GPUs (share with policy)
- Reward: 0 GPUs (share with policy)
- Critic: 0 GPUs (share with policy)

**Asynchronous (256 GPUs total)**:
- Policy: 128 GPUs (training only)
- Reference: 64 GPUs (generation only)
- Reward: 32 GPUs (scoring only)
- Critic: 32 GPUs (value estimation)

### When to Use Each:

- **Sync**: Small scale, debugging, deterministic results
- **Async**: Production, large scale, maximize throughput

In [None]:
# RL Training throughput estimation

def estimate_rl_throughput(
    model_params_b: float,
    batch_size: int,
    prompt_len: int,
    gen_len: int,
    num_ppo_epochs: int,
    total_gpus: int,
    mode: str = 'sync',  # 'sync' or 'async'
    gpu_tflops: float = 990,
    memory_bandwidth_tbs: float = 3.35
) -> Dict[str, float]:
    """
    Estimate RL training throughput.
    
    Args:
        model_params_b: Model parameters in billions
        batch_size: Batch size for generation
        prompt_len: Prompt length
        gen_len: Generation length
        num_ppo_epochs: Number of PPO training epochs
        total_gpus: Total number of GPUs
        mode: 'sync' or 'async'
        gpu_tflops: GPU TFLOPS
        memory_bandwidth_tbs: Memory bandwidth
    
    Returns:
        Dictionary with throughput estimates
    """
    # Generation time
    gen_result = calculate_inference_latency(
        model_params_b=model_params_b,
        prompt_len=prompt_len,
        gen_len=gen_len,
        batch_size=batch_size,
        gpu_tflops=gpu_tflops,
        memory_bandwidth_tbs=memory_bandwidth_tbs
    )
    t_generate_ms = gen_result['total_latency_ms']
    
    # Reward scoring time (much faster, small model)
    # Assume reward model is 10x smaller
    reward_result = calculate_inference_latency(
        model_params_b=model_params_b / 10,
        prompt_len=prompt_len + gen_len,
        gen_len=1,  # Just scoring
        batch_size=batch_size,
        gpu_tflops=gpu_tflops,
        memory_bandwidth_tbs=memory_bandwidth_tbs
    )
    t_score_ms = reward_result['prefill_ms']
    
    # Training time (simplified)
    # FLOPs for training: batch_size × (prompt_len + gen_len) × 6N per epoch
    tokens_per_batch = batch_size * (prompt_len + gen_len)
    flops_per_epoch = tokens_per_batch * 6 * model_params_b * 1e9
    
    # With data parallelism
    if mode == 'sync':
        training_gpus = total_gpus
    else:
        # In async mode, split GPUs
        training_gpus = total_gpus // 2  # Half for training
    
    total_flops_per_epoch = flops_per_epoch
    time_per_epoch_s = total_flops_per_epoch / (training_gpus * gpu_tflops * 1e12 * 0.5)  # 50% MFU
    time_per_epoch_ms = time_per_epoch_s * 1000
    t_train_ms = time_per_epoch_ms * num_ppo_epochs
    
    # Total time per iteration
    if mode == 'sync':
        # Sequential: generate + score + train
        time_per_iter_ms = t_generate_ms + t_score_ms + t_train_ms
    else:
        # Parallel: max(generate + score, train)
        time_per_iter_ms = max(t_generate_ms + t_score_ms, t_train_ms)
    
    # Throughput
    samples_per_second = batch_size / (time_per_iter_ms / 1000)
    tokens_per_second = (batch_size * gen_len) / (time_per_iter_ms / 1000)
    
    return {
        't_generate_ms': t_generate_ms,
        't_score_ms': t_score_ms,
        't_train_ms': t_train_ms,
        'time_per_iter_ms': time_per_iter_ms,
        'samples_per_second': samples_per_second,
        'tokens_per_second': tokens_per_second,
        'speedup': 1.0  # Will be calculated later
    }

# Example: Compare sync vs async RL training
print("RL Training Throughput Comparison:")
print("=" * 70)

config = {
    'model_params_b': 70,
    'batch_size': 512,
    'prompt_len': 100,
    'gen_len': 50,
    'num_ppo_epochs': 4,
    'total_gpus': 256,
}

sync_result = estimate_rl_throughput(**config, mode='sync')
async_result = estimate_rl_throughput(**config, mode='async')

async_result['speedup'] = sync_result['time_per_iter_ms'] / async_result['time_per_iter_ms']

print(f"Configuration: {config['model_params_b']}B params, batch={config['batch_size']}, "
      f"{config['total_gpus']} GPUs")
print()

print("Synchronous Training:")
print(f"  Generation time:  {sync_result['t_generate_ms']:7.1f} ms")
print(f"  Scoring time:     {sync_result['t_score_ms']:7.1f} ms")
print(f"  Training time:    {sync_result['t_train_ms']:7.1f} ms")
print(f"  Total time:       {sync_result['time_per_iter_ms']:7.1f} ms")
print(f"  Throughput:       {sync_result['samples_per_second']:7.1f} samples/s")
print(f"                    {sync_result['tokens_per_second']:7.1f} tokens/s")
print()

print("Asynchronous Training:")
print(f"  Generation time:  {async_result['t_generate_ms']:7.1f} ms (parallel)")
print(f"  Scoring time:     {async_result['t_score_ms']:7.1f} ms (parallel)")
print(f"  Training time:    {async_result['t_train_ms']:7.1f} ms (parallel)")
print(f"  Total time:       {async_result['time_per_iter_ms']:7.1f} ms")
print(f"  Throughput:       {async_result['samples_per_second']:7.1f} samples/s")
print(f"                    {async_result['tokens_per_second']:7.1f} tokens/s")
print()

print(f"Speedup (async vs sync): {async_result['speedup']:.2f}x")
print()
print("Key Insight: Async training can provide significant speedup by overlapping")
print("generation and training, at the cost of more complexity and GPUs.")

In [None]:
# Visualization: RL Training Comparison

# Compare different batch sizes
batch_sizes = [128, 256, 512, 1024]
sync_throughputs = []
async_throughputs = []

for bs in batch_sizes:
    sync = estimate_rl_throughput(
        model_params_b=70,
        batch_size=bs,
        prompt_len=100,
        gen_len=50,
        num_ppo_epochs=4,
        total_gpus=256,
        mode='sync'
    )
    async_rl = estimate_rl_throughput(
        model_params_b=70,
        batch_size=bs,
        prompt_len=100,
        gen_len=50,
        num_ppo_epochs=4,
        total_gpus=256,
        mode='async'
    )
    sync_throughputs.append(sync['samples_per_second'])
    async_throughputs.append(async_rl['samples_per_second'])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Throughput comparison
x = np.arange(len(batch_sizes))
width = 0.35

ax1.bar(x - width/2, sync_throughputs, width, label='Synchronous', alpha=0.8)
ax1.bar(x + width/2, async_throughputs, width, label='Asynchronous', alpha=0.8)

ax1.set_xlabel('Batch Size', fontsize=12)
ax1.set_ylabel('Throughput (samples/s)', fontsize=12)
ax1.set_title('RL Training Throughput: Sync vs Async', fontsize=14)
ax1.set_xticks(x)
ax1.set_xticklabels(batch_sizes)
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# Plot 2: Speedup
speedups = [a / s for a, s in zip(async_throughputs, sync_throughputs)]
ax2.plot(batch_sizes, speedups, marker='o', linewidth=2, markersize=8, color='green')
ax2.axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='Baseline (sync)')

ax2.set_xlabel('Batch Size', fontsize=12)
ax2.set_ylabel('Speedup (async / sync)', fontsize=12)
ax2.set_title('Async vs Sync Speedup', fontsize=14)
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Average speedup: {np.mean(speedups):.2f}x")
print(f"\nAsync training is particularly beneficial for RL where generation")
print(f"and training can be done in parallel on different GPU sets.")

## Summary: Key Takeaways for Interviews

### Training Token Optimization:
- ✅ **Chinchilla optimal**: 20 tokens per parameter
- ✅ Real models often train beyond 20x for better quality
- ✅ MoE: Use active parameters for token calculation

### FLOPs Calculations:
- ✅ **Training**: 6N FLOPs per token (2 FWD + 4 BWD)
- ✅ **BWD = 2 × FWD**, **Total = 3 × FWD**
- ✅ Use this to estimate GPU days

### Parallelism Strategies:
- ✅ **TP**: Within node, use NVLink
- ✅ **PP**: Across nodes, minimize bubble with micro-batching
- ✅ **FSDP ZeRO-2**: Sweet spot for most training
- ❌ **FSDP ZeRO-3**: Avoid - too much communication overhead
- ✅ **3D Parallelism**: TP × PP × DP for large scale

### Communication:
- ✅ **TP**: 2 AllReduce per layer (attention + MLP)
- ✅ **ZeRO-3**: AllGather + ReduceScatter per layer (2x overhead)
- ✅ Fast interconnect is crucial (NVLink > InfiniBand > Ethernet)

### Inference:
- ✅ **Prefill**: Compute-bound, can batch efficiently
- ✅ **Decode**: Memory-bound, limited by bandwidth
- ✅ **Time per token** ≈ (Model size × 2 bytes) / Memory bandwidth
- ✅ **KV cache**: Major memory constraint for batching

### RL Training:
- ✅ **Sync**: Simple, lower GPU utilization
- ✅ **Async**: Complex, higher throughput (1.5-2x speedup)
- ✅ Use async for production, sync for debugging

### Interview Tips:
1. Know the 6N rule for training FLOPs
2. Understand when to use each parallelism strategy
3. Be able to estimate GPU requirements from model size
4. Explain why ZeRO-3 is rarely used (communication overhead)
5. Understand inference is memory-bound during decode
6. Know KV cache calculations for serving

### Useful Rules of Thumb:
- Training: 6N FLOPs per token
- Inference decode: ~Model size / Memory bandwidth seconds per token
- Memory: Model + Gradients + Optimizer + Activations ≈ 20 × Model size
- H100: ~990 TFLOPS, 3.35 TB/s memory bandwidth
- Good MFU: 40-50%, Great MFU: 50-60%, Excellent: 60%+