diff --git a/README.md b/README.md index 11836fdf6a..851f010f2d 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用 ## 🎉 News +- 2023.1.4: Update [Benchmark](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Benchmark.md) to facilitate viewing the training speed and GPU memory required for different models. - 🔥 2023.12.29: Support web-ui for training and inference, use `swift web-ui` after the installation of ms-swift. - 🔥 2023.12.29: Support DPO RLHF(Reinforcement Learning from Human Feedback) and two datasets: AI-ModelScope/stack-exchange-paired and AI-ModelScope/hh-rlhf for this task. Use [this script](https://github.com/modelscope/swift/blob/v1.5.0/examples/pytorch/llm/scripts/dpo/lora/dpo.sh) to start training! - 🔥 2023.12.28: Support SCEdit! This framework can easily reduce memory usage in training and inference, and replace ControlNet for controllable image generating scenarios, view the following chapter for details. @@ -75,6 +76,8 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用 - 2023.12.7: Support [Multi-Node DDP training](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM%E5%BE%AE%E8%B0%83%E6%96%87%E6%A1%A3.md#%E4%BD%BF%E7%94%A8cli). - 2023.12.4: Supported models: zephyr-7b-beta-chat, openbuddy-zephyr-7b-chat. Supported datasets: hc3-zh, hc3-en. - 🔥 2023.12.2: [Best Practices for Self-cognition Fine-tuning](https://github.com/modelscope/swift/blob/main/docs/source/LLM/自我认知微调最佳实践.md), **10 minutes for self-cognition fine-tuning for LLM**, creating a LLM that is specific to oneself. +
More + - 🔥 2023.11.30: Support for training and inference of the **qwen-1_8b**, **qwen-72b**, and **qwen-audio** model series. The corresponding shell scripts can be viewed at [qwen_1_8b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen_1_8b_chat), [qwen_72b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen_72b_chat), [qwen_audio_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen_audio_chat). - 🔥 2023.11.29: Support the training and inference for **AnimateDiff** - 🔥 2023.11.24: Support for **yi-34b-chat**, **codefuse-codellama-34b-chat**: The corresponding shell script can be found in [yi_34b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_34b_chat), [codefuse_codellama_34b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/codefuse_codellama_34b_chat). @@ -86,8 +89,6 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用 - 🔥 2023.11.10: Support for **bluelm** series models: bluelm-7b, bluelm-7b-chat, bluelm-7b-32k, bluelm-7b-chat-32k. The corresponding shell script can be found in [bluelm_7b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/bluelm_7b_chat). - 🔥 2023.11.08: Support the finetuning of **xverse-65b** model, scripts can be found at: [xverse_65b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/xverse_65b). - 🔥 2023.11.07: Support the finetuning of **yi-6b**, **yi-34b** model, scripts can be found at: [yi_6b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_6b), [yi_34b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_34b). -
More - - 🔥 2023.10.30: Support **QA-LoRA** and **LongLoRA** to decrease memory usage in training. - 🔥 2023.10.30: Support **ROME**(Rank One Model Editing) to add/modify knowledges, training is not needed! - 2023.10.30: Support for **skywork-13b** series models: skywork-13b, skywork-13b-chat. The corresponding shell script can be found in [skywork_13b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/skywork_13b). diff --git a/README_CN.md b/README_CN.md index 551ad3e0ee..2de3733a99 100644 --- a/README_CN.md +++ b/README_CN.md @@ -60,6 +60,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。 ## 🎉 新闻 +- 2023.1.4: 更新[Benchmark](https://github.com/modelscope/swift/blob/main/docs/source/LLM/Benchmark.md), 方便查看不同模型训练的速度和所需显存. - 🔥 2023.12.29: 支持web-ui进行sft训练和推理,安装ms-swift后使用`swift web-ui`开启 - 🔥 2023.12.29: 支持 DPO RLHF(Reinforcement Learning from Human Feedback) 和两个用于此任务的数据集: AI-ModelScope/stack-exchange-paired 以及 AI-ModelScope/hh-rlhf. 使用[这个脚本](https://github.com/modelscope/swift/blob/v1.5.0/examples/pytorch/llm/scripts/dpo/lora/dpo.sh)开启训练! - 🔥 2023.12.28: 支持SCEdit! 该tuner可显著降低U-Net中的显存占用,并支持低显存可控图像生成(取代ControlNet),阅读下面的章节来了解详细信息 @@ -73,6 +74,8 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 - 2023.12.7: 支持[Multi-Node DDP训练](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM%E5%BE%AE%E8%B0%83%E6%96%87%E6%A1%A3.md#%E4%BD%BF%E7%94%A8cli). - 2023.12.5: 支持模型: zephyr-7b-beta-chat, openbuddy-zephyr-7b-chat. 支持数据集: hc3-zh, hc3-en. - 🔥 2023.12.2: [自我认知微调最佳实践](https://github.com/modelscope/swift/blob/main/docs/source/LLM/自我认知微调最佳实践.md), **10分钟对大模型进行自我认知微调**, 创建专属于自己的大模型. +
更多 + - 🔥 2023.11.30: 支持**qwen-1_8b**, **qwen-72b**, **qwen-audio**系列模型的训练的推理. 对应的sh脚本可以查看[qwen_1_8b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen_1_8b_chat), [qwen_72b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen_72b_chat), [qwen_audio_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/qwen_audio_chat) - 🔥 2023.11.29: 支持**AnimateDiff**的训练和推理 - 🔥 2023.11.24: 支持**yi-34b-chat**, **codefuse-codellama-34b-chat**模型. 对应的sh脚本可以查看[yi_34b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_34b_chat), [codefuse_codellama_34b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/codefuse_codellama_34b_chat). @@ -84,8 +87,6 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 - 🔥 2023.11.10: 支持**bluelm**系列模型: bluelm-7b, bluelm-7b-chat, bluelm-7b-32k, bluelm-7b-chat-32k. 对应的sh脚本可以查看[bluelm_7b_chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/bluelm_7b_chat). - 🔥 2023.11.08: 支持**xverse-65b**模型的训练和推理流程,脚本在[xverse_65b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/xverse_65b). - 🔥 2023.11.07: 支持**yi-6b**, **yi-34b**模型的训练和推理流程,脚本在[yi_6b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_6b), [yi_34b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/yi_34b). -
更多 - - 🔥 2023.10.30: 支持 **QA-LoRA** 和 **LongLoRA**两种新的tuners. - 🔥 2023.10.30: 支持使用**ROME**(Rank One Model Editing)来编辑模型,在无需训练的情况下即可给模型灌注新知识! - 2023.10.30: 支持**skywork-13b**系列模型: skywork-13b, skywork-13b-chat. 对应的sh脚本可以查看[skywork_13b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/skywork_13b). diff --git a/docs/source/LLM/Benchmark.md b/docs/source/LLM/Benchmark.md index 4691e24b4b..98097ac563 100644 --- a/docs/source/LLM/Benchmark.md +++ b/docs/source/LLM/Benchmark.md @@ -2,16 +2,13 @@ ## 目录 - [参数设置](#参数设置) - [量化](#量化) -- [Max Length](#max-length) +- [Model Type & Max Length](#model-type--max-length) - [Batch Size](#batch-size) - [Use Flash Attn & Gradient Checkpointing](#use-flash-attn--gradient-checkpointing) -- [Model Type](#model-type) - [LoRA Rank & LoRA Target Modules](#lora-rank--lora-target-modules) +- [Gradient Accumulation Steps](#gradient-accumulation-steps) ## 参数设置 - -测试参数对于训练速度和训练内存使用的影响. 后续会补充部分参数对训练效果的影响. - 实验环境: - A100 - CUDA 11.8 @@ -23,15 +20,12 @@ - bitsandbytes 0.41.3.post2 -实验使用脚本可以查看`scripts/benchmark/test_memory_time/`. - 以下为所有实验的相同命令行设置部分: ```bash --dataset_test_ratio 0 \ --dataset cls-fudan-news-zh \ --save_strategy no \ --check_dataset_strategy warning \ - --truncation_strategy truncation_left \ --preprocess_num_proc 4 \ ``` @@ -44,10 +38,12 @@ --lora_rank 8 \ --lora_target_modules DEFAULT \ --quantization_bit 0 \ + --gradient_accumulation_steps 16 \ ``` -对应测试数据集的token数统计量(由qwen的tokenizer获取): 3234.4±2547.5, min=91, max=19548 +对应测试数据集的token数统计量(由qwen的tokenizer获取): 3234.4±2547.5, min=91, max=19548. +实验使用脚本可以查看`scripts/benchmark/test_memory_time/`. ## 量化 测试脚本为: @@ -131,7 +127,7 @@ swift sft \ -## Max Length +## Model Type & Max Length ### LoRA 测试脚本为: ```bash @@ -227,6 +223,188 @@ swift sft \ 0.79 60.74 + + qwen-72b-chat (2*A100) + 512 + 1.41 + 67.68+73.07 + + + 1024 + 1.02 + 70.25+77.11 + + + 2048 + 0.59 + 73.71+78.54 + + + 4096 + - + OOM + + + 8192 + - + OOM + + + chatglm3-6b + 512 + 6.72 + 13.94 + + + 1024 + 6.16 + 12.99 + + + 2048 + 4.20 + 17.20 + + + 4096 + 1.92 + 29.80 + + + 8192 + 1.24 + 66.82 + + + yi-6b-chat + 512 + 5.27 + 13.72 + + + 1024 + 5.07 + 15.44 + + + 2048 + 3.84 + 16.95 + + + 4096 + 1.99 + 28.25 + + + 8192 + 1.35 + 43.81 + + + yi-34b-chat + 512 + 2.32 + 66.72 + + + 1024 + 1.76 + 69.10 + + + 2048 + 1.05 + 71.34 + + + 4096 + 0.47 + 78.72 + + + 8192 + 0.31 (2*A100) + 47.01+65.03 + + + openbuddy-zephyr-7b-chat + 512 + 5.17 + 14.99 + + + 1024 + 3.92 + 16.57 + + + 2048 + 3.08 + 19.89 + + + 4096 + 1.85 + 23.29 + + + 8192 + 0.92 + 52.14 + + + baichuan2-7b-chat + 512 + 6.09 + 18.18 + + + 1024 + 5.36 + 17.45 + + + 2048 + 3.43 + 19.18 + + + 4096 + 1.69 + 34.22 + + + 8192 + 1.16 + 45.47 + + + baichuan2-13b-chat + 512 + 5.32 + 31.01 + + + 1024 + 3.91 + 31.58 + + + 2048 + 1.77 + 32.40 + + + 4096 + 0.65 + 49.63 + + + 8192 + 0.36 + 76.17 + @@ -414,77 +592,6 @@ swift sft \ -## Model Type -测试脚本为: -```bash -swift sft \ - --model_type {MODEL_TYPE} \ - --sft_type lora \ - ... -``` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Model Type [LoRA]Training Speed (samples/s)GPU Memory (GiB)
qwen-1_8b-chat8.7716.35
qwen-7b-chat4.3127.74
qwen-14b-chat2.6040.14
chatglm2-6b4.2620.43
chatglm3-6b4.2917.20
baichuan2-7b-chat3.4919.18
baichuan2-13b-chat1.9632.40
yi-6b-chat3.9816.28
yi-34b-chat1.0671.34
openbuddy-mistral-7b-chat3.2419.89
openbuddy-zephyr-7b-chat3.2519.89
## LoRA Rank & LoRA Target Modules 测试脚本为: @@ -536,3 +643,59 @@ swift sft \ 17.89 + + +## Gradient Accumulation Steps +测试脚本为: +```bash +swift sft \ + --gradient_accumulation_steps {GRADIENT_ACCUMULATION_STEPS} \ + --model_type qwen-7b-chat \ + --sft_type lora \ + ... +``` + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Model Type [LoRA]Gradient Accumulation StepsTraining Speed (samples/s)GPU Memory (GiB)
qwen-7b-chat14.2627.73
24.3227.74
44.3127.74
84.3227.74
164.3327.74
324.3027.74
644.3227.74
diff --git a/swift/llm/infer.py b/swift/llm/infer.py index 9897e527b9..1b8bfe20df 100644 --- a/swift/llm/infer.py +++ b/swift/llm/infer.py @@ -2,7 +2,7 @@ import datetime as dt import os import shutil -from typing import Literal, Tuple +from typing import Literal, Optional, Tuple import json import torch @@ -23,7 +23,7 @@ def merge_lora(args: InferArguments, replace_if_exists=False, device_map: str = 'auto', - **kwargs) -> str: + **kwargs) -> Optional[str]: logger.info(f'replace_if_exists: {replace_if_exists}') assert args.ckpt_dir is not None, 'args.ckpt_dir is not specified.' assert args.sft_type == 'lora', "Only supports sft_type == 'lora'" diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index a3fc959e91..0721e13a87 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -2,7 +2,7 @@ import math import os from dataclasses import dataclass, field -from typing import List, Optional, Set, Tuple, Union +from typing import List, Literal, Optional, Set, Tuple, Union import json import torch @@ -33,12 +33,9 @@ class SftArguments: model_revision: Optional[str] = None model_cache_dir: Optional[str] = None - sft_type: str = field( - default='lora', - metadata={'choices': ['lora', 'longlora', 'qalora', 'full']}) + sft_type: Literal['lora', 'longlora', 'qalora', 'full'] = 'lora' freeze_parameters: float = 0. # 0 ~ 1 - tuner_backend: str = field( - default='swift', metadata={'choices': ['swift', 'peft']}) + tuner_backend: Literal['swift', 'peft'] = 'swift' template_type: str = field( default='AUTO', metadata={ @@ -47,17 +44,14 @@ class SftArguments: }) output_dir: str = 'output' add_output_dir_suffix: bool = True - custom_output_dir_suffix: str = None - ddp_backend: str = field( - default='nccl', metadata={'choices': ['nccl', 'gloo', 'mpi', 'ccl']}) + ddp_backend: Literal['nccl', 'gloo', 'mpi', 'ccl'] = 'nccl' seed: int = 42 resume_from_checkpoint: Optional[str] = None - dtype: str = field( - default='AUTO', metadata={'choices': ['bf16', 'fp16', 'fp32', 'AUTO']}) + dtype: Literal['bf16', 'fp16', 'fp32', 'AUTO'] = 'AUTO' - dataset: Optional[List[str]] = field( - default=None, + dataset: List[str] = field( + default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'}) dataset_seed: int = 42 dataset_test_ratio: float = 0.01 @@ -65,29 +59,28 @@ class SftArguments: val_dataset_sample: Optional[int] = None # -1: all dataset system: Optional[str] = None max_length: int = 2048 # -1: no limit - truncation_strategy: str = field( - default='delete', metadata={'choices': ['delete', 'truncation_left']}) - check_dataset_strategy: str = field( - default='none', - metadata={'choices': ['none', 'discard', 'error', 'warning']}) - custom_train_dataset_path: Optional[List[str]] = None - custom_val_dataset_path: Optional[List[str]] = None + truncation_strategy: Literal['delete', 'truncation_left'] = 'delete' + check_dataset_strategy: Literal['none', 'discard', 'error', + 'warning'] = 'none' + custom_train_dataset_path: List[str] = field(default_factory=list) + custom_val_dataset_path: List[str] = field(default_factory=list) self_cognition_sample: int = 0 # Chinese name and English name - model_name: Optional[List[str]] = None # e.g. ['小黄', 'Xiao Huang'] - model_author: Optional[List[str]] = None # e.g. ['魔搭', 'ModelScope'] - + model_name: List[str] = field( + default_factory=lambda: [None, None], + metadata={'help': "e.g. ['小黄', 'Xiao Huang']"}) + model_author: List[str] = field( + default_factory=lambda: [None, None], + metadata={'help': "e.g. ['魔搭', 'ModelScope']"}) # If you want to use qlora, set the quantization_bit to 8 or 4. # And you need to install bitsandbytes: `pip install bitsandbytes -U` # note: bf16 and quantization have requirements for gpu architecture - quantization_bit: int = field(default=0, metadata={'choices': [0, 4, 8]}) - bnb_4bit_comp_dtype: str = field( - default='AUTO', metadata={'choices': ['fp16', 'bf16', 'fp32', 'AUTO']}) - bnb_4bit_quant_type: str = field( - default='nf4', metadata={'choices': ['fp4', 'nf4']}) + quantization_bit: Literal[0, 4, 8] = 0 + bnb_4bit_comp_dtype: Literal['fp16', 'bf16', 'fp32', 'AUTO'] = 'AUTO' + bnb_4bit_quant_type: Literal['fp4', 'nf4'] = 'nf4' bnb_4bit_use_double_quant: bool = True - lora_target_modules: Optional[List[str]] = None + lora_target_modules: List[str] = field(default_factory=lambda: ['DEFAULT']) lora_rank: int = 8 lora_alpha: int = 32 lora_dropout_p: float = 0.05 @@ -121,12 +114,8 @@ class SftArguments: # 'user_name/repo_name' or 'repo_name' hub_model_id: Optional[str] = None hub_private_repo: bool = True - push_hub_strategy: str = field( - default='push_best', - metadata={ - 'choices': - ['end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'] - }) + push_hub_strategy: Literal['end', 'push_best', 'push_last', 'checkpoint', + 'all_checkpoints'] = 'push_best' # None: use env var `MODELSCOPE_API_TOKEN` hub_token: Optional[str] = field( default=None, @@ -148,14 +137,13 @@ class SftArguments: preprocess_num_proc: int = 1 use_flash_attn: Optional[bool] = None ignore_args_error: bool = False # True: notebook compatibility - logging_dir: Optional[str] = None - report_to: Optional[List[str]] = None check_model_is_latest: bool = True - acc_strategy: str = field( - default='token', metadata={'choices': ['token', 'sentence']}) + + logging_dir: Optional[str] = None + report_to: List[str] = field(default_factory=lambda: ['all']) + acc_strategy: Literal['token', 'sentence'] = 'token' save_on_each_node: bool = True - save_strategy: str = field( - default='steps', metadata={'choices': ['steps', 'no']}) + save_strategy: Literal['steps', 'no'] = 'steps' save_safetensors: bool = True # generation config @@ -208,11 +196,7 @@ def __post_init__(self) -> None: if self.add_output_dir_suffix: self.output_dir = os.path.join(self.output_dir, self.model_type) - if self.custom_output_dir_suffix is not None: - self.output_dir = os.path.join(self.output_dir, - self.custom_output_dir_suffix) - else: - self.output_dir = add_version_to_work_dir(self.output_dir) + self.output_dir = add_version_to_work_dir(self.output_dir) logger.info(f'output_dir: {self.output_dir}') if self.sft_type in ('lora', 'longlora', 'qalora'): @@ -247,8 +231,6 @@ def __post_init__(self) -> None: logger.info(f'Setting template_type: {self.template_type}') if isinstance(self.dataset, str): self.dataset = [self.dataset] - elif self.dataset is None: - self.dataset = [] if len(self.dataset) == 0 and (len(self.custom_train_dataset_path) == 0 and len( self.custom_val_dataset_path) == 0 @@ -257,9 +239,7 @@ def __post_init__(self) -> None: if self.save_steps is None: self.save_steps = self.eval_steps - if self.lora_target_modules is None: - self.lora_target_modules = ['DEFAULT'] - elif isinstance(self.lora_target_modules, str): + if isinstance(self.lora_target_modules, str): self.lora_target_modules = [self.lora_target_modules] if 'DEFAULT' in self.lora_target_modules or 'AUTO' in self.lora_target_modules: assert len(self.lora_target_modules) == 1 @@ -288,8 +268,6 @@ def __post_init__(self) -> None: logger.info(f'Using deepspeed: {self.deepspeed}') if self.logging_dir is None: self.logging_dir = f'{self.output_dir}/runs' - if self.report_to is None: - self.report_to = ['all'] if self.gradient_accumulation_steps is None: self.gradient_accumulation_steps = math.ceil(16 / self.batch_size / world_size) @@ -323,17 +301,14 @@ class InferArguments: model_revision: Optional[str] = None model_cache_dir: Optional[str] = None - sft_type: str = field( - default='lora', - metadata={'choices': ['lora', 'longlora', 'qalora', 'full']}) + sft_type: Literal['lora', 'longlora', 'qalora', 'full'] = 'lora' template_type: str = field( default='AUTO', metadata={ 'help': f"template_type choices: {list(TEMPLATE_MAPPING.keys()) + ['AUTO']}" }) - infer_backend: str = field( - default='AUTO', metadata={'choices': ['AUTO', 'vllm', 'pt']}) + infer_backend: Literal['AUTO', 'vllm', 'pt'] = 'AUTO' ckpt_dir: Optional[str] = field( default=None, metadata={'help': '/path/to/your/vx_xxx/checkpoint-xxx'}) load_args_from_ckpt_dir: bool = True @@ -341,11 +316,10 @@ class InferArguments: eval_human: Optional[bool] = None seed: int = 42 - dtype: str = field( - default='AUTO', metadata={'choices': ['bf16', 'fp16', 'fp32', 'AUTO']}) + dtype: Literal['bf16', 'fp16', 'fp32', 'AUTO'] = 'AUTO' - dataset: Optional[List[str]] = field( - default=None, + dataset: List[str] = field( + default_factory=list, metadata={'help': f'dataset choices: {list(DATASET_MAPPING.keys())}'}) dataset_seed: int = 42 dataset_test_ratio: float = 0.01 @@ -353,19 +327,15 @@ class InferArguments: save_result: bool = True system: Optional[str] = None max_length: int = 2048 # -1: no limit - truncation_strategy: str = field( - default='delete', metadata={'choices': ['delete', 'truncation_left']}) - check_dataset_strategy: str = field( - default='none', - metadata={'choices': ['none', 'discard', 'error', 'warning']}) - custom_train_dataset_path: Optional[List[str]] = None - custom_val_dataset_path: Optional[List[str]] = None - - quantization_bit: int = field(default=0, metadata={'choices': [0, 4, 8]}) - bnb_4bit_comp_dtype: str = field( - default='AUTO', metadata={'choices': ['fp16', 'bf16', 'fp32', 'AUTO']}) - bnb_4bit_quant_type: str = field( - default='nf4', metadata={'choices': ['fp4', 'nf4']}) + truncation_strategy: Literal['delete', 'truncation_left'] = 'delete' + check_dataset_strategy: Literal['none', 'discard', 'error', + 'warning'] = 'none' + custom_train_dataset_path: List[str] = field(default_factory=list) + custom_val_dataset_path: List[str] = field(default_factory=list) + + quantization_bit: Literal[0, 4, 8] = 0 + bnb_4bit_comp_dtype: Literal['fp16', 'bf16', 'fp32', 'AUTO'] = 'AUTO' + bnb_4bit_quant_type: Literal['fp4', 'nf4'] = 'nf4' bnb_4bit_use_double_quant: bool = True max_new_tokens: int = 2048 @@ -421,8 +391,6 @@ def __post_init__(self) -> None: logger.info(f'Setting template_type: {self.template_type}') if isinstance(self.dataset, str): self.dataset = [self.dataset] - elif self.dataset is None: - self.dataset = [] has_dataset = ( len(self.dataset) > 0 or len(self.custom_train_dataset_path) > 0 or len(self.custom_val_dataset_path) > 0) @@ -706,9 +674,7 @@ def handle_path(args: Union[SftArguments, InferArguments]) -> None: def register_custom_dataset(args: Union[SftArguments, InferArguments]) -> None: for key in ['custom_train_dataset_path', 'custom_val_dataset_path']: value = getattr(args, key) - if value is None: - setattr(args, key, []) - elif isinstance(value, str): + if isinstance(value, str): setattr(args, key, [value]) if len(args.custom_train_dataset_path) == 0 and len( args.custom_val_dataset_path) == 0: @@ -749,7 +715,7 @@ def load_from_ckpt_dir(args: InferArguments) -> None: if (key in { 'dataset', 'custom_train_dataset_path', 'custom_val_dataset_path' - } and getattr(args, key) is not None): + } and len(getattr(args, key)) > 0): continue setattr(args, key, sft_args.get(key)) diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index b8c8dee611..4bd263f221 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -940,7 +940,7 @@ def _preprocess_hc3(dataset: HfDataset) -> HfDataset: def add_self_cognition_dataset( train_dataset: HfDataset, dataset_sample: int, model_name: Tuple[str, Optional[str]], - model_author: Tuple[str, Optional[str]]) -> None: + model_author: Tuple[str, Optional[str]]) -> HfDataset: assert model_name[0] is not None assert model_author[0] is not None if model_name[1] is None: diff --git a/swift/llm/utils/preprocess.py b/swift/llm/utils/preprocess.py index 523119bf6f..874b62470b 100644 --- a/swift/llm/utils/preprocess.py +++ b/swift/llm/utils/preprocess.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import ast -from typing import Callable, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional, Union from datasets import Dataset as HfDataset from tqdm import tqdm @@ -56,7 +56,7 @@ def __call__(self, dataset: HfDataset) -> HfDataset: return dataset -def _default_repair_conversations(s: str) -> Dict[str, str]: +def _default_repair_conversations(s: Union[str, Any]) -> Any: if isinstance(s, str): return ast.literal_eval(s) return s diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 8f466861f1..fc6a00d590 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -9,8 +9,7 @@ from functools import partial, wraps from queue import Empty, Queue from tempfile import TemporaryDirectory -from typing import (Any, Callable, Dict, Iterator, List, Optional, Tuple, - TypeVar, Union) +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union import accelerate import multiprocess @@ -267,7 +266,7 @@ def _map_mp(dataset: HfDataset, map_func: MapFunc, def dataset_map(dataset: HfDataset, map_func: MapFunc, - num_proc: int = 1) -> LLMDataset: + num_proc: int = 1) -> Optional[LLMDataset]: single_map = partial(_single_map, map_func=map_func) if num_proc == 1: data = [] @@ -628,7 +627,8 @@ def inference(model: PreTrainedModel, def limit_history_length(template: Template, query: str, - history: Optional[History], max_length: int) -> int: + history: Optional[History], + max_length: int) -> Tuple[History, History]: """binary search""" if history is None: history = [] diff --git a/swift/llm/utils/vllm_utils.py b/swift/llm/utils/vllm_utils.py index 8c23b20328..8d3e985835 100644 --- a/swift/llm/utils/vllm_utils.py +++ b/swift/llm/utils/vllm_utils.py @@ -1,7 +1,7 @@ import inspect import os from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch from modelscope import GenerationConfig, snapshot_download @@ -91,20 +91,20 @@ class VllmGenerationConfig(SamplingParams): def __init__( self, - max_length: int = 20, - max_new_tokens: Optional[int] = None, + max_new_tokens: int = 64, # max_tokens temperature: float = 1., top_k: int = 50, # -1: all - top_p: float = 1.0, + top_p: float = 1., repetition_penalty: float = 1., num_beams: int = 1, *, - length_penalty: float = 1.0, + n: int = 1, + length_penalty: float = 1., stop: Optional[List[str]] = None, **kwargs, - ): + ) -> None: # The parameter design is similar to transformers.GenerationConfig. - if num_beams: + if num_beams > 1: top_k = -1 top_p = 1 temperature = 0 @@ -113,8 +113,9 @@ def __init__( ) if top_k == 0: top_k = -1 - self.max_new_tokens = max_new_tokens - kwargs['max_tokens'] = max_length + if stop is None: + stop = [] + kwargs['max_tokens'] = max_new_tokens kwargs['temperature'] = temperature kwargs['top_k'] = top_k kwargs['top_p'] = top_p @@ -123,6 +124,7 @@ def __init__( assert 'use_beam_search' not in kwargs and 'best_of' not in kwargs kwargs['use_beam_search'] = True kwargs['best_of'] = num_beams + kwargs['n'] = n kwargs['length_penalty'] = length_penalty kwargs['stop'] = stop parameters = inspect.signature(SamplingParams.__init__).parameters @@ -134,13 +136,15 @@ def __init__( kwargs.pop(k) super().__init__(**kwargs) - @property - def max_length(self) -> int: - return self.max_tokens - - @max_length.setter - def max_length(self, value: int) -> None: - self.max_tokens = value + def __setattr__(self, key: str, value: str) -> None: + if key == 'max_new_tokens': + self.max_tokens = value + elif key == 'max_length': + raise ValueError( + '`max_length` is not supported, please use `max_new_tokens` for setting.' + ) + else: + super().__setattr__(key, value) def inference_stream_vllm( @@ -150,7 +154,7 @@ def inference_stream_vllm( *, generation_config: Optional[VllmGenerationConfig] = None, use_tqdm: bool = False, - **kwargs) -> List[Dict[str, Any]]: + **kwargs) -> Iterator[List[Dict[str, Any]]]: """ request_list: e.g. [{'query': 'hello!'}]. The keys that can be included are: 'query', 'history', 'system'. @@ -177,9 +181,6 @@ def inference_stream_vllm( tokenizer = template.tokenizer if tokenizer.eos_token is not None and tokenizer.eos_token not in generation_config.stop: generation_config.stop.append(tokenizer.eos_token) - if generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + len( - input_ids) llm_engine.add_request(str(i), None, generation_config, input_ids) batch_size = len(request_list) @@ -245,9 +246,6 @@ def inference_vllm(llm_engine: LLMEngine, tokenizer = template.tokenizer if tokenizer.eos_token is not None and tokenizer.eos_token not in generation_config.stop: generation_config.stop.append(tokenizer.eos_token) - if generation_config.max_new_tokens is not None: - generation_config.max_length = generation_config.max_new_tokens + len( - input_ids) llm_engine.add_request(str(i), None, generation_config, input_ids) batch_size = len(request_list) diff --git a/swift/trainers/utils.py b/swift/trainers/utils.py index 518046b2ee..f5d7cafccf 100644 --- a/swift/trainers/utils.py +++ b/swift/trainers/utils.py @@ -17,7 +17,7 @@ ShardedDDPOption = None -def can_return_loss(model: Module) -> List[str]: +def can_return_loss(model: Module) -> bool: """Check if a given model can return loss.""" signature = inspect.signature(model.forward) for p in signature.parameters: diff --git a/tests/llm/test_run.py b/tests/llm/test_run.py index 406bf3e5ab..226f8c4740 100644 --- a/tests/llm/test_run.py +++ b/tests/llm/test_run.py @@ -188,7 +188,7 @@ def test_self_cognition(self): if not __name__ == '__main__': # ignore citest error in github return - for dataset in [None, [DatasetName.alpaca_zh, DatasetName.alpaca_en]]: + for dataset in [[], [DatasetName.alpaca_zh, DatasetName.alpaca_en]]: sft_args = SftArguments( model_type=ModelType.qwen_7b_chat, dataset=dataset, # no dataset @@ -208,7 +208,7 @@ def test_self_cognition(self): print(f'last_model_checkpoint: {last_model_checkpoint}') print(f'best_model_checkpoint: {best_model_checkpoint}') ckpt_dir = best_model_checkpoint or last_model_checkpoint - if dataset is None: + if len(dataset) == 0: continue infer_args = InferArguments( ckpt_dir=ckpt_dir,