Skip to content

fix: use key-based matching in fsdp2_load_full_state_dict#3982

Open
roycho96 wants to merge 1 commit intohuggingface:mainfrom
roycho96:fix/fsdp2-load-state-dict-key-matching
Open

fix: use key-based matching in fsdp2_load_full_state_dict#3982
roycho96 wants to merge 1 commit intohuggingface:mainfrom
roycho96:fix/fsdp2-load-state-dict-key-matching

Conversation

@roycho96
Copy link

What does this PR do?

Replaces positional (zip) matching with key-based matching in fsdp2_load_full_state_dict() on rank 0, preventing deadlocks when full_sd and meta_sharded_sd have different numbers of keys.

Problem

fsdp2_load_full_state_dict() currently uses different iteration strategies per rank:

  • Rank 0: zip(full_sd.items(), meta_sharded_sd.values()) — positional
  • Other ranks: meta_sharded_sd.items() — key-based

This assumes full_sd and meta_sharded_sd have identical key counts and ordering. When they don't, dist.broadcast calls become misaligned across ranks, causing a deadlock (all ranks hang indefinitely).

When does this break?

Models whose state_dict() contains extra sidecar entries not present in the sharded model's state dict. A concrete example is BnB QLoRA (Params4bit) with cpu_ram_efficient_loading=True:

meta_sharded_sd (after fully_shard, all ranks):
  model.layers.0.self_attn.q_proj.weight          # DTensor
  model.layers.0.self_attn.q_proj.lora_A.weight    # DTensor
  model.layers.0.self_attn.q_proj.lora_B.weight    # DTensor
  ... (N keys total)

full_sd (original state dict, rank 0 only):
  model.layers.0.self_attn.q_proj.weight
  model.layers.0.self_attn.q_proj.weight.absmax           # BnB sidecar
  model.layers.0.self_attn.q_proj.weight.quant_map         # BnB sidecar
  model.layers.0.self_attn.q_proj.weight.quant_state.bitsandbytes__nf4  # BnB sidecar
  model.layers.0.self_attn.q_proj.lora_A.weight
  model.layers.0.self_attn.q_proj.lora_B.weight
  ... (N + M keys total, M = sidecar entries)

How the deadlock happens

With zip on rank 0, iteration is positional — the i-th entry from full_sd is paired with the i-th entry from meta_sharded_sd. When full_sd has extra sidecar keys interleaved, the pairing shifts:

Step 1:
  Rank 0 broadcasts: full_sd["q_proj.weight"]          (correct)
  Rank 1 expects:    meta_sharded_sd["q_proj.weight"]   (correct)

Step 2:
  Rank 0 broadcasts: full_sd["q_proj.weight.absmax"]    (sidecar!)
  Rank 1 expects:    meta_sharded_sd["q_proj.lora_A.weight"]   MISMATCH

→ broadcast tensor sizes differ across ranks → NCCL error or infinite hang

After N iterations (matching meta_sharded_sd size), zip stops on rank 0, but the data that was broadcast was wrong from step 2 onward. In practice, the size/dtype mismatch causes NCCL to hang waiting for matching collectives.

Fix

Use the same iteration strategy on all ranks — iterate over meta_sharded_sd.items() and look up full_sd by key:

# Before (rank 0): positional — breaks with extra keys
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):

# After (rank 0): key-based — matches other ranks' iteration
for param_name, sharded_param in meta_sharded_sd.items():
    full_param = full_sd[param_name]

This also adds a clear KeyError message when a sharded key is missing from full_sd, instead of a silent mismatch or cryptic broadcast error.

For the normal case (no extra keys), behavior is identical — same keys, same order, same broadcasts.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant