Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
ad3f9b1
support ppo
Jintao-Huang Dec 27, 2024
1e0e17d
update
Jintao-Huang Dec 27, 2024
29cd2d0
update
Jintao-Huang Dec 27, 2024
0492943
update
Jintao-Huang Dec 27, 2024
0214e49
fix
Jintao-Huang Dec 27, 2024
df183a2
update
Jintao-Huang Dec 27, 2024
b99eb6b
Merge branch 'main' into support_ppo
Jintao-Huang Dec 27, 2024
0a54bb8
update
Jintao-Huang Dec 28, 2024
8e90c4e
update
Jintao-Huang Dec 30, 2024
7ee63b9
Merge branch 'dev_1230' into support_ppo
Jintao-Huang Dec 30, 2024
1bdec8a
update
Jintao-Huang Dec 30, 2024
ed3e2c3
Merge branch 'main' into support_ppo
Jintao-Huang Dec 30, 2024
13e86a0
Merge branch 'main' into support_ppo
Jintao-Huang Dec 31, 2024
e19bffc
update
Jintao-Huang Dec 31, 2024
b351b54
Merge branch 'main' into support_ppo
Jintao-Huang Jan 1, 2025
b7c28aa
update
Jintao-Huang Jan 2, 2025
d8b2105
fix
Jintao-Huang Jan 2, 2025
f726d0a
update
Jintao-Huang Jan 2, 2025
d5dfcab
update
Jintao-Huang Jan 2, 2025
d609752
update
Jintao-Huang Jan 2, 2025
80157a6
fix bugs
Jintao-Huang Jan 2, 2025
106f588
fix
Jintao-Huang Jan 2, 2025
774b115
update
Jintao-Huang Jan 2, 2025
0508c53
Merge branch 'fix_shell_0102' into support_ppo
Jintao-Huang Jan 2, 2025
8e00e42
update
Jintao-Huang Jan 2, 2025
1fd06c6
update
Jintao-Huang Jan 2, 2025
3735cbd
update
Jintao-Huang Jan 2, 2025
acce966
Merge branch 'main' into support_ppo
Jintao-Huang Jan 2, 2025
c5b7022
update
Jintao-Huang Jan 2, 2025
16c8c00
update
Jintao-Huang Jan 2, 2025
368b2ef
update
Jintao-Huang Jan 3, 2025
9a783dd
update
Jintao-Huang Jan 3, 2025
3491646
Merge branch 'main' into support_ppo
Jintao-Huang Jan 3, 2025
1211640
revert
Jintao-Huang Jan 3, 2025
3bf2367
Merge branch 'main' into support_ppo
Jintao-Huang Jan 4, 2025
05c76ee
Merge branch 'main' into support_ppo
Jintao-Huang Jan 5, 2025
e7e7fd6
update
Jintao-Huang Jan 5, 2025
2b47806
update
Jintao-Huang Jan 5, 2025
ea082cd
update
Jintao-Huang Jan 6, 2025
349cc43
Merge branch 'main' into support_ppo
Jintao-Huang Jan 6, 2025
2f045c9
update
Jintao-Huang Jan 6, 2025
d4d33ef
Merge branch 'main' into support_ppo
Jintao-Huang Jan 6, 2025
6c62557
update
Jintao-Huang Jan 6, 2025
585ad23
fix
Jintao-Huang Jan 6, 2025
89dbe79
fix
Jintao-Huang Jan 6, 2025
d8030db
fix
Jintao-Huang Jan 6, 2025
ac49ee6
update
Jintao-Huang Jan 6, 2025
8495e33
Merge branch 'main' into support_ppo
Jintao-Huang Jan 6, 2025
25d9d9d
update
Jintao-Huang Jan 6, 2025
102257b
update
Jintao-Huang Jan 7, 2025
db9bdc6
update
Jintao-Huang Jan 7, 2025
ee202e1
update
Jintao-Huang Jan 7, 2025
455fbd5
fix
Jintao-Huang Jan 7, 2025
5789bb9
update
Jintao-Huang Jan 7, 2025
813dadf
update
Jintao-Huang Jan 7, 2025
828996b
fix
Jintao-Huang Jan 7, 2025
0486eab
update
Jintao-Huang Jan 7, 2025
2e98d6a
fix
Jintao-Huang Jan 7, 2025
6c2b682
fix
Jintao-Huang Jan 7, 2025
c592ac0
update
Jintao-Huang Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ You can contact us and communicate with us by adding our group:
- 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
- **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
- **RLHF Training**: Supports human alignment training methods such as DPO, CPO, SimPO, ORPO, KTO, RM for both pure text and multi-modal large models.
- **RLHF Training**: Supports human alignment training methods such as DPO, CPO, SimPO, ORPO, KTO, RM, PPO for both pure text and multi-modal large models.
- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
Expand All @@ -83,7 +83,7 @@ You can contact us and communicate with us by adding our group:
- 🎉 2024.08.12: The SWIFT paper has been published on arXiv, and you can read it [here](https://arxiv.org/abs/2408.05517).
- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
- 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM.
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
- 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).


Expand Down
4 changes: 2 additions & 2 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
- 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
- **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
- **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM等人类对齐训练方法
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM、PPO等人类对齐训练方法
- 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
- **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
- **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
Expand All @@ -78,7 +78,7 @@
- 🎉 2024.08.12: SWIFT论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
- 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
- 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM。
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO
- 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。

## 🛠️ 安装
Expand Down
8 changes: 8 additions & 0 deletions docs/source/Customization/自定义数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ query-response格式:
{"messages": [{"role": "system", "content": "你是个有用无害的数学计算器"}, {"role": "user", "content": "1+1等于几"}, {"role": "assistant", "content": "等于2"}, {"role": "user", "content": "再加1呢"}, {"role": "assistant", "content": "等于3"}], "label": true}
```

#### PPO

```jsonl
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}]}
{"messages": [{"role": "system", "content": "你是个有用无害的数学计算器"}, {"role": "user", "content": "1+1等于几"}, {"role": "assistant", "content": "等于2"}, {"role": "user", "content": "再加1呢"}]}
{"messages": [{"role": "user", "content": "你的名字是什么"}]}
```

### 序列分类
```jsonl
{"messages": [{"role": "user", "content": "今天天气真好呀"}], "label": 1}
Expand Down
2 changes: 1 addition & 1 deletion docs/source/GetStarted/快速开始.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ms-swift是魔搭社区提供的大模型与多模态大模型训练部署框架
- 🍊 轻量训练:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
- 分布式训练:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
- 量化训练:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
- RLHF训练:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM等人类对齐训练方法
- RLHF训练:支持纯文本大模型和多模态大模型的DPO、CPO、SimPO、ORPO、KTO、RM、PPO等人类对齐训练方法
- 🍓 多模态训练:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
- 界面训练:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
- 插件化与拓展:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
Expand Down
5 changes: 2 additions & 3 deletions docs/source/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@

## 待完成

1. RM/PPO能力3.0版本尚不支持,请使用2.6.1版本
2. 自定义数据集评测3.0版本尚不支持,请使用2.6.1版本
3. Megatron预训练能力3.0版本尚不支持,请使用2.6.1版本
1. 自定义数据集评测3.0版本尚不支持,请使用2.6.1版本
2. Megatron预训练能力3.0版本尚不支持,请使用2.6.1版本
3. 文档和README暂时未更新完整
25 changes: 23 additions & 2 deletions docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
- 🔥max_pixels: 多模态模型图片前处理的最大像素数(H\*W),默认不缩放。
- tools_prompt: 智能体训练时的工具列表转为system的格式,请参考[智能体训练](./智能体的支持.md),默认为'react_en'
- padding_side: 当训练`batch_size>=2`时的padding_side,可选值为'left', 'right',默认为'right'。(`generate`的batch_size>=2时,只进行左padding)
- loss_scale: 如何针对训练添加token的loss权重。默认为`'default'`,代表所有response(含history)以1计算交叉熵损失。具体可以查看[插件化](../Customization/插件化.md)和[智能体训练](./智能体的支持.md)
- loss_scale: 如何针对训练添加token的loss权重。默认为`'default'`,代表所有response(含history)以1计算交叉熵损失。可选值为'default', 'last_round', 'all', 以及agent需要的loss_scale: 'react', 'agentflan', 'alpha_umi', 'qwen'。具体可以查看[插件化](../Customization/插件化.md)和[智能体训练](./智能体的支持.md)
- sequence_parallel_size: 序列并行数量。参考[example](https://github.com/modelscope/ms-swift/tree/main/examples/train/sequence_parallel/train.sh)
- use_chat_template: 使用chat模板或generation模板,默认为`True`。`swift pt`会自动设置为generation模板
- template_backend: 使用swift或jinja进行推理。如果使用jinja,则使用transformers的`apply_chat_template`。默认为swift
Expand Down Expand Up @@ -307,7 +307,7 @@ Vera使用`target_modules`, `target_regex`, `modules_to_save`三个参数.
### RLHF参数
RLHF参数继承于[训练参数](#训练参数)

- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`
- 🔥rlhf_type: 对齐算法类型,支持`dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`, `ppo`
- ref_model: DPO等算法中的原始对比模型
- ref_model_type: 同model_type
- ref_model_revision: 同model_revision
Expand All @@ -324,6 +324,27 @@ RLHF参数继承于[训练参数](#训练参数)
- desirable_weight: KTO算法中对desirable response的loss权重 $\lambda_D$ ,默认为`1.`
- undesirable_weight: KTO论文中对undesirable response的loss权重 $\lambda_U$ , 默认为`1.`

#### PPO参数
- reward_model: 默认为None
- reward_adapters: 默认为`[]`
- reward_model_type: 默认为None
- reward_model_revision: 默认为None

以下参数含义可以参考[这里](https://huggingface.co/docs/trl/main/ppo_trainer)
- num_ppo_epochs: 默认为4
- whiten_rewards: 默认为False
- kl_coef: 默认为0.05
- cliprange: 默认为0.2
- vf_coef: 默认为0.1
- cliprange_value: 默认为0.2
- gamma: 默认为1.0
- lam: 默认为0.95
- num_mini_batches: 默认为1
- local_rollout_forward_batch_size: 默认为64
- num_sample_generations: 默认为10
- response_length: 默认为512
- temperature: 默认为0.7
- missing_eos_penalty: 默认为None

### 推理参数

Expand Down
8 changes: 8 additions & 0 deletions docs/source_en/Customization/Custom-dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ The following provides the recommended dataset format for ms-swift, where the sy
{"messages": [{"role": "system", "content": "You are a useful and harmless math calculator"}, {"role": "user", "content": "What is 1 + 1?"}, {"role": "assistant", "content": "It equals 2"}, {"role": "user", "content": "What about adding 1?"}, {"role": "assistant", "content": "It equals 3"}], "label": true}
```

#### PPO

```jsonl
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "Tell me tomorrow's weather"}]}
{"messages": [{"role": "system", "content": "You are a useful and harmless math calculator"}, {"role": "user", "content": "What is 1 + 1?"}, {"role": "assistant", "content": "It equals 2"}, {"role": "user", "content": "What about adding 1?"}]}
{"messages": [{"role": "user", "content": "What is your name?"}]}
```

### Sequence Classification
```jsonl
{"messages": [{"role": "user", "content": "The weather is really nice today"}], "label": 1}
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/GetStarted/Quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ms-swift is a comprehensive training and deployment framework for large language
- 🍊 Lightweight Training: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel, and more.
- Distributed Training: Supports distributed data parallel (DDP), simple model parallelism via device_map, DeepSpeed ZeRO2 ZeRO3, FSDP, and other distributed training technologies.
- Quantization Training: Provides training for quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
- RLHF Training: Supports human alignment training methods like DPO, CPO, SimPO, ORPO, KTO, RM for both text-based and multimodal large models.
- RLHF Training: Supports human alignment training methods like DPO, CPO, SimPO, ORPO, KTO, RM, PPO for both text-based and multimodal large models.
- 🍓 Multimodal Training: Capable of training models for different modalities such as images, videos, and audios; supports tasks like VQA (Visual Question Answering), Captioning, OCR (Optical Character Recognition), and Grounding.
- Interface-driven Training: Offers training, inference, evaluation, and quantization capabilities through an interface, enabling a complete workflow for large models.
- Plugins and Extensions: Allows customization and extension of models and datasets, and supports customizations for components like loss, metric, trainer, loss-scale, callback, optimizer, etc.
Expand Down
36 changes: 30 additions & 6 deletions docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The introduction to command line parameters will cover base arguments, atomic ar
- 🔥max_pixels: Maximum pixel count for pre-processing images in multimodal models (H*W), default is no scaling.
- tools_prompt: The list of tools for agent training converted to system format, refer to [Agent Training](./Agent-support.md), default is 'react_en'.
- padding_side: The padding_side used when training with `batch_size >= 2`, with optional values of 'left' and 'right', defaulting to 'right'. (When the batch_size in `generate` is >= 2, only left padding is applied.)
- loss_scale: How to add token loss weight during training. Default is `'default'`, meaning all responses (including history) are treated as 1 for cross-entropy loss. For specifics, see [Pluginization](../Customization/Pluginization.md) and [Agent Training](./Agent-support.md).
- loss_scale: How to add token loss weight during training. Default is `'default'`, meaning all responses (including history) are treated as 1 for cross-entropy loss. The optional values are 'default', 'last_round', 'all', and the loss scale required by the agent: 'react', 'agentflan', 'alpha_umi', 'qwen'. For specifics, see [Pluginization](../Customization/Pluginization.md) and [Agent Training](./Agent-support.md).
- sequence_parallel_size: Number of sequence parallelism. Refer to [example](https://github.com/modelscope/ms-swift/tree/main/examples/train/sequence_parallel/train.sh).
- use_chat_template: Use chat template or generation template, default is `True`. `swift pt` is automatically set to the generation template.
- template_backend: Use swift or jinja for inference. If using jinja, it will utilize transformers' `apply_chat_template`. Default is swift.
Expand Down Expand Up @@ -311,23 +311,47 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine

RLHF arguments inherit from the [training arguments](#training-arguments).

- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`.
- 🔥rlhf_type: Alignment algorithm type, supports `dpo`, `orpo`, `simpo`, `kto`, `cpo`, `rm`, `ppo`.
- ref_model: Original comparison model in algorithms like DPO.
- ref_model_type: Same as model_type.
- ref_model_revision: Same as model_revision.

- 🔥beta: KL regularization term coefficient, default is `None`, i.e., for `simpo` algorithm default is `2.`, for other algorithms default is `0.1`. Refer to the [documentation](./Human-alignment.md) for specifics.
- label_smoothing: Whether to use DPO smoothing, default value is `0`, generally set between 0~0.5.
-

- 🔥rpo_alpha: Weight for adding sft_loss in DPO, default is `1`. The final loss is `KL_loss + rpo_alpha * sft_loss`.
-

- cpo_alpha: The coefficient of nll loss in CPO/SimPO loss, default is `1.`.
-

- simpo_gamma: Reward margin term in SimPO algorithm, recommended to set between 0.5-1.5 in the paper, default is `1.`.
-

- desirable_weight: Loss weight for desirable response in KTO algorithm $\lambda_D$, default is `1.`.
- undesirable_weight: Loss weight for undesirable response in KTO paper $\lambda_U$, default is `1.`.

#### PPO Arguments

- reward_model: Defaults to None
- reward_adapters: Defaults to `[]`
- reward_model_type: Defaults to None
- reward_model_revision: Defaults to None

The meanings of the following parameters can be referenced [here](https://huggingface.co/docs/trl/main/ppo_trainer):

- num_ppo_epochs: Defaults to 4
- whiten_rewards: Defaults to False
- kl_coef: Defaults to 0.05
- cliprange: Defaults to 0.2
- vf_coef: Defaults to 0.1
- cliprange_value: Defaults to 0.2
- gamma: Defaults to 1.0
- lam: Defaults to 0.95
- num_mini_batches: Defaults to 1
- local_rollout_forward_batch_size: Defaults to 64
- num_sample_generations: Defaults to 10
- response_length: Defaults to 512
- temperature: Defaults to 0.7
- missing_eos_penalty: Defaults to None

### Inference Arguments

Inference arguments include the [base arguments](#base-arguments), [merge arguments](#merge-arguments), [vLLM arguments](#vllm-arguments), [LMDeploy arguments](#LMDeploy-arguments), and also contain the following:
Expand Down
7 changes: 3 additions & 4 deletions docs/source_en/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ The parameters marked as compatible in version 2.0 have been entirely removed.

## Pending Tasks

1. RM/PPO capabilities are not supported in version 3.0. Please use version 2.6.1.
2. Custom dataset evaluation is not supported in version 3.0. Please use version 2.6.1.
3. Megatron pre-training capabilities are not supported in version 3.0. Please use version 2.6.1.
4. Documentation and README are temporarily incomplete and will be updated.
1. Custom dataset evaluation is not supported in version 3.0. Please use version 2.6.1.
2. Megatron pre-training capabilities are not supported in version 3.0. Please use version 2.6.1.
3. Documentation and README are temporarily incomplete and will be updated.
2 changes: 1 addition & 1 deletion examples/deploy/lora/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ def infer_multilora(engine: InferClient, infer_request: InferRequest):

if __name__ == '__main__':
engine = InferClient(host='127.0.0.1', port=8000)
infer_request = InferRequest(messages=[{'role': 'user', 'content': '你是谁'}])
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
infer_multilora(engine, infer_request)
60 changes: 60 additions & 0 deletions examples/infer/demo_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
def infer_hf():
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from modelscope import snapshot_download
model_dir = snapshot_download('Qwen/Qwen2.5-7B-Instruct')
adapter_dir = snapshot_download('swift/test_lora')
model = AutoModelForCausalLM.from_pretrained(model_dir, torch_dtype='auto', device_map='auto')
model = PeftModel.from_pretrained(model, adapter_dir)

tokenizer = AutoTokenizer.from_pretrained(model_dir)

messages = [{
'role': 'system',
'content': 'You are a helpful assistant.'
}, {
'role': 'user',
'content': 'who are you?'
}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = tokenizer([text], return_tensors='pt').to(model.device)

generated_ids = model.generate(**model_inputs, max_new_tokens=512, do_sample=False)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f'response: {response}')
return response


def infer_swift():
from swift.llm import get_model_tokenizer, get_template, InferRequest, RequestConfig, PtEngine
from modelscope import snapshot_download
from swift.tuners import Swift
model_dir = snapshot_download('Qwen/Qwen2.5-7B-Instruct')
adapter_dir = snapshot_download('swift/test_lora')
model, tokenizer = get_model_tokenizer(model_dir, device_map='auto')
model = Swift.from_pretrained(model, adapter_dir)
template = get_template(model.model_meta.template, tokenizer)
engine = PtEngine.from_model_template(model, template)

messages = [{
'role': 'system',
'content': 'You are a helpful assistant.'
}, {
'role': 'user',
'content': 'who are you?'
}]
request_config = RequestConfig(max_tokens=512, temperature=0)
resp_list = engine.infer([InferRequest(messages=messages)], request_config=request_config)
response = resp_list[0].choices[0].message.content
print(f'response: {response}')
return response


if __name__ == '__main__':
response = infer_hf()
response2 = infer_swift()
assert response == response2
Loading
Loading