[bugfix] Fix lora llm resume from checkpoint#9225
Conversation
…nto support_kimi_k26
There was a problem hiding this comment.
Code Review
This pull request removes an outdated comment about checkpoint support in the SFT script and modifies the LoRA LLM tuner plugin to ensure that vision tower and aligner modules are set to require gradients after loading. Feedback was provided to improve the robustness of the new logic by adding checks for the existence of model architecture metadata and its attributes to avoid potential attribute errors.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request enables support for resuming from checkpoints in multimodal LoRA training by updating the DeepSpeed integration. Key changes include removing the training status check in Qwen model forwards to ensure consistency with DeepSpeed, and implementing a more robust parameter loading mechanism in LoRALLMTuner that handles DeepSpeed ZeRO-3's partitioned parameters. A performance improvement was suggested to gather all parameters at once when using ZeRO-3 to reduce communication overhead during state dict loading.
| params_dict = dict(model.named_parameters()) | ||
| for name, tensor in state_dict.items(): | ||
| if name in params_dict: | ||
| param = params_dict[name] | ||
| with deepspeed.zero.GatheredParameters([param], modifier_rank=0): | ||
| if deepspeed.comm.get_rank() == 0: | ||
| param.data.copy_(tensor) |
There was a problem hiding this comment.
For better performance when using DeepSpeed ZeRO-3, it's more efficient to gather all relevant parameters at once rather than iterating and gathering them one by one. This minimizes communication overhead.
| params_dict = dict(model.named_parameters()) | |
| for name, tensor in state_dict.items(): | |
| if name in params_dict: | |
| param = params_dict[name] | |
| with deepspeed.zero.GatheredParameters([param], modifier_rank=0): | |
| if deepspeed.comm.get_rank() == 0: | |
| param.data.copy_(tensor) | |
| params_dict = dict(model.named_parameters()) | |
| params_to_load = {name: params_dict[name] for name in state_dict if name in params_dict} | |
| if params_to_load: | |
| with deepspeed.zero.GatheredParameters(list(params_to_load.values()), modifier_rank=0): | |
| if deepspeed.comm.get_rank() == 0: | |
| for name, param in params_to_load.items(): | |
| param.data.copy_(state_dict[name]) |
No description provided.