Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion docs/source/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,55 @@ swift rlhf \
1. 在 GRPOTrainer 中,reward_model 会依次append到 reward_funcs 中。因此,reward_weights 的顺序对应 [reward_funcs, reward_model]。
2. reward_model_plugin 默认为 default,即使用 ORM 处理逻辑。

## 多任务训练
我们可以在数据集中添加一个用于标识任务类型的列,并在奖励函数/奖励模型插件中根据任务类型进行判断,从而实现多任务训练。假设数据集中包含数学和编程任务,比如:

```
{"query": "Solve the equation x + 2 = 5", "solution": "3", "task": "math"},
{"query": "Write a function to calculate the Fibonacci sequence", "solution": "xxx", "task": "code"},
{"query": "What is the integral of x^2?", "solution": "xxx", "task": "math"},
{"query": "Implement a sorting algorithm in Python", "solution": "xxx", "task": "code"},
```

下面是针对不同任务的奖励函数的示例:

```python
from swift.plugin import ORM, orms
import random

# Math-specific reward function
class MathRandomReward(ORM):
def __call__(self, completions, task, **kwargs):
rewards = []
for completion, t in zip(completions, task):
if t == "math":
import random
# imple math accuracy logic
reward = random.random()
rewards.append(reward)
else:
# Return None for non-math tasks
rewards.append(None)
return rewards

# Coding-specific reward function
class CodeRandomReward(ORM):
def __call__(self, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "code":
# imple coding accuracy logic
reward = random.random()
rewards.append(reward)
else:
# Return None for non-coding tasks
rewards.append(None)
return rewards

orms['math_reward'] = MathRandomReward
orms['code_reward'] = CodeRandomReward
```
对于非当前任务的数据, 通过返回 None 来处理,从而使得奖励相关仅计算任务内的数据。

## DAPO
[Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)](https://arxiv.org/abs/2503.14476)在GRPO的基础上设置了几种trick,分别是
Expand Down Expand Up @@ -363,7 +412,21 @@ num_generations = 64

**5. clip_ratio为什么总是1?**

num_iterations = 1,async_generate = False 下为 on-policy RL,old_policy此时等于policy
Clip机制的核心目的是限制策略更新的幅度,防止因单次更新过大而导致策略性能崩溃(即策略更新后表现急剧下降)。
Clip操作的具体公式如下:
$$
L_{\text{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min\left(r_{t}(\theta) \hat{A}_{t}, \text{clip}(r_{t}(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_{t} \right) \right]
$$

其中:$r_{t}(\theta) = \frac{\pi_{\theta}(a_{t} \mid s_{t})}{\pi_{\text{old}}(a_{t} \mid s_{t})}$ 是重要性采样比,衡量新旧策略的差异。$\hat{A}_{t}$ 是优势函数(advantage function),表示动作的相对收益。$\epsilon$ 用于限制 $r_{t}(\theta)$ 的偏离范围。

在 on-policy 训练过程中,由于每次更新都使用最新策略生成的数据,新旧策略相同,即 $\pi_{\theta} = \pi_{\text{old}}$

因此重要性采样比恒为 1,此时,clip 操作不会生效。

在设置以下参数情况下,算法为off-policy (near-on-policy)
1. num_iterations > 1
2. steps_per_generation > gradient_accumulation_steps

参考[issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851)

Expand Down
70 changes: 69 additions & 1 deletion docs/source_en/Instruction/GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,60 @@ Notes:
1. In the GRPOTrainer, reward_model instances are appended sequentially to reward_funcs. Therefore, the order of reward_weights corresponds to [reward_funcs, reward_model].
2. The default value for reward_model_plugin is default, which uses the ORM processing logic.

## Multi-task training

We can add a column to the dataset to identify the task type and make judgments based on the task type in the reward function/reward model plugin, thereby enabling multi-task training. Suppose the dataset contains math and programming tasks, such as:
```
{"query": "Solve the equation x + 2 = 5", "solution": "3", "task": "math"},
{"query": "Write a function to calculate the Fibonacci sequence", "solution": "xxx", "task": "code"},
{"query": "What is the integral of x^2?", "solution": "xxx", "task": "math"},
{"query": "Implement a sorting algorithm in Python", "solution": "xxx", "task": "code"},
```

Below are examples of reward functions for different tasks:

```python
from swift.plugin import ORM, orms

# Math-specific reward function
from swift.plugin import ORM, orms
import random

# Math-specific reward function
class MathRandomReward(ORM):
def __call__(self, completions, task, **kwargs):
rewards = []
for completion, t in zip(completions, task):
if t == "math":
import random
# imple math accuracy logic
reward = random.random()
rewards.append(reward)
else:
# Return None for non-math tasks
rewards.append(None)
return rewards

# Coding-specific reward function
class CodeRandomReward(ORM):
def __call__(self, completions, task, **kwargs):
rewards = []
for completion, t in zip(completions, task):
if t == "code":
# imple coding accuracy logic
reward = random.random()
rewards.append(reward)
else:
# Return None for non-coding tasks
rewards.append(None)
return rewards

orms['math_reward'] = MathRandomReward
orms['code_reward'] = CodeRandomReward
```

For data that does not belong to the current task, it is handled by returning None, ensuring that the reward calculation only applies to data within the task.


## DAPO
Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) introduces several tricks based on GRPO, which are:
Expand Down Expand Up @@ -380,7 +434,21 @@ See reference: [issue](https://github.com/modelscope/ms-swift/issues/3912)

**5. Why is clip_ratio always 1?**

When num_iterations = 1 and async_generate = False, it's on-policy RL, and old_policy is equal to policy.
The core purpose of the Clip mechanism is to limit the magnitude of policy updates, preventing a single update from being too large and causing a collapse in policy performance (i.e., a sudden drop in performance after the policy is updated). The specific formula for the Clip operation is as follows:

$$
L_{\text{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min\left(r_{t}(\theta) \hat{A}_{t}, \text{clip}(r_{t}(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_{t} \right) \right]
$$

Where $r_{t}(\theta) = \frac{\pi_{\theta}(a_{t} \mid s_{t})}{\pi_{\text{old}}(a_{t} \mid s_{t})}$ is the importance sampling ratio, measuring the difference between the new and old policies. $\hat{A}_{t}$ is the advantage function, representing the relative reward of an action. $\epsilon$ is used to limit the deviation range of $r_{t}(\theta)$


Therefore, the importance sampling is always equal to 1, and in this case, the clip operation will not take effect.

Under the following parameter settings, the algorithm is off-policy (near-on-policy).

1. num_iterations > 1
2. steps_per_generation > gradient_accumulation_steps

See reference: [issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851)

Expand Down
54 changes: 34 additions & 20 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
from transformers import PreTrainedModel, TrainerCallback
from transformers.trainer import Trainer
from trl import GRPOTrainer as HFGRPOTrainer
from trl.extras.profiling import profiling_decorator
from trl.extras.profiling import profiling_context, profiling_decorator
from trl.models import prepare_deepspeed
from trl.trainer.callbacks import SyncRefModelCallback
from trl.trainer.grpo_trainer import nanmax, nanmin
from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd

from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device
from swift.llm.model.utils import get_llm_model
Expand Down Expand Up @@ -873,19 +873,30 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te
completions = [example['messages'][-1]['content'] for example in inputs]
rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device)

for i, (reward_func, reward_model_plugin) in enumerate(zip(self.reward_funcs, self.reward_model_plugins)):
# reward model
if isinstance(reward_func, nn.Module):
rewards_per_func[:, i] = reward_model_plugin(inputs=inputs)
# reward function
else:
# Repeat all input columns (but "messages" and "completion") to match the number of generations
reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
output_reward_func = reward_func(completions, **reward_kwargs)
for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)):
with profiling_context(self, reward_func_name):
# reward model
if isinstance(reward_func, nn.Module):
output_reward_func = reward_model_plugin(inputs=inputs)
# reward function
else:
# Repeat all input columns (but "messages" and "completion") to match the number of generations
reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
output_reward_func = reward_func(completions, **reward_kwargs)
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

# If all reward functions return None for a given row, issue a detailed warning
if torch.isnan(rewards_per_func).all(dim=1).any():
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
row_reward_kwargs['completion'] = completions[nan_row_idx]
logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. '
'Please ensure that at least one reward function returns a valid reward.')

total_rewards_per_func = gather(rewards_per_func)
total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)

return total_rewards_per_func, total_rewards, completions

Expand Down Expand Up @@ -1027,10 +1038,11 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func)

self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio)

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = rewards_per_func[:, i].mean().item()
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards)
std_rewards = rewards_per_func[:, i].std().item()
std_rewards = nanstd(rewards_per_func[:, i]).item()
self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards)

# Log overall reward stats
Expand Down Expand Up @@ -1071,7 +1083,8 @@ def _compute_loss(self, model, inputs):
# apply the completion_mask to exclude loss and metrics for overlong completions
if self.args.overlong_filter and any(truncated_mask):
if all(truncated_mask):
logger.info('All completions are overlong, loss and KL will be zero')
logger.info('All completions are overlong and truncated, '
'resulting in NaN some values for some metrics (e.g., KL)')
truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device)
completion_mask = completion_mask * (~truncated_mask)

Expand Down Expand Up @@ -1341,11 +1354,12 @@ def _engine_infer(
*,
use_tqdm: Optional[bool] = False,
):
if self.vllm_mode == 'server':
self._process_infer_requests_images(infer_requests)
return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm)
else:
return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)
with profiling_context(self, 'generate'):
if self.vllm_mode == 'server':
self._process_infer_requests_images(infer_requests)
return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm)
else:
return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)

def _process_infer_requests_images(self, infer_requests: List[InferRequest]):
# Process image format into a format that session.post can accept
Expand Down
Loading