A minimal, from-scratch implementation of ZeRO Stage 1 (Distributed Optimizer) for educational purposes. ~300 lines of core code across three layers.
Companion blog post: ZeRO-1 Distributed Optimizer — A Deep Dive
Three-layer design following the real distributed optimizer stack:
Buffer → DDP → Distributed Optimizer
| Layer | File | Responsibility |
|---|---|---|
| Buffer | buffer.py |
Flat contiguous storage, padding, shard view, all-gather / reduce-scatter |
| DDP | ddp.py |
Param / grad buffer creation, remapping .data and .main_grad, gradient and parameter sync |
| Distributed Optimizer | zero.py |
fp32 Adam on 1/N shard, bf16 writeback |
Only optimizer state (fp32 master params + Adam m/v) is sharded. Each rank keeps full copies of bf16 model parameters and gradients — no forward/backward hooks needed.
forward full bf16 params in param_buffer
backward grads written to param.grad
sync_grads() copy param.grad → grad_buffer, reduce-scatter(SUM)
optimizer.step() grad shard bf16→fp32, Adam, fp32→bf16 writeback, all-gather
grad_buffer (bf16, full)
→ reduce-scatter → grad shard (bf16, P/N)
→ float() → shard_fp32.grad (fp32, P/N)
→ Adam.step() → shard_fp32 (fp32, P/N)
→ bfloat16() → param_buffer shard (bf16, P/N)
→ all-gather → param_buffer (bf16, full)
__init__.py Package exports
__main__.py CLI entry point
model.py Simple multi-layer MLP (test model)
buffer.py Buffer (contiguous storage + shard view)
ddp.py DistributedDataParallel (param/grad remapping + sync)
zero.py DistributedOptimizer (fp32 Adam on shards)
utils.py Shared distributed testing utilities
test_zero.py Multi-step training correctness tests
profile_memory.py GPU memory profiling (baseline vs ZeRO-1)
DESIGN.md Detailed design document (Chinese)
# Run correctness tests (requires >= 2 GPUs)
python -m zero_optim_toy test
# Run memory profiling (requires >= 4 GPUs)
python -m zero_optim_toy profileOr run individual scripts directly:
# all tests
python -m pytest test_zero.py -v
# single test
python -m pytest test_zero.py::TestZeROTraining::test_multi_step_2gpu -v
# memory profiling
python profile_memory.pyThe tests verify that ZeRO-1 training produces bit-exact identical parameters as single-process reference training following the same precision path.
- Python >= 3.8
- PyTorch >= 2.1
- Multiple CUDA GPUs