# Lecture 14: Distributed Training

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/efficientml-course/efficientml_course/14_distributed_training/demo.ipynb)

Data parallelism, ZeRO optimization, and FSDP.


In [None]:
!pip install torch -q
import torch

# ZeRO Memory Analysis
def zero_memory_per_gpu(model_params_b, num_gpus, dtype_bytes=2):
    """Calculate memory per GPU for different ZeRO stages"""
    P = model_params_b * 1e9  # Parameters
    
    # Memory components (FP16 training with FP32 optimizer)
    model_mem = P * dtype_bytes
    grad_mem = P * dtype_bytes
    optimizer_mem = P * 4 * 2  # Adam: momentum + variance in FP32
    
    results = {}
    
    # No ZeRO (DDP): Everything replicated
    results['DDP'] = (model_mem + grad_mem + optimizer_mem) / 1e9
    
    # ZeRO-1: Partition optimizer states
    results['ZeRO-1'] = (model_mem + grad_mem + optimizer_mem / num_gpus) / 1e9
    
    # ZeRO-2: + Partition gradients
    results['ZeRO-2'] = (model_mem + grad_mem / num_gpus + optimizer_mem / num_gpus) / 1e9
    
    # ZeRO-3: + Partition parameters
    results['ZeRO-3'] = (model_mem / num_gpus + grad_mem / num_gpus + optimizer_mem / num_gpus) / 1e9
    
    return results

# Example: 7B model on 8 GPUs
model_size = 7  # billion params
num_gpus = 8

results = zero_memory_per_gpu(model_size, num_gpus)

print(f"Memory per GPU for {model_size}B model on {num_gpus} GPUs:")
print("=" * 45)
for stage, mem in results.items():
    bar = "â–ˆ" * int(mem / 5)
    fits = "âœ“" if mem < 80 else "âœ—"
    print(f"{stage:8} | {mem:>6.1f} GB | {fits} | {bar}")

print("\nðŸŽ¯ ZeRO-3 enables training models that don't fit on single GPU!")
