diff --git "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" index 001bbcf9d0..c3c14ffcc7 100644 --- "a/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" +++ "b/docs/source/Instruction/Megatron-SWIFT\350\256\255\347\273\203.md" @@ -205,6 +205,16 @@ swift export \ ``` - 注意:`mcore_adapters`文件夹中包含`args.json`文件,转换过程中会读取文件中`mcore_model`和LoRA相关的参数信息,并将`mcore_model`和`mcore_adapters`进行merge-lora成完整权重,最终转换成HF格式权重。 +如果你只想merge-lora,而不希望转成HF格式权重,用于后续DPO训练,可以使用以下脚本: +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --mcore_adapters megatron_output/Qwen2.5-7B-Instruct/vx-xxx \ + --to_mcore true \ + --torch_dtype bfloat16 \ + --output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-mcore \ + --test_convert_precision true +``` ## Benchmark diff --git a/docs/source_en/Instruction/Megatron-SWIFT-Training.md b/docs/source_en/Instruction/Megatron-SWIFT-Training.md index 6f5fdf7f33..c4cb91060e 100644 --- a/docs/source_en/Instruction/Megatron-SWIFT-Training.md +++ b/docs/source_en/Instruction/Megatron-SWIFT-Training.md @@ -213,6 +213,18 @@ swift export \ - Note: The `mcore_adapters` folder contains an `args.json` file. During the conversion process, parameters related to `mcore_model` and LoRA will be loaded from this file. The system will then perform a merge-lora operation between the `mcore_model` and `mcore_adapters` to obtain the complete model weights, and finally convert them into HuggingFace (HF) format. +If you only want to merge the LoRA weights without converting them to Hugging Face format, for subsequent DPO training, you can use the following script: + +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --mcore_adapters megatron_output/Qwen2.5-7B-Instruct/vx-xxx \ + --to_mcore true \ + --torch_dtype bfloat16 \ + --output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-mcore \ + --test_convert_precision true +``` + ## Benchmark The speed comparison of full-parameter training for Dense/MoE models using `megatron sft` and `swift sft` on a single machine with eight A800 GPUs is shown below. The corresponding scripts can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark). diff --git a/swift/llm/export/export.py b/swift/llm/export/export.py index 3e2f12ed7a..23f9202735 100644 --- a/swift/llm/export/export.py +++ b/swift/llm/export/export.py @@ -32,12 +32,12 @@ def run(self): export_to_ollama(args) elif args.to_cached_dataset: export_cached_dataset(args) + elif args.to_hf or args.mcore_adapters and args.to_mcore: + from swift.megatron import convert_mcore2hf + convert_mcore2hf(args) elif args.to_mcore: from swift.megatron import convert_hf2mcore convert_hf2mcore(args) - elif args.to_hf: - from swift.megatron import convert_mcore2hf - convert_mcore2hf(args) elif args.push_to_hub: model_dir = args.adapters and args.adapters[0] or args.model_dir assert model_dir, f'model_dir: {model_dir}' diff --git a/swift/llm/infer/utils.py b/swift/llm/infer/utils.py index 757da2c226..3ebcbf5d8b 100644 --- a/swift/llm/infer/utils.py +++ b/swift/llm/infer/utils.py @@ -143,7 +143,8 @@ def prepare_adapter(args, model, adapters=None): def prepare_model_template(args, **kwargs): model, processor = args.get_model_processor(**kwargs) - model = prepare_adapter(args, model) template = args.get_template(processor) - update_generation_config_eos_token(model.generation_config, template) + if model is not None: + model = prepare_adapter(args, model) + update_generation_config_eos_token(model.generation_config, template) return model, template diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index a4114b9287..2349ceced2 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -336,10 +336,8 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t tuner: Tuner = extra_tuners[args.train_type] else: tuner = Swift - kwargs = {} assert not args.adapters or len(args.adapters) == 1, f'args.adapters: {args.adapters}' - model = tuner.from_pretrained( - model, args.resume_from_checkpoint or args.adapters[0], is_trainable=True, **kwargs) + model = tuner.from_pretrained(model, args.resume_from_checkpoint or args.adapters[0], is_trainable=True) else: if args.train_type in extra_tuners: tuner: Tuner = extra_tuners[args.train_type] diff --git a/swift/megatron/init.py b/swift/megatron/init.py index a80d156dab..f816bfd1a1 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -659,9 +659,9 @@ def _patch_peft_ModulesToSaveWrapper(): from megatron.core.dist_checkpointing.mapping import ShardedStateDict from .utils import tuners_sharded_state_dict - ModulesToSaveWrapper = peft_module.ModulesToSaveWrapper + OriginModulesToSaveWrapper = peft_module.ModulesToSaveWrapper - class NewModulesToSaveWrapper(ModulesToSaveWrapper): + class ModulesToSaveWrapper(OriginModulesToSaveWrapper): def __init__(self, module_to_save, *args, **kwargs): tp_group = getattr(module_to_save, 'tp_group', None) @@ -694,7 +694,7 @@ def sharded_state_dict( f'{prefix}modules_to_save.default.weight'] return sharded_state_dict - peft_module.ModulesToSaveWrapper = NewModulesToSaveWrapper + peft_module.ModulesToSaveWrapper = ModulesToSaveWrapper def _patch_TransformerLayer(): @@ -790,9 +790,20 @@ def _worker(plan_shard): FileSystemReader.read_data = read_data +def _patch_TELinear(): + from megatron.core.extensions.transformer_engine import TELinear + + def __repr__(self): + return (f'{type(self).__name__}(in_features={self.in_features}, ' + f'out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})') + + TELinear.__repr__ = __repr__ + + def _patch_megatron(): _patch_flash_attn() _patch_transformer_engine() + _patch_TELinear() _patch__batched_p2p_ops() _patch_mla_attention() _patch_TEGroupedLinear() diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index a2eb2ac1e1..f0fef7d3a2 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -168,7 +168,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() - if hf_model.model_info.is_moe_model: + if args.model_info.is_moe_model: current_convert_kwargs['moe_grouped_gemm'] = True megatron_args = MegatronArguments( **kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) @@ -191,12 +191,8 @@ def convert_hf2mcore(args: ExportArguments) -> None: def convert_mcore2hf(args: ExportArguments) -> None: from swift.megatron import prepare_mcore_model, adapter_state_dict_context - hf_model, template = prepare_model_template(args) + hf_model, template = prepare_model_template(args, load_model=args.to_hf) processor = template.processor - if args.thread_count is None: - checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 - args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB - patch_torch_dist_shard(args.thread_count) megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' @@ -204,12 +200,13 @@ def convert_mcore2hf(args: ExportArguments) -> None: logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() - if hf_model.model_info.is_moe_model: + if args.model_info.is_moe_model: current_convert_kwargs['moe_grouped_gemm'] = True megatron_args = MegatronArguments( **kwargs, **current_convert_kwargs, load=args.mcore_model, + save=args.output_dir if args.to_mcore else None, adapter_load=args.mcore_adapters[0] if args.mcore_adapters else None, torch_dtype=args.torch_dtype) patch_megatron_tokenizer(processor) @@ -226,18 +223,28 @@ def convert_mcore2hf(args: ExportArguments) -> None: logger.info('Merge LoRA...') mg_model = peft_model.merge_and_unload() logger.info('Megatron model created successfully.') - megatron_model_meta.convert_mcore2hf(hf_model, mg_model) - if args.test_convert_precision: - test_convert_precision(hf_model, mg_model, template) - del mg_model - logger.info('Successfully transferred MG model weights to HF model.') - ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load - save_checkpoint( - hf_model, - processor, - args.output_dir, - safe_serialization=args.safe_serialization, - model_dirs=[ckpt_dir, args.model_dir], - max_shard_size=args.max_shard_size, - additional_saved_files=hf_model.model_meta.additional_saved_files) - logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') + if args.to_hf: + megatron_model_meta.convert_mcore2hf(hf_model, mg_model) + if args.test_convert_precision: + test_convert_precision(hf_model, mg_model, template) + del mg_model + logger.info('Successfully transferred MG model weights to HF model.') + ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load + save_checkpoint( + hf_model, + processor, + args.output_dir, + safe_serialization=args.safe_serialization, + model_dirs=[ckpt_dir, args.model_dir], + max_shard_size=args.max_shard_size, + additional_saved_files=hf_model.model_meta.additional_saved_files) + logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') + elif args.to_mcore: + if args.thread_count is None: + checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 + args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB + patch_torch_dist_shard(args.thread_count) + + args.save_args() + mg_save_checkpoint(1, [mg_model], None, None, 0) + logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.')