Skip to content

_wrap_state_as_dtensor raises ValueError for unevenly-sharded FSDP2 parameters #6

@pzelasko

Description

@pzelasko

_wrap_state_as_dtensor rejects parameters whose shard dimension isn't evenly divisible by the FSDP mesh size, but FSDP2 handles uneven sharding natively via padding. This makes checkpoint saving fail for any model with small parameters (e.g. MoE gates, Mamba params) when the world size exceeds the parameter's shard dimension.

ValueError: DCP checkpointing requires evenly-sharded parameters, but parameter
with shape torch.Size([8, 128]) is unevenly sharded on dim 0 across 32 ranks.

Reproduce

  1. Use FlashAdamW with any FSDP2-sharded model that has a parameter with shape[shard_dim] < world_size (e.g. Nemotron-H 30B which has [8, 128] params from Mamba layers)
  2. Run on 32 GPUs
  3. Save a checkpoint — triggers optimizer.state_dict()_state_dict_for_param()_wrap_state_as_dtensor() → ValueError

Root cause

In optimizers.py:654:

state[key] = DTensor.from_local(val, mesh, placements)

Without an explicit shape=, DTensor.from_local infers global_shape = local_shape * mesh_size, which is wrong for padded (uneven) shards. The check on lines 635-646 preemptively rejects this case with a ValueError.

Suggested fix

Pass the parameter's global shape explicitly. This is correct for both even and uneven shards:

state[key] = DTensor.from_local(
    val, mesh, placements,
    shape=param.shape, stride=param.stride(),
)

For even shards, param.shape == local_shape * mesh_size so this is a no-op. For uneven shards, it provides the true global shape so DTensor correctly accounts for padding. PyTorch's DCP has supported uneven DTensor shards since 2.1+.

The entire validation block (lines 633-646) can then be removed.

Environment

  • flashoptim 0.1.3
  • PyTorch 2.10
  • 32 GPUs (4 nodes × 8), FSDP2 (32) + EP (8), no TP/PP/CP
  • Model: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions