_wrap_state_as_dtensor rejects parameters whose shard dimension isn't evenly divisible by the FSDP mesh size, but FSDP2 handles uneven sharding natively via padding. This makes checkpoint saving fail for any model with small parameters (e.g. MoE gates, Mamba params) when the world size exceeds the parameter's shard dimension.
ValueError: DCP checkpointing requires evenly-sharded parameters, but parameter
with shape torch.Size([8, 128]) is unevenly sharded on dim 0 across 32 ranks.
Reproduce
- Use
FlashAdamW with any FSDP2-sharded model that has a parameter with shape[shard_dim] < world_size (e.g. Nemotron-H 30B which has [8, 128] params from Mamba layers)
- Run on 32 GPUs
- Save a checkpoint — triggers
optimizer.state_dict() → _state_dict_for_param() → _wrap_state_as_dtensor() → ValueError
Root cause
In optimizers.py:654:
state[key] = DTensor.from_local(val, mesh, placements)
Without an explicit shape=, DTensor.from_local infers global_shape = local_shape * mesh_size, which is wrong for padded (uneven) shards. The check on lines 635-646 preemptively rejects this case with a ValueError.
Suggested fix
Pass the parameter's global shape explicitly. This is correct for both even and uneven shards:
state[key] = DTensor.from_local(
val, mesh, placements,
shape=param.shape, stride=param.stride(),
)
For even shards, param.shape == local_shape * mesh_size so this is a no-op. For uneven shards, it provides the true global shape so DTensor correctly accounts for padding. PyTorch's DCP has supported uneven DTensor shards since 2.1+.
The entire validation block (lines 633-646) can then be removed.
Environment
- flashoptim 0.1.3
- PyTorch 2.10
- 32 GPUs (4 nodes × 8), FSDP2 (32) + EP (8), no TP/PP/CP
- Model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
_wrap_state_as_dtensorrejects parameters whose shard dimension isn't evenly divisible by the FSDP mesh size, but FSDP2 handles uneven sharding natively via padding. This makes checkpoint saving fail for any model with small parameters (e.g. MoE gates, Mamba params) when the world size exceeds the parameter's shard dimension.Reproduce
FlashAdamWwith any FSDP2-sharded model that has a parameter withshape[shard_dim] < world_size(e.g. Nemotron-H 30B which has[8, 128]params from Mamba layers)optimizer.state_dict()→_state_dict_for_param()→_wrap_state_as_dtensor()→ ValueErrorRoot cause
In
optimizers.py:654:Without an explicit
shape=,DTensor.from_localinfersglobal_shape = local_shape * mesh_size, which is wrong for padded (uneven) shards. The check on lines 635-646 preemptively rejects this case with a ValueError.Suggested fix
Pass the parameter's global shape explicitly. This is correct for both even and uneven shards:
For even shards,
param.shape == local_shape * mesh_sizeso this is a no-op. For uneven shards, it provides the true global shape so DTensor correctly accounts for padding. PyTorch's DCP has supported uneven DTensor shards since 2.1+.The entire validation block (lines 633-646) can then be removed.
Environment