<a href="https://colab.research.google.com/github/kiankyars/Ultra-Scale-Playbook-Series/blob/main/notebooks/7_ZeRO2_and_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ultra-Scale Playbook: Part 7 - Advanced ZeRO and Memory Optimization

## Overview
This notebook covers:
- ZeRO-2: Gradient partitioning
- ZeRO-3/Fully Sharded Data Parallel (FSDP): Parameter partitioning
- Communication patterns in advanced ZeRO
- Activation memory challenges
- Practical implementations and tradeoffs

## Key Concepts

### ZeRO-2: Gradient Partitioning
- Extends ZeRO-1 by also partitioning gradients
- Each GPU only stores gradients for its parameter partition
- Memory savings: 8x vs baseline (compared to 4x with ZeRO-1)
- Communication pattern same as ZeRO-1 (reduce-scatter + all-gather)

### ZeRO-3/FSDP: Parameter Partitioning
- Partitions parameters across GPUs
- Requires on-demand gathering of parameters during forward/backward
- Uses prefetching to overlap communication with computation
- Adds ~50% communication overhead
- Enables training models that don't fit on single GPU

### Communication Patterns
| Operation       | ZeRO-1/2          | ZeRO-3/FSDP       |
|----------------|-------------------|-------------------|
| Forward Pass   | Full params       | Layer-by-layer all-gather |
| Backward Pass  | Full params       | Layer-by-layer all-gather |
| Gradient Sync  | reduce-scatter    | reduce-scatter    |
| Param Update   | Partitioned       | Partitioned       |

## Memory Optimization Implementation

In [None]:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Simplified FSDP-style layer implementation
class ShardedLinear(torch.nn.Module):
    def __init__(self, input_size, output_size, num_partitions):
        super().__init__()
        self.num_partitions = num_partitions
        self.rank = dist.get_rank() if dist.is_initialized() else 0
        
        # Partition the weight matrix
        partition_size = output_size // num_partitions
        self.shard = torch.nn.Parameter(
            torch.randn(partition_size, input_size) * 0.01
        )
        
    def forward(self, x):
        # In real FSDP, this would be an all-gather
        # Here we just simulate the partitioned computation
        return x @ self.shard.T

In [None]:
def calculate_memory_savings(num_params, num_gpus, zero_stage):
    """
    Calculate memory savings for different ZeRO stages
    
    Args:
        num_params: Number of model parameters
        num_gpus: Number of GPUs
        zero_stage: 1, 2, or 3
        
    Returns:
        Memory usage per GPU in GB
    """
    # Base memory (naive DP)
    base_mem = num_params * (2 + 2 + 4 + 8)  # params, grads, fp32 params, optimizer states
    
    if zero_stage == 1:
        # Only optimizer states partitioned
        mem = base_mem - (num_params * 8 * (1 - 1/num_gpus))
    elif zero_stage == 2:
        # Optimizer states + gradients partitioned
        mem = base_mem - (num_params * (8 + 2) * (1 - 1/num_gpus))
    elif zero_stage == 3:
        # Everything partitioned
        mem = base_mem - (num_params * (8 + 2 + 2 + 4) * (1 - 1/num_gpus))
    
    return mem / (1024**3)  # Convert to GB

## Interactive Exercises

### Exercise 1: Memory Savings Comparison
Compare memory usage for:
- 13B parameter model
- 8 GPUs
- ZeRO stages 1, 2, and 3

Calculate memory per GPU for each stage

In [None]:
# Your solution here
num_params = 13_000_000_000
num_gpus = 8

for stage in [1, 2, 3]:
    mem_gb = calculate_memory_savings(num_params, num_gpus, stage)
    print(f"ZeRO-{stage} memory per GPU: {mem_gb:.2f} GB")

### Exercise 2: Communication Overhead Estimation
Estimate communication overhead for ZeRO-3 with:
- 24 layers
- 1B parameters per layer
- bf16 precision (2 bytes per parameter)

Calculate total communication volume for one training step

In [None]:
# Your solution here
num_layers = 24
params_per_layer = 1_000_000_000
bytes_per_param = 2

# ZeRO-3 does 2*L-1 all-gathers (forward + backward)
all_gathers = 2 * num_layers - 1
all_gather_comm = all_gathers * params_per_layer * bytes_per_param

# Plus one reduce-scatter for gradients
reduce_scatter_comm = params_per_layer * bytes_per_param

total_comm = (all_gather_comm + reduce_scatter_comm) / (1024**3)
print(f"Total communication per step: {total_comm:.2f} GB")

## Activation Memory Challenges
- ZeRO doesn't partition activations
- Activation memory grows with:
  - Batch size
  - Sequence length
  - Model width (hidden size)
- For large models, activations dominate memory usage
- Solutions (coming in next parts):
  - Tensor parallelism (partition activations)
  - Activation checkpointing
  - Sequence parallelism

## Quiz
1. What does ZeRO-2 add beyond ZeRO-1?
   a) Parameter partitioning
   b) Gradient partitioning
   c) Activation partitioning
   
2. How does ZeRO-3/FSDP handle parameters during forward pass?
   a) Keeps full copy on each GPU
   b) Gathers layer parameters when needed
   c) Uses CPU offloading
   
3. What is the main remaining memory bottleneck with ZeRO-3?
   a) Optimizer states
   b) Gradients
   c) Activations
   
4. What technique helps overlap communication in ZeRO-3?
   a) Prefetching
   b) Compression
   c) Quantization

Answers:
1. b) Gradient partitioning
2. b) Gathers layer parameters when needed
3. c) Activations
4. a) Prefetching

## Summary
- ZeRO-2 adds gradient partitioning to optimizer state partitioning
- ZeRO-3/FSDP adds parameter partitioning for maximum memory savings
- Advanced ZeRO requires careful communication optimization
- Activation memory remains a key challenge
- Next we'll explore tensor parallelism for activation partitioning