-
Notifications
You must be signed in to change notification settings - Fork 897
[megatron] support export lora to_mcore #5445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,7 +168,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: | |
logger.info(f'megatron_config: {kwargs}') | ||
_check_megatron_kwargs(kwargs) | ||
current_convert_kwargs = convert_kwargs.copy() | ||
if hf_model.model_info.is_moe_model: | ||
if args.model_info.is_moe_model: | ||
current_convert_kwargs['moe_grouped_gemm'] = True | ||
megatron_args = MegatronArguments( | ||
**kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) | ||
|
@@ -191,25 +191,22 @@ def convert_hf2mcore(args: ExportArguments) -> None: | |
|
||
def convert_mcore2hf(args: ExportArguments) -> None: | ||
from swift.megatron import prepare_mcore_model, adapter_state_dict_context | ||
hf_model, template = prepare_model_template(args) | ||
hf_model, template = prepare_model_template(args, load_model=args.to_hf) | ||
processor = template.processor | ||
if args.thread_count is None: | ||
checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 | ||
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB | ||
patch_torch_dist_shard(args.thread_count) | ||
|
||
megatron_model_meta = get_megatron_model_meta(args.model_type) | ||
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' | ||
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) | ||
logger.info(f'megatron_config: {kwargs}') | ||
_check_megatron_kwargs(kwargs) | ||
current_convert_kwargs = convert_kwargs.copy() | ||
if hf_model.model_info.is_moe_model: | ||
if args.model_info.is_moe_model: | ||
current_convert_kwargs['moe_grouped_gemm'] = True | ||
megatron_args = MegatronArguments( | ||
**kwargs, | ||
**current_convert_kwargs, | ||
load=args.mcore_model, | ||
save=args.output_dir if args.to_mcore else None, | ||
adapter_load=args.mcore_adapters[0] if args.mcore_adapters else None, | ||
torch_dtype=args.torch_dtype) | ||
patch_megatron_tokenizer(processor) | ||
|
@@ -226,18 +223,28 @@ def convert_mcore2hf(args: ExportArguments) -> None: | |
logger.info('Merge LoRA...') | ||
mg_model = peft_model.merge_and_unload() | ||
logger.info('Megatron model created successfully.') | ||
megatron_model_meta.convert_mcore2hf(hf_model, mg_model) | ||
if args.test_convert_precision: | ||
test_convert_precision(hf_model, mg_model, template) | ||
del mg_model | ||
logger.info('Successfully transferred MG model weights to HF model.') | ||
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load | ||
save_checkpoint( | ||
hf_model, | ||
processor, | ||
args.output_dir, | ||
safe_serialization=args.safe_serialization, | ||
model_dirs=[ckpt_dir, args.model_dir], | ||
max_shard_size=args.max_shard_size, | ||
additional_saved_files=hf_model.model_meta.additional_saved_files) | ||
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') | ||
if args.to_hf: | ||
megatron_model_meta.convert_mcore2hf(hf_model, mg_model) | ||
if args.test_convert_precision: | ||
test_convert_precision(hf_model, mg_model, template) | ||
del mg_model | ||
logger.info('Successfully transferred MG model weights to HF model.') | ||
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load | ||
save_checkpoint( | ||
hf_model, | ||
processor, | ||
args.output_dir, | ||
safe_serialization=args.safe_serialization, | ||
model_dirs=[ckpt_dir, args.model_dir], | ||
max_shard_size=args.max_shard_size, | ||
additional_saved_files=hf_model.model_meta.additional_saved_files) | ||
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.') | ||
elif args.to_mcore: | ||
if args.thread_count is None: | ||
checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9 | ||
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB | ||
patch_torch_dist_shard(args.thread_count) | ||
|
||
args.save_args() | ||
mg_save_checkpoint(1, [mg_model], None, None, 0) | ||
logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.') | ||
Comment on lines
+226
to
+250
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of However, changing To support both flags concurrently, you could refactor this section to perform both conversions if requested, and then delete the model as suggested below. if args.to_hf:
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template)
logger.info('Successfully transferred MG model weights to HF model.')
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load
save_checkpoint(
hf_model,
processor,
args.output_dir,
safe_serialization=args.safe_serialization,
model_dirs=[ckpt_dir, args.model_dir],
max_shard_size=args.max_shard_size,
additional_saved_files=hf_model.model_meta.additional_saved_files)
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
if args.to_mcore:
if args.thread_count is None:
checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
patch_torch_dist_shard(args.thread_count)
args.save_args()
mg_save_checkpoint(1, [mg_model], None, None, 0)
logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.')
del mg_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The boolean logic
args.to_hf or args.mcore_adapters and args.to_mcore
can be hard to read due to operator precedence. Whileand
has higher precedence thanor
, adding parentheses would make the intent much clearer and prevent potential misinterpretations.