From 5168b84a9a1275e778d73439f1082b4480301775 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 Dec 2023 21:45:00 +0800 Subject: [PATCH 1/6] update readme & phi2-3b --- README.md | 4 ++- README_CN.md | 4 ++- ...53\351\200\237\344\275\277\347\224\250.md" | 3 ++ ...14\346\225\260\346\215\256\351\233\206.md" | 1 + .../pytorch/llm/scripts/phi2_3b/lora/infer.sh | 13 ++++++++ .../pytorch/llm/scripts/phi2_3b/lora/sft.sh | 17 +++++++++++ swift/llm/utils/argument.py | 15 ++++++---- swift/llm/utils/model.py | 30 +++++++++++++++++-- swift/llm/utils/template.py | 8 +++-- swift/llm/utils/utils.py | 17 ++++++++--- 10 files changed, 97 insertions(+), 15 deletions(-) create mode 100644 examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh create mode 100644 examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh diff --git a/README.md b/README.md index af7a5b46f8..6e55aa2870 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@

-ModelScope Hub +ModelScope Hub  |  Docs
中文  |  English

@@ -60,6 +60,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用 ## 🎉 News +- 2023.12.19: Support [phi2-3b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/phi2_3b). - 2023.12.18: Support for **VLLM** for inference acceleration and deployment. For more details, refer to [VLLM Inference Acceleration and Deployment](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md). - 2023.12.15: Support **deepseek**, **deepseek-coder** series: deepseek-7b, deepseek-7b-chat, deepseek-67b, deepseek-67b-chat, openbuddy-deepseek-67b-chat, deepseek-coder-1_3b, deepseek-coder-1_3b-chat, deepseek-coder-6_7b, deepseek-coder-6_7b-chat, deepseek-coder-33b, deepseek-coder-33b-chat. - 2023.12.13: Support mistral-7b-chat-v2, [mixtral-7b-moe](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe), [mixtral-7b-moe-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe_chat). @@ -139,6 +140,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用 - Coding: - codefuse series: [codefuse-codellama-34b-chat](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B/summary) - deepseek-coder series: [deepseek-coder-1_3b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-base/summary), [deepseek-coder-1_3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-instruct/summary), [deepseek-coder-6_7b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-base/summary), [deepseek-coder-6_7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-instruct/summary), [deepseek-coder-33b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-base/summary), [deepseek-coder-33b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-instruct/summary) + - phi series: [phi2-3b](https://modelscope.cn/models/AI-ModelScope/phi-2/summary) - Supported Datasets: [[Detail]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86) - NLP: - General: 🔥[alpaca-en](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-en/summary)(gpt4), 🔥[alpaca-zh](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-zh/summary)(gpt4), [multi-alpaca-all](https://www.modelscope.cn/datasets/damo/nlp_polylm_multialpaca_sft/summary), [instinwild-en](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [instinwild-zh](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [cot-en](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [cot-zh](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [firefly-all-zh](https://www.modelscope.cn/datasets/wyj123456/firefly/summary), [instruct-en](https://www.modelscope.cn/datasets/wyj123456/instruct/summary), [gpt4all-en](https://www.modelscope.cn/datasets/wyj123456/GPT4all/summary), [sharegpt-en](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [sharegpt-zh](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [tutu-v2-sft-mixture](https://modelscope.cn/datasets/AI-ModelScope/tulu-v2-sft-mixture/summary), [wikipedia-zh](https://modelscope.cn/datasets/AI-ModelScope/wikipedia-cn-20230720-filtered/summary), [open-orca](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [open-orca-gpt4](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [sharegpt-gpt4](https://modelscope.cn/datasets/AI-ModelScope/sharegpt_gpt4/summary) diff --git a/README_CN.md b/README_CN.md index 0d6a213e3f..ec80a92bce 100644 --- a/README_CN.md +++ b/README_CN.md @@ -7,7 +7,7 @@

-魔搭社区 +魔搭社区  |  文档
中文  |  English

@@ -58,6 +58,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。 ## 🎉 新闻 +- 2023.12.19: 支持[phi2-3b](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/phi2_3b). - 2023.12.18: 支持**VLLM**进行推理加速和部署. 具体可以查看[VLLM推理加速与部署](https://github.com/modelscope/swift/blob/main/docs/source/LLM/VLLM推理加速与部署.md). - 2023.12.15: 支持**deepseek**, **deepseek-coder**系列: deepseek-7b, deepseek-7b-chat, deepseek-67b, deepseek-67b-chat, openbuddy-deepseek-67b-chat, deepseek-coder-1_3b, deepseek-coder-1_3b-chat, deepseek-coder-6_7b, deepseek-coder-6_7b-chat, deepseek-coder-33b, deepseek-coder-33b-chat. - 2023.12.13: 支持mistral-7b-chat-v2, [mixtral-7b-moe](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe), [mixtral-7b-moe-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/mixtral_7b_moe_chat). @@ -137,6 +138,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展 - 代码: - codefuse 系列: [codefuse-codellama-34b-chat](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B/summary) - deepseek-coder 系列: [deepseek-coder-1_3b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-base/summary), [deepseek-coder-1_3b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-1.3b-instruct/summary), [deepseek-coder-6_7b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-base/summary), [deepseek-coder-6_7b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-instruct/summary), [deepseek-coder-33b](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-base/summary), [deepseek-coder-33b-chat](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-instruct/summary) + - phi 系列: [phi2-3b](https://modelscope.cn/models/AI-ModelScope/phi-2/summary) - 支持的数据集: [[详细]](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md#%E6%95%B0%E6%8D%AE%E9%9B%86) - NLP: - 通用: 🔥[alpaca-en](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-en/summary)(gpt4), 🔥[alpaca-zh](https://modelscope.cn/datasets/AI-ModelScope/alpaca-gpt4-data-zh/summary)(gpt4), [multi-alpaca-all](https://www.modelscope.cn/datasets/damo/nlp_polylm_multialpaca_sft/summary), [instinwild-en](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [instinwild-zh](https://www.modelscope.cn/datasets/wyj123456/instinwild/summary), [cot-en](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [cot-zh](https://www.modelscope.cn/datasets/YorickHe/CoT/summary), [firefly-all-zh](https://www.modelscope.cn/datasets/wyj123456/firefly/summary), [instruct-en](https://www.modelscope.cn/datasets/wyj123456/instruct/summary), [gpt4all-en](https://www.modelscope.cn/datasets/wyj123456/GPT4all/summary), [sharegpt-en](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [sharegpt-zh](https://www.modelscope.cn/datasets/huangjintao/sharegpt/summary), [tutu-v2-sft-mixture](https://modelscope.cn/datasets/AI-ModelScope/tulu-v2-sft-mixture/summary), [wikipedia-zh](https://modelscope.cn/datasets/AI-ModelScope/wikipedia-cn-20230720-filtered/summary), [open-orca](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [open-orca-gpt4](https://modelscope.cn/datasets/AI-ModelScope/OpenOrca/summary), [sharegpt-gpt4](https://modelscope.cn/datasets/AI-ModelScope/sharegpt_gpt4/summary) diff --git "a/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" "b/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" index 6596d4d325..7b2db22196 100644 --- "a/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" +++ "b/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" @@ -30,6 +30,9 @@ pip install ms-swift -U SWIFT库提供了**LLM&AIGC模型的训练推理脚手架**,支持LLaMA、QWen、ChatGLM、Stable Diffusion等多种模型的直接训练和推理,并且集成了SWIFT库提供的tuners, 开发者可以直接使用。它们的位置在:https://github.com/modelscope/swift/tree/main/examples/pytorch/llm +- LLM训练和推理可以查看[LLM微调文档](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM微调文档.md) +- AIGC训练和推理可以查看[文生图微调文档](https://github.com/modelscope/swift/blob/main/docs/source/AIGC/AnimateDiff微调推理文档.md) + 如果需要在自定义的训练流程中使用tuners,可以参考下面的代码。下面的代码使用LoRA在分类任务上训练了`bert-base-uncased`模型: ```python diff --git "a/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index b45b2369c6..6759428628 100644 --- "a/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/LLM/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -108,6 +108,7 @@ |deepseek-coder-6_7b-chat|[deepseek-ai/deepseek-coder-6.7b-instruct](https://modelscope.cn/models/deepseek-ai/deepseek-coder-6.7b-instruct/summary)|q_proj, k_proj, v_proj|deepseek-coder|✔|✔|| |deepseek-coder-33b|[deepseek-ai/deepseek-coder-33b-base](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-base/summary)|q_proj, k_proj, v_proj|default-generation-bos|✔|✔|| |deepseek-coder-33b-chat|[deepseek-ai/deepseek-coder-33b-instruct](https://modelscope.cn/models/deepseek-ai/deepseek-coder-33b-instruct/summary)|q_proj, k_proj, v_proj|deepseek-coder|✔|✔|| +|phi2-3b|[AI-ModelScope/phi-2](https://modelscope.cn/models/AI-ModelScope/phi-2/summary)|Wqkv|default-generation|✔|✔|| ## 数据集 diff --git a/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh b/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh new file mode 100644 index 0000000000..1a591c9030 --- /dev/null +++ b/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh @@ -0,0 +1,13 @@ +# Experimental environment: A10 +# 8GB GPU memory +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --ckpt_dir "/mnt/workspace/my_git/swift/examples/pytorch/llm/output/phi2-3b/v0-20231219-204021/checkpoint-100" \ + --load_dataset_config true \ + --max_length 2048 \ + --use_flash_attn false \ + --max_new_tokens 2048 \ + --temperature 0.1 \ + --top_p 0.7 \ + --repetition_penalty 1.05 \ + --merge_lora_and_save false \ diff --git a/examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh b/examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh new file mode 100644 index 0000000000..0a58c3cb45 --- /dev/null +++ b/examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh @@ -0,0 +1,17 @@ +# Experimental environment: A100 +# 60GB GPU memory +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model_type phi2-3b \ + --sft_type lora \ + --template_type default \ + --train_dataset_sample 20000 \ + --eval_steps 100 \ + --output_dir output \ + --num_train_epochs 1 \ + --max_length 2048 \ + --learning_rate 1e-4 \ + --use_flash_attn true \ + --only_save_model true \ + --lora_target_modules ALL \ + --dataset codefuse-python-en \ diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 09d51bd4b4..b992a80ceb 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -94,7 +94,7 @@ class SftArguments: neftune_alpha: float = 0.0 - gradient_checkpointing: bool = True + gradient_checkpointing: Optional[bool] = None deepspeed_config_path: Optional[str] = None # e.g. 'ds_config/zero2.json' batch_size: int = 1 eval_batch_size: Optional[int] = None @@ -299,8 +299,13 @@ def __post_init__(self) -> None: logger.info( f'Setting self.preprocess_num_proc: {self.preprocess_num_proc}' ) - if 'moe' in self.model_type: - assert self.gradient_checkpointing is False, 'moe not support gradient_checkpointing' + model_info = MODEL_MAPPING[self.model_type] + support_gradient_checkpointing = model_info.get( + 'support_gradient_checkpointing', True) + if self.gradient_checkpointing is None: + self.gradient_checkpointing = support_gradient_checkpointing + elif not support_gradient_checkpointing: + assert self.gradient_checkpointing is False, 'not support gradient_checkpointing' @dataclass @@ -428,8 +433,8 @@ def __post_init__(self) -> None: logger.warning('Setting overwrite_generation_config: False') if self.ckpt_dir is None: self.sft_type = 'full' - support_vllm = MODEL_MAPPING[self.model_type].get( - 'support_vllm', False) + model_info = MODEL_MAPPING[self.model_type] + support_vllm = model_info.get('support_vllm', False) if self.infer_backend == 'AUTO': if self.sft_type == 'full' and is_vllm_available( ) and support_vllm: diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 4aeb013db3..6b4483c124 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -151,6 +151,8 @@ class ModelType: deepseek_coder_6_7b_chat = 'deepseek-coder-6_7b-chat' deepseek_coder_33b = 'deepseek-coder-33b' deepseek_coder_33b_chat = 'deepseek-coder-33b-chat' + # phi + phi2_3b = 'phi2-3b' @classmethod def get_model_name_list(cls) -> List[str]: @@ -170,6 +172,7 @@ class LoRATM(NamedTuple): qwen = ['c_attn'] polylm = ['c_attn'] bloom = ['query_key_value'] + phi = ['Wqkv'] GetModelTokenizerFunction = Callable[..., Tuple[Optional[PreTrainedModel], @@ -777,7 +780,8 @@ def cross_entropy_forward(self, inputs: Tensor, TemplateType.default_generation_bos, requires=['transformers>=4.36'], support_flash_attn=True, - support_vllm=True) + support_vllm=True, + support_gradient_checkpointing=False) @register_model( ModelType.mixtral_7b_moe_chat, 'AI-ModelScope/Mixtral-8x7B-Instruct-v0.1', @@ -785,7 +789,8 @@ def cross_entropy_forward(self, inputs: Tensor, TemplateType.llama, requires=['transformers>=4.36'], support_flash_attn=True, - support_vllm=True) + support_vllm=True, + support_gradient_checkpointing=False) def get_model_tokenizer_with_flash_attn(model_dir: str, torch_dtype: Dtype, model_kwargs: Dict[str, Any], @@ -1291,6 +1296,27 @@ def get_model_tokenizer_codellama(model_dir: str, **kwargs) +@register_model( + ModelType.phi2_3b, + 'AI-ModelScope/phi-2', + LoRATM.phi, + TemplateType.default_generation, + support_flash_attn=True, + support_vllm=True, + support_gradient_checkpointing=False) +def get_model_tokenizer_phi(model_dir: str, + torch_dtype: Dtype, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + model_config = AutoConfig.from_pretrained( + model_dir, trust_remote_code=True) + use_flash_attn = kwargs.get('use_flash_attn', False) + model_config.flash_attn = use_flash_attn + return get_model_tokenizer_from_repo(model_dir, torch_dtype, model_kwargs, + load_model, model_config, **kwargs) + + def fix_transformers_upgrade(module: PreTrainedModel) -> None: # from 4.35, transformers changes its arguments of _set_gradient_checkpointing if version.parse(transformers.__version__) >= version.parse('4.35'): diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index fb46b7ad24..924e65c562 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -234,14 +234,18 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, self.tokenizer = tokenizer self.stop_words = stop_words self.decode_kwargs = decode_kwargs + self.start_idx = -1 def __call__(self, input_ids: Tensor, scores: Tensor) -> bool: + if self.start_idx == -1: + self.start_idx = len(input_ids[0]) - 1 tokenizer = self.tokenizer stop_words = self.stop_words - text = tokenizer.decode(input_ids[0], **self.decode_kwargs) + text = tokenizer.decode(input_ids[0, self.start_idx:], + **self.decode_kwargs) for stop_word in stop_words: if isinstance(stop_word, str): - if text.endswith(stop_word): + if stop_word in text: return True elif isinstance(stop_word, list) and len(stop_word) > 0: res = [] diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 23d5a61816..95d207a362 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -39,7 +39,8 @@ from swift.hub import ModelScopeConfig from swift.utils import (get_dist_setting, get_logger, is_ddp_plus_mp, is_dist, is_local_master, is_master, stat_array, upper_bound) -from .template import History, StopWordsCriteria, Template, get_audio_info +from .template import (History, StopWords, StopWordsCriteria, Template, + get_audio_info) logger = get_logger() ms_logger = get_ms_logger() @@ -446,11 +447,14 @@ def inference_stream( history: Optional[History] = None, system: Optional[str] = None, *, - generation_config: Optional[GenerationConfig] = None + generation_config: Optional[GenerationConfig] = None, + stop_words: Optional[List[StopWords]] = None, ) -> Iterator[Tuple[str, History]]: """ generation_config: Priority: generation_config > model.generation_config. """ + if stop_words is None: + stop_words = [] if history is None: history = [] else: @@ -479,7 +483,8 @@ def inference_stream( if stream_config.max_new_tokens is not None: stream_config.max_length = 20 # fix max_length, max_new_tokens warning stream_config.do_sample = True # avoid is_greedy_gen_mode = True - stop_words = [template.suffix[-1]] + if template.suffix[-1] not in stop_words: + stop_words.append(template.suffix[-1]) decode_kwargs = {} model_kwargs = {} if audio_info is not None: @@ -522,6 +527,7 @@ def inference(model: PreTrainedModel, system: Optional[str] = None, *, generation_config: Optional[GenerationConfig] = None, + stop_words: Optional[List[StopWords]] = None, stream: bool = False, verbose: bool = False, prompt_prefix: str = '[PROMPT]', @@ -529,6 +535,8 @@ def inference(model: PreTrainedModel, """ generation_config: Priority: generation_config > model.generation_config. """ + if stop_words is None: + stop_words = [] if history is None: history = [] else: @@ -569,7 +577,8 @@ def inference(model: PreTrainedModel, generation_config.pad_token_id = tokenizer.pad_token_id if generation_config.max_new_tokens is not None: generation_config.max_length = 20 # fix max_length, max_new_tokens warning - stop_words = [template.suffix[-1]] + if template.suffix[-1] not in stop_words: + stop_words.append(template.suffix[-1]) stopping_criteria = StoppingCriteriaList( [StopWordsCriteria(tokenizer, stop_words, **decode_kwargs)]) generate_ids = model.generate( From 6c0252cf3f0567a8531914edaff19f4246e05d23 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 Dec 2023 21:53:06 +0800 Subject: [PATCH 2/6] update readme --- README.md | 4 ++-- README_CN.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6e55aa2870..248b5145b6 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,9 @@

-ModelScope Hub  |  Docs +ModelScope Hub
- 中文  |  English + 中文  |  English  |  Docs

diff --git a/README_CN.md b/README_CN.md index ec80a92bce..e0ad74aeda 100644 --- a/README_CN.md +++ b/README_CN.md @@ -7,9 +7,9 @@

-魔搭社区  |  文档 +魔搭社区
- 中文  |  English + 中文  |  English  |  文档

From dcb763a6a236008d79396861d3aa001b613ed116 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 Dec 2023 21:56:21 +0800 Subject: [PATCH 3/6] update docs --- .../\345\277\253\351\200\237\344\275\277\347\224\250.md" | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" "b/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" index 7b2db22196..f543aac4d0 100644 --- "a/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" +++ "b/docs/source/GetStarted/\345\277\253\351\200\237\344\275\277\347\224\250.md" @@ -30,8 +30,8 @@ pip install ms-swift -U SWIFT库提供了**LLM&AIGC模型的训练推理脚手架**,支持LLaMA、QWen、ChatGLM、Stable Diffusion等多种模型的直接训练和推理,并且集成了SWIFT库提供的tuners, 开发者可以直接使用。它们的位置在:https://github.com/modelscope/swift/tree/main/examples/pytorch/llm -- LLM训练和推理可以查看[LLM微调文档](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM微调文档.md) -- AIGC训练和推理可以查看[文生图微调文档](https://github.com/modelscope/swift/blob/main/docs/source/AIGC/AnimateDiff微调推理文档.md) +- LLM训练和推理可以查看: [LLM微调文档](https://github.com/modelscope/swift/blob/main/docs/source/LLM/LLM微调文档.md) +- AIGC训练和推理可以查看: [文生图微调文档](https://github.com/modelscope/swift/blob/main/docs/source/AIGC/AnimateDiff微调推理文档.md) 如果需要在自定义的训练流程中使用tuners,可以参考下面的代码。下面的代码使用LoRA在分类任务上训练了`bert-base-uncased`模型: From 020ea68260452fe845422b55806d229cb673aab7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 Dec 2023 21:58:25 +0800 Subject: [PATCH 4/6] update sh --- examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh b/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh index 1a591c9030..accbb2c140 100644 --- a/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh +++ b/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh @@ -2,7 +2,7 @@ # 8GB GPU memory CUDA_VISIBLE_DEVICES=0 \ swift infer \ - --ckpt_dir "/mnt/workspace/my_git/swift/examples/pytorch/llm/output/phi2-3b/v0-20231219-204021/checkpoint-100" \ + --ckpt_dir "phi2-3b/vx_xxx/checkpoint-xxx" \ --load_dataset_config true \ --max_length 2048 \ --use_flash_attn false \ From eaa9b3cfd4e05252e659e84246dcb3f0c9bf32aa Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 Dec 2023 21:59:44 +0800 Subject: [PATCH 5/6] update sh --- examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh b/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh index accbb2c140..9cdaf3781f 100644 --- a/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh +++ b/examples/pytorch/llm/scripts/phi2_3b/lora/infer.sh @@ -2,7 +2,7 @@ # 8GB GPU memory CUDA_VISIBLE_DEVICES=0 \ swift infer \ - --ckpt_dir "phi2-3b/vx_xxx/checkpoint-xxx" \ + --ckpt_dir "output/phi2-3b/vx_xxx/checkpoint-xxx" \ --load_dataset_config true \ --max_length 2048 \ --use_flash_attn false \ From 9be807485b1cae2bb725edce7930cd3a4b416658 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 Dec 2023 22:12:48 +0800 Subject: [PATCH 6/6] update sh --- examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh b/examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh index 0a58c3cb45..00eac0d155 100644 --- a/examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh +++ b/examples/pytorch/llm/scripts/phi2_3b/lora/sft.sh @@ -15,3 +15,4 @@ swift sft \ --only_save_model true \ --lora_target_modules ALL \ --dataset codefuse-python-en \ + --gradient_checkpointing false \