# üåê Lecture 14: Distributed Training - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/14_distributed_training/demo.ipynb)

## What You'll Learn
- Data parallelism vs model parallelism
- ZeRO optimization stages
- Pipeline parallelism
- Communication overhead analysis

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

print('Ready for Distributed Training!')

## Part 1: Why Distributed Training?

Modern models are too large for single GPUs!

In [None]:
def model_vs_gpu_memory():
    """
    Compare model training memory requirements vs GPU capacity.
    """
    models = {
        'GPT-2': 1.5,
        'GPT-3': 175,
        'LLaMA-7B': 7,
        'LLaMA-70B': 70,
        'GPT-4 (est.)': 1800,
    }
    
    gpus = {
        'RTX 4090': 24,
        'A100-40GB': 40,
        'A100-80GB': 80,
        'H100-80GB': 80,
    }
    
    print('üìä MODEL TRAINING MEMORY (FP16 + Adam)')
    print('=' * 60)
    print('\nRule of thumb: Training memory ‚âà 18 √ó model size')
    print('  (Weights + Gradients + Optimizer: 2 + 2 + 12 = 16 bytes/param)')
    print('  (+ Activations: ~2 bytes/param)')
    
    print(f'\n{"Model":<15} {"Params (B)":<12} {"Training Mem (GB)":<20} {"GPUs Needed":<15}')
    print('-' * 60)
    
    for name, params_b in models.items():
        train_mem = params_b * 18  # GB
        gpus_needed = np.ceil(train_mem / 80)  # A100-80GB
        print(f'{name:<15} {params_b:<12.1f} {train_mem:<20.0f} {gpus_needed:<15.0f}')
    
    return models, gpus

models, gpus = model_vs_gpu_memory()

In [None]:
# Visualize
fig, ax = plt.subplots(figsize=(12, 6))

model_names = list(models.keys())
train_mems = [p * 18 for p in models.values()]

bars = ax.bar(model_names, train_mems, color='#3b82f6')

# GPU lines
ax.axhline(y=24, color='#22c55e', linestyle='--', linewidth=2, label='RTX 4090 (24GB)')
ax.axhline(y=80, color='#f59e0b', linestyle='--', linewidth=2, label='A100-80GB')
ax.axhline(y=80*8, color='#ef4444', linestyle='--', linewidth=2, label='8√ó A100-80GB')

ax.set_ylabel('Training Memory (GB)', fontsize=12)
ax.set_title('üìä Model Training Memory vs GPU Capacity', fontsize=14)
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

for bar, mem in zip(bars, train_mems):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.2,
            f'{mem:.0f}GB', ha='center', fontsize=10)

plt.tight_layout()
plt.show()

## Part 2: Parallelism Strategies

In [None]:
def explain_parallelism():
    """
    Explain different parallelism strategies.
    """
    print('üìä PARALLELISM STRATEGIES')
    print('=' * 70)
    
    strategies = {
        'Data Parallelism (DDP)': {
            'description': 'Replicate model on each GPU, split batch',
            'memory': 'Full model per GPU',
            'comm': 'All-reduce gradients',
            'best_for': 'Models that fit in single GPU'
        },
        'Model Parallelism (MP)': {
            'description': 'Split model layers across GPUs',
            'memory': 'Fraction of model per GPU',
            'comm': 'Activations between GPUs',
            'best_for': 'Very large models'
        },
        'Pipeline Parallelism (PP)': {
            'description': 'Split model stages, micro-batch pipeline',
            'memory': 'Fraction of model + pipeline buffers',
            'comm': 'Activations between stages',
            'best_for': 'Deep models'
        },
        'Tensor Parallelism (TP)': {
            'description': 'Split individual tensors across GPUs',
            'memory': 'Fraction of each layer per GPU',
            'comm': 'High - within each layer',
            'best_for': 'Large layers (attention, FFN)'
        },
    }
    
    for name, info in strategies.items():
        print(f'\nüîπ {name}')
        print(f'   Description: {info["description"]}')
        print(f'   Memory: {info["memory"]}')
        print(f'   Communication: {info["comm"]}')
        print(f'   Best for: {info["best_for"]}')

explain_parallelism()

In [None]:
# Visualize Data Parallelism
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Data Parallelism diagram
ax = axes[0]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')

# GPUs
gpu_positions = [(1, 7), (4, 7), (7, 7)]
for i, (x, y) in enumerate(gpu_positions):
    rect = plt.Rectangle((x, y), 2, 2, fill=True, color='#3b82f6', alpha=0.8)
    ax.add_patch(rect)
    ax.text(x+1, y+1, f'GPU {i}\nFull Model', ha='center', va='center', color='white', fontsize=9)

