Support zero3 hierarchical gather in the ref sync callback#9170
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements reference model weight synchronization for RLHF trainers. It introduces a SyncRefModelCallback and a _sync_ref_model_weights method within the RolloutTrainerMixin to handle weight mixing, including support for DeepSpeed ZeRO-3. Feedback indicates that the initialization of parameter groups happens too early, which may lead to incorrect configurations when LoRA is enabled. Additionally, the synchronization method contains performance inefficiencies due to redundant dictionary creations and iterations inside loops.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a custom SyncRefModelCallback and a _sync_ref_model_weights method within RolloutTrainerMixin to support reference model weight synchronization during training. It also refactors the initialization of parameter groups to ensure they are available for this process. The review feedback suggests improving the robustness of the synchronization logic by ensuring empty parameter groups result in a no-op rather than defaulting to all parameters, and by enhancing error diagnostics to identify specific missing parameters when using DeepSpeed.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a custom SyncRefModelCallback and the _sync_ref_model_weights method within RolloutTrainerMixin to support reference model weight synchronization. It also moves the initialization of parameter groups to prepare_rollout to ensure availability regardless of vLLM usage. A review comment suggests refactoring the _sync_ref_model_weights method to reduce code duplication between the DeepSpeed ZeRO-3 and standard execution paths.
fix #8095