diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 02a5d09032..f922f37104 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -608,6 +608,8 @@ reward模型参数将在PPO、GRPO中使用。 - max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。 - top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) - log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) +- rollout_importance_sampling_mode: 训推不一致校正模式,可选项为 `token_truncate`、`token_mask`、`sequence_truncate`、`sequence_mask`。默认为None,不启用校正。具体参考[文档](./GRPO/AdvancedResearch/training_inference_mismatch.md) +- rollout_importance_sampling_threshold: 重要性采样权重的阈值,用于截断或屏蔽极端权重。默认为2.0。 ##### 奖励函数参数 内置的奖励函数参考[文档](./GRPO/DeveloperGuide/reward_function.md) diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/index.rst b/docs/source/Instruction/GRPO/AdvancedResearch/index.rst index 9f833e1025..73eb9e4966 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/index.rst +++ b/docs/source/Instruction/GRPO/AdvancedResearch/index.rst @@ -11,3 +11,4 @@ Advanced Research REINFORCEPP.md CHORD.md CISPO.md + training_inference_mismatch.md diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md new file mode 100644 index 0000000000..229d35cc74 --- /dev/null +++ b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -0,0 +1,204 @@ +# Training-Inference-Mismatch + +**版本依赖**:ms-swift>=3.11 + +**TL;DR**: GRPO 引入 vLLM 加速采样过程的同时,也引入了训练-推理不一致(Training-Inference Mismatch)的问题,从而可能影响训练稳定性。本文将解释这个问题的背景、原因以及相应的解决方案。 + +## Background + +### GRPO 的基本假设 + +GRPO (Group Relative Policy Optimization) 的训练目标可以表示为: + +$$ +\mathcal{L}_{\text{GRPO}} = - \mathbb{E}_{y \sim \pi_\theta} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right] +$$ + +其中: +- $r_t(\theta) = \frac{\pi_\theta(y_t|x, y_{=3.11) +- `kl` / `k3_kl`:训练策略与 rollout 策略之间的 KL 散度(直接估计器 / K3 估计器) +- `training_ppl` / `rollout_ppl`:训练策略和 rollout 策略的困惑度 +- `log_ppl_diff`:log PPL 差异,反映分布偏移程度 +- `ppl_ratio`:PPL 比率 +- `chi2_token` / `chi2_seq`:Token/Sequence 级别的 χ² 散度 + +IS 校正指标(需设置rollout_importance_sampling_mode) +- `is_weight_mean`:平均重要性采样权重 +- `ess`:有效样本大小(Effective Sample Size) +- `clipped_frac`:被截断或屏蔽的样本比例 + +> 训推一致性指标详细说明请参考文档 [Training-Inference-Mismatch](../AdvancedResearch/training_inference_mismatch.md) + 如果设置了`log_completions`, 将保存训练动态在output对应文件夹中,包括 - step:记录时的训练步数 - prompt:模型输入 diff --git a/docs/source/Megatron-SWIFT/GRPO.md b/docs/source/Megatron-SWIFT/GRPO.md index a8aa4df0e4..c59c0880b1 100644 --- a/docs/source/Megatron-SWIFT/GRPO.md +++ b/docs/source/Megatron-SWIFT/GRPO.md @@ -15,6 +15,7 @@ Megatron GRPO 当前已支持以下功能: 以下参数或功能将在后续版本中逐步支持: - **Entropy 相关配置**:如 `top_entropy_quantile`、`log_entropy` +- **Rollout Correction(TIS/MIS)** - **Reward Model / Reward Model Plugin** - **多轮 Rollout 调度机制**(`multi_turn_scheduler`):实现多轮对话策略优化 - **优势估计器**(`advantage_estimator`):支持更复杂的策略梯度估计方法 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index c24f76074a..55356a5007 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -622,6 +622,8 @@ The hyperparameters for the reward function can be found in the [Built-in Reward - max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit. - top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md). - log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). +- rollout_importance_sampling_mode: Training-inference mismatch correction mode. Options are `token_truncate`, `token_mask`, `sequence_truncate`, `sequence_mask`. Default is None (disabled). For details, refer to the [documentation](./GRPO/AdvancedResearch/training_inference_mismatch.md). +- rollout_importance_sampling_threshold: Threshold for importance sampling weights, used for truncating or masking extreme weights. Default is 2.0. ##### Reward function parameters diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst b/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst index 0cf7cd2478..0634af42c4 100644 --- a/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst @@ -11,3 +11,4 @@ Advanced Research RLOO.md CHORD.md CISPO.md + training_inference_mismatch.md diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md new file mode 100644 index 0000000000..7f96359aca --- /dev/null +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -0,0 +1,204 @@ +# Training-Inference-Mismatch + +**Version Requirement**: ms-swift>=3.11 + +**TL;DR**: While GRPO introduces vLLM to accelerate the sampling process, it also introduces Training-Inference Mismatch issues that may affect training stability. This document explains the background, causes, and solutions to this problem. + +## Background + +### Basic Assumptions of GRPO + +The training objective of GRPO (Group Relative Policy Optimization) can be expressed as: + +$$ +\mathcal{L}_{\text{GRPO}} = - \mathbb{E}_{y \sim \pi_\theta} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right] +$$ + +Where: +- $r_t(\theta) = \frac{\pi_\theta(y_t|x, y_{=3.11): +- `kl` / `k3_kl`: KL divergence between training policy and rollout policy (direct estimator / K3 estimator) +- `training_ppl` / `rollout_ppl`: Perplexity of training policy and rollout policy +- `log_ppl_diff`: Log PPL difference, reflects the degree of distribution shift +- `ppl_ratio`: PPL ratio +- `chi2_token` / `chi2_seq`: Token/Sequence-level χ² divergence + +IS correction metrics (requires setting rollout_importance_sampling_mode): +- `is_weight_mean`: Average importance sampling weight +- `ess`: Effective Sample Size +- `clipped_frac`: Fraction of samples that were truncated or masked + +> For detailed explanation of training-inference consistency metrics, please refer to [Training-Inference-Mismatch](../AdvancedResearch/training_inference_mismatch.md) + If `log_completions` is set, the training dynamics will be saved in the output directory, including: - step: The training step at the time of logging. - prompt: The model input. diff --git a/docs/source_en/Megatron-SWIFT/GRPO.md b/docs/source_en/Megatron-SWIFT/GRPO.md index 3fa9dfb58d..06e79db35a 100644 --- a/docs/source_en/Megatron-SWIFT/GRPO.md +++ b/docs/source_en/Megatron-SWIFT/GRPO.md @@ -15,6 +15,7 @@ Megatron GRPO currently supports the following features: The following parameters or features will be gradually supported in future versions: - **Entropy-related Configuration**: e.g., `top_entropy_quantile`, `log_entropy` +- **Rollout Correction(TIS/MIS)** - **Reward Model / Reward Model Plugin** - **Multi-turn Rollout Scheduling** (`multi_turn_scheduler`): Multi-turn conversation policy optimization - **Advantage Estimator** (`advantage_estimator`): Support for more complex policy gradient estimation methods diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index cb0d2d56df..feef98b3c8 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -46,6 +46,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + logprobs_mode: Optional[str] = None, speculative_config: Optional[Union[str, dict]] = None, # lora enable_lora: bool = False, @@ -81,6 +82,7 @@ def __init__( disable_cascade_attn=disable_cascade_attn, load_format=load_format, mm_processor_cache_gb=mm_processor_cache_gb, + logprobs_mode=logprobs_mode, speculative_config=speculative_config, enable_lora=enable_lora, max_loras=max_loras, @@ -181,7 +183,7 @@ def _create_chat_completion_response(self, result, inputs, template: Template, r logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(response, template) - token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None + token_ids = output.token_ids if request_config.return_details else None choice = ChatCompletionResponseChoice( index=output.index, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/llm/infer/infer_engine/lmdeploy_engine.py index 4f7b50c0da..edbca4f5e6 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -260,7 +260,7 @@ async def _infer_full_async( toolcall = self._get_toolcall(response, template) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, output.status.name == 'FINISH') - token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None + token_ids = output.token_ids if request_config.return_details else None choices = [ ChatCompletionResponseChoice( index=0, diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index f173c4dbfc..09fbe5825f 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -426,7 +426,7 @@ def _infer_full(self, template: Template, inputs: Dict[str, Any], *, generation_ response = template.decode(generate_ids, template_inputs=template_inputs[i]) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True) toolcall = self._get_toolcall(response, template) - token_ids = template.skip_stop_tokens(generate_ids) if request_config.return_details else None + token_ids = generate_ids if request_config.return_details else None choices.append( ChatCompletionResponseChoice( index=j, diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index 37de0f845e..fe02ac8a3f 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -149,7 +149,7 @@ def _create_chat_completion_response(self, output, inputs, template, return_deta if template.template_meta.response_prefix: response = template.template_meta.response_prefix + response toolcall = self._get_toolcall(response, template) - token_ids = template.skip_stop_tokens(output['output_ids']) if return_details else None + token_ids = output['output_ids'] if return_details else None choice = ChatCompletionResponseChoice( index=0, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index e5c927cbb4..764a4e49d6 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -70,6 +70,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + logprobs_mode: Optional[str] = None, speculative_config: Optional[Union[str, dict]] = None, # lora enable_lora: bool = False, @@ -121,6 +122,7 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, enforce_eager=enforce_eager, limit_mm_per_prompt=limit_mm_per_prompt, + logprobs_mode=logprobs_mode, enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, @@ -174,6 +176,7 @@ def _prepare_engine_kwargs( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + logprobs_mode: Optional[str] = None, speculative_config: Optional[Union[str, dict]] = None, **engine_kwargs, ) -> None: @@ -205,7 +208,7 @@ def _prepare_engine_kwargs( 'The current version of vLLM does not support `limit_mm_per_prompt`. Please upgrade vLLM.') for key in [ 'enable_expert_parallel', 'enable_sleep_mode', 'disable_cascade_attn', 'load_format', - 'mm_processor_cache_gb', 'speculative_config' + 'mm_processor_cache_gb', 'speculative_config', 'logprobs_mode' ]: if key in parameters: if locals()[key] is not None: @@ -573,7 +576,7 @@ def _create_chat_completion_response( logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(content, template) # Use content instead of response for tool calls - token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None + token_ids = output.token_ids if request_config.return_details else None choice = ChatCompletionResponseChoice( index=output.index, message=ChatMessage( diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 26d6fc9083..724275850e 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -338,6 +338,9 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): logger.info('Currently, rollout only supports the vLLM backend. Set vLLM backend') kwargs.update(args.get_vllm_engine_kwargs()) kwargs.update({'enable_lora': args.vllm_enable_lora}) # override + # Important: Use processed_logprobs so temperature scaling affects the logprobs + # This is required for correct importance sampling in rollout correction + kwargs['logprobs_mode'] = 'processed_logprobs' # used for RL external rollout backend engine_kwargs = kwargs.get('engine_kwargs', {}) # for RL rollout model weight sync diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index a178d86841..b8b6a518e8 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -343,6 +343,12 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): # dataset dataset_shuffle: Optional[bool] = True + # Rollout Importance Sampling Correction (off-policy correction) + # Set to None to disable, or choose from: 'token_truncate', 'token_mask', 'sequence_truncate', 'sequence_mask' + rollout_importance_sampling_mode: Optional[Literal['token_truncate', 'token_mask', 'sequence_truncate', + 'sequence_mask']] = None + rollout_importance_sampling_threshold: float = 2.0 # Threshold for truncation/masking (C in paper) + @dataclass class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments): diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index e2ceb54197..468c426cae 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -32,8 +32,8 @@ from ..mixin import SwiftMixin from .rollout_mixin import DataType, RolloutTrainerMixin from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator, - load_pil_img, make_chord_sft_dataset, patch_profiling_context, patch_profiling_decorator, - patch_save_last_checkpoint, replace_assistant_response_with_ids) + load_pil_img, make_chord_sft_dataset, pad_logps_back_to_batch, patch_profiling_context, + patch_profiling_decorator, patch_save_last_checkpoint, replace_assistant_response_with_ids) try: from trl.trainer.utils import entropy_from_logits @@ -242,12 +242,8 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: f'Mismatch: {len(gas_chunks)} chunks vs {len(batch_encoded_inputs)} batches' for batch, batch_encoded in zip(gas_chunks, batch_encoded_inputs): - if self.template.padding_free: - lengths = batch_encoded['seq_lengths'] - advantages_stacked = torch.stack([data['advantages'] for data in batch]) - all_advantages = torch.repeat_interleave(advantages_stacked, lengths) - else: - all_advantages = torch.stack([data['advantages'] for data in batch]) + # Advantages are always [batch_size], will be broadcast to [batch_size, seq_len] in loss computation + all_advantages = torch.stack([data['advantages'] for data in batch]) batch_encoded['advantages'] = all_advantages with patch_profiling_context(self, 'log_metrics'): @@ -422,14 +418,8 @@ def log_rewards_all(rewards_per_func: torch.Tensor): old_per_token_logps = batch_encoded['old_per_token_logps'] ref_per_token_logps = batch_encoded['ref_per_token_logps'] completion_mask = batch_encoded['completion_mask'] - if self.template.padding_free: - lengths = batch_encoded['seq_lengths'] - per_token_kl = torch.split(old_per_token_logps - ref_per_token_logps, lengths.tolist(), dim=1) - completion_masks = torch.split(completion_mask, lengths.tolist(), dim=1) - kl = torch.cat([(kl * mask).sum(-1) for kl, mask in zip(per_token_kl, completion_masks)]) - else: - per_token_kl = old_per_token_logps - ref_per_token_logps - kl = (per_token_kl * completion_mask).sum(-1) + per_token_kl = old_per_token_logps - ref_per_token_logps + kl = (per_token_kl * completion_mask).sum(-1) kl_list.append(kl) kl = torch.cat(kl_list, dim=0) @@ -815,8 +805,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: with torch.no_grad(): batch_encoded_inputs['old_per_token_logps'] = ( - self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] - if self.old_policy() or self.kl_in_reward else None) + self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0]) if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: @@ -828,15 +817,75 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps + # Extract vLLM logprobs if available for importance sampling + if self.use_fast_infer: + vllm_logprobs_list = [] + for data in batch: + if 'vllm_logprobs' in data: + vllm_logprobs_list.append(data['vllm_logprobs']) + else: + vllm_logprobs_list.append(None) + + # Convert to tensor if all samples have vllm_logprobs + if all(lp is not None for lp in vllm_logprobs_list): + if self.template.padding_free: + # In padding_free mode, seq_lengths includes prompts for sequences after the first one + # (because logits_to_keep only removes the first prompt). + # But vllm_logprobs only contains completion logprobs for each sequence. + # So we need to use the actual vllm_logprobs lengths, not seq_lengths. + # + # We pad each sequence's vllm_logprobs to match seq_lengths[i] by prepending -1e10 + # for the prompt portion that vLLM doesn't have logprobs for. + seq_lengths = batch_encoded_inputs['seq_lengths'] + vllm_logprobs_aligned = [] + for i, lp in enumerate(vllm_logprobs_list): + target_len = seq_lengths[i].item() + vllm_len = len(lp) + if vllm_len >= target_len: + # vLLM has more tokens than expected, take last target_len + vllm_logprobs_aligned.extend(lp[-target_len:]) + else: + # vLLM has fewer tokens (only completion), pad the front (prompt portion) + pad_len = target_len - vllm_len + vllm_logprobs_aligned.extend([-1e10] * pad_len + list(lp)) + + # Convert to tensor and pad to batch format + vllm_logprobs_rmpad = torch.tensor( + vllm_logprobs_aligned, dtype=torch.float32, device=self.accelerator.device).unsqueeze(0) + + batch_size = seq_lengths.shape[0] + vllm_logps_padded, _ = pad_logps_back_to_batch( + logps_rmpad=vllm_logprobs_rmpad, + logits_to_keep=logits_to_keep, + batch_size=batch_size, + seq_lengths=seq_lengths, + dtype=torch.float32) + batch_encoded_inputs['vllm_per_token_logps'] = vllm_logps_padded + else: + # Standard mode: simple padding + max_len = logits_to_keep + padded_logprobs = [] + for lp in vllm_logprobs_list: + # Take last logits_to_keep tokens + lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp + # Pad if needed + if len(lp_tensor) < max_len: + # right padding + lp_tensor = lp_tensor + [-1e10] * (max_len - len(lp_tensor)) + padded_logprobs.append(lp_tensor) + batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( + padded_logprobs, dtype=torch.float32, device=self.accelerator.device) + else: + batch_encoded_inputs['vllm_per_token_logps'] = None + else: + batch_encoded_inputs['vllm_per_token_logps'] = None + ga_batch_encoded_inputs.append(batch_encoded_inputs) # --- log completion lengths --- mode = 'train' if self.model.training else 'eval' device = self.accelerator.device - if self.template.padding_free: - local_lengths = [inp['seq_lengths'].tolist() for inp in ga_batch_encoded_inputs] - else: - local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs] + local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs] total_lengths = self._gather_and_flatten(local_lengths, dtype=torch.float32, device=device, flatten_level=1) # Store num_items_in_batch for DAPO loss (total completion tokens across all processes) @@ -921,11 +970,11 @@ def _compute_loss_single(self, model, inputs): def _compute_loss_and_metrics(self, model, inputs): """Core loss computation without metrics recording.""" mode = 'train' if self.model.training else 'eval' - - completion_mask = inputs['completion_mask'] - truncated_mask = inputs['truncated_mask'] if self.template.padding_free: - lengths = inputs['seq_lengths'] + completion_mask = inputs['completion_mask_padded'] + else: + completion_mask = inputs['completion_mask'] + truncated_mask = inputs['truncated_mask'] per_token_logps, entropies = self._get_per_token_logps_and_entropies( model, inputs, compute_entropy=self.compute_entropy) @@ -936,11 +985,7 @@ def _compute_loss_and_metrics(self, model, inputs): # fill the padded token with NaN entropies = entropies.masked_fill(completion_mask == 0, float('nan')) if self.args.log_entropy: - if self.template.padding_free: - entropy_list = torch.split(entropies, lengths.tolist()) - per_completion_entropies_mean = torch.stack([torch.nanmean(e) for e in entropy_list]) - else: - per_completion_entropies_mean = torch.nanmean(entropies, dim=1) + per_completion_entropies_mean = torch.nanmean(entropies, dim=1) global_per_completion_entropies_mean = gather(per_completion_entropies_mean) entropy_metrics = { 'entropy_logs': global_per_completion_entropies_mean.tolist(), @@ -960,11 +1005,7 @@ def _compute_loss_and_metrics(self, model, inputs): if all(truncated_mask): logger.info('All completions are overlong and truncated, ' 'resulting in NaN some values for some metrics (e.g., KL)') - if self.template.padding_free: - truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) - assert truncated_mask.shape == completion_mask.shape - else: - truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask) + truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask) completion_mask = completion_mask & (~truncated_mask) # Compute the KL divergence between the model and the reference model @@ -983,32 +1024,48 @@ def _compute_loss_and_metrics(self, model, inputs): old_per_token_logps = ( per_token_logps.detach() if inputs['old_per_token_logps'] is None else inputs['old_per_token_logps']) + # Compute rollout diagnostic metrics and apply IS correction if enabled + rollout_correction_metrics = {} + if inputs.get('vllm_per_token_logps') is not None and not self.disable_rollout_importance_sampling: + vllm_per_token_logps = inputs['vllm_per_token_logps'] + + # Always compute diagnostic metrics (KL, PPL, etc.) for monitoring off-policy gap + # This helps diagnose whether rollout correction is needed + rollout_correction_metrics = self._compute_rollout_offpolicy_metrics(old_per_token_logps, + vllm_per_token_logps, completion_mask) + + # Apply importance sampling correction if mode is enabled + if self.rollout_importance_sampling_mode is not None: + # Compute the log ratio between policy model and vLLM rollout model + # log π_θ(y|x) - log π_vllm(y|x) + vllm_log_ratio = old_per_token_logps - vllm_per_token_logps + + # Apply importance sampling correction based on mode + rollout_is_weights = self._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Compute additional IS-specific metrics (ESS, clipped_frac, is_weight_mean) + is_metrics = self._compute_is_correction_metrics(vllm_log_ratio, rollout_is_weights, completion_mask) + rollout_correction_metrics.update(is_metrics) + + # Store IS weights for loss computation + inputs['rollout_is_weights'] = rollout_is_weights + else: + inputs['rollout_is_weights'] = None + else: + inputs['rollout_is_weights'] = None + log_ratio = per_token_logps - old_per_token_logps if self.importance_sampling_level == 'token': log_importance_weights = log_ratio elif self.importance_sampling_level in ['sequence', 'sequence_token']: - if self.template.padding_free: - # split to batch, compute seq-level normalization - log_ratio_list = torch.split(log_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)] - seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1) - if self.importance_sampling_level == 'sequence': - log_importance_weights = seq_level_log_weights - else: - seq_level_log_weight = seq_level_log_weights.detach() - seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, lengths).unsqueeze(0) - log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + seq_level_log_weights = ((log_ratio * completion_mask).sum(-1) + / completion_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1) + if self.importance_sampling_level == 'sequence': + log_importance_weights = seq_level_log_weights else: - seq_level_log_weights = ((log_ratio * completion_mask).sum(-1) - / completion_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1) - if self.importance_sampling_level == 'sequence': - log_importance_weights = seq_level_log_weights - else: - # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] - seq_level_log_weight = seq_level_log_weights.detach() - log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight - + # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] + seq_level_log_weight = seq_level_log_weights.detach() + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: raise ValueError( f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " @@ -1018,47 +1075,32 @@ def _compute_loss_and_metrics(self, model, inputs): if self.loss_type == 'cispo': clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() - if self.template.padding_free: - advantages = advantages[-coef_1.shape[1]:] - per_token_loss = -clamped_ratios * advantages.unsqueeze(0) * per_token_logps - else: - per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps + per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) if self.args.delta is not None: coef_1 = torch.clamp(coef_1, max=self.args.delta) - if self.template.padding_free: - if self.importance_sampling_level == 'sequence': - # Expand sequence-level weights to token-level - coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths).unsqueeze(0) - coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths).unsqueeze(0) - - advantages = advantages[-coef_1.shape[1]:] - per_token_loss1 = coef_1 * advantages.unsqueeze(0) - per_token_loss2 = coef_2 * advantages.unsqueeze(0) - else: - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask if per_token_kl is not None: per_token_loss = per_token_loss + self.beta * per_token_kl + # Apply vLLM importance sampling weights if available + if inputs.get('rollout_is_weights') is not None and self.rollout_importance_sampling_mode is not None: + rollout_is_weights = inputs['rollout_is_weights'] + per_token_loss = per_token_loss * rollout_is_weights + if self.loss_type == 'grpo': - if self.template.padding_free: - loss_list = torch.split(per_token_loss.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) - for loss, mask in zip(loss_list, mask_list)] - loss = torch.stack(sample_loss).mean() - else: - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + # completion_mask is now always [batch_size, seq_len] after pad_back + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() elif self.loss_type == 'bnpo': loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': - batch_size = lengths.shape[0] if self.template.padding_free else inputs['input_ids'].shape[0] + batch_size = completion_mask.shape[0] loss = (per_token_loss * completion_mask).sum() / (batch_size * self.max_completion_length) elif self.loss_type in ['cispo', 'dapo']: # CISPO and DAPO: Normalize by total completion tokens across all processes @@ -1088,23 +1130,20 @@ def masked_batch_mean(x): mean_kl = masked_batch_mean(per_token_kl) metrics_data['kl'] = self.accelerator.gather_for_metrics(mean_kl).nanmean().item() + # Add rollout correction metrics + if rollout_correction_metrics: + metrics_data['rollout_correction'] = rollout_correction_metrics + # Compute the clipped probability ratios if self.loss_type == 'cispo': # CISPO: Only track upper bound clipping - if self.template.padding_free: - is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(0) > 0) - else: - is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather_for_metrics(cispo_clip_ratio) metrics_data['clipping'] = {'cispo_clip_ratio': gathered_cispo_clip_ratio.nanmean().item()} else: - if self.template.padding_free: - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(0) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(0) > 0) - else: - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) is_region_clipped = is_low_clipped | is_high_clipped low_clip = masked_batch_mean(is_low_clipped.float()) @@ -1146,6 +1185,12 @@ def _update_metrics(self, metrics_data): if 'kl' in metrics_data: self._metrics[mode]['kl'].append(metrics_data['kl']) + # Update vLLM correction metrics + if 'rollout_correction' in metrics_data: + rollout_metrics = metrics_data['rollout_correction'] + for key, value in rollout_metrics.items(): + self._metrics[mode][f'rollout_correction/{key}'].append(value) + # Update clipping metrics if 'clipping' in metrics_data: clipping = metrics_data['clipping'] @@ -1314,7 +1359,7 @@ def _get_per_token_logps_and_entropies_sp( k: v for k, v in inputs.items() if k not in [ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', - 'truncated_mask', 'seq_lengths', 'num_items_in_batch' + 'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'vllm_per_token_logps' ] } sequence_parallel.prepare_inputs(inputs) @@ -1367,8 +1412,16 @@ def _get_per_token_logps_and_entropies_single(self, compute_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.template.sequence_parallel_size > 1: return self._get_per_token_logps_and_entropies_sp(model, inputs, compute_entropy=compute_entropy) + logits_to_keep = inputs['logits_to_keep'] input_ids = inputs['input_ids'] + is_padding_free = self.template.padding_free + + # Store metadata for padding_free restoration + if is_padding_free: + original_seq_lengths = inputs.get('seq_lengths') + batch_size = original_seq_lengths.shape[0] + unwrapped_model = self.accelerator.unwrap_model(model) if is_peft_model(unwrapped_model): parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters @@ -1376,13 +1429,10 @@ def _get_per_token_logps_and_entropies_single(self, parameters = inspect.signature(unwrapped_model.forward).parameters use_local_entropy = not hasattr(super(), '_get_per_token_logps_and_entropies') and compute_entropy - can_use_super = (not self.is_multimodal and 'logits_to_keep' in parameters and not use_local_entropy) - if 'attention_mask' not in inputs: - # when set padding_free true, the attention_mask is not in inputs - can_use_super = False + can_use_super = (not self.is_multimodal and 'logits_to_keep' in parameters and not use_local_entropy + and not is_padding_free) if can_use_super: - # save memory if hasattr(super(), '_get_per_token_logps_and_entropies'): logps, entropies = super()._get_per_token_logps_and_entropies( model, input_ids, inputs['attention_mask'], logits_to_keep, compute_entropy=compute_entropy) @@ -1390,24 +1440,76 @@ def _get_per_token_logps_and_entropies_single(self, logps = super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep) entropies = None else: - inputs = { + model_inputs = { k: v for k, v in inputs.items() if k not in [ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', - 'truncated_mask', 'seq_lengths', 'num_items_in_batch' + 'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'vllm_per_token_logps' ] } if 'logits_to_keep' in self.model_kwarg_keys: - inputs['logits_to_keep'] = logits_to_keep + 1 - logits = model(**inputs).logits - # exclude the last logit: it corresponds to the next token pred - logits = logits[:, -(logits_to_keep + 1):-1, :] - logits = logits / self.temperature - input_ids = input_ids[:, -logits_to_keep:] - logps = selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens - entropies = None - if compute_entropy: - entropies = entropy_from_logits(logits) + model_inputs['logits_to_keep'] = logits_to_keep + 1 + + # Forward pass + logits = model(**model_inputs).logits + + # Extract relevant portion and apply temperature + logits = logits[:, -(logits_to_keep + 1):-1, :] / self.temperature + input_ids_for_logps = input_ids[:, -logits_to_keep:] + + # Compute on rmpad, then pad back + if is_padding_free: + # In padding_free mode, compute logps on flattened tensors + logits_rmpad = logits.squeeze(0) # [total_nnz, vocab_size] + input_ids_rmpad = input_ids_for_logps.squeeze(0) # [total_nnz] + + # Compute logps on rmpad tensors + per_token_logps_rmpad = selective_log_softmax(logits_rmpad, input_ids_rmpad) # [total_nnz] + + # Compute entropy if needed + if compute_entropy: + entropy_rmpad = entropy_from_logits(logits_rmpad) # [total_nnz] + else: + entropy_rmpad = None + + # Restore to batch shape using seq_lengths + logps, padded_shape_mask = pad_logps_back_to_batch( + logps_rmpad=per_token_logps_rmpad.unsqueeze(0), # [1, total_nnz] + logits_to_keep=logits_to_keep, + batch_size=batch_size, + seq_lengths=original_seq_lengths) + + # Also restore entropy if computed + if compute_entropy: + entropies, _ = pad_logps_back_to_batch( + logps_rmpad=entropy_rmpad.unsqueeze(0), + logits_to_keep=logits_to_keep, + batch_size=batch_size, + seq_lengths=original_seq_lengths) + else: + entropies = None + + # In padding_free mode, the original completion_mask is [1, logits_to_keep] (flattened). + # We need to convert it to [batch_size, max_seq_len] format. + # The original mask correctly identifies completion vs prompt tokens. + if 'completion_mask_padded' not in inputs: + original_completion_mask = inputs['completion_mask'] # [1, logits_to_keep] + completion_mask_padded, _ = pad_logps_back_to_batch( + logps_rmpad=original_completion_mask.float(), # [1, logits_to_keep] + logits_to_keep=logits_to_keep, + batch_size=batch_size, + seq_lengths=original_seq_lengths, + pad_value=0.0) + # Combine with shape mask to ensure padding positions are also masked + inputs['completion_mask_padded'] = completion_mask_padded + + else: + logps = selective_log_softmax(logits, input_ids_for_logps) + + if compute_entropy: + entropies = entropy_from_logits(logits) + else: + entropies = None return logps, entropies @@ -1488,7 +1590,7 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep): k: v for k, v in inputs.items() if k not in [ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', - 'truncated_mask', 'seq_lengths', 'num_items_in_batch' + 'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'vllm_per_token_logps' ] } if 'logits_to_keep' in self.model_kwarg_keys: @@ -1912,6 +2014,10 @@ def _prepare_algorithm_params(self): self.advantage_estimator = args.advantage_estimator self.kl_in_reward = args.kl_in_reward + # Rollout Importance Sampling Correction + self.rollout_importance_sampling_mode = args.rollout_importance_sampling_mode + self.rollout_importance_sampling_threshold = args.rollout_importance_sampling_threshold + def _prepare_chord_dataset(self): # CHORD, https://arxiv.org/abs/2508.11408 self.chord_sft_iterator = None @@ -2028,3 +2134,256 @@ def single_sample_context(): with single_sample_context(): self.truncated_resample_iterator = cyclic_iter(self.get_train_dataloader()) + + def _compute_sequence_level_ratios(self, is_ratio: torch.Tensor, completion_mask: torch.Tensor) -> torch.Tensor: + """ + Helper function to compute sequence-level importance sampling ratios. + + Args: + is_ratio: Token-level IS ratios, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Sequence-level ratios as geometric mean of token-level ratios + """ + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + + return seq_ratios + + def _apply_rollout_importance_sampling(self, rollout_log_ratio: torch.Tensor, + completion_mask: torch.Tensor) -> torch.Tensor: + """ + Apply vLLM importance sampling correction using one of four modes. + + Args: + rollout_log_ratio: log(π_θ / π_rollout) per token, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + IS weights to multiply with loss, same shape as rollout_log_ratio + """ + mode = self.rollout_importance_sampling_mode + threshold = self.rollout_importance_sampling_threshold + + # Clamp log_ratio to prevent numerical overflow from padding values (-1e10) + # A log_ratio of 20 corresponds to exp(20) ≈ 485 million, which is already extreme + SAFETY_BOUND = 20.0 + rollout_log_ratio_safe = torch.clamp(rollout_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + + # Compute importance sampling ratios: exp(log_ratio) + is_ratio = torch.exp(rollout_log_ratio_safe) + + if mode == 'token_truncate': + # Token-level truncated IS: clip ratios from above at threshold + is_weights = torch.clamp(is_ratio, max=threshold) + + elif mode == 'token_mask': + # Token-level masked IS: mask out tokens with ratio > threshold + is_weights = torch.where(is_ratio <= threshold, is_ratio, torch.zeros_like(is_ratio)) + + elif mode == 'sequence_truncate': + # Sequence-level truncated IS: compute sequence-level ratio and clip + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) + + is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) + + elif mode == 'sequence_mask': + # Sequence-level masked IS: mask entire sequences with ratio > threshold + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + seq_mask = (seq_ratios <= threshold).float() + + # Apply mask to original token-level ratios + is_weights = is_ratio * seq_mask.unsqueeze(-1) + else: + return is_ratio + + return is_weights + + def _compute_rollout_offpolicy_metrics( + self, + per_token_logps: torch.Tensor, + rollout_per_token_logps: torch.Tensor, + completion_mask: torch.Tensor, + ) -> Dict[str, float]: + """ + Compute off-policy diagnostic metrics (always computed for monitoring). + reference: verl/verl/trainer/ppo/rollout_corr_helper.py + + These metrics help diagnose the off-policy gap between rollout and training policies, + which can arise from policy mismatch (e.g., vLLM BF16 vs FSDP FP32), model staleness, + or general distribution shifts. + + Key metrics: + - kl: Direct KL divergence estimator KL(π_rollout || π_training) + - k3_kl: K3 KL estimator for stability (more stable for small KL) + - training_ppl: Perplexity of training policy + - rollout_ppl: Perplexity of rollout policy + - log_ppl_diff: Difference in log perplexities + - ppl_ratio: Ratio of training PPL to rollout PPL + - chi2_token: Token-level χ² divergence E[ρ²] - 1 + - chi2_seq: Sequence-level χ² divergence E[(∏ρ_t)²] - 1 + + Args: + per_token_logps: Log probs from training policy model, shape [B, T] + rollout_per_token_logps: Log probs from rollout policy, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Dictionary with off-policy diagnostic metrics + """ + SAFETY_BOUND = 20.0 + metrics = {} + + # Helper function for masked mean + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) + + # 1. Training policy perplexity (always computed) + # Formula: exp(-1/|T| * Σ log π_training(y_t|y_ Dict[str, float]: + """ + Compute importance sampling correction metrics (ess, clipped_frac, is_weight_mean). + Only called when rollout_importance_sampling_mode is enabled. + + Args: + vllm_log_ratio: Log ratio log(π_policy / π_rollout), shape [B, T] + is_weights: Importance sampling weights after correction, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Dictionary with IS-specific metrics: + - is_weight_mean: Mean of IS weights + - ess: Effective Sample Size = 1 / E[(w_i / E[w_i])²] + - clipped_frac: Fraction of clipped/masked samples + """ + metrics = {} + SAFETY_BOUND = 20.0 + threshold = self.rollout_importance_sampling_threshold + threshold_lower = 1.0 / threshold # Default lower threshold (reciprocal of upper) + + # Helper function for masked mean + def masked_mean(x, mask): + return (x * mask).sum() / mask.sum().clamp(min=1.0) + + # Compute IS ratio with safety bounds + log_ratio_safe = torch.clamp(vllm_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + is_ratio = torch.exp(log_ratio_safe) + + # 1. IS weight statistics + mean_is_weight = masked_mean(is_weights, completion_mask) + metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() + + # 2. Compute Effective Sample Size (ESS) for IS weights + # ESS = 1 / E[(w_i / E[w_i])²] (using clamped weights for stability) + # This measures how many "effective" independent samples we have after IS weighting + weights_for_ess = is_weights.clamp(min=threshold_lower, max=threshold) + mean_for_ess = masked_mean(weights_for_ess, completion_mask) + is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero + ess = 1.0 / masked_mean(is_weights_normalized.square(), completion_mask).clamp(min=1e-10) + metrics['ess'] = self.accelerator.gather_for_metrics(ess).nanmean().item() + + # 3. Fraction of clipped/masked samples + if self.rollout_importance_sampling_mode in ['token_truncate', 'token_mask']: + # Token-level + if self.rollout_importance_sampling_mode == 'token_truncate': + clipped_frac = masked_mean((is_ratio > threshold).float(), completion_mask) + else: # token_mask + clipped_frac = masked_mean((is_weights == 0).float(), completion_mask) + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() + else: + # Sequence-level (both truncate and mask) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + clipped_frac = (seq_ratios > threshold).float().mean() + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() + + return metrics diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index 61837c2d05..8dcf1caa99 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -100,6 +100,8 @@ def _prepare_rollout_params(self): self.completion_length_limit_scope = args.completion_length_limit_scope self.async_generate = args.async_generate + # Enable logprobs for vLLM importance sampling if requested + self.request_config = RequestConfig( n=1, max_tokens=args.max_completion_length, @@ -108,7 +110,10 @@ def _prepare_rollout_params(self): top_k=args.top_k, repetition_penalty=args.repetition_penalty, stop=args.stop_words, - return_details=True) + return_details=True, + logprobs=args.use_vllm) + + self.disable_rollout_importance_sampling = False def _prepare_vllm(self): """Initialize vLLM engine (server or colocate mode)""" @@ -144,6 +149,10 @@ def _prepare_vllm(self): self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + if self.enable_server_multi_turn: + if getattr(args, 'rollout_importance_sampling_mode', None) is not None: + logger.warning('Rollout importance sampling is disabled for server multi-turn mode') + self.disable_rollout_importance_sampling = True self.rollout_enable_lora = broadcast_object_list(enable_lora, from_process=0)[0] if self.use_gym_env: self.reward_func_names = ['gym_reward'] @@ -226,6 +235,7 @@ def _prepare_vllm_engine(self): template=vllm_template, distributed_executor_backend='external_launcher', engine_kwargs=self.args.vllm_engine_kwargs, + logprobs_mode='processed_logprobs', **lora_kwargs, ) set_expandable_segments(True) @@ -839,9 +849,17 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out if output.rollout_infos: input_data['rollout_infos'] = output.rollout_infos + # Extract vLLM logprobs for importance sampling if available + if choice.logprobs is not None: + # Extract logprobs from the response + # logprobs format: {'content': [{'token': ..., 'logprob': ..., 'bytes': ...}, ...]} + if 'content' in choice.logprobs: + vllm_logprobs = [item['logprob'] for item in choice.logprobs['content']] + input_data['vllm_logprobs'] = vllm_logprobs + input_data['finish_reason'] = choice.finish_reason input_data['is_truncated'] = choice.finish_reason == 'length' - input_data['add_eos'] = not choice.finish_reason == 'length' + input_data['add_eos'] = False if output.rollout_infos: multi_modal_keys = ['images', 'videos', 'audios'] for key in multi_modal_keys: @@ -928,6 +946,10 @@ def _prepare_scheduler(self): return if args.multi_turn_scheduler: + if getattr(args, 'rollout_importance_sampling_mode', None) is not None: + # TODO + logger.warning('Rollout importance sampling mode is not supported for multi-turn scheduler') + self.disable_rollout_importance_sampling = True if isinstance(args.multi_turn_scheduler, str): assert args.multi_turn_scheduler in multi_turns multi_turn_scheduler = multi_turns[args.multi_turn_scheduler](max_turns=args.max_turns) diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 72c7ab5d29..ee2f91790b 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -1124,3 +1124,103 @@ def get_even_process_data(trainer, global_data: List[T]) -> List[T]: end = start + base_size return global_data[start:end] + + +# ============================================================================ +# Padding-free utilities +# ============================================================================ + + +def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + logits_to_keep: int = None, + batch_size: int = None, + seq_lengths: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + pad_value: float = -1e10) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Restore padding-free logprobs back to [batch_size, seq_len] shape with LEFT PADDING. + + - Input: logps in rmpad format [1, total_nnz] + - Output: logps in batch format [batch_size, max_seq_len] with data right-aligned + + Args: + logps_rmpad: [1, total_nnz] per-token log probabilities in padding_free format + position_ids: [1, total_nnz] position ids to determine sequence boundaries (deprecated, use seq_lengths) + logits_to_keep: number of tokens to keep per sequence (= max_seq_len) + batch_size: number of sequences in the batch + seq_lengths: [batch_size] actual sequence lengths (preferred over position_ids) + dtype: optional dtype for output, defaults to logps_rmpad.dtype + pad_value: value to use for padding positions (default: -1e10 for logps, use 0.0 for masks) + + Returns: + logps_padded: [batch_size, logits_to_keep] padded log probabilities (left-padded, data right-aligned) + valid_mask: [batch_size, logits_to_keep] mask indicating valid (non-padding) positions + """ + if dtype is None: + dtype = logps_rmpad.dtype + + device = logps_rmpad.device + + # Determine sequence lengths + if seq_lengths is not None: + # Use provided seq_lengths directly - they should already be adjusted + # by the caller (e.g., in _generate_and_score_completions) + # DO NOT adjust again here to avoid double adjustment + pass + else: + # Fallback: infer from position_ids + from swift.utils.torch_utils import get_cu_seqlens_from_position_ids as get_cu_seqlens + cu_seqlens = get_cu_seqlens(position_ids) + + # Adjust cu_seqlens for logits_to_keep if needed + total_length = cu_seqlens[-1].item() + if total_length > logits_to_keep: + # Adjust the first sequence length + adjustment = total_length - logits_to_keep + cu_seqlens = cu_seqlens - adjustment + cu_seqlens[0] = 0 # First element should always be 0 + + # Compute actual sequence lengths + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + + # Compute cumulative sequence lengths + cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=device), seq_lengths]), dim=0) + max_seq_len = logits_to_keep # All sequences will be padded to this length + + # Initialize output tensors with padding value + logps_padded = torch.full((batch_size, max_seq_len), pad_value, dtype=dtype, device=device) + valid_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.float32, device=device) + + # Unflatten: assign each sequence's logps to the corresponding row + # Use LEFT PADDING (right-align the data) to match the standard padding convention + logps_flat = logps_rmpad.squeeze(0) # [total_nnz] + + for i in range(batch_size): + start_idx = cu_seqlens[i].item() + end_idx = cu_seqlens[i + 1].item() + seq_len = int(seq_lengths[i].item()) + + actual_end_idx = min(end_idx, len(logps_flat)) + actual_len = actual_end_idx - start_idx + + if actual_len <= 0: + continue + + # Left padding: place data at the RIGHT side of the row + # pad_len is the number of padding tokens at the beginning + pad_len = max_seq_len - seq_len + + if actual_len < seq_len: + # Input data is shorter than expected seq_len + # This happens when logps_flat doesn't have enough data + # Place actual data at the rightmost positions + data_pad_len = max_seq_len - actual_len + logps_padded[i, data_pad_len:] = logps_flat[start_idx:actual_end_idx] + valid_mask[i, data_pad_len:] = 1.0 + else: + # Normal case: seq_len tokens of data + logps_padded[i, pad_len:] = logps_flat[start_idx:end_idx] + valid_mask[i, pad_len:] = 1.0 + + return logps_padded, valid_mask diff --git a/tests/train/test_vllm_importance_sampling_basic.py b/tests/train/test_vllm_importance_sampling_basic.py new file mode 100644 index 0000000000..20e3541515 --- /dev/null +++ b/tests/train/test_vllm_importance_sampling_basic.py @@ -0,0 +1,485 @@ +""" +Basic tests for vLLM Importance Sampling implementation + +This test file verifies the core functionality of the vLLM IS correction, +including the IS weight computation and metrics calculation. + +Reference: verl/verl/trainer/ppo/rollout_corr_helper.py +""" + +import torch + + +class MockAccelerator: + """Mock accelerator for testing metrics gathering""" + + def __init__(self, device='cpu'): + self.device = device + + def gather_for_metrics(self, tensor): + # In testing, just return the tensor as-is + return tensor + + +class MockGRPOTrainer: + """Mock GRPO trainer for testing IS methods""" + + def __init__(self, mode='token_truncate', threshold=2.0): + self.rollout_importance_sampling_mode = mode + self.rollout_importance_sampling_threshold = threshold + self.accelerator = MockAccelerator() + + def _compute_sequence_level_ratios(self, is_ratio: torch.Tensor, completion_mask: torch.Tensor) -> torch.Tensor: + """ + Helper function to compute sequence-level importance sampling ratios. + + Args: + is_ratio: Token-level IS ratios, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Sequence-level ratios as geometric mean of token-level ratios + """ + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + + return seq_ratios + + def _apply_rollout_importance_sampling(self, rollout_log_ratio: torch.Tensor, + completion_mask: torch.Tensor) -> torch.Tensor: + """ + Apply vLLM importance sampling correction using one of four modes. + + Args: + rollout_log_ratio: log(π_θ / π_rollout) per token, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + IS weights to multiply with loss, same shape as rollout_log_ratio + """ + mode = self.rollout_importance_sampling_mode + threshold = self.rollout_importance_sampling_threshold + + # Clamp log_ratio to prevent numerical overflow from padding values (-1e10) + # A log_ratio of 20 corresponds to exp(20) ≈ 485 million, which is already extreme + SAFETY_BOUND = 20.0 + rollout_log_ratio_safe = torch.clamp(rollout_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + + # Compute importance sampling ratios: exp(log_ratio) + is_ratio = torch.exp(rollout_log_ratio_safe) + + if mode == 'token_truncate': + # Token-level truncated IS: clip ratios from above at threshold + is_weights = torch.clamp(is_ratio, max=threshold) + + elif mode == 'token_mask': + # Token-level masked IS: mask out tokens with ratio > threshold + is_weights = torch.where(is_ratio <= threshold, is_ratio, torch.zeros_like(is_ratio)) + + elif mode == 'sequence_truncate': + # Sequence-level truncated IS: compute sequence-level ratio and clip + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) + + is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) + + elif mode == 'sequence_mask': + # Sequence-level masked IS: mask entire sequences with ratio > threshold + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + seq_mask = (seq_ratios <= threshold).float() + + # Apply mask to original token-level ratios + is_weights = is_ratio * seq_mask.unsqueeze(-1) + else: + return is_ratio + + return is_weights + + def _compute_is_correction_metrics( + self, + vllm_log_ratio: torch.Tensor, + is_weights: torch.Tensor, + completion_mask: torch.Tensor, + ) -> dict: + """ + Compute importance sampling correction metrics (ess, clipped_frac, is_weight_mean). + Only called when rollout_importance_sampling_mode is enabled. + + Args: + vllm_log_ratio: Log ratio log(π_policy / π_rollout), shape [B, T] + is_weights: Importance sampling weights after correction, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Dictionary with IS-specific metrics: + - is_weight_mean: Mean of IS weights + - ess: Effective Sample Size = 1 / E[(w_i / E[w_i])²] + - clipped_frac: Fraction of clipped/masked samples + """ + metrics = {} + SAFETY_BOUND = 20.0 + threshold = self.rollout_importance_sampling_threshold + threshold_lower = 1.0 / threshold # Default lower threshold (reciprocal of upper) + + # Helper function for masked mean + def masked_mean(x, mask): + return (x * mask).sum() / mask.sum().clamp(min=1.0) + + # Compute IS ratio with safety bounds + log_ratio_safe = torch.clamp(vllm_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + is_ratio = torch.exp(log_ratio_safe) + + # 1. IS weight statistics + mean_is_weight = masked_mean(is_weights, completion_mask) + metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() + + # 2. Compute Effective Sample Size (ESS) for IS weights + # ESS = 1 / E[(w_i / E[w_i])²] (using clamped weights for stability) + # This measures how many "effective" independent samples we have after IS weighting + weights_for_ess = is_weights.clamp(min=threshold_lower, max=threshold) + mean_for_ess = masked_mean(weights_for_ess, completion_mask) + is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero + ess = 1.0 / masked_mean(is_weights_normalized.square(), completion_mask).clamp(min=1e-10) + metrics['ess'] = self.accelerator.gather_for_metrics(ess).nanmean().item() + + # 3. Fraction of clipped/masked samples + if self.rollout_importance_sampling_mode in ['token_truncate', 'token_mask']: + # Token-level + if self.rollout_importance_sampling_mode == 'token_truncate': + clipped_frac = masked_mean((is_ratio > threshold).float(), completion_mask) + else: # token_mask + clipped_frac = masked_mean((is_weights == 0).float(), completion_mask) + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() + else: + # Sequence-level (both truncate and mask) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + clipped_frac = (seq_ratios > threshold).float().mean() + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() + + return metrics + + +class TestVLLMImportanceSampling: + """Test suite for vLLM Importance Sampling""" + + def test_token_truncate_basic(self): + """Test token-level truncated IS""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Create mock data: [batch=2, seq_len=4] + # Log ratios that will produce ratios [0.5, 1.5, 3.0, 5.0] + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0], [0.8, 1.2, 2.5, 4.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Check truncation at threshold=2.0 + assert is_weights.shape == vllm_log_ratio.shape + assert torch.allclose(is_weights[0, 0], torch.tensor(0.5), atol=1e-5) + assert torch.allclose(is_weights[0, 1], torch.tensor(1.5), atol=1e-5) + assert torch.allclose(is_weights[0, 2], torch.tensor(2.0), atol=1e-5) # Truncated + assert torch.allclose(is_weights[0, 3], torch.tensor(2.0), atol=1e-5) # Truncated + + def test_token_mask_basic(self): + """Test token-level masked IS""" + trainer = MockGRPOTrainer(mode='token_mask', threshold=2.0) + + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Check masking: ratio > threshold should be 0 + assert torch.allclose(is_weights[0, 0], torch.tensor(0.5), atol=1e-5) + assert torch.allclose(is_weights[0, 1], torch.tensor(1.5), atol=1e-5) + assert torch.allclose(is_weights[0, 2], torch.tensor(0.0), atol=1e-5) # Masked + assert torch.allclose(is_weights[0, 3], torch.tensor(0.0), atol=1e-5) # Masked + + def test_sequence_truncate_basic(self): + """Test sequence-level truncated IS""" + trainer = MockGRPOTrainer(mode='sequence_truncate', threshold=2.0) + + # First sequence has high ratios, second has low ratios + vllm_log_ratio = torch.log( + torch.tensor([ + [3.0, 3.0, 3.0, 3.0], # geometric mean=3.0 > 2.0 + [1.0, 1.0, 1.0, 1.0] + ])) # geometric mean=1.0 < 2.0 + completion_mask = torch.ones_like(vllm_log_ratio) + + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # First sequence should be truncated to 2.0 for all tokens + assert torch.allclose(is_weights[0, :], torch.tensor(2.0), atol=1e-5) + # Second sequence should remain 1.0 + assert torch.allclose(is_weights[1, :], torch.tensor(1.0), atol=1e-5) + + def test_sequence_mask_basic(self): + """Test sequence-level masked IS""" + trainer = MockGRPOTrainer(mode='sequence_mask', threshold=2.0) + + vllm_log_ratio = torch.log( + torch.tensor([ + [3.0, 3.0, 3.0, 3.0], # geometric mean=3.0 > 2.0 + [1.0, 1.0, 1.0, 1.0] + ])) # geometric mean=1.0 < 2.0 + completion_mask = torch.ones_like(vllm_log_ratio) + + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # First sequence should be completely masked (0) + # Note: sequence_mask multiplies is_ratio by 0, so all tokens become 0 + assert torch.allclose(is_weights[0, :], torch.tensor(0.0), atol=1e-5) + # Second sequence should keep original ratios (1.0 * 1.0 = 1.0) + assert torch.allclose(is_weights[1, :], torch.tensor(1.0), atol=1e-5) + + def test_threshold_sensitivity(self): + """Test different threshold values""" + vllm_log_ratio = torch.log(torch.tensor([[1.0, 2.0, 3.0, 4.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + + # Test threshold=1.5 + trainer_low = MockGRPOTrainer(mode='token_truncate', threshold=1.5) + is_weights_low = trainer_low._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Test threshold=3.5 + trainer_high = MockGRPOTrainer(mode='token_truncate', threshold=3.5) + is_weights_high = trainer_high._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Lower threshold should truncate more + truncated_low = (is_weights_low < torch.exp(vllm_log_ratio)).sum() + truncated_high = (is_weights_high < torch.exp(vllm_log_ratio)).sum() + assert truncated_low > truncated_high + + def test_completion_mask(self): + """Test that completion mask is respected""" + trainer = MockGRPOTrainer(mode='sequence_truncate', threshold=2.0) + + vllm_log_ratio = torch.log(torch.tensor([[3.0, 3.0, 3.0, 3.0]])) + # Mask out last two tokens + completion_mask = torch.tensor([[1.0, 1.0, 0.0, 0.0]]) + + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Should only consider masked tokens for sequence ratio calculation + # With only first two tokens (both 3.0), geometric mean=3.0, truncated to 2.0 + assert torch.allclose(is_weights[0, :2], torch.tensor(2.0), atol=1e-5) + + def test_edge_cases(self): + """Test edge cases""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Case 1: All ratios below threshold + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.0, 1.5]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + assert torch.allclose(is_weights, torch.exp(vllm_log_ratio), atol=1e-5) + + # Case 2: All ratios above threshold + vllm_log_ratio = torch.log(torch.tensor([[3.0, 4.0, 5.0]])) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask[:, :3]) + assert torch.allclose(is_weights, torch.tensor(2.0), atol=1e-5) + + # Case 3: Empty mask + vllm_log_ratio = torch.log(torch.tensor([[1.0, 2.0, 3.0]])) + completion_mask = torch.zeros_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + # Should still compute but result may not be meaningful + assert is_weights.shape == vllm_log_ratio.shape + + def test_safety_bound(self): + """Test that extreme log ratios are clamped""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Create extreme log ratios that would overflow without clamping + vllm_log_ratio = torch.tensor([[100.0, -100.0, 0.0]]) # exp(100) would overflow + completion_mask = torch.ones_like(vllm_log_ratio) + + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Should not have inf or nan + assert torch.isfinite(is_weights).all() + # Large positive log_ratio should be clamped to threshold + assert is_weights[0, 0] <= 2.0 + # Large negative log_ratio should result in small positive value + assert is_weights[0, 1] > 0 + + +class TestISCorrectionMetrics: + """Test suite for IS correction metrics""" + + def test_ess_uniform_weights(self): + """Test ESS with uniform weights (should be close to 1.0)""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Uniform weights of 1.0 + vllm_log_ratio = torch.zeros((2, 4)) # exp(0) = 1.0 + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = torch.ones_like(vllm_log_ratio) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # ESS should be 1.0 for uniform weights + assert abs(metrics['ess'] - 1.0) < 0.01 + # Mean weight should be 1.0 + assert abs(metrics['is_weight_mean'] - 1.0) < 0.01 + # No clipping for uniform weights + assert metrics['clipped_frac'] == 0.0 + + def test_ess_varied_weights(self): + """Test ESS with varied weights (should be < 1.0)""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Varied weights + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.0, 1.5, 2.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = torch.tensor([[0.5, 1.0, 1.5, 2.0]]) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # ESS should be less than 1.0 for non-uniform weights + assert metrics['ess'] < 1.0 + assert metrics['ess'] > 0.0 + + def test_clipped_frac_token_truncate(self): + """Test clipped_frac for token_truncate mode""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # 2 out of 4 tokens exceed threshold + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # 2/4 = 0.5 tokens clipped + assert abs(metrics['clipped_frac'] - 0.5) < 0.01 + + def test_clipped_frac_token_mask(self): + """Test clipped_frac for token_mask mode""" + trainer = MockGRPOTrainer(mode='token_mask', threshold=2.0) + + # 2 out of 4 tokens exceed threshold + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # 2/4 = 0.5 tokens masked (is_weights == 0) + assert abs(metrics['clipped_frac'] - 0.5) < 0.01 + + def test_clipped_frac_sequence_level(self): + """Test clipped_frac for sequence-level modes""" + trainer = MockGRPOTrainer(mode='sequence_truncate', threshold=2.0) + + # First sequence exceeds threshold, second doesn't + vllm_log_ratio = torch.log(torch.tensor([[3.0, 3.0, 3.0, 3.0], [1.0, 1.0, 1.0, 1.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # 1/2 = 0.5 sequences clipped + assert abs(metrics['clipped_frac'] - 0.5) < 0.01 + + +class TestOffpolicyMetrics: + """Test suite for off-policy diagnostic metrics""" + + def test_kl_divergence_same_policy(self): + """Test KL divergence when policies are identical""" + # When per_token_logps == rollout_per_token_logps, KL should be 0 + per_token_logps = torch.tensor([[-1.0, -2.0, -1.5, -0.5]]) + rollout_per_token_logps = per_token_logps.clone() + completion_mask = torch.ones_like(per_token_logps) + + # Helper function for masked mean + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) + + # KL = E[log(π_rollout) - log(π_training)] + kl = masked_mean(rollout_per_token_logps - per_token_logps, completion_mask) + + assert abs(kl.item()) < 1e-6 + + def test_k3_kl_estimator(self): + """Test K3 KL estimator""" + per_token_logps = torch.tensor([[-1.0, -2.0, -1.5, -0.5]]) + rollout_per_token_logps = torch.tensor([[-1.1, -1.9, -1.6, -0.4]]) + completion_mask = torch.ones_like(per_token_logps) + + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) + + # K3 estimator: E[exp(log_ratio) - log_ratio - 1] + log_ratio = per_token_logps - rollout_per_token_logps + log_ratio *= completion_mask + k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 + k3_kl = masked_mean(k3_kl_matrix, completion_mask) + + # K3 KL should be non-negative + assert k3_kl.item() >= 0 + + def test_chi2_divergence(self): + """Test χ² divergence calculation""" + per_token_logps = torch.tensor([[-1.0, -2.0]]) + rollout_per_token_logps = torch.tensor([[-1.5, -1.5]]) + completion_mask = torch.ones_like(per_token_logps) + + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) + + SAFETY_BOUND = 20.0 + log_ratio = per_token_logps - rollout_per_token_logps + log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rho_token = torch.exp(log_ratio_safe) + rho_squared_token = rho_token.square() + chi2_token = masked_mean(rho_squared_token, completion_mask) - 1.0 + + # χ² should be >= -1 (can be negative if E[ρ²] < 1) + assert chi2_token.item() >= -1.0 + + +if __name__ == '__main__': + # Run tests manually + import sys + + test_classes = [ + ('TestVLLMImportanceSampling', TestVLLMImportanceSampling), + ('TestISCorrectionMetrics', TestISCorrectionMetrics), + ('TestOffpolicyMetrics', TestOffpolicyMetrics), + ] + + failed_tests = [] + + for class_name, test_class in test_classes: + print(f'\n=== {class_name} ===') + test_instance = test_class() + + test_methods = [m for m in dir(test_instance) if m.startswith('test_')] + + for method_name in test_methods: + try: + print(f'Running {method_name}...') + getattr(test_instance, method_name)() + print(f'✓ {method_name} passed') + except Exception as e: + print(f'✗ {method_name} failed: {e}') + failed_tests.append(f'{class_name}.{method_name}') + + if failed_tests: + print(f'\nFailed tests: {failed_tests}') + sys.exit(1) + else: + print('\nAll tests passed!')