Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<p align="center">
<a href="https://modelscope.cn/home">ModelScope Hub</a>
<br>
<a href="README_CN.md">中文</a>&nbsp | &nbspEnglish
<a href="README_CN.md">中文</a>&nbsp | &nbspEnglish&nbsp | &nbsp<a href="https://github.com/modelscope/swift/blob/main/docs/source/GetStarted/%E5%BF%AB%E9%80%9F%E4%BD%BF%E7%94%A8.md">Docs</a>
</p>

<p align="center">
Expand Down Expand Up @@ -60,6 +60,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用


## 🎉 News
- 2023.12.19: Support [phi2-3b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/phi2_3b).
- 2023.12.18: Support for **VLLM** for inference acceleration and deployment. For more details, refer to [VLLM Inference Acceleration and Deployment](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md).
- 2023.12.15: Support **deepseek**, **deepseek-coder** series: deepseek-7b, deepseek-7b-chat, deepseek-67b, deepseek-67b-chat, openbuddy-deepseek-67b-chat, deepseek-coder-1_3b, deepseek-coder-1_3b-chat, deepseek-coder-6_7b, deepseek-coder-6_7b-chat, deepseek-coder-33b, deepseek-coder-33b-chat.
- 2023.12.13: Support mistral-7b-chat-v2, [mixtral-7b-moe](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe), [mixtral-7b-moe-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe_chat).
Expand Down Expand Up @@ -139,6 +140,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用
- Coding:
- codefuse series: [codefuse-codellama-34b-chat](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B/summary)
- deepseek-coder series: [deepseek-coder-1_3b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-base/summary), [deepseek-coder-1_3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-instruct/summary), [deepseek-coder-6_7b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-base/summary), [deepseek-coder-6_7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-instruct/summary), [deepseek-coder-33b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-base/summary), [deepseek-coder-33b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-instruct/summary)
- phi series: [phi2-3b](https://modelscope.cn/models/AI-ModelScope/phi-2/summary)
- Supported Datasets: [[Detail]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
- NLP:
- General: 🔥[alpaca-en](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-en/summary)(gpt4), 🔥[alpaca-zh](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-zh/summary)(gpt4), [multi-alpaca-all](https://www.modelscope.cn/datasets/damo/nlp_polylm_multialpaca_sft/summary), [instinwild-en](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [instinwild-zh](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [cot-en](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [cot-zh](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [firefly-all-zh](https://www.modelscope.cn/datasets/wyj123456/firefly/summary), [instruct-en](https://www.modelscope.cn/datasets/wyj123456/instruct/summary), [gpt4all-en](https://www.modelscope.cn/datasets/wyj123456/GPT4all/summary), [sharegpt-en](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [sharegpt-zh](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [tutu-v2-sft-mixture](https://modelscope.cn/datasets/AI-ModelScope/tulu-v2-sft-mixture/summary), [wikipedia-zh](https://modelscope.cn/datasets/AI-ModelScope/wikipedia-cn-20230720-filtered/summary), [open-orca](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [open-orca-gpt4](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [sharegpt-gpt4](https://modelscope.cn/datasets/AI-ModelScope/sharegpt_gpt4/summary)
Expand Down
4 changes: 3 additions & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<p align="center">
<a href="https://modelscope.cn/home">魔搭社区</a>
<br>
中文&nbsp | &nbsp<a href="README.md">English</a>
中文&nbsp | &nbsp<a href="README.md">English</a>&nbsp | &nbsp<a href="https://github.com/modelscope/swift/blob/main/docs/source/GetStarted/%E5%BF%AB%E9%80%9F%E4%BD%BF%E7%94%A8.md">文档</a>
</p>

<p align="center">
Expand Down Expand Up @@ -58,6 +58,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。

## 🎉 新闻
- 2023.12.19: 支持[phi2-3b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/phi2_3b).
- 2023.12.18: 支持**VLLM**进行推理加速和部署. 具体可以查看[VLLM推理加速与部署](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md).
- 2023.12.15: 支持**deepseek**, **deepseek-coder**系列: deepseek-7b, deepseek-7b-chat, deepseek-67b, deepseek-67b-chat, openbuddy-deepseek-67b-chat, deepseek-coder-1_3b, deepseek-coder-1_3b-chat, deepseek-coder-6_7b, deepseek-coder-6_7b-chat, deepseek-coder-33b, deepseek-coder-33b-chat.
- 2023.12.13: 支持mistral-7b-chat-v2, [mixtral-7b-moe](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe), [mixtral-7b-moe-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe_chat).
Expand Down Expand Up @@ -137,6 +138,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
- 代码:
- codefuse 系列: [codefuse-codellama-34b-chat](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B/summary)
- deepseek-coder 系列: [deepseek-coder-1_3b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-base/summary), [deepseek-coder-1_3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-instruct/summary), [deepseek-coder-6_7b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-base/summary), [deepseek-coder-6_7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-instruct/summary), [deepseek-coder-33b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-base/summary), [deepseek-coder-33b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-instruct/summary)
- phi 系列: [phi2-3b](https://modelscope.cn/models/AI-ModelScope/phi-2/summary)
- 支持的数据集: [[详细]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86)
- NLP:
- 通用: 🔥[alpaca-en](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-en/summary)(gpt4), 🔥[alpaca-zh](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-zh/summary)(gpt4), [multi-alpaca-all](https://www.modelscope.cn/datasets/damo/nlp_polylm_multialpaca_sft/summary), [instinwild-en](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [instinwild-zh](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [cot-en](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [cot-zh](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [firefly-all-zh](https://www.modelscope.cn/datasets/wyj123456/firefly/summary), [instruct-en](https://www.modelscope.cn/datasets/wyj123456/instruct/summary), [gpt4all-en](https://www.modelscope.cn/datasets/wyj123456/GPT4all/summary), [sharegpt-en](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [sharegpt-zh](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [tutu-v2-sft-mixture](https://modelscope.cn/datasets/AI-ModelScope/tulu-v2-sft-mixture/summary), [wikipedia-zh](https://modelscope.cn/datasets/AI-ModelScope/wikipedia-cn-20230720-filtered/summary), [open-orca](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [open-orca-gpt4](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [sharegpt-gpt4](https://modelscope.cn/datasets/AI-ModelScope/sharegpt_gpt4/summary)
Expand Down
3 changes: 3 additions & 0 deletions docs/source/GetStarted/快速使用.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pip install ms-swift -U
SWIFT库提供了**LLM&AIGC模型的训练推理脚手架**,支持LLaMA、QWen、ChatGLM、Stable Diffusion等多种模型的直接训练和推理,并且集成了SWIFT库提供的tuners,
开发者可以直接使用。它们的位置在:https://github.com/modelscope/swift/tree/main/examples/pytorch/llm

- LLM训练和推理可以查看: [LLM微调文档](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM微调文档.md)
- AIGC训练和推理可以查看: [文生图微调文档](https://github.com/modelscope/swift/blob/main/docs/source/AIGC/AnimateDiff微调推理文档.md)

如果需要在自定义的训练流程中使用tuners,可以参考下面的代码。下面的代码使用LoRA在分类任务上训练了`bert-base-uncased`模型:

```python
Expand Down
1 change: 1 addition & 0 deletions docs/source/LLM/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
|deepseek-coder-6_7b-chat|[deepseek-ai/deepseek-coder-6.7b-instruct](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-instruct/summary)|q_proj, k_proj, v_proj|deepseek-coder|&#x2714;|&#x2714;||
|deepseek-coder-33b|[deepseek-ai/deepseek-coder-33b-base](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|&#x2714;|&#x2714;||
|deepseek-coder-33b-chat|[deepseek-ai/deepseek-coder-33b-instruct](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-instruct/summary)|q_proj, k_proj, v_proj|deepseek-coder|&#x2714;|&#x2714;||
|phi2-3b|[AI-ModelScope/phi-2](https://modelscope.cn/models/AI-ModelScope/phi-2/summary)|Wqkv|default-generation|&#x2714;|&#x2714;||


## 数据集
Expand Down
13 changes: 13 additions & 0 deletions examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Experimental environment: A10
# 8GB GPU memory
CUDA_VISIBLE_DEVICES=0 \
swift infer \
--ckpt_dir "output/phi2-3b/vx_xxx/checkpoint-xxx" \
--load_dataset_config true \
--max_length 2048 \
--use_flash_attn false \
--max_new_tokens 2048 \
--temperature 0.1 \
--top_p 0.7 \
--repetition_penalty 1.05 \
--merge_lora_and_save false \
18 changes: 18 additions & 0 deletions examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Experimental environment: A100
# 60GB GPU memory
CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model_type phi2-3b \
--sft_type lora \
--template_type default \
--train_dataset_sample 20000 \
--eval_steps 100 \
--output_dir output \
--num_train_epochs 1 \
--max_length 2048 \
--learning_rate 1e-4 \
--use_flash_attn true \
--only_save_model true \
--lora_target_modules ALL \
--dataset codefuse-python-en \
--gradient_checkpointing false \
15 changes: 10 additions & 5 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class SftArguments:

neftune_alpha: float = 0.0

gradient_checkpointing: bool = True
gradient_checkpointing: Optional[bool] = None
deepspeed_config_path: Optional[str] = None # e.g. 'ds_config/zero2.json'
batch_size: int = 1
eval_batch_size: Optional[int] = None
Expand Down Expand Up @@ -299,8 +299,13 @@ def __post_init__(self) -> None:
logger.info(
f'Setting self.preprocess_num_proc: {self.preprocess_num_proc}'
)
if 'moe' in self.model_type:
assert self.gradient_checkpointing is False, 'moe not support gradient_checkpointing'
model_info = MODEL_MAPPING[self.model_type]
support_gradient_checkpointing = model_info.get(
'support_gradient_checkpointing', True)
if self.gradient_checkpointing is None:
self.gradient_checkpointing = support_gradient_checkpointing
elif not support_gradient_checkpointing:
assert self.gradient_checkpointing is False, 'not support gradient_checkpointing'


@dataclass
Expand Down Expand Up @@ -428,8 +433,8 @@ def __post_init__(self) -> None:
logger.warning('Setting overwrite_generation_config: False')
if self.ckpt_dir is None:
self.sft_type = 'full'
support_vllm = MODEL_MAPPING[self.model_type].get(
'support_vllm', False)
model_info = MODEL_MAPPING[self.model_type]
support_vllm = model_info.get('support_vllm', False)
if self.infer_backend == 'AUTO':
if self.sft_type == 'full' and is_vllm_available(
) and support_vllm:
Expand Down
30 changes: 28 additions & 2 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class ModelType:
deepseek_coder_6_7b_chat = 'deepseek-coder-6_7b-chat'
deepseek_coder_33b = 'deepseek-coder-33b'
deepseek_coder_33b_chat = 'deepseek-coder-33b-chat'
# phi
phi2_3b = 'phi2-3b'

@classmethod
def get_model_name_list(cls) -> List[str]:
Expand All @@ -170,6 +172,7 @@ class LoRATM(NamedTuple):
qwen = ['c_attn']
polylm = ['c_attn']
bloom = ['query_key_value']
phi = ['Wqkv']


GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel],
Expand Down Expand Up @@ -777,15 +780,17 @@ def cross_entropy_forward(self, inputs: Tensor,
TemplateType.default_generation_bos,
requires=['transformers>=4.36'],
support_flash_attn=True,
support_vllm=True)
support_vllm=True,
support_gradient_checkpointing=False)
@register_model(
ModelType.mixtral_7b_moe_chat,
'AI-ModelScope/Mixtral-8x7B-Instruct-v0.1',
LoRATM.llama2,
TemplateType.llama,
requires=['transformers>=4.36'],
support_flash_attn=True,
support_vllm=True)
support_vllm=True,
support_gradient_checkpointing=False)
def get_model_tokenizer_with_flash_attn(model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
Expand Down Expand Up @@ -1291,6 +1296,27 @@ def get_model_tokenizer_codellama(model_dir: str,
**kwargs)


@register_model(
ModelType.phi2_3b,
'AI-ModelScope/phi-2',
LoRATM.phi,
TemplateType.default_generation,
support_flash_attn=True,
support_vllm=True,
support_gradient_checkpointing=False)
def get_model_tokenizer_phi(model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
use_flash_attn = kwargs.get('use_flash_attn', False)
model_config.flash_attn = use_flash_attn
return get_model_tokenizer_from_repo(model_dir, torch_dtype, model_kwargs,
load_model, model_config, **kwargs)


def fix_transformers_upgrade(module: PreTrainedModel) -> None:
# from 4.35, transformers changes its arguments of _set_gradient_checkpointing
if version.parse(transformers.__version__) >= version.parse('4.35'):
Expand Down
8 changes: 6 additions & 2 deletions swift/llm/utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,18 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase,
self.tokenizer = tokenizer
self.stop_words = stop_words
self.decode_kwargs = decode_kwargs
self.start_idx = -1

def __call__(self, input_ids: Tensor, scores: Tensor) -> bool:
if self.start_idx == -1:
self.start_idx = len(input_ids[0]) - 1
tokenizer = self.tokenizer
stop_words = self.stop_words
text = tokenizer.decode(input_ids[0], **self.decode_kwargs)
text = tokenizer.decode(input_ids[0, self.start_idx:],
**self.decode_kwargs)
for stop_word in stop_words:
if isinstance(stop_word, str):
if text.endswith(stop_word):
if stop_word in text:
return True
elif isinstance(stop_word, list) and len(stop_word) > 0:
res = []
Expand Down
17 changes: 13 additions & 4 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
from swift.hub import ModelScopeConfig
from swift.utils import (get_dist_setting, get_logger, is_ddp_plus_mp, is_dist,
is_local_master, is_master, stat_array, upper_bound)
from .template import History, StopWordsCriteria, Template, get_audio_info
from .template import (History, StopWords, StopWordsCriteria, Template,
get_audio_info)

logger = get_logger()
ms_logger = get_ms_logger()
Expand Down Expand Up @@ -446,11 +447,14 @@ def inference_stream(
history: Optional[History] = None,
system: Optional[str] = None,
*,
generation_config: Optional[GenerationConfig] = None
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[List[StopWords]] = None,
) -> Iterator[Tuple[str, History]]:
"""
generation_config: Priority: generation_config > model.generation_config.
"""
if stop_words is None:
stop_words = []
if history is None:
history = []
else:
Expand Down Expand Up @@ -479,7 +483,8 @@ def inference_stream(
if stream_config.max_new_tokens is not None:
stream_config.max_length = 20 # fix max_length, max_new_tokens warning
stream_config.do_sample = True # avoid is_greedy_gen_mode = True
stop_words = [template.suffix[-1]]
if template.suffix[-1] not in stop_words:
stop_words.append(template.suffix[-1])
decode_kwargs = {}
model_kwargs = {}
if audio_info is not None:
Expand Down Expand Up @@ -522,13 +527,16 @@ def inference(model: PreTrainedModel,
system: Optional[str] = None,
*,
generation_config: Optional[GenerationConfig] = None,
stop_words: Optional[List[StopWords]] = None,
stream: bool = False,
verbose: bool = False,
prompt_prefix: str = '[PROMPT]',
output_prefix: str = '[OUTPUT]') -> Tuple[str, History]:
"""
generation_config: Priority: generation_config > model.generation_config.
"""
if stop_words is None:
stop_words = []
if history is None:
history = []
else:
Expand Down Expand Up @@ -569,7 +577,8 @@ def inference(model: PreTrainedModel,
generation_config.pad_token_id = tokenizer.pad_token_id
if generation_config.max_new_tokens is not None:
generation_config.max_length = 20 # fix max_length, max_new_tokens warning
stop_words = [template.suffix[-1]]
if template.suffix[-1] not in stop_words:
stop_words.append(template.suffix[-1])
stopping_criteria = StoppingCriteriaList(
[StopWordsCriteria(tokenizer, stop_words, **decode_kwargs)])
generate_ids = model.generate(
Expand Down