Support FIPO#9328
Conversation
# Conflicts: # swift/megatron/arguments/megatron_args.py # swift/megatron/trainers/grpo_trainer.py
There was a problem hiding this comment.
Code Review
This pull request implements FIPO (Future-KL Influenced Policy Optimization), including documentation, training scripts, and integration into Megatron and RLHF trainers. Reviewers identified potential numerical stability issues in weight calculations and a discrepancy between the safety threshold logic and its documentation. Recommendations were also made to log FIPO-specific metrics and refactor the loss function to minimize code duplication.
|
thanks for your contribution is this pr ready to be merged? |
|
Yes. Please let me know if there are any remaining concerns or changes needed. |
|
Thanks for your contribution! I've left a few comments |
|
Thanks for your review! I've addressed the comments and pushed the updates. Please let me know if anything else needs to be changed. |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces FIPO (Future-KL Influenced Policy Optimization), a value-free RL method designed for long-chain reasoning, along with comprehensive documentation and training examples. The implementation adds FIPO-specific arguments and loss logic to both the Megatron and standard RLHF trainers, including metrics tracking for future-KL and influence weights. Review feedback suggests removing an extra blank line in the English documentation for consistency and recommends refactoring the duplicated _compute_fipo_influence logic into a shared utility to improve maintainability and reduce code duplication.
|
|
||
| ## Parameters | ||
|
|
||
|
|
| def _compute_fipo_influence(self, log_ratio: torch.Tensor, coef_1: torch.Tensor, advantages: torch.Tensor, | ||
| completion_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
| """Compute FIPO token-level influence weight from Future-KL divergence.""" | ||
| future_kl_delta = log_ratio.masked_fill(~completion_mask, 0.0) | ||
|
|
||
| # Dual-Clip participation mask: high-ratio tokens do not contribute to Future-KL. | ||
| if self.args.delta is not None: | ||
| delta = torch.as_tensor(self.args.delta, dtype=log_ratio.dtype, device=log_ratio.device) | ||
| high_ratio_mask = coef_1 > delta | ||
| future_kl_delta = torch.where(high_ratio_mask, torch.zeros_like(future_kl_delta), future_kl_delta) | ||
|
|
||
| seq_len = future_kl_delta.shape[1] | ||
| future_kl = torch.zeros_like(future_kl_delta) | ||
| positions = torch.arange(seq_len, device=log_ratio.device).unsqueeze(1) | ||
| gamma = torch.as_tensor(self.fipo_gamma, dtype=log_ratio.dtype, device=log_ratio.device) | ||
| chunk_size = 128 | ||
| for chunk_start in range(0, seq_len, chunk_size): | ||
| chunk_end = min(seq_len, chunk_start + chunk_size) | ||
| chunk_positions = torch.arange(chunk_start, chunk_end, device=log_ratio.device).unsqueeze(0) | ||
| distance = chunk_positions - positions | ||
| future_mask = distance >= 0 | ||
| decay_block = torch.pow(gamma, distance.clamp(min=0)) * future_mask.to(log_ratio.dtype) | ||
| future_kl += torch.matmul(future_kl_delta[:, chunk_start:chunk_end], decay_block.t()) | ||
| future_kl = future_kl.masked_fill(~completion_mask, 0.0) | ||
|
|
||
| influence_weight = torch.exp(future_kl) | ||
|
|
||
| if self.fipo_clip_range: | ||
| high = 1 + self.fipo_clip_range | ||
| low = 1.0 if self.fipo_clip_high_only else 1 - self.fipo_clip_range | ||
| influence_weight = torch.clamp(influence_weight, min=low, max=high) | ||
| influence_weight = influence_weight.detach() | ||
|
|
||
| # avoid amplifying negative-advantage tokens with very high IS ratios. | ||
| safety_mask = torch.ones_like(completion_mask, dtype=torch.bool) | ||
| if self.fipo_safety_threshold is not None: | ||
| negative_advantage = advantages.unsqueeze(1) < 0 | ||
| high_is_ratio = coef_1 > self.fipo_safety_threshold | ||
| safety_mask = ~(negative_advantage & high_is_ratio) | ||
| influence_weight = torch.where(safety_mask, influence_weight, | ||
| torch.clamp(influence_weight, min=0.8, max=1.0)) | ||
|
|
||
| metrics = { | ||
| 'future_kl': future_kl, | ||
| 'influence_weight': influence_weight, | ||
| 'safety_mask': safety_mask, | ||
| } | ||
| return influence_weight, metrics |
There was a problem hiding this comment.
This function _compute_fipo_influence is nearly identical to the one in swift/megatron/trainers/grpo_trainer.py. To improve maintainability and avoid code duplication, consider refactoring this logic into a shared utility function that both trainers can call. This will make future updates easier and reduce the risk of inconsistencies.
|
LGTM thanks! |
PR type
PR information
FIPO is a value-free RL algorithm proposed in FIPO: Future-guided Importance Policy Optimization. It is designed to improve long-form reasoning training by refining token-level credit assignment in GRPO/DAPO-style optimization.
Algorithm Overview
FIPO keeps the standard DAPO scaffold, but changes how token-level updates are weighted. The local signal is the signed log-probability shift between the current and old policy:
Positive values mean the token is being reinforced, while negative values mean it is being suppressed. Since reasoning is sequential, FIPO then accumulates this signal over the future trajectory:
Positive
FutureKL_tmeans the future following tokentis being reinforced; negativeFutureKL_tmeans it is being suppressed. The decay window keeps the signal local enough to stay stable, while the mask removes extreme-ratio outliers.FIPO maps this future signal into a bounded influence weight:
The final token-level FIPO loss keeps the standard clipped PPO/DAPO form, but replaces the original advantage with the future-aware one:
Tokens that lead into preferred futures are amplified, while tokens that lead into suppressed futures are attenuated. Clipping keeps this modulation stable. The final DAPO-style loss therefore stays clipped and simple, but the advantage term becomes future-aware rather than uniformly inherited from the final outcome.
Reference
Paper: FIPO: Future-guided Importance Policy Optimization