<a href="https://colab.research.google.com/github/kiankyars/Ultra-Scale-Playbook-Series/blob/main/6_ZeRO1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ultra-Scale Playbook: Part 6 - Zero Redundancy Optimizer (ZeRO) Introduction

## Overview
This notebook covers:
- Memory redundancies in naive data parallelism
- ZeRO optimization stages (01-03)
- How ZeRO-1 partitions optimizer states
- Communication patterns in ZeRO
- Practical implementations and tradeoffs

## Key Concepts

### Memory Redundancy in Data Parallelism
In naive DP, each GPU stores:
- Full copy of model parameters
- Full copy of gradients
- Full copy of optimizer states

This creates significant memory redundancy that ZeRO aims to eliminate.

### ZeRO Optimization Stages
ZeRO progressively eliminates redundancies:
1. **ZeRO-1**: Partitions optimizer states
2. **ZeRO-2**: Partitions gradients + optimizer states
3. **ZeRO-3**: Partitions parameters + gradients + optimizer states

Today we focus on ZeRO-1.

### ZeRO-1: Optimizer State Partitioning
- Optimizer states split across GPUs
- Each GPU only updates its portion of parameters
- Requires new communication patterns:
  - Replace all-reduce with reduce-scatter
  - Add all-gather after optimizer step

## Memory Calculations

In [None]:
def calculate_memory_usage(num_params, use_fp32_grad_accum=False):
    """
    Calculate memory usage for different components
    
    Args:
        num_params: Number of model parameters
        use_fp32_grad_accum: Whether to accumulate gradients in fp32
        
    Returns:
        Dictionary of memory usage in bytes
    """
    memory = {}
    
    # Model parameters in bf16/fp16
    memory['params_bf16'] = num_params * 2
    
    # Gradients in bf16/fp16
    memory['gradients_bf16'] = num_params * 2
    
    # Parameters in fp32 (for optimizer)
    memory['params_fp32'] = num_params * 4
    
    # Optimizer states (momentum + variance)
    memory['optim_states'] = num_params * 8
    
    # Gradient accumulation in fp32 (optional)
    memory['grad_accum_fp32'] = num_params * 4 if use_fp32_grad_accum else 0
    
    # Total memory
    memory['total'] = (memory['params_bf16'] + 
                      memory['gradients_bf16'] + 
                      memory['params_fp32'] + 
                      memory['optim_states'] + 
                      memory['grad_accum_fp32'])
    
    return memory

## ZeRO-1 Implementation Example

In [None]:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Simplified ZeRO-1 style optimizer
class PartitionedOptimizer:
    def __init__(self, params, optimizer_class, num_partitions):
        self.num_partitions = num_partitions
        self.rank = dist.get_rank() if dist.is_initialized() else 0
        
        # Partition parameters
        param_groups = self._partition_parameters(list(params))
        
        # Create optimizer only for this partition
        self.optimizer = optimizer_class(param_groups[self.rank])
    
    def _partition_parameters(self, params):
        """Split parameters into equal partitions"""
        partition_size = len(params) // self.num_partitions
        return [params[i*partition_size:(i+1)*partition_size] 
                for i in range(self.num_partitions)]
    
    def step(self):
        """Perform optimizer step only on this partition"""
        self.optimizer.step()
        
        # In real implementation, would need all-gather here
        # to share updated parameters across all ranks
        
    def zero_grad(self):
        self.optimizer.zero_grad()

## Interactive Exercises

### Exercise 1: Memory Savings Calculation
Calculate the memory savings from ZeRO-1 for:
- Model with 7B parameters
- 8 GPUs
- Without fp32 gradient accumulation

Compare naive DP vs ZeRO-1 memory usage

In [None]:
# Your solution here
num_params = 7_000_000_000
num_gpus = 8

# Naive DP memory (per GPU)
naive_mem = calculate_memory_usage(num_params)['total']

# ZeRO-1 memory (per GPU)
zero_mem = calculate_memory_usage(num_params)['total']
# Only optimizer states are partitioned
zero_mem -= (calculate_memory_usage(num_params)['optim_states'] * (1 - 1/num_gpus))

print(f"Naive DP memory per GPU: {naive_mem / (1024**3):.2f} GB")
print(f"ZeRO-1 memory per GPU: {zero_mem / (1024**3):.2f} GB")
print(f"Memory savings: {(naive_mem - zero_mem) / (1024**3):.2f} GB per GPU")

### Exercise 2: Communication Pattern Analysis
Compare the communication volume for:
1. Naive DP (all-reduce)
2. ZeRO-1 (reduce-scatter + all-gather)

Assume:
- 1B parameters
- bf16 precision (2 bytes per parameter)
- 8 GPUs

In [None]:
# Your solution here
num_params = 1_000_000_000
num_gpus = 8
bytes_per_param = 2

# Naive DP: all-reduce (2x parameter size)
naive_comm = 2 * num_params * bytes_per_param

# ZeRO-1: reduce-scatter (1x) + all-gather (1x)
zero_comm = num_params * bytes_per_param * 2

print(f"Naive DP communication: {naive_comm / (1024**3):.2f} GB")
print(f"ZeRO-1 communication: {zero_comm / (1024**3):.2f} GB")
print(f"Communication difference: {(naive_comm - zero_comm) / (1024**3):.2f} GB")

## ZeRO Communication Patterns

### Training Step Comparison
| Operation       | Vanilla DP         | ZeRO-1             |
|-----------------|--------------------|--------------------|
| Forward Pass    | Same               | Same               |
| Backward Pass   | Same               | Same               |
| Gradient Sync   | all-reduce         | reduce-scatter     |
| Optimizer Step  | Full update        | Partitioned update |
| Param Sync      | None               | all-gather         |

## Quiz
1. What does ZeRO-1 partition across GPUs?
   a) Parameters
   b) Gradients
   c) Optimizer states
   
2. What new communication operation does ZeRO-1 introduce?
   a) broadcast
   b) all-gather
   c) reduce
   
3. Why can't we partition activations in data parallelism?
   a) Each GPU processes different data
   b) Activations are too small
   c) It would hurt model accuracy
   
4. What's the main tradeoff in using ZeRO?
   a) Memory savings vs communication overhead
   b) Speed vs accuracy
   c) Model size vs batch size

Answers:
1. c) Optimizer states
2. b) all-gather
3. a) Each GPU processes different data
4. a) Memory savings vs communication overhead

## Summary
- ZeRO reduces memory redundancy in distributed training
- ZeRO-1 partitions optimizer states across GPUs
- Introduces new communication patterns (reduce-scatter + all-gather)
- Provides significant memory savings with some communication overhead
- Next we'll explore ZeRO-2 (gradient partitioning) and ZeRO-3 (parameter partitioning)