Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ You can contact us and communicate with us by adding our group:

## 🎉 News

- 🎁 2024.01.23: SWIFT support the `sample` command, this is a very important feature for complex CoT and RFT. Meanwhile, we support an [Reinforced Fine-tuning script](docs/source_en/Instruction/Reinforced_Fine_tuning.md).
- 🎁 2024.12.04: **SWIFT3.0** major version update. Please check the [Release Notes and Changes](https://swift.readthedocs.io/en/latest/Instruction/ReleaseNote3.0.html).
- 🎉 2024.08.12: The SWIFT paper has been published on arXiv, and you can read it [here](https://arxiv.org/abs/2408.05517).
- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
Expand Down Expand Up @@ -295,6 +296,15 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
--infer_backend vllm
```

### Sampling
```shell
CUDA_VISIBLE_DEVICES=0 swift sample \
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
--sampler_engine pt \
--num_return_sequences 5 \
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
```

### Evaluation
```shell
CUDA_VISIBLE_DEVICES=0 swift eval \
Expand Down
10 changes: 10 additions & 0 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
- **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。

## 🎉 新闻
- 🎁 2024.01.23: SWIFT支持了`sample`命令, 这是一个对CoT和RFT非常重要的命令. 同时, 我们支持了一个[强化微调脚本](docs/source/Instruction/强化微调.md)。
- 🎁 2024.12.04: **SWIFT3.0**大版本更新. 请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
- 🎉 2024.08.12: SWIFT论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
- 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
Expand Down Expand Up @@ -288,6 +289,15 @@ CUDA_VISIBLE_DEVICES=0 swift deploy \
--infer_backend vllm
```

### 采样
```shell
CUDA_VISIBLE_DEVICES=0 swift sample \
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
--sampler_engine pt \
--num_return_sequences 5 \
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
```

### 评测
```shell
CUDA_VISIBLE_DEVICES=0 swift eval \
Expand Down
103 changes: 103 additions & 0 deletions docs/source/Instruction/强化微调.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# 强化微调

强化微调是目前模型训练非常重要的功能之一,它本身的实现是多种多样的,SWIFT目前已经支持了强化微调所需要的原子能力,如采样、强化学习和微调。目前我们提供了拒绝采样微调的一个具体示例,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。

## 强化微调的概念

强化微调是从2022年开始(甚至更早)就被提出的概念。其方式一般有下列流程:

1. 使用某个模型生成数据,或进行原始数据扩充
2. 使用数据训练目标模型
3. 如果有必要,重复上述过程

步骤1:

- 如果生成数据的模型是更大的模型,如GPT、Qwen-Max、DeepSeek-V3/R1等,则该强化微调可以理解为蒸馏
- 如果生成数据的模型是本模型,则可以理解为自我提升(self-improvement)微调
- 如果采样过程是采样一个batch,然后通过KL散度和reward进行拟合训练并不断循环,则可以理解为PPO、GRPO等on-policy算法
- 采样数据的算法包含蒙特卡洛采样、do_sample采样、group beam search、dvts等
- 采样过程可以引入ORM(结果判断),PRM(过程打分),多样性过滤,语种过滤等

步骤2:

- 如果使用SFT,则称为拒绝采样微调
- 如果是强化学习,则称为强化学习微调

步骤3:

- 如果使用更大的模型蒸馏,例如更大模型的蒙特卡洛采样蒸馏,一般不会有循环
- 如果使用本模型进行采样,或者PPO等算法,则会有循环

泛泛来说,常见强化微调的方式有下面几种:

1. 蒸馏:使用蒙特卡洛、do_sample等方式从超大模型中采样大量优质数据,训练小模型
2. 自我提升:从本模型中采样部分优质数据,筛选后训练本模型,循环执行
3. on-policy RL:使用PPO、GRPO等方式循环训练

采样过程一般很漫长,比训练过程漫长的多。如果使用GPT等模型蒸馏数据,则需要购买token。因此,强化微调的时间成本和花费成本比较高,所以一般作为微调的补充机制出现,当然也有特例,例如最近的DeepSeek-R1。

DeepSeek-R1使用了GRPO算法从零使base模型涌现CoT能力,该方法需要大规模集群支持,且模型需要足够大才能发生能力涌现,在本文中不详细讨论。如果需要了解该过程,请查看[论文解析](https://zhuanlan.zhihu.com/p/19714987272)。

有关强化微调的一些论文:

- 拒绝采样微调:https://arxiv.org/pdf/2308.01825
- ReST:https://arxiv.org/pdf/2308.08998
- B-STAR:https://arxiv.org/pdf/2412.17256
- DeepSeekMath:https://arxiv.org/pdf/2402.03300
- Qwen-math-PRM:https://arxiv.org/pdf/2501.07301
- DeepSeek-R1:https://github.com/deepseek-ai/DeepSeek-R1/tree/main

## 什么时候使用强化微调

在LLaMA3之后,我们发现一个非常明显但却是不常被提及的特点:使用某个含有CoT的train数据集训练Instruct模型,再通过对应的test集进行评测,会发现test集评测效果变差。例如,使用gsm8k训练集训练llama3.1-8b-instruct,对生成的ckpt使用test集进行评测,会发现掉点。

这个特性主要来源于模型的知识遗忘问题。在模型厂商的微调中,会加入非常多的CoT数据集,模型在解决数学任务的时候,用到的能力很有可能不是来自于math数据集,而是来自arc数据集,这个推论有[一些工作可以证明](https://zhuanlan.zhihu.com/p/19269451950)。在继续训练通用任务后,知识遗忘破坏了模型原有能力,导致了掉点。

然而,优先使用微调方式训练模型总是正确的。微调可以使模型快速适应数据集的分布,并且微调的成本很低。当有如下条件之一时使用强化微调:

1. 已经微调过模型,能力不满足需求
2. 需要更强的CoT能力
3. 对基模型训练通用能力,而原始数据集已经导致模型效果无法提升
4. 对应query的输出结果可以相对准确地评估好坏,例如结果清晰(数学,代码),过程清晰(翻译,风格)等

强化微调非常依赖于reward评估是否准确。如果评估结果不准确,可能导致模型训练原地震荡,甚至越训越差。

## SWIFT的实现

SWIFT支持sample命令,该命令就是用于模型采样。目前支持的采样方式有:

- do_sample:sample方式对模型进行采样,该方式支持对开源模型进行采样,后续会支持模型蒸馏
- sample方式后续会支持URL采样,用于大模型蒸馏

- mcts:蒙特卡洛采样,该方式在PR中,后续会支持
- dvts:调研中

目前我们给出了一个较为通用的[RFT脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本适用于自我提升方式的训练,且支持动态调整采样温度值、PRM阈值等超参数,并且训练方式灵活可变(微调、DPO等;或者每次迭代重新训练原模型或继续训练上个迭代的模型,甚至加载上个迭代的所有训练状态等)。开发者可以在该脚本中增加其他数据过滤(生成的数据集中,id相同的行来自同一个query),例如多样性判断、语种判断等。

## 实验结果

我们对该RFT脚本针对数学领域使用competition_math数据集进行了训练和评测,结果如下:

| 模型 | MATH指标 | 训练方式 | 迭代次数 | 训练后MATH指标 |
| ------------------------ | -------- | -------- | -------- | --------------------- |
| LLaMA3.1_8b | 12.0 | SFT | 3 | 25.2(LLaMA3.1_8b_sft) |
| LLaMA3.1_8b_sft | 25.2 | RFT | 2 | 32.4 |
| LLaMA3.1_8b_instruct | 52.2 | SFT | 2 | 39.0 |
| LLaMA3.1_8b_instruct | 52.2 | RFT | 3 | 58 |
| Qwen2.5_math_7b_instruct | 79.6 | RFT | 2 | 83.2 |

可以看到,使用competition_math直接SFT后,instruct模型的掉点十分严重。而RFT后模型能力有提升,即使对Qwen2.5_math_7b_instruct这个SOTA的math模型也同样有一定提升空间。

特别地,针对Qwen2.5_math_7b_instruct我们测试了gsm8k的指标:

| 模型 | gsm8k指标 | RFT后gsm8k指标 |
| ------------------------ | --------- | -------------- |
| Qwen2.5_math_7b_instruct | 92.8 | 91.6 |

可以看到,RFT训练后gsm8k指标变化不大,并没有出现前述的掉点现象。

## 未来计划

1. 更多的采样方式,如MCTS
2. 超大模型蒸馏训练
3. 以PPO为主的on-policy训练
13 changes: 11 additions & 2 deletions docs/source/Instruction/采样.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class CustomPRM:
pass

@torch.inference_mode()
def infer(self, infer_requests: List[InferRequest], **kwargs) -> List[ChatCompletionResponse]:
def infer(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[ChatCompletionResponse]:
...


Expand All @@ -59,8 +59,17 @@ prms = {'custom': CustomPRM}

之后在命令行中使用`--prm_model custom`即可。

## 显存控制

如果被采样模型和PRM共同加载进显存,则可能出现OOM的问题。因此采样可以分为两段进行:

- 第一段指定`--model`和``--sampler_engine`,同时不指定`--orm_model`和`--prm_model`,仅进行采样,并存储为文件
- 第二段指定`--sampler_engine no`,指定`--orm_model`和`--prm_model`,并同时指定`--cache_files`,仅进行RM数据过滤,不重新采样

通过两段方式可以每次仅加载一个模型,防止OOM。

## 实际例子

请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/scripts/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。
请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。

> 注意:该脚本的实际效果和模型、数据、RM的质量强相关,因此仅作为样例出现,用户请自行修改该脚本并训练自己的RM和generator模型。
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ Swift DOCUMENTATION
Instruction/预训练及微调.md
Instruction/人类对齐.md
Instruction/推理和部署.md
Instruction/采样.md
Instruction/评测.md
Instruction/导出.md
Instruction/强化微调.md
Instruction/支持的模型和数据集.md
Instruction/使用tuners.md
Instruction/智能体的支持.md
Expand Down
103 changes: 103 additions & 0 deletions docs/source_en/Instruction/Reinforced_Fine_tuning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Reinforced Fine-Tuning

Reinforced fine-tuning is one of the most important functionalities in current model training, with various implementations. SWIFT has already supported the atomic capabilities required for reinforced fine-tuning, such as sampling, reinforcement learning, and fine-tuning. Currently, we provide a specific example of rejection sampling fine-tuning, which can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py).

## Concept of Reinforced Fine-Tuning

The concept of reinforced fine-tuning has been proposed since 2022 (or even earlier). Its general workflow typically includes the following steps:

1. Generate data using a specific model or augment the original dataset.
2. Train the target model using the generated data.
3. Repeat the above process if necessary.

**Step 1:**

- If the data-generating model is a larger model, such as GPT, Qwen-Max, DeepSeek-V3/R1, etc., this process can be understood as distillation.
- If the data-generating model is the same model being trained, this can be considered self-improvement fine-tuning.
- If the sampling process involves sampling a batch, fitting the data with KL divergence and rewards, and iterating continuously, it can be classified as on-policy algorithms like PPO or GRPO.
- Sampling algorithms include Monte Carlo sampling, do_sample, group beam search, DVTS, etc.
- The sampling process can incorporate ORM (Outcome Reward Model), PRM (Process Reward Model), diversity filtering, language filtering, etc.

**Step 2:**

- If SFT (Supervised Fine-Tuning) is used, it is referred to as rejection sampling fine-tuning.
- If reinforcement learning is used, it is called reinforcement learning fine-tuning.

**Step 3:**

- If distillation is performed using a larger model (e.g., Monte Carlo sampling distillation with a larger model), the process usually does not involve iterations.
- If the same model is used for sampling or algorithms like PPO are applied, iterations are typically included.

In general, the common approaches to reinforced fine-tuning include:

1. **Distillation**: Sampling high-quality data in bulk from a larger model using methods like Monte Carlo or do_sample, and training a smaller model on this data.
2. **Self-improvement**: Sampling a portion of high-quality data from the same model, filtering it, and training the model iteratively.
3. **On-policy RL**: Using methods like PPO or GRPO for iterative training.

The sampling process is usually much more time-consuming than the training process. If data is distilled using GPT or other large models, token costs must be considered. Thus, reinforced fine-tuning is generally a supplementary mechanism for fine-tuning, except for special cases like DeepSeek-R1.

DeepSeek-R1 uses the GRPO algorithm to enable the emergence of CoT (Chain-of-Thought) capabilities from scratch in a base model. This method requires large-scale cluster support and sufficiently large models for capability emergence. This is not discussed in detail here, but more information can be found in the [paper analysis](https://zhuanlan.zhihu.com/p/19714987272).

Some related papers on reinforced fine-tuning:

- Rejection Sampling Fine-Tuning: https://arxiv.org/pdf/2308.01825
- ReST: https://arxiv.org/pdf/2308.08998
- B-STAR: https://arxiv.org/pdf/2412.17256
- DeepSeekMath: https://arxiv.org/pdf/2402.03300
- Qwen-Math-PRM: https://arxiv.org/pdf/2501.07301
- DeepSeek-R1: https://github.com/deepseek-ai/DeepSeek-R1/tree/main

## When to Use Reinforced Fine-Tuning

Since LLaMA3, we have observed a very noticeable yet rarely mentioned phenomenon: when training an Instruct model using a CoT-enabled training dataset and evaluating it on the corresponding test set, the test set performance tends to degrade. For example, training `llama3.1-8b-instruct` on the GSM8K training set and evaluating the generated checkpoint on the test set reveals performance degradation.

This phenomenon mainly arises from the issue of knowledge forgetting disaster in models. During fine-tuning by model manufacturers, a significant amount of CoT data is often included. When solving mathematical tasks, the model's capability often originates not from the math dataset itself but potentially from datasets like ARC. This inference is supported by [some works](https://zhuanlan.zhihu.com/p/19269451950). Continued training on general tasks disrupts the model's existing capabilities, leading to performance degradation.

However, it is always correct to prioritize fine-tuning. Fine-tuning allows the model to quickly adapt to the dataset distribution at a low cost. Reinforced fine-tuning should be used under the following conditions:

1. The model has already been fine-tuned but does not meet the requirements.
2. Stronger CoT capabilities are needed.
3. Base model training for general capabilities is necessary, and the original dataset no longer improves performance.
4. The output results for corresponding queries can be relatively accurately evaluated, such as tasks with clear results (math, code) or clear processes (translation, style fitting).

Reinforced fine-tuning heavily depends on the accuracy of reward evaluations. If the evaluations are inaccurate, the training may oscillate without progress or even degrade the model performance.

## SWIFT Implementation

SWIFT supports the `sample` command, which is used for model sampling. Currently supported sampling methods include:

- **do_sample**: A sampling method for open-source models; future updates will include support for model distillation.
- URL sampling will also be supported in the future for large-model distillation.

- **mcts**: Monte Carlo sampling, currently under review, with future support planned.
- **dvts**: Currently under investigation.

We have provided a general [RFT script](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py). This script supports self-improvement training and allows dynamic adjustments of sampling temperature, PRM thresholds, and other hyperparameters. The training method is flexible (e.g., fine-tuning, DPO) and supports iterative retraining of the original model or continued training from the previous iteration, even loading all training states from the previous iteration. Developers can incorporate additional data filtering (e.g., ensuring rows with the same ID come from the same query), including diversity checks, language filtering, etc.

## Experimental Results

We used the RFT script to train and evaluate the `competition_math` dataset in the math domain. The results are as follows:

| Model | MATH Score | Training Method | Iterations | Post-Training MATH Score |
|----------------------------|------------|-----------------|------------|---------------------------|
| LLaMA3.1_8b | 12.0 | SFT | 3 | 25.2 (LLaMA3.1_8b_sft) |
| LLaMA3.1_8b_sft | 25.2 | RFT | 2 | 32.4 |
| LLaMA3.1_8b_instruct | 52.2 | SFT | 2 | 39.0 |
| LLaMA3.1_8b_instruct | 52.2 | RFT | 3 | 58 |
| Qwen2.5_math_7b_instruct | 79.6 | RFT | 2 | 83.2 |

As shown, applying SFT to the `competition_math` dataset resulted in significant performance degradation for the instruct model. However, RFT improved the model's capabilities, even for the state-of-the-art `Qwen2.5_math_7b_instruct` math model.

Specifically, we tested the GSM8K metric for `Qwen2.5_math_7b_instruct`:

| Model | GSM8K Score | Post-RFT GSM8K Score |
|----------------------------|-------------|-----------------------|
| Qwen2.5_math_7b_instruct | 92.8 | 91.6 |

As shown, RFT training did not significantly change the GSM8K score, avoiding the previously mentioned performance degradation phenomenon.

## Future Roadmap

1. More sampling methods,MCTS for example
2. Distill from super huge model
3. On policy RFT like PPO
Loading
Loading