-
Notifications
You must be signed in to change notification settings - Fork 998
feat: grpo rollout correction for training-inference mismatch #6678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a crucial feature for GRPO training: rollout correction using importance sampling to address the training-inference mismatch. The implementation is comprehensive, offering multiple correction modes and detailed diagnostic metrics. The addition of a new test suite for this functionality is highly commendable and significantly improves the reliability of the new code.
Furthermore, the refactoring of the padding-free mode is a major improvement to the codebase. By centralizing the padding logic, the downstream code, especially the loss computation, has been greatly simplified, making it more readable and maintainable.
I have a few suggestions for improvement:
- A minor bug in the padding logic for
vllm_per_token_logpsin standard (non-padding-free) mode, which should be corrected to ensure consistency with the new left-padding convention. - Clarifications in comments and documentation regarding the direction of the KL divergence being estimated by the
k3_klmetric to avoid potential misinterpretation by users.
Overall, this is an excellent pull request that adds significant value and improves the quality of the trainer.
docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md
Outdated
Show resolved
Hide resolved
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant improvements to GRPO training by adding rollout correction mechanisms to address training-inference mismatch, refactoring the padding-free mode for simplicity, and adding a comprehensive set of diagnostic metrics. The changes are well-structured, and the inclusion of detailed documentation and new unit tests is commendable. My review has identified one potential issue concerning padding direction that could affect the correctness of the importance sampling calculations.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a crucial feature for GRPO training: rollout correction using Importance Sampling (IS) to mitigate training-inference mismatch. The changes are comprehensive, covering new command-line arguments, integration with vLLM for log probability extraction, refactoring of padding-free logic for consistent tensor shapes, and the implementation of various IS correction modes (token/sequence, truncate/mask) along with detailed diagnostic metrics. The addition of a dedicated test file for the IS correction logic is highly commendable, ensuring the robustness and correctness of this new functionality. The refactoring of _get_per_token_logps_and_entropies to compute log probabilities on flattened tensors and then pad them back to batch shape is a significant improvement for handling padding-free mode more consistently.
Summary
This PR introduces rollout correction mechanisms to address training-inference mismatch in GRPO training, and includes several related improvements.
Changes
1. Support Truncated/Masked Importance Sampling (TIS/MIS) for Training-Inference Mismatch Correction
Added importance sampling correction to handle off-policy issues
New parameters:
2. Refactor GRPO Padding-Free Mode
Refactored _get_per_token_logps_and_entropies to:
3. Add Comprehensive Rollout Correction Metrics
Added off-policy diagnostic metrics (always logged when rollout_log_probs available):
All metrics are prefixed with rollout_correction/ in logs.
resolve #6235