diff --git a/swift/megatron/model/mm_gpt/utils.py b/swift/megatron/model/mm_gpt/utils.py index 100628d378..322ba56d71 100644 --- a/swift/megatron/model/mm_gpt/utils.py +++ b/swift/megatron/model/mm_gpt/utils.py @@ -68,6 +68,7 @@ def __init__(self, config, ignore_init_model_cls=None): ignore_init_model_cls = [ignore_init_model_cls] context_list = [patch_device_map_meta(model_cls) for model_cls in ignore_init_model_cls] context_list.append(patch_hf_initialize_weight()) + kwargs['model_type'] = args.model_info.model_type with ContextManagers(context_list): model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs) self.model_config = model.config