From 33105341b9a61ac33d80bf80dc8fcafa78ba4dfc Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Mon, 23 Sep 2024 16:32:58 +0800 Subject: [PATCH 1/3] fjx cpu infer device_map --- swift/llm/sft.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 7bd87d1041..dd99af27df 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -113,6 +113,25 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]: plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9) return {} +def get_device_map(): + if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true': + return None + local_rank = get_dist_setting()[1] + if is_torch_npu_available(): + if local_rank >= 0: + return f'npu:{local_rank}' + else: + return 'npu:0' + elif args.device_map_config is not None: + model_kwargs = {'device_map': args.device_map_config} + else: + model_kwargs = {'low_cpu_mem_usage': True} + if is_dist() and not is_ddp_plus_mp(): + model_kwargs['device_map'] = {'': args.local_rank} + elif torch.cuda.device_count() == 1: + return 'cuda:0' + elif not use_torchacc(): + return 'auto' def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): @@ -128,21 +147,7 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): f'world_size: {args.world_size}, local_world_size: {args.local_world_size}') # Loading Model and Tokenizer - if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true': - model_kwargs = {'device_map': None} - elif is_torch_npu_available(): - model_kwargs = {'device_map': args.local_rank if args.local_rank >= 0 else 0} - elif args.device_map_config is not None: - model_kwargs = {'device_map': args.device_map_config} - else: - model_kwargs = {'low_cpu_mem_usage': True} - if is_dist() and not is_ddp_plus_mp(): - model_kwargs['device_map'] = {'': args.local_rank} - elif torch.cuda.device_count() == 1: - model_kwargs['device_map'] = 'cuda:0' - elif not use_torchacc(): - model_kwargs['device_map'] = 'auto' - + model_kwargs['device_map'] = get_device_map(args.local_rank) if args.device_max_memory: n_gpu = torch.cuda.device_count() assert len(args.device_max_memory) == n_gpu // args.local_world_size From 9503774332ada7081e4591e20dd744b9acc11048 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 23 Sep 2024 16:49:14 +0800 Subject: [PATCH 2/3] update --- swift/llm/infer.py | 15 +++++++-------- swift/llm/sft.py | 34 +++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/swift/llm/infer.py b/swift/llm/infer.py index ec2da427de..e8b3afd9fc 100644 --- a/swift/llm/infer.py +++ b/swift/llm/infer.py @@ -15,6 +15,7 @@ from swift.tuners import Swift from swift.utils import (append_to_jsonl, get_logger, get_main, get_model_info, read_multi_line, seed_everything, show_layers) +from .sft import get_default_device_map from .utils import (DeployArguments, InferArguments, MediaTag, Template, get_additional_saved_files, get_dataset, get_model_tokenizer, get_template, inference, inference_stream, is_adapter, is_quant_model, sample_dataset, set_generation_config) @@ -135,16 +136,14 @@ def prepare_model_template(args: InferArguments, device_map: Optional[str] = None, verbose: bool = True, automodel_class=None) -> Tuple[PreTrainedModel, Template]: - - model_kwargs = {} if is_torch_npu_available(): - logger.info(f'device_count: {torch.npu.device_count()}') - if device_map is None: - device_map = 'npu:0' + print(f'device_count: {torch.npu.device_count()}') else: - logger.info(f'device_count: {torch.cuda.device_count()}') - if device_map is None: - device_map = 'auto' if torch.cuda.device_count() > 1 else 'cuda:0' + print(f'device_count: {torch.cuda.device_count()}') + model_kwargs = {} + if device_map is not None: + device_map = get_default_device_map() + model_kwargs['device_map'] = device_map if device_map == 'auto': model_kwargs['low_cpu_mem_usage'] = True model_kwargs['device_map'] = device_map diff --git a/swift/llm/sft.py b/swift/llm/sft.py index dd99af27df..b57411e130 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -15,8 +15,8 @@ from swift.torchacc_utils import patch_acc_model from swift.trainers import TrainerFactory from swift.trainers.utils import can_return_loss, find_labels -from swift.utils import (append_to_jsonl, check_json_format, compute_acc_metrics, compute_nlg_metrics, get_logger, - get_main, get_model_info, is_ddp_plus_mp, is_dist, is_master, plot_images, +from swift.utils import (append_to_jsonl, check_json_format, compute_acc_metrics, compute_nlg_metrics, get_dist_setting, + get_logger, get_main, get_model_info, is_ddp_plus_mp, is_dist, is_master, plot_images, preprocess_logits_for_metrics, seed_everything, show_layers, use_torchacc) from .accelerator import ta_accelerate from .tuner import prepare_model @@ -113,7 +113,8 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]: plot_images(images_dir, args.logging_dir, ['train/loss'], 0.9) return {} -def get_device_map(): + +def get_default_device_map(): if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true': return None local_rank = get_dist_setting()[1] @@ -122,16 +123,15 @@ def get_device_map(): return f'npu:{local_rank}' else: return 'npu:0' - elif args.device_map_config is not None: - model_kwargs = {'device_map': args.device_map_config} + if is_dist() and not is_ddp_plus_mp(): + return f'cuda:{local_rank}' + elif torch.cuda.device_count() == 0: + return 'cpu' + elif torch.cuda.device_count() == 1: + return 'cuda:0' else: - model_kwargs = {'low_cpu_mem_usage': True} - if is_dist() and not is_ddp_plus_mp(): - model_kwargs['device_map'] = {'': args.local_rank} - elif torch.cuda.device_count() == 1: - return 'cuda:0' - elif not use_torchacc(): - return 'auto' + return 'auto' + def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): @@ -147,7 +147,15 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None): f'world_size: {args.world_size}, local_world_size: {args.local_world_size}') # Loading Model and Tokenizer - model_kwargs['device_map'] = get_device_map(args.local_rank) + model_kwargs = {} + if not use_torchacc(): + if args.device_map_config is not None: + device_map = args.device_map_config + else: + device_map = get_default_device_map() + model_kwargs['device_map'] = device_map + if device_map == 'auto': + model_kwargs['low_cpu_mem_usage'] = True if args.device_max_memory: n_gpu = torch.cuda.device_count() assert len(args.device_max_memory) == n_gpu // args.local_world_size From 9facc4958ca51bcdaf48017f78f4dcc024ded431 Mon Sep 17 00:00:00 2001 From: "huangjintao.hjt" Date: Mon, 23 Sep 2024 17:12:34 +0800 Subject: [PATCH 3/3] fix --- swift/llm/infer.py | 2 +- swift/llm/sft.py | 7 +++---- swift/llm/utils/utils.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/swift/llm/infer.py b/swift/llm/infer.py index e8b3afd9fc..660e721806 100644 --- a/swift/llm/infer.py +++ b/swift/llm/infer.py @@ -15,7 +15,6 @@ from swift.tuners import Swift from swift.utils import (append_to_jsonl, get_logger, get_main, get_model_info, read_multi_line, seed_everything, show_layers) -from .sft import get_default_device_map from .utils import (DeployArguments, InferArguments, MediaTag, Template, get_additional_saved_files, get_dataset, get_model_tokenizer, get_template, inference, inference_stream, is_adapter, is_quant_model, sample_dataset, set_generation_config) @@ -136,6 +135,7 @@ def prepare_model_template(args: InferArguments, device_map: Optional[str] = None, verbose: bool = True, automodel_class=None) -> Tuple[PreTrainedModel, Template]: + from .sft import get_default_device_map if is_torch_npu_available(): print(f'device_count: {torch.npu.device_count()}') else: diff --git a/swift/llm/sft.py b/swift/llm/sft.py index b57411e130..9739bdd665 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -123,12 +123,12 @@ def get_default_device_map(): return f'npu:{local_rank}' else: return 'npu:0' - if is_dist() and not is_ddp_plus_mp(): - return f'cuda:{local_rank}' - elif torch.cuda.device_count() == 0: + if torch.cuda.device_count() == 0: return 'cpu' elif torch.cuda.device_count() == 1: return 'cuda:0' + elif is_dist() and not is_ddp_plus_mp(): + return f'cuda:{local_rank}' else: return 'auto' @@ -369,7 +369,6 @@ def prepare_dataset(args, template: Template, msg: Optional[Dict[str, Any]] = No f'Setting args.preprocess_num_proc to: {args.preprocess_num_proc}') else: template.model = None - logger.info(f'Using num_proc: {args.preprocess_num_proc}') td0, tkwargs0 = template.encode(train_dataset[0]) print_example(td0, tokenizer, tkwargs0) train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc, streaming=args.streaming) diff --git a/swift/llm/utils/utils.py b/swift/llm/utils/utils.py index 548c15c534..5ec51d53f9 100644 --- a/swift/llm/utils/utils.py +++ b/swift/llm/utils/utils.py @@ -302,7 +302,7 @@ def _map_mp(dataset: HfDataset, map_func: MapFunc, num_proc: int) -> List[Dict[s # Solving the unordered problem data = [None] * len(dataset) num_proc = min(num_proc, len(dataset)) - for d in tqdm(_map_mp_i(dataset, map_func, num_proc), total=len(dataset)): + for d in tqdm(_map_mp_i(dataset, map_func, num_proc), total=len(dataset), desc=f'Map (num_proc={num_proc})'): data[d[0]] = d[1] return data @@ -317,7 +317,7 @@ def dataset_map(dataset: DATASET_TYPE, single_map = partial(_single_map, map_func=map_func) if num_proc == 1: data = [] - for d in tqdm(dataset): + for d in tqdm(dataset, desc='Map'): d = single_map(d) data.append(d) else: