Skip to content

ggluo/zero-optim

Repository files navigation

zero_optim_toy

A minimal, from-scratch implementation of ZeRO Stage 1 (Distributed Optimizer) for educational purposes. ~300 lines of core code across three layers.

Companion blog post: ZeRO-1 Distributed Optimizer — A Deep Dive

Architecture

Three-layer design following the real distributed optimizer stack:

Buffer  →  DDP  →  Distributed Optimizer
Layer File Responsibility
Buffer buffer.py Flat contiguous storage, padding, shard view, all-gather / reduce-scatter
DDP ddp.py Param / grad buffer creation, remapping .data and .main_grad, gradient and parameter sync
Distributed Optimizer zero.py fp32 Adam on 1/N shard, bf16 writeback

What ZeRO-1 shards

Only optimizer state (fp32 master params + Adam m/v) is sharded. Each rank keeps full copies of bf16 model parameters and gradients — no forward/backward hooks needed.

Training step

forward           full bf16 params in param_buffer
backward          grads written to param.grad
sync_grads()      copy param.grad → grad_buffer, reduce-scatter(SUM)
optimizer.step()  grad shard bf16→fp32, Adam, fp32→bf16 writeback, all-gather

Mixed precision flow

grad_buffer (bf16, full)
  → reduce-scatter → grad shard (bf16, P/N)
  → float() → shard_fp32.grad (fp32, P/N)
  → Adam.step() → shard_fp32 (fp32, P/N)
  → bfloat16() → param_buffer shard (bf16, P/N)
  → all-gather → param_buffer (bf16, full)

Files

__init__.py                Package exports
__main__.py                CLI entry point
model.py                   Simple multi-layer MLP (test model)
buffer.py                  Buffer (contiguous storage + shard view)
ddp.py                     DistributedDataParallel (param/grad remapping + sync)
zero.py                    DistributedOptimizer (fp32 Adam on shards)
utils.py                   Shared distributed testing utilities
test_zero.py               Multi-step training correctness tests
profile_memory.py          GPU memory profiling (baseline vs ZeRO-1)
DESIGN.md                  Detailed design document (Chinese)

Usage

# Run correctness tests (requires >= 2 GPUs)
python -m zero_optim_toy test

# Run memory profiling (requires >= 4 GPUs)
python -m zero_optim_toy profile

Or run individual scripts directly:

# all tests
python -m pytest test_zero.py -v

# single test
python -m pytest test_zero.py::TestZeROTraining::test_multi_step_2gpu -v

# memory profiling
python profile_memory.py

The tests verify that ZeRO-1 training produces bit-exact identical parameters as single-process reference training following the same precision path.

Requirements

  • Python >= 3.8
  • PyTorch >= 2.1
  • Multiple CUDA GPUs

About

A toy zero optimiser

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages