-
Notifications
You must be signed in to change notification settings - Fork 998
[bugfix] fix megatron grpo server mode sync weight #6648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @hjh0119, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a bug in the Megatron GRPO trainer's server mode concerning weight synchronization. The changes refine the logic for how model weights are loaded and synchronized across distributed processes. By adjusting the conditions for initiating weight loading and performing synchronization, the PR ensures that all processes correctly participate in the weight management phase while centralizing the critical synchronization step to the main process, thereby preventing potential inconsistencies in the model's state during training or inference. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request attempts to fix an issue with weight synchronization in the Megatron GRPO server mode. While the changes correctly involve all processes in the weight export step, which likely resolves a process hang, they inadvertently introduce a critical bug. The modifications cause only the weights from the main process to be synchronized, while weights from all other processes are discarded. This leads to an incomplete model update on the vLLM server. I have provided a detailed review comment explaining the issue and suggesting a path to resolution.
| bucket_params: List of (name, tensor) tuples to sync | ||
| """ | ||
| if not bucket_params: | ||
| if not bucket_params or not self.is_main_process: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change, combined with the modification at line 248, appears to introduce a critical bug. The change at line 248 causes _load_weights_to_server_in_buckets to be called on all processes, which is likely correct if export_weights is a collective operation. However, adding or not self.is_main_process here makes _sync_bucket_to_server return immediately on non-main processes. As a result, only the weights from the main process's weight_iterator are synchronized with the vLLM server, while weights from all other processes are silently ignored. This will lead to an incompletely updated model on the server.
To fix this, you should gather bucket_params from all processes onto the main process before syncing. This would likely require changes in the calling function, _load_weights_to_server_in_buckets, to orchestrate the gathering of buckets from all ranks before the main rank performs the synchronization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The complete weights have already been gathered in the upper _load_weights_to_server_in_buckets method, ignore.
No description provided.