From c6dcaac3fb33f59ceb5250703ca3d288a9bc0639 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 20 Nov 2025 16:20:49 +0800 Subject: [PATCH 01/21] init --- .../infer/infer_engine/grpo_vllm_engine.py | 2 + swift/llm/infer/infer_engine/vllm_engine.py | 5 +- swift/trainers/arguments.py | 5 + swift/trainers/rlhf_trainer/grpo_trainer.py | 262 ++++++++++++++++++ swift/trainers/rlhf_trainer/rollout_mixin.py | 17 +- .../test_vllm_importance_sampling_basic.py | 257 +++++++++++++++++ 6 files changed, 546 insertions(+), 2 deletions(-) create mode 100644 tests/train/test_vllm_importance_sampling_basic.py diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 18b626a505..54690097cf 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -46,6 +46,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + logprobs_mode: Optional[str] = None, # lora enable_lora: bool = False, max_loras: int = 1, @@ -80,6 +81,7 @@ def __init__( disable_cascade_attn=disable_cascade_attn, load_format=load_format, mm_processor_cache_gb=mm_processor_cache_gb, + logprobs_mode=logprobs_mode, enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 62b1d77f87..7ee696be4c 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -70,6 +70,7 @@ def __init__( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + logprobs_mode: Optional[str] = None, # lora enable_lora: bool = False, max_loras: int = 1, @@ -120,6 +121,7 @@ def __init__( disable_custom_all_reduce=disable_custom_all_reduce, enforce_eager=enforce_eager, limit_mm_per_prompt=limit_mm_per_prompt, + logprobs_mode=logprobs_mode, enable_lora=enable_lora, max_loras=max_loras, max_lora_rank=max_lora_rank, @@ -172,6 +174,7 @@ def _prepare_engine_kwargs( disable_cascade_attn: bool = False, load_format: str = 'auto', mm_processor_cache_gb: Optional[float] = None, + logprobs_mode: Optional[str] = None, **engine_kwargs, ) -> None: if task == 'embedding': @@ -202,7 +205,7 @@ def _prepare_engine_kwargs( 'The current version of vLLM does not support `limit_mm_per_prompt`. Please upgrade vLLM.') for key in [ 'enable_expert_parallel', 'enable_sleep_mode', 'disable_cascade_attn', 'load_format', - 'mm_processor_cache_gb' + 'mm_processor_cache_gb', 'logprobs_mode' ]: if key in parameters: if locals()[key] is not None: diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 42f6afdcdd..81895b3dc4 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -338,6 +338,11 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): # dataset dataset_shuffle: Optional[bool] = True + # Rollout Importance Sampling Correction (off-policy correction) + rollout_importance_sampling_mode: Literal['token_truncate', 'token_mask', 'sequence_truncate', + 'sequence_mask'] = 'token_truncate' + rollout_importance_sampling_threshold: float = 2.0 # Threshold for truncation/masking (C in paper) + @dataclass class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments): diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 3fcafb4bcf..e54f5884d9 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -828,6 +828,42 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] batch_encoded_inputs['ref_per_token_logps'] = ref_per_token_logps + # Extract vLLM logprobs if available for importance sampling + if self.use_fast_infer: + vllm_logprobs_list = [] + for data in batch: + if 'vllm_logprobs' in data: + vllm_logprobs_list.append(data['vllm_logprobs']) + else: + vllm_logprobs_list.append(None) + + # Convert to tensor if all samples have vllm_logprobs + if all(lp is not None for lp in vllm_logprobs_list): + if self.template.padding_free: + # For padding-free mode, concatenate all logprobs + vllm_logprobs_flat = [] + for lp in vllm_logprobs_list: + vllm_logprobs_flat.extend(lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp) + batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( + vllm_logprobs_flat, dtype=torch.float32, device=self.accelerator.device).unsqueeze(0) + else: + # For standard mode, pad to match completion length + max_len = logits_to_keep + padded_logprobs = [] + for lp in vllm_logprobs_list: + # Take last logits_to_keep tokens + lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp + # Pad if needed + if len(lp_tensor) < max_len: + lp_tensor = [0.0] * (max_len - len(lp_tensor)) + lp_tensor + padded_logprobs.append(lp_tensor) + batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( + padded_logprobs, dtype=torch.float32, device=self.accelerator.device) + else: + batch_encoded_inputs['vllm_per_token_logps'] = None + else: + batch_encoded_inputs['vllm_per_token_logps'] = None + ga_batch_encoded_inputs.append(batch_encoded_inputs) # --- log completion lengths --- @@ -983,6 +1019,32 @@ def _compute_loss_and_metrics(self, model, inputs): old_per_token_logps = ( per_token_logps.detach() if inputs['old_per_token_logps'] is None else inputs['old_per_token_logps']) + # Apply vLLM importance sampling correction if enabled + rollout_correction_metrics = {} + if inputs.get('vllm_per_token_logps') is not None: + vllm_per_token_logps = inputs['vllm_per_token_logps'] + # Compute the log ratio between policy model and vLLM rollout model + # log π_θ(y|x) - log π_vllm(y|x) + vllm_log_ratio = per_token_logps - vllm_per_token_logps + + # Apply importance sampling correction based on mode + if self.rollout_importance_sampling_mode: + rollout_is_weights = self._apply_rollout_importance_sampling( + vllm_log_ratio, completion_mask, lengths if self.template.padding_free else None) + else: + rollout_is_weights = None + + # Compute and log correction metrics if enabled + rollout_correction_metrics = self._compute_rollout_correction_metrics( + per_token_logps, vllm_per_token_logps, rollout_is_weights, completion_mask, + lengths if self.template.padding_free else None) + + # Apply IS weights: multiply the final loss by the IS weight + # Store for later application in loss computation + inputs['rollout_is_weights'] = rollout_is_weights + else: + inputs['rollout_is_weights'] = None + log_ratio = per_token_logps - old_per_token_logps if self.importance_sampling_level == 'token': log_importance_weights = log_ratio @@ -1046,6 +1108,11 @@ def _compute_loss_and_metrics(self, model, inputs): if per_token_kl is not None: per_token_loss = per_token_loss + self.beta * per_token_kl + # Apply vLLM importance sampling weights if available + if inputs.get('rollout_is_weights') is not None: + rollout_is_weights = inputs['rollout_is_weights'] + per_token_loss = per_token_loss * rollout_is_weights + if self.loss_type == 'grpo': if self.template.padding_free: loss_list = torch.split(per_token_loss.squeeze(0), lengths.tolist()) @@ -1088,6 +1155,10 @@ def masked_batch_mean(x): mean_kl = masked_batch_mean(per_token_kl) metrics_data['kl'] = self.accelerator.gather_for_metrics(mean_kl).nanmean().item() + # Add rollout correction metrics + if rollout_correction_metrics: + metrics_data['rollout_correction'] = rollout_correction_metrics + # Compute the clipped probability ratios if self.loss_type == 'cispo': # CISPO: Only track upper bound clipping @@ -1146,6 +1217,12 @@ def _update_metrics(self, metrics_data): if 'kl' in metrics_data: self._metrics[mode]['kl'].append(metrics_data['kl']) + # Update vLLM correction metrics + if 'rollout_correction' in metrics_data: + rollout_metrics = metrics_data['rollout_correction'] + for key, value in rollout_metrics.items(): + self._metrics[mode][f'rollout_correction/{key}'].append(value) + # Update clipping metrics if 'clipping' in metrics_data: clipping = metrics_data['clipping'] @@ -1912,6 +1989,10 @@ def _prepare_algorithm_params(self): self.advantage_estimator = args.advantage_estimator self.kl_in_reward = args.kl_in_reward + # Rollout Importance Sampling Correction + self.rollout_importance_sampling_mode = args.rollout_importance_sampling_mode + self.rollout_importance_sampling_threshold = args.rollout_importance_sampling_threshold + def _prepare_chord_dataset(self): # CHORD, https://arxiv.org/abs/2508.11408 self.chord_sft_iterator = None @@ -2028,3 +2109,184 @@ def single_sample_context(): with single_sample_context(): self.truncated_resample_iterator = cyclic_iter(self.get_train_dataloader()) + + def _apply_rollout_importance_sampling(self, + vllm_log_ratio: torch.Tensor, + completion_mask: torch.Tensor, + lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Apply vLLM importance sampling correction using one of four modes. + + Args: + vllm_log_ratio: log(π_θ / π_vllm) per token, shape [B, T] or [1, total_T] for padding_free + completion_mask: Boolean mask for completion tokens + lengths: Sequence lengths for padding_free mode + + Returns: + IS weights to multiply with loss, same shape as vllm_log_ratio + """ + mode = self.rollout_importance_sampling_mode + threshold = self.rollout_importance_sampling_threshold + + # Compute importance sampling ratios: exp(log_ratio) + is_ratio = torch.exp(vllm_log_ratio) + + if mode == 'token_truncate': + # Token-level truncated IS: clip ratios from above at threshold + is_weights = torch.clamp(is_ratio, max=threshold) + + elif mode == 'token_mask': + # Token-level masked IS: mask out tokens with ratio > threshold + is_weights = torch.where(is_ratio <= threshold, is_ratio, torch.zeros_like(is_ratio)) + + elif mode == 'sequence_truncate': + # Sequence-level truncated IS: compute sequence-level ratio and clip + if self.template.padding_free: + # Split by sequence lengths + ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + + seq_ratios = [] + for ratio, mask in zip(ratio_list, mask_list): + # Geometric mean approximation: exp(mean(log(ratio))) + log_ratio = torch.log(ratio.clamp(min=1e-10)) + seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) + seq_ratios.append(seq_ratio) + + seq_ratios = torch.stack(seq_ratios) + # Clip sequence ratios + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) + # Broadcast back to tokens + is_weights = torch.repeat_interleave(clipped_seq_ratios, lengths).unsqueeze(0) + else: + # Standard mode: [B, T] + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) + is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) + + elif mode == 'sequence_mask': + # Sequence-level masked IS: mask entire sequences with ratio > threshold + if self.template.padding_free: + ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + + seq_ratios = [] + for ratio, mask in zip(ratio_list, mask_list): + log_ratio = torch.log(ratio.clamp(min=1e-10)) + seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) + seq_ratios.append(seq_ratio) + + seq_ratios = torch.stack(seq_ratios) + # Mask sequences with ratio > threshold + seq_mask = (seq_ratios <= threshold).float() + is_weights = torch.repeat_interleave(seq_mask, lengths).unsqueeze(0) + else: + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + seq_mask = (seq_ratios <= threshold).float() + is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) + else: + raise ValueError(f'Unknown rollout importance sampling mode: {mode}') + + return is_weights + + def _compute_rollout_correction_metrics( + self, + per_token_logps: torch.Tensor, + rollout_per_token_logps: torch.Tensor, + is_weights: torch.Tensor, + completion_mask: torch.Tensor, + lengths: Optional[torch.Tensor] = None, + ) -> Dict[str, float]: + """ + Compute rollout correction metrics: KL, PPL, chi-square, ESS. + + Args: + per_token_logps: Log probs from policy model + rollout_per_token_logps: Log probs from rollout + is_weights: Importance sampling weights + completion_mask: Boolean mask for completion tokens + lengths: Sequence lengths for padding_free mode + + Returns: + Dictionary with metrics + """ + metrics = {} + + # Compute log ratios + log_ratio = per_token_logps - rollout_per_token_logps + is_ratio = torch.exp(log_ratio) + + # Helper function for masked mean + def masked_mean(x, mask): + if self.template.padding_free: + # x: [1, T], mask: [1, T] + return (x.squeeze(0) * mask.squeeze(0)).sum() / mask.squeeze(0).sum().clamp(min=1.0) + else: + # x: [B, T], mask: [B, T] + return (x * mask).sum() / mask.sum().clamp(min=1.0) + + # 1. KL divergence: KL(π_θ || π_rollout) ≈ E[log(π_θ/π_rollout)] + kl_div = masked_mean(log_ratio, completion_mask) + metrics['kl_rollout'] = self.accelerator.gather_for_metrics(kl_div).nanmean().item() + + # 2. Perplexity: exp(-mean_log_prob) + rollout_ppl = torch.exp(-masked_mean(rollout_per_token_logps, completion_mask)) + policy_ppl = torch.exp(-masked_mean(per_token_logps, completion_mask)) + metrics['ppl_rollout'] = self.accelerator.gather_for_metrics(rollout_ppl).nanmean().item() + metrics['ppl_policy'] = self.accelerator.gather_for_metrics(policy_ppl).nanmean().item() + + # 3. Chi-square divergence: E[(π_θ/π_rollout - 1)^2] = E[(ratio - 1)^2] + chi_square = masked_mean((is_ratio - 1.0)**2, completion_mask) + metrics['chi_square'] = self.accelerator.gather_for_metrics(chi_square).nanmean().item() + + # 4. Effective Sample Size (ESS): 1 / E[w^2] where w = π_θ/π_rollout + # For sequence-level ESS, we compute per-sequence ratios + if self.template.padding_free: + ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + + seq_ratios = [] + for ratio, mask in zip(ratio_list, mask_list): + log_r = torch.log(ratio.clamp(min=1e-10)) + seq_ratio = torch.exp((log_r * mask).sum() / mask.sum().clamp(min=1.0)) + seq_ratios.append(seq_ratio) + + seq_ratios = torch.stack(seq_ratios) + else: + log_r = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_r * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + + # ESS = N / (1 + var(w)) ≈ N / mean(w^2) for normalized weights + # But we use unnormalized: ESS = 1 / mean(w^2) + mean_ratio_squared = (seq_ratios**2).mean() + ess = 1.0 / mean_ratio_squared.clamp(min=1e-10) + metrics['ess'] = self.accelerator.gather_for_metrics(ess).nanmean().item() + + # 5. IS weight statistics + mean_is_weight = masked_mean(is_weights, completion_mask) + metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() + + # Fraction of clipped/masked samples + if self.rollout_importance_sampling_mode in ['token_truncate', 'token_mask']: + # Token-level + threshold = self.rollout_importance_sampling_threshold + if self.rollout_importance_sampling_mode == 'token_truncate': + clipped_frac = masked_mean((is_ratio > threshold).float(), completion_mask) + else: # token_mask + clipped_frac = masked_mean((is_weights == 0).float(), completion_mask) + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() + else: + # Sequence-level + threshold = self.rollout_importance_sampling_threshold + if self.rollout_importance_sampling_mode == 'sequence_truncate': + clipped_frac = (seq_ratios > threshold).float().mean() + else: # sequence_mask + clipped_frac = (is_weights.view(-1)[0] == 0).float() # Any token masked means seq masked + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() + + return metrics diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index 3a82e54944..7a00563ebb 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -100,6 +100,11 @@ def _prepare_rollout_params(self): self.completion_length_limit_scope = args.completion_length_limit_scope self.async_generate = args.async_generate + # Enable logprobs for vLLM importance sampling if requested + # TODO: check if logprobs is needed + use_vllm_is = args.rollout_importance_sampling_mode and args.use_vllm + logprobs = use_vllm_is + self.request_config = RequestConfig( n=1, max_tokens=args.max_completion_length, @@ -108,7 +113,8 @@ def _prepare_rollout_params(self): top_k=args.top_k, repetition_penalty=args.repetition_penalty, stop=args.stop_words, - return_details=True) + return_details=True, + logprobs=logprobs) def _prepare_vllm(self): """Initialize vLLM engine (server or colocate mode)""" @@ -229,6 +235,7 @@ def _prepare_vllm_engine(self): template=vllm_template, distributed_executor_backend='external_launcher', engine_kwargs=self.args.vllm_engine_kwargs, + logprobs_mode='processed_logprobs', **lora_kwargs, ) set_expandable_segments(True) @@ -842,6 +849,14 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out if output.rollout_infos: input_data['rollout_infos'] = output.rollout_infos + # Extract vLLM logprobs for importance sampling if available + if choice.logprobs is not None: + # Extract logprobs from the response + # logprobs format: {'content': [{'token': ..., 'logprob': ..., 'bytes': ...}, ...]} + if 'content' in choice.logprobs: + vllm_logprobs = [item['logprob'] for item in choice.logprobs['content']] + input_data['vllm_logprobs'] = vllm_logprobs + input_data['finish_reason'] = choice.finish_reason input_data['is_truncated'] = choice.finish_reason == 'length' input_data['add_eos'] = not choice.finish_reason == 'length' diff --git a/tests/train/test_vllm_importance_sampling_basic.py b/tests/train/test_vllm_importance_sampling_basic.py new file mode 100644 index 0000000000..d44f8ef94a --- /dev/null +++ b/tests/train/test_vllm_importance_sampling_basic.py @@ -0,0 +1,257 @@ +""" +Basic tests for vLLM Importance Sampling implementation + +This test file verifies the core functionality of the vLLM IS correction. +""" + +import torch + + +class MockGRPOTrainer: + """Mock GRPO trainer for testing IS methods""" + + def __init__(self, mode='token_truncate', threshold=2.0): + self.vllm_importance_sampling_mode = mode + self.vllm_importance_sampling_threshold = threshold + self.template = MockTemplate() + + def _apply_vllm_importance_sampling(self, vllm_log_ratio, completion_mask, lengths=None): + """Copy of the implementation from grpo_trainer.py""" + mode = self.vllm_importance_sampling_mode + threshold = self.vllm_importance_sampling_threshold + + is_ratio = torch.exp(vllm_log_ratio) + + if mode == 'token_truncate': + is_weights = torch.clamp(is_ratio, max=threshold) + + elif mode == 'token_mask': + is_weights = torch.where(is_ratio <= threshold, is_ratio, torch.zeros_like(is_ratio)) + + elif mode == 'sequence_truncate': + if self.template.padding_free: + ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + + seq_ratios = [] + for ratio, mask in zip(ratio_list, mask_list): + log_ratio = torch.log(ratio.clamp(min=1e-10)) + seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) + seq_ratios.append(seq_ratio) + + seq_ratios = torch.stack(seq_ratios) + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) + is_weights = torch.repeat_interleave(clipped_seq_ratios, lengths).unsqueeze(0) + else: + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) + is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) + + elif mode == 'sequence_mask': + if self.template.padding_free: + ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + + seq_ratios = [] + for ratio, mask in zip(ratio_list, mask_list): + log_ratio = torch.log(ratio.clamp(min=1e-10)) + seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) + seq_ratios.append(seq_ratio) + + seq_ratios = torch.stack(seq_ratios) + seq_mask = (seq_ratios <= threshold).float() + is_weights = torch.repeat_interleave(seq_mask, lengths).unsqueeze(0) + else: + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + seq_mask = (seq_ratios <= threshold).float() + is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) + else: + raise ValueError(f'Unknown mode: {mode}') + + return is_weights + + +class MockTemplate: + + def __init__(self, padding_free=False): + self.padding_free = padding_free + + +class TestVLLMImportanceSampling: + """Test suite for vLLM Importance Sampling""" + + def test_token_truncate_basic(self): + """Test token-level truncated IS""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Create mock data: [batch=2, seq_len=4] + # Log ratios that will produce ratios [0.5, 1.5, 3.0, 5.0] + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0], [0.8, 1.2, 2.5, 4.0]])) + completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + + # Check truncation at threshold=2.0 + assert is_weights.shape == vllm_log_ratio.shape + assert torch.allclose(is_weights[0, 0], torch.tensor(0.5), atol=1e-5) + assert torch.allclose(is_weights[0, 1], torch.tensor(1.5), atol=1e-5) + assert torch.allclose(is_weights[0, 2], torch.tensor(2.0), atol=1e-5) # Truncated + assert torch.allclose(is_weights[0, 3], torch.tensor(2.0), atol=1e-5) # Truncated + + def test_token_mask_basic(self): + """Test token-level masked IS""" + trainer = MockGRPOTrainer(mode='token_mask', threshold=2.0) + + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) + completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + + # Check masking: ratio > threshold should be 0 + assert torch.allclose(is_weights[0, 0], torch.tensor(0.5), atol=1e-5) + assert torch.allclose(is_weights[0, 1], torch.tensor(1.5), atol=1e-5) + assert torch.allclose(is_weights[0, 2], torch.tensor(0.0), atol=1e-5) # Masked + assert torch.allclose(is_weights[0, 3], torch.tensor(0.0), atol=1e-5) # Masked + + def test_sequence_truncate_basic(self): + """Test sequence-level truncated IS""" + trainer = MockGRPOTrainer(mode='sequence_truncate', threshold=2.0) + + # First sequence has high ratios, second has low ratios + vllm_log_ratio = torch.log(torch.tensor([ + [3.0, 3.0, 3.0, 3.0], # avg=3.0 > 2.0 + [1.0, 1.0, 1.0, 1.0] + ])) # avg=1.0 < 2.0 + completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + + # First sequence should be truncated to 2.0 for all tokens + assert torch.allclose(is_weights[0, :], torch.tensor(2.0), atol=1e-5) + # Second sequence should remain 1.0 + assert torch.allclose(is_weights[1, :], torch.tensor(1.0), atol=1e-5) + + def test_sequence_mask_basic(self): + """Test sequence-level masked IS""" + trainer = MockGRPOTrainer(mode='sequence_mask', threshold=2.0) + + vllm_log_ratio = torch.log(torch.tensor([ + [3.0, 3.0, 3.0, 3.0], # avg=3.0 > 2.0 + [1.0, 1.0, 1.0, 1.0] + ])) # avg=1.0 < 2.0 + completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + + # First sequence should be completely masked (0) + assert torch.allclose(is_weights[0, :], torch.tensor(0.0), atol=1e-5) + # Second sequence should remain 1.0 + assert torch.allclose(is_weights[1, :], torch.tensor(1.0), atol=1e-5) + + def test_padding_free_mode(self): + """Test padding-free mode""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + trainer.template.padding_free = True + + # Simulate padding-free: [1, total_tokens] = [1, 6] for two sequences of len 4 and 2 + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0, 0.8, 1.2]])) + completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + lengths = torch.tensor([4, 2]) # Two sequences: len=4 and len=2 + + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask, lengths) + + # Should have same shape as input + assert is_weights.shape == vllm_log_ratio.shape + # Check truncation: first sequence tokens 2,3 should be truncated to 2.0 + assert torch.allclose(is_weights[0, 2], torch.tensor(2.0), atol=1e-5) + assert torch.allclose(is_weights[0, 3], torch.tensor(2.0), atol=1e-5) + # Check second sequence: only one token should be truncated if > threshold + # 0.8 < 2.0, so should remain 0.8 + assert torch.allclose(is_weights[0, 4], torch.tensor(0.8), atol=1e-5) + + def test_threshold_sensitivity(self): + """Test different threshold values""" + vllm_log_ratio = torch.log(torch.tensor([[1.0, 2.0, 3.0, 4.0]])) + completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + + # Test threshold=1.5 + trainer_low = MockGRPOTrainer(mode='token_truncate', threshold=1.5) + is_weights_low = trainer_low._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + + # Test threshold=3.5 + trainer_high = MockGRPOTrainer(mode='token_truncate', threshold=3.5) + is_weights_high = trainer_high._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + + # Lower threshold should truncate more + truncated_low = (is_weights_low < torch.exp(vllm_log_ratio)).sum() + truncated_high = (is_weights_high < torch.exp(vllm_log_ratio)).sum() + assert truncated_low > truncated_high + + def test_completion_mask(self): + """Test that completion mask is respected""" + trainer = MockGRPOTrainer(mode='sequence_truncate', threshold=2.0) + + vllm_log_ratio = torch.log(torch.tensor([[3.0, 3.0, 3.0, 3.0]])) + # Mask out last two tokens + completion_mask = torch.tensor([[True, True, False, False]]) + + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + + # Should only consider masked tokens for sequence ratio calculation + # With only first two tokens (both 3.0), avg=3.0, truncated to 2.0 + assert torch.allclose(is_weights[0, :2], torch.tensor(2.0), atol=1e-5) + + def test_edge_cases(self): + """Test edge cases""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Case 1: All ratios below threshold + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.0, 1.5]])) + completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + assert torch.allclose(is_weights, torch.exp(vllm_log_ratio), atol=1e-5) + + # Case 2: All ratios above threshold + vllm_log_ratio = torch.log(torch.tensor([[3.0, 4.0, 5.0]])) + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + assert torch.allclose(is_weights, torch.tensor(2.0), atol=1e-5) + + # Case 3: Empty mask + vllm_log_ratio = torch.log(torch.tensor([[1.0, 2.0, 3.0]])) + completion_mask = torch.zeros_like(vllm_log_ratio, dtype=torch.bool) + is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + # Should still compute but result may not be meaningful + assert is_weights.shape == vllm_log_ratio.shape + + +if __name__ == '__main__': + # Run tests manually + import sys + + test_instance = TestVLLMImportanceSampling() + + test_methods = [ + 'test_token_truncate_basic', 'test_token_mask_basic', 'test_sequence_truncate_basic', + 'test_sequence_mask_basic', 'test_padding_free_mode', 'test_threshold_sensitivity', 'test_completion_mask', + 'test_edge_cases' + ] + + failed_tests = [] + for method_name in test_methods: + try: + print(f'Running {method_name}...') + getattr(test_instance, method_name)() + print(f'✓ {method_name} passed') + except Exception as e: + print(f'✗ {method_name} failed: {e}') + failed_tests.append(method_name) + + if failed_tests: + print(f'\nFailed tests: {failed_tests}') + sys.exit(1) + else: + print('\nAll tests passed!') From 9392ddb03d75eacdafb2d2385cc2005304c7e348 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sun, 23 Nov 2025 14:04:57 +0800 Subject: [PATCH 02/21] set logprobs 0 --- swift/trainers/rlhf_trainer/rollout_mixin.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index b25e24f3a7..eb395b4b74 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -101,9 +101,6 @@ def _prepare_rollout_params(self): self.async_generate = args.async_generate # Enable logprobs for vLLM importance sampling if requested - # TODO: check if logprobs is needed - use_vllm_is = args.rollout_importance_sampling_mode and args.use_vllm - logprobs = use_vllm_is self.request_config = RequestConfig( n=1, @@ -114,7 +111,7 @@ def _prepare_rollout_params(self): repetition_penalty=args.repetition_penalty, stop=args.stop_words, return_details=True, - logprobs=logprobs) + logprobs=args.use_vllm) def _prepare_vllm(self): """Initialize vLLM engine (server or colocate mode)""" From 6f775102e46e46f890394dabe750d02492418af1 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sun, 23 Nov 2025 22:36:02 +0800 Subject: [PATCH 03/21] update metrics --- swift/trainers/rlhf_trainer/grpo_trainer.py | 36 +++++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c06ec7a8db..7eb6dfe85d 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -2217,8 +2217,12 @@ def _compute_rollout_correction_metrics( metrics = {} # Compute log ratios + # Keep original log_ratio for KL computation (accurate) + # Use clamped version for exponential operations (numerically stable) + SAFETY_BOUND = 20.0 log_ratio = per_token_logps - rollout_per_token_logps - is_ratio = torch.exp(log_ratio) + log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + is_ratio = torch.exp(log_ratio_safe) # Helper function for masked mean def masked_mean(x, mask): @@ -2239,11 +2243,9 @@ def masked_mean(x, mask): metrics['ppl_rollout'] = self.accelerator.gather_for_metrics(rollout_ppl).nanmean().item() metrics['ppl_policy'] = self.accelerator.gather_for_metrics(policy_ppl).nanmean().item() - # 3. Chi-square divergence: E[(π_θ/π_rollout - 1)^2] = E[(ratio - 1)^2] - chi_square = masked_mean((is_ratio - 1.0)**2, completion_mask) - metrics['chi_square'] = self.accelerator.gather_for_metrics(chi_square).nanmean().item() - - # 4. Effective Sample Size (ESS): 1 / E[w^2] where w = π_θ/π_rollout + # 3. Effective Sample Size (ESS): 1 / E[(w/E[w])²] + # ESS measures the "effective" number of independent samples after importance sampling correction + # Higher ESS means better sample quality and more stable gradient estimates # For sequence-level ESS, we compute per-sequence ratios if self.template.padding_free: ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) @@ -2261,13 +2263,18 @@ def masked_mean(x, mask): seq_log_ratios = (log_r * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) seq_ratios = torch.exp(seq_log_ratios) - # ESS = N / (1 + var(w)) ≈ N / mean(w^2) for normalized weights - # But we use unnormalized: ESS = 1 / mean(w^2) - mean_ratio_squared = (seq_ratios**2).mean() - ess = 1.0 / mean_ratio_squared.clamp(min=1e-10) - metrics['ess'] = self.accelerator.gather_for_metrics(ess).nanmean().item() - - # 5. IS weight statistics + # ESS = 1 / E[(w/E[w])²] - measures effective number of independent samples + # Following verl implementation: normalize weights to mean=1, then compute ESS + mean_seq_ratio = seq_ratios.mean() + seq_ratios_normalized = seq_ratios / (mean_seq_ratio + 1e-8) + ess = 1.0 / (seq_ratios_normalized**2).mean().clamp(min=1e-10) + # ESS is already normalized (ranges from ~0 to N where N is batch size) + # Divide by batch size to get relative ESS in [0, 1] + num_sequences = max(len(seq_ratios), 1) + ess_normalized = ess / num_sequences + metrics['ess'] = self.accelerator.gather_for_metrics(ess_normalized).nanmean().item() + + # 4. IS weight statistics mean_is_weight = masked_mean(is_weights, completion_mask) metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() @@ -2286,7 +2293,8 @@ def masked_mean(x, mask): if self.rollout_importance_sampling_mode == 'sequence_truncate': clipped_frac = (seq_ratios > threshold).float().mean() else: # sequence_mask - clipped_frac = (is_weights.view(-1)[0] == 0).float() # Any token masked means seq masked + # Check which sequences are masked (ratio > threshold) + clipped_frac = (seq_ratios > threshold).float().mean() metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() return metrics From 897d099bde836da7dc5a17145b5a76b41c34dc67 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 24 Nov 2025 10:24:50 +0800 Subject: [PATCH 04/21] wip --- swift/trainers/arguments.py | 5 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 110 ++++++++++-------- .../test_vllm_importance_sampling_basic.py | 32 ++--- 3 files changed, 79 insertions(+), 68 deletions(-) diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 81895b3dc4..835a7dbaac 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -339,8 +339,9 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): dataset_shuffle: Optional[bool] = True # Rollout Importance Sampling Correction (off-policy correction) - rollout_importance_sampling_mode: Literal['token_truncate', 'token_mask', 'sequence_truncate', - 'sequence_mask'] = 'token_truncate' + # Set to None to disable, or choose from: 'token_truncate', 'token_mask', 'sequence_truncate', 'sequence_mask' + rollout_importance_sampling_mode: Optional[Literal['token_truncate', 'token_mask', 'sequence_truncate', + 'sequence_mask']] = None rollout_importance_sampling_threshold: float = 2.0 # Threshold for truncation/masking (C in paper) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 7eb6dfe85d..c813c1497f 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -816,7 +816,8 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: with torch.no_grad(): batch_encoded_inputs['old_per_token_logps'] = ( self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] - if self.old_policy() or self.kl_in_reward else None) + if self.old_policy() or self.kl_in_reward or + (self.use_vllm and self.rollout_importance_sampling_mode is not None) else None) if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: @@ -853,9 +854,10 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: for lp in vllm_logprobs_list: # Take last logits_to_keep tokens lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp - # Pad if needed + # Pad if needed - use a very small negative value to avoid affecting ratio computation + # These padded positions should be masked by completion_mask anyway if len(lp_tensor) < max_len: - lp_tensor = [0.0] * (max_len - len(lp_tensor)) + lp_tensor + lp_tensor = [-1e10] * (max_len - len(lp_tensor)) + lp_tensor padded_logprobs.append(lp_tensor) batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( padded_logprobs, dtype=torch.float32, device=self.accelerator.device) @@ -1021,20 +1023,17 @@ def _compute_loss_and_metrics(self, model, inputs): # Apply vLLM importance sampling correction if enabled rollout_correction_metrics = {} - if inputs.get('vllm_per_token_logps') is not None: + if inputs.get('vllm_per_token_logps') is not None and self.rollout_importance_sampling_mode is not None: vllm_per_token_logps = inputs['vllm_per_token_logps'] # Compute the log ratio between policy model and vLLM rollout model # log π_θ(y|x) - log π_vllm(y|x) vllm_log_ratio = per_token_logps - vllm_per_token_logps # Apply importance sampling correction based on mode - if self.rollout_importance_sampling_mode: - rollout_is_weights = self._apply_rollout_importance_sampling( - vllm_log_ratio, completion_mask, lengths if self.template.padding_free else None) - else: - rollout_is_weights = None + rollout_is_weights = self._apply_rollout_importance_sampling( + vllm_log_ratio, completion_mask, lengths if self.template.padding_free else None) - # Compute and log correction metrics if enabled + # Compute and log correction metrics rollout_correction_metrics = self._compute_rollout_correction_metrics( per_token_logps, vllm_per_token_logps, rollout_is_weights, completion_mask, lengths if self.template.padding_free else None) @@ -2110,6 +2109,42 @@ def single_sample_context(): with single_sample_context(): self.truncated_resample_iterator = cyclic_iter(self.get_train_dataloader()) + def _compute_sequence_level_ratios(self, + is_ratio: torch.Tensor, + completion_mask: torch.Tensor, + lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Helper function to compute sequence-level importance sampling ratios. + + Args: + is_ratio: Token-level IS ratios, shape [B, T] or [1, total_T] for padding_free + completion_mask: Boolean mask for completion tokens + lengths: Sequence lengths for padding_free mode + + Returns: + Sequence-level ratios as geometric mean of token-level ratios + """ + if self.template.padding_free: + # Split by sequence lengths + ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + + seq_ratios = [] + for ratio, mask in zip(ratio_list, mask_list): + # Geometric mean approximation: exp(mean(log(ratio))) + log_ratio = torch.log(ratio.clamp(min=1e-10)) + seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) + seq_ratios.append(seq_ratio) + + seq_ratios = torch.stack(seq_ratios) + else: + # Standard mode: [B, T] + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + + return seq_ratios + def _apply_rollout_importance_sampling(self, vllm_log_ratio: torch.Tensor, completion_mask: torch.Tensor, @@ -2141,52 +2176,24 @@ def _apply_rollout_importance_sampling(self, elif mode == 'sequence_truncate': # Sequence-level truncated IS: compute sequence-level ratio and clip - if self.template.padding_free: - # Split by sequence lengths - ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask, lengths) + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) - seq_ratios = [] - for ratio, mask in zip(ratio_list, mask_list): - # Geometric mean approximation: exp(mean(log(ratio))) - log_ratio = torch.log(ratio.clamp(min=1e-10)) - seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) - seq_ratios.append(seq_ratio) - - seq_ratios = torch.stack(seq_ratios) - # Clip sequence ratios - clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) - # Broadcast back to tokens + # Broadcast back to tokens + if self.template.padding_free: is_weights = torch.repeat_interleave(clipped_seq_ratios, lengths).unsqueeze(0) else: - # Standard mode: [B, T] - log_ratio = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) - clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) elif mode == 'sequence_mask': # Sequence-level masked IS: mask entire sequences with ratio > threshold - if self.template.padding_free: - ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask, lengths) + seq_mask = (seq_ratios <= threshold).float() - seq_ratios = [] - for ratio, mask in zip(ratio_list, mask_list): - log_ratio = torch.log(ratio.clamp(min=1e-10)) - seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) - seq_ratios.append(seq_ratio) - - seq_ratios = torch.stack(seq_ratios) - # Mask sequences with ratio > threshold - seq_mask = (seq_ratios <= threshold).float() + # Broadcast back to tokens + if self.template.padding_free: is_weights = torch.repeat_interleave(seq_mask, lengths).unsqueeze(0) else: - log_ratio = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) - seq_mask = (seq_ratios <= threshold).float() is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) else: raise ValueError(f'Unknown rollout importance sampling mode: {mode}') @@ -2264,15 +2271,18 @@ def masked_mean(x, mask): seq_ratios = torch.exp(seq_log_ratios) # ESS = 1 / E[(w/E[w])²] - measures effective number of independent samples + # For distributed training, gather all seq_ratios across ranks first to compute global ESS + all_seq_ratios = self.accelerator.gather_for_metrics(seq_ratios) + # Following verl implementation: normalize weights to mean=1, then compute ESS - mean_seq_ratio = seq_ratios.mean() - seq_ratios_normalized = seq_ratios / (mean_seq_ratio + 1e-8) + mean_seq_ratio = all_seq_ratios.mean() + seq_ratios_normalized = all_seq_ratios / (mean_seq_ratio + 1e-8) ess = 1.0 / (seq_ratios_normalized**2).mean().clamp(min=1e-10) # ESS is already normalized (ranges from ~0 to N where N is batch size) # Divide by batch size to get relative ESS in [0, 1] - num_sequences = max(len(seq_ratios), 1) + num_sequences = max(len(all_seq_ratios), 1) ess_normalized = ess / num_sequences - metrics['ess'] = self.accelerator.gather_for_metrics(ess_normalized).nanmean().item() + metrics['ess'] = ess_normalized.item() # 4. IS weight statistics mean_is_weight = masked_mean(is_weights, completion_mask) diff --git a/tests/train/test_vllm_importance_sampling_basic.py b/tests/train/test_vllm_importance_sampling_basic.py index d44f8ef94a..2fc18dd416 100644 --- a/tests/train/test_vllm_importance_sampling_basic.py +++ b/tests/train/test_vllm_importance_sampling_basic.py @@ -11,14 +11,14 @@ class MockGRPOTrainer: """Mock GRPO trainer for testing IS methods""" def __init__(self, mode='token_truncate', threshold=2.0): - self.vllm_importance_sampling_mode = mode - self.vllm_importance_sampling_threshold = threshold + self.rollout_importance_sampling_mode = mode + self.rollout_importance_sampling_threshold = threshold self.template = MockTemplate() - def _apply_vllm_importance_sampling(self, vllm_log_ratio, completion_mask, lengths=None): + def _apply_rollout_importance_sampling(self, vllm_log_ratio, completion_mask, lengths=None): """Copy of the implementation from grpo_trainer.py""" - mode = self.vllm_importance_sampling_mode - threshold = self.vllm_importance_sampling_threshold + mode = self.rollout_importance_sampling_mode + threshold = self.rollout_importance_sampling_threshold is_ratio = torch.exp(vllm_log_ratio) @@ -93,7 +93,7 @@ def test_token_truncate_basic(self): vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0], [0.8, 1.2, 2.5, 4.0]])) completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Check truncation at threshold=2.0 assert is_weights.shape == vllm_log_ratio.shape @@ -109,7 +109,7 @@ def test_token_mask_basic(self): vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Check masking: ratio > threshold should be 0 assert torch.allclose(is_weights[0, 0], torch.tensor(0.5), atol=1e-5) @@ -128,7 +128,7 @@ def test_sequence_truncate_basic(self): ])) # avg=1.0 < 2.0 completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # First sequence should be truncated to 2.0 for all tokens assert torch.allclose(is_weights[0, :], torch.tensor(2.0), atol=1e-5) @@ -145,7 +145,7 @@ def test_sequence_mask_basic(self): ])) # avg=1.0 < 2.0 completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # First sequence should be completely masked (0) assert torch.allclose(is_weights[0, :], torch.tensor(0.0), atol=1e-5) @@ -162,7 +162,7 @@ def test_padding_free_mode(self): completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) lengths = torch.tensor([4, 2]) # Two sequences: len=4 and len=2 - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask, lengths) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask, lengths) # Should have same shape as input assert is_weights.shape == vllm_log_ratio.shape @@ -180,11 +180,11 @@ def test_threshold_sensitivity(self): # Test threshold=1.5 trainer_low = MockGRPOTrainer(mode='token_truncate', threshold=1.5) - is_weights_low = trainer_low._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights_low = trainer_low._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Test threshold=3.5 trainer_high = MockGRPOTrainer(mode='token_truncate', threshold=3.5) - is_weights_high = trainer_high._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights_high = trainer_high._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Lower threshold should truncate more truncated_low = (is_weights_low < torch.exp(vllm_log_ratio)).sum() @@ -199,7 +199,7 @@ def test_completion_mask(self): # Mask out last two tokens completion_mask = torch.tensor([[True, True, False, False]]) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Should only consider masked tokens for sequence ratio calculation # With only first two tokens (both 3.0), avg=3.0, truncated to 2.0 @@ -212,18 +212,18 @@ def test_edge_cases(self): # Case 1: All ratios below threshold vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.0, 1.5]])) completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) assert torch.allclose(is_weights, torch.exp(vllm_log_ratio), atol=1e-5) # Case 2: All ratios above threshold vllm_log_ratio = torch.log(torch.tensor([[3.0, 4.0, 5.0]])) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) assert torch.allclose(is_weights, torch.tensor(2.0), atol=1e-5) # Case 3: Empty mask vllm_log_ratio = torch.log(torch.tensor([[1.0, 2.0, 3.0]])) completion_mask = torch.zeros_like(vllm_log_ratio, dtype=torch.bool) - is_weights = trainer._apply_vllm_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Should still compute but result may not be meaningful assert is_weights.shape == vllm_log_ratio.shape From d69bf59de3f4382ea281e3a1002009a2d71babdc Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 24 Nov 2025 11:18:03 +0800 Subject: [PATCH 05/21] rm padding_free --- swift/trainers/rlhf_trainer/__init__.py | 2 + swift/trainers/rlhf_trainer/grpo_trainer.py | 342 ++++++++------------ swift/trainers/rlhf_trainer/utils.py | 69 ++++ 3 files changed, 210 insertions(+), 203 deletions(-) diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py index 829dba091b..36c6b2d0ac 100644 --- a/swift/trainers/rlhf_trainer/__init__.py +++ b/swift/trainers/rlhf_trainer/__init__.py @@ -15,6 +15,7 @@ from .rlhf_mixin import RLHFTrainerMixin from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection from .vllm_client import VLLMClient + from .padding_free_utils import pad_logps_back_to_batch, get_cu_seqlens_from_position_ids else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], @@ -28,6 +29,7 @@ 'rlhf_mixin': ['RLHFTrainerMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], 'vllm_client': ['VLLMClient'], + 'padding_free_utils': ['pad_logps_back_to_batch', 'get_cu_seqlens_from_position_ids'], } import sys diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c813c1497f..7c6f274b12 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -32,8 +32,8 @@ from ..mixin import SwiftMixin from .rollout_mixin import DataType, RolloutTrainerMixin from .utils import (_ForwardRedirection, compute_chord_loss, get_even_process_data, identity_data_collator, - load_pil_img, make_chord_sft_dataset, patch_profiling_context, patch_profiling_decorator, - patch_save_last_checkpoint, replace_assistant_response_with_ids) + load_pil_img, make_chord_sft_dataset, pad_logps_back_to_batch, patch_profiling_context, + patch_profiling_decorator, patch_save_last_checkpoint, replace_assistant_response_with_ids) try: from trl.trainer.utils import entropy_from_logits @@ -242,12 +242,8 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: f'Mismatch: {len(gas_chunks)} chunks vs {len(batch_encoded_inputs)} batches' for batch, batch_encoded in zip(gas_chunks, batch_encoded_inputs): - if self.template.padding_free: - lengths = batch_encoded['seq_lengths'] - advantages_stacked = torch.stack([data['advantages'] for data in batch]) - all_advantages = torch.repeat_interleave(advantages_stacked, lengths) - else: - all_advantages = torch.stack([data['advantages'] for data in batch]) + # Advantages are always [batch_size], will be broadcast to [batch_size, seq_len] in loss computation + all_advantages = torch.stack([data['advantages'] for data in batch]) batch_encoded['advantages'] = all_advantages with patch_profiling_context(self, 'log_metrics'): @@ -422,14 +418,9 @@ def log_rewards_all(rewards_per_func: torch.Tensor): old_per_token_logps = batch_encoded['old_per_token_logps'] ref_per_token_logps = batch_encoded['ref_per_token_logps'] completion_mask = batch_encoded['completion_mask'] - if self.template.padding_free: - lengths = batch_encoded['seq_lengths'] - per_token_kl = torch.split(old_per_token_logps - ref_per_token_logps, lengths.tolist(), dim=1) - completion_masks = torch.split(completion_mask, lengths.tolist(), dim=1) - kl = torch.cat([(kl * mask).sum(-1) for kl, mask in zip(per_token_kl, completion_masks)]) - else: - per_token_kl = old_per_token_logps - ref_per_token_logps - kl = (per_token_kl * completion_mask).sum(-1) + # Unified: logps and mask are now always [batch_size, seq_len] after pad_back + per_token_kl = old_per_token_logps - ref_per_token_logps + kl = (per_token_kl * completion_mask).sum(-1) kl_list.append(kl) kl = torch.cat(kl_list, dim=0) @@ -840,27 +831,18 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: # Convert to tensor if all samples have vllm_logprobs if all(lp is not None for lp in vllm_logprobs_list): - if self.template.padding_free: - # For padding-free mode, concatenate all logprobs - vllm_logprobs_flat = [] - for lp in vllm_logprobs_list: - vllm_logprobs_flat.extend(lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp) - batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( - vllm_logprobs_flat, dtype=torch.float32, device=self.accelerator.device).unsqueeze(0) - else: - # For standard mode, pad to match completion length - max_len = logits_to_keep - padded_logprobs = [] - for lp in vllm_logprobs_list: - # Take last logits_to_keep tokens - lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp - # Pad if needed - use a very small negative value to avoid affecting ratio computation - # These padded positions should be masked by completion_mask anyway - if len(lp_tensor) < max_len: - lp_tensor = [-1e10] * (max_len - len(lp_tensor)) + lp_tensor - padded_logprobs.append(lp_tensor) - batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( - padded_logprobs, dtype=torch.float32, device=self.accelerator.device) + max_len = logits_to_keep + padded_logprobs = [] + for lp in vllm_logprobs_list: + # Take last logits_to_keep tokens + lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp + # Pad if needed - use a very small negative value to avoid affecting ratio computation + # These padded positions will be masked by completion_mask + if len(lp_tensor) < max_len: + lp_tensor = [-1e10] * (max_len - len(lp_tensor)) + lp_tensor + padded_logprobs.append(lp_tensor) + batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( + padded_logprobs, dtype=torch.float32, device=self.accelerator.device) else: batch_encoded_inputs['vllm_per_token_logps'] = None else: @@ -871,10 +853,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: # --- log completion lengths --- mode = 'train' if self.model.training else 'eval' device = self.accelerator.device - if self.template.padding_free: - local_lengths = [inp['seq_lengths'].tolist() for inp in ga_batch_encoded_inputs] - else: - local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs] + local_lengths = [inp['completion_mask'].sum(1).tolist() for inp in ga_batch_encoded_inputs] total_lengths = self._gather_and_flatten(local_lengths, dtype=torch.float32, device=device, flatten_level=1) # Store num_items_in_batch for DAPO loss (total completion tokens across all processes) @@ -962,8 +941,6 @@ def _compute_loss_and_metrics(self, model, inputs): completion_mask = inputs['completion_mask'] truncated_mask = inputs['truncated_mask'] - if self.template.padding_free: - lengths = inputs['seq_lengths'] per_token_logps, entropies = self._get_per_token_logps_and_entropies( model, inputs, compute_entropy=self.compute_entropy) @@ -974,11 +951,8 @@ def _compute_loss_and_metrics(self, model, inputs): # fill the padded token with NaN entropies = entropies.masked_fill(completion_mask == 0, float('nan')) if self.args.log_entropy: - if self.template.padding_free: - entropy_list = torch.split(entropies, lengths.tolist()) - per_completion_entropies_mean = torch.stack([torch.nanmean(e) for e in entropy_list]) - else: - per_completion_entropies_mean = torch.nanmean(entropies, dim=1) + # Unified: entropies are now always [batch_size, seq_len] after pad_back + per_completion_entropies_mean = torch.nanmean(entropies, dim=1) global_per_completion_entropies_mean = gather(per_completion_entropies_mean) entropy_metrics = { 'entropy_logs': global_per_completion_entropies_mean.tolist(), @@ -998,11 +972,8 @@ def _compute_loss_and_metrics(self, model, inputs): if all(truncated_mask): logger.info('All completions are overlong and truncated, ' 'resulting in NaN some values for some metrics (e.g., KL)') - if self.template.padding_free: - truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) - assert truncated_mask.shape == completion_mask.shape - else: - truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask) + # Unified: expand truncated_mask to match completion_mask [batch_size, seq_len] + truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask) completion_mask = completion_mask & (~truncated_mask) # Compute the KL divergence between the model and the reference model @@ -1030,13 +1001,11 @@ def _compute_loss_and_metrics(self, model, inputs): vllm_log_ratio = per_token_logps - vllm_per_token_logps # Apply importance sampling correction based on mode - rollout_is_weights = self._apply_rollout_importance_sampling( - vllm_log_ratio, completion_mask, lengths if self.template.padding_free else None) + rollout_is_weights = self._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask, lengths=None) # Compute and log correction metrics rollout_correction_metrics = self._compute_rollout_correction_metrics( - per_token_logps, vllm_per_token_logps, rollout_is_weights, completion_mask, - lengths if self.template.padding_free else None) + per_token_logps, vllm_per_token_logps, rollout_is_weights, completion_mask, lengths=None) # Apply IS weights: multiply the final loss by the IS weight # Store for later application in loss computation @@ -1048,28 +1017,15 @@ def _compute_loss_and_metrics(self, model, inputs): if self.importance_sampling_level == 'token': log_importance_weights = log_ratio elif self.importance_sampling_level in ['sequence', 'sequence_token']: - if self.template.padding_free: - # split to batch, compute seq-level normalization - log_ratio_list = torch.split(log_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)] - seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1) - if self.importance_sampling_level == 'sequence': - log_importance_weights = seq_level_log_weights - else: - seq_level_log_weight = seq_level_log_weights.detach() - seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, lengths).unsqueeze(0) - log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + # Unified: completion_mask is now always [batch_size, seq_len] + seq_level_log_weights = ((log_ratio * completion_mask).sum(-1) + / completion_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1) + if self.importance_sampling_level == 'sequence': + log_importance_weights = seq_level_log_weights else: - seq_level_log_weights = ((log_ratio * completion_mask).sum(-1) - / completion_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1) - if self.importance_sampling_level == 'sequence': - log_importance_weights = seq_level_log_weights - else: - # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] - seq_level_log_weight = seq_level_log_weights.detach() - log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight - + # GSPO-token: sg[si(θ)] * πθ(yi,t)/sg[πθ(yi,t)] + seq_level_log_weight = seq_level_log_weights.detach() + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight else: raise ValueError( f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " @@ -1079,28 +1035,16 @@ def _compute_loss_and_metrics(self, model, inputs): if self.loss_type == 'cispo': clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() - if self.template.padding_free: - advantages = advantages[-coef_1.shape[1]:] - per_token_loss = -clamped_ratios * advantages.unsqueeze(0) * per_token_logps - else: - per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps + # Unified: coef_1 and advantages are now [batch_size, seq_len] and [batch_size] + per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) if self.args.delta is not None: coef_1 = torch.clamp(coef_1, max=self.args.delta) - if self.template.padding_free: - if self.importance_sampling_level == 'sequence': - # Expand sequence-level weights to token-level - coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths).unsqueeze(0) - coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths).unsqueeze(0) - - advantages = advantages[-coef_1.shape[1]:] - per_token_loss1 = coef_1 * advantages.unsqueeze(0) - per_token_loss2 = coef_2 * advantages.unsqueeze(0) - else: - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) + # Unified: coef_1 is [batch_size, seq_len], advantages is [batch_size] + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask @@ -1113,18 +1057,12 @@ def _compute_loss_and_metrics(self, model, inputs): per_token_loss = per_token_loss * rollout_is_weights if self.loss_type == 'grpo': - if self.template.padding_free: - loss_list = torch.split(per_token_loss.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) - for loss, mask in zip(loss_list, mask_list)] - loss = torch.stack(sample_loss).mean() - else: - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + # completion_mask is now always [batch_size, seq_len] after pad_back + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() elif self.loss_type == 'bnpo': loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': - batch_size = lengths.shape[0] if self.template.padding_free else inputs['input_ids'].shape[0] + batch_size = completion_mask.shape[0] # Unified: always batch_size from completion_mask loss = (per_token_loss * completion_mask).sum() / (batch_size * self.max_completion_length) elif self.loss_type in ['cispo', 'dapo']: # CISPO and DAPO: Normalize by total completion tokens across all processes @@ -1161,20 +1099,13 @@ def masked_batch_mean(x): # Compute the clipped probability ratios if self.loss_type == 'cispo': # CISPO: Only track upper bound clipping - if self.template.padding_free: - is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(0) > 0) - else: - is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_cispo_clipped = (coef_1 > self.epsilon_high) & (advantages.unsqueeze(1) > 0) cispo_clip_ratio = masked_batch_mean(is_cispo_clipped.float()) gathered_cispo_clip_ratio = self.accelerator.gather_for_metrics(cispo_clip_ratio) metrics_data['clipping'] = {'cispo_clip_ratio': gathered_cispo_clip_ratio.nanmean().item()} else: - if self.template.padding_free: - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(0) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(0) > 0) - else: - is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) - is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) is_region_clipped = is_low_clipped | is_high_clipped low_clip = masked_batch_mean(is_low_clipped.float()) @@ -1443,8 +1374,17 @@ def _get_per_token_logps_and_entropies_single(self, compute_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.template.sequence_parallel_size > 1: return self._get_per_token_logps_and_entropies_sp(model, inputs, compute_entropy=compute_entropy) + logits_to_keep = inputs['logits_to_keep'] input_ids = inputs['input_ids'] + is_padding_free = self.template.padding_free + + # Store metadata for padding_free restoration + if is_padding_free: + original_position_ids = inputs.get('text_position_ids') or inputs.get('position_ids') + original_seq_lengths = inputs.get('seq_lengths') + batch_size = original_seq_lengths.shape[0] + unwrapped_model = self.accelerator.unwrap_model(model) if is_peft_model(unwrapped_model): parameters = inspect.signature(unwrapped_model.base_model.model.forward).parameters @@ -1452,13 +1392,10 @@ def _get_per_token_logps_and_entropies_single(self, parameters = inspect.signature(unwrapped_model.forward).parameters use_local_entropy = not hasattr(super(), '_get_per_token_logps_and_entropies') and compute_entropy - can_use_super = (not self.is_multimodal and 'logits_to_keep' in parameters and not use_local_entropy) - if 'attention_mask' not in inputs: - # when set padding_free true, the attention_mask is not in inputs - can_use_super = False + can_use_super = (not self.is_multimodal and 'logits_to_keep' in parameters and not use_local_entropy + and not is_padding_free) if can_use_super: - # save memory if hasattr(super(), '_get_per_token_logps_and_entropies'): logps, entropies = super()._get_per_token_logps_and_entropies( model, input_ids, inputs['attention_mask'], logits_to_keep, compute_entropy=compute_entropy) @@ -1466,7 +1403,7 @@ def _get_per_token_logps_and_entropies_single(self, logps = super()._get_per_token_logps(model, input_ids, inputs['attention_mask'], logits_to_keep) entropies = None else: - inputs = { + model_inputs = { k: v for k, v in inputs.items() if k not in [ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', @@ -1474,16 +1411,60 @@ def _get_per_token_logps_and_entropies_single(self, ] } if 'logits_to_keep' in self.model_kwarg_keys: - inputs['logits_to_keep'] = logits_to_keep + 1 - logits = model(**inputs).logits - # exclude the last logit: it corresponds to the next token pred - logits = logits[:, -(logits_to_keep + 1):-1, :] - logits = logits / self.temperature - input_ids = input_ids[:, -logits_to_keep:] - logps = selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens - entropies = None - if compute_entropy: - entropies = entropy_from_logits(logits) + model_inputs['logits_to_keep'] = logits_to_keep + 1 + + # Forward pass + logits = model(**model_inputs).logits + + # Extract relevant portion and apply temperature + logits = logits[:, -(logits_to_keep + 1):-1, :] / self.temperature + input_ids_for_logps = input_ids[:, -logits_to_keep:] + + # Compute on rmpad, then pad back + if is_padding_free: + # In padding_free mode, compute logps on flattened tensors + logits_rmpad = logits.squeeze(0) # [total_nnz, vocab_size] + input_ids_rmpad = input_ids_for_logps.squeeze(0) # [total_nnz] + + # Compute logps on rmpad tensors + per_token_logps_rmpad = selective_log_softmax(logits_rmpad, input_ids_rmpad) # [total_nnz] + + # Compute entropy if needed + if compute_entropy: + entropy_rmpad = entropy_from_logits(logits_rmpad) # [total_nnz] + else: + entropy_rmpad = None + + # Restore to batch shape + position_ids_for_restore = original_position_ids.squeeze()[-logits_to_keep:] + + logps, completion_mask = pad_logps_back_to_batch( + logps_rmpad=per_token_logps_rmpad.unsqueeze(0), # [1, total_nnz] + position_ids=position_ids_for_restore, + logits_to_keep=logits_to_keep, + batch_size=batch_size) + + # Also restore entropy if computed + if compute_entropy: + entropies, _ = pad_logps_back_to_batch( + logps_rmpad=entropy_rmpad.unsqueeze(0), + position_ids=position_ids_for_restore, + logits_to_keep=logits_to_keep, + batch_size=batch_size) + else: + entropies = None + + # Store the restored completion_mask back to inputs + # This eliminates the need to recompute it in loss functions + inputs['completion_mask'] = completion_mask + + else: + logps = selective_log_softmax(logits, input_ids_for_logps) + + if compute_entropy: + entropies = entropy_from_logits(logits) + else: + entropies = None return logps, entropies @@ -2109,53 +2090,32 @@ def single_sample_context(): with single_sample_context(): self.truncated_resample_iterator = cyclic_iter(self.get_train_dataloader()) - def _compute_sequence_level_ratios(self, - is_ratio: torch.Tensor, - completion_mask: torch.Tensor, - lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + def _compute_sequence_level_ratios(self, is_ratio: torch.Tensor, completion_mask: torch.Tensor) -> torch.Tensor: """ Helper function to compute sequence-level importance sampling ratios. Args: - is_ratio: Token-level IS ratios, shape [B, T] or [1, total_T] for padding_free - completion_mask: Boolean mask for completion tokens - lengths: Sequence lengths for padding_free mode + is_ratio: Token-level IS ratios, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] Returns: Sequence-level ratios as geometric mean of token-level ratios """ - if self.template.padding_free: - # Split by sequence lengths - ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - - seq_ratios = [] - for ratio, mask in zip(ratio_list, mask_list): - # Geometric mean approximation: exp(mean(log(ratio))) - log_ratio = torch.log(ratio.clamp(min=1e-10)) - seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) - seq_ratios.append(seq_ratio) - - seq_ratios = torch.stack(seq_ratios) - else: - # Standard mode: [B, T] - log_ratio = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) + # Unified: is_ratio and completion_mask are always [B, T] after pad_back + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) return seq_ratios - def _apply_rollout_importance_sampling(self, - vllm_log_ratio: torch.Tensor, - completion_mask: torch.Tensor, - lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + def _apply_rollout_importance_sampling(self, vllm_log_ratio: torch.Tensor, + completion_mask: torch.Tensor) -> torch.Tensor: """ Apply vLLM importance sampling correction using one of four modes. Args: - vllm_log_ratio: log(π_θ / π_vllm) per token, shape [B, T] or [1, total_T] for padding_free - completion_mask: Boolean mask for completion tokens - lengths: Sequence lengths for padding_free mode + vllm_log_ratio: log(π_θ / π_vllm) per token, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] Returns: IS weights to multiply with loss, same shape as vllm_log_ratio @@ -2176,25 +2136,19 @@ def _apply_rollout_importance_sampling(self, elif mode == 'sequence_truncate': # Sequence-level truncated IS: compute sequence-level ratio and clip - seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask, lengths) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) - # Broadcast back to tokens - if self.template.padding_free: - is_weights = torch.repeat_interleave(clipped_seq_ratios, lengths).unsqueeze(0) - else: - is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) + # Broadcast back to tokens (unified for both modes) + is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) elif mode == 'sequence_mask': # Sequence-level masked IS: mask entire sequences with ratio > threshold - seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask, lengths) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) seq_mask = (seq_ratios <= threshold).float() - # Broadcast back to tokens - if self.template.padding_free: - is_weights = torch.repeat_interleave(seq_mask, lengths).unsqueeze(0) - else: - is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) + # Broadcast back to tokens (unified for both modes) + is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) else: raise ValueError(f'Unknown rollout importance sampling mode: {mode}') @@ -2206,17 +2160,15 @@ def _compute_rollout_correction_metrics( rollout_per_token_logps: torch.Tensor, is_weights: torch.Tensor, completion_mask: torch.Tensor, - lengths: Optional[torch.Tensor] = None, ) -> Dict[str, float]: """ Compute rollout correction metrics: KL, PPL, chi-square, ESS. Args: - per_token_logps: Log probs from policy model - rollout_per_token_logps: Log probs from rollout - is_weights: Importance sampling weights - completion_mask: Boolean mask for completion tokens - lengths: Sequence lengths for padding_free mode + per_token_logps: Log probs from policy model, shape [B, T] + rollout_per_token_logps: Log probs from rollout, shape [B, T] + is_weights: Importance sampling weights, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] Returns: Dictionary with metrics @@ -2233,12 +2185,8 @@ def _compute_rollout_correction_metrics( # Helper function for masked mean def masked_mean(x, mask): - if self.template.padding_free: - # x: [1, T], mask: [1, T] - return (x.squeeze(0) * mask.squeeze(0)).sum() / mask.squeeze(0).sum().clamp(min=1.0) - else: - # x: [B, T], mask: [B, T] - return (x * mask).sum() / mask.sum().clamp(min=1.0) + # x: [B, T], mask: [B, T] (after pad_back) + return (x * mask).sum() / mask.sum().clamp(min=1.0) # 1. KL divergence: KL(π_θ || π_rollout) ≈ E[log(π_θ/π_rollout)] kl_div = masked_mean(log_ratio, completion_mask) @@ -2253,28 +2201,16 @@ def masked_mean(x, mask): # 3. Effective Sample Size (ESS): 1 / E[(w/E[w])²] # ESS measures the "effective" number of independent samples after importance sampling correction # Higher ESS means better sample quality and more stable gradient estimates - # For sequence-level ESS, we compute per-sequence ratios - if self.template.padding_free: - ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - - seq_ratios = [] - for ratio, mask in zip(ratio_list, mask_list): - log_r = torch.log(ratio.clamp(min=1e-10)) - seq_ratio = torch.exp((log_r * mask).sum() / mask.sum().clamp(min=1.0)) - seq_ratios.append(seq_ratio) - - seq_ratios = torch.stack(seq_ratios) - else: - log_r = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_r * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) + # For sequence-level ESS, we compute per-sequence ratios (unified) + log_r = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_r * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) # ESS = 1 / E[(w/E[w])²] - measures effective number of independent samples # For distributed training, gather all seq_ratios across ranks first to compute global ESS all_seq_ratios = self.accelerator.gather_for_metrics(seq_ratios) - # Following verl implementation: normalize weights to mean=1, then compute ESS + # normalize weights to mean=1, then compute ESS mean_seq_ratio = all_seq_ratios.mean() seq_ratios_normalized = all_seq_ratios / (mean_seq_ratio + 1e-8) ess = 1.0 / (seq_ratios_normalized**2).mean().clamp(min=1e-10) diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 72c7ab5d29..55f6db6a4e 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -1124,3 +1124,72 @@ def get_even_process_data(trainer, global_data: List[T]) -> List[T]: end = start + base_size return global_data[start:end] + + +# ============================================================================ +# Padding-free utilities +# ============================================================================ + + +def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, + position_ids: torch.Tensor, + logits_to_keep: int, + batch_size: int, + dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Restore padding-free logprobs back to [batch_size, seq_len] shape. + + - Input: logps in rmpad format [1, total_nnz] + - Output: logps in batch format [batch_size, max_seq_len] + + Args: + logps_rmpad: [1, total_nnz] per-token log probabilities in padding_free format + position_ids: [1, total_nnz] position ids to determine sequence boundaries + logits_to_keep: number of tokens to keep per sequence + batch_size: number of sequences in the batch + dtype: optional dtype for output, defaults to logps_rmpad.dtype + + Returns: + logps_padded: [batch_size, logits_to_keep] padded log probabilities + completion_mask: [batch_size, logits_to_keep] mask indicating valid positions + """ + from swift.utils.torch_utils import get_cu_seqlens_from_position_ids as get_cu_seqlens + + if dtype is None: + dtype = logps_rmpad.dtype + + device = logps_rmpad.device + + # Get cumulative sequence lengths using swift's existing implementation + cu_seqlens = get_cu_seqlens(position_ids) + + # Adjust cu_seqlens for logits_to_keep if needed + total_length = cu_seqlens[-1].item() + if total_length > logits_to_keep: + # Adjust the first sequence length + adjustment = total_length - logits_to_keep + cu_seqlens = cu_seqlens - adjustment + cu_seqlens[0] = 0 # First element should always be 0 + + # Compute actual sequence lengths + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + max_seq_len = logits_to_keep # All sequences will be padded to this length + + # Initialize output tensors + logps_padded = torch.zeros(batch_size, max_seq_len, dtype=dtype, device=device) + completion_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.float32, device=device) + + # Unflatten: assign each sequence's logps to the corresponding row + logps_flat = logps_rmpad.squeeze(0) # [total_nnz] + + for i in range(batch_size): + start_idx = cu_seqlens[i].item() + end_idx = cu_seqlens[i + 1].item() + seq_len = seq_lengths[i].item() + + # Copy the sequence logps + logps_padded[i, :seq_len] = logps_flat[start_idx:end_idx] + # Set mask for valid positions + completion_mask[i, :seq_len] = 1.0 + + return logps_padded, completion_mask From 811c14222764905312c71ddba16cfea4dd5fcd8c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 24 Nov 2025 11:19:36 +0800 Subject: [PATCH 06/21] rm comments --- swift/trainers/rlhf_trainer/grpo_trainer.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 7c6f274b12..8967e367fa 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -418,7 +418,6 @@ def log_rewards_all(rewards_per_func: torch.Tensor): old_per_token_logps = batch_encoded['old_per_token_logps'] ref_per_token_logps = batch_encoded['ref_per_token_logps'] completion_mask = batch_encoded['completion_mask'] - # Unified: logps and mask are now always [batch_size, seq_len] after pad_back per_token_kl = old_per_token_logps - ref_per_token_logps kl = (per_token_kl * completion_mask).sum(-1) kl_list.append(kl) @@ -951,7 +950,6 @@ def _compute_loss_and_metrics(self, model, inputs): # fill the padded token with NaN entropies = entropies.masked_fill(completion_mask == 0, float('nan')) if self.args.log_entropy: - # Unified: entropies are now always [batch_size, seq_len] after pad_back per_completion_entropies_mean = torch.nanmean(entropies, dim=1) global_per_completion_entropies_mean = gather(per_completion_entropies_mean) entropy_metrics = { @@ -972,7 +970,6 @@ def _compute_loss_and_metrics(self, model, inputs): if all(truncated_mask): logger.info('All completions are overlong and truncated, ' 'resulting in NaN some values for some metrics (e.g., KL)') - # Unified: expand truncated_mask to match completion_mask [batch_size, seq_len] truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask) completion_mask = completion_mask & (~truncated_mask) @@ -1017,7 +1014,6 @@ def _compute_loss_and_metrics(self, model, inputs): if self.importance_sampling_level == 'token': log_importance_weights = log_ratio elif self.importance_sampling_level in ['sequence', 'sequence_token']: - # Unified: completion_mask is now always [batch_size, seq_len] seq_level_log_weights = ((log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).unsqueeze(-1) if self.importance_sampling_level == 'sequence': @@ -1035,14 +1031,12 @@ def _compute_loss_and_metrics(self, model, inputs): if self.loss_type == 'cispo': clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() - # Unified: coef_1 and advantages are now [batch_size, seq_len] and [batch_size] per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps elif self.loss_type in ['grpo', 'bnpo', 'dr_grpo', 'dapo']: coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) if self.args.delta is not None: coef_1 = torch.clamp(coef_1, max=self.args.delta) - # Unified: coef_1 is [batch_size, seq_len], advantages is [batch_size] per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) @@ -1062,7 +1056,7 @@ def _compute_loss_and_metrics(self, model, inputs): elif self.loss_type == 'bnpo': loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) elif self.loss_type == 'dr_grpo': - batch_size = completion_mask.shape[0] # Unified: always batch_size from completion_mask + batch_size = completion_mask.shape[0] loss = (per_token_loss * completion_mask).sum() / (batch_size * self.max_completion_length) elif self.loss_type in ['cispo', 'dapo']: # CISPO and DAPO: Normalize by total completion tokens across all processes @@ -2101,7 +2095,6 @@ def _compute_sequence_level_ratios(self, is_ratio: torch.Tensor, completion_mask Returns: Sequence-level ratios as geometric mean of token-level ratios """ - # Unified: is_ratio and completion_mask are always [B, T] after pad_back log_ratio = torch.log(is_ratio.clamp(min=1e-10)) seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) seq_ratios = torch.exp(seq_log_ratios) @@ -2139,7 +2132,6 @@ def _apply_rollout_importance_sampling(self, vllm_log_ratio: torch.Tensor, seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) - # Broadcast back to tokens (unified for both modes) is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) elif mode == 'sequence_mask': @@ -2147,7 +2139,6 @@ def _apply_rollout_importance_sampling(self, vllm_log_ratio: torch.Tensor, seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) seq_mask = (seq_ratios <= threshold).float() - # Broadcast back to tokens (unified for both modes) is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) else: raise ValueError(f'Unknown rollout importance sampling mode: {mode}') @@ -2201,7 +2192,7 @@ def masked_mean(x, mask): # 3. Effective Sample Size (ESS): 1 / E[(w/E[w])²] # ESS measures the "effective" number of independent samples after importance sampling correction # Higher ESS means better sample quality and more stable gradient estimates - # For sequence-level ESS, we compute per-sequence ratios (unified) + # For sequence-level ESS, we compute per-sequence ratios log_r = torch.log(is_ratio.clamp(min=1e-10)) seq_log_ratios = (log_r * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) seq_ratios = torch.exp(seq_log_ratios) From 62c2409aed510c27798d5b64d3a879afe58e3975 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 24 Nov 2025 11:39:48 +0800 Subject: [PATCH 07/21] fix --- swift/trainers/rlhf_trainer/__init__.py | 2 - swift/trainers/rlhf_trainer/grpo_trainer.py | 19 ++++---- swift/trainers/rlhf_trainer/utils.py | 52 +++++++++++++-------- 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py index 36c6b2d0ac..829dba091b 100644 --- a/swift/trainers/rlhf_trainer/__init__.py +++ b/swift/trainers/rlhf_trainer/__init__.py @@ -15,7 +15,6 @@ from .rlhf_mixin import RLHFTrainerMixin from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection from .vllm_client import VLLMClient - from .padding_free_utils import pad_logps_back_to_batch, get_cu_seqlens_from_position_ids else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], @@ -29,7 +28,6 @@ 'rlhf_mixin': ['RLHFTrainerMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], 'vllm_client': ['VLLMClient'], - 'padding_free_utils': ['pad_logps_back_to_batch', 'get_cu_seqlens_from_position_ids'], } import sys diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 8967e367fa..50c2d786ed 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -998,11 +998,11 @@ def _compute_loss_and_metrics(self, model, inputs): vllm_log_ratio = per_token_logps - vllm_per_token_logps # Apply importance sampling correction based on mode - rollout_is_weights = self._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask, lengths=None) + rollout_is_weights = self._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Compute and log correction metrics - rollout_correction_metrics = self._compute_rollout_correction_metrics( - per_token_logps, vllm_per_token_logps, rollout_is_weights, completion_mask, lengths=None) + rollout_correction_metrics = self._compute_rollout_correction_metrics(per_token_logps, vllm_per_token_logps, + rollout_is_weights, completion_mask) # Apply IS weights: multiply the final loss by the IS weight # Store for later application in loss computation @@ -1375,7 +1375,6 @@ def _get_per_token_logps_and_entropies_single(self, # Store metadata for padding_free restoration if is_padding_free: - original_position_ids = inputs.get('text_position_ids') or inputs.get('position_ids') original_seq_lengths = inputs.get('seq_lengths') batch_size = original_seq_lengths.shape[0] @@ -1429,22 +1428,20 @@ def _get_per_token_logps_and_entropies_single(self, else: entropy_rmpad = None - # Restore to batch shape - position_ids_for_restore = original_position_ids.squeeze()[-logits_to_keep:] - + # Restore to batch shape using seq_lengths logps, completion_mask = pad_logps_back_to_batch( logps_rmpad=per_token_logps_rmpad.unsqueeze(0), # [1, total_nnz] - position_ids=position_ids_for_restore, logits_to_keep=logits_to_keep, - batch_size=batch_size) + batch_size=batch_size, + seq_lengths=original_seq_lengths) # Also restore entropy if computed if compute_entropy: entropies, _ = pad_logps_back_to_batch( logps_rmpad=entropy_rmpad.unsqueeze(0), - position_ids=position_ids_for_restore, logits_to_keep=logits_to_keep, - batch_size=batch_size) + batch_size=batch_size, + seq_lengths=original_seq_lengths) else: entropies = None diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 55f6db6a4e..c43b31152c 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -1132,9 +1132,10 @@ def get_even_process_data(trainer, global_data: List[T]) -> List[T]: def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, - position_ids: torch.Tensor, - logits_to_keep: int, - batch_size: int, + position_ids: Optional[torch.Tensor] = None, + logits_to_keep: int = None, + batch_size: int = None, + seq_lengths: Optional[torch.Tensor] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Restore padding-free logprobs back to [batch_size, seq_len] shape. @@ -1144,35 +1145,48 @@ def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, Args: logps_rmpad: [1, total_nnz] per-token log probabilities in padding_free format - position_ids: [1, total_nnz] position ids to determine sequence boundaries + position_ids: [1, total_nnz] position ids to determine sequence boundaries (deprecated, use seq_lengths) logits_to_keep: number of tokens to keep per sequence batch_size: number of sequences in the batch + seq_lengths: [batch_size] actual sequence lengths (preferred over position_ids) dtype: optional dtype for output, defaults to logps_rmpad.dtype Returns: logps_padded: [batch_size, logits_to_keep] padded log probabilities completion_mask: [batch_size, logits_to_keep] mask indicating valid positions """ - from swift.utils.torch_utils import get_cu_seqlens_from_position_ids as get_cu_seqlens - if dtype is None: dtype = logps_rmpad.dtype device = logps_rmpad.device - # Get cumulative sequence lengths using swift's existing implementation - cu_seqlens = get_cu_seqlens(position_ids) - - # Adjust cu_seqlens for logits_to_keep if needed - total_length = cu_seqlens[-1].item() - if total_length > logits_to_keep: - # Adjust the first sequence length - adjustment = total_length - logits_to_keep - cu_seqlens = cu_seqlens - adjustment - cu_seqlens[0] = 0 # First element should always be 0 - - # Compute actual sequence lengths - seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + # Determine sequence lengths + if seq_lengths is not None: + # Use provided seq_lengths directly + total_length = seq_lengths.sum().item() + if total_length > logits_to_keep: + # Adjust the first sequence length to account for truncation + adjustment = total_length - logits_to_keep + seq_lengths = seq_lengths.clone() + seq_lengths[0] = seq_lengths[0] - adjustment + else: + # Fallback: infer from position_ids + from swift.utils.torch_utils import get_cu_seqlens_from_position_ids as get_cu_seqlens + cu_seqlens = get_cu_seqlens(position_ids) + + # Adjust cu_seqlens for logits_to_keep if needed + total_length = cu_seqlens[-1].item() + if total_length > logits_to_keep: + # Adjust the first sequence length + adjustment = total_length - logits_to_keep + cu_seqlens = cu_seqlens - adjustment + cu_seqlens[0] = 0 # First element should always be 0 + + # Compute actual sequence lengths + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + + # Compute cumulative sequence lengths + cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=device), seq_lengths]), dim=0) max_seq_len = logits_to_keep # All sequences will be padded to this length # Initialize output tensors From d535fc58cad71d944d26b16c4dd58925b9aab4bb Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 24 Nov 2025 14:25:15 +0800 Subject: [PATCH 08/21] fix --- swift/trainers/rlhf_trainer/grpo_trainer.py | 49 ++++++++++++++++----- swift/trainers/rlhf_trainer/utils.py | 14 +++++- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 50c2d786ed..eb6f1c4ab5 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -830,18 +830,43 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: # Convert to tensor if all samples have vllm_logprobs if all(lp is not None for lp in vllm_logprobs_list): - max_len = logits_to_keep - padded_logprobs = [] - for lp in vllm_logprobs_list: - # Take last logits_to_keep tokens - lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp - # Pad if needed - use a very small negative value to avoid affecting ratio computation - # These padded positions will be masked by completion_mask - if len(lp_tensor) < max_len: - lp_tensor = [-1e10] * (max_len - len(lp_tensor)) + lp_tensor - padded_logprobs.append(lp_tensor) - batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( - padded_logprobs, dtype=torch.float32, device=self.accelerator.device) + if self.template.padding_free: + # In padding_free mode, use pad_logps_back_to_batch for consistency + # Concatenate all vllm logprobs into a flat tensor + vllm_logprobs_flat = [] + for lp in vllm_logprobs_list: + # Take last logits_to_keep tokens (or all if shorter) + lp_to_use = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp + vllm_logprobs_flat.extend(lp_to_use) + + # Convert to tensor [1, total_nnz] + vllm_logprobs_rmpad = torch.tensor( + vllm_logprobs_flat, dtype=torch.float32, device=self.accelerator.device).unsqueeze(0) + + # Restore to batch format using the same seq_lengths + from swift.trainers.rlhf_trainer.utils import pad_logps_back_to_batch + seq_lengths = batch_encoded_inputs['seq_lengths'] + batch_size = seq_lengths.shape[0] + vllm_logps_padded, _ = pad_logps_back_to_batch( + logps_rmpad=vllm_logprobs_rmpad, + logits_to_keep=logits_to_keep, + batch_size=batch_size, + seq_lengths=seq_lengths, + dtype=torch.float32) + batch_encoded_inputs['vllm_per_token_logps'] = vllm_logps_padded + else: + # Standard mode: simple padding + max_len = logits_to_keep + padded_logprobs = [] + for lp in vllm_logprobs_list: + # Take last logits_to_keep tokens + lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp + # Pad if needed + if len(lp_tensor) < max_len: + lp_tensor = [-1e10] * (max_len - len(lp_tensor)) + lp_tensor + padded_logprobs.append(lp_tensor) + batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( + padded_logprobs, dtype=torch.float32, device=self.accelerator.device) else: batch_encoded_inputs['vllm_per_token_logps'] = None else: diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index c43b31152c..0eb45dd598 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -1201,8 +1201,18 @@ def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, end_idx = cu_seqlens[i + 1].item() seq_len = seq_lengths[i].item() - # Copy the sequence logps - logps_padded[i, :seq_len] = logps_flat[start_idx:end_idx] + actual_end_idx = min(end_idx, len(logps_flat)) + actual_len = actual_end_idx - start_idx + + if actual_len < seq_len: + # pad at the beginning + pad_len = seq_len - actual_len + logps_padded[i, :pad_len] = -1e10 # Padding value + logps_padded[i, pad_len:seq_len] = logps_flat[start_idx:actual_end_idx] + else: + # Normal case + logps_padded[i, :seq_len] = logps_flat[start_idx:end_idx] + # Set mask for valid positions completion_mask[i, :seq_len] = 1.0 From ecff28c9b3123c97f09a978ee82f8292dbfb6f35 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 24 Nov 2025 16:13:36 +0800 Subject: [PATCH 09/21] always old_logps? --- swift/trainers/rlhf_trainer/grpo_trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index eb6f1c4ab5..599cac60e6 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -805,9 +805,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: with torch.no_grad(): batch_encoded_inputs['old_per_token_logps'] = ( - self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0] - if self.old_policy() or self.kl_in_reward or - (self.use_vllm and self.rollout_importance_sampling_mode is not None) else None) + self._get_per_token_logps_and_entropies(self.model, batch_encoded_inputs)[0]) if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: @@ -1340,7 +1338,7 @@ def _get_per_token_logps_and_entropies_sp( k: v for k, v in inputs.items() if k not in [ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', - 'truncated_mask', 'seq_lengths', 'num_items_in_batch' + 'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'vllm_per_token_logps' ] } sequence_parallel.prepare_inputs(inputs) @@ -1425,7 +1423,7 @@ def _get_per_token_logps_and_entropies_single(self, k: v for k, v in inputs.items() if k not in [ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', - 'truncated_mask', 'seq_lengths', 'num_items_in_batch' + 'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'vllm_per_token_logps' ] } if 'logits_to_keep' in self.model_kwarg_keys: @@ -1561,7 +1559,7 @@ def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep): k: v for k, v in inputs.items() if k not in [ 'logits_to_keep', 'completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', - 'truncated_mask', 'seq_lengths', 'num_items_in_batch' + 'truncated_mask', 'seq_lengths', 'num_items_in_batch', 'vllm_per_token_logps' ] } if 'logits_to_keep' in self.model_kwarg_keys: From b668fb326cb4661709a5b8bfcba51ff48feeec26 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 24 Nov 2025 19:29:53 +0800 Subject: [PATCH 10/21] fix --- .../GRPO/AdvancedResearch/index.rst | 1 + .../training_inference_mismatch.md | 177 ++++++++++++++++++ .../grpo/plugin/run_external_reward_func.sh | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 22 ++- 4 files changed, 191 insertions(+), 11 deletions(-) create mode 100644 docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/index.rst b/docs/source/Instruction/GRPO/AdvancedResearch/index.rst index 9f833e1025..73eb9e4966 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/index.rst +++ b/docs/source/Instruction/GRPO/AdvancedResearch/index.rst @@ -11,3 +11,4 @@ Advanced Research REINFORCEPP.md CHORD.md CISPO.md + training_inference_mismatch.md diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md new file mode 100644 index 0000000000..806279b3fb --- /dev/null +++ b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -0,0 +1,177 @@ +# Training-Inference-Mismatch + +**版本依赖**:ms-swift>=3.11 + +**TL;DR**: GRPO 引入 vLLM 加速采样过程的同时,也引入了训练-推理不一致(Training-Inference Mismatch)的问题,从而可能影响训练稳定性。本文将解释这个问题的背景、原因以及相应的解决方案。 + +## Background + +### GRPO 的基本假设 + +GRPO (Group Relative Policy Optimization) 的训练目标可以表示为: + +$$ +\mathcal{L}_{\text{GRPO}} = - \mathbb{E}_{y \sim \pi_\theta} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right] +$$ + +其中: +- $r_t(\theta) = \frac{\pi_\theta(y_t|x, y_{ torch.Tensor: """ Apply vLLM importance sampling correction using one of four modes. Args: - vllm_log_ratio: log(π_θ / π_vllm) per token, shape [B, T] + rollout_log_ratio: log(π_θ / π_rollout) per token, shape [B, T] completion_mask: Boolean mask for completion tokens, shape [B, T] Returns: - IS weights to multiply with loss, same shape as vllm_log_ratio + IS weights to multiply with loss, same shape as rollout_log_ratio """ mode = self.rollout_importance_sampling_mode threshold = self.rollout_importance_sampling_threshold # Compute importance sampling ratios: exp(log_ratio) - is_ratio = torch.exp(vllm_log_ratio) + is_ratio = torch.exp(rollout_log_ratio) if mode == 'token_truncate': # Token-level truncated IS: clip ratios from above at threshold @@ -2159,9 +2160,10 @@ def _apply_rollout_importance_sampling(self, vllm_log_ratio: torch.Tensor, seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) seq_mask = (seq_ratios <= threshold).float() - is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) + # Apply mask to original token-level ratios + is_weights = is_ratio * seq_mask.unsqueeze(-1) else: - raise ValueError(f'Unknown rollout importance sampling mode: {mode}') + return is_ratio return is_weights From 14d2e1177a3f2a2a4586374ca5ef0004c6668760 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 00:21:29 +0800 Subject: [PATCH 11/21] fix wip --- .../grpo/plugin/run_external_reward_func.sh | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 135 +++++++++++------- 2 files changed, 85 insertions(+), 52 deletions(-) diff --git a/examples/train/grpo/plugin/run_external_reward_func.sh b/examples/train/grpo/plugin/run_external_reward_func.sh index 29a203b6b0..91ae09d5bc 100644 --- a/examples/train/grpo/plugin/run_external_reward_func.sh +++ b/examples/train/grpo/plugin/run_external_reward_func.sh @@ -8,7 +8,7 @@ swift rlhf \ --rlhf_type grpo \ --model Qwen/Qwen2.5-7B-Instruct \ --external_plugins examples/train/grpo/plugin/plugin.py \ - --reward_funcs format \ + --reward_funcs external_math_acc external_math_format \ --train_type lora \ --lora_rank 8 \ --lora_alpha 32 \ diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 120f1d6ac1..e64c6e6803 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1012,25 +1012,33 @@ def _compute_loss_and_metrics(self, model, inputs): old_per_token_logps = ( per_token_logps.detach() if inputs['old_per_token_logps'] is None else inputs['old_per_token_logps']) - # Apply vLLM importance sampling correction if enabled + # Compute rollout diagnostic metrics and apply IS correction if enabled rollout_correction_metrics = {} if inputs.get('vllm_per_token_logps') is not None: vllm_per_token_logps = inputs['vllm_per_token_logps'] - # Compute the log ratio between policy model and vLLM rollout model - # log π_θ(y|x) - log π_vllm(y|x) - vllm_log_ratio = old_per_token_logps - vllm_per_token_logps - # Apply importance sampling correction based on mode - rollout_is_weights = self._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + # Always compute diagnostic metrics (KL, PPL, etc.) for monitoring off-policy gap + # This helps diagnose whether rollout correction is needed + rollout_correction_metrics = self._compute_rollout_offpolicy_metrics(old_per_token_logps, + vllm_per_token_logps, completion_mask) - # Compute and log correction metrics - rollout_correction_metrics = self._compute_rollout_correction_metrics(old_per_token_logps, - vllm_per_token_logps, - rollout_is_weights, completion_mask) + # Apply importance sampling correction if mode is enabled + if self.rollout_importance_sampling_mode is not None: + # Compute the log ratio between policy model and vLLM rollout model + # log π_θ(y|x) - log π_vllm(y|x) + vllm_log_ratio = old_per_token_logps - vllm_per_token_logps - # Apply IS weights: multiply the final loss by the IS weight - # Store for later application in loss computation - inputs['rollout_is_weights'] = rollout_is_weights + # Apply importance sampling correction based on mode + rollout_is_weights = self._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Compute additional IS-specific metrics (ESS, clipped_frac, is_weight_mean) + is_metrics = self._compute_is_correction_metrics(vllm_log_ratio, rollout_is_weights, completion_mask) + rollout_correction_metrics.update(is_metrics) + + # Store IS weights for loss computation + inputs['rollout_is_weights'] = rollout_is_weights + else: + inputs['rollout_is_weights'] = None else: inputs['rollout_is_weights'] = None @@ -2167,93 +2175,118 @@ def _apply_rollout_importance_sampling(self, rollout_log_ratio: torch.Tensor, return is_weights - def _compute_rollout_correction_metrics( + def _compute_rollout_offpolicy_metrics( self, per_token_logps: torch.Tensor, rollout_per_token_logps: torch.Tensor, - is_weights: torch.Tensor, completion_mask: torch.Tensor, ) -> Dict[str, float]: """ - Compute rollout correction metrics: KL, PPL, chi-square, ESS. + Compute off-policy diagnostic metrics (always computed for monitoring). + + These metrics help diagnose the off-policy gap between rollout and training policies, + which can arise from policy mismatch (e.g., vLLM BF16 vs FSDP FP32), model staleness, + or general distribution shifts. Args: per_token_logps: Log probs from policy model, shape [B, T] rollout_per_token_logps: Log probs from rollout, shape [B, T] - is_weights: Importance sampling weights, shape [B, T] completion_mask: Boolean mask for completion tokens, shape [B, T] Returns: - Dictionary with metrics + Dictionary with metrics: kl, ppl_policy, ppl_rollout, log_ppl_diff """ metrics = {} - # Compute log ratios - # Keep original log_ratio for KL computation (accurate) - # Use clamped version for exponential operations (numerically stable) - SAFETY_BOUND = 20.0 - log_ratio = per_token_logps - rollout_per_token_logps - log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) - is_ratio = torch.exp(log_ratio_safe) - # Helper function for masked mean def masked_mean(x, mask): - # x: [B, T], mask: [B, T] (after pad_back) return (x * mask).sum() / mask.sum().clamp(min=1.0) - # 1. KL divergence: KL(π_θ || π_rollout) ≈ E[log(π_θ/π_rollout)] + # 1. Training policy perplexity (always computed) + mean_logps = (per_token_logps * completion_mask).sum(1) / completion_mask.sum(1) + policy_ppl = torch.exp(-mean_logps).mean() + metrics['ppl_policy'] = self.accelerator.gather_for_metrics(policy_ppl).nanmean().item() + + # 2. Rollout off-policy metrics + # KL divergence: KL(π_policy || π_rollout) ≈ E[log(π_policy/π_rollout)] + log_ratio = per_token_logps - rollout_per_token_logps kl_div = masked_mean(log_ratio, completion_mask) metrics['kl_rollout'] = self.accelerator.gather_for_metrics(kl_div).nanmean().item() - # 2. Perplexity: exp(-mean_log_prob) - rollout_ppl = torch.exp(-masked_mean(rollout_per_token_logps, completion_mask)) - policy_ppl = torch.exp(-masked_mean(per_token_logps, completion_mask)) + # Rollout policy perplexity + mean_rollout_logps = (rollout_per_token_logps * completion_mask).sum(1) / completion_mask.sum(1) + rollout_ppl = torch.exp(-mean_rollout_logps) metrics['ppl_rollout'] = self.accelerator.gather_for_metrics(rollout_ppl).nanmean().item() - metrics['ppl_policy'] = self.accelerator.gather_for_metrics(policy_ppl).nanmean().item() - # 3. Effective Sample Size (ESS): 1 / E[(w/E[w])²] - # ESS measures the "effective" number of independent samples after importance sampling correction - # Higher ESS means better sample quality and more stable gradient estimates - # For sequence-level ESS, we compute per-sequence ratios + # Log PPL difference (for easier monitoring of distribution drift) + mean_log_prob_policy = masked_mean(per_token_logps, completion_mask) + mean_log_prob_rollout = masked_mean(rollout_per_token_logps, completion_mask) + log_ppl_diff = -mean_log_prob_rollout - (-mean_log_prob_policy) # log(ppl_policy) - log(ppl_rollout) + metrics['log_ppl_diff'] = self.accelerator.gather_for_metrics(log_ppl_diff).nanmean().item() + + return metrics + + def _compute_is_correction_metrics( + self, + vllm_log_ratio: torch.Tensor, + is_weights: torch.Tensor, + completion_mask: torch.Tensor, + ) -> Dict[str, float]: + """ + Compute importance sampling correction metrics (ESS, clipped_frac, is_weight_mean). + Only called when rollout_importance_sampling_mode is enabled. + + Args: + vllm_log_ratio: Log ratio log(π_policy / π_rollout), shape [B, T] + is_weights: Importance sampling weights after correction, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Dictionary with IS-specific metrics + """ + metrics = {} + SAFETY_BOUND = 20.0 + + # Helper function for masked mean + def masked_mean(x, mask): + return (x * mask).sum() / mask.sum().clamp(min=1.0) + + # Compute IS ratio with safety bounds + log_ratio_safe = torch.clamp(vllm_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + is_ratio = torch.exp(log_ratio_safe) + + # Compute sequence-level ratios for ESS and clipped_frac log_r = torch.log(is_ratio.clamp(min=1e-10)) seq_log_ratios = (log_r * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) seq_ratios = torch.exp(seq_log_ratios) # ESS = 1 / E[(w/E[w])²] - measures effective number of independent samples - # For distributed training, gather all seq_ratios across ranks first to compute global ESS + # For distributed training, gather all seq_ratios across ranks first all_seq_ratios = self.accelerator.gather_for_metrics(seq_ratios) - - # normalize weights to mean=1, then compute ESS mean_seq_ratio = all_seq_ratios.mean() seq_ratios_normalized = all_seq_ratios / (mean_seq_ratio + 1e-8) ess = 1.0 / (seq_ratios_normalized**2).mean().clamp(min=1e-10) - # ESS is already normalized (ranges from ~0 to N where N is batch size) - # Divide by batch size to get relative ESS in [0, 1] + # Normalize by batch size to get relative ESS in [0, 1] num_sequences = max(len(all_seq_ratios), 1) ess_normalized = ess / num_sequences metrics['ess'] = ess_normalized.item() - # 4. IS weight statistics + # 2. IS weight statistics mean_is_weight = masked_mean(is_weights, completion_mask) metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() - # Fraction of clipped/masked samples + # 3. Fraction of clipped/masked samples + threshold = self.rollout_importance_sampling_threshold if self.rollout_importance_sampling_mode in ['token_truncate', 'token_mask']: # Token-level - threshold = self.rollout_importance_sampling_threshold if self.rollout_importance_sampling_mode == 'token_truncate': clipped_frac = masked_mean((is_ratio > threshold).float(), completion_mask) else: # token_mask clipped_frac = masked_mean((is_weights == 0).float(), completion_mask) metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() else: - # Sequence-level - threshold = self.rollout_importance_sampling_threshold - if self.rollout_importance_sampling_mode == 'sequence_truncate': - clipped_frac = (seq_ratios > threshold).float().mean() - else: # sequence_mask - # Check which sequences are masked (ratio > threshold) - clipped_frac = (seq_ratios > threshold).float().mean() + # Sequence-level (both truncate and mask) + clipped_frac = (seq_ratios > threshold).float().mean() metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() return metrics From c3c997a8d51567b6f564b04217fb5b0f61488e7b Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 14:16:49 +0800 Subject: [PATCH 12/21] fix rollout metrics under padding_free --- swift/trainers/rlhf_trainer/grpo_trainer.py | 81 ++++++++++++++------- swift/trainers/rlhf_trainer/utils.py | 43 ++++++----- 2 files changed, 80 insertions(+), 44 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index e64c6e6803..e573cccd0f 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -829,21 +829,30 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: # Convert to tensor if all samples have vllm_logprobs if all(lp is not None for lp in vllm_logprobs_list): if self.template.padding_free: - # In padding_free mode, use pad_logps_back_to_batch for consistency - # Concatenate all vllm logprobs into a flat tensor - vllm_logprobs_flat = [] - for lp in vllm_logprobs_list: - # Take last logits_to_keep tokens (or all if shorter) - lp_to_use = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp - vllm_logprobs_flat.extend(lp_to_use) - - # Convert to tensor [1, total_nnz] + # In padding_free mode, seq_lengths includes prompts for sequences after the first one + # (because logits_to_keep only removes the first prompt). + # But vllm_logprobs only contains completion logprobs for each sequence. + # So we need to use the actual vllm_logprobs lengths, not seq_lengths. + # + # We pad each sequence's vllm_logprobs to match seq_lengths[i] by prepending -1e10 + # for the prompt portion that vLLM doesn't have logprobs for. + seq_lengths = batch_encoded_inputs['seq_lengths'] + vllm_logprobs_aligned = [] + for i, lp in enumerate(vllm_logprobs_list): + target_len = seq_lengths[i].item() + vllm_len = len(lp) + if vllm_len >= target_len: + # vLLM has more tokens than expected, take last target_len + vllm_logprobs_aligned.extend(lp[-target_len:]) + else: + # vLLM has fewer tokens (only completion), pad the front (prompt portion) + pad_len = target_len - vllm_len + vllm_logprobs_aligned.extend([-1e10] * pad_len + list(lp)) + + # Convert to tensor and pad to batch format vllm_logprobs_rmpad = torch.tensor( - vllm_logprobs_flat, dtype=torch.float32, device=self.accelerator.device).unsqueeze(0) + vllm_logprobs_aligned, dtype=torch.float32, device=self.accelerator.device).unsqueeze(0) - # Restore to batch format using the same seq_lengths - from swift.trainers.rlhf_trainer.utils import pad_logps_back_to_batch - seq_lengths = batch_encoded_inputs['seq_lengths'] batch_size = seq_lengths.shape[0] vllm_logps_padded, _ = pad_logps_back_to_batch( logps_rmpad=vllm_logprobs_rmpad, @@ -960,8 +969,10 @@ def _compute_loss_single(self, model, inputs): def _compute_loss_and_metrics(self, model, inputs): """Core loss computation without metrics recording.""" mode = 'train' if self.model.training else 'eval' - - completion_mask = inputs['completion_mask'] + if self.template.padding_free: + completion_mask = inputs['completion_mask_padded'] + else: + completion_mask = inputs['completion_mask'] truncated_mask = inputs['truncated_mask'] per_token_logps, entropies = self._get_per_token_logps_and_entropies( model, inputs, compute_entropy=self.compute_entropy) @@ -1461,7 +1472,7 @@ def _get_per_token_logps_and_entropies_single(self, entropy_rmpad = None # Restore to batch shape using seq_lengths - logps, completion_mask = pad_logps_back_to_batch( + logps, padded_shape_mask = pad_logps_back_to_batch( logps_rmpad=per_token_logps_rmpad.unsqueeze(0), # [1, total_nnz] logits_to_keep=logits_to_keep, batch_size=batch_size, @@ -1477,9 +1488,18 @@ def _get_per_token_logps_and_entropies_single(self, else: entropies = None - # Store the restored completion_mask back to inputs - # This eliminates the need to recompute it in loss functions - inputs['completion_mask'] = completion_mask + # In padding_free mode, the original completion_mask is [1, logits_to_keep] (flattened). + # We need to convert it to [batch_size, max_seq_len] format. + # The original mask correctly identifies completion vs prompt tokens. + if 'completion_mask_padded' not in inputs: + original_completion_mask = inputs['completion_mask'] # [1, logits_to_keep] + completion_mask_padded, _ = pad_logps_back_to_batch( + logps_rmpad=original_completion_mask.float(), # [1, logits_to_keep] + logits_to_keep=logits_to_keep, + batch_size=batch_size, + seq_lengths=original_seq_lengths) + # Combine with shape mask to ensure padding positions are also masked + inputs['completion_mask_padded'] = (completion_mask_padded > 0.5) & (padded_shape_mask > 0.5) else: logps = selective_log_softmax(logits, input_ids_for_logps) @@ -2145,8 +2165,13 @@ def _apply_rollout_importance_sampling(self, rollout_log_ratio: torch.Tensor, mode = self.rollout_importance_sampling_mode threshold = self.rollout_importance_sampling_threshold + # Clamp log_ratio to prevent numerical overflow from padding values (-1e10) + # A log_ratio of 20 corresponds to exp(20) ≈ 485 million, which is already extreme + SAFETY_BOUND = 20.0 + rollout_log_ratio_safe = torch.clamp(rollout_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + # Compute importance sampling ratios: exp(log_ratio) - is_ratio = torch.exp(rollout_log_ratio) + is_ratio = torch.exp(rollout_log_ratio_safe) if mode == 'token_truncate': # Token-level truncated IS: clip ratios from above at threshold @@ -2198,29 +2223,35 @@ def _compute_rollout_offpolicy_metrics( """ metrics = {} + # Clamp rollout logps to prevent numerical issues from padding values + # Padding values are typically -1e10, which would cause exp() overflow + LOGP_MIN = -100.0 # log(exp(-100)) is essentially 0 probability + rollout_per_token_logps_safe = torch.clamp(rollout_per_token_logps, min=LOGP_MIN) + # Helper function for masked mean def masked_mean(x, mask): return (x * mask).sum() / mask.sum().clamp(min=1.0) # 1. Training policy perplexity (always computed) - mean_logps = (per_token_logps * completion_mask).sum(1) / completion_mask.sum(1) + mean_logps = (per_token_logps * completion_mask).sum(1) / completion_mask.sum(1).clamp(min=1.0) policy_ppl = torch.exp(-mean_logps).mean() metrics['ppl_policy'] = self.accelerator.gather_for_metrics(policy_ppl).nanmean().item() # 2. Rollout off-policy metrics # KL divergence: KL(π_policy || π_rollout) ≈ E[log(π_policy/π_rollout)] - log_ratio = per_token_logps - rollout_per_token_logps + log_ratio = per_token_logps - rollout_per_token_logps_safe kl_div = masked_mean(log_ratio, completion_mask) metrics['kl_rollout'] = self.accelerator.gather_for_metrics(kl_div).nanmean().item() # Rollout policy perplexity - mean_rollout_logps = (rollout_per_token_logps * completion_mask).sum(1) / completion_mask.sum(1) - rollout_ppl = torch.exp(-mean_rollout_logps) + mean_rollout_logps = (rollout_per_token_logps_safe + * completion_mask).sum(1) / completion_mask.sum(1).clamp(min=1.0) + rollout_ppl = torch.exp(-mean_rollout_logps).mean() metrics['ppl_rollout'] = self.accelerator.gather_for_metrics(rollout_ppl).nanmean().item() # Log PPL difference (for easier monitoring of distribution drift) mean_log_prob_policy = masked_mean(per_token_logps, completion_mask) - mean_log_prob_rollout = masked_mean(rollout_per_token_logps, completion_mask) + mean_log_prob_rollout = masked_mean(rollout_per_token_logps_safe, completion_mask) log_ppl_diff = -mean_log_prob_rollout - (-mean_log_prob_policy) # log(ppl_policy) - log(ppl_rollout) metrics['log_ppl_diff'] = self.accelerator.gather_for_metrics(log_ppl_diff).nanmean().item() diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 0eb45dd598..a3b751455d 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -1162,13 +1162,10 @@ def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, # Determine sequence lengths if seq_lengths is not None: - # Use provided seq_lengths directly - total_length = seq_lengths.sum().item() - if total_length > logits_to_keep: - # Adjust the first sequence length to account for truncation - adjustment = total_length - logits_to_keep - seq_lengths = seq_lengths.clone() - seq_lengths[0] = seq_lengths[0] - adjustment + # Use provided seq_lengths directly - they should already be adjusted + # by the caller (e.g., in _generate_and_score_completions) + # DO NOT adjust again here to avoid double adjustment + pass else: # Fallback: infer from position_ids from swift.utils.torch_utils import get_cu_seqlens_from_position_ids as get_cu_seqlens @@ -1189,31 +1186,39 @@ def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=device), seq_lengths]), dim=0) max_seq_len = logits_to_keep # All sequences will be padded to this length - # Initialize output tensors - logps_padded = torch.zeros(batch_size, max_seq_len, dtype=dtype, device=device) + # Initialize output tensors with padding value + logps_padded = torch.full((batch_size, max_seq_len), -1e10, dtype=dtype, device=device) completion_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.float32, device=device) # Unflatten: assign each sequence's logps to the corresponding row + # Use LEFT PADDING (right-align the data) to match the standard padding convention logps_flat = logps_rmpad.squeeze(0) # [total_nnz] for i in range(batch_size): start_idx = cu_seqlens[i].item() end_idx = cu_seqlens[i + 1].item() - seq_len = seq_lengths[i].item() + seq_len = int(seq_lengths[i].item()) actual_end_idx = min(end_idx, len(logps_flat)) actual_len = actual_end_idx - start_idx + if actual_len <= 0: + continue + + # Left padding: place data at the RIGHT side of the row + # pad_len is the number of padding tokens at the beginning + pad_len = max_seq_len - seq_len + if actual_len < seq_len: - # pad at the beginning - pad_len = seq_len - actual_len - logps_padded[i, :pad_len] = -1e10 # Padding value - logps_padded[i, pad_len:seq_len] = logps_flat[start_idx:actual_end_idx] + # Input data is shorter than expected seq_len + # This happens when logps_flat doesn't have enough data + # Place actual data at the rightmost positions + data_pad_len = max_seq_len - actual_len + logps_padded[i, data_pad_len:] = logps_flat[start_idx:actual_end_idx] + completion_mask[i, data_pad_len:] = 1.0 else: - # Normal case - logps_padded[i, :seq_len] = logps_flat[start_idx:end_idx] - - # Set mask for valid positions - completion_mask[i, :seq_len] = 1.0 + # Normal case: seq_len tokens of data + logps_padded[i, pad_len:] = logps_flat[start_idx:end_idx] + completion_mask[i, pad_len:] = 1.0 return logps_padded, completion_mask From 0cf6c840a0306ded7828b83f93b77f779e978601 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 15:13:02 +0800 Subject: [PATCH 13/21] right padding for non-padding-free --- .../infer/infer_engine/grpo_vllm_engine.py | 2 +- swift/llm/infer/infer_engine/vllm_engine.py | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 8 ++++--- swift/trainers/rlhf_trainer/rollout_mixin.py | 2 +- swift/trainers/rlhf_trainer/utils.py | 24 ++++++++++--------- 5 files changed, 21 insertions(+), 17 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 54690097cf..a2fa6495e4 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -181,7 +181,7 @@ def _create_chat_completion_response(self, result, inputs, template: Template, r logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(response, template) - token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None + token_ids = output.token_ids if request_config.return_details else None choice = ChatCompletionResponseChoice( index=output.index, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 23b95a9c30..ec5cd64258 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -573,7 +573,7 @@ def _create_chat_completion_response( logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(content, template) # Use content instead of response for tool calls - token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None + token_ids = output.token_ids if request_config.return_details else None choice = ChatCompletionResponseChoice( index=output.index, message=ChatMessage( diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index e573cccd0f..7c22e61eb8 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -870,7 +870,8 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: lp_tensor = lp[-logits_to_keep:] if len(lp) >= logits_to_keep else lp # Pad if needed if len(lp_tensor) < max_len: - lp_tensor = [-1e10] * (max_len - len(lp_tensor)) + lp_tensor + # right padding + lp_tensor = lp_tensor + [-1e10] * (max_len - len(lp_tensor)) padded_logprobs.append(lp_tensor) batch_encoded_inputs['vllm_per_token_logps'] = torch.tensor( padded_logprobs, dtype=torch.float32, device=self.accelerator.device) @@ -1497,9 +1498,10 @@ def _get_per_token_logps_and_entropies_single(self, logps_rmpad=original_completion_mask.float(), # [1, logits_to_keep] logits_to_keep=logits_to_keep, batch_size=batch_size, - seq_lengths=original_seq_lengths) + seq_lengths=original_seq_lengths, + pad_value=0.0) # Combine with shape mask to ensure padding positions are also masked - inputs['completion_mask_padded'] = (completion_mask_padded > 0.5) & (padded_shape_mask > 0.5) + inputs['completion_mask_padded'] = completion_mask_padded else: logps = selective_log_softmax(logits, input_ids_for_logps) diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index eb395b4b74..71a6d8a0a4 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -853,7 +853,7 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out input_data['finish_reason'] = choice.finish_reason input_data['is_truncated'] = choice.finish_reason == 'length' - input_data['add_eos'] = not choice.finish_reason == 'length' + input_data['add_eos'] = False if output.rollout_infos: multi_modal_keys = ['images', 'videos', 'audios'] for key in multi_modal_keys: diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index a3b751455d..ee2f91790b 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -1136,24 +1136,26 @@ def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, logits_to_keep: int = None, batch_size: int = None, seq_lengths: Optional[torch.Tensor] = None, - dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]: + dtype: Optional[torch.dtype] = None, + pad_value: float = -1e10) -> Tuple[torch.Tensor, torch.Tensor]: """ - Restore padding-free logprobs back to [batch_size, seq_len] shape. + Restore padding-free logprobs back to [batch_size, seq_len] shape with LEFT PADDING. - Input: logps in rmpad format [1, total_nnz] - - Output: logps in batch format [batch_size, max_seq_len] + - Output: logps in batch format [batch_size, max_seq_len] with data right-aligned Args: logps_rmpad: [1, total_nnz] per-token log probabilities in padding_free format position_ids: [1, total_nnz] position ids to determine sequence boundaries (deprecated, use seq_lengths) - logits_to_keep: number of tokens to keep per sequence + logits_to_keep: number of tokens to keep per sequence (= max_seq_len) batch_size: number of sequences in the batch seq_lengths: [batch_size] actual sequence lengths (preferred over position_ids) dtype: optional dtype for output, defaults to logps_rmpad.dtype + pad_value: value to use for padding positions (default: -1e10 for logps, use 0.0 for masks) Returns: - logps_padded: [batch_size, logits_to_keep] padded log probabilities - completion_mask: [batch_size, logits_to_keep] mask indicating valid positions + logps_padded: [batch_size, logits_to_keep] padded log probabilities (left-padded, data right-aligned) + valid_mask: [batch_size, logits_to_keep] mask indicating valid (non-padding) positions """ if dtype is None: dtype = logps_rmpad.dtype @@ -1187,8 +1189,8 @@ def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, max_seq_len = logits_to_keep # All sequences will be padded to this length # Initialize output tensors with padding value - logps_padded = torch.full((batch_size, max_seq_len), -1e10, dtype=dtype, device=device) - completion_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.float32, device=device) + logps_padded = torch.full((batch_size, max_seq_len), pad_value, dtype=dtype, device=device) + valid_mask = torch.zeros(batch_size, max_seq_len, dtype=torch.float32, device=device) # Unflatten: assign each sequence's logps to the corresponding row # Use LEFT PADDING (right-align the data) to match the standard padding convention @@ -1215,10 +1217,10 @@ def pad_logps_back_to_batch(logps_rmpad: torch.Tensor, # Place actual data at the rightmost positions data_pad_len = max_seq_len - actual_len logps_padded[i, data_pad_len:] = logps_flat[start_idx:actual_end_idx] - completion_mask[i, data_pad_len:] = 1.0 + valid_mask[i, data_pad_len:] = 1.0 else: # Normal case: seq_len tokens of data logps_padded[i, pad_len:] = logps_flat[start_idx:end_idx] - completion_mask[i, pad_len:] = 1.0 + valid_mask[i, pad_len:] = 1.0 - return logps_padded, completion_mask + return logps_padded, valid_mask From 8d93bb5d030604cdd18a8c57bb5911e81823637d Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 17:15:22 +0800 Subject: [PATCH 14/21] update metrics&doc --- .../Instruction/Command-line-parameters.md | 2 + .../training_inference_mismatch.md | 60 +++-- .../Instruction/GRPO/GetStarted/GRPO.md | 14 ++ .../Instruction/Command-line-parameters.md | 2 + .../GRPO/AdvancedResearch/index.rst | 1 + .../training_inference_mismatch.md | 206 ++++++++++++++++++ .../Instruction/GRPO/GetStarted/GRPO.md | 14 ++ swift/trainers/rlhf_trainer/grpo_trainer.py | 155 ++++++++----- 8 files changed, 387 insertions(+), 67 deletions(-) create mode 100644 docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index af857745ca..0356d89aa7 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -607,6 +607,8 @@ reward模型参数将在PPO、GRPO中使用。 - max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。 - top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) - log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) +- rollout_importance_sampling_mode: 训推不一致校正模式,可选项为 `token_truncate`、`token_mask`、`sequence_truncate`、`sequence_mask`。默认为None,不启用校正。具体参考[文档](./GRPO/AdvancedResearch/training_inference_mismatch.md) +- rollout_importance_sampling_threshold: 重要性采样权重的阈值,用于截断或屏蔽极端权重。默认为2.0。 ##### 奖励函数参数 内置的奖励函数参考[文档](./GRPO/DeveloperGuide/reward_function.md) diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md index 806279b3fb..9c425ed6aa 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md +++ b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -40,7 +40,6 @@ $$ $$ 其中样本来自 $\pi_{\text{vLLM}}$,但梯度是基于 $\pi_\theta$ 计算的,这**破坏了算法的 on-policy 假设**,引入了训推不一致的问题。 -- 最终性能下降 ## Solution @@ -108,7 +107,7 @@ $$ ### 四种校正模式 -结合粒度和控制策略,共有四种校正模式(通过 `--rollout_importance_sampling_mode` 参数选择): +结合粒度和控制策略,共设置四种校正模式(通过 `--rollout_importance_sampling_mode` 参数选择): | 模式 | 说明 | |------|------| @@ -121,17 +120,26 @@ $$ ## Metrics -为了监控训练中训推不一致的程度,我们在log中加入以下指标: +为了监控训练中训推不一致的程度,我们在log中加入以下指标(前缀为 `rollout_correction/`): ### 1. KL 散度(KL Divergence) -KL 散度衡量两个分布之间的差异: +KL 散度衡量两个分布之间的差异,提供两种估计器: + +**直接估计器 `kl`**: $$ -\text{KL}(\pi_\theta \| \pi_{\text{vLLM}}) \approx \frac{1}{|y|} \sum_{t=1}^{|y|} \log \frac{\pi_\theta(y_t|x, y_{=3.11) +- `kl` / `k3_kl`:训练策略与 rollout 策略之间的 KL 散度(直接估计器 / K3 估计器) +- `training_ppl` / `rollout_ppl`:训练策略和 rollout 策略的困惑度 +- `log_ppl_diff`:log PPL 差异,反映分布偏移程度 +- `ppl_ratio`:PPL 比率 +- `chi2_token` / `chi2_seq`:Token/Sequence 级别的 χ² 散度 + +IS 校正指标(需设置rollout_importance_sampling_mode) +- `is_weight_mean`:平均重要性采样权重 +- `ess`:有效样本大小(Effective Sample Size) +- `clipped_frac`:被截断或屏蔽的样本比例 + +> 训推一致性指标详细说明请参考文档 [Training-Inference-Mismatch](../AdvancedResearch/training_inference_mismatch.md) + 如果设置了`log_completions`, 将保存训练动态在output对应文件夹中,包括 - step:记录时的训练步数 - prompt:模型输入 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 243c4d5d0f..08d878a179 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -621,6 +621,8 @@ The hyperparameters for the reward function can be found in the [Built-in Reward - max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit. - top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md). - log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). +- rollout_importance_sampling_mode: Training-inference mismatch correction mode. Options are `token_truncate`, `token_mask`, `sequence_truncate`, `sequence_mask`. Default is None (disabled). For details, refer to the [documentation](./GRPO/AdvancedResearch/training_inference_mismatch.md). +- rollout_importance_sampling_threshold: Threshold for importance sampling weights, used for truncating or masking extreme weights. Default is 2.0. ##### Reward function parameters diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst b/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst index 0cf7cd2478..0634af42c4 100644 --- a/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst @@ -11,3 +11,4 @@ Advanced Research RLOO.md CHORD.md CISPO.md + training_inference_mismatch.md diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md new file mode 100644 index 0000000000..731d23f61d --- /dev/null +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -0,0 +1,206 @@ +# Training-Inference-Mismatch + +**Version Requirement**: ms-swift>=3.11 + +**TL;DR**: While GRPO introduces vLLM to accelerate the sampling process, it also introduces Training-Inference Mismatch issues that may affect training stability. This document explains the background, causes, and solutions to this problem. + +## Background + +### Basic Assumptions of GRPO + +The training objective of GRPO (Group Relative Policy Optimization) can be expressed as: + +$$ +\mathcal{L}_{\text{GRPO}} = - \mathbb{E}_{y \sim \pi_\theta} \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right] +$$ + +Where: +- $r_t(\theta) = \frac{\pi_\theta(y_t|x, y_{=3.11): +- `kl` / `k3_kl`: KL divergence between training policy and rollout policy (direct estimator / K3 estimator) +- `training_ppl` / `rollout_ppl`: Perplexity of training policy and rollout policy +- `log_ppl_diff`: Log PPL difference, reflects the degree of distribution shift +- `ppl_ratio`: PPL ratio +- `chi2_token` / `chi2_seq`: Token/Sequence-level χ² divergence + +IS correction metrics (requires setting rollout_importance_sampling_mode): +- `is_weight_mean`: Average importance sampling weight +- `ess`: Effective Sample Size +- `clipped_frac`: Fraction of samples that were truncated or masked + +> For detailed explanation of training-inference consistency metrics, please refer to [Training-Inference-Mismatch](../AdvancedResearch/training_inference_mismatch.md) + If `log_completions` is set, the training dynamics will be saved in the output directory, including: - step: The training step at the time of logging. - prompt: The model input. diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 7c22e61eb8..c38b487ff9 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -2210,52 +2210,107 @@ def _compute_rollout_offpolicy_metrics( ) -> Dict[str, float]: """ Compute off-policy diagnostic metrics (always computed for monitoring). + reference: verl/verl/trainer/ppo/rollout_corr_helper.py These metrics help diagnose the off-policy gap between rollout and training policies, which can arise from policy mismatch (e.g., vLLM BF16 vs FSDP FP32), model staleness, or general distribution shifts. + Key metrics: + - kl: Direct KL divergence estimator KL(π_rollout || π_training) + - k3_kl: K3 KL estimator for stability (more stable for small KL) + - training_ppl: Perplexity of training policy + - rollout_ppl: Perplexity of rollout policy + - log_ppl_diff: Difference in log perplexities + - ppl_ratio: Ratio of training PPL to rollout PPL + - chi2_token: Token-level χ² divergence E[ρ²] - 1 + - chi2_seq: Sequence-level χ² divergence E[(∏ρ_t)²] - 1 + Args: - per_token_logps: Log probs from policy model, shape [B, T] - rollout_per_token_logps: Log probs from rollout, shape [B, T] + per_token_logps: Log probs from training policy model, shape [B, T] + rollout_per_token_logps: Log probs from rollout policy, shape [B, T] completion_mask: Boolean mask for completion tokens, shape [B, T] Returns: - Dictionary with metrics: kl, ppl_policy, ppl_rollout, log_ppl_diff + Dictionary with off-policy diagnostic metrics """ + SAFETY_BOUND = 20.0 metrics = {} - # Clamp rollout logps to prevent numerical issues from padding values - # Padding values are typically -1e10, which would cause exp() overflow - LOGP_MIN = -100.0 # log(exp(-100)) is essentially 0 probability - rollout_per_token_logps_safe = torch.clamp(rollout_per_token_logps, min=LOGP_MIN) - # Helper function for masked mean - def masked_mean(x, mask): - return (x * mask).sum() / mask.sum().clamp(min=1.0) + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) # 1. Training policy perplexity (always computed) - mean_logps = (per_token_logps * completion_mask).sum(1) / completion_mask.sum(1).clamp(min=1.0) - policy_ppl = torch.exp(-mean_logps).mean() - metrics['ppl_policy'] = self.accelerator.gather_for_metrics(policy_ppl).nanmean().item() - - # 2. Rollout off-policy metrics - # KL divergence: KL(π_policy || π_rollout) ≈ E[log(π_policy/π_rollout)] - log_ratio = per_token_logps - rollout_per_token_logps_safe - kl_div = masked_mean(log_ratio, completion_mask) - metrics['kl_rollout'] = self.accelerator.gather_for_metrics(kl_div).nanmean().item() - - # Rollout policy perplexity - mean_rollout_logps = (rollout_per_token_logps_safe - * completion_mask).sum(1) / completion_mask.sum(1).clamp(min=1.0) - rollout_ppl = torch.exp(-mean_rollout_logps).mean() - metrics['ppl_rollout'] = self.accelerator.gather_for_metrics(rollout_ppl).nanmean().item() - - # Log PPL difference (for easier monitoring of distribution drift) - mean_log_prob_policy = masked_mean(per_token_logps, completion_mask) - mean_log_prob_rollout = masked_mean(rollout_per_token_logps_safe, completion_mask) - log_ppl_diff = -mean_log_prob_rollout - (-mean_log_prob_policy) # log(ppl_policy) - log(ppl_rollout) - metrics['log_ppl_diff'] = self.accelerator.gather_for_metrics(log_ppl_diff).nanmean().item() + # Formula: exp(-1/|T| * Σ log π_training(y_t|y_ Dict[str, float]: """ - Compute importance sampling correction metrics (ESS, clipped_frac, is_weight_mean). + Compute importance sampling correction metrics (ess, clipped_frac, is_weight_mean). Only called when rollout_importance_sampling_mode is enabled. Args: @@ -2275,10 +2330,15 @@ def _compute_is_correction_metrics( completion_mask: Boolean mask for completion tokens, shape [B, T] Returns: - Dictionary with IS-specific metrics + Dictionary with IS-specific metrics: + - is_weight_mean: Mean of IS weights + - ess: Effective Sample Size = 1 / E[(w_i / E[w_i])²] + - clipped_frac: Fraction of clipped/masked samples """ metrics = {} SAFETY_BOUND = 20.0 + threshold = self.rollout_importance_sampling_threshold + threshold_lower = 1.0 / threshold # Default lower threshold (reciprocal of upper) # Helper function for masked mean def masked_mean(x, mask): @@ -2288,28 +2348,20 @@ def masked_mean(x, mask): log_ratio_safe = torch.clamp(vllm_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) is_ratio = torch.exp(log_ratio_safe) - # Compute sequence-level ratios for ESS and clipped_frac - log_r = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_r * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) - - # ESS = 1 / E[(w/E[w])²] - measures effective number of independent samples - # For distributed training, gather all seq_ratios across ranks first - all_seq_ratios = self.accelerator.gather_for_metrics(seq_ratios) - mean_seq_ratio = all_seq_ratios.mean() - seq_ratios_normalized = all_seq_ratios / (mean_seq_ratio + 1e-8) - ess = 1.0 / (seq_ratios_normalized**2).mean().clamp(min=1e-10) - # Normalize by batch size to get relative ESS in [0, 1] - num_sequences = max(len(all_seq_ratios), 1) - ess_normalized = ess / num_sequences - metrics['ess'] = ess_normalized.item() - - # 2. IS weight statistics + # 1. IS weight statistics mean_is_weight = masked_mean(is_weights, completion_mask) metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() + # 2. Compute Effective Sample Size (ESS) for IS weights + # ESS = 1 / E[(w_i / E[w_i])²] (using clamped weights for stability) + # This measures how many "effective" independent samples we have after IS weighting + weights_for_ess = is_weights.clamp(min=threshold_lower, max=threshold) + mean_for_ess = masked_mean(weights_for_ess, completion_mask) + is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero + ess = 1.0 / masked_mean(is_weights_normalized.square(), completion_mask).clamp(min=1e-10) + metrics['ess'] = self.accelerator.gather_for_metrics(ess).nanmean().item() + # 3. Fraction of clipped/masked samples - threshold = self.rollout_importance_sampling_threshold if self.rollout_importance_sampling_mode in ['token_truncate', 'token_mask']: # Token-level if self.rollout_importance_sampling_mode == 'token_truncate': @@ -2319,6 +2371,7 @@ def masked_mean(x, mask): metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() else: # Sequence-level (both truncate and mask) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) clipped_frac = (seq_ratios > threshold).float().mean() metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() From 481b2b8e61bdbdfe32d48ccb5a8cda774447276d Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 17:28:18 +0800 Subject: [PATCH 15/21] megatron grpo doc --- docs/source/Megatron-SWIFT/GRPO.md | 1 + docs/source_en/Megatron-SWIFT/GRPO.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/source/Megatron-SWIFT/GRPO.md b/docs/source/Megatron-SWIFT/GRPO.md index a8aa4df0e4..c59c0880b1 100644 --- a/docs/source/Megatron-SWIFT/GRPO.md +++ b/docs/source/Megatron-SWIFT/GRPO.md @@ -15,6 +15,7 @@ Megatron GRPO 当前已支持以下功能: 以下参数或功能将在后续版本中逐步支持: - **Entropy 相关配置**:如 `top_entropy_quantile`、`log_entropy` +- **Rollout Correction(TIS/MIS)** - **Reward Model / Reward Model Plugin** - **多轮 Rollout 调度机制**(`multi_turn_scheduler`):实现多轮对话策略优化 - **优势估计器**(`advantage_estimator`):支持更复杂的策略梯度估计方法 diff --git a/docs/source_en/Megatron-SWIFT/GRPO.md b/docs/source_en/Megatron-SWIFT/GRPO.md index 3fa9dfb58d..06e79db35a 100644 --- a/docs/source_en/Megatron-SWIFT/GRPO.md +++ b/docs/source_en/Megatron-SWIFT/GRPO.md @@ -15,6 +15,7 @@ Megatron GRPO currently supports the following features: The following parameters or features will be gradually supported in future versions: - **Entropy-related Configuration**: e.g., `top_entropy_quantile`, `log_entropy` +- **Rollout Correction(TIS/MIS)** - **Reward Model / Reward Model Plugin** - **Multi-turn Rollout Scheduling** (`multi_turn_scheduler`): Multi-turn conversation policy optimization - **Advantage Estimator** (`advantage_estimator`): Support for more complex policy gradient estimation methods From 3bc4d634cfef320de119fc0d5f5a9119ad0a10d5 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 17:46:03 +0800 Subject: [PATCH 16/21] add test --- .../test_vllm_importance_sampling_basic.py | 432 +++++++++++++----- 1 file changed, 330 insertions(+), 102 deletions(-) diff --git a/tests/train/test_vllm_importance_sampling_basic.py b/tests/train/test_vllm_importance_sampling_basic.py index 2fc18dd416..20e3541515 100644 --- a/tests/train/test_vllm_importance_sampling_basic.py +++ b/tests/train/test_vllm_importance_sampling_basic.py @@ -1,84 +1,163 @@ """ Basic tests for vLLM Importance Sampling implementation -This test file verifies the core functionality of the vLLM IS correction. +This test file verifies the core functionality of the vLLM IS correction, +including the IS weight computation and metrics calculation. + +Reference: verl/verl/trainer/ppo/rollout_corr_helper.py """ import torch +class MockAccelerator: + """Mock accelerator for testing metrics gathering""" + + def __init__(self, device='cpu'): + self.device = device + + def gather_for_metrics(self, tensor): + # In testing, just return the tensor as-is + return tensor + + class MockGRPOTrainer: """Mock GRPO trainer for testing IS methods""" def __init__(self, mode='token_truncate', threshold=2.0): self.rollout_importance_sampling_mode = mode self.rollout_importance_sampling_threshold = threshold - self.template = MockTemplate() + self.accelerator = MockAccelerator() + + def _compute_sequence_level_ratios(self, is_ratio: torch.Tensor, completion_mask: torch.Tensor) -> torch.Tensor: + """ + Helper function to compute sequence-level importance sampling ratios. + + Args: + is_ratio: Token-level IS ratios, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Sequence-level ratios as geometric mean of token-level ratios + """ + log_ratio = torch.log(is_ratio.clamp(min=1e-10)) + seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) + seq_ratios = torch.exp(seq_log_ratios) + + return seq_ratios + + def _apply_rollout_importance_sampling(self, rollout_log_ratio: torch.Tensor, + completion_mask: torch.Tensor) -> torch.Tensor: + """ + Apply vLLM importance sampling correction using one of four modes. + + Args: + rollout_log_ratio: log(π_θ / π_rollout) per token, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] - def _apply_rollout_importance_sampling(self, vllm_log_ratio, completion_mask, lengths=None): - """Copy of the implementation from grpo_trainer.py""" + Returns: + IS weights to multiply with loss, same shape as rollout_log_ratio + """ mode = self.rollout_importance_sampling_mode threshold = self.rollout_importance_sampling_threshold - is_ratio = torch.exp(vllm_log_ratio) + # Clamp log_ratio to prevent numerical overflow from padding values (-1e10) + # A log_ratio of 20 corresponds to exp(20) ≈ 485 million, which is already extreme + SAFETY_BOUND = 20.0 + rollout_log_ratio_safe = torch.clamp(rollout_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + + # Compute importance sampling ratios: exp(log_ratio) + is_ratio = torch.exp(rollout_log_ratio_safe) if mode == 'token_truncate': + # Token-level truncated IS: clip ratios from above at threshold is_weights = torch.clamp(is_ratio, max=threshold) elif mode == 'token_mask': + # Token-level masked IS: mask out tokens with ratio > threshold is_weights = torch.where(is_ratio <= threshold, is_ratio, torch.zeros_like(is_ratio)) elif mode == 'sequence_truncate': - if self.template.padding_free: - ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - - seq_ratios = [] - for ratio, mask in zip(ratio_list, mask_list): - log_ratio = torch.log(ratio.clamp(min=1e-10)) - seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) - seq_ratios.append(seq_ratio) - - seq_ratios = torch.stack(seq_ratios) - clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) - is_weights = torch.repeat_interleave(clipped_seq_ratios, lengths).unsqueeze(0) - else: - log_ratio = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) - clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) - is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) + # Sequence-level truncated IS: compute sequence-level ratio and clip + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + clipped_seq_ratios = torch.clamp(seq_ratios, max=threshold) + + is_weights = clipped_seq_ratios.unsqueeze(-1).expand_as(is_ratio) elif mode == 'sequence_mask': - if self.template.padding_free: - ratio_list = torch.split(is_ratio.squeeze(0), lengths.tolist()) - mask_list = torch.split(completion_mask.squeeze(0), lengths.tolist()) - - seq_ratios = [] - for ratio, mask in zip(ratio_list, mask_list): - log_ratio = torch.log(ratio.clamp(min=1e-10)) - seq_ratio = torch.exp((log_ratio * mask).sum() / mask.sum().clamp(min=1.0)) - seq_ratios.append(seq_ratio) - - seq_ratios = torch.stack(seq_ratios) - seq_mask = (seq_ratios <= threshold).float() - is_weights = torch.repeat_interleave(seq_mask, lengths).unsqueeze(0) - else: - log_ratio = torch.log(is_ratio.clamp(min=1e-10)) - seq_log_ratios = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0) - seq_ratios = torch.exp(seq_log_ratios) - seq_mask = (seq_ratios <= threshold).float() - is_weights = seq_mask.unsqueeze(-1).expand_as(is_ratio) + # Sequence-level masked IS: mask entire sequences with ratio > threshold + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + seq_mask = (seq_ratios <= threshold).float() + + # Apply mask to original token-level ratios + is_weights = is_ratio * seq_mask.unsqueeze(-1) else: - raise ValueError(f'Unknown mode: {mode}') + return is_ratio return is_weights + def _compute_is_correction_metrics( + self, + vllm_log_ratio: torch.Tensor, + is_weights: torch.Tensor, + completion_mask: torch.Tensor, + ) -> dict: + """ + Compute importance sampling correction metrics (ess, clipped_frac, is_weight_mean). + Only called when rollout_importance_sampling_mode is enabled. + + Args: + vllm_log_ratio: Log ratio log(π_policy / π_rollout), shape [B, T] + is_weights: Importance sampling weights after correction, shape [B, T] + completion_mask: Boolean mask for completion tokens, shape [B, T] + + Returns: + Dictionary with IS-specific metrics: + - is_weight_mean: Mean of IS weights + - ess: Effective Sample Size = 1 / E[(w_i / E[w_i])²] + - clipped_frac: Fraction of clipped/masked samples + """ + metrics = {} + SAFETY_BOUND = 20.0 + threshold = self.rollout_importance_sampling_threshold + threshold_lower = 1.0 / threshold # Default lower threshold (reciprocal of upper) + + # Helper function for masked mean + def masked_mean(x, mask): + return (x * mask).sum() / mask.sum().clamp(min=1.0) + + # Compute IS ratio with safety bounds + log_ratio_safe = torch.clamp(vllm_log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + is_ratio = torch.exp(log_ratio_safe) + + # 1. IS weight statistics + mean_is_weight = masked_mean(is_weights, completion_mask) + metrics['is_weight_mean'] = self.accelerator.gather_for_metrics(mean_is_weight).nanmean().item() + + # 2. Compute Effective Sample Size (ESS) for IS weights + # ESS = 1 / E[(w_i / E[w_i])²] (using clamped weights for stability) + # This measures how many "effective" independent samples we have after IS weighting + weights_for_ess = is_weights.clamp(min=threshold_lower, max=threshold) + mean_for_ess = masked_mean(weights_for_ess, completion_mask) + is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8) # Avoid division by zero + ess = 1.0 / masked_mean(is_weights_normalized.square(), completion_mask).clamp(min=1e-10) + metrics['ess'] = self.accelerator.gather_for_metrics(ess).nanmean().item() + + # 3. Fraction of clipped/masked samples + if self.rollout_importance_sampling_mode in ['token_truncate', 'token_mask']: + # Token-level + if self.rollout_importance_sampling_mode == 'token_truncate': + clipped_frac = masked_mean((is_ratio > threshold).float(), completion_mask) + else: # token_mask + clipped_frac = masked_mean((is_weights == 0).float(), completion_mask) + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() + else: + # Sequence-level (both truncate and mask) + seq_ratios = self._compute_sequence_level_ratios(is_ratio, completion_mask) + clipped_frac = (seq_ratios > threshold).float().mean() + metrics['clipped_frac'] = self.accelerator.gather_for_metrics(clipped_frac).nanmean().item() -class MockTemplate: - - def __init__(self, padding_free=False): - self.padding_free = padding_free + return metrics class TestVLLMImportanceSampling: @@ -91,7 +170,7 @@ def test_token_truncate_basic(self): # Create mock data: [batch=2, seq_len=4] # Log ratios that will produce ratios [0.5, 1.5, 3.0, 5.0] vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0], [0.8, 1.2, 2.5, 4.0]])) - completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + completion_mask = torch.ones_like(vllm_log_ratio) is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) @@ -107,7 +186,7 @@ def test_token_mask_basic(self): trainer = MockGRPOTrainer(mode='token_mask', threshold=2.0) vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) - completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + completion_mask = torch.ones_like(vllm_log_ratio) is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) @@ -122,11 +201,12 @@ def test_sequence_truncate_basic(self): trainer = MockGRPOTrainer(mode='sequence_truncate', threshold=2.0) # First sequence has high ratios, second has low ratios - vllm_log_ratio = torch.log(torch.tensor([ - [3.0, 3.0, 3.0, 3.0], # avg=3.0 > 2.0 - [1.0, 1.0, 1.0, 1.0] - ])) # avg=1.0 < 2.0 - completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + vllm_log_ratio = torch.log( + torch.tensor([ + [3.0, 3.0, 3.0, 3.0], # geometric mean=3.0 > 2.0 + [1.0, 1.0, 1.0, 1.0] + ])) # geometric mean=1.0 < 2.0 + completion_mask = torch.ones_like(vllm_log_ratio) is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) @@ -139,44 +219,25 @@ def test_sequence_mask_basic(self): """Test sequence-level masked IS""" trainer = MockGRPOTrainer(mode='sequence_mask', threshold=2.0) - vllm_log_ratio = torch.log(torch.tensor([ - [3.0, 3.0, 3.0, 3.0], # avg=3.0 > 2.0 - [1.0, 1.0, 1.0, 1.0] - ])) # avg=1.0 < 2.0 - completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + vllm_log_ratio = torch.log( + torch.tensor([ + [3.0, 3.0, 3.0, 3.0], # geometric mean=3.0 > 2.0 + [1.0, 1.0, 1.0, 1.0] + ])) # geometric mean=1.0 < 2.0 + completion_mask = torch.ones_like(vllm_log_ratio) is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # First sequence should be completely masked (0) + # Note: sequence_mask multiplies is_ratio by 0, so all tokens become 0 assert torch.allclose(is_weights[0, :], torch.tensor(0.0), atol=1e-5) - # Second sequence should remain 1.0 + # Second sequence should keep original ratios (1.0 * 1.0 = 1.0) assert torch.allclose(is_weights[1, :], torch.tensor(1.0), atol=1e-5) - def test_padding_free_mode(self): - """Test padding-free mode""" - trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) - trainer.template.padding_free = True - - # Simulate padding-free: [1, total_tokens] = [1, 6] for two sequences of len 4 and 2 - vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0, 0.8, 1.2]])) - completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) - lengths = torch.tensor([4, 2]) # Two sequences: len=4 and len=2 - - is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask, lengths) - - # Should have same shape as input - assert is_weights.shape == vllm_log_ratio.shape - # Check truncation: first sequence tokens 2,3 should be truncated to 2.0 - assert torch.allclose(is_weights[0, 2], torch.tensor(2.0), atol=1e-5) - assert torch.allclose(is_weights[0, 3], torch.tensor(2.0), atol=1e-5) - # Check second sequence: only one token should be truncated if > threshold - # 0.8 < 2.0, so should remain 0.8 - assert torch.allclose(is_weights[0, 4], torch.tensor(0.8), atol=1e-5) - def test_threshold_sensitivity(self): """Test different threshold values""" vllm_log_ratio = torch.log(torch.tensor([[1.0, 2.0, 3.0, 4.0]])) - completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + completion_mask = torch.ones_like(vllm_log_ratio) # Test threshold=1.5 trainer_low = MockGRPOTrainer(mode='token_truncate', threshold=1.5) @@ -197,12 +258,12 @@ def test_completion_mask(self): vllm_log_ratio = torch.log(torch.tensor([[3.0, 3.0, 3.0, 3.0]])) # Mask out last two tokens - completion_mask = torch.tensor([[True, True, False, False]]) + completion_mask = torch.tensor([[1.0, 1.0, 0.0, 0.0]]) is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Should only consider masked tokens for sequence ratio calculation - # With only first two tokens (both 3.0), avg=3.0, truncated to 2.0 + # With only first two tokens (both 3.0), geometric mean=3.0, truncated to 2.0 assert torch.allclose(is_weights[0, :2], torch.tensor(2.0), atol=1e-5) def test_edge_cases(self): @@ -211,44 +272,211 @@ def test_edge_cases(self): # Case 1: All ratios below threshold vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.0, 1.5]])) - completion_mask = torch.ones_like(vllm_log_ratio, dtype=torch.bool) + completion_mask = torch.ones_like(vllm_log_ratio) is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) assert torch.allclose(is_weights, torch.exp(vllm_log_ratio), atol=1e-5) # Case 2: All ratios above threshold vllm_log_ratio = torch.log(torch.tensor([[3.0, 4.0, 5.0]])) - is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask[:, :3]) assert torch.allclose(is_weights, torch.tensor(2.0), atol=1e-5) # Case 3: Empty mask vllm_log_ratio = torch.log(torch.tensor([[1.0, 2.0, 3.0]])) - completion_mask = torch.zeros_like(vllm_log_ratio, dtype=torch.bool) + completion_mask = torch.zeros_like(vllm_log_ratio) is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) # Should still compute but result may not be meaningful assert is_weights.shape == vllm_log_ratio.shape + def test_safety_bound(self): + """Test that extreme log ratios are clamped""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Create extreme log ratios that would overflow without clamping + vllm_log_ratio = torch.tensor([[100.0, -100.0, 0.0]]) # exp(100) would overflow + completion_mask = torch.ones_like(vllm_log_ratio) + + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + # Should not have inf or nan + assert torch.isfinite(is_weights).all() + # Large positive log_ratio should be clamped to threshold + assert is_weights[0, 0] <= 2.0 + # Large negative log_ratio should result in small positive value + assert is_weights[0, 1] > 0 + + +class TestISCorrectionMetrics: + """Test suite for IS correction metrics""" + + def test_ess_uniform_weights(self): + """Test ESS with uniform weights (should be close to 1.0)""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Uniform weights of 1.0 + vllm_log_ratio = torch.zeros((2, 4)) # exp(0) = 1.0 + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = torch.ones_like(vllm_log_ratio) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # ESS should be 1.0 for uniform weights + assert abs(metrics['ess'] - 1.0) < 0.01 + # Mean weight should be 1.0 + assert abs(metrics['is_weight_mean'] - 1.0) < 0.01 + # No clipping for uniform weights + assert metrics['clipped_frac'] == 0.0 + + def test_ess_varied_weights(self): + """Test ESS with varied weights (should be < 1.0)""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # Varied weights + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.0, 1.5, 2.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = torch.tensor([[0.5, 1.0, 1.5, 2.0]]) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # ESS should be less than 1.0 for non-uniform weights + assert metrics['ess'] < 1.0 + assert metrics['ess'] > 0.0 + + def test_clipped_frac_token_truncate(self): + """Test clipped_frac for token_truncate mode""" + trainer = MockGRPOTrainer(mode='token_truncate', threshold=2.0) + + # 2 out of 4 tokens exceed threshold + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # 2/4 = 0.5 tokens clipped + assert abs(metrics['clipped_frac'] - 0.5) < 0.01 + + def test_clipped_frac_token_mask(self): + """Test clipped_frac for token_mask mode""" + trainer = MockGRPOTrainer(mode='token_mask', threshold=2.0) + + # 2 out of 4 tokens exceed threshold + vllm_log_ratio = torch.log(torch.tensor([[0.5, 1.5, 3.0, 5.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # 2/4 = 0.5 tokens masked (is_weights == 0) + assert abs(metrics['clipped_frac'] - 0.5) < 0.01 + + def test_clipped_frac_sequence_level(self): + """Test clipped_frac for sequence-level modes""" + trainer = MockGRPOTrainer(mode='sequence_truncate', threshold=2.0) + + # First sequence exceeds threshold, second doesn't + vllm_log_ratio = torch.log(torch.tensor([[3.0, 3.0, 3.0, 3.0], [1.0, 1.0, 1.0, 1.0]])) + completion_mask = torch.ones_like(vllm_log_ratio) + is_weights = trainer._apply_rollout_importance_sampling(vllm_log_ratio, completion_mask) + + metrics = trainer._compute_is_correction_metrics(vllm_log_ratio, is_weights, completion_mask) + + # 1/2 = 0.5 sequences clipped + assert abs(metrics['clipped_frac'] - 0.5) < 0.01 + + +class TestOffpolicyMetrics: + """Test suite for off-policy diagnostic metrics""" + + def test_kl_divergence_same_policy(self): + """Test KL divergence when policies are identical""" + # When per_token_logps == rollout_per_token_logps, KL should be 0 + per_token_logps = torch.tensor([[-1.0, -2.0, -1.5, -0.5]]) + rollout_per_token_logps = per_token_logps.clone() + completion_mask = torch.ones_like(per_token_logps) + + # Helper function for masked mean + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) + + # KL = E[log(π_rollout) - log(π_training)] + kl = masked_mean(rollout_per_token_logps - per_token_logps, completion_mask) + + assert abs(kl.item()) < 1e-6 + + def test_k3_kl_estimator(self): + """Test K3 KL estimator""" + per_token_logps = torch.tensor([[-1.0, -2.0, -1.5, -0.5]]) + rollout_per_token_logps = torch.tensor([[-1.1, -1.9, -1.6, -0.4]]) + completion_mask = torch.ones_like(per_token_logps) + + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) + + # K3 estimator: E[exp(log_ratio) - log_ratio - 1] + log_ratio = per_token_logps - rollout_per_token_logps + log_ratio *= completion_mask + k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 + k3_kl = masked_mean(k3_kl_matrix, completion_mask) + + # K3 KL should be non-negative + assert k3_kl.item() >= 0 + + def test_chi2_divergence(self): + """Test χ² divergence calculation""" + per_token_logps = torch.tensor([[-1.0, -2.0]]) + rollout_per_token_logps = torch.tensor([[-1.5, -1.5]]) + completion_mask = torch.ones_like(per_token_logps) + + def masked_mean(x, mask, axis=None): + if axis is None: + return (x * mask).sum() / mask.sum().clamp(min=1.0) + else: + return (x * mask).sum(axis) / mask.sum(axis).clamp(min=1.0) + + SAFETY_BOUND = 20.0 + log_ratio = per_token_logps - rollout_per_token_logps + log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rho_token = torch.exp(log_ratio_safe) + rho_squared_token = rho_token.square() + chi2_token = masked_mean(rho_squared_token, completion_mask) - 1.0 + + # χ² should be >= -1 (can be negative if E[ρ²] < 1) + assert chi2_token.item() >= -1.0 + if __name__ == '__main__': # Run tests manually import sys - test_instance = TestVLLMImportanceSampling() - - test_methods = [ - 'test_token_truncate_basic', 'test_token_mask_basic', 'test_sequence_truncate_basic', - 'test_sequence_mask_basic', 'test_padding_free_mode', 'test_threshold_sensitivity', 'test_completion_mask', - 'test_edge_cases' + test_classes = [ + ('TestVLLMImportanceSampling', TestVLLMImportanceSampling), + ('TestISCorrectionMetrics', TestISCorrectionMetrics), + ('TestOffpolicyMetrics', TestOffpolicyMetrics), ] failed_tests = [] - for method_name in test_methods: - try: - print(f'Running {method_name}...') - getattr(test_instance, method_name)() - print(f'✓ {method_name} passed') - except Exception as e: - print(f'✗ {method_name} failed: {e}') - failed_tests.append(method_name) + + for class_name, test_class in test_classes: + print(f'\n=== {class_name} ===') + test_instance = test_class() + + test_methods = [m for m in dir(test_instance) if m.startswith('test_')] + + for method_name in test_methods: + try: + print(f'Running {method_name}...') + getattr(test_instance, method_name)() + print(f'✓ {method_name} passed') + except Exception as e: + print(f'✗ {method_name} failed: {e}') + failed_tests.append(f'{class_name}.{method_name}') if failed_tests: print(f'\nFailed tests: {failed_tests}') From e9e45686280f69283381157059df1db563f5b43d Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 18:00:07 +0800 Subject: [PATCH 17/21] norm chi2_seq --- .../AdvancedResearch/training_inference_mismatch.md | 4 ++-- .../AdvancedResearch/training_inference_mismatch.md | 4 ++-- swift/trainers/rlhf_trainer/grpo_trainer.py | 13 ++++++++----- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md index 9c425ed6aa..06aeeaa5ba 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md +++ b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -166,9 +166,9 @@ $$ $$ - `chi2_token`:Token 级别 χ² 散度,$\mathbb{E}[\rho_t^2] - 1$ -- `chi2_seq`:Sequence 级别 χ² 散度,$\mathbb{E}[(\prod_t \rho_t)^2] - 1$ +- `chi2_seq`:Sequence 级别 χ² 散度(基于几何平均),$\mathbb{E}[\rho_{\text{geo}}^2] - 1$,其中 $\rho_{\text{geo}} = \exp(\frac{1}{T}\sum_t \log \rho_t)$ -χ² 散度越大,表示 IS 权重方差越大,训练越不稳定。 +χ² 散度越大,表示 IS 权重方差越大,训练越不稳定。`chi2_seq` 使用几何平均而非乘积,使其与 `chi2_token` 在量级上可比较。 ### 4. Effective Sample Size (ESS) diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md index 731d23f61d..93ed432901 100644 --- a/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -167,9 +167,9 @@ $$ $$ - `chi2_token`: Token-level χ² divergence, $\mathbb{E}[\rho_t^2] - 1$ -- `chi2_seq`: Sequence-level χ² divergence, $\mathbb{E}[(\prod_t \rho_t)^2] - 1$ +- `chi2_seq`: Sequence-level χ² divergence (geometric mean based), $\mathbb{E}[\rho_{\text{geo}}^2] - 1$, where $\rho_{\text{geo}} = \exp(\frac{1}{T}\sum_t \log \rho_t)$ -Higher χ² divergence indicates larger IS weight variance and less stable training. +Higher χ² divergence indicates larger IS weight variance and less stable training. `chi2_seq` uses geometric mean instead of product, making it comparable in scale to `chi2_token`. ### 4. Effective Sample Size (ESS) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c38b487ff9..73b67e880e 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -2305,11 +2305,14 @@ def masked_mean(x, mask, axis=None): chi2_token = masked_mean(rho_squared_token, completion_mask) - 1.0 metrics['chi2_token'] = self.accelerator.gather_for_metrics(chi2_token).nanmean().item() - # Sequence-level: E_seq[(Π ρ_t)²] - 1 = E_seq[exp(2 * Σ log ρ_t)] - 1 - log_ratio_sum = (log_ratio * completion_mask).sum(-1) # Σ log ρ_t per sequence - log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND) - rho_squared_seq = torch.exp(2.0 * log_ratio_sum_safe) # (Π ρ_t)² - chi2_seq = rho_squared_seq.mean() - 1.0 + # Sequence-level (geometric mean): E_seq[ρ_geo²] - 1 + # where ρ_geo = exp(mean(log ρ_t)) is the geometric mean of token-level ratios + # This is more interpretable than the product-based chi2_seq, as it's normalized by sequence length + # and comparable to other per-token metrics like chi2_token + log_ratio_mean = masked_mean(log_ratio, completion_mask, axis=-1) # mean(log ρ_t) per sequence + log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND) + rho_geo = torch.exp(log_ratio_mean_safe) # geometric mean of ρ_t + chi2_seq = (rho_geo.square().mean() - 1.0) metrics['chi2_seq'] = self.accelerator.gather_for_metrics(chi2_seq).nanmean().item() return metrics From e1b6565d4efa0b72378c5763305c87b7bd474b84 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 18:06:42 +0800 Subject: [PATCH 18/21] cancel skip stop tokens for other engine --- swift/llm/infer/infer_engine/lmdeploy_engine.py | 2 +- swift/llm/infer/infer_engine/pt_engine.py | 2 +- swift/llm/infer/infer_engine/sglang_engine.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/llm/infer/infer_engine/lmdeploy_engine.py b/swift/llm/infer/infer_engine/lmdeploy_engine.py index 4f7b50c0da..edbca4f5e6 100644 --- a/swift/llm/infer/infer_engine/lmdeploy_engine.py +++ b/swift/llm/infer/infer_engine/lmdeploy_engine.py @@ -260,7 +260,7 @@ async def _infer_full_async( toolcall = self._get_toolcall(response, template) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, output.status.name == 'FINISH') - token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None + token_ids = output.token_ids if request_config.return_details else None choices = [ ChatCompletionResponseChoice( index=0, diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index f173c4dbfc..09fbe5825f 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -426,7 +426,7 @@ def _infer_full(self, template: Template, inputs: Dict[str, Any], *, generation_ response = template.decode(generate_ids, template_inputs=template_inputs[i]) finish_reason = self._get_finish_reason(generation_config.max_new_tokens, len(generate_ids), True) toolcall = self._get_toolcall(response, template) - token_ids = template.skip_stop_tokens(generate_ids) if request_config.return_details else None + token_ids = generate_ids if request_config.return_details else None choices.append( ChatCompletionResponseChoice( index=j, diff --git a/swift/llm/infer/infer_engine/sglang_engine.py b/swift/llm/infer/infer_engine/sglang_engine.py index 37de0f845e..fe02ac8a3f 100644 --- a/swift/llm/infer/infer_engine/sglang_engine.py +++ b/swift/llm/infer/infer_engine/sglang_engine.py @@ -149,7 +149,7 @@ def _create_chat_completion_response(self, output, inputs, template, return_deta if template.template_meta.response_prefix: response = template.template_meta.response_prefix + response toolcall = self._get_toolcall(response, template) - token_ids = template.skip_stop_tokens(output['output_ids']) if return_details else None + token_ids = output['output_ids'] if return_details else None choice = ChatCompletionResponseChoice( index=0, message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), From 038094447c4b1b0d6aa989562e3e43d01c88e036 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 18:21:37 +0800 Subject: [PATCH 19/21] unify rollout_kl --- .../training_inference_mismatch.md | 9 ++++--- .../training_inference_mismatch.md | 10 ++++---- swift/trainers/rlhf_trainer/grpo_trainer.py | 24 ++++++++++++------- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md index 06aeeaa5ba..229d35cc74 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md +++ b/docs/source/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -124,22 +124,21 @@ $$ ### 1. KL 散度(KL Divergence) -KL 散度衡量两个分布之间的差异,提供两种估计器: +KL 散度衡量训练策略偏离 rollout 策略的程度。两个指标都估计 $\text{KL}(\pi_\theta \| \pi_{\text{vLLM}})$,这与重要性采样权重 $\rho = \frac{\pi_\theta}{\pi_{\text{vLLM}}}$ 直接相关。 **直接估计器 `kl`**: $$ -\text{KL}(\pi_{\text{vLLM}} \| \pi_\theta) = \mathbb{E}\left[ \log \pi_{\text{vLLM}} - \log \pi_\theta \right] +\text{KL}(\pi_\theta \| \pi_{\text{vLLM}}) = \mathbb{E}_{\pi_{\text{vLLM}}}\left[ \log \frac{\pi_\theta}{\pi_{\text{vLLM}}} \right] $$ - **K3 估计器 `k3_kl`**: $$ -\text{KL}_{K3} = \mathbb{E}\left[ \rho - \log \rho - 1 \right], \quad \rho = \frac{\pi_\theta}{\pi_{\text{vLLM}}} +\text{KL}(\pi_\theta \| \pi_{\text{vLLM}}) \approx \mathbb{E}_{\pi_{\text{vLLM}}}\left[ \rho - \log \rho - 1 \right], \quad \rho = \frac{\pi_\theta}{\pi_{\text{vLLM}}} $$ -K3 估计器在 KL 值较小时数值更稳定。 +K3 估计器在 KL 值较小时数值更稳定,且始终非负。 ### 2. Perplexity (PPL) diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md index 93ed432901..7f96359aca 100644 --- a/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/training_inference_mismatch.md @@ -124,23 +124,21 @@ To monitor the degree of training-inference mismatch during training, we add the ### 1. KL Divergence -KL divergence measures the difference between two distributions. We provide two estimators: +KL divergence measures how much the training policy deviates from the rollout policy. Both metrics estimate $\text{KL}(\pi_\theta \| \pi_{\text{vLLM}})$, which is directly related to the importance sampling ratio $\rho = \frac{\pi_\theta}{\pi_{\text{vLLM}}}$. **Direct estimator `kl`**: $$ -\text{KL}(\pi_{\text{vLLM}} \| \pi_\theta) = \mathbb{E}\left[ \log \pi_{\text{vLLM}} - \log \pi_\theta \right] +\text{KL}(\pi_\theta \| \pi_{\text{vLLM}}) = \mathbb{E}_{\pi_{\text{vLLM}}}\left[ \log \frac{\pi_\theta}{\pi_{\text{vLLM}}} \right] $$ -A positive value indicates the rollout policy is more confident than the training policy. - **K3 estimator `k3_kl`**: $$ -\text{KL}_{K3} = \mathbb{E}\left[ \rho - \log \rho - 1 \right], \quad \rho = \frac{\pi_\theta}{\pi_{\text{vLLM}}} +\text{KL}(\pi_\theta \| \pi_{\text{vLLM}}) \approx \mathbb{E}_{\pi_{\text{vLLM}}}\left[ \rho - \log \rho - 1 \right], \quad \rho = \frac{\pi_\theta}{\pi_{\text{vLLM}}} $$ -The K3 estimator is more numerically stable when KL values are small. +The K3 estimator is more numerically stable when KL values are small and is always non-negative. ### 2. Perplexity (PPL) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 73b67e880e..8ebed3cf93 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -2255,17 +2255,25 @@ def masked_mean(x, mask, axis=None): (-mean_log_prob_training).mean()).nanmean().item() # 2. Compute rollout off-policy metrics - # 2a. kl: Direct estimator for KL(π_rollout || π_training) - # This is the standard KL divergence: E[log(π_rollout) - log(π_training)] - # Positive value means rollout policy is more confident than training policy - kl = masked_mean(rollout_per_token_logps - per_token_logps, completion_mask) - metrics['kl'] = self.accelerator.gather_for_metrics(kl).nanmean().item() + # All KL metrics estimate KL(π_training || π_rollout), which measures how much + # the training policy deviates from the rollout policy. This is directly related + # to the importance sampling ratio ρ = π_training / π_rollout. - # 2b. k3_kl: K3 estimator for KL(π_rollout || π_training) - # More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1] - # Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout + # log_ratio = log(π_training / π_rollout), used for both KL estimators log_ratio = per_token_logps - rollout_per_token_logps log_ratio *= completion_mask + + # 2a. kl: Direct estimator for KL(π_training || π_rollout) + # Formula: KL(P||Q) = E_Q[log(P/Q)] when sampled from Q (rollout) + # However, we use the identity: E_Q[log(P/Q)] = E_Q[log P] - E_Q[log Q] + # Since data is from rollout, E_Q[log Q] ≈ E[rollout_logps], E_Q[log P] ≈ E[training_logps] + # Positive value means training policy assigns higher probability than rollout + kl = masked_mean(log_ratio, completion_mask) + metrics['kl'] = self.accelerator.gather_for_metrics(kl).nanmean().item() + + # 2b. k3_kl: K3 estimator for KL(π_training || π_rollout) + # More stable for small KL values + # Formula: KL(P||Q) ≈ E_Q[P/Q - log(P/Q) - 1] where P=π_training, Q=π_rollout k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1 k3_kl = masked_mean(k3_kl_matrix, completion_mask) metrics['k3_kl'] = self.accelerator.gather_for_metrics(k3_kl).nanmean().item() From 5c84f7ae06137320035f6265582d9cbaa6b71031 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 19:34:35 +0800 Subject: [PATCH 20/21] server mode --- swift/llm/infer/rollout.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 26d6fc9083..724275850e 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -338,6 +338,9 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): logger.info('Currently, rollout only supports the vLLM backend. Set vLLM backend') kwargs.update(args.get_vllm_engine_kwargs()) kwargs.update({'enable_lora': args.vllm_enable_lora}) # override + # Important: Use processed_logprobs so temperature scaling affects the logprobs + # This is required for correct importance sampling in rollout correction + kwargs['logprobs_mode'] = 'processed_logprobs' # used for RL external rollout backend engine_kwargs = kwargs.get('engine_kwargs', {}) # for RL rollout model weight sync From 088c868adc9012658bae5a226c54b0bb71ce8233 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 25 Nov 2025 23:51:35 +0800 Subject: [PATCH 21/21] cancel for multi turn --- swift/trainers/rlhf_trainer/grpo_trainer.py | 2 +- swift/trainers/rlhf_trainer/rollout_mixin.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 8ebed3cf93..468c426cae 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1026,7 +1026,7 @@ def _compute_loss_and_metrics(self, model, inputs): # Compute rollout diagnostic metrics and apply IS correction if enabled rollout_correction_metrics = {} - if inputs.get('vllm_per_token_logps') is not None: + if inputs.get('vllm_per_token_logps') is not None and not self.disable_rollout_importance_sampling: vllm_per_token_logps = inputs['vllm_per_token_logps'] # Always compute diagnostic metrics (KL, PPL, etc.) for monitoring off-policy gap diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index 7e20929cb6..8dcf1caa99 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -113,6 +113,8 @@ def _prepare_rollout_params(self): return_details=True, logprobs=args.use_vllm) + self.disable_rollout_importance_sampling = False + def _prepare_vllm(self): """Initialize vLLM engine (server or colocate mode)""" args = self.args @@ -147,6 +149,10 @@ def _prepare_vllm(self): self.vllm_use_async_engine = broadcast_object_list(vllm_use_async_engine, from_process=0)[0] self.use_gym_env = broadcast_object_list(use_gym_env, from_process=0)[0] self.enable_server_multi_turn = broadcast_object_list(enable_multi_turn, from_process=0)[0] + if self.enable_server_multi_turn: + if getattr(args, 'rollout_importance_sampling_mode', None) is not None: + logger.warning('Rollout importance sampling is disabled for server multi-turn mode') + self.disable_rollout_importance_sampling = True self.rollout_enable_lora = broadcast_object_list(enable_lora, from_process=0)[0] if self.use_gym_env: self.reward_func_names = ['gym_reward'] @@ -940,6 +946,10 @@ def _prepare_scheduler(self): return if args.multi_turn_scheduler: + if getattr(args, 'rollout_importance_sampling_mode', None) is not None: + # TODO + logger.warning('Rollout importance sampling mode is not supported for multi-turn scheduler') + self.disable_rollout_importance_sampling = True if isinstance(args.multi_turn_scheduler, str): assert args.multi_turn_scheduler in multi_turns multi_turn_scheduler = multi_turns[args.multi_turn_scheduler](max_turns=args.max_turns)