diff --git a/README.md b/README.md
index 3990ba88e5..e92a380245 100644
--- a/README.md
+++ b/README.md
@@ -103,6 +103,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用
- View the models and datasets supported by Swift. You can check [supported models and datasets](https://github.com/modelscope/swift/blob/main/docs/source/LLM/支持的模型和数据集.md).
- Expand and customize models, datasets, and dialogue templates in Swift, see [Customization and Expansion](https://github.com/modelscope/swift/blob/main/docs/source/LLM/自定义与拓展.md).
- Check command-line parameters for fine-tuning and inference, see [Command-Line parameters](https://github.com/modelscope/swift/blob/main/docs/source/LLM/命令行参数.md).
+- View the training time and training GPU memory comparison under different parameters, you can check [Benchmark](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Benchmark.md).
### Features
diff --git a/README_CN.md b/README_CN.md
index 257f6f8ae1..771ccf68cf 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -101,6 +101,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
- 查看swift支持的模型和数据集. 可以查看[支持的模型和数据集](https://github.com/modelscope/swift/blob/main/docs/source/LLM/支持的模型和数据集.md).
- 对swift中的模型, 数据集, 对话模板进行**拓展**, 可以查看[自定义与拓展](https://github.com/modelscope/swift/blob/main/docs/source/LLM/自定义与拓展.md).
- 查询微调和推理的命令行参数, 可以查看[命令行参数](https://github.com/modelscope/swift/blob/main/docs/source/LLM/命令行参数.md).
+- 查看不同参数下的训练时间和训练显存对比, 可以查看[Benchmark](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Benchmark.md).
### 特性
diff --git a/docs/source/LLM/Benchmark.md b/docs/source/LLM/Benchmark.md
new file mode 100644
index 0000000000..7b17d5aaf0
--- /dev/null
+++ b/docs/source/LLM/Benchmark.md
@@ -0,0 +1,515 @@
+# Benchmark
+## 目录
+- [参数设置](#参数设置)
+- [量化](#量化)
+- [Max Length](#max-length)
+- [Batch Size](#batch-size)
+- [Use Flash Attn & Gradient Checkpointing](#use-flash-attn--gradient-checkpointing)
+- [Model Type](#model-type)
+- [LoRA Rank & LoRA Target Modules](#lora-rank--lora-target-modules)
+
+## 参数设置
+
+测试参数对于训练速度和训练内存使用的影响. 后续会补充部分参数对训练效果的影响.
+
+实验环境:
+- A100
+- CUDA 11.8
+- python 3.10
+- torch 2.1.1
+- flash_attn 2.3.4
+- xformers 0.0.23
+- auto_gptq 0.5.1
+- bitsandbytes 0.41.3
+
+
+我们使用了1000条训练数据集进行基准测试. 实验使用脚本可以查看`scripts/benchmark/test_memory_time/`.
+
+以下为所有实验的相同命令行设置部分:
+```bash
+ --dataset_test_ratio 0 \
+ --dataset cls-fudan-news-zh \
+ --train_dataset_sample 1000 \
+ --save_strategy no \
+ --check_dataset_strategy warning \
+ --truncation_strategy truncation_left \
+ --preprocess_num_proc 4 \
+```
+
+如果未指定以下参数, 则使用以下默认值:
+```bash
+ --max_length 2048 \
+ --batch_size 1 \
+ --gradient_checkpinting true \
+ --use_flash_attn true \
+ --lora_rank 8 \
+ --lora_target_modules DEFAULT \
+ --quantization_bit 0 \
+```
+
+## 量化
+测试脚本为:
+```bash
+swift sft \
+ --model_type {MODEL_TYPE} \
+ --quantization_bit {QUANTIZATION_BIT} \
+ --sft_type lora \
+ ...
+```
+
+
+
+ Model Type [LoRA] |
+ Quantization |
+ Training Speed |
+ GPU Memory |
+
+
+ qwen-7b-chat |
+ bf16 |
+ 7.01min |
+ 19362MiB |
+
+
+ int4 (gptq) |
+ 11.37min |
+ 10504MiB |
+
+
+ int8 (gptq) |
+ 11.73min |
+ 13648MiB |
+
+
+ int4 (bnb) |
+ 9.41min |
+ 13616MiB |
+
+
+ qwen-14b-chat |
+ bf16 |
+ 11.73min |
+ 32186MiB |
+
+
+ int4 (gptq) |
+ 19.69min |
+ 14852MiB |
+
+
+ int8 (gptq) |
+ 20.60min |
+ 20790MiB |
+
+
+ int4 (bnb) |
+ 16.30min |
+ 19278MiB |
+
+
+ qwen-72b-chat |
+ bf16 |
+ - |
+ OOM |
+
+
+ int4 (gptq) |
+ 97.94min |
+ 46980MiB |
+
+
+ int8 (gptq) |
+ 103.83min |
+ 80646MiB |
+
+
+ int4 (bnb) |
+ 81.72min |
+ 62430MiB |
+
+
+
+## Max Length
+### Full
+测试脚本为:
+```bash
+swift sft \
+ --model_type {MODEL_TYPE} \
+ --max_length {MAX_LENGTH} \
+ --sft_type full \
+ ...
+```
+
+
+
+ Model Type [FULL] |
+ Max Length |
+ Training Speed |
+ GPU Memory |
+
+
+ qwen-1_8b-chat |
+ 512 |
+ 1.85min |
+ 18010MiB |
+
+
+ 1024 |
+ 1.98min |
+ 18072MiB |
+
+
+ 2048 |
+ 2.76min |
+ 20286MiB |
+
+
+ 4096 |
+ 3.87min |
+ 26436MiB |
+
+
+ 8192 |
+ 4.86min |
+ 37530MiB |
+
+
+ qwen-7b-chat |
+ 512 |
+ 3.89min |
+ 75213MiB |
+
+
+ 1024 |
+ 5.74min |
+ 75627MiB |
+
+
+ 2048 |
+ 8.88min |
+ 76520MiB |
+
+
+ 4096 |
+ 13.94min |
+ 78986MiB |
+
+
+ 8192 |
+ - |
+ OOM |
+
+
+ qwen-14b-chat |
+ 512 |
+ - |
+ OOM |
+
+
+
+### LoRA
+测试脚本为:
+```bash
+swift sft \
+ --model_type {MODEL_TYPE} \
+ --max_length {MAX_LENGTH} \
+ --sft_type lora \
+ ...
+```
+
+
+
+ Model Type [LoRA] |
+ Max Length |
+ Training Speed |
+ GPU Memory |
+
+
+ qwen-1_8b-chat |
+ 512 |
+ 2.02min |
+ 4610MiB |
+
+
+ 1024 |
+ 2.07min |
+ 5576MiB |
+
+
+ 2048 |
+ 2.48min |
+ 7624MiB |
+
+
+ 4096 |
+ 3.73min |
+ 17324MiB |
+
+
+ 8192 |
+ 4.48min |
+ 36620MiB |
+
+
+ qwen-7b-chat |
+ 512 |
+ 2.52min |
+ 15926MiB |
+
+
+ 1024 |
+ 4.11min |
+ 17096MiB |
+
+
+ 2048 |
+ 7.01min |
+ 19362MiB |
+
+
+ 4096 |
+ 11.12min |
+ 29264MiB |
+
+
+ 8192 |
+ 13.63min |
+ 48560MiB |
+
+
+ qwen-14b-chat |
+ 512 |
+ 3.94min |
+ 28466MiB |
+
+
+ 1024 |
+ 6.67min |
+ 29708MiB |
+
+
+ 2048 |
+ 11.73min |
+ 32186MiB |
+
+
+ 4096 |
+ 18.88min |
+ 42098MiB |
+
+
+ 8192 |
+ 23.61min |
+ 61412MiB |
+
+
+
+
+## Batch Size
+测试脚本为:
+```bash
+swift sft \
+ --batch_size {BATCH_SIZE} \
+ --model_type qwen-7b-chat \
+ --sft_type lora \
+ ...
+```
+
+
+
+ Model Type [LoRA] |
+ Batch Size |
+ Training Speed |
+ GPU Memory |
+
+
+ qwen-7b-chat |
+ 1 |
+ 7.01min |
+ 19362MiB |
+
+
+ 2 |
+ 8.05min |
+ 24842MiB |
+
+
+ 4 |
+ 7.95min |
+ 34842MiB |
+
+
+ 8 |
+ 7.94min |
+ 54844MiB |
+
+
+
+## Use Flash Attn & Gradient Checkpointing
+测试脚本为:
+```bash
+swift sft \
+ --use_flash_attn {USE_FLASH_ATTN} \
+ --gradient_checkpointing {GRADIENT_CHECKPOINTING} \
+ --model_type qwen-7b-chat \
+ --sft_type lora \
+ ...
+```
+
+
+
+ Model Type [LoRA] |
+ Use Flash Attn |
+ Gradient Checkpointing |
+ Training Speed |
+ GPU Memory |
+
+
+ qwen-7b-chat |
+ ✔ |
+ ✔ |
+ 7.01min |
+ 19362MiB |
+
+
+ ✔ |
+ ✘ |
+ 5.19min |
+ 30316MiB |
+
+
+ ✘ |
+ ✔ |
+ 9.94min |
+ 19422MiB |
+
+
+ ✘ |
+ ✘ |
+ 7.37min |
+ 42260MiB |
+
+
+
+## Model Type
+测试脚本为:
+```bash
+swift sft \
+ --model_type {MODEL_TYPE} \
+ --sft_type lora \
+ ...
+```
+
+
+
+ Model Type [LoRA] |
+ Training Speed |
+ GPU Memory |
+
+
+ qwen-1_8b-chat |
+ 2.48min |
+ 7624MiB |
+
+
+ qwen-7b-chat |
+ 7.01min |
+ 19362MiB |
+
+
+ qwen-14b-chat |
+ 11.73min |
+ 32186MiB |
+
+
+ chatglm2-6b |
+ 7.14min |
+ 14540MiB |
+
+
+ chatglm3-6b |
+ 7.19min |
+ 14612MiB |
+
+
+ baichuan2-7b |
+ 8.61min |
+ 19254MiB |
+
+
+ baichuan2-13b |
+ 16.37min |
+ 33118MiB |
+
+
+ yi-6b-chat |
+ 8.18min |
+ 14386MiB |
+
+
+ yi-34b-chat |
+ 30.77min |
+ 70482MiB |
+
+
+ openbuddy-mistral-7b-chat |
+ 9.08min |
+ 16618MiB |
+
+
+ openbuddy-zephyr-7b-chat |
+ 9.10min |
+ 16618MiB |
+
+
+
+## LoRA Rank & LoRA Target Modules
+测试脚本为:
+```bash
+swift sft \
+ --lora_rank {LORA_RANK} \
+ --lora_target_modules {LORA_TARGET_MODULES} \
+ --model_type qwen-7b-chat \
+ --sft_type lora \
+ ...
+```
+
+
+
+ Model Type [LoRA] |
+ LoRA Rank |
+ LoRA Target Modules |
+ Training Speed |
+ GPU Memory |
+ Trainable Params |
+
+
+ qwen-7b-chat |
+ 2 |
+ DEFAULT (c_attn) |
+ 7.01min |
+ 19300MiB |
+ 1.05M |
+
+
+ 8 |
+ DEFAULT |
+ 7.01min |
+ 19362MiB |
+ 4.19M |
+
+
+ 64 |
+ DEFAULT |
+ 7.01min |
+ 20728MiB |
+ 33.55MB |
+
+
+ 8 |
+ ALL (all linear) |
+ 9.36min |
+ 19670MiB |
+ 17.89M |
+
+
diff --git "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md"
index c1c20d4a91..7d9b86b6a3 100644
--- "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md"
+++ "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md"
@@ -75,6 +75,7 @@
- `--logging_dir`: 默认为`None`. 即设置为`f'{self.output_dir}/runs'`, 表示tensorboard文件存储路径.
- `--check_model_is_latest`: 检查模型是否是最新, 默认为`True`. 如果你需要断网进行训练, 请将该参数设置为`False`.
- `--save_on_each_node`: 该参数在多机训练时生效, 默认为`True`.
+- `--save_strategy`: 保存checkpoint的策略, 默认为`'steps'`, 可选择的值包括: 'steps', 'no'.
- `--max_new_tokens`: 默认为`2048`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
- `--do_sample`: 默认为`True`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
- `--temperature`: 默认为`0.3`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
diff --git "a/docs/source/LLM/\350\207\252\346\210\221\350\256\244\347\237\245\345\276\256\350\260\203\346\234\200\344\275\263\345\256\236\350\267\265.md" "b/docs/source/LLM/\350\207\252\346\210\221\350\256\244\347\237\245\345\276\256\350\260\203\346\234\200\344\275\263\345\256\236\350\267\265.md"
index 4a436cecc4..0c96abe219 100644
--- "a/docs/source/LLM/\350\207\252\346\210\221\350\256\244\347\237\245\345\276\256\350\260\203\346\234\200\344\275\263\345\256\236\350\267\265.md"
+++ "b/docs/source/LLM/\350\207\252\346\210\221\350\256\244\347\237\245\345\276\256\350\260\203\346\234\200\344\275\263\345\256\236\350\267\265.md"
@@ -7,7 +7,7 @@
- [微调前推理](#微调前推理)
- [微调](#微调)
- [微调后推理](#微调后推理)
-- [Web-UI](#Web-UI)
+- [Web-UI](#web-ui)
- [了解更多](#了解更多)
## 环境安装
diff --git a/examples/pytorch/llm/scripts/qwen_7b_chat/full/sft.sh b/examples/pytorch/llm/scripts/qwen_7b_chat/full/sft.sh
index db026911cc..8811ff8e3b 100644
--- a/examples/pytorch/llm/scripts/qwen_7b_chat/full/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_7b_chat/full/sft.sh
@@ -13,3 +13,4 @@ swift sft \
--use_flash_attn true \
--only_save_model true \
--dataset codefuse-evol-instruction-zh \
+ --preprocess_num_proc 4 \
diff --git a/examples/pytorch/llm/scripts/qwen_7b_chat/full_freeze_ddp/sft.sh b/examples/pytorch/llm/scripts/qwen_7b_chat/full_freeze_ddp/sft.sh
index 0d44d72e81..2056a17567 100644
--- a/examples/pytorch/llm/scripts/qwen_7b_chat/full_freeze_ddp/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_7b_chat/full_freeze_ddp/sft.sh
@@ -15,3 +15,4 @@ swift sft \
--only_save_model true \
--dataset codefuse-evol-instruction-zh \
--freeze_parameters 0.2 \
+ --preprocess_num_proc 4 \
diff --git a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/sft.sh b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/sft.sh
index a28ad92c95..a9b2c93532 100644
--- a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp/sft.sh
@@ -31,3 +31,4 @@ python llm_sft.py \
--hub_model_id qwen-7b-chat-full \
--hub_private_repo true \
--hub_token 'your-sdk-token' \
+ --preprocess_num_proc 4 \
diff --git a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/sft.sh b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/sft.sh
index 8d1277372b..d3d36f11ee 100644
--- a/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_7b_chat/full_mp_ddp/sft.sh
@@ -36,3 +36,4 @@ torchrun \
--hub_model_id qwen-7b-chat-full \
--hub_private_repo true \
--hub_token 'your-sdk-token' \
+ --preprocess_num_proc 4 \
diff --git a/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp/sft.sh b/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp/sft.sh
index 1231cac1ea..fb15756c0f 100644
--- a/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp/sft.sh
@@ -13,3 +13,4 @@ swift sft \
--use_flash_attn true \
--only_save_model true \
--dataset aishell1-mini-zh \
+ --lazy_tokenize true \
diff --git a/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp_ddp/sft.sh b/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp_ddp/sft.sh
index c14da31573..22b9709102 100644
--- a/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp_ddp/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_audio_chat/full_mp_ddp/sft.sh
@@ -14,3 +14,4 @@ swift sft \
--use_flash_attn true \
--only_save_model true \
--dataset aishell1-mini-zh \
+ --lazy_tokenize true \
diff --git a/examples/pytorch/llm/scripts/qwen_audio_chat/lora/sft.sh b/examples/pytorch/llm/scripts/qwen_audio_chat/lora/sft.sh
index b3e3840930..b2e7827de2 100644
--- a/examples/pytorch/llm/scripts/qwen_audio_chat/lora/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_audio_chat/lora/sft.sh
@@ -35,3 +35,4 @@ python llm_sft.py \
--hub_model_id qwen-audio-chat-lora \
--hub_private_repo true \
--hub_token 'your-sdk-token' \
+ --lazy_tokenize true \
diff --git a/examples/pytorch/llm/scripts/qwen_audio_chat/lora_ddp_ds/sft.sh b/examples/pytorch/llm/scripts/qwen_audio_chat/lora_ddp_ds/sft.sh
index e1e1e93808..b67975682b 100644
--- a/examples/pytorch/llm/scripts/qwen_audio_chat/lora_ddp_ds/sft.sh
+++ b/examples/pytorch/llm/scripts/qwen_audio_chat/lora_ddp_ds/sft.sh
@@ -42,3 +42,4 @@ torchrun \
--hub_token 'your-sdk-token' \
--deepspeed_config_path 'ds_config/zero2.json' \
--only_save_model true \
+ --lazy_tokenize true \
diff --git a/scripts/benchmark/test_memory_time/run_loop.py b/scripts/benchmark/test_memory_time/run_loop.py
new file mode 100644
index 0000000000..b77e9fff91
--- /dev/null
+++ b/scripts/benchmark/test_memory_time/run_loop.py
@@ -0,0 +1,27 @@
+# CUDA_VISIBLE_DEVICES=0 nohup python scripts/benchmark/test_memory_time/run_loop.py &> 0.out &
+
+import os
+# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+import subprocess
+from typing import List
+
+from swift.utils import read_from_jsonl, write_to_jsonl
+
+
+def test_memory_time_loop(train_kwargs_jsonl: str) -> None:
+ while True:
+ obj_list = read_from_jsonl(train_kwargs_jsonl)
+ if len(obj_list[0]) == 0:
+ break
+ obj: List[str] = obj_list.pop(0)
+ obj_list.append(obj)
+ write_to_jsonl(train_kwargs_jsonl, obj_list)
+ ret = subprocess.run([
+ 'python', 'scripts/benchmark/test_memory_time/run_single.py', *obj
+ ])
+ assert ret.returncode == 0
+
+
+if __name__ == '__main__':
+ jsonl_path = os.path.join('scripts/benchmark/test_memory_time/run.jsonl')
+ test_memory_time_loop(jsonl_path)
diff --git a/scripts/benchmark/test_memory_time/run_single.py b/scripts/benchmark/test_memory_time/run_single.py
new file mode 100644
index 0000000000..300dc6ac5f
--- /dev/null
+++ b/scripts/benchmark/test_memory_time/run_single.py
@@ -0,0 +1,69 @@
+import time
+from dataclasses import dataclass, field
+from typing import *
+
+import numpy as np
+import torch
+
+from swift.llm import *
+from swift.utils import *
+
+
+@dataclass
+class TrainArguments(SftArguments):
+ run_time: int = 1
+ global_seed: int = 42
+
+ def __post_init__(self):
+ if self.model_type is None:
+ self.model_type = 'qwen-7b-chat'
+ if self.use_flash_attn is None:
+ self.use_flash_attn = True
+ return
+
+
+def get_non_default_args(train_args) -> Dict[str, Any]:
+ train_args_default = train_args.__class__()
+ res = {}
+ for k, v in train_args.__dict__.items():
+ v_default = getattr(train_args_default, k)
+ if v != v_default or k in {'use_flash_attn', 'model_type'}:
+ res[k] = v
+ return res
+
+
+def test_memory_time(train_args: TrainArguments) -> Dict[str, Dict[str, Any]]:
+ random_state = np.random.RandomState(train_args.global_seed)
+ args_kwargs = get_non_default_args(train_args)
+ print(f'args_kwargs: {args_kwargs}')
+ for i in range(train_args.run_time):
+ start_t = time.time()
+ sft_args = SftArguments(
+ dataset_test_ratio=0,
+ dataset=DatasetName.cls_fudan_news_zh,
+ train_dataset_sample=1000,
+ save_strategy='no',
+ check_dataset_strategy='warning',
+ truncation_strategy='truncation_left',
+ seed=get_seed(random_state),
+ preprocess_num_proc=4,
+ **args_kwargs)
+ output = sft_main(sft_args)
+ t = (time.time() - start_t) / 60 # min
+ max_memory = torch.cuda.max_memory_reserved() / 1024**2
+ torch.cuda.empty_cache()
+ output = {
+ 'time': f'{t}min',
+ 'memory': f'{max_memory}MiB',
+ 'train_args': check_json_format(args_kwargs),
+ 'model_info': output['model_info'],
+ }
+ append_to_jsonl('scripts/benchmark/test_memory_time/result.jsonl', output)
+ print(output)
+ return output
+
+
+test_memory_time_main = get_main(TrainArguments, test_memory_time)
+
+if __name__ == '__main__':
+ test_memory_time_main()
diff --git a/swift/llm/infer.py b/swift/llm/infer.py
index a4688521ad..219d7c1805 100644
--- a/swift/llm/infer.py
+++ b/swift/llm/infer.py
@@ -11,7 +11,7 @@
from transformers import PreTrainedModel
from swift.tuners import Swift
-from swift.utils import (append_to_jsonl, get_logger, print_model_info,
+from swift.utils import (append_to_jsonl, get_logger, get_model_info,
read_multi_line, seed_everything, show_layers)
from .utils import (InferArguments, Template, get_dataset, get_model_tokenizer,
get_template, inference, inference_stream,
@@ -124,7 +124,7 @@ def prepare_model_template(
model = Swift.from_pretrained(
model, args.ckpt_dir, inference_mode=True)
- print_model_info(model)
+ logger.info(get_model_info(model))
show_layers(model)
template: Template = get_template(args.template_type, tokenizer,
diff --git a/swift/llm/rome.py b/swift/llm/rome.py
index e656c4211b..06a69c012e 100644
--- a/swift/llm/rome.py
+++ b/swift/llm/rome.py
@@ -4,7 +4,7 @@
from modelscope import GenerationConfig
from swift.tuners import Swift
-from swift.utils import (get_logger, print_model_info, seed_everything,
+from swift.utils import (get_logger, get_model_info, seed_everything,
show_layers)
from ..tuners.rome import RomeConfig
from .utils import (RomeArguments, Template, get_dataset, get_model_tokenizer,
@@ -53,7 +53,7 @@ def rome_infer(args: RomeArguments) -> None:
model = Swift.prepare_model(model, config, inference_mode=True)
show_layers(model)
- print_model_info(model)
+ logger.info(get_model_info(model))
# Inference
template: Template = get_template(args.template_type, tokenizer,
diff --git a/swift/llm/sft.py b/swift/llm/sft.py
index 7d813f6a4d..aaa25a8401 100644
--- a/swift/llm/sft.py
+++ b/swift/llm/sft.py
@@ -13,9 +13,10 @@
LoRAConfig, NEFTuneConfig, Swift)
from swift.utils import (check_json_format, compute_acc_metrics,
compute_nlg_metrics, freeze_model_parameters,
- get_dist_setting, get_logger, is_ddp_plus_mp, is_dist,
- is_master, plot_images, preprocess_logits_for_metrics,
- print_model_info, seed_everything, show_layers)
+ get_dist_setting, get_logger, get_model_info,
+ is_ddp_plus_mp, is_dist, is_master, plot_images,
+ preprocess_logits_for_metrics, seed_everything,
+ show_layers)
from .utils import (LazyLLMDataset, SftArguments, Template,
add_self_cognition_dataset, data_collate_fn, dataset_map,
find_all_linear_for_lora, get_additional_saved_files,
@@ -124,7 +125,8 @@ def llm_sft(args: SftArguments) -> str:
logger.info(f'neftune_config: {neftune_config}')
show_layers(model)
- print_model_info(model)
+ model_info = get_model_info(model)
+ logger.info(model_info)
logger.info(model)
# Loading Dataset
@@ -220,7 +222,7 @@ def llm_sft(args: SftArguments) -> str:
lr_scheduler_type=args.lr_scheduler_type,
warmup_ratio=args.warmup_ratio,
logging_steps=args.logging_steps,
- save_strategy=IntervalStrategy.STEPS,
+ save_strategy=args.save_strategy,
save_steps=args.save_steps,
save_total_limit=args.save_total_limit,
remove_unused_columns=False,
@@ -251,7 +253,8 @@ def llm_sft(args: SftArguments) -> str:
deepspeed=args.deepspeed,
additional_saved_files=additional_saved_files,
disable_tqdm=args.disable_tqdm,
- save_on_each_node=args.save_on_each_node)
+ save_on_each_node=args.save_on_each_node,
+ acc_strategy=args.acc_strategy)
if args.gradient_checkpointing:
model.enable_input_require_grads()
@@ -272,7 +275,9 @@ def llm_sft(args: SftArguments) -> str:
trainer_kwargs['compute_metrics'] = partial(
compute_nlg_metrics, tokenizer=tokenizer)
else:
- trainer_kwargs['compute_metrics'] = compute_acc_metrics
+ compute_metrics = partial(
+ compute_acc_metrics, acc_strategy=args.acc_strategy)
+ trainer_kwargs['compute_metrics'] = compute_metrics
trainer_kwargs[
'preprocess_logits_for_metrics'] = preprocess_logits_for_metrics
if args.check_model_is_latest is False:
@@ -322,4 +327,5 @@ def llm_sft(args: SftArguments) -> str:
'best_metric': trainer.state.best_metric,
'global_step': trainer.state.global_step,
'log_history': trainer.state.log_history,
+ 'model_info': model_info,
}
diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py
index 28a697e1d6..48e67c174b 100644
--- a/swift/llm/utils/argument.py
+++ b/swift/llm/utils/argument.py
@@ -149,7 +149,11 @@ class SftArguments:
logging_dir: Optional[str] = None
report_to: Optional[List[str]] = None
check_model_is_latest: bool = True
+ acc_strategy: str = field(
+ default='token', metadata={'choices': ['token', 'sentence']})
save_on_each_node: bool = True
+ save_strategy: str = field(
+ default='steps', metadata={'choices': ['steps', 'no']})
# generation config
max_new_tokens: int = 2048
diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py
index 565653f950..f19ce091eb 100644
--- a/swift/llm/utils/utils.py
+++ b/swift/llm/utils/utils.py
@@ -266,11 +266,16 @@ def dataset_map(dataset: HfDataset,
return LLMDataset(data)
-def stat_dataset(llm_dataset: HfDataset) -> None:
+def stat_dataset(llm_dataset: Dataset) -> None:
"""Statistical analysis was performed on the dataset"""
_token_len = []
- for d in llm_dataset:
- _token_len.append(len(d['input_ids']))
+ if isinstance(llm_dataset, HfDataset):
+ input_ids = llm_dataset['input_ids']
+ for ii in input_ids:
+ _token_len.append(len(ii))
+ else:
+ for d in llm_dataset:
+ _token_len.append(len(d['input_ids']))
_, stat_str = stat_array(_token_len)
logger.info(f'Dataset Token Length: {stat_str}')
diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py
index f83378d8cb..d526a00509 100644
--- a/swift/trainers/arguments.py
+++ b/swift/trainers/arguments.py
@@ -19,6 +19,8 @@ class SwiftArgumentsMixin:
'choices':
{'end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'}
})
+ acc_strategy: str = field(
+ default='token', metadata={'choices': ['token', 'sentence']})
additional_saved_files: Optional[List[str]] = None
def __post_init__(self):
diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py
index 29bbc6ebda..620c122184 100644
--- a/swift/trainers/trainers.py
+++ b/swift/trainers/trainers.py
@@ -185,7 +185,17 @@ def compute_loss(self, model, inputs, return_outputs=None):
preds = outputs.logits.argmax(dim=2)[..., :-1]
labels = inputs['labels'][..., 1:]
masks = labels != -100
- acc: Tensor = (preds[masks] == labels[masks]).float().mean()
+ acc_strategy = getattr(self.args, 'acc_strategy', 'token')
+ acc: Tensor
+ if acc_strategy == 'sentence':
+ acc_list = []
+ for i, m in enumerate(masks):
+ acc_list.append(
+ torch.all(preds[i, m] == labels[i,
+ m]).to(torch.int64).item())
+ acc = torch.tensor(acc_list, device=preds.device).float().mean()
+ else:
+ acc = (preds[masks] == labels[masks]).float().mean()
if model.training:
if 'acc' not in self._custom_metrics:
self._custom_metrics['acc'] = torch.tensor(0.).to(
diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py
index 99be7f3858..8d53c4417e 100644
--- a/swift/utils/__init__.py
+++ b/swift/utils/__init__.py
@@ -9,9 +9,9 @@
from .tb_utils import (TB_COLOR, TB_COLOR_SMOOTH, plot_images,
read_tensorboard_file, tensorboard_smoothing)
from .torch_utils import (broadcast_string, freeze_model_parameters,
- get_dist_setting, is_ddp_plus_mp, is_dist,
- is_local_master, is_master, is_on_same_device,
- print_model_info, seed_everything, show_layers,
+ get_dist_setting, get_model_info, is_ddp_plus_mp,
+ is_dist, is_local_master, is_master,
+ is_on_same_device, seed_everything, show_layers,
time_synchronize)
from .utils import (add_version_to_work_dir, check_json_format, lower_bound,
parse_args, read_multi_line, test_time, upper_bound)
diff --git a/swift/utils/metric.py b/swift/utils/metric.py
index 5ffa3a5a13..80f9126317 100644
--- a/swift/utils/metric.py
+++ b/swift/utils/metric.py
@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
-from typing import Dict
+from typing import Dict, Literal
import jieba
import numpy as np
@@ -58,13 +58,20 @@ def _decode(tokens, ignore_pad_token_for_loss=False):
return score_dict
-def compute_acc_metrics(eval_prediction: EvalPrediction) -> Dict[str, Tensor]:
+def compute_acc_metrics(
+ eval_prediction: EvalPrediction,
+ acc_strategy: Literal['token',
+ 'sentence'] = 'token') -> Dict[str, Tensor]:
labels = eval_prediction.label_ids[..., 1:]
predictions = eval_prediction.predictions[..., :-1]
masks = labels != -100
- predictions = predictions[masks]
- labels = labels[masks]
- acc = np.mean((predictions == labels).astype(np.float64))
+ if acc_strategy == 'sentence':
+ acc_list = []
+ for i, m in enumerate(masks):
+ acc_list.append(np.all(predictions[i, m] == labels[i, m]))
+ acc = np.mean(np.array(acc_list))
+ else:
+ acc = np.mean((predictions[masks] == labels[masks]).astype(np.float64))
return {'acc': acc}
diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py
index ea459275cb..5524e74f5b 100644
--- a/swift/utils/torch_utils.py
+++ b/swift/utils/torch_utils.py
@@ -50,7 +50,7 @@ def seed_everything(seed: Optional[int] = None,
return seed
-def print_model_info(model: Module, name: Optional[str] = None) -> None:
+def get_model_info(model: Module, name: Optional[str] = None) -> str:
if name is None:
name = model.__class__.__name__
@@ -65,7 +65,7 @@ def print_model_info(model: Module, name: Optional[str] = None) -> None:
f'{n_params:.4f}M Params ({n_grads:.4f}M Trainable '
f'[{100 * n_grads / n_params:.4f}%]), '
f'{n_buffers:.4f}M Buffers.')
- logger.info(s)
+ return s
def find_sub_module(module: torch.nn.Module,