Fix OOM regression for FSDP2 + cpu_ram_efficient_loading on large models#45649
Fix OOM regression for FSDP2 + cpu_ram_efficient_loading on large models#45649AmineDiro wants to merge 1 commit intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
View the CircleCI Test Summary for this PR: https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45649&sha=74480d |
|
cc @Cyrilvallez I think |
albertvillanova
left a comment
There was a problem hiding this comment.
Thanks a lot for the clear diagnosis and the fix: Skip CPU param materialization on non-rank-0 FSDP ranks to avoid OOM
The OOM regression in #45050 is real: zeros_like forces an immediate physical-memory commit (page fault on every zero write), whereas empty_like relies on overcommit/lazy allocation. Note this was already commented by @ArthurZucker: https://github.com/huggingface/transformers/pull/45050/changes#r3029107360
the reason I don't want this is because its costly!
Let me trace through the full flow after the change to confirm:
- On non-rank-0 FSDP ranks:
- Parameters stay on meta device: zero physical memory committed
- Buffers (both persistent and non-persistent) get real CPU
zeros_likeplaceholders
- Then
_initialize_missing_keys(PR #44473) marks state-dict parameters (now meta tensors) as_is_hf_initialized = True.initialize_weights()then runs: for RotaryEmbedding, inv_freq and original_inv_freq are non-persistent buffers, so they are not in state_dict(), not marked, and _init_weights correctly computes and copies their values into the real CPU zero tensors - Accelerate's
fsdp2_prepare_modelthen:- Saves non-persistent buffers (now correctly initialized by _init_weights) from each rank
- Moves the model to meta; parameters that were already on meta: no-op
- Applies fully_shard
- fsdp2_load_full_state_dict broadcasts from rank-0 into all ranks: parameters receive correct values
- Restores non-persistent buffers from each rank's saved copy
The original NaN bug is still fixed: parameters that _init_weights skips (marked as initialized) are subsequently overwritten by the broadcast with rank-0's values. The difference from #45050 is that we never pay the cost of materializing them on non-rank-0 in the first place.
The fix is correct, targeted, and eliminates the OOM without reintroducing the NaN regression (I have confirmed this). 🤗
| return | ||
|
|
||
| # In this case we need to move everything back | ||
| # Leave parameters on meta on non-rank-0 FSDP ranks (rank-0 broadcast overwrites them); only buffers need real placeholders. |
There was a problem hiding this comment.
Nit: I think the comment is right, but it under-specifies the mechanism:
- Parameters can stay on meta:
- accelerate's fsdp2_prepare_model moves the whole model to meta before fully_shard
- then fsdp2_load_full_state_dict broadcasts rank-0's state_dict to all ranks
- Only buffers need real allocations:
- persistent buffers are also broadcast
- but non-persistent ones (RoPE caches etc.) are per-rank and must be initialized locally by _init_weights
What does this PR do?
PR #45050 replaces
torch.empty_likewithtorch.zeros_likein_move_missing_keys_from_meta_to_device. While this fixes a real issue (NaN garbage in uninitialized memory), it forces a physical-memory commit of the entire model on every non-rank-0 FSDP rank.With 8 ranks per node loading a 30B model, peak cpu mem jumps from ~60 GB to ~480 GB :/
The regression was identified by bisecting transformers commits between 2026-04-10 (working) and 2026-04-22 (failing) using a 2-node FSDP2 control config:
a001f34439(pre-#45050)ff49f7c4cb(PR #45050)Test config:
Qwen/Qwen3-30B-A3B, FSDP2, 2 nodes × 8 H100, DP=16, sdpa, max_steps=5,fsdp_cpu_ram_efficient_loading=true.The placeholder values on non-rank-0 ranks for state-dict params are immediately overwritte by
fsdp2_load_full_state_dictduring accelerate's FSDP2 prepare.acceleratemoves the entire model tometadevice before sharding inaccelerate.utils.fsdp_utils.fsdp2_prepare_modelSo allocating CPU placeholders for parameters on non-rank-0 ranks is unnecessary work. The parameters can stay on meta. Btw, from what I can understand buffers (RoPE caches, attention masks, etc.) are per-rank and not part of the broadcast, so they still need real allocations.
Fixes # (issue)
Code Agent Policy
Before submitting
Who can review?
@albertvillanova @ArthurZucker