From 2f60d858682de9bac685facf90f377556dc092fb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 8 Dec 2023 02:52:54 +0800 Subject: [PATCH 1/3] support multi-node --- ...56\350\260\203\346\226\207\346\241\243.md" | 28 +++++++++++++++++-- ...44\350\241\214\345\217\202\346\225\260.md" | 2 +- swift/llm/sft.py | 3 +- swift/llm/utils/argument.py | 12 ++++---- swift/llm/utils/utils.py | 22 +++++++-------- swift/utils/utils.py | 4 ++- 6 files changed, 47 insertions(+), 24 deletions(-) diff --git "a/docs/source/LLM/LLM\345\276\256\350\260\203\346\226\207\346\241\243.md" "b/docs/source/LLM/LLM\345\276\256\350\260\203\346\226\207\346\241\243.md" index d977f0600e..dc7e32df8f 100644 --- "a/docs/source/LLM/LLM\345\276\256\350\260\203\346\226\207\346\241\243.md" +++ "b/docs/source/LLM/LLM\345\276\256\350\260\203\346\226\207\346\241\243.md" @@ -78,6 +78,13 @@ swift sft \ --dataset blossom-math-zh \ --output_dir output \ +# 使用自己的数据集 +CUDA_VISIBLE_DEVICES=0 \ +swift sft \ + --model_id_or_path qwen/Qwen-7B-Chat \ + --custom_train_dataset_path chatml.jsonl \ + --output_dir output \ + # 使用DDP # Experimental environment: 2 * 3090 # 2 * 23GB GPU memory @@ -88,11 +95,26 @@ swift sft \ --dataset blossom-math-zh \ --output_dir output \ -# 使用自己的数据集 -CUDA_VISIBLE_DEVICES=0 \ +# 多机多卡 +# node0 +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +NNODES=2 \ +NODE_RANK=0 \ +MASTER_ADDR=127.0.0.1 \ +NPROC_PER_NODE=4 \ swift sft \ --model_id_or_path qwen/Qwen-7B-Chat \ - --custom_train_dataset_path chatml.jsonl \ + --dataset blossom-math-zh \ + --output_dir output \ +# node1 +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +NNODES=2 \ +NODE_RANK=1 \ +MASTER_ADDR=xxx.xxx.xxx.xxx \ +NPROC_PER_NODE=4 \ +swift sft \ + --model_id_or_path qwen/Qwen-7B-Chat \ + --dataset blossom-math-zh \ --output_dir output \ ``` diff --git "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 87fd9ae75e..defda908d1 100644 --- "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -49,7 +49,7 @@ - `--optim`: 默认为`'adamw_torch'`. - `--learning_rate`: 默认值为`None`, 即如果`sft_type`为lora, 则设置为1e-4, 如果`sft_type`为full, 则设置为2e-5. - `--weight_decay`: 默认值为`0.01`. -- `--gradient_accumulation_steps`: 梯度累加, 默认值为`16`. `total_batch_size = batch_size * gradient_accumulation_steps * world_size`. +- `--gradient_accumulation_steps`: 梯度累加, 默认值为`None`, 设置为`math.ceil(16 / self.batch_size / world_size)`. `total_batch_size = batch_size * gradient_accumulation_steps * world_size`. - `--max_grad_norm`: 梯度裁剪, 默认值为`0.5`. - `--predict_with_generate`: 评估时是否使用生成式的方式, 默认为`False`. 如果设置为False, 则使用`loss`进行评估. 如果设置为True, 则使用`ROUGE-L`等指标进行评估. 使用生成式评估耗费的时间很长, 请谨慎选择. - `--lr_scheduler_type`: 默认值为`'cosine'`. diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 3e2ae74a78..601200b3df 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -236,7 +236,8 @@ def llm_sft(args: SftArguments) -> str: train_sampler_random=args.train_sampler_random, report_to=args.report_to, deepspeed=args.deepspeed, - additional_saved_files=additional_saved_files) + additional_saved_files=additional_saved_files, + save_on_each_node=True) if args.gradient_checkpointing: model.enable_input_require_grads() diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index abf7d6d036..e1dd3cf8aa 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -179,11 +179,6 @@ def __post_init__(self) -> None: 'For example: `--lora_target_modules ALL`. ' 'If you have already added LoRA on MLP, please ignore this warning.' ) - if self.add_output_dir_suffix: - self.output_dir = os.path.join(self.output_dir, self.model_type) - if is_master(): - self.output_dir = add_version_to_work_dir(self.output_dir) - logger.info(f'output_dir: {self.output_dir}') self.torch_dtype, self.fp16, self.bf16 = select_dtype(self) world_size = 1 @@ -197,8 +192,11 @@ def __post_init__(self) -> None: # Initialize in advance if not dist.is_initialized(): dist.init_process_group(backend=self.ddp_backend) - # Make sure to set the same output_dir when using DDP. - self.output_dir = broadcast_string(self.output_dir) + + if self.add_output_dir_suffix: + self.output_dir = os.path.join(self.output_dir, self.model_type) + self.output_dir = add_version_to_work_dir(self.output_dir) + logger.info(f'output_dir: {self.output_dir}') if self.sft_type in ('lora', 'longlora', 'qalora'): if 'int4' in self.model_type or 'int8' in self.model_type: diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 0e98b4177b..7a32b32645 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -40,6 +40,17 @@ logger = get_logger() ms_logger = get_ms_logger() +logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s') + +logger.handlers[0].setFormatter(logger_format) +ms_logger.handlers[0].setFormatter(logger_format) +if is_local_master(): + logger.setLevel(logging.INFO) + ms_logger.setLevel(logging.INFO) +else: + logger.setLevel(logging.ERROR) + ms_logger.setLevel(logging.ERROR) + os.environ['TOKENIZERS_PARALLELISM'] = 'true' @@ -187,17 +198,6 @@ def dataset_map( return LLMDataset(data) -logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s') - -logger.handlers[0].setFormatter(logger_format) -ms_logger.handlers[0].setFormatter(logger_format) -if is_master(): - logger.setLevel(logging.INFO) - ms_logger.setLevel(logging.INFO) -else: - logger.setLevel(logging.ERROR) - ms_logger.setLevel(logging.ERROR) - _TArgsClass = TypeVar('_TArgsClass') _T = TypeVar('_T') NoneType = type(None) diff --git a/swift/utils/utils.py b/swift/utils/utils.py index af6573817f..f3fd8ed622 100644 --- a/swift/utils/utils.py +++ b/swift/utils/utils.py @@ -11,6 +11,7 @@ from .logger import get_logger from .np_utils import stat_array +from .torch_utils import broadcast_string, is_dist logger = get_logger() @@ -52,7 +53,8 @@ def add_version_to_work_dir(work_dir: str) -> str: """add version""" version = _get_version(work_dir) time = dt.datetime.now().strftime('%Y%m%d-%H%M%S') - + if is_dist(): + time = broadcast_string(time) work_dir = os.path.join(work_dir, f'v{version}-{time}') return work_dir From 94abf0b46425d0c2854ec0190974f0769041579d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 8 Dec 2023 02:58:52 +0800 Subject: [PATCH 2/3] fix --- swift/utils/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/swift/utils/utils.py b/swift/utils/utils.py index f3fd8ed622..635c7676c1 100644 --- a/swift/utils/utils.py +++ b/swift/utils/utils.py @@ -7,6 +7,7 @@ Type, TypeVar) import numpy as np +import torch.distributed as dist from transformers import HfArgumentParser from .logger import get_logger @@ -53,7 +54,7 @@ def add_version_to_work_dir(work_dir: str) -> str: """add version""" version = _get_version(work_dir) time = dt.datetime.now().strftime('%Y%m%d-%H%M%S') - if is_dist(): + if dist.is_initialized() and is_dist(): time = broadcast_string(time) work_dir = os.path.join(work_dir, f'v{version}-{time}') return work_dir From 3490c18a4e8cfd404c79841500f7b42474f39eb5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 8 Dec 2023 03:04:39 +0800 Subject: [PATCH 3/3] update save_on_each_node --- ...\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 1 + swift/llm/sft.py | 2 +- swift/llm/utils/argument.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index defda908d1..778024d0b0 100644 --- "a/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/LLM/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -70,6 +70,7 @@ - `--ignore_args_error`: 是否忽略命令行传参错误抛出的Error, 默认为`False`. 如果需要拷贝代码到notebook中运行, 需要设置成True. - `--logging_dir`: 默认为`None`. 即设置为`f'{self.output_dir}/runs'`, 表示tensorboard文件存储路径. - `--check_model_is_latest`: 检查模型是否是最新, 默认为`True`. 如果你需要断网进行训练, 请将该参数设置为`False`. +- `--save_on_each_node`: 该参数在多机训练时生效, 默认为`True`. - `--max_new_tokens`: 默认为`2048`. 该参数只有在`predict_with_generate`设置为True的时候才生效. - `--do_sample`: 默认为`True`. 该参数只有在`predict_with_generate`设置为True的时候才生效. - `--temperature`: 默认为`0.3`. 该参数只有在`predict_with_generate`设置为True的时候才生效. diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 601200b3df..077df7ec40 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -237,7 +237,7 @@ def llm_sft(args: SftArguments) -> str: report_to=args.report_to, deepspeed=args.deepspeed, additional_saved_files=additional_saved_files, - save_on_each_node=True) + save_on_each_node=args.save_on_each_node) if args.gradient_checkpointing: model.enable_input_require_grads() diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index e1dd3cf8aa..1f9415cc78 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -145,6 +145,7 @@ class SftArguments: logging_dir: Optional[str] = None report_to: Optional[List[str]] = None check_model_is_latest: bool = True + save_on_each_node: bool = True # generation config max_new_tokens: int = 2048