Skip to content

Commit c5ca64d

Browse files
authored
[grpo] support CHORD algorithm (#5680)
* chord wip * wip * wip * update doc * remove chord to utils * fix * fix * update script * doc * remove unused import * fix link * readme * readme en * compute sft only for train * fix mu=0
1 parent c56cac3 commit c5ca64d

File tree

12 files changed

+399
-22
lines changed

12 files changed

+399
-22
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group:
7575

7676

7777
## 🎉 News
78+
- 🎁 2025.09.07: Added support for CHORD training algorithm. See the [documentation](./docs/source_en/Instruction/GRPO/AdvancedResearch/CHORD.md)
7879
- 🎁 2025.09.06: Ulysses can now be used with ring-attention, allowing sequences to be sharded into any number of chunks (no longer limited by the number of heads). The argument remains `--sequence_parallel_size N`.
7980
- 🎁 2025.09.02: Megatron-SWIFT now supports multimodal model training. Documentation can be found [here](./docs/source_en/Megatron-SWIFT/Multimodal-Model.md).
8081
- 🎁 2025.08.12: Support [Dynamic Fine-Tuning](https://arxiv.org/abs/2508.05629)(DFT) in SFT training, use parameter `--enable_dft_loss true`. Training scripts can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/dft.sh).

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
- **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。
7272

7373
## 🎉 新闻
74+
- 🎁 2025.09.07: 支持CHORD训练算法,请查看[文档](docs/source/Instruction/GRPO/AdvancedResearch/CHORD.md)
7475
- 🎁 2025.09.06: Ulysses现已支持与ring-attention结合使用,使得输入序列可以被切分成任意数量的块(不再受限于num_heads),命令参数仍然是`--sequence_parallel_size N`
7576
- 🎁 2025.09.02: Megatron-SWIFT支持多模态模型训练。文档参考[这里](./docs/source/Megatron-SWIFT/多模态模型.md)
7677
- 🎁 2025.08.12: 支持在SFT训练中使用[Dynamic Fine-Tuning](https://arxiv.org/abs/2508.05629)(DFT),使用参数 `--enable_dft_loss true`。训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/dft.sh)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# On-Policy RL Meets Off-Policy Experts: Harmonizing SFT and RL via Dynamic Weighting (CHORD)
2+
3+
**版本依赖**:ms-swift>=3.9
4+
5+
本文档介绍论文 [On-Policy RL Meets Off-Policy Experts: Harmonizing SFT and RL via Dynamic Weighting](https://arxiv.org/abs/2508.11408) 中提出的 **CHORD** 算法。CHORD 的核心思想是在 **on-policy 强化学习**(如 GRPO/PPO)过程中,**动态融合 off-policy 专家数据(SFT)**,通过 **全局权重 μ + token 级别权重 φ** 的双重控制机制,在模仿与探索之间实现平衡。
6+
7+
## 算法概述
8+
CHORD 算法通过在 **GRPO loss** 中引入 **SFT loss**,实现动态混合训练。总体目标函数为:
9+
10+
$$
11+
\mathcal{L}_{\text{CHORD}} = (1 - \mu) \cdot \mathcal{L}_{\text{GRPO}} + \mu \cdot \mathcal{L}_{\text{SFT}}
12+
$$
13+
14+
其中:
15+
- $\mathcal{L}_{\text{GRPO}}$:基于 on-policy 采样的强化学习损失(类似 PPO)。
16+
- $\mathcal{L}_{\text{SFT}}$:监督微调损失。
17+
- $\mu \in [0, 1]$:全局平衡系数,控制 SFT 信号在总梯度中的贡献。
18+
19+
### 参数配置(数据与批量大小)
20+
我们可以基于 GRPO 训练实现 CHORD 训练。
21+
22+
CHORD 需要在训练时指定额外的 SFT 数据集和批量大小:
23+
- `chord_sft_dataset`: 用于提供专家数据的 SFT 数据集。
24+
- `chord_sft_per_device_train_batch_size`: 每个设备的 SFT mini-batch 大小。
25+
26+
---
27+
28+
## 两种 CHORD 变体
29+
30+
论文提出了两种算法变体:**CHORD-µ****CHORD-ϕ**
31+
32+
### CHORD-µ
33+
通过在训练过程中逐步 **衰减 μ**,实现从模仿专家到自主探索的过渡。
34+
35+
**参数:**
36+
- `chord_mu_peak`:μ 的峰值。
37+
- `chord_mu_valley` μ 的衰减终值。
38+
- `chord_mu_warmup_steps` μ 值上升至峰值的训练步数。
39+
- `chord_mu_decay_steps` μ 从峰值衰减到谷值的训练步数。
40+
41+
### CHORD-ϕ(Token 级加权)
42+
**CHORD-ϕ** 不依赖 μ 的动态衰减,而是固定 μ 为一个较小的常数(推荐 **0.05 ~ 0.2**),再通过 **token-wise 权重函数 φ** 动态控制每个专家 token 的梯度贡献。
43+
44+
**φ 定义:**
45+
$$
46+
\phi(y_t^\star, \pi_\theta) = p_t \cdot (1 - p_t)
47+
$$
48+
49+
其中:
50+
- $p_t = \pi_\theta(y_t^\star \mid x, y_{<t}^\star)$:模型当前预测专家 token 的概率。
51+
- 当 $p_t ≈ 0.5$(模型不确定时),φ 取最大值 → 强化学习不确定的 token。
52+
- 当 $p_t ≈ 0$ 或 $p_t ≈ 1$,φ → 0 → 避免对过于确定或完全不会的 token 过度学习。
53+
54+
**开启 φ 加权的参数**
55+
- `chord_enable_phi_function: bool = False`
56+
- 设置为 `True` 即启用 token-wise 权重 φ。
57+
58+
注:如果使用常数 μ 值 ,设置 chord_mu_peak 与 chord_mu_valley 相同
59+
60+
<details>
61+
<summary>mu值衰减与loss计算代码实现</summary>
62+
请参考`GRPOTrainer``_compute_chord_loss`方法:
63+
</details>
64+
65+
训练参考该[脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/chord.sh)

docs/source/Instruction/GRPO/AdvancedResearch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ Advanced Research
77
DAPO.md
88
deepeyes.md
99
GSPO.md
10+
CHORD.md
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# On-Policy RL Meets Off-Policy Experts: Harmonizing SFT and RL via Dynamic Weighting (CHORD)
2+
3+
**Version Requirement**: ms-swift>=3.9
4+
5+
This document describes the CHORD algorithm proposed in the paper "On-Policy RL Meets Off-Policy Experts: Harmonizing SFT and RL via Dynamic Weighting" (https://arxiv.org/abs/2508.11408). The core idea of CHORD is to dynamically integrate off-policy expert data (SFT) into on-policy reinforcement learning (e.g., GRPO/PPO) by a dual control mechanism: a global weight μ plus a token-level weight φ, thereby balancing imitation and exploration.
6+
7+
## Algorithm Overview
8+
CHORD mixes training by introducing the SFT loss into the GRPO loss. The overall objective is:
9+
10+
$$
11+
\mathcal{L}_{\text{CHORD}} = (1 - \mu) \cdot \mathcal{L}_{\text{GRPO}} + \mu \cdot \mathcal{L}_{\text{SFT}}
12+
$$
13+
14+
where:
15+
- $\mathcal{L}_{\text{GRPO}}$: on-policy RL loss based on on-policy samples (similar to PPO).
16+
- $\mathcal{L}_{\text{SFT}}$: supervised fine-tuning (SFT) loss.
17+
- $\mu \in [0, 1]$: global balancing coefficient that controls the contribution of the SFT signal to the overall gradient.
18+
19+
### Configuration (data and batch sizes)
20+
We can implement CHORD training based on GRPO training.
21+
22+
CHORD requires specifying an additional SFT dataset and batch size at training time:
23+
- `chord_sft_dataset`: the SFT dataset that provides expert data.
24+
- `chord_sft_per_device_train_batch_size`: SFT mini-batch size per device.
25+
26+
---
27+
28+
## Two CHORD Variants
29+
30+
The paper proposes two variants: CHORD-μ and CHORD-φ.
31+
32+
### CHORD-μ
33+
CHORD-μ gradually decays μ during training to transition from imitating experts toward autonomous exploration.
34+
35+
Parameters:
36+
- `chord_mu_peak`: the peak value of μ.
37+
- `chord_mu_valley`: the final decayed value of μ.
38+
- `chord_mu_warmup_steps`: number of training steps to ramp μ up to the peak.
39+
- `chord_mu_decay_steps`: number of training steps to decay μ from peak to valley.
40+
41+
### CHORD-φ (Token-level weighting)
42+
CHORD-φ does not rely on μ scheduling; instead it keeps μ fixed to a small constant (recommended 0.05–0.2) and uses a token-wise weighting function φ to dynamically control each expert token's gradient contribution.
43+
44+
Definition of φ:
45+
$$
46+
\phi(y_t^\star, \pi_\theta) = p_t \cdot (1 - p_t)
47+
$$
48+
49+
where:
50+
- $p_t = \pi_\theta(y_t^\star \mid x, y_{<t}^\star)$ is the model's current predicted probability of the expert token.
51+
- When $p_t \approx 0.5$ (model uncertainty), φ is maximal → emphasize tokens the model is uncertain about.
52+
- When $p_t \approx 0$ or $p_t \approx 1$, φ → 0 → avoid overemphasizing tokens that are already certain or impossible.
53+
54+
Parameter to enable φ weighting:
55+
- `chord_enable_phi_function: bool = False`
56+
- Set to `True` to enable token-wise weight φ.
57+
58+
Note: If using a constant μ, set `chord_mu_peak` and `chord_mu_valley` to the same value.
59+
60+
<details>
61+
<summary>Code implementation of μ scheduling and loss computation</summary>
62+
See the `GRPOTrainer` method `_compute_chord_loss`.
63+
</details>
64+
65+
Training reference script: https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/internal/chord.sh

docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ Advanced Research
77
DAPO.md
88
deepeyes.md
99
GSPO.md
10+
CHORD.md
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# 8*80G GPU
2+
# CHORD https://arxiv.org/abs/2508.11408
3+
# GRPO total batch = 32(prompts)*8(num_generations) = 256 = 8(gpus) * 4(per_device_train_batch_size) * 8(gradient_accumulation_steps)
4+
# SFT total batch = 64 = 8(gpus) * 1(chord_sft_per_device_train_batch_size) * 8(gradient_accumulation_steps)
5+
6+
# NOTE: We use the same dataset for GRPO and SFT, which may cause overlap (i.e., the same examples to be selected).
7+
# You can pre-download the dataset and manually split it to avoid this.
8+
9+
export CHORD_SYSTEM_PROMPT="You are a helpful assistant that solves MATH problems.
10+
You should first think about the reasoning process in mind and then provide the user with the answer.
11+
You should present your reasoning process using the format: <think>\n...your reasoning process here... </think>\n"
12+
13+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
14+
NPROC_PER_NODE=8 \
15+
swift rlhf \
16+
--rlhf_type grpo \
17+
--model Qwen/Qwen2.5-7B-Instruct \
18+
--dataset AI-MO/NuminaMath-TIR \
19+
--torch_dtype bfloat16 \
20+
--beta 0.0 \
21+
--steps_per_generation 4 \
22+
--num_train_epochs 1 \
23+
--per_device_train_batch_size 4 \
24+
--gradient_accumulation_steps 8 \
25+
--chord_sft_per_device_train_batch_size 1 \
26+
--chord_sft_dataset AI-MO/NuminaMath-TIR \
27+
--chord_enable_phi_function false \
28+
--chord_mu_warmup_steps 0 \
29+
--chord_mu_decay_steps 200 \
30+
--chord_mu_peak 0.9 \
31+
--chord_mu_valley 0.05 \
32+
--num_generations 8 \
33+
--train_type full \
34+
--reward_funcs accuracy \
35+
--system "$CHORD_SYSTEM_PROMPT" \
36+
--use_vllm true \
37+
--vllm_mode colocate \
38+
--vllm_gpu_memory_utilization 0.4 \
39+
--vllm_max_model_len 8192 \
40+
--max_completion_length 4096 \
41+
--overlong_filter true \
42+
--offload_optimizer true \
43+
--offload_model true \
44+
--sleep_level 1 \
45+
--save_steps 1000 \
46+
--learning_rate 1e-6 \
47+
--save_total_limit 2 \
48+
--logging_steps 1 \
49+
--warmup_ratio 0.05 \
50+
--dataloader_num_workers 4 \
51+
--deepspeed zero3 \
52+
--log_completions true \
53+
--report_to tensorboard swanlab

swift/llm/train/rlhf.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t
164164
def _prepare_template(self) -> None:
165165
args = self.args
166166
super()._prepare_template()
167-
model_mapping = {'kto': 'kto', 'gkd': 'gkd', 'ppo': 'pt', 'grpo': 'pt'}
167+
model_mapping = {'kto': 'kto', 'gkd': 'gkd', 'ppo': 'pt', 'grpo': 'train'}
168168
self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf'))
169169

170170
if args.rlhf_type == 'ppo':
@@ -177,6 +177,25 @@ def _get_dataset(self):
177177
train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset)
178178
return train_dataset, val_dataset
179179

180+
def _prepare_chord_sft_dataset(self):
181+
from ..dataset import load_dataset
182+
from swift.llm.dataset.loader import DatasetLoader
183+
184+
# prepare expert sft dataset for chord
185+
args = self.args
186+
assert hasattr(args, 'chord_sft_dataset') and args.chord_sft_dataset
187+
dataset_kwargs = args.get_dataset_kwargs()
188+
chord_sft_datasets = []
189+
# TODO: validatition
190+
chord_sft_dataset, _ = load_dataset(
191+
args.chord_sft_dataset, split_dataset_ratio=0, shuffle=args.dataset_shuffle, **dataset_kwargs)
192+
chord_sft_dataset, _ = self._encode_dataset(chord_sft_dataset, None, pre_process=True)
193+
chord_sft_datasets.append(chord_sft_dataset)
194+
chord_sft_dataset = DatasetLoader._concat_datasets(chord_sft_datasets)
195+
datasets = [chord_sft_dataset, None]
196+
datasets = self._post_process_datasets(datasets)
197+
return datasets
198+
180199
def _get_trainer_kwargs(self):
181200
trainer_kwargs = {}
182201
for key in ['ref', 'reward', 'value', 'teacher']:
@@ -189,6 +208,8 @@ def _get_trainer_kwargs(self):
189208
if self.args.rlhf_type == 'grpo':
190209
trainer_kwargs['reward_funcs'] = self.args.reward_funcs
191210
trainer_kwargs['vllm_client'] = self.args.vllm_client
211+
if self.args.chord_sft_dataset:
212+
trainer_kwargs['chord_sft_dataset'], _ = self._prepare_chord_sft_dataset()
192213
return trainer_kwargs
193214

194215

swift/llm/train/sft.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,22 +113,29 @@ def _get_cached_dataset(self):
113113

114114
def _prepare_dataset(self):
115115
args = self.args
116+
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
116117
if args.cached_dataset:
117118
train_datasets, val_datasets = self._get_cached_dataset()
118119
else:
119120
train_datasets, val_datasets = [], []
120121
if args.dataset:
121122
train_dataset, val_dataset = self._get_dataset()
122-
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
123+
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset, pre_process=not is_grpo)
123124
train_datasets.append(train_dataset)
124125
val_datasets.append(val_dataset)
125126
train_dataset = DatasetLoader._concat_datasets(train_datasets)
126127
val_dataset = DatasetLoader._concat_datasets(val_datasets)
127-
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
128-
predict_with_generate = getattr(args, 'predict_with_generate', False)
129128
datasets = [train_dataset, val_dataset]
130129
if is_grpo:
131130
return datasets
131+
datasets = self._post_process_datasets(datasets)
132+
133+
return datasets
134+
135+
def _post_process_datasets(self, datasets: List) -> List:
136+
args = self.args
137+
predict_with_generate = getattr(args, 'predict_with_generate', False)
138+
132139
template = self.template
133140
for i, dataset in enumerate(datasets):
134141
if dataset is None:
@@ -294,15 +301,14 @@ def _show_dataset(self, train_dataset, val_dataset):
294301
if val_dataset is not None and not predict_with_generate:
295302
self.train_msg['val_dataset'] = self._stat_dataset(val_dataset)
296303

297-
def _encode_dataset(self, train_dataset, val_dataset):
304+
def _encode_dataset(self, train_dataset, val_dataset, pre_process=True):
298305
template = self.template
299306
args = self.args
300307
self._save_val_dataset(val_dataset)
301308

302-
is_grpo = hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo'
303309
predict_with_generate = getattr(args, 'predict_with_generate', False)
304310
datasets = [train_dataset, val_dataset]
305-
if is_grpo:
311+
if not pre_process:
306312
return datasets
307313

308314
origin_template_model = template.model

swift/trainers/arguments.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,15 @@ def __post_init__(self):
137137
class RLHFArgumentsMixin:
138138
# gkd
139139
sft_alpha: float = 0
140+
# chord
141+
chord_sft_dataset: Optional[str] = None
142+
chord_sft_per_device_train_batch_size: Optional[int] = None
143+
144+
chord_enable_phi_function: bool = False
145+
chord_mu_warmup_steps: Optional[int] = None
146+
chord_mu_decay_steps: Optional[int] = None
147+
chord_mu_peak: Optional[float] = None
148+
chord_mu_valley: Optional[float] = None
140149

141150

142151
@dataclass

0 commit comments

Comments
 (0)