Describe the bug
After training with ZeRO Stage 3 + BF16, we need to extract Adam optimizer moments (exp_avg, exp_avg_sq) from saved checkpoint files for downstream analysis (e.g. loss landscape studies). There is no documented public API to do this safely.
optimizer_state_dict
└── optimizer_state_dict
└── state
└── 0 ← single entry for ALL params
exp_avg Tensor[33_554_432] float32
exp_avg_sq Tensor[33_554_432] float32
fp32_flat_groups [Tensor[33_554_432]]
1. cat shards across all ranks
full_exp_avg = torch.cat([shard["optimizer_state_dict"]
["optimizer_state_dict"]["state"][0]["exp_avg"]
for shard in shards])
2. slice by model param sizes in named_parameters() order
offset = 0
for name, p in model.named_parameters():
numel = p.numel()
param_avg = full_exp_avg[offset:offset+numel]
offset += numel
We are not certain this ordering is correct. Specifically:
Does DeepSpeed flatten parameters in model.named_parameters() order, or does it use a different internal ordering?
Is param_shapes in zero_pp_rank_*_mp_rank_00_model_states.pt always the authoritative ordering to use for slicing?
Is concatenating rank shards in rank order (0, 1, 2, ...) always correct, or does it depend on partition_count or other metadata?
Is the padding always appended at the end of the last shard, making it safe to ignore trailing elements after slicing by total param count?
Contrast with standard checkpoint
A standard (non-ZeRO) optimizer.pt has one state[i] entry per parameter, making moment extraction trivial. ZeRO-3's flat sharding makes this non-obvious and undocumented.
Feature request
Please consider one or more of the following:
A public utility function (analogous to get_fp32_state_dict_from_zero_checkpoint) that reassembles optimizer moments from ZeRO-3 checkpoint files post-hoc, without needing a live DeepSpeed engine.
Documentation of the exact flat parameter ordering and shard layout so users can safely implement their own extraction.
A verification utility to confirm that a given model's named_parameters() order matches the flat shard ordering in a checkpoint.
Expected behavior
A clear and concise description of what you expected to happen.
ds_report output
gds .................... [NO] ....... [NO]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
[WARNING] using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
DeepSpeed general environment info:
torch version .................... 2.4.0+cu118
deepspeed info ................... 0.15.0, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 13.1
deepspeed wheel compiled w. ...... torch 2.4, cuda 11.8
shared memory (/dev/shm) size .... 503.87 GB
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
- OS: [e.g. Ubuntu 18.04]
- GPU count and types 1 machines with 4 A100s
- (if applicable) what DeepSpeed-MII version are you using
- (if applicable) Hugging Face Transformers/Accelerate/etc. versions
- Python version: 3.11
- Any other relevant info about your setup
Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.
Describe the bug
After training with ZeRO Stage 3 + BF16, we need to extract Adam optimizer moments (exp_avg, exp_avg_sq) from saved checkpoint files for downstream analysis (e.g. loss landscape studies). There is no documented public API to do this safely.
optimizer_state_dict
└── optimizer_state_dict
└── state
└── 0 ← single entry for ALL params
exp_avg Tensor[33_554_432] float32
exp_avg_sq Tensor[33_554_432] float32
fp32_flat_groups [Tensor[33_554_432]]
1. cat shards across all ranks
full_exp_avg = torch.cat([shard["optimizer_state_dict"]
["optimizer_state_dict"]["state"][0]["exp_avg"]
for shard in shards])
2. slice by model param sizes in named_parameters() order
offset = 0
for name, p in model.named_parameters():
numel = p.numel()
param_avg = full_exp_avg[offset:offset+numel]
offset += numel
We are not certain this ordering is correct. Specifically:
Contrast with standard checkpoint
A standard (non-ZeRO) optimizer.pt has one state[i] entry per parameter, making moment extraction trivial. ZeRO-3's flat sharding makes this non-obvious and undocumented.
Feature request
Please consider one or more of the following:
Expected behavior
A clear and concise description of what you expected to happen.
ds_report output
gds .................... [NO] ....... [NO]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
[WARNING] using untested triton version (3.0.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
DeepSpeed general environment info:
torch version .................... 2.4.0+cu118
deepspeed info ................... 0.15.0, unknown, unknown
torch cuda version ............... 11.8
torch hip version ................ None
nvcc version ..................... 13.1
deepspeed wheel compiled w. ...... torch 2.4, cuda 11.8
shared memory (/dev/shm) size .... 503.87 GB
Screenshots
If applicable, add screenshots to help explain your problem.
System info (please complete the following information):
Docker context
Are you using a specific docker image that you can share?
Additional context
Add any other context about the problem here.