# Data batches
batch_positions = [(1, 3), (4, 3), (7, 3)]
for i, (x, y) in enumerate(batch_positions):
    rect = plt.Rectangle((x, y), 2, 1.5, fill=True, color='#22c55e', alpha=0.8)
    ax.add_patch(rect)
    ax.text(x+1, y+0.75, f'Batch {i}', ha='center', va='center', color='white', fontsize=9)

# Arrows
for (gx, gy), (bx, by) in zip(gpu_positions, batch_positions):
    ax.annotate('', xy=(gx+1, gy), xytext=(bx+1, by+1.5),
                arrowprops=dict(arrowstyle='->', color='gray', lw=2))

ax.set_title('Data Parallelism\n(Each GPU has full model)', fontsize=12)
ax.axis('off')

# Model Parallelism diagram
ax = axes[1]
ax.set_xlim(0, 10)
ax.set_ylim(0, 10)
ax.set_aspect('equal')

# GPUs with model parts
gpu_positions = [(1, 7), (4, 7), (7, 7)]
parts = ['Layers 1-4', 'Layers 5-8', 'Layers 9-12']
for i, ((x, y), part) in enumerate(zip(gpu_positions, parts)):
    rect = plt.Rectangle((x, y), 2, 2, fill=True, color='#ef4444', alpha=0.8)
    ax.add_patch(rect)
    ax.text(x+1, y+1, f'GPU {i}\n{part}', ha='center', va='center', color='white', fontsize=9)

# Arrows between GPUs
for i in range(len(gpu_positions)-1):
    x1, y1 = gpu_positions[i]
    x2, y2 = gpu_positions[i+1]
    ax.annotate('', xy=(x2, y1+1), xytext=(x1+2, y1+1),
                arrowprops=dict(arrowstyle='->', color='gray', lw=2))

ax.set_title('Model Parallelism\n(Model split across GPUs)', fontsize=12)
ax.axis('off')

plt.tight_layout()
plt.show()

## Part 3: ZeRO Optimization

In [None]:
def zero_memory_analysis(model_params_b, num_gpus):
    """
    Analyze memory per GPU for different ZeRO stages.
    
    ZeRO-1: Partition optimizer states
    ZeRO-2: + Partition gradients
    ZeRO-3: + Partition parameters
    """
    # Memory breakdown (bytes per parameter)
    # FP16 training: params=2, grads=2, optimizer=12 (Adam: m+v+master)
    
    param_bytes = 2  # FP16
    grad_bytes = 2   # FP16
    opt_bytes = 12   # FP32 master + momentum + variance
    
    params = model_params_b * 1e9
    
    results = {}
    
    # DDP (no partitioning)
    ddp_mem = params * (param_bytes + grad_bytes + opt_bytes) / 1e9
    results['DDP'] = ddp_mem
    
    # ZeRO-1: Partition optimizer
    zero1_mem = params * (param_bytes + grad_bytes + opt_bytes / num_gpus) / 1e9
    results['ZeRO-1'] = zero1_mem
    
    # ZeRO-2: + Partition gradients
    zero2_mem = params * (param_bytes + (grad_bytes + opt_bytes) / num_gpus) / 1e9
    results['ZeRO-2'] = zero2_mem
    
    # ZeRO-3: + Partition parameters
    zero3_mem = params * (param_bytes + grad_bytes + opt_bytes) / num_gpus / 1e9
    results['ZeRO-3'] = zero3_mem
    
    return results

# Analyze for 7B model
print('üìä ZERO OPTIMIZATION MEMORY (7B Model, 8 GPUs)')
print('=' * 60)

mems = zero_memory_analysis(model_params_b=7, num_gpus=8)

print(f'{"Stage":<15} {"Memory/GPU (GB)":<20} {"vs DDP":<15}')
print('-' * 50)

ddp_mem = mems['DDP']
for stage, mem in mems.items():
    savings = f'{ddp_mem/mem:.1f}x' if stage != 'DDP' else '-'
    print(f'{stage:<15} {mem:<20.1f} {savings:<15}')

print('\nüí° ZeRO-3 enables training models 8x larger than GPU memory!')

In [None]:
# Visualize ZeRO stages
fig, axes = plt.subplots(1, 4, figsize=(16, 5))

stages = ['DDP', 'ZeRO-1', 'ZeRO-2', 'ZeRO-3']
mems_7b = zero_memory_analysis(7, 8)

# What's stored on each GPU
components = ['Parameters', 'Gradients', 'Optimizer']
ddp_vals = [14, 14, 84]  # 7B √ó 2, 2, 12 bytes

stage_vals = {
    'DDP': [14, 14, 84],
    'ZeRO-1': [14, 14, 84/8],
    'ZeRO-2': [14, 14/8, 84/8],
    'ZeRO-3': [14/8, 14/8, 84/8],
}

colors = ['#3b82f6', '#22c55e', '#f59e0b']

for ax, stage in zip(axes, stages):
    vals = stage_vals[stage]
    bars = ax.bar(components, vals, color=colors)
    ax.set_ylabel('Memory (GB)')
    ax.set_title(f'{stage}\nTotal: {sum(vals):.0f} GB/GPU')
    ax.set_ylim(0, 100)

plt.suptitle('üìä Memory per GPU for 7B Model on 8 GPUs', fontsize=14)
plt.tight_layout()
plt.show()

## Part 4: Communication Analysis

In [None]:
def communication_analysis(model_params_b, num_gpus, bandwidth_gbps=400):
    """
    Analyze communication overhead for different strategies.
    """
    params = model_params_b * 1e9
    param_bytes = 2  # FP16
    
    # Bytes to communicate per iteration
    comm = {}
    
    # DDP: All-reduce gradients (2√ó params for ring all-reduce)
    comm['DDP'] = 2 * params * param_bytes
    
    # ZeRO-1: Same as DDP (gradients)
    comm['ZeRO-1'] = 2 * params * param_bytes
    
    # ZeRO-2: Reduce-scatter gradients
    comm['ZeRO-2'] = params * param_bytes
    
    # ZeRO-3: + All-gather params (twice: forward + backward)
    comm['ZeRO-3'] = params * param_bytes + 2 * params * param_bytes
    
    print('üìä COMMUNICATION OVERHEAD')
    print('=' * 60)
    print(f'Model: {model_params_b}B params, {num_gpus} GPUs, {bandwidth_gbps} Gbps')
    print(f'\n{"Strategy":<15} {"Comm (GB)":<15} {"Time (ms)":<15}')
    print('-' * 45)
    
    for strategy, bytes_comm in comm.items():
        gb = bytes_comm / 1e9
        time_ms = (bytes_comm * 8 / bandwidth_gbps / 1e9) * 1000
        print(f'{strategy:<15} {gb:<15.1f} {time_ms:<15.1f}')
    
    return comm

comm = communication_analysis(7, 8)

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

strategies = list(comm.keys())
comm_gb = [c / 1e9 for c in comm.values()]

bars = ax.bar(strategies, comm_gb, color=['#3b82f6', '#22c55e', '#f59e0b', '#ef4444'])
ax.set_ylabel('Communication per Step (GB)')
ax.set_title('üìä Communication Overhead by Strategy')
ax.grid(True, alpha=0.3, axis='y')

for bar, gb in zip(bars, comm_gb):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5,
            f'{gb:.1f}GB', ha='center')

plt.tight_layout()
plt.show()

## Part 5: Choosing the Right Strategy

In [None]:
def strategy_recommendation(model_params_b, gpu_memory_gb, num_gpus):
    """
    Recommend parallelism strategy based on constraints.
    """
    train_mem_per_model = model_params_b * 18  # GB
    
    print('üìä STRATEGY RECOMMENDATION')
    print('=' * 60)
    print(f'Model: {model_params_b}B params')
    print(f'Training memory needed: {train_mem_per_model:.0f} GB')
    print(f'Available: {num_gpus} √ó {gpu_memory_gb}GB = {num_gpus * gpu_memory_gb}GB')
    
    print('\nüéØ RECOMMENDATIONS:')
    
    if train_mem_per_model <= gpu_memory_gb:
        print('\n‚úÖ Model fits in single GPU')
        print('   ‚Üí Use DDP for fastest training')
        
    elif train_mem_per_model <= gpu_memory_gb * 1.5:
        print('\n‚ö†Ô∏è Model barely fits - tight on memory')
        print('   ‚Üí Use ZeRO-1 or ZeRO-2')
        
    elif train_mem_per_model <= num_gpus * gpu_memory_gb:
        print('\n‚ö†Ô∏è Model needs memory sharding')
        print('   ‚Üí Use ZeRO-3 (DeepSpeed/FSDP)')
        
    else:
        print('\n‚ùå Model too large even with ZeRO-3')
        print('   ‚Üí Need more GPUs or model parallelism')
        print(f'   ‚Üí Minimum GPUs needed: {int(np.ceil(train_mem_per_model / gpu_memory_gb))}')

# Test different scenarios
scenarios = [
    (1.5, 24, 4),   # GPT-2 on 4√ó RTX 4090
    (7, 80, 8),     # LLaMA-7B on 8√ó A100
    (70, 80, 8),    # LLaMA-70B on 8√ó A100
]

for params, gpu_mem, n_gpus in scenarios:
    print('\n' + '='*60)
    strategy_recommendation(params, gpu_mem, n_gpus)

In [None]:
print('üéØ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. Large models need distributed training across GPUs')
print('\n2. Data Parallelism: Replicate model, split data')
print('\n3. ZeRO partitions memory across GPUs (1‚Üí2‚Üí3)')
print('\n4. ZeRO-3 enables models 8x larger than GPU memory')
print('\n5. More partitioning = more communication')
print('\n6. Use DDP when possible, ZeRO when needed')
print('\n' + '=' * 60)
print('\nüìö Next: Efficient Vision Models!')