fix: Handle FSDP-sharded parameters in LoRA ParamWrapper get_delta_weight#3102
fix: Handle FSDP-sharded parameters in LoRA ParamWrapper get_delta_weight#3102BillionClaw wants to merge 2 commits intohuggingface:mainfrom
Conversation
…ight When using FSDP with MoE layers (num_experts > 1), the LoRA weights (lora_A, lora_B) are sharded across GPUs. The get_delta_weight method was failing with a RuntimeError because the reshape operation expected the full tensor shape but only received the sharded portion. This fix adds fsdp_summon_full_params_ctx context manager that: 1. Detects if a module is FSDP-wrapped by checking for _is_fsdp_managed_module attribute or FullyShardedDataParallel class name 2. Uses FSDP.summon_full_params() to gather full parameters when needed 3. Is a no-op for non-FSDP modules The ParamWrapper.get_delta_weight method now wraps the weight access in this context manager to ensure full parameters are available for the reshape and einsum operations. Fixes huggingface#3080
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for adding this PR. I actually had a very similar solution in the works. Before merging:
- For consistency with
gather_fsdp_params_ctx, it would make sense IMO to change the signature offsdp_summon_full_params_ctxto accept*modules. AsFSDP.summon_full_paramsonly accepts a single module, though, we'd have to chain one context per module, which can be done withcotextlib.ExitStack. - We should add a test to ensure this works and continues working. For our normal CI, we don't run multi-GPU tests, but we can add a test to the nightly test suite, which has two GPUs. For this, refer to how existing FSDP tests are called:
Line 69 in 9f63491
- Update fsdp_summon_full_params_ctx to accept *modules (variadic) for consistency with gather_params_ctx pattern - Use contextlib.ExitStack to chain contexts since FSDP.summon_full_params only accepts a single module - Update usage in ParamWrapper.get_delta_weight to pass both modules in a single call - Add multi-GPU test for FSDP + ParamWrapper to nightly test suite
|
Thanks for the review @BenjaminBossan! I've addressed both requests:
Please let me know if any further changes are needed. |
|
Thanks for the review @BenjaminBossan! I've addressed both requests:
Please let me know if any further changes are needed. |
|
Thanks for the feedback! I've updated the signature to accept *modules using ExitStack for consistency with gather_fsdp_params_ctx, and added a test to exercise the code path. |
|
Hi @BenjaminBossan, all review feedback has been addressed:
All changes are in commit `8cd0129`. Please let me know if any further adjustments are needed. |
|
Hi @BenjaminBossan, all review feedback has been addressed in the latest commit (8cd0129):
All syntax checks pass. Please let me know if any further adjustments are needed. |
|
Hi @BenjaminBossan, all review feedback has been addressed in the latest commit (8cd0129):
Please let me know if any further adjustments are needed. |
|
@BillionClaw Thanks for making these changes. However, the test you added doesn't work, as there is no Also, please configure your bot to only post once and not multiple times. |
Problem
When using FSDP (FullyShardedDataParallel) with MoE layers (
num_experts > 1) viatarget_parameters, the LoRA weights (lora_A, lora_B) are sharded across GPUs. TheParamWrapper.get_delta_weightmethod was failing with aRuntimeErrorduring the reshape operation because it expected the full tensor shape but only received the sharded portion.Error from issue #3080:
Solution
Added
fsdp_summon_full_params_ctxcontext manager inintegrations.pythat:_is_fsdp_managed_moduleattribute orFullyShardedDataParallelclass nameFSDP.summon_full_params()to gather full parameters when neededModified
ParamWrapper.get_delta_weightto wrap the weight access in this context manager, ensuring full parameters are available for the reshape and einsum operations.Changes
src/peft/utils/integrations.py: Addedfsdp_summon_full_params_ctxcontext managersrc/peft/tuners/lora/layer.py: UpdatedParamWrapper.get_delta_weightto use the new context manager for bothlora_Aandlora_BmodulesTesting
test_target_parameterstests passtest_gpu_examples.py::TestFSDPWrap::test_fsdp_auto_wrap_policy_does_not_raise_on_custom_modelpassesFixes #3080
I have read the CLA Document and I hereby sign the CLA