Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions docs/source/Instruction/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ $$
#### Forward KL(前向 KL)

$$
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
$$

**特性**:Mode-seeking(寻模)
- 期望在学生分布下计算
- 学生模型倾向于集中在教师模型的峰值区域(高概率区域
**特性**:Mode-covering
- 期望在教师分布下计算
- 学生模型倾向于覆盖教师的整个分布(包括低概率区域

#### Reverse KL(反向 KL)

#### Reverse KL(反向 KL)
$$
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
$$

**特性**:Mode-covering(覆模)
- 期望在教师分布下计算
- 学生模型倾向于覆盖教师的整个分布(包括低概率区域
**特性**:Mode-seeking
- 期望在学生分布下计算
- 学生模型倾向于集中在教师模型的峰值区域(高概率区域

### 广义 Jensen-Shannon 散度(Generalized JSD)

Expand Down Expand Up @@ -78,8 +78,8 @@ $$
其中 $M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}$

> 对极端情况($\beta = 0$ 或 $\beta = 1$),直接计算单个 KL 散度:
> - 当 $\beta = 0$ 时:直接定义 $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$(Reverse KL,Mode-covering)
> - 当 $\beta = 1$ 时:直接定义 $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$(Forward KL,Mode-seeking)
> - 当 $\beta = 0$ 时:直接定义 $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$(Forward KL,Mode-covering)
> - 当 $\beta = 1$ 时:直接定义 $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$(Reverse KL,Mode-seeking)
> - 当 $0 < \beta < 1$ 时:使用上述混合分布公式进行插值

通过调节 $\beta$ 参数,可以在不同的散度度量之间进行插值,当 $\beta = 0.5$ 时,散度为标准的对称 JSD。
Expand Down Expand Up @@ -142,8 +142,8 @@ loss = D_JSD(P_teacher(·|x,y), P_student(·|x,y))
| 参数 | 类型 | 默认值 | 取值范围 | 说明 |
|------|------|--------|---------|------|
| `--teacher_model` | str | 必需 | - | 教师模型路径或模型 ID |
| `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数<br>• 0.0: Reverse KL (覆模,更多样)<br>• 0.5: JSD (平衡,**推荐**)<br>• 1.0: Forward KL (寻模,更专注) |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率<br>• 0.0: 纯 Off-Policy<br>• 0.5: 混合策略 (**推荐**)<br>• 1.0: 纯 On-Policy |
| `--beta` | float | 0.5 | [0.0, 1.0] | 散度插值系数<br>• 0.0: Forward KL <br>• 0.5: JSD (平衡)<br>• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy 学习触发概率<br>• 0.0: 纯 Off-Policy<br>• 0.5: 混合策略<br>• 1.0: 纯 On-Policy |
| `--seq_kd` | bool | False | True/False | 是否使用教师生成序列<br>• False: 非 on-policy 时使用数据集<br>• True: 非 on-policy 时使用教师生成 |
| `--temperature` | float | 0.9 | > 0 | 生成采样温度,控制随机性 |
| `--max_completion_length` | int | 512 | > 0 | 生成时的最大 token 数 |
Expand Down Expand Up @@ -200,3 +200,13 @@ swift rlhf \
```

训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/fast.sh)

## On-Policy Distillation

我们可以通过设置以下参数实现 Thinking Machine Lab blog 中的[On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/)训练。
```bash
--lmbda 1 # on-policy
--beta 1 # reverse
```

相关脚本可以参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh)
1 change: 1 addition & 0 deletions docs/source/Instruction/GRPO/AdvancedResearch/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Advanced Research
DAPO.md
deepeyes.md
GSPO.md
RLOO.md
CHORD.md
33 changes: 22 additions & 11 deletions docs/source_en/Instruction/GKD.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ In knowledge distillation, there are two choices depending on the order of the t
#### Forward KL

$$
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
$$

**Characteristics**: Mode-seeking
- Expectation is computed under the student distribution
- The student model tends to concentrate on the peak regions (high-probability areas) of the teacher model
**Characteristics**: Mode-covering
- Expectation is computed under the teacher distribution
- The student model tends to cover the entire teacher distribution (including low-probability regions)

#### Reverse KL

$$
\text{KL}(P_{\text{teacher}} \| P_{\text{student}}) = \sum_v P_{\text{teacher}}(v) \log \frac{P_{\text{teacher}}(v)}{P_{\text{student}}(v)}
\text{KL}(P_{\text{student}} \| P_{\text{teacher}}) = \sum_v P_{\text{student}}(v) \log \frac{P_{\text{student}}(v)}{P_{\text{teacher}}(v)}
$$

**Characteristics**: Mode-covering
- Expectation is computed under the teacher distribution
- The student model tends to cover the entire teacher distribution (including low-probability regions)
**Characteristics**: Mode-seeking
- Expectation is computed under the student distribution
- The student model tends to concentrate on the peak regions (high-probability areas) of the teacher model

### Generalized Jensen-Shannon Divergence (Generalized JSD)

Expand Down Expand Up @@ -78,8 +78,8 @@ $$
Where $M = \beta \cdot P_{\text{teacher}} + (1-\beta) \cdot P_{\text{student}}$

> For extreme cases ($\beta = 0$ or $\beta = 1$), directly compute a single KL divergence:
> - When $\beta = 0$: directly define $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$ (Reverse KL, Mode-covering)
> - When $\beta = 1$: directly define $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$ (Forward KL, Mode-seeking)
> - When $\beta = 0$: directly define $D = \text{KL}(P_{\text{teacher}} \| P_{\text{student}})$ (Forward KL, Mode-covering)
> - When $\beta = 1$: directly define $D = \text{KL}(P_{\text{student}} \| P_{\text{teacher}})$ (Reverse KL, Mode-seeking)
> - When $0 < \beta < 1$: use the above mixture distribution formula for interpolation

By adjusting the $\beta$ parameter, interpolation can be performed between different divergence metrics. When $\beta = 0.5$, the divergence is the standard symmetric JSD.
Expand Down Expand Up @@ -142,7 +142,7 @@ We can perform GKD training by setting the following parameters:
| Parameter | Type | Default | Range | Description |
|------|------|--------|---------|------|
| `--teacher_model` | str | Required | - | Teacher model path or model ID |
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient<br>• 0.0: Reverse KL (mode-covering, more diverse)<br>• 0.5: JSD (balanced, **recommended**)<br>• 1.0: Forward KL (mode-seeking, more focused) |
| `--beta` | float | 0.5 | [0.0, 1.0] | Divergence interpolation coefficient<br>• 0.0: Forward KL <br>• 0.5: JSD (balanced)<br>• 1.0: Reverse KL |
| `--lmbda` | float | 0.5 | [0.0, 1.0] | On-Policy learning trigger probability<br>• 0.0: Pure Off-Policy<br>• 0.5: Mixed strategy (**recommended**)<br>• 1.0: Pure On-Policy |
| `--seq_kd` | bool | False | True/False | Whether to use teacher-generated sequences<br>• False: Use dataset when not on-policy<br>• True: Use teacher generation when not on-policy |
| `--temperature` | float | 0.9 | > 0 | Generation sampling temperature, controls randomness |
Expand Down Expand Up @@ -201,3 +201,14 @@ swift rlhf \
```

Training script reference [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf/gkd/fast.sh)


## On-Policy Distillation
We can achieve the [On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/) training described in the Thinking Machines Lab blog by setting the following parameters:

```bash
--lmbda 1 # on-policy
--beta 1 # reverse
```

For a complete implementation, refer to the example script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/on_policy_distillation.sh).
1 change: 1 addition & 0 deletions docs/source_en/Instruction/GRPO/AdvancedResearch/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Advanced Research
DAPO.md
deepeyes.md
GSPO.md
RLOO.md
CHORD.md
41 changes: 41 additions & 0 deletions examples/train/on_policy_distillation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# On-Policy Distillation https://thinkingmachines.ai/blog/on-policy-distillation/

# CUDA_VISIBLE_DEVICES=7 \
# swift rollout \
# --model Qwen/Qwen3-8B-Base \
# --vllm_max_model_len 24192

NPROC_PER_NODE=7 \
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 \
swift rlhf \
--rlhf_type gkd \
--model Qwen/Qwen3-8B-Base \
--teacher_model Qwen/Qwen3-32B \
--train_type full \
--dataset open-thoughts/OpenThoughts3-1.2M#10000 \
--seq_kd false \
--lmbda 1 \
--beta 1 \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--learning_rate 1e-5 \
--gradient_accumulation_steps 1 \
--save_steps 1000 \
--save_total_limit 2 \
--logging_steps 1 \
--max_length 16000 \
--max_completion_length 8192 \
--output_dir output \
--warmup_ratio 0.05 \
--save_only_model true \
--dataloader_num_workers 64 \
--dataset_num_proc 4 \
--deepspeed zero2 \
--teacher_deepspeed zero3 \
--attn_impl flash_attn \
--use_vllm true \
--vllm_mode server \
--vllm_server_host 127.0.0.1 \
--vllm_server_port 8000
Loading