Skip to content

Conversation

@hjh0119
Copy link
Collaborator

@hjh0119 hjh0119 commented Nov 20, 2025

Summary

This PR introduces rollout correction mechanisms to address training-inference mismatch in GRPO training, and includes several related improvements.

Changes

1. Support Truncated/Masked Importance Sampling (TIS/MIS) for Training-Inference Mismatch Correction

Added importance sampling correction to handle off-policy issues

New parameters:

  • rollout_importance_sampling_mode: Correction mode selection
  • token_truncate: Token-level truncated IS (clips extreme weights)
    • token_mask: Token-level masked IS (discards extreme tokens)
    • sequence_truncate: Sequence-level truncated IS
    • sequence_mask: Sequence-level masked IS
    • rollout_importance_sampling_threshold: Upper threshold for IS weights (default: 2.0)

2. Refactor GRPO Padding-Free Mode

Refactored _get_per_token_logps_and_entropies to:

  • Compute logps on rmpad (remove-padding) tensors first
  • Pad back to batch shape after computation using pad_logps_back_to_batch
  • This makes loss_func padding-free-agnostic, simplifying downstream logic

3. Add Comprehensive Rollout Correction Metrics

Added off-policy diagnostic metrics (always logged when rollout_log_probs available):

  • KL Divergence: kl (direct estimator), k3_kl (K3 estimator for stability)
  • Perplexity: training_ppl, rollout_ppl, log_ppl_diff, ppl_ratio
  • χ² Divergence: chi2_token, chi2_seq
  • Added IS correction metrics (logged when IS mode enabled):
  • is_weight_mean: Mean of importance sampling weights
  • ess: Effective Sample Size (measures weight uniformity, range (0, 1])
  • clipped_frac: Fraction of clipped/masked samples

All metrics are prefixed with rollout_correction/ in logs.

resolve #6235

@hjh0119 hjh0119 mentioned this pull request Nov 24, 2025
1 task
@hjh0119 hjh0119 changed the title [WIP] Support GRPO TIS/MIS feat: grpo rollout correction for training-inference mismatch Nov 25, 2025
@hjh0119 hjh0119 marked this pull request as ready for review November 25, 2025 09:26
@hjh0119
Copy link
Collaborator Author

hjh0119 commented Nov 25, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a crucial feature for GRPO training: rollout correction using importance sampling to address the training-inference mismatch. The implementation is comprehensive, offering multiple correction modes and detailed diagnostic metrics. The addition of a new test suite for this functionality is highly commendable and significantly improves the reliability of the new code.

Furthermore, the refactoring of the padding-free mode is a major improvement to the codebase. By centralizing the padding logic, the downstream code, especially the loss computation, has been greatly simplified, making it more readable and maintainable.

I have a few suggestions for improvement:

  1. A minor bug in the padding logic for vllm_per_token_logps in standard (non-padding-free) mode, which should be corrected to ensure consistency with the new left-padding convention.
  2. Clarifications in comments and documentation regarding the direction of the KL divergence being estimated by the k3_kl metric to avoid potential misinterpretation by users.

Overall, this is an excellent pull request that adds significant value and improves the quality of the trainer.

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Nov 25, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant improvements to GRPO training by adding rollout correction mechanisms to address training-inference mismatch, refactoring the padding-free mode for simplicity, and adding a comprehensive set of diagnostic metrics. The changes are well-structured, and the inclusion of detailed documentation and new unit tests is commendable. My review has identified one potential issue concerning padding direction that could affect the correctness of the importance sampling calculations.

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Nov 25, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a crucial feature for GRPO training: rollout correction using Importance Sampling (IS) to mitigate training-inference mismatch. The changes are comprehensive, covering new command-line arguments, integration with vLLM for log probability extraction, refactoring of padding-free logic for consistent tensor shapes, and the implementation of various IS correction modes (token/sequence, truncate/mask) along with detailed diagnostic metrics. The addition of a dedicated test file for the IS correction logic is highly commendable, ensuring the robustness and correctness of this new functionality. The refactoring of _get_per_token_logps_and_entropies to compute log probabilities on flattened tensors and then pad them back to batch shape is a significant improvement for handling padding-free mode more consistently.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

support Truncated Importance Sampling (TIS)

2 participants