From abba4971df8cbe27eba80ce72bdfbc8219d4b447 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 22 Aug 2025 23:56:31 +0800 Subject: [PATCH 01/31] update --- .../train/multimodal/lora_llm_full_vit/sft.sh | 1 + swift/megatron/model/qwen2_5_vl/vit.py | 20 +++++++++++++++++++ swift/megatron/train/sft.py | 2 +- swift/megatron/utils/convert.py | 7 ++++++- 4 files changed, 28 insertions(+), 2 deletions(-) create mode 100644 swift/megatron/model/qwen2_5_vl/vit.py diff --git a/examples/train/multimodal/lora_llm_full_vit/sft.sh b/examples/train/multimodal/lora_llm_full_vit/sft.sh index 8a8ea0ec2e..523139f3ff 100644 --- a/examples/train/multimodal/lora_llm_full_vit/sft.sh +++ b/examples/train/multimodal/lora_llm_full_vit/sft.sh @@ -1,5 +1,6 @@ # 4 * 22GiB # vit/merger lr 1e-5; llm lora lr 1e-4 +# Note: not support resume_from_checkpoint (only support resume_only_model) NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ MAX_PIXELS=1003520 \ diff --git a/swift/megatron/model/qwen2_5_vl/vit.py b/swift/megatron/model/qwen2_5_vl/vit.py new file mode 100644 index 0000000000..a2c6acd179 --- /dev/null +++ b/swift/megatron/model/qwen2_5_vl/vit.py @@ -0,0 +1,20 @@ + +from megatron.core.models.huggingface import HuggingFaceModule + +from megatron.training import get_args +class Qwen2_5VL_Vit(HuggingFaceModule): + + def __init__(self, config): + super().__init__(config) + args = get_args() + model_dir = args.model_info.model_dir + model, _ = get_model_tokenizer(model_dir, return_dummy_model=True) + self.model = model.visual + + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + + def get_input_embeds(self, input_embeds): + self() + diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index f7eb8b57ec..6043097d70 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -25,9 +25,9 @@ def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = No super(SwiftSft, self).__init__(args) args = self.args _, self.processor = args.get_model_processor(load_model=False) + self._prepare_template() patch_megatron_tokenizer(self.processor) args.init_model_args(self.processor, self.processor.model_info.config) - self._prepare_template() self.template.use_megatron = True args.save_args(args.save) self.trainer = self.prepare_trainer() diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index cc8a70307c..3a6f2d82c5 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -165,7 +165,10 @@ def convert_hf2mcore(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + config = processor.model_info.config + if args.model_meta.is_multimodal: + config = config.text_config + kwargs = megatron_model_meta.convert_hf_config(config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() @@ -175,6 +178,8 @@ def convert_hf2mcore(args: ExportArguments) -> None: **kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype) patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() + extra_args['model_info'] = args.model_info + extra_args['model_meta'] = args.model_meta extra_args_provider = megatron_model_meta.extra_args_provider initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args) From 0883b8493cab05a8af708d4e1d49085f799db769 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 24 Aug 2025 20:27:17 +0800 Subject: [PATCH 02/31] update --- swift/llm/infer/utils.py | 2 ++ swift/megatron/model/__init__.py | 2 +- swift/megatron/model/constant.py | 1 + swift/megatron/model/gpt/__init__.py | 14 ++++++++++++++ swift/megatron/model/gpt/hf2mcore.py | 4 ++-- swift/megatron/model/gpt_model.py | 9 +++++++++ swift/megatron/model/qwen2_5_vl/__init__.py | 15 +++++++++++++++ swift/megatron/model/qwen2_5_vl/convert.py | 21 +++++++++++++++++++++ swift/megatron/model/qwen2_5_vl/vit.py | 15 ++++++++------- swift/megatron/model/register.py | 3 ++- 10 files changed, 75 insertions(+), 11 deletions(-) create mode 100644 swift/megatron/model/qwen2_5_vl/__init__.py create mode 100644 swift/megatron/model/qwen2_5_vl/convert.py diff --git a/swift/llm/infer/utils.py b/swift/llm/infer/utils.py index 3ebcbf5d8b..49f35fc646 100644 --- a/swift/llm/infer/utils.py +++ b/swift/llm/infer/utils.py @@ -145,6 +145,8 @@ def prepare_model_template(args, **kwargs): model, processor = args.get_model_processor(**kwargs) template = args.get_template(processor) if model is not None: + if template.use_model: + template.model = model model = prepare_adapter(args, model) update_generation_config_eos_token(model.generation_config, template) return model, template diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index 3d13a8d1b5..5cd8834e90 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import gpt +from . import gpt, qwen2_5_vl from .constant import MegatronModelType from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/constant.py b/swift/megatron/model/constant.py index 8eebb6aa76..82e9b38913 100644 --- a/swift/megatron/model/constant.py +++ b/swift/megatron/model/constant.py @@ -1,3 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. class MegatronModelType: gpt = 'gpt' + qwen2_5_vl = 'qwen2_5_vl' diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index 092566d543..12717ffd19 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -1,4 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from dataclasses import dataclass +from typing import Any, Callable, Dict + +from torch import nn +from transformers import PretrainedConfig + from swift.llm import ModelType from ..constant import MegatronModelType from ..register import MegatronModelMeta, register_megatron_model @@ -52,3 +58,11 @@ ModelType.glm4_5, ModelType.deepseek_v3_1, ], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore)) + + +@dataclass +class GptMegatronModelMeta(MegatronModelMeta): + model_provider: Callable[[], nn.Module] = model_provider + convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] = convert_gpt_hf_config + convert_mcore2hf: Callable[[nn.Module, nn.Module], None] = convert_mcore2hf + convert_hf2mcore: Callable[[nn.Module, nn.Module], None] = convert_hf2mcore diff --git a/swift/megatron/model/gpt/hf2mcore.py b/swift/megatron/model/gpt/hf2mcore.py index 93a8d9b36e..76780be641 100644 --- a/swift/megatron/model/gpt/hf2mcore.py +++ b/swift/megatron/model/gpt/hf2mcore.py @@ -91,7 +91,7 @@ def set_mlp_state(args, mg_mlp, hf_mlp): def set_layer_state(args, mg_model, hf_model, layer_idx): mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] + hf_layer = hf_model.layers[layer_idx] if args.multi_latent_attention: set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) mg_layer.input_layernorm.weight.data.copy_(hf_layer.input_layernorm.weight) @@ -115,4 +115,4 @@ def convert_hf2mcore(hf_model, mg_model): mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) mg_model.decoder.final_layernorm.weight.data.copy_(hf_model.model.norm.weight) for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model, layer_idx) + set_layer_state(args, mg_model, hf_model.model, layer_idx) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 62a1448190..2ac9fecb2c 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -11,6 +11,7 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training import get_args from swift.utils import get_logger from .rope import dynamic_rope_update, get_rope_inv_freq @@ -91,6 +92,11 @@ def __init__( logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') + args = get_args() + self.visual = None + if args.megatron_model_meta.visual is not None: + self.visual = args.megatron_model_meta.visual(config) + @contextmanager def _patch_apply_rotary_pos_emb(self): if self.attention_scaling == 1.: @@ -118,6 +124,7 @@ def forward( attention_mask: torch.Tensor, decoder_input: torch.Tensor = None, labels: torch.Tensor = None, + multimodal_data: Optional[Dict[str, Any]] = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, @@ -141,6 +148,8 @@ def forward( pass elif self.pre_process: decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if self.visual is not None: + decoder_input = self.visual.get_inputs_embeds(decoder_input, multimodal_data) else: # intermediate stage of pipeline # decoder will get hidden_states from encoder.input_tensor diff --git a/swift/megatron/model/qwen2_5_vl/__init__.py b/swift/megatron/model/qwen2_5_vl/__init__.py new file mode 100644 index 0000000000..5dc8a61510 --- /dev/null +++ b/swift/megatron/model/qwen2_5_vl/__init__.py @@ -0,0 +1,15 @@ +from swift.llm import ModelType +from ..constant import MegatronModelType +from ..gpt import GptMegatronModelMeta +from ..register import MegatronModelMeta, register_megatron_model +from .convert import convert_hf2mcore_qwen2_5_vl, convert_mcore2hf_qwen2_5_vl +from .vit import Qwen2_5VL_Vit + +register_megatron_model( + GptMegatronModelMeta( + MegatronModelType.qwen2_5_vl, [ + ModelType.qwen2_5_vl, + ], + convert_hf2mcore=convert_hf2mcore_qwen2_5_vl, + convert_mcore2hf=convert_mcore2hf_qwen2_5_vl, + visual=Qwen2_5VL_Vit)) diff --git a/swift/megatron/model/qwen2_5_vl/convert.py b/swift/megatron/model/qwen2_5_vl/convert.py new file mode 100644 index 0000000000..0d937ca451 --- /dev/null +++ b/swift/megatron/model/qwen2_5_vl/convert.py @@ -0,0 +1,21 @@ +from megatron.training import get_args + +from swift.utils import deep_getattr +from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore + + +def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): + language_model = hf_model.model.language_model + args = get_args() + # language_model + mg_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) + if args.untie_embeddings_and_output_weights: + mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) + mg_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) + for layer_idx in range(args.num_layers): + set_layer_state_hf2mcore(args, mg_model, language_model, layer_idx) + mg_model.visual.model.load_state_dict(hf_model.model.visual.state_dict()) + + +def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): + return convert_mcore2hf(hf_model, mg_model) diff --git a/swift/megatron/model/qwen2_5_vl/vit.py b/swift/megatron/model/qwen2_5_vl/vit.py index a2c6acd179..6a331e1eca 100644 --- a/swift/megatron/model/qwen2_5_vl/vit.py +++ b/swift/megatron/model/qwen2_5_vl/vit.py @@ -1,20 +1,21 @@ - from megatron.core.models.huggingface import HuggingFaceModule - from megatron.training import get_args + +from swift.llm import get_model_tokenizer + + class Qwen2_5VL_Vit(HuggingFaceModule): - + def __init__(self, config): super().__init__(config) args = get_args() model_dir = args.model_info.model_dir model, _ = get_model_tokenizer(model_dir, return_dummy_model=True) self.model = model.visual + self.model.to_empty(device='cpu') - def forward(self, *args, **kwargs): return self.model(*args, **kwargs) - - def get_input_embeds(self, input_embeds): - self() + def get_inputs_embeds(self, inputs_embeds, multimodal_data): + print() diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index 950a68ede2..fc8b95f80c 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from argparse import ArgumentParser from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Type import torch.nn as nn from transformers import PretrainedConfig @@ -20,6 +20,7 @@ class MegatronModelMeta: convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] convert_mcore2hf: Callable[[nn.Module, nn.Module], None] convert_hf2mcore: Callable[[nn.Module, nn.Module], None] + visual: Optional[Type[nn.Module]] = None extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None From bdbaa9a9732bd3980d68784d6ec6297e2bf6e729 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 24 Aug 2025 20:27:42 +0800 Subject: [PATCH 03/31] update --- swift/megatron/utils/convert.py | 69 +++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index 3a6f2d82c5..d46bd162c6 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -3,6 +3,7 @@ import math from contextlib import contextmanager from dataclasses import fields +from typing import Any, Dict import torch import torch.nn as nn @@ -12,7 +13,7 @@ from megatron.training.utils import get_ltor_masks_and_position_ids from swift.llm import ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device -from swift.utils import get_logger, get_n_params_grads +from swift.utils import deep_getattr, get_logger, get_n_params_grads from ..argument import MegatronArguments from ..model import get_megatron_model_meta from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard @@ -78,31 +79,64 @@ def _to_cpu_hook(module, args, output): hook.remove() +def get_examples(is_multimodal: bool) -> Dict[str, Any]: + if is_multimodal: + data = { + 'messages': [{ + 'role': 'user', + 'content': 'describe the image.' + }, { + 'role': + 'assistant', + 'content': + 'The image depicts a close-up of a kitten with striking features. ' + 'The kitten has a white and gray coat with distinct black stripes, ' + 'particularly noticeable on its face and ears. Its eyes are large ' + 'and expressive, with a captivating blue hue that stands out against ' + "the darker fur around them. The kitten's nose is small and pink, " + 'and it has long, delicate whiskers extending from either side of its mouth. ' + "The background is blurred, drawing attention to the kitten's face and " + 'making it the focal point of the image. The overall impression is ' + 'one of cuteness and charm.' + }], + 'images': ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'] + } + else: + data = { + 'messages': [ + { + 'role': 'user', + 'content': 'Introduction to ms-swift.' + }, + { + 'role': + 'assistant', + 'content': + 'ms-swift is an official framework provided by the ModelScope community for fine-tuning ' + 'and deploying large language models and multi-modal large models.' + }, + ] + } + return data + + def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float32): _test_params_sum(hf_model) _test_params_sum(mg_model) template.set_mode('train') - inputs = template.encode({ - 'messages': [ - { - 'role': 'user', - 'content': 'Introduction to ms-swift.' - }, - { - 'role': - 'assistant', - 'content': - 'ms-swift is an official framework provided by the ModelScope community for fine-tuning ' - 'and deploying large language models and multi-modal large models.' - }, - ] - }) + template.register_post_encode_hook([hf_model]) + is_multimodal = template.model_meta.is_multimodal + inputs = get_examples(is_multimodal) + inputs = template.encode(inputs) inputs = to_device(template.data_collator([inputs]), 'cuda') HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) share_embedding = mg_model.share_embeddings_and_output_weights - hf_modules = _find_modules(hf_model) + if is_multimodal: + _, inputs = template.pre_forward_hook(hf_model, None, inputs) + language_model = deep_getattr(hf_model, template.model_meta.model_arch.language_model[0]) + hf_modules = _find_modules(language_model) with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding): hf_logits = hf_model(**inputs).logits hf_model = hf_model.to('cpu') @@ -180,6 +214,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: extra_args = megatron_args.parse_to_megatron() extra_args['model_info'] = args.model_info extra_args['model_meta'] = args.model_meta + extra_args['megatron_model_meta'] = megatron_model_meta extra_args_provider = megatron_model_meta.extra_args_provider initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args) From 0e60545fdaec54cbdc593366102b9a3ea31932bf Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Fri, 29 Aug 2025 23:37:35 +0800 Subject: [PATCH 04/31] update --- swift/megatron/model/gpt_model.py | 6 ++- swift/megatron/model/qwen2_5_vl/vit.py | 57 ++++++++++++++++++++++++-- swift/megatron/utils/convert.py | 36 +++++++++------- 3 files changed, 78 insertions(+), 21 deletions(-) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 2aee76c395..dc58f495bf 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -124,7 +124,6 @@ def forward( attention_mask: torch.Tensor = None, decoder_input: torch.Tensor = None, labels: torch.Tensor = None, - multimodal_data: Optional[Dict[str, Any]] = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, @@ -150,7 +149,10 @@ def forward( elif self.pre_process: decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) if self.visual is not None: - decoder_input = self.visual.get_inputs_embeds(decoder_input, multimodal_data) + kwargs.update({'input_ids': input_ids}) + decoder_input = decoder_input.transpose(0, 1) + decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs) + decoder_input = decoder_input.transpose(0, 1) else: # intermediate stage of pipeline # decoder will get hidden_states from encoder.input_tensor diff --git a/swift/megatron/model/qwen2_5_vl/vit.py b/swift/megatron/model/qwen2_5_vl/vit.py index 6a331e1eca..fc45c0d7d8 100644 --- a/swift/megatron/model/qwen2_5_vl/vit.py +++ b/swift/megatron/model/qwen2_5_vl/vit.py @@ -1,7 +1,7 @@ from megatron.core.models.huggingface import HuggingFaceModule from megatron.training import get_args -from swift.llm import get_model_tokenizer +from swift.llm import get_model_tokenizer, to_device class Qwen2_5VL_Vit(HuggingFaceModule): @@ -12,10 +12,59 @@ def __init__(self, config): model_dir = args.model_info.model_dir model, _ = get_model_tokenizer(model_dir, return_dummy_model=True) self.model = model.visual - self.model.to_empty(device='cpu') + self.model_config = model.config + self.model.to_empty(device='cuda') def forward(self, *args, **kwargs): return self.model(*args, **kwargs) - def get_inputs_embeds(self, inputs_embeds, multimodal_data): - print() + def get_inputs_embeds(self, inputs_embeds, **kwargs): + input_ids = kwargs['input_ids'] + pixel_values = kwargs.get('pixel_values') + pixel_values_videos = kwargs.get('pixel_values_videos') + image_grid_thw = kwargs.get('image_grid_thw') + video_grid_thw = kwargs.get('video_grid_thw') + dtype = self.model.dtype + if pixel_values is None and pixel_values_videos is None: # plain-text + from PIL import Image + images = [Image.new('RGB', (32, 32), (0, 0, 0))] + media_inputs = self.processor.image_processor(images=images, return_tensors='pt') + device = input_ids.device + media_inputs = to_device(media_inputs, device) + pixel_values = media_inputs['pixel_values'].type(dtype) + image_embeds = self.model(pixel_values, grid_thw=media_inputs['image_grid_thw']) + inputs_embeds += image_embeds.mean() * 0. + else: + if pixel_values is None: + pixel_values_mixed = pixel_values_videos + grid_thw = video_grid_thw + elif pixel_values_videos is None: + pixel_values_mixed = pixel_values + grid_thw = image_grid_thw + else: + pixel_values_mixed = torch.concat([pixel_values, pixel_values_videos], dim=0) + grid_thw = torch.concat([image_grid_thw, video_grid_thw], dim=0) + pixel_values_mixed = pixel_values_mixed.type(dtype) + mixed_embeds = self.model(pixel_values_mixed, grid_thw=grid_thw) + if pixel_values is None: + image_embeds = None + video_embeds = mixed_embeds + elif pixel_values_videos is None: + image_embeds = mixed_embeds + video_embeds = None + else: + merge_length = self.processor.image_processor.merge_size**2 + image_tokens = (image_grid_thw.prod(dim=-1) // merge_length).sum() + image_embeds = mixed_embeds[:image_tokens] + video_embeds = mixed_embeds[image_tokens:] + + if image_embeds is not None: + image_mask = (input_ids == self.model_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if video_embeds is not None: + video_mask = (input_ids == self.model_config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + return inputs_embeds diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index d46bd162c6..f5c03370dd 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -40,17 +40,23 @@ def _test_params_sum(model): logger.info(f'zero_count: {zero_count}') -def _find_modules(model, recurse: bool = True): +def _find_modules(model, recurse: bool = True, prefix='', ignore_modules=None): + ignore_modules = ignore_modules or [] + for k in ignore_modules: + if prefix.startswith(k): + return [] + else: + named_children = list(model.named_children()) + modules = [] - children = list(model.children()) - for module in children: + for n, module in named_children: if module.__class__ is nn.ModuleList: - modules += _find_modules(module, False) + modules += _find_modules(module, False, prefix=f'{prefix}{n}.', ignore_modules=ignore_modules) elif recurse: - modules += _find_modules(module) + modules += _find_modules(module, prefix=f'{prefix}{n}.', ignore_modules=ignore_modules) else: modules.append(module) - if not children: + if not named_children: modules.append(model) return modules @@ -133,10 +139,10 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) share_embedding = mg_model.share_embeddings_and_output_weights - if is_multimodal: - _, inputs = template.pre_forward_hook(hf_model, None, inputs) - language_model = deep_getattr(hf_model, template.model_meta.model_arch.language_model[0]) - hf_modules = _find_modules(language_model) + model_arch = hf_model.model_meta.model_arch + ignore_modules = [] if model_arch is None else (model_arch.vision_tower + model_arch.aligner) + + hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding): hf_logits = hf_model(**inputs).logits hf_model = hf_model.to('cpu') @@ -151,14 +157,14 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float # packed_seq_params = get_packed_seq_params(position_ids) # attention_mask = None mg_model.config.fp8 = None # compat fp8 - mg_modules = _find_modules(mg_model) + mg_modules = _find_modules(mg_model, ignore_modules=['visual']) + kwargs = {k: v for k, v in inputs.items() if k not in ['input_ids', 'attention_mask', 'labels']} + if 'position_ids' not in kwargs: + kwargs['position_ids'] = position_ids with torch.inference_mode(), _model_cpu_forward_context( mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding): mg_logits = mg_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - packed_seq_params=packed_seq_params) + input_ids=input_ids, attention_mask=attention_mask, packed_seq_params=packed_seq_params, **kwargs) token_mean_diff = (mg_logits - hf_logits).abs().mean(dim=-1) mean_diff = token_mean_diff.mean().item() From b349b567ea190ed8cf8d1df1d8244e8c2ec17ba6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 31 Aug 2025 15:40:12 +0800 Subject: [PATCH 05/31] update --- swift/llm/model/patcher.py | 6 +++-- swift/megatron/argument/megatron_args.py | 5 ++++- swift/megatron/model/gpt/config.py | 3 +++ swift/megatron/model/gpt/mcore2hf.py | 4 ++-- swift/megatron/model/gpt_model.py | 7 ++++++ swift/megatron/model/qwen2_5_vl/convert.py | 12 +++++++++- swift/megatron/model/qwen2_5_vl/vit.py | 26 +++++++++++++++++++--- swift/megatron/utils/convert.py | 8 ++++++- 8 files changed, 61 insertions(+), 10 deletions(-) diff --git a/swift/llm/model/patcher.py b/swift/llm/model/patcher.py index ea5b918283..70432a866e 100644 --- a/swift/llm/model/patcher.py +++ b/swift/llm/model/patcher.py @@ -294,8 +294,10 @@ def _new_from_pretrained(cls, *args, **kwargs): if hasattr(cls, '_tp_plan'): # fix tp_plan cls._tp_plan = cls._tp_plan or {} if return_dummy_model: - with torch.device('meta'): - model = cls(copy.deepcopy(kwargs['config'])) + origin_torch_dtype = torch.get_default_dtype() + torch.set_default_dtype(kwargs['config'].torch_dtype) + model = cls(copy.deepcopy(kwargs['config'])) + torch.set_default_dtype(origin_torch_dtype) else: model = from_pretrained(cls, *args, **kwargs) return model diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 9628b4da6a..5c882149a3 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -185,7 +185,8 @@ class MegatronArguments(ExtraMegatronArguments): group_query_attention: Optional[bool] = None num_query_groups: Optional[int] = None max_position_embeddings: Optional[int] = None - position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none'] = 'rope' + position_embedding_type: Optional[Literal['learned_absolute', 'rope', 'mrope', 'relative', 'none']] = None + mrope_section: Optional[List[int]] = None rotary_base: Optional[int] = None rotary_percent: float = 1. rotary_interleaved: Optional[bool] = None @@ -380,6 +381,8 @@ def __post_init__(self): self.eval_interval = self.save_interval if self.seq_length is None: self.seq_length = self.max_position_embeddings + if self.position_embedding_type is None: + self.position_embedding_type = 'rope' if self.tensorboard_dir is None and self.save is not None: self.tensorboard_dir = f'{self.save}/runs' self._init_moe() diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index ec58a28142..112b0cde5d 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -39,6 +39,9 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: res['rotary_interleaved'] = True elif architectures == 'Glm4MoeForCausalLM': res['moe_router_score_function'] = 'sigmoid' + elif architectures == 'Qwen2_5_VLForConditionalGeneration': + res['position_embedding_type'] = 'mrope' + res['mrope_section'] = res['rope_scaling']['mrope_section'] if first_k_dense_replace is not None: res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' if res.get('moe_router_score_function', 'softmax') == 'sigmoid': diff --git a/swift/megatron/model/gpt/mcore2hf.py b/swift/megatron/model/gpt/mcore2hf.py index bd0e480f65..3f063d4559 100644 --- a/swift/megatron/model/gpt/mcore2hf.py +++ b/swift/megatron/model/gpt/mcore2hf.py @@ -88,7 +88,7 @@ def set_mlp_state(args, mg_mlp, hf_mlp): def set_layer_state(args, mg_model, hf_model, layer_idx): mg_layer = mg_model.decoder.layers[layer_idx] - hf_layer = hf_model.model.layers[layer_idx] + hf_layer = hf_model.layers[layer_idx] if args.multi_latent_attention: set_mla_attn_state(args, mg_layer.self_attention, hf_layer.self_attn) @@ -113,4 +113,4 @@ def convert_mcore2hf(hf_model, mg_model): hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight) hf_model.model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) for layer_idx in range(args.num_layers): - set_layer_state(args, mg_model, hf_model, layer_idx) + set_layer_state(args, mg_model, hf_model.model, layer_idx) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index dc58f495bf..975e32e60b 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -183,6 +183,13 @@ def forward( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) + elif self.position_embedding_type in 'mrope': + if self.training or not self.config.flash_decode: + rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) + else: + # Flash decoding uses precomputed cos and sin for RoPE + raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implmented in ' + 'MultimodalRotaryEmbedding yet.') if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None and inference_params): sequence_len_offset = torch.tensor( diff --git a/swift/megatron/model/qwen2_5_vl/convert.py b/swift/megatron/model/qwen2_5_vl/convert.py index 0d937ca451..1e523f5a62 100644 --- a/swift/megatron/model/qwen2_5_vl/convert.py +++ b/swift/megatron/model/qwen2_5_vl/convert.py @@ -2,6 +2,7 @@ from swift.utils import deep_getattr from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore +from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): @@ -18,4 +19,13 @@ def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): - return convert_mcore2hf(hf_model, mg_model) + language_model = hf_model.model.language_model + args = get_args() + # language_model + language_model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) + if args.untie_embeddings_and_output_weights: + hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight) + language_model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) + for layer_idx in range(args.num_layers): + set_layer_state_mcore2hf(args, mg_model, language_model, layer_idx) + hf_model.model.visual.load_state_dict(mg_model.visual.model.state_dict()) diff --git a/swift/megatron/model/qwen2_5_vl/vit.py b/swift/megatron/model/qwen2_5_vl/vit.py index fc45c0d7d8..774bd106e6 100644 --- a/swift/megatron/model/qwen2_5_vl/vit.py +++ b/swift/megatron/model/qwen2_5_vl/vit.py @@ -1,19 +1,39 @@ +from contextlib import contextmanager + +import torch from megatron.core.models.huggingface import HuggingFaceModule from megatron.training import get_args from swift.llm import get_model_tokenizer, to_device +@contextmanager +def patch_device_map_meta(model_cls): + __origin_init__ = model_cls.__init__ + + def __init__(self, *args, **kwargs): + with torch.device('meta'): + __origin_init__(self, *args, **kwargs) + + model_cls.__init__ = __init__ + try: + yield + finally: + model_cls.__init__ = __origin_init__ + + class Qwen2_5VL_Vit(HuggingFaceModule): def __init__(self, config): + from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel super().__init__(config) args = get_args() model_dir = args.model_info.model_dir - model, _ = get_model_tokenizer(model_dir, return_dummy_model=True) - self.model = model.visual + kwargs = {'attn_impl': 'flash_attn'} if args.attention_backend.name == 'flash' else {} + with patch_device_map_meta(Qwen2_5_VLTextModel): + model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs) + self.model = model.visual.to('cuda') self.model_config = model.config - self.model.to_empty(device='cuda') def forward(self, *args, **kwargs): return self.model(*args, **kwargs) diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index f5c03370dd..9ed74c1368 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -244,7 +244,10 @@ def convert_mcore2hf(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' - kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config) + config = processor.model_info.config + if args.model_meta.is_multimodal: + config = config.text_config + kwargs = megatron_model_meta.convert_hf_config(config) logger.info(f'megatron_config: {kwargs}') _check_megatron_kwargs(kwargs) current_convert_kwargs = convert_kwargs.copy() @@ -263,6 +266,9 @@ def convert_mcore2hf(args: ExportArguments) -> None: torch_dtype=args.torch_dtype) patch_megatron_tokenizer(processor) extra_args = megatron_args.parse_to_megatron() + extra_args['model_info'] = args.model_info + extra_args['model_meta'] = args.model_meta + extra_args['megatron_model_meta'] = megatron_model_meta extra_args_provider = megatron_model_meta.extra_args_provider initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=extra_args) From 6a77b0e4f9cdb133e454a344a708a5935dcec518 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 31 Aug 2025 19:53:55 +0800 Subject: [PATCH 06/31] update --- swift/megatron/argument/megatron_args.py | 6 ++++++ swift/megatron/argument/train_args.py | 3 +++ swift/megatron/model/gpt_model.py | 6 +++++- swift/megatron/train/sft.py | 11 +++++++++-- swift/megatron/trainers/base.py | 9 +++++++++ swift/megatron/trainers/utils.py | 3 +++ 6 files changed, 35 insertions(+), 3 deletions(-) diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 5c882149a3..605e8df62d 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -94,6 +94,10 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin): partial_rotary_factor: Optional[float] = None use_shared_expert_gate: Optional[bool] = None + # visual + vit_gradient_checkpointing: bool = True + gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None + @dataclass class MegatronArguments(ExtraMegatronArguments): @@ -377,6 +381,8 @@ def __post_init__(self): self.rope_scaling = json_parse_to_dict(self.rope_scaling) if 'type' in self.rope_scaling and 'rope_type' not in self.rope_scaling: self.rope_scaling['rope_type'] = self.rope_scaling['type'] + if self.gradient_checkpointing_kwargs is not None: + self.gradient_checkpointing_kwargs = json_parse_to_dict(self.gradient_checkpointing_kwargs) if self.eval_interval is None: self.eval_interval = self.save_interval if self.seq_length is None: diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 6f492dc7cd..36d8cbe1be 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -28,6 +28,9 @@ def init_model_args(self, tokenizer, config): setattr(self, k, v) MegatronArguments.__post_init__(self) self.extra_args = self.parse_to_megatron() + self.extra_args['model_info'] = self.model_info + self.extra_args['model_meta'] = self.model_meta + self.extra_args['megatron_model_meta'] = self.megatron_model_meta def _init_save(self): init_process_group(backend=self.ddp_backend, timeout=self.ddp_timeout) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 975e32e60b..a279b93612 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Literal, Optional import torch -from megatron.core import InferenceParams +from megatron.core import InferenceParams, mpu from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel as McoreGPTModel @@ -144,11 +144,15 @@ def forward( # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. # Decoder embedding. + args = get_args() if decoder_input is not None: pass elif self.pre_process: decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) if self.visual is not None: + if args.tensor_model_parallel_size > 1 and args.sequence_parallel: + input_ids = input_ids.chunk( + args.tensor_model_parallel_size, dim=-1)[mpu.get_tensor_model_parallel_rank()] kwargs.update({'input_ids': input_ids}) decoder_input = decoder_input.transpose(0, 1) decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 6043097d70..3a2e3dde55 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -3,6 +3,8 @@ from functools import partial from typing import List, Optional, Union +import torch + from swift.llm.train import SwiftSft from swift.utils import get_logger, is_master, plot_images from ..argument import MegatronTrainArguments @@ -24,12 +26,17 @@ def __init__(self, args: Optional[Union[List[str], MegatronTrainArguments]] = No self.train_msg = {} super(SwiftSft, self).__init__(args) args = self.args - _, self.processor = args.get_model_processor(load_model=False) + if args.model_meta.is_multimodal: + kwargs = {'return_dummy_model': True} + else: + kwargs = {'load_model': False} + with torch.device('meta'): + self.model, self.processor = args.get_model_processor(**kwargs) self._prepare_template() patch_megatron_tokenizer(self.processor) + args.save_args(args.save) args.init_model_args(self.processor, self.processor.model_info.config) self.template.use_megatron = True - args.save_args(args.save) self.trainer = self.prepare_trainer() def _get_data_collator(self): diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index afd6152a69..bcaba4c679 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -258,8 +258,17 @@ def new_model_provider_func(*args, **kwargs): with adapter_state_dict_context(): args.iteration, args.num_floating_point_operations_so_far = load_checkpoint( model, optimizer, opt_param_scheduler, load_arg='adapter_load', strict=False) + if args.model_meta.is_multimodal: + self._prepare_vit_gradient_checkpointing() return model, optimizer, opt_param_scheduler + def _prepare_vit_gradient_checkpointing(self): + visual = self.unwrapped_model.visual.model + args = get_args() + if args.vit_gradient_checkpointing: + visual.gradient_checkpointing_enable(**(args.gradient_checkpointing_kwargs or {})) + visual.enable_input_require_grads() + @staticmethod def _initialize_embedding(model): # compat new_special_tokens diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index a05b57f75f..42687dc00c 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -118,5 +118,8 @@ def get_batch(data_iterator): batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids']) batch['packed_seq_params'].num_samples = num_samples # slice batch along sequence dimension for context parallelism + position_ids = batch.get('real_position_ids') + if position_ids is not None: + batch['position_ids'] = position_ids batch = get_batch_on_this_cp_rank(batch) return batch From a502bbb1dda05f07f9c5a20e30834aef85b9bc2b Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 31 Aug 2025 23:45:41 +0800 Subject: [PATCH 07/31] update --- examples/megatron/multimodal/dense.sh | 31 +++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 examples/megatron/multimodal/dense.sh diff --git a/examples/megatron/multimodal/dense.sh b/examples/megatron/multimodal/dense.sh new file mode 100644 index 0000000000..bc6f248760 --- /dev/null +++ b/examples/megatron/multimodal/dense.sh @@ -0,0 +1,31 @@ +# 4 * 56GiB; 2.3s/it +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=4 \ +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +megatron sft \ + --load Qwen2.5-VL-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite' \ + --tensor_model_parallel_size 2 \ + --sequence_parallel true \ + --packing true \ + --split_dataset_ratio 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --max_epochs 1 \ + --save megatron_output/Qwen2.5-VL-7B-Instruct \ + --save_interval 200 \ + --vit_gradient_checkpointing true \ + --max_length 2048 \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 8 From 6fad478b0c49cbb3f88b2b4ab9d528bfc2fbc056 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 14:15:12 +0800 Subject: [PATCH 08/31] fix --- swift/megatron/model/gpt/__init__.py | 27 +++++++++++---------- swift/megatron/model/gpt_model.py | 2 +- swift/megatron/model/qwen2_5_vl/__init__.py | 6 ++--- swift/megatron/model/qwen2_5_vl/convert.py | 3 --- swift/megatron/trainers/utils.py | 2 +- swift/megatron/utils/convert.py | 4 +-- 6 files changed, 21 insertions(+), 23 deletions(-) diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index f53416b893..dafbb9f8ec 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -9,12 +9,21 @@ from ..constant import MegatronModelType from ..register import MegatronModelMeta, register_megatron_model from .config import convert_gpt_hf_config -from .hf2mcore import convert_hf2mcore -from .mcore2hf import convert_mcore2hf -from .model import model_provider +from .hf2mcore import convert_hf2mcore as convert_hf2mcore_gpt +from .mcore2hf import convert_mcore2hf as convert_mcore2hf_gpt +from .model import model_provider as gpt_model_provider + + +@dataclass +class GPTMegatronModelMeta(MegatronModelMeta): + model_provider: Callable[[], nn.Module] = gpt_model_provider + convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] = convert_gpt_hf_config + convert_mcore2hf: Callable[[nn.Module, nn.Module], None] = convert_mcore2hf_gpt + convert_hf2mcore: Callable[[nn.Module, nn.Module], None] = convert_hf2mcore_gpt + register_megatron_model( - MegatronModelMeta(MegatronModelType.gpt, [ + GPTMegatronModelMeta(MegatronModelType.gpt, [ ModelType.qwen2, ModelType.qwen2_5, ModelType.qwq, @@ -58,12 +67,4 @@ ModelType.ernie, ModelType.glm4_5, ModelType.deepseek_v3_1, - ], model_provider, convert_gpt_hf_config, convert_mcore2hf, convert_hf2mcore)) - - -@dataclass -class GptMegatronModelMeta(MegatronModelMeta): - model_provider: Callable[[], nn.Module] = model_provider - convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] = convert_gpt_hf_config - convert_mcore2hf: Callable[[nn.Module, nn.Module], None] = convert_mcore2hf - convert_hf2mcore: Callable[[nn.Module, nn.Module], None] = convert_hf2mcore + ])) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index a279b93612..6b675cf5d2 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -187,7 +187,7 @@ def forward( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', ) - elif self.position_embedding_type in 'mrope': + elif self.position_embedding_type == 'mrope': if self.training or not self.config.flash_decode: rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) else: diff --git a/swift/megatron/model/qwen2_5_vl/__init__.py b/swift/megatron/model/qwen2_5_vl/__init__.py index 5dc8a61510..fba53304c1 100644 --- a/swift/megatron/model/qwen2_5_vl/__init__.py +++ b/swift/megatron/model/qwen2_5_vl/__init__.py @@ -1,12 +1,12 @@ from swift.llm import ModelType from ..constant import MegatronModelType -from ..gpt import GptMegatronModelMeta -from ..register import MegatronModelMeta, register_megatron_model +from ..gpt import GPTMegatronModelMeta +from ..register import register_megatron_model from .convert import convert_hf2mcore_qwen2_5_vl, convert_mcore2hf_qwen2_5_vl from .vit import Qwen2_5VL_Vit register_megatron_model( - GptMegatronModelMeta( + GPTMegatronModelMeta( MegatronModelType.qwen2_5_vl, [ ModelType.qwen2_5_vl, ], diff --git a/swift/megatron/model/qwen2_5_vl/convert.py b/swift/megatron/model/qwen2_5_vl/convert.py index 1e523f5a62..9da831611b 100644 --- a/swift/megatron/model/qwen2_5_vl/convert.py +++ b/swift/megatron/model/qwen2_5_vl/convert.py @@ -1,6 +1,5 @@ from megatron.training import get_args -from swift.utils import deep_getattr from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf @@ -8,7 +7,6 @@ def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): language_model = hf_model.model.language_model args = get_args() - # language_model mg_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) if args.untie_embeddings_and_output_weights: mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) @@ -21,7 +19,6 @@ def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): language_model = hf_model.model.language_model args = get_args() - # language_model language_model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) if args.untie_embeddings_and_output_weights: hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 42687dc00c..847314f166 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -118,7 +118,7 @@ def get_batch(data_iterator): batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids']) batch['packed_seq_params'].num_samples = num_samples # slice batch along sequence dimension for context parallelism - position_ids = batch.get('real_position_ids') + position_ids = batch.get('real_position_ids') # fix Qwen2.5-VL if position_ids is not None: batch['position_ids'] = position_ids batch = get_batch_on_this_cp_rank(batch) diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index 9ed74c1368..c72787aaa1 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -13,7 +13,7 @@ from megatron.training.utils import get_ltor_masks_and_position_ids from swift.llm import ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device -from swift.utils import deep_getattr, get_logger, get_n_params_grads +from swift.utils import get_logger, get_n_params_grads from ..argument import MegatronArguments from ..model import get_megatron_model_meta from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard @@ -145,7 +145,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding): hf_logits = hf_model(**inputs).logits - hf_model = hf_model.to('cpu') + hf_model.to('cpu') input_ids = inputs['input_ids'] attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True) From 2878d15b3ead2f82585b575be7003acc5d56e262 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 14:28:39 +0800 Subject: [PATCH 09/31] update --- ...\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 +- ...\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" | 2 ++ docs/source_en/Instruction/Command-line-parameters.md | 2 +- docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 2 ++ 4 files changed, 6 insertions(+), 2 deletions(-) diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 7886f6548b..2464590cef 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -159,7 +159,7 @@ - 🔥aligner_lr: 当训练多模态大模型时,该参数指定aligner的学习率,默认为None,等于learning_rate。 - lr_scheduler_type: lr_scheduler类型,默认为'cosine'。 - lr_scheduler_kwargs: lr_scheduler其他参数。默认为None。 -- 🔥gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。 +- gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。 - 注意:当使用DDP而不使用deepspeed/fsdp,且gradient_checkpointing_kwargs为None,会默认设置其为`'{"use_reentrant": false}'`。 - full_determinism: 确保训练中获得可重现的结果,注意:这会对性能产生负面影响。默认为False。 - 🔥report_to: 默认值为`tensorboard`。你也可以指定`--report_to tensorboard wandb swanlab`、`--report_to all`。 diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index b76390dd75..2ce3444d5e 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -234,6 +234,8 @@ Megatron训练参数继承自Megatron参数和基本参数(与ms-swift共用da - 若要自定义attention_mask,你可以设置`--padding_free false`。 - 注意:Megatron-SWIFT训练特性优先支持padding_free格式,若非特殊情况,请勿修改该值。 - mlp_padding_free: 默认为False。用于padding_free设置为false时,对mlp进行padding_free优化。这可以在自定义attention_mask的同时,提升训练速度和减少显存占用。 +- vit_gradient_checkpointing: 多模态模型训练时,是否对vit部分开启gradient_checkpointing。默认为True。 +- gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。 - 🔥packing: 是否使用序列packing,默认为False。当前支持CPT/SFT/DPO。 - packing_length: packing的长度。默认为None,设置为max_length。 - streaming: 流式读取并处理数据集,默认False。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 2606710f79..df68f9ca06 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -162,7 +162,7 @@ This parameter list inherits from transformers `Seq2SeqTrainingArguments`, with - 🔥aligner_lr: When training a multimodal large model, this parameter specifies the learning rate for the aligner. By default, it is set to None, which means it equals `learning_rate`. - lr_scheduler_type: Type of lr_scheduler, defaults to 'cosine'. - lr_scheduler_kwargs: Other parameters for the lr_scheduler, defaults to None. -- 🔥gradient_checkpointing_kwargs: Parameters for `torch.utils.checkpoint`. For example, set as `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to None. +- gradient_checkpointing_kwargs: Parameters for `torch.utils.checkpoint`. For example, set as `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to None. - Note: When using DDP without DeepSpeed/FSDP, and `gradient_checkpointing_kwargs` is `None`, it will default to `'{"use_reentrant": false}'`. - full_determinism: Ensures reproducible results during training. Note: This will negatively impact performance. Defaults to False. - 🔥report_to: Default value is `tensorboard`. You can also specify `--report_to tensorboard wandb swanlab` or `--report_to all`. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 0501c3db87..f51e730940 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -249,6 +249,8 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - If you wish to customize the attention_mask, you can set `--padding_free false`. - Note: The Megatron-SWIFT training feature prioritizes support for the padding-free format. Unless under special circumstances, please do not modify this value. - mlp_padding_free: The default is False. This is used for applying padding-free optimization to the MLP when padding_free is set to false. It allows for improved training speed and reduced memory usage while customizing the attention_mask. +- vit_gradient_checkpointing: Whether to enable gradient checkpointing for the ViT part during multimodal model training. Default: True. +- gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Default: None. - 🔥packing: Whether to use sequence packing, defaults to False. Currently supports CPT/SFT/DPO. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - streaming: Stream data loading and processing, default is False. From 0cbada080f29a109f7567f2099c1bc9ddfd284aa Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 17:05:17 +0800 Subject: [PATCH 10/31] update --- ...44\350\241\214\345\217\202\346\225\260.md" | 6 +- ...44\350\241\214\345\217\202\346\225\260.md" | 4 + .../Instruction/Command-line-parameters.md | 8 +- .../Megatron-SWIFT/Command-line-parameters.md | 4 + swift/megatron/argument/megatron_args.py | 3 + swift/megatron/argument/train_args.py | 2 +- swift/megatron/model/__init__.py | 2 +- swift/megatron/model/gpt/__init__.py | 116 +++++++++--------- swift/megatron/model/gpt_model.py | 14 --- swift/megatron/model/mm_gpt/__init__.py | 1 + .../vit.py => mm_gpt/qwen2_5_vl.py} | 54 +++++--- swift/megatron/model/mm_gpt/utils.py | 34 +++++ swift/megatron/model/mm_gpt_model.py | 70 +++++++++++ .../model/{gpt/model.py => model_provider.py} | 11 +- swift/megatron/model/qwen2_5_vl/__init__.py | 15 --- swift/megatron/model/qwen2_5_vl/convert.py | 28 ----- swift/megatron/model/register.py | 8 +- swift/megatron/utils/convert.py | 7 +- 18 files changed, 239 insertions(+), 148 deletions(-) create mode 100644 swift/megatron/model/mm_gpt/__init__.py rename swift/megatron/model/{qwen2_5_vl/vit.py => mm_gpt/qwen2_5_vl.py} (62%) create mode 100644 swift/megatron/model/mm_gpt/utils.py create mode 100644 swift/megatron/model/mm_gpt_model.py rename swift/megatron/model/{gpt/model.py => model_provider.py} (94%) delete mode 100644 swift/megatron/model/qwen2_5_vl/__init__.py delete mode 100644 swift/megatron/model/qwen2_5_vl/convert.py diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 2464590cef..4ba75ad711 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -211,10 +211,10 @@ - hub_private_repo: 默认为False。 ### Tuner参数 -- 🔥freeze_llm: 该参数只对多模态模型生效,可用于全参和LoRA,但含义不同。若是全参数训练,将freeze_llm设置为True将会将llm部分权重进行冻结,若是LoRA训练且`target_modules`设置为'all-linear',将freeze_llm设置为True将会取消在llm部分添加LoRA模块。该参数默认为False。 -- 🔥freeze_vit: 该参数只对多模态模型生效,可用于全参和LoRA,含义参考`freeze_llm`。默认为True。 +- 🔥freeze_llm: 该参数只对多模态模型生效,可用于全参和LoRA,但会产生不同的效果。若是全参数训练,将freeze_llm设置为True将会将LLM部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_llm设置为True将会取消在LLM部分添加LoRA模块。该参数默认为False。 +- 🔥freeze_vit: 该参数只对多模态模型生效,可用于全参和LoRA,但会产生不同的效果。若是全参数训练,将freeze_vit设置为True将会将vit部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_vit设置为True将会取消在vit部分添加LoRA模块。该参数默认为True。 - 注意:这里的vit不仅限于vision_tower, 也包括audio_tower。 -- 🔥freeze_aligner: 该参数只对多模态模型生效,可用于全参和LoRA,含义参考`freeze_llm`。默认为True。 +- 🔥freeze_aligner: 该参数只对多模态模型生效,可用于全参和LoRA,但会产生不同的效果。若是全参数训练,将freeze_aligner设置为True将会将aligner(也称为projector)部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_aligner设置为True将会取消在aligner部分添加LoRA模块。该参数默认为True。 - 🔥target_modules: 指定lora模块, 默认为`['all-linear']`。你也可以设置为module的后缀,例如:`--target_modules q_proj k_proj v_proj`。该参数不限于LoRA,可用于其他tuners。 - 注意:在LLM和多模态LLM中,'all-linear'的行为有所不同。若是LLM则自动寻找除lm_head外的linear并附加tuner;若是多模态LLM,则默认只在LLM上附加tuner,该行为可以被`freeze_llm`、`freeze_vit`、`freeze_aligner`控制。 - 🔥target_regex: 指定lora模块的regex表达式,默认为`None`。如果该值传入,则target_modules参数失效。该参数不限于LoRA,可用于其他tuners。 diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 2ce3444d5e..b4864b250f 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -192,6 +192,10 @@ **Tuner参数**: - train_type: 可选为'lora'和'full'。默认为'full'。 +- 🔥freeze_llm: 该参数只对多模态模型生效,可用于全参和LoRA,但会产生不同的效果。若是全参数训练,将freeze_llm设置为True将会将LLM部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_llm设置为True将会取消在LLM部分添加LoRA模块。该参数默认为False。 +- 🔥freeze_vit: 该参数只对多模态模型生效,可用于全参和LoRA,但会产生不同的效果。若是全参数训练,将freeze_vit设置为True将会将vit部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_vit设置为True将会取消在vit部分添加LoRA模块。该参数默认为True。 + - 注意:这里的vit不仅限于vision_tower, 也包括audio_tower。 +- 🔥freeze_aligner: 该参数只对多模态模型生效,可用于全参和LoRA,但会产生不同的效果。若是全参数训练,将freeze_aligner设置为True将会将aligner(也称为projector)部分权重进行冻结;若是LoRA训练且`target_modules`设置为'all-linear',将freeze_aligner设置为True将会取消在aligner部分添加LoRA模块。该参数默认为True。 全参数训练: - freeze_parameters: 需要被冻结参数的前缀,默认为`[]`。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index df68f9ca06..15c4f96053 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -215,10 +215,10 @@ Other important parameters: ### Tuner Arguments -- 🔥freeze_llm: This parameter is only effective for multimodal models and can be used for full parameter training and LoRA, but with different meanings. In full parameter training, setting freeze_llm to True will freeze some of the LLM weights. In LoRA training, if `target_modules` is set to 'all-linear', setting freeze_llm to True will prevent adding LoRA modules to the LLM part. The default is False. -- 🔥freeze_vit: This parameter is only effective for multimodal models and can be used for full parameter training and LoRA, with similar meanings as `freeze_llm`. The default is True. - - Note: Here, "vit" refers not only to the vision_tower but also includes the audio_tower. -- 🔥freeze_aligner: This parameter is only effective for multimodal models and can be used for full parameter training and LoRA, with similar meanings as `freeze_llm`. The default is True. +- 🔥 freeze_llm: This parameter only takes effect for multimodal models and can be used in both full-parameter and LoRA training, but with different behaviors. In full-parameter training, setting `freeze_llm` to `True` will freeze the weights of the LLM component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_llm` to `True` will prevent LoRA modules from being added to the LLM component. The default value is `False`. +- 🔥 freeze_vit: This parameter only applies to multimodal models and can be used in both full-parameter and LoRA training, though with different effects. In full-parameter training, setting `freeze_vit` to `True` will freeze the weights of the ViT component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_vit` to `True` will prevent LoRA modules from being added to the ViT component. The default value is `True`. + - Note: The term "ViT" here refers not only to the vision tower but also includes the audio tower. +- 🔥 freeze_aligner: This parameter is only effective for multimodal models and can be used in both full-parameter and LoRA training, with differing outcomes. In full-parameter training, setting `freeze_aligner` to `True` will freeze the weights of the aligner (also known as the projector) component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_aligner` to `True` will prevent LoRA modules from being added to the aligner component. The default value is `True`. - 🔥 target_modules: Specifies the LoRA modules. The default is `['all-linear']`, but you can also pass layer-name suffixes, e.g. `--target_modules q_proj k_proj v_proj`. This argument is not restricted to LoRA and can be used with other tuners as well. - Note: The behavior of the special value `'all-linear'` differs between plain LLMs and multimodal LLMs. For a standard LLM, it automatically locates every linear layer except `lm_head` and attaches a tuner. For a multimodal LLM, it attaches the tuner only to the LLM component by default. This default can be changed with the `freeze_llm`, `freeze_vit`, and `freeze_aligner` options. - 🔥target_regex: Specifies a regex expression for LoRA modules, with a default of `None`. If this value is provided, the target_modules parameter becomes ineffective. This parameter is not limited to LoRA and can be used for other tuners. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index f51e730940..0a82d5dd7a 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -206,6 +206,10 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the **Tuner Parameters**: - train_type: Options are `'lora'` and `'full'`. Default is `'full'`. +- 🔥 freeze_llm: This parameter only takes effect for multimodal models and can be used in both full-parameter and LoRA training, but with different behaviors. In full-parameter training, setting `freeze_llm` to `True` will freeze the weights of the LLM component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_llm` to `True` will prevent LoRA modules from being added to the LLM component. The default value is `False`. +- 🔥 freeze_vit: This parameter only applies to multimodal models and can be used in both full-parameter and LoRA training, though with different effects. In full-parameter training, setting `freeze_vit` to `True` will freeze the weights of the ViT component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_vit` to `True` will prevent LoRA modules from being added to the ViT component. The default value is `True`. + - Note: The term "ViT" here refers not only to the vision tower but also includes the audio tower. +- 🔥 freeze_aligner: This parameter is only effective for multimodal models and can be used in both full-parameter and LoRA training, with differing outcomes. In full-parameter training, setting `freeze_aligner` to `True` will freeze the weights of the aligner (also known as the projector) component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_aligner` to `True` will prevent LoRA modules from being added to the aligner component. The default value is `True`. Full-parameter Training: diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 605e8df62d..5d5182f228 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -31,6 +31,9 @@ class RLHFMegatronArgumentsMixin: @dataclass class MegatronTunerMixin: train_type: Literal['lora', 'full'] = 'full' + freeze_llm: bool = False + freeze_vit: bool = True + freeze_aligner: bool = True # full freeze_parameters: List[str] = field(default_factory=list) freeze_parameters_regex: Optional[str] = None diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 36d8cbe1be..9481b451df 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -17,7 +17,6 @@ class MegatronTrainArguments(MegatronArguments, BaseArguments): add_version: bool = True def init_model_args(self, tokenizer, config): - self.megatron_model_meta = get_megatron_model_meta(self.model_type) kwargs = self.megatron_model_meta.convert_hf_config(config) if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer): kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128 @@ -49,6 +48,7 @@ def __post_init__(self): self.padding_free = True self.load = to_abspath(self.load, check_path_exist=True) BaseArguments.__post_init__(self) + self.megatron_model_meta = get_megatron_model_meta(self.model_type) if len(self.dataset) == 0 and len(self.cached_dataset) == 0: raise ValueError(f'self.dataset: {self.dataset}, self.cached_dataset: {self.cached_dataset}. ' 'Please input the training dataset.') diff --git a/swift/megatron/model/__init__.py b/swift/megatron/model/__init__.py index 5cd8834e90..3c882c9864 100644 --- a/swift/megatron/model/__init__.py +++ b/swift/megatron/model/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from . import gpt, qwen2_5_vl +from . import gpt, mm_gpt from .constant import MegatronModelType from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index dafbb9f8ec..b22cde9433 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -1,70 +1,70 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from dataclasses import dataclass -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Type from torch import nn from transformers import PretrainedConfig from swift.llm import ModelType from ..constant import MegatronModelType +from ..gpt_model import GPTModel +from ..model_provider import model_provider from ..register import MegatronModelMeta, register_megatron_model from .config import convert_gpt_hf_config -from .hf2mcore import convert_hf2mcore as convert_hf2mcore_gpt -from .mcore2hf import convert_mcore2hf as convert_mcore2hf_gpt -from .model import model_provider as gpt_model_provider - - -@dataclass -class GPTMegatronModelMeta(MegatronModelMeta): - model_provider: Callable[[], nn.Module] = gpt_model_provider - convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] = convert_gpt_hf_config - convert_mcore2hf: Callable[[nn.Module, nn.Module], None] = convert_mcore2hf_gpt - convert_hf2mcore: Callable[[nn.Module, nn.Module], None] = convert_hf2mcore_gpt - +from .hf2mcore import convert_hf2mcore +from .mcore2hf import convert_mcore2hf register_megatron_model( - GPTMegatronModelMeta(MegatronModelType.gpt, [ - ModelType.qwen2, - ModelType.qwen2_5, - ModelType.qwq, - ModelType.qwq_preview, - ModelType.qwen2_5_math, - ModelType.llama, - ModelType.llama3, - ModelType.llama3_1, - ModelType.llama3_2, - ModelType.longwriter_llama3_1, - ModelType.codefuse_codellama, - ModelType.marco_o1, - ModelType.deepseek, - ModelType.deepseek_r1_distill, - ModelType.yi, - ModelType.yi_coder, - ModelType.sus, - ModelType.skywork_o1, - ModelType.openbuddy_llama, - ModelType.openbuddy_llama3, - ModelType.megrez, - ModelType.reflection, - ModelType.numina, - ModelType.ziya, - ModelType.mengzi3, - ModelType.qwen3, - ModelType.qwen3_thinking, - ModelType.qwen3_nothinking, - ModelType.qwen2_moe, - ModelType.qwen3_moe, - ModelType.qwen3_moe_thinking, - ModelType.internlm3, - ModelType.mimo, - ModelType.mimo_rl, - ModelType.moonlight, - ModelType.deepseek_moe, - ModelType.deepseek_v2, - ModelType.deepseek_v2_5, - ModelType.deepseek_r1, - ModelType.dots1, - ModelType.ernie, - ModelType.glm4_5, - ModelType.deepseek_v3_1, - ])) + MegatronModelMeta( + MegatronModelType.gpt, + [ + ModelType.qwen2, + ModelType.qwen2_5, + ModelType.qwq, + ModelType.qwq_preview, + ModelType.qwen2_5_math, + ModelType.llama, + ModelType.llama3, + ModelType.llama3_1, + ModelType.llama3_2, + ModelType.longwriter_llama3_1, + ModelType.codefuse_codellama, + ModelType.marco_o1, + ModelType.deepseek, + ModelType.deepseek_r1_distill, + ModelType.yi, + ModelType.yi_coder, + ModelType.sus, + ModelType.skywork_o1, + ModelType.openbuddy_llama, + ModelType.openbuddy_llama3, + ModelType.megrez, + ModelType.reflection, + ModelType.numina, + ModelType.ziya, + ModelType.mengzi3, + ModelType.qwen3, + ModelType.qwen3_thinking, + ModelType.qwen3_nothinking, + ModelType.qwen2_moe, + ModelType.qwen3_moe, + ModelType.qwen3_moe_thinking, + ModelType.internlm3, + ModelType.mimo, + ModelType.mimo_rl, + ModelType.moonlight, + ModelType.deepseek_moe, + ModelType.deepseek_v2, + ModelType.deepseek_v2_5, + ModelType.deepseek_r1, + ModelType.dots1, + ModelType.ernie, + ModelType.glm4_5, + ModelType.deepseek_v3_1, + ], + model_provider=model_provider, + model_cls=GPTModel, + convert_hf_config=convert_gpt_hf_config, + convert_mcore2hf=convert_mcore2hf, + convert_hf2mcore=convert_hf2mcore, + )) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 6b675cf5d2..15f0641df9 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -92,11 +92,6 @@ def __init__( logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') - args = get_args() - self.visual = None - if args.megatron_model_meta.visual is not None: - self.visual = args.megatron_model_meta.visual(config) - @contextmanager def _patch_apply_rotary_pos_emb(self): if self.attention_scaling == 1.: @@ -144,19 +139,10 @@ def forward( # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. # Decoder embedding. - args = get_args() if decoder_input is not None: pass elif self.pre_process: decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) - if self.visual is not None: - if args.tensor_model_parallel_size > 1 and args.sequence_parallel: - input_ids = input_ids.chunk( - args.tensor_model_parallel_size, dim=-1)[mpu.get_tensor_model_parallel_rank()] - kwargs.update({'input_ids': input_ids}) - decoder_input = decoder_input.transpose(0, 1) - decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs) - decoder_input = decoder_input.transpose(0, 1) else: # intermediate stage of pipeline # decoder will get hidden_states from encoder.input_tensor diff --git a/swift/megatron/model/mm_gpt/__init__.py b/swift/megatron/model/mm_gpt/__init__.py new file mode 100644 index 0000000000..30f489086c --- /dev/null +++ b/swift/megatron/model/mm_gpt/__init__.py @@ -0,0 +1 @@ +from . import qwen2_5_vl diff --git a/swift/megatron/model/qwen2_5_vl/vit.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py similarity index 62% rename from swift/megatron/model/qwen2_5_vl/vit.py rename to swift/megatron/model/mm_gpt/qwen2_5_vl.py index 774bd106e6..7415be2f8d 100644 --- a/swift/megatron/model/qwen2_5_vl/vit.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -1,28 +1,44 @@ -from contextlib import contextmanager - import torch from megatron.core.models.huggingface import HuggingFaceModule from megatron.training import get_args -from swift.llm import get_model_tokenizer, to_device +from swift.llm import ModelType, get_model_tokenizer, to_device +from ..constant import MegatronModelType +from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore +from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf +from ..register import MegatronModelMeta, register_megatron_model +from .utils import MMGPTMegatronModelMeta, patch_device_map_meta -@contextmanager -def patch_device_map_meta(model_cls): - __origin_init__ = model_cls.__init__ +def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): + language_model = hf_model.model.language_model + mg_language_model = mg_model.language_model + args = get_args() + mg_language_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) + if args.untie_embeddings_and_output_weights: + mg_language_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) + mg_language_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) + for layer_idx in range(args.num_layers): + set_layer_state_hf2mcore(args, mg_language_model, language_model, layer_idx) + mg_model.visual.model.load_state_dict(hf_model.model.visual.state_dict()) - def __init__(self, *args, **kwargs): - with torch.device('meta'): - __origin_init__(self, *args, **kwargs) - model_cls.__init__ = __init__ - try: - yield - finally: - model_cls.__init__ = __origin_init__ +def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): + language_model = hf_model.model.language_model + mg_language_model = mg_model.language_model + args = get_args() + language_model.embed_tokens.weight.data.copy_(mg_language_model.embedding.word_embeddings.weight) + if args.untie_embeddings_and_output_weights: + hf_model.lm_head.weight.data.copy_(mg_language_model.output_layer.weight) + language_model.norm.weight.data.copy_(mg_language_model.decoder.final_layernorm.weight) + for layer_idx in range(args.num_layers): + set_layer_state_mcore2hf(args, mg_language_model, language_model, layer_idx) + hf_model.model.visual.load_state_dict(mg_model.visual.model.state_dict()) class Qwen2_5VL_Vit(HuggingFaceModule): + vision_tower = ['model'] + aligner = ['model.merger'] def __init__(self, config): from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel @@ -88,3 +104,13 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) return inputs_embeds + + +register_megatron_model( + MMGPTMegatronModelMeta( + MegatronModelType.qwen2_5_vl, [ + ModelType.qwen2_5_vl, + ], + convert_hf2mcore=convert_hf2mcore_qwen2_5_vl, + convert_mcore2hf=convert_mcore2hf_qwen2_5_vl, + visual_cls=Qwen2_5VL_Vit)) diff --git a/swift/megatron/model/mm_gpt/utils.py b/swift/megatron/model/mm_gpt/utils.py new file mode 100644 index 0000000000..882f990fcc --- /dev/null +++ b/swift/megatron/model/mm_gpt/utils.py @@ -0,0 +1,34 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Dict, Type + +import torch +from torch import nn +from transformers import PretrainedConfig + +from ..gpt.config import convert_gpt_hf_config +from ..mm_gpt_model import MultimodalGPTModel +from ..model_provider import model_provider as model_provider_func +from ..register import MegatronModelMeta, register_megatron_model + + +@contextmanager +def patch_device_map_meta(model_cls): + __origin_init__ = model_cls.__init__ + + def __init__(self, *args, **kwargs): + with torch.device('meta'): + __origin_init__(self, *args, **kwargs) + + model_cls.__init__ = __init__ + try: + yield + finally: + model_cls.__init__ = __origin_init__ + + +@dataclass +class MMGPTMegatronModelMeta(MegatronModelMeta): + model_cls: Type[nn.Module] = MultimodalGPTModel + model_provider: Callable[[], nn.Module] = model_provider_func + convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] = convert_gpt_hf_config diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py new file mode 100644 index 0000000000..859b27b362 --- /dev/null +++ b/swift/megatron/model/mm_gpt_model.py @@ -0,0 +1,70 @@ +import torch +from megatron.core import InferenceParams, mpu +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.training import get_args + +from .gpt_model import GPTModel + + +class MultimodalGPTModel(MegatronModule): + + def __init__(self, + config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + *args, + **kwargs): + super().__init__(config) + self.pre_process = pre_process + self.post_process = post_process + self.language_model = GPTModel(config, transformer_layer_spec, vocab_size, max_sequence_length, pre_process, + post_process, *args, **kwargs) + + args = get_args() + self.visual = None + if args.megatron_model_meta.visual_cls is not None: + self.visual = args.megatron_model_meta.visual_cls(config) + + # Code borrowed from NVIDIA/Megatron-LM + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor = None, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + **kwargs, + ) -> torch.Tensor: + args = get_args() + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if self.visual is not None: + if args.tensor_model_parallel_size > 1 and args.sequence_parallel: + input_ids = input_ids.chunk( + args.tensor_model_parallel_size, dim=-1)[mpu.get_tensor_model_parallel_rank()] + kwargs.update({'input_ids': input_ids}) + decoder_input = decoder_input.transpose(0, 1) + decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs) + decoder_input = decoder_input.transpose(0, 1) + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + return self.model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + decoder_input=decoder_input, + labels=labels, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + ) diff --git a/swift/megatron/model/gpt/model.py b/swift/megatron/model/model_provider.py similarity index 94% rename from swift/megatron/model/gpt/model.py rename to swift/megatron/model/model_provider.py index 42cd69f375..1eeff0b8a2 100644 --- a/swift/megatron/model/gpt/model.py +++ b/swift/megatron/model/model_provider.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Union +from typing import TYPE_CHECKING, Union import megatron.legacy import torch @@ -12,11 +12,14 @@ from megatron.training.arguments import core_transformer_config_from_args from megatron.training.yaml_arguments import core_transformer_config_from_yaml -from ..gpt_model import GPTModel +if TYPE_CHECKING: + from .gpt_model import GPTModel + from .mm_gpt import MultimodalGPTModel # Code borrowed from NVIDIA/Megatron-LM -def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: +def model_provider(pre_process=True, + post_process=True) -> Union['GPTModel', 'MultimodalGPTModel', megatron.legacy.model.GPTModel]: """Builds the model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. @@ -97,7 +100,7 @@ def oom_observer(device, alloc, device_alloc, device_free): # qwen2_moe for layer_spec in transformer_layer_spec.layer_specs: layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - model = GPTModel( + model = args.megatron_model_meta.model_cls( config=config, transformer_layer_spec=transformer_layer_spec, vocab_size=args.padded_vocab_size, diff --git a/swift/megatron/model/qwen2_5_vl/__init__.py b/swift/megatron/model/qwen2_5_vl/__init__.py deleted file mode 100644 index fba53304c1..0000000000 --- a/swift/megatron/model/qwen2_5_vl/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from swift.llm import ModelType -from ..constant import MegatronModelType -from ..gpt import GPTMegatronModelMeta -from ..register import register_megatron_model -from .convert import convert_hf2mcore_qwen2_5_vl, convert_mcore2hf_qwen2_5_vl -from .vit import Qwen2_5VL_Vit - -register_megatron_model( - GPTMegatronModelMeta( - MegatronModelType.qwen2_5_vl, [ - ModelType.qwen2_5_vl, - ], - convert_hf2mcore=convert_hf2mcore_qwen2_5_vl, - convert_mcore2hf=convert_mcore2hf_qwen2_5_vl, - visual=Qwen2_5VL_Vit)) diff --git a/swift/megatron/model/qwen2_5_vl/convert.py b/swift/megatron/model/qwen2_5_vl/convert.py deleted file mode 100644 index 9da831611b..0000000000 --- a/swift/megatron/model/qwen2_5_vl/convert.py +++ /dev/null @@ -1,28 +0,0 @@ -from megatron.training import get_args - -from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore -from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf - - -def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): - language_model = hf_model.model.language_model - args = get_args() - mg_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) - if args.untie_embeddings_and_output_weights: - mg_model.output_layer.weight.data.copy_(hf_model.lm_head.weight) - mg_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) - for layer_idx in range(args.num_layers): - set_layer_state_hf2mcore(args, mg_model, language_model, layer_idx) - mg_model.visual.model.load_state_dict(hf_model.model.visual.state_dict()) - - -def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): - language_model = hf_model.model.language_model - args = get_args() - language_model.embed_tokens.weight.data.copy_(mg_model.embedding.word_embeddings.weight) - if args.untie_embeddings_and_output_weights: - hf_model.lm_head.weight.data.copy_(mg_model.output_layer.weight) - language_model.norm.weight.data.copy_(mg_model.decoder.final_layernorm.weight) - for layer_idx in range(args.num_layers): - set_layer_state_mcore2hf(args, mg_model, language_model, layer_idx) - hf_model.model.visual.load_state_dict(mg_model.visual.model.state_dict()) diff --git a/swift/megatron/model/register.py b/swift/megatron/model/register.py index fc8b95f80c..8ed93f1ac5 100644 --- a/swift/megatron/model/register.py +++ b/swift/megatron/model/register.py @@ -16,11 +16,13 @@ class MegatronModelMeta: megatron_model_type: str model_types: List[str] - model_provider: Callable[[], nn.Module] - convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] convert_mcore2hf: Callable[[nn.Module, nn.Module], None] convert_hf2mcore: Callable[[nn.Module, nn.Module], None] - visual: Optional[Type[nn.Module]] = None + + model_cls: Type[nn.Module] + model_provider: Callable[[], nn.Module] + convert_hf_config: Callable[[PretrainedConfig], Dict[str, Any]] + visual_cls: Optional[Type[nn.Module]] = None extra_args_provider: Optional[Callable[[ArgumentParser], ArgumentParser]] = None diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index c72787aaa1..232c2460e3 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -138,7 +138,8 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float inputs = to_device(template.data_collator([inputs]), 'cuda') HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False) - share_embedding = mg_model.share_embeddings_and_output_weights + mg_language_model = mg_model.language_model if is_multimodal else mg_model + share_embedding = mg_language_model.share_embeddings_and_output_weights model_arch = hf_model.model_meta.model_arch ignore_modules = [] if model_arch is None else (model_arch.vision_tower + model_arch.aligner) @@ -156,8 +157,8 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float # mg_torch_dtype = None # packed_seq_params = get_packed_seq_params(position_ids) # attention_mask = None - mg_model.config.fp8 = None # compat fp8 - mg_modules = _find_modules(mg_model, ignore_modules=['visual']) + mg_language_model.config.fp8 = None # compat fp8 + mg_modules = _find_modules(mg_language_model, ignore_modules=['visual']) kwargs = {k: v for k, v in inputs.items() if k not in ['input_ids', 'attention_mask', 'labels']} if 'position_ids' not in kwargs: kwargs['position_ids'] = position_ids From 315160372b9b23d91dbcd0e6bb4f2b4a3f17f060 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 17:13:04 +0800 Subject: [PATCH 11/31] fix --- swift/megatron/model/mm_gpt_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 859b27b362..fc64efe3b9 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -2,6 +2,7 @@ from megatron.core import InferenceParams, mpu from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig from megatron.training import get_args @@ -46,7 +47,7 @@ def forward( if decoder_input is not None: pass elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) if self.visual is not None: if args.tensor_model_parallel_size > 1 and args.sequence_parallel: input_ids = input_ids.chunk( @@ -59,7 +60,7 @@ def forward( # intermediate stage of pipeline # decoder will get hidden_states from encoder.input_tensor decoder_input = None - return self.model( + return self.language_model( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, @@ -68,3 +69,6 @@ def forward( inference_params=inference_params, packed_seq_params=packed_seq_params, ) + + def set_input_tensor(self, input_tensor: torch.Tensor) -> None: + return self.language_model.set_input_tensor(input_tensor) From 93c46939b557fbff1a8af171fba7f2ec3a41bc7e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 19:46:10 +0800 Subject: [PATCH 12/31] update --- .../Instruction/Command-line-parameters.md | 8 +++---- .../Megatron-SWIFT/Command-line-parameters.md | 6 ++--- swift/megatron/argument/megatron_args.py | 2 ++ swift/megatron/trainers/base.py | 22 ++++++++++++++++++- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 15c4f96053..6499184173 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -215,11 +215,11 @@ Other important parameters: ### Tuner Arguments -- 🔥 freeze_llm: This parameter only takes effect for multimodal models and can be used in both full-parameter and LoRA training, but with different behaviors. In full-parameter training, setting `freeze_llm` to `True` will freeze the weights of the LLM component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_llm` to `True` will prevent LoRA modules from being added to the LLM component. The default value is `False`. -- 🔥 freeze_vit: This parameter only applies to multimodal models and can be used in both full-parameter and LoRA training, though with different effects. In full-parameter training, setting `freeze_vit` to `True` will freeze the weights of the ViT component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_vit` to `True` will prevent LoRA modules from being added to the ViT component. The default value is `True`. +- 🔥freeze_llm: This parameter only takes effect for multimodal models and can be used in both full-parameter and LoRA training, but with different behaviors. In full-parameter training, setting `freeze_llm` to `True` will freeze the weights of the LLM component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_llm` to `True` will prevent LoRA modules from being added to the LLM component. The default value is `False`. +- 🔥freeze_vit: This parameter only applies to multimodal models and can be used in both full-parameter and LoRA training, though with different effects. In full-parameter training, setting `freeze_vit` to `True` will freeze the weights of the ViT component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_vit` to `True` will prevent LoRA modules from being added to the ViT component. The default value is `True`. - Note: The term "ViT" here refers not only to the vision tower but also includes the audio tower. -- 🔥 freeze_aligner: This parameter is only effective for multimodal models and can be used in both full-parameter and LoRA training, with differing outcomes. In full-parameter training, setting `freeze_aligner` to `True` will freeze the weights of the aligner (also known as the projector) component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_aligner` to `True` will prevent LoRA modules from being added to the aligner component. The default value is `True`. -- 🔥 target_modules: Specifies the LoRA modules. The default is `['all-linear']`, but you can also pass layer-name suffixes, e.g. `--target_modules q_proj k_proj v_proj`. This argument is not restricted to LoRA and can be used with other tuners as well. +- 🔥freeze_aligner: This parameter is only effective for multimodal models and can be used in both full-parameter and LoRA training, with differing outcomes. In full-parameter training, setting `freeze_aligner` to `True` will freeze the weights of the aligner (also known as the projector) component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_aligner` to `True` will prevent LoRA modules from being added to the aligner component. The default value is `True`. +- 🔥target_modules: Specifies the LoRA modules. The default is `['all-linear']`, but you can also pass layer-name suffixes, e.g. `--target_modules q_proj k_proj v_proj`. This argument is not restricted to LoRA and can be used with other tuners as well. - Note: The behavior of the special value `'all-linear'` differs between plain LLMs and multimodal LLMs. For a standard LLM, it automatically locates every linear layer except `lm_head` and attaches a tuner. For a multimodal LLM, it attaches the tuner only to the LLM component by default. This default can be changed with the `freeze_llm`, `freeze_vit`, and `freeze_aligner` options. - 🔥target_regex: Specifies a regex expression for LoRA modules, with a default of `None`. If this value is provided, the target_modules parameter becomes ineffective. This parameter is not limited to LoRA and can be used for other tuners. - target_parameters: List of parameter names to be replaced with LoRA. This argument behaves similarly to target_modules, but you should pass parameter names instead. This feature requires "peft>=0.17.0". For example, in many Mixture-of-Experts (MoE) layers in Hugging Face Transformers, `nn.Linear` is not used; instead, `nn.Parameter` is used. In such cases, the `target_parameters` argument can be used to apply LoRA. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 0a82d5dd7a..89963700d0 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -206,10 +206,10 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the **Tuner Parameters**: - train_type: Options are `'lora'` and `'full'`. Default is `'full'`. -- 🔥 freeze_llm: This parameter only takes effect for multimodal models and can be used in both full-parameter and LoRA training, but with different behaviors. In full-parameter training, setting `freeze_llm` to `True` will freeze the weights of the LLM component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_llm` to `True` will prevent LoRA modules from being added to the LLM component. The default value is `False`. -- 🔥 freeze_vit: This parameter only applies to multimodal models and can be used in both full-parameter and LoRA training, though with different effects. In full-parameter training, setting `freeze_vit` to `True` will freeze the weights of the ViT component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_vit` to `True` will prevent LoRA modules from being added to the ViT component. The default value is `True`. +- 🔥freeze_llm: This parameter only takes effect for multimodal models and can be used in both full-parameter and LoRA training, but with different behaviors. In full-parameter training, setting `freeze_llm` to `True` will freeze the weights of the LLM component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_llm` to `True` will prevent LoRA modules from being added to the LLM component. The default value is `False`. +- 🔥freeze_vit: This parameter only applies to multimodal models and can be used in both full-parameter and LoRA training, though with different effects. In full-parameter training, setting `freeze_vit` to `True` will freeze the weights of the ViT component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_vit` to `True` will prevent LoRA modules from being added to the ViT component. The default value is `True`. - Note: The term "ViT" here refers not only to the vision tower but also includes the audio tower. -- 🔥 freeze_aligner: This parameter is only effective for multimodal models and can be used in both full-parameter and LoRA training, with differing outcomes. In full-parameter training, setting `freeze_aligner` to `True` will freeze the weights of the aligner (also known as the projector) component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_aligner` to `True` will prevent LoRA modules from being added to the aligner component. The default value is `True`. +- 🔥freeze_aligner: This parameter is only effective for multimodal models and can be used in both full-parameter and LoRA training, with differing outcomes. In full-parameter training, setting `freeze_aligner` to `True` will freeze the weights of the aligner (also known as the projector) component. In LoRA training with `target_modules` set to 'all-linear', setting `freeze_aligner` to `True` will prevent LoRA modules from being added to the aligner component. The default value is `True`. Full-parameter Training: diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 5d5182f228..02b3721add 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -74,6 +74,8 @@ def load_tuner_config(adapter_load: Optional[str]) -> Dict[str, Any]: def __post_init__(self): if self.freeze_parameters_ratio > 0 and self.pipeline_model_parallel_size > 1: raise ValueError('`freeze_parameters_ratio` is not supported when `pipeline_model_parallel_size` > 1') + if self.target_regex: + self.target_modules = self.target_regex @dataclass diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index bcaba4c679..65b68f7795 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -242,10 +242,11 @@ def new_model_provider_func(*args, **kwargs): self.peft_model = prepare_mcore_model(self.unwrapped_model) return self.unwrapped_model + args = get_args() + self._init_multimodal_full(args) with self._patch_load_state_dict(self._load_base_checkpoint): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( new_model_provider_func, model_type, *_args, **kwargs) - args = get_args() if args.initialize_embedding: self._initialize_embedding(self.unwrapped_model) if args.train_type != 'full' and args.modules_to_save: @@ -711,6 +712,25 @@ def _patch_megatron(self): self._origin_save_checkpoint = training.save_checkpoint training.save_checkpoint = self.save_checkpoint + @staticmethod + def _init_multimodal_full(args): + visual_cls = args.megatron_model_meta.visual_cls + if args.train_type == 'full' and args.model_meta.is_multimodal and visual_cls is not None: + vision_tower = [f'visual.{vit}' for vit in visual_cls.vision_tower] + aligner = [f'visual.{_aliger}' for _aliger in visual_cls.aligner] + if args.freeze_llm: + args.freeze_parameters.append('language_model') + if args.freeze_vit: + args.freeze_parameters += vision_tower + if args.freeze_aligner: + args.freeze_parameters += aligner + else: + args.trainable_parameters += aligner + if args.freeze_parameters: + logger.info(f'freeze_parameters: {args.freeze_parameters}') + if args.trainable_parameters: + logger.info(f'additional trainable_parameters: {args.trainable_parameters}') + def train(self, train_dataset, val_dataset, data_collator): args = self.args datasets_provider = get_swift_datasets_provider(train_dataset, val_dataset) From 308d5659c78b0735448b0fd4811b9f5f9975a019 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 21:04:07 +0800 Subject: [PATCH 13/31] update --- swift/megatron/model/mm_gpt_model.py | 1 + swift/megatron/trainers/base.py | 2 +- swift/megatron/utils/utils.py | 59 ++++++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index fc64efe3b9..a4cb45923f 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -26,6 +26,7 @@ def __init__(self, self.language_model = GPTModel(config, transformer_layer_spec, vocab_size, max_sequence_length, pre_process, post_process, *args, **kwargs) + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights args = get_args() self.visual = None if args.megatron_model_meta.visual_cls is not None: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 65b68f7795..841a4ae229 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -717,7 +717,7 @@ def _init_multimodal_full(args): visual_cls = args.megatron_model_meta.visual_cls if args.train_type == 'full' and args.model_meta.is_multimodal and visual_cls is not None: vision_tower = [f'visual.{vit}' for vit in visual_cls.vision_tower] - aligner = [f'visual.{_aliger}' for _aliger in visual_cls.aligner] + aligner = [f'visual.{_aligner}' for _aligner in visual_cls.aligner] if args.freeze_llm: args.freeze_parameters.append('language_model') if args.freeze_vit: diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index db260bf3c3..b0bad06933 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -11,8 +11,10 @@ from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default from megatron.training import checkpointing, get_args from peft.utils.other import ModulesToSaveWrapper +from torch import nn -from swift.utils import activate_parameters, find_layers, freeze_parameters, get_logger, get_model_parameter_info +from swift.utils import (activate_parameters, deep_getattr, find_layers, freeze_parameters, get_logger, + get_model_parameter_info) logger = get_logger() @@ -20,7 +22,7 @@ def find_all_linears(model): def _cond(name, module): - if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear)): + if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, nn.Linear)): return True return False @@ -35,13 +37,62 @@ def find_embedding(model): return find_layers(model, lambda name, module: isinstance(module, LanguageModelEmbedding)) +def get_multimodal_target_regex( + args, + model, + *, + freeze_llm: bool = False, + freeze_vit: bool = True, + freeze_aligner: bool = True, +) -> str: + modules = [] + visual_cls = args.megatron_model_meta.visual_cls + vision_tower = [f'visual.{vit}' for vit in visual_cls.vision_tower] + aligner = [f'visual.{_aligner}' for _aligner in visual_cls.aligner] + if not freeze_llm: + modules.append('language_model') + if not freeze_vit: + modules += vision_tower + if not freeze_aligner: + modules += aligner + assert len(modules) > 0, f'modules: {modules}' + + res = [] + for module in modules: + rejected_modules = [] + if not freeze_vit: + for aligner in aligner: + if aligner.startswith(f'{module}.'): + rejected_modules.append(aligner) + + sub_module = deep_getattr(model, module) + target_modules = find_all_linears(sub_module) + if not target_modules: + continue + target_modules = [tm for tm in target_modules if tm] + target_pattern = rf'.*\.({"|".join(target_modules)})' if target_modules else '' + rejected_pattern = rf'(?!({"|".join(rejected_modules)}))' if rejected_modules else '' + res.append(rf'{rejected_pattern}{module}{target_pattern}') + + return rf'^({"|".join(res)})$' + + def get_target_modules(args, model): if isinstance(args.target_modules, str): return args.target_modules target_modules = args.target_modules.copy() if 'all-linear' in target_modules: - target_modules.remove('all-linear') - target_modules += find_all_linears(model) + if args.model_meta.is_multimodal: + return get_multimodal_target_regex( + args, + model, + freeze_llm=args.freeze_llm, + freeze_vit=args.freeze_vit, + freeze_aligner=args.freeze_aligner, + ) + else: + target_modules.remove('all-linear') + target_modules += find_all_linears(model) if 'all-embedding' in target_modules: target_modules.remove('all-embedding') target_modules += find_embedding(model) From 08320eb43ca59ba9570191f9915bb4816c041b80 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 21:06:57 +0800 Subject: [PATCH 14/31] update --- swift/megatron/model/gpt/__init__.py | 6 ------ swift/megatron/model/gpt_model.py | 3 +-- swift/megatron/model/mm_gpt/utils.py | 2 +- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/swift/megatron/model/gpt/__init__.py b/swift/megatron/model/gpt/__init__.py index b22cde9433..9e2654620e 100644 --- a/swift/megatron/model/gpt/__init__.py +++ b/swift/megatron/model/gpt/__init__.py @@ -1,10 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from dataclasses import dataclass -from typing import Any, Callable, Dict, Type - -from torch import nn -from transformers import PretrainedConfig - from swift.llm import ModelType from ..constant import MegatronModelType from ..gpt_model import GPTModel diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 15f0641df9..b74e3db7af 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -4,14 +4,13 @@ from typing import Any, Dict, Literal, Optional import torch -from megatron.core import InferenceParams, mpu +from megatron.core import InferenceParams from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel as McoreGPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.training import get_args from swift.utils import get_logger from .rope import dynamic_rope_update, get_rope_inv_freq diff --git a/swift/megatron/model/mm_gpt/utils.py b/swift/megatron/model/mm_gpt/utils.py index 882f990fcc..ec949f19c5 100644 --- a/swift/megatron/model/mm_gpt/utils.py +++ b/swift/megatron/model/mm_gpt/utils.py @@ -9,7 +9,7 @@ from ..gpt.config import convert_gpt_hf_config from ..mm_gpt_model import MultimodalGPTModel from ..model_provider import model_provider as model_provider_func -from ..register import MegatronModelMeta, register_megatron_model +from ..register import MegatronModelMeta @contextmanager From 44a95bdd97e4dc8e3f93add08a98f14ab219d1d8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 23:42:52 +0800 Subject: [PATCH 15/31] fix --- swift/megatron/model/gpt_model.py | 2 +- swift/megatron/model/mm_gpt/qwen2_5_vl.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index b74e3db7af..088fbd9f92 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -177,7 +177,7 @@ def forward( rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) else: # Flash decoding uses precomputed cos and sin for RoPE - raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implmented in ' + raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implemented in ' 'MultimodalRotaryEmbedding yet.') if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None and inference_params): diff --git a/swift/megatron/model/mm_gpt/qwen2_5_vl.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py index 7415be2f8d..616ab231ba 100644 --- a/swift/megatron/model/mm_gpt/qwen2_5_vl.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -1,7 +1,6 @@ import torch from megatron.core.models.huggingface import HuggingFaceModule -from megatron.training import get_args - +from megatron.training import get_args, get_tokenizer from swift.llm import ModelType, get_model_tokenizer, to_device from ..constant import MegatronModelType from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore @@ -50,6 +49,7 @@ def __init__(self, config): model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs) self.model = model.visual.to('cuda') self.model_config = model.config + self.processor = get_tokenizer() def forward(self, *args, **kwargs): return self.model(*args, **kwargs) From 50b2eb172f2f6f7697ac45890ff94f20809be0bf Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 1 Sep 2025 23:43:58 +0800 Subject: [PATCH 16/31] lint pass --- swift/megatron/model/mm_gpt/qwen2_5_vl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/megatron/model/mm_gpt/qwen2_5_vl.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py index 616ab231ba..1fa63bd52c 100644 --- a/swift/megatron/model/mm_gpt/qwen2_5_vl.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -1,6 +1,7 @@ import torch from megatron.core.models.huggingface import HuggingFaceModule from megatron.training import get_args, get_tokenizer + from swift.llm import ModelType, get_model_tokenizer, to_device from ..constant import MegatronModelType from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore From 51f315a8473e17c08d8928dcfdc83a2f7760e878 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 01:27:03 +0800 Subject: [PATCH 17/31] fix cp --- swift/llm/template/base.py | 3 +- swift/megatron/init.py | 51 +++++++++++++++++++++++ swift/megatron/model/gpt_model.py | 11 ++--- swift/megatron/model/mm_gpt/qwen2_5_vl.py | 1 + swift/megatron/model/mm_gpt_model.py | 7 +++- swift/megatron/trainers/utils.py | 48 ++++++++++++++++----- 6 files changed, 99 insertions(+), 22 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index a1ef0b3905..15571decfd 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1216,7 +1216,7 @@ def _encode_truncated(self, inputs: StdTemplateInputs): encoded[key] = value else: encoded = self._encode(inputs) - + self._handle_megatron_cp(encoded) # TODO: fix cp_size & cached_dataset input_ids = encoded.get('input_ids') labels = encoded.get('labels') loss_scale = encoded.get('loss_scale') @@ -1276,7 +1276,6 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded['input_ids'] = input_ids encoded['labels'] = labels encoded['loss_scale'] = loss_scale - self._handle_megatron_cp(encoded) # TODO: fix cp_size & cached_dataset if encoded.get('labels') is not None: encoded['labels'][0] = -100 if encoded.get('loss_scale') is not None: diff --git a/swift/megatron/init.py b/swift/megatron/init.py index a1d90db786..db6313d332 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -518,6 +518,56 @@ def __repr__(self): TELinear.__repr__ = __repr__ +def _patch_mrope(): + from megatron.core.models.common.embeddings.rotary_pos_embedding import (MultimodalRotaryEmbedding) + from megatron.core import parallel_state + from megatron.core.models.common.embeddings.rope_utils import get_pos_emb_on_this_cp_rank + + def forward(self, + max_seq_len: int, + mrope_section: List[int], + offset: int = 0, + packed_seq: bool = False) -> torch.Tensor: + if self.inv_freq.device.type == 'cpu': + # move `inv_freq` to GPU once at the first micro-batch forward pass + self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device()) + + seq = (torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + offset) + + if self.seq_len_interpolation_factor is not None: + seq *= 1 / self.seq_len_interpolation_factor + + # shape (3, bs, dim, 1) + inv_freq_expanded = self.inv_freq[None, None, :, None].expand(3, seq.shape[1], -1, 1) + # shape (3, bs, 1, seq_length) + seq_expanded = seq[:, :, None, :].float() + # shape (3, bs, seq_length, dim) + freqs = (inv_freq_expanded @ seq_expanded).transpose(2, 3) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + if not self.rotary_interleaved: + emb = torch.cat((freqs, freqs), dim=-1) # shape (3, bs, seq_length, 2 * dim) + else: + bs = freqs.shape[1] + emb = torch.stack((freqs.view(3, bs, -1, 1), freqs.view(3, bs, -1, 1)), + dim=-1).view(3, bs, freqs.shape[0], -1) + + # generate freqs with mrope_section + # shape (bs, seq_length, 2 * dim) + mrope_section = mrope_section * 2 + emb = torch.cat([m[i % 3] for i, m in enumerate(emb.split(mrope_section, dim=-1))], dim=-1) + + # shape (seq_length, bs, 1, 2 * dim) + emb = emb[..., None, :].transpose(0, 1).contiguous() + if parallel_state.get_context_parallel_world_size() > 1 and not packed_seq: + # slice rotary_pos_emb along sequence dimension and select the parition of the current + # CP rank + emb = get_pos_emb_on_this_cp_rank(emb, 0, parallel_state.get_context_parallel_group()) + return emb + + MultimodalRotaryEmbedding.forward = forward + + def _patch_megatron(): _patch_flash_attn() _patch_transformer_engine() @@ -527,6 +577,7 @@ def _patch_megatron(): _patch_TEGroupedLinear() _patch_TransformerLayer() _patch_compile_helpers() + _patch_mrope() from swift.megatron import tuners # patch lora try: _patch_torch_FileSystemReader() diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 088fbd9f92..37465925f3 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -154,7 +154,7 @@ def forward( rotary_pos_emb = None rotary_pos_cos = None rotary_pos_sin = None - if self.position_embedding_type == 'rope': + if self.position_embedding_type in {'rope', 'mrope'}: if not self.training and self.config.flash_decode and inference_params: # Flash decoding uses precomputed cos and sin for RoPE rotary_pos_cos, rotary_pos_sin = self.rotary_pos_emb_cache.setdefault( @@ -168,17 +168,12 @@ def forward( attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len) if attention_scaling is not None: self.attention_scaling = attention_scaling + kwargs = {mrope_section: self.mrope_section} if self.position_embedding_type == 'mrope' else {} rotary_pos_emb = self.rotary_pos_emb( rotary_seq_len, packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', + **kwargs, ) - elif self.position_embedding_type == 'mrope': - if self.training or not self.config.flash_decode: - rotary_pos_emb = self.rotary_pos_emb(position_ids, self.mrope_section) - else: - # Flash decoding uses precomputed cos and sin for RoPE - raise NotImplementedError('Flash decoding uses precomputed cos and sin for RoPE, not implemented in ' - 'MultimodalRotaryEmbedding yet.') if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None and inference_params): sequence_len_offset = torch.tensor( diff --git a/swift/megatron/model/mm_gpt/qwen2_5_vl.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py index 1fa63bd52c..ec6ad09dc5 100644 --- a/swift/megatron/model/mm_gpt/qwen2_5_vl.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -95,6 +95,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): image_embeds = mixed_embeds[:image_tokens] video_embeds = mixed_embeds[image_tokens:] + input_ids = input_ids.transpose(0, 1) if image_embeds is not None: image_mask = (input_ids == self.model_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index a4cb45923f..61b5a2301a 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -44,6 +44,7 @@ def forward( packed_seq_params: PackedSeqParams = None, **kwargs, ) -> torch.Tensor: + from ..trainers.utils import get_batch_on_this_cp_rank args = get_args() if decoder_input is not None: pass @@ -54,9 +55,11 @@ def forward( input_ids = input_ids.chunk( args.tensor_model_parallel_size, dim=-1)[mpu.get_tensor_model_parallel_rank()] kwargs.update({'input_ids': input_ids}) - decoder_input = decoder_input.transpose(0, 1) decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs) - decoder_input = decoder_input.transpose(0, 1) + decoder_input = get_batch_on_this_cp_rank({ + 'decoder_input': decoder_input, + 'packed_seq_params': packed_seq_params + })['decoder_input'] else: # intermediate stage of pipeline # decoder will get hidden_states from encoder.input_tensor diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 847314f166..30b90400f4 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -65,22 +65,41 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams: def _split_tokens(tokens, cu_seqlens): - assert tokens.shape[0] == 1, f'tokens.shape: {tokens.shape}' + assert tokens.shape[-2] == 1, f'tokens.shape: {tokens.shape}' # [..., 1, L] new_tokens = [] cp_size = mpu.get_context_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() for i in range(cu_seqlens.shape[0] - 1): - val = tokens[:, cu_seqlens[i]:cu_seqlens[i + 1]] + val = tokens[..., cu_seqlens[i]:cu_seqlens[i + 1]] val = val.view( - tokens.shape[0], + *tokens.shape[:-1], 2 * cp_size, - val.shape[1] // (2 * cp_size), + val.shape[-1] // (2 * cp_size), ) index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', pin_memory=True).cuda(non_blocking=True) - val = val.index_select(1, index) - new_tokens.append(val.view(tokens.shape[0], -1)) - return torch.cat(new_tokens, dim=1) + val = val.index_select(-2, index) + new_tokens.append(val.view(*tokens.shape[:-1], -1)) + return torch.cat(new_tokens, dim=-1) + + +def _split_tokens_decoder_input(tokens, cu_seqlens): + assert tokens.shape[1] == 1, f'tokens.shape: {tokens.shape}' # [L, 1, E] + new_tokens = [] + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() + for i in range(cu_seqlens.shape[0] - 1): + val = tokens[cu_seqlens[i]:cu_seqlens[i + 1], ...] + val = val.view( + 2 * cp_size, + val.shape[0] // (2 * cp_size), + *tokens.shape[1:], + ) + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', + pin_memory=True).cuda(non_blocking=True) + val = val.index_select(0, index) + new_tokens.append(val.view(-1, *tokens.shape[1:])) + return torch.cat(new_tokens, dim=0) def get_batch_on_this_cp_rank(batch: Dict[str, Any]): @@ -96,14 +115,23 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): # that we can get balanced workload among GPUs in a context parallel group. cp_size = mpu.get_context_parallel_world_size() if cp_size > 1: + args = get_args() + keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] + if args.model_meta.is_multimodal: + keys.append('decoder_input') + else: + keys.append('input_ids') packed_seq_params = batch.get('packed_seq_params') if packed_seq_params is None: return mcore_get_batch_on_this_cp_rank(batch) for key, val in batch.items(): - if key in {'packed_seq_params', 'channel'}: + if key not in keys: continue if val is not None: - batch[key] = _split_tokens(val, packed_seq_params.cu_seqlens_q) + if key == 'decoder_input': + batch[key] = _split_tokens_decoder_input(val, packed_seq_params.cu_seqlens_q) + else: + batch[key] = _split_tokens(val, packed_seq_params.cu_seqlens_q) return batch @@ -118,7 +146,7 @@ def get_batch(data_iterator): batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids']) batch['packed_seq_params'].num_samples = num_samples # slice batch along sequence dimension for context parallelism - position_ids = batch.get('real_position_ids') # fix Qwen2.5-VL + position_ids = batch.pop('real_position_ids', None) # fix Qwen2.5-VL if position_ids is not None: batch['position_ids'] = position_ids batch = get_batch_on_this_cp_rank(batch) From ef6690109c8c81b128f000c7718bfb096ac5cf32 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 02:11:48 +0800 Subject: [PATCH 18/31] update --- ...14\346\225\260\346\215\256\351\233\206.md" | 8 +-- .../Supported-models-and-datasets.md | 8 +-- swift/megatron/init.py | 66 +++++++++++++++++-- swift/megatron/model/gpt_model.py | 23 ++++--- 4 files changed, 81 insertions(+), 24 deletions(-) diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index b3836c2592..e7ec13b1d5 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -673,10 +673,10 @@ |[bytedance-research/UI-TARS-72B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-SFT](https://huggingface.co/bytedance-research/UI-TARS-72B-SFT)| |[bytedance-research/UI-TARS-72B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-DPO](https://huggingface.co/bytedance-research/UI-TARS-72B-DPO)| |[allenai/olmOCR-7B-0225-preview](https://modelscope.cn/models/allenai/olmOCR-7B-0225-preview)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[allenai/olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview)| -|[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| -|[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| -|[Qwen/Qwen2.5-VL-32B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct)| -|[Qwen/Qwen2.5-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct)| +|[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| +|[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| +|[Qwen/Qwen2.5-VL-32B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct)| +|[Qwen/Qwen2.5-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct)| |[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)| |[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)| |[Qwen/Qwen2.5-VL-32B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct-AWQ)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index a20d5c0bd6..78dc0bcb8a 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -673,10 +673,10 @@ The table below introduces the models integrated with ms-swift: |[bytedance-research/UI-TARS-72B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-SFT](https://huggingface.co/bytedance-research/UI-TARS-72B-SFT)| |[bytedance-research/UI-TARS-72B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-DPO](https://huggingface.co/bytedance-research/UI-TARS-72B-DPO)| |[allenai/olmOCR-7B-0225-preview](https://modelscope.cn/models/allenai/olmOCR-7B-0225-preview)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[allenai/olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview)| -|[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| -|[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| -|[Qwen/Qwen2.5-VL-32B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct)| -|[Qwen/Qwen2.5-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct)| +|[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| +|[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| +|[Qwen/Qwen2.5-VL-32B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct)| +|[Qwen/Qwen2.5-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-72B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-72B-Instruct)| |[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct-AWQ)| |[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct-AWQ)| |[Qwen/Qwen2.5-VL-32B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct-AWQ)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct-AWQ)| diff --git a/swift/megatron/init.py b/swift/megatron/init.py index db6313d332..edf028a2fb 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -521,18 +521,17 @@ def __repr__(self): def _patch_mrope(): from megatron.core.models.common.embeddings.rotary_pos_embedding import (MultimodalRotaryEmbedding) from megatron.core import parallel_state - from megatron.core.models.common.embeddings.rope_utils import get_pos_emb_on_this_cp_rank + from megatron.core.models.common.embeddings.rope_utils import (get_pos_emb_on_this_cp_rank, + _apply_rotary_pos_emb_bshd) + from megatron.core.models.common.embeddings import rope_utils + from megatron.training import get_args def forward(self, - max_seq_len: int, + position_ids, mrope_section: List[int], offset: int = 0, packed_seq: bool = False) -> torch.Tensor: - if self.inv_freq.device.type == 'cpu': - # move `inv_freq` to GPU once at the first micro-batch forward pass - self.inv_freq = self.inv_freq.to(device=torch.cuda.current_device()) - - seq = (torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + offset) + seq = position_ids.to(device=self.inv_freq.device, dtype=self.inv_freq.dtype) if self.seq_len_interpolation_factor is not None: seq *= 1 / self.seq_len_interpolation_factor @@ -566,6 +565,59 @@ def forward(self, return emb MultimodalRotaryEmbedding.forward = forward + _origin_apply_rotary_pos_emb_thd = rope_utils._apply_rotary_pos_emb_thd + + def _apply_rotary_pos_emb_thd( + t: torch.Tensor, + cu_seqlens: torch.Tensor, + freqs: torch.Tensor, + rotary_interleaved: bool = False, + multi_latent_attention: bool = False, + mscale: float = 1.0, + cp_group: torch.distributed.ProcessGroup = None, + ) -> torch.Tensor: + """A baseline implementation of applying RoPE for `thd` format. + + Args: + t (Tensor): Input tensor T is of shape [t, h, d] + cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`, + with shape [b + 1] and dtype torch.int32. + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d] + cp_group (torch.distributed.ProcessGroup): The context parallel group + + Returns: + Tensor: Shape [t, h, d]. The input tensor after applying RoPE. + """ + args = get_args() + if args.position_embedding_type != 'mrope': + return _origin_apply_rotary_pos_emb_thd( + t, + cu_seqlens, + freqs, + rotary_interleaved=config.rotary_interleaved, + multi_latent_attention=config.multi_latent_attention, + mscale=mscale, + cp_group=cp_group, + ) + + if cp_group is None: + raise ValueError('cp_group must be provided for THD format RoPE') + cp_size = cp_group.size() + cp_rank = cp_group.rank() + cu_seqlens = cu_seqlens // cp_size + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + + return torch.cat([ + _apply_rotary_pos_emb_bshd( + x.unsqueeze(1), + f, + rotary_interleaved=rotary_interleaved, + multi_latent_attention=multi_latent_attention, + mscale=mscale, + ) for x, f in zip(torch.split(t, seqlens), torch.split(freqs, seqlens)) + ]).squeeze(1) + + rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd def _patch_megatron(): diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 37465925f3..f03a7c855e 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -86,7 +86,7 @@ def __init__( new_inv_freq, self.attention_scaling = get_rope_inv_freq() self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) - if self.attention_scaling != 1 and config.apply_rope_fusion: + if (self.attention_scaling != 1 or position_embedding_type == 'mrope') and config.apply_rope_fusion: config.apply_rope_fusion = False logger.warning('`apply_rope_fusion` does not support `attention_scaling`. ' f'Setting `config.apply_rope_fusion`: {config.apply_rope_fusion}') @@ -162,18 +162,23 @@ def forward( self.rotary_pos_emb.get_cos_sin(inference_params.max_sequence_length), ) else: - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(inference_params, self.decoder, decoder_input, - self.config, packed_seq_params) + rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_params, self.decoder, decoder_input, + self.config, packed_seq_params) if self.hf_rope_scaling is not None: attention_scaling = dynamic_rope_update(self, self.rotary_pos_emb.inv_freq, rotary_seq_len) if attention_scaling is not None: self.attention_scaling = attention_scaling - kwargs = {mrope_section: self.mrope_section} if self.position_embedding_type == 'mrope' else {} - rotary_pos_emb = self.rotary_pos_emb( - rotary_seq_len, - packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', - **kwargs, - ) + if self.position_embedding_type == 'mrope': + rotary_pos_emb = self.rotary_pos_emb( + position_ids, + mrope_section=self.mrope_section, + packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', + ) + else: + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None and packed_seq_params.qkv_format == 'thd', + ) if ((self.config.enable_cuda_graph or self.config.flash_decode) and rotary_pos_cos is not None and inference_params): sequence_len_offset = torch.tensor( From 631691ccf494ee647549206621a9a9775f12ec7a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 02:14:18 +0800 Subject: [PATCH 19/31] lint pass --- swift/megatron/init.py | 1 - 1 file changed, 1 deletion(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index edf028a2fb..4c14e40f8e 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -603,7 +603,6 @@ def _apply_rotary_pos_emb_thd( if cp_group is None: raise ValueError('cp_group must be provided for THD format RoPE') cp_size = cp_group.size() - cp_rank = cp_group.rank() cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() From 587298e552423fd31e1ef5a56db47be82524ba19 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 17:21:04 +0800 Subject: [PATCH 20/31] update --- swift/llm/template/base.py | 26 ++++++++++----- swift/llm/template/template/qwen.py | 4 +-- swift/megatron/init.py | 2 ++ swift/megatron/model/mm_gpt/qwen2_5_vl.py | 3 +- swift/megatron/model/mm_gpt_model.py | 40 +++++++++++++++++------ swift/megatron/utils/utils.py | 6 ++-- swift/trainers/mixin.py | 2 +- 7 files changed, 55 insertions(+), 28 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 15571decfd..c8a86cb3ad 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1625,7 +1625,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in res = {} if self.padding_free: assert len(batch) == 1, f'batch: {batch}' - for k in ['input_ids', 'labels', 'position_ids', 'loss_scale', 'channel']: + for k in ['input_ids', 'labels', 'position_ids', 'loss_scale', 'channel', 'real_position_ids']: v = batch[0].get(k) if v is not None: res[k] = v if k == 'channel' else [v] @@ -1647,9 +1647,10 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in res[key] = val keys = [ - 'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids' + 'input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids', 'token_type_ids', + 'real_position_ids' ] - pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0] + pad_values = [self.tokenizer.pad_token_id, 0., 0, -100, 0., 0., 0, 0.] # Convert to tensor and remove unnecessary dimensions. seq_lens = None for key in keys: @@ -1676,10 +1677,14 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in if self.padding_free: cp_size = self.sequence_parallel_size if cp_size > 1: - padding_len = padding_to - seq_lens[0] - position_ids = res['position_ids'][0].tolist() - position_ids += list(range(cp_size * 2)) * (padding_len // (cp_size * 2)) - res['position_ids'] = [torch.tensor(position_ids)] + for key in ['position_ids', 'real_position_ids']: + padding_len = padding_to - seq_lens[0] + position_ids = res[key][0] + extended_position_ids = torch.arange(cp_size * 2).repeat(padding_len // (cp_size * 2)) + if position_ids.ndim == 3: # compat mrope + extended_position_ids = extended_position_ids[None, + None, :].expand(position_ids.shape[0], 1, -1) + res[key] = [torch.concat([position_ids, extended_position_ids], dim=-1)] else: seq_len = max(seq_lens) if padding_to is None else padding_to res['attention_mask'] = torch.tril(torch.ones( @@ -1693,13 +1698,16 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in continue if self.use_megatron and not self.padding_free and key == 'attention_mask': continue - if padding_to is not None and not (self.padding_free and key == 'position_ids' + if padding_to is not None and not (self.padding_free and key in {'position_ids', 'real_position_ids'} and self.sequence_parallel_size > 1): padding_len = padding_to - seq_lens[0] if padding_len > 0: res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0), 'constant', pad_value) - res[key] = self._pad_sequence(res[key], pad_value) + if key == 'real_position_ids': + res[key] = torch.concat(res[key], dim=-1) + else: + res[key] = self._pad_sequence(res[key], pad_value) # multimodal res.update(self._data_collator_mm_data(batch)) diff --git a/swift/llm/template/template/qwen.py b/swift/llm/template/template/qwen.py index b59bbfa2f3..f6a935f168 100644 --- a/swift/llm/template/template/qwen.py +++ b/swift/llm/template/template/qwen.py @@ -424,9 +424,7 @@ def _get_position_ids(self, inputs: Dict[str, Any]): def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]: res = super()._data_collator(batch, padding_to=padding_to) - if self.padding_free: - res['real_position_ids'] = self.concat_tensor(batch, 'real_position_ids', -1) - elif self.is_training: + if not self.padding_free and self.is_training: res['position_ids'] = self._get_position_ids(res) return res diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 4c14e40f8e..1693164930 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -648,6 +648,8 @@ def _patch_megatron(): def init_megatron_env() -> None: if 'MEGATRON_LM_PATH' not in os.environ: + # TODO: Synchronization issues may occur in DDP scenarios + # if the distributed environment has not been initialized. os.environ['MEGATRON_LM_PATH'] = git_clone_github( 'https://github.com/NVIDIA/Megatron-LM', branch='core_r0.13.0') with safe_ddp_context(hash_id='megatron-lm'): diff --git a/swift/megatron/model/mm_gpt/qwen2_5_vl.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py index ec6ad09dc5..37dfb065a4 100644 --- a/swift/megatron/model/mm_gpt/qwen2_5_vl.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -70,7 +70,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): media_inputs = to_device(media_inputs, device) pixel_values = media_inputs['pixel_values'].type(dtype) image_embeds = self.model(pixel_values, grid_thw=media_inputs['image_grid_thw']) - inputs_embeds += image_embeds.mean() * 0. + inputs_embeds = inputs_embeds + image_embeds.mean() * 0. else: if pixel_values is None: pixel_values_mixed = pixel_values_videos @@ -95,7 +95,6 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): image_embeds = mixed_embeds[:image_tokens] video_embeds = mixed_embeds[image_tokens:] - input_ids = input_ids.transpose(0, 1) if image_embeds is not None: image_mask = (input_ids == self.model_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index 61b5a2301a..deeece9b63 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -1,6 +1,9 @@ +from contextlib import contextmanager + import torch -from megatron.core import InferenceParams, mpu +from megatron.core import InferenceParams from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel import VocabParallelEmbedding, scatter_to_sequence_parallel_region from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -32,6 +35,28 @@ def __init__(self, if args.megatron_model_meta.visual_cls is not None: self.visual = args.megatron_model_meta.visual_cls(config) + @contextmanager + def _patch_word_embeddings(self, kwargs): + origin_forward = VocabParallelEmbedding.forward + + def forward(_self, input_): + reduce_scatter_embeddings = _self.reduce_scatter_embeddings + _self.reduce_scatter_embeddings = False + res = origin_forward(_self, input_) + _self.reduce_scatter_embeddings = reduce_scatter_embeddings + if self.visual is not None: + res = self.visual.get_inputs_embeds(res, **kwargs) + if reduce_scatter_embeddings: + res = res.transpose(0, 1).contiguous() + res = scatter_to_sequence_parallel_region(res, group=_self.tp_group) + return res + + VocabParallelEmbedding.forward = forward + try: + yield + finally: + VocabParallelEmbedding.forward = origin_forward + # Code borrowed from NVIDIA/Megatron-LM def forward( self, @@ -44,18 +69,13 @@ def forward( packed_seq_params: PackedSeqParams = None, **kwargs, ) -> torch.Tensor: - from ..trainers.utils import get_batch_on_this_cp_rank - args = get_args() if decoder_input is not None: pass elif self.pre_process: - decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) - if self.visual is not None: - if args.tensor_model_parallel_size > 1 and args.sequence_parallel: - input_ids = input_ids.chunk( - args.tensor_model_parallel_size, dim=-1)[mpu.get_tensor_model_parallel_rank()] - kwargs.update({'input_ids': input_ids}) - decoder_input = self.visual.get_inputs_embeds(decoder_input, **kwargs) + from ..trainers.utils import get_batch_on_this_cp_rank + kwargs.update({'input_ids': input_ids}) + with self._patch_word_embeddings(kwargs): + decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids) decoder_input = get_batch_on_this_cp_rank({ 'decoder_input': decoder_input, 'packed_seq_params': packed_seq_params diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index b0bad06933..3578c7e140 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -61,9 +61,9 @@ def get_multimodal_target_regex( for module in modules: rejected_modules = [] if not freeze_vit: - for aligner in aligner: - if aligner.startswith(f'{module}.'): - rejected_modules.append(aligner) + for _aligner in aligner: + if _aligner.startswith(f'{module}.'): + rejected_modules.append(_aligner) sub_module = deep_getattr(model, module) target_modules = find_all_linears(sub_module) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 3cda3e3a0d..89d6bc6332 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -839,7 +839,7 @@ def get_cu_seqlens(self, position_ids, logits_to_keep) -> torch.Tensor: start, end = cu_seqlens[i], cu_seqlens[i + 1] res_cu_seqlens[i + 1:] -= (~logits_to_keep[start:end]).sum() elif isinstance(logits_to_keep, int): - res_cu_seqlens[1:] -= position_ids.shape[0] + 1 - logits_to_keep + res_cu_seqlens[1:] -= position_ids.shape[-1] + 1 - logits_to_keep return res_cu_seqlens def get_batch_samples(self, *args, **kwargs): From a1cb64bd75048dd9c1ae73c826696831d3ea6edd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 17:33:51 +0800 Subject: [PATCH 21/31] update --- swift/megatron/model/mm_gpt/qwen2_5_vl.py | 19 ++++++++++++++----- swift/megatron/utils/convert.py | 4 ++-- tests/megatron/test_align/test_llm.py | 8 +++++++- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/swift/megatron/model/mm_gpt/qwen2_5_vl.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py index 37dfb065a4..d2f4b2ea8b 100644 --- a/swift/megatron/model/mm_gpt/qwen2_5_vl.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -11,7 +11,10 @@ def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): - language_model = hf_model.model.language_model + language_model = hf_model.model + if hasattr(language_model, 'language_model'): + language_model = language_model.language_model + visual = hf_model.visual if hasattr(hf_model, 'visual') else hf_model.model.visual mg_language_model = mg_model.language_model args = get_args() mg_language_model.embedding.word_embeddings.weight.data.copy_(language_model.embed_tokens.weight) @@ -20,11 +23,14 @@ def convert_hf2mcore_qwen2_5_vl(hf_model, mg_model): mg_language_model.decoder.final_layernorm.weight.data.copy_(language_model.norm.weight) for layer_idx in range(args.num_layers): set_layer_state_hf2mcore(args, mg_language_model, language_model, layer_idx) - mg_model.visual.model.load_state_dict(hf_model.model.visual.state_dict()) + mg_model.visual.model.load_state_dict(visual.state_dict()) def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): - language_model = hf_model.model.language_model + language_model = hf_model.model + if hasattr(language_model, 'language_model'): + language_model = language_model.language_model + visual = hf_model.visual if hasattr(hf_model, 'visual') else hf_model.model.visual mg_language_model = mg_model.language_model args = get_args() language_model.embed_tokens.weight.data.copy_(mg_language_model.embedding.word_embeddings.weight) @@ -33,7 +39,7 @@ def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): language_model.norm.weight.data.copy_(mg_language_model.decoder.final_layernorm.weight) for layer_idx in range(args.num_layers): set_layer_state_mcore2hf(args, mg_language_model, language_model, layer_idx) - hf_model.model.visual.load_state_dict(mg_model.visual.model.state_dict()) + visual.load_state_dict(mg_model.visual.model.state_dict()) class Qwen2_5VL_Vit(HuggingFaceModule): @@ -41,7 +47,10 @@ class Qwen2_5VL_Vit(HuggingFaceModule): aligner = ['model.merger'] def __init__(self, config): - from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel + try: + from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel + except ImportError: + from transformers.models.qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel super().__init__(config) args = get_args() model_dir = args.model_info.model_dir diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index 232c2460e3..ac7044c464 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -207,7 +207,7 @@ def convert_hf2mcore(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' config = processor.model_info.config - if args.model_meta.is_multimodal: + if args.model_meta.is_multimodal and hasattr(config, 'text_config'): config = config.text_config kwargs = megatron_model_meta.convert_hf_config(config) logger.info(f'megatron_config: {kwargs}') @@ -246,7 +246,7 @@ def convert_mcore2hf(args: ExportArguments) -> None: megatron_model_meta = get_megatron_model_meta(args.model_type) assert megatron_model_meta is not None, f'Model: {args.model} is not supported.' config = processor.model_info.config - if args.model_meta.is_multimodal: + if args.model_meta.is_multimodal and hasattr(config, 'text_config'): config = config.text_config kwargs = megatron_model_meta.convert_hf_config(config) logger.info(f'megatron_config: {kwargs}') diff --git a/tests/megatron/test_align/test_llm.py b/tests/megatron/test_align/test_llm.py index 163fd1933a..b0006cd023 100644 --- a/tests/megatron/test_align/test_llm.py +++ b/tests/megatron/test_align/test_llm.py @@ -127,6 +127,11 @@ def test_glm4_5(): _test_model('ZhipuAI/GLM-4.5-Air') +def test_qwen2_5_vl(): + os.environ['MAX_PIXELS'] = str(1280 * 28 * 28) + _test_model('Qwen/Qwen2.5-VL-7B-Instruct') + + if __name__ == '__main__': # test_qwen2() # test_llama2() @@ -151,4 +156,5 @@ def test_glm4_5(): # test_kimi_dev() # test_hunyuan() # test_ernie() - test_glm4_5() + # test_glm4_5() + test_qwen2_5_vl() From 798bdd45ae26ac025c05efff74262cba66c78b42 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 17:41:50 +0800 Subject: [PATCH 22/31] fix --- swift/megatron/utils/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/utils/convert.py b/swift/megatron/utils/convert.py index ac7044c464..19b836bd79 100644 --- a/swift/megatron/utils/convert.py +++ b/swift/megatron/utils/convert.py @@ -141,7 +141,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float mg_language_model = mg_model.language_model if is_multimodal else mg_model share_embedding = mg_language_model.share_embeddings_and_output_weights model_arch = hf_model.model_meta.model_arch - ignore_modules = [] if model_arch is None else (model_arch.vision_tower + model_arch.aligner) + ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else [] hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules) with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding): From 5953506a173e8deb0fb83f0e57992dbeefc7f1a6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 17:49:36 +0800 Subject: [PATCH 23/31] fix --- swift/megatron/init.py | 7 +++---- swift/megatron/model/mm_gpt/qwen2_5_vl.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 1693164930..d7a8b0d875 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -519,7 +519,7 @@ def __repr__(self): def _patch_mrope(): - from megatron.core.models.common.embeddings.rotary_pos_embedding import (MultimodalRotaryEmbedding) + from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding from megatron.core import parallel_state from megatron.core.models.common.embeddings.rope_utils import (get_pos_emb_on_this_cp_rank, _apply_rotary_pos_emb_bshd) @@ -529,7 +529,6 @@ def _patch_mrope(): def forward(self, position_ids, mrope_section: List[int], - offset: int = 0, packed_seq: bool = False) -> torch.Tensor: seq = position_ids.to(device=self.inv_freq.device, dtype=self.inv_freq.dtype) @@ -594,8 +593,8 @@ def _apply_rotary_pos_emb_thd( t, cu_seqlens, freqs, - rotary_interleaved=config.rotary_interleaved, - multi_latent_attention=config.multi_latent_attention, + rotary_interleaved=rotary_interleaved, + multi_latent_attention=multi_latent_attention, mscale=mscale, cp_group=cp_group, ) diff --git a/swift/megatron/model/mm_gpt/qwen2_5_vl.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py index d2f4b2ea8b..e1ea21ad22 100644 --- a/swift/megatron/model/mm_gpt/qwen2_5_vl.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -6,7 +6,7 @@ from ..constant import MegatronModelType from ..gpt.hf2mcore import set_layer_state as set_layer_state_hf2mcore from ..gpt.mcore2hf import set_layer_state as set_layer_state_mcore2hf -from ..register import MegatronModelMeta, register_megatron_model +from ..register import register_megatron_model from .utils import MMGPTMegatronModelMeta, patch_device_map_meta From 0508aaf0b726cb4219b8b40b2dee6514c6c48234 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 17:52:51 +0800 Subject: [PATCH 24/31] lint pass --- swift/megatron/init.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index d7a8b0d875..f9ae0747dd 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -526,10 +526,7 @@ def _patch_mrope(): from megatron.core.models.common.embeddings import rope_utils from megatron.training import get_args - def forward(self, - position_ids, - mrope_section: List[int], - packed_seq: bool = False) -> torch.Tensor: + def forward(self, position_ids, mrope_section: List[int], packed_seq: bool = False) -> torch.Tensor: seq = position_ids.to(device=self.inv_freq.device, dtype=self.inv_freq.dtype) if self.seq_len_interpolation_factor is not None: From 6a431aa7a1665c2b2cd53b6cae177e9b0bb2d606 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 18:14:56 +0800 Subject: [PATCH 25/31] update --- examples/megatron/multimodal/dense.sh | 11 +++++--- examples/megatron/multimodal/lora.sh | 38 +++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 examples/megatron/multimodal/lora.sh diff --git a/examples/megatron/multimodal/dense.sh b/examples/megatron/multimodal/dense.sh index bc6f248760..b40e2a1c1f 100644 --- a/examples/megatron/multimodal/dense.sh +++ b/examples/megatron/multimodal/dense.sh @@ -1,14 +1,17 @@ -# 4 * 56GiB; 2.3s/it +# 2 * 72GiB; 4.1s/it PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ -NPROC_PER_NODE=4 \ +NPROC_PER_NODE=2 \ MAX_PIXELS=1003520 \ -CUDA_VISIBLE_DEVICES=0,1,2,3 \ +CUDA_VISIBLE_DEVICES=0,1 \ megatron sft \ --load Qwen2.5-VL-7B-Instruct-mcore \ - --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite' \ + --dataset 'AI-ModelScope/LaTeX_OCR' \ --tensor_model_parallel_size 2 \ --sequence_parallel true \ --packing true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ --split_dataset_ratio 0.01 \ --micro_batch_size 1 \ --global_batch_size 4 \ diff --git a/examples/megatron/multimodal/lora.sh b/examples/megatron/multimodal/lora.sh new file mode 100644 index 0000000000..23a82ab01a --- /dev/null +++ b/examples/megatron/multimodal/lora.sh @@ -0,0 +1,38 @@ +# 2 * 15GiB; 3.8s/it +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --load Qwen2.5-VL-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/LaTeX_OCR' \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --tensor_model_parallel_size 2 \ + --sequence_parallel true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ + --packing true \ + --split_dataset_ratio 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --max_epochs 1 \ + --save megatron_output/Qwen2.5-VL-7B-Instruct \ + --save_interval 200 \ + --vit_gradient_checkpointing true \ + --max_length 2048 \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 8 From 3f35a8672bd7781c7ad19afb8fa24e323a2e30cb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 19:47:36 +0800 Subject: [PATCH 26/31] update --- ...14\346\225\260\346\215\256\351\233\206.md" | 24 +-- ...41\346\200\201\346\250\241\345\236\213.md" | 147 ++++++++++++++++++ ...53\351\200\237\345\274\200\345\247\213.md" | 2 +- .../Supported-models-and-datasets.md | 24 +-- docs/source_en/Megatron-SWIFT/Quick-start.md | 2 +- examples/megatron/multimodal/dense.sh | 2 +- examples/megatron/multimodal/lora.sh | 2 +- swift/megatron/model/constant.py | 1 + swift/megatron/model/gpt/config.py | 2 +- swift/megatron/model/mm_gpt/qwen2_5_vl.py | 32 +++- swift/megatron/model/mm_gpt/utils.py | 10 ++ tests/megatron/test_align/test_llm.py | 6 + 12 files changed, 220 insertions(+), 34 deletions(-) create mode 100644 "docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" index e7ec13b1d5..f9a8ba548b 100644 --- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" +++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" @@ -652,12 +652,12 @@ |[Qwen/Qwen-VL-Chat-Int4](https://modelscope.cn/models/Qwen/Qwen-VL-Chat-Int4)|qwen_vl|qwen_vl|-|✘|vision|[Qwen/Qwen-VL-Chat-Int4](https://huggingface.co/Qwen/Qwen-VL-Chat-Int4)| |[Qwen/Qwen-Audio-Chat](https://modelscope.cn/models/Qwen/Qwen-Audio-Chat)|qwen_audio|qwen_audio|-|✘|audio|[Qwen/Qwen-Audio-Chat](https://huggingface.co/Qwen/Qwen-Audio-Chat)| |[Qwen/Qwen-Audio](https://modelscope.cn/models/Qwen/Qwen-Audio)|qwen_audio|qwen_audio|-|✘|audio|[Qwen/Qwen-Audio](https://huggingface.co/Qwen/Qwen-Audio)| -|[Qwen/Qwen2-VL-2B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)| -|[Qwen/Qwen2-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)| -|[Qwen/Qwen2-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct)| -|[Qwen/Qwen2-VL-2B](https://modelscope.cn/models/Qwen/Qwen2-VL-2B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B](https://huggingface.co/Qwen/Qwen2-VL-2B)| -|[Qwen/Qwen2-VL-7B](https://modelscope.cn/models/Qwen/Qwen2-VL-7B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B)| -|[Qwen/Qwen2-VL-72B](https://modelscope.cn/models/Qwen/Qwen2-VL-72B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B](https://huggingface.co/Qwen/Qwen2-VL-72B)| +|[Qwen/Qwen2-VL-2B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)| +|[Qwen/Qwen2-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)| +|[Qwen/Qwen2-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct)| +|[Qwen/Qwen2-VL-2B](https://modelscope.cn/models/Qwen/Qwen2-VL-2B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-2B](https://huggingface.co/Qwen/Qwen2-VL-2B)| +|[Qwen/Qwen2-VL-7B](https://modelscope.cn/models/Qwen/Qwen2-VL-7B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B)| +|[Qwen/Qwen2-VL-72B](https://modelscope.cn/models/Qwen/Qwen2-VL-72B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-72B](https://huggingface.co/Qwen/Qwen2-VL-72B)| |[Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4)| |[Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4)| |[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4)| @@ -667,12 +667,12 @@ |[Qwen/Qwen2-VL-2B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct-AWQ)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-AWQ)| |[Qwen/Qwen2-VL-7B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct-AWQ)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-AWQ)| |[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct-AWQ)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-AWQ)| -|[bytedance-research/UI-TARS-2B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-2B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-2B-SFT](https://huggingface.co/bytedance-research/UI-TARS-2B-SFT)| -|[bytedance-research/UI-TARS-7B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-7B-SFT](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)| -|[bytedance-research/UI-TARS-7B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-7B-DPO](https://huggingface.co/bytedance-research/UI-TARS-7B-DPO)| -|[bytedance-research/UI-TARS-72B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-SFT](https://huggingface.co/bytedance-research/UI-TARS-72B-SFT)| -|[bytedance-research/UI-TARS-72B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-DPO](https://huggingface.co/bytedance-research/UI-TARS-72B-DPO)| -|[allenai/olmOCR-7B-0225-preview](https://modelscope.cn/models/allenai/olmOCR-7B-0225-preview)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[allenai/olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview)| +|[bytedance-research/UI-TARS-2B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-2B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-2B-SFT](https://huggingface.co/bytedance-research/UI-TARS-2B-SFT)| +|[bytedance-research/UI-TARS-7B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-7B-SFT](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)| +|[bytedance-research/UI-TARS-7B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-7B-DPO](https://huggingface.co/bytedance-research/UI-TARS-7B-DPO)| +|[bytedance-research/UI-TARS-72B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-72B-SFT](https://huggingface.co/bytedance-research/UI-TARS-72B-SFT)| +|[bytedance-research/UI-TARS-72B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-72B-DPO](https://huggingface.co/bytedance-research/UI-TARS-72B-DPO)| +|[allenai/olmOCR-7B-0225-preview](https://modelscope.cn/models/allenai/olmOCR-7B-0225-preview)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[allenai/olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview)| |[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| |[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| |[Qwen/Qwen2.5-VL-32B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct)| diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" new file mode 100644 index 0000000000..6089befc34 --- /dev/null +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -0,0 +1,147 @@ +# 多模态模型 + +ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen2.5-VL等模型的预训练和微调。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。 + +环境准备请参考Megatron-SWIFT的[快速开始文档](./快速开始.md)。 + +## Dense模型 Full/LoRA + +这里介绍使用2卡80GiB A100对Qwen2.5-VL-7B-Instruct模型进行Latex-OCR的微调,分别使用全参数和LoRA的方式,以下最佳实践可以在30分钟内完成。 + +首先,我们需要将HF格式的权重转为Megatron格式: +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --to_mcore true \ + --torch_dtype bfloat16 \ + --output_dir Qwen2.5-VL-7B-Instruct-mcore \ + --test_convert_precision true +``` + +### Full + +全参数训练脚本如下: +```shell +# 2 * 72GiB; 4.1s/it +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --load Qwen2.5-VL-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#20000' \ + --tensor_model_parallel_size 2 \ + --sequence_parallel true \ + --packing true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ + --split_dataset_ratio 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --max_epochs 1 \ + --save megatron_output/Qwen2.5-VL-7B-Instruct \ + --save_interval 200 \ + --vit_gradient_checkpointing true \ + --max_length 2048 \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 8 +``` + +将全参数保存的Megatron格式权重转为HF格式: +- 注意:`--mcore_model`请指向`iter_xxx`的上级目录。默认会使用`latest_checkpointed_iteration.txt`中对应的checkpoint。 +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --mcore_model megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx \ + --to_hf true \ + --torch_dtype bfloat16 \ + --output_dir megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx-hf \ + --test_convert_precision true +``` + +### LoRA + +LoRA训练脚本如下: +```shell +# 2 * 15GiB; 3.8s/it +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --load Qwen2.5-VL-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#20000' \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --tensor_model_parallel_size 2 \ + --sequence_parallel true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ + --packing true \ + --split_dataset_ratio 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --max_epochs 1 \ + --save megatron_output/Qwen2.5-VL-7B-Instruct \ + --save_interval 200 \ + --vit_gradient_checkpointing true \ + --max_length 2048 \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 8 +``` + +将LoRA保存的增量权重进行Merge-LoRA并转为HF格式: +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --mcore_adapters megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx \ + --to_hf true \ + --torch_dtype bfloat16 \ + --output_dir megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx-hf \ + --test_convert_precision true +``` + + +最后,我们对生成的HF格式权重进行推理: +```shell +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --model megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \ + --attn_impl flash_attn \ + --stream true \ + --load_data_args true \ + --temperature 0 \ + --max_new_tokens 2048 +``` + +推理结果如下: +``` +<<< who are you? +I am a language model developed by swift, you can call me swift-robot. How can I assist you? +``` diff --git "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" index d105487afe..782d532cab 100644 --- "a/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\277\253\351\200\237\345\274\200\345\247\213.md" @@ -1,7 +1,7 @@ # 快速开始 -ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行,专家并行。支持Qwen3、[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/qwen3_moe.sh)、Qwen2.5、Llama3、Deepseek-R1、GLM4.5等模型的预训练和微调。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。推荐在MoE训练时使用Megatron-SWIFT,这通常可以获得10倍的训练速度提升。 +ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行,专家并行。支持Qwen3、[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/qwen3_moe.sh)、Qwen2.5、Llama3、Deepseek-R1、GLM4.5等模型的CPT/SFT/DPO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。推荐在MoE训练时使用Megatron-SWIFT,这通常可以获得10倍的训练速度提升。 ## 环境准备 使用Megatron-SWIFT,除了安装swift依赖外,还需要安装以下内容: diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 78dc0bcb8a..1830a604a9 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -652,12 +652,12 @@ The table below introduces the models integrated with ms-swift: |[Qwen/Qwen-VL-Chat-Int4](https://modelscope.cn/models/Qwen/Qwen-VL-Chat-Int4)|qwen_vl|qwen_vl|-|✘|vision|[Qwen/Qwen-VL-Chat-Int4](https://huggingface.co/Qwen/Qwen-VL-Chat-Int4)| |[Qwen/Qwen-Audio-Chat](https://modelscope.cn/models/Qwen/Qwen-Audio-Chat)|qwen_audio|qwen_audio|-|✘|audio|[Qwen/Qwen-Audio-Chat](https://huggingface.co/Qwen/Qwen-Audio-Chat)| |[Qwen/Qwen-Audio](https://modelscope.cn/models/Qwen/Qwen-Audio)|qwen_audio|qwen_audio|-|✘|audio|[Qwen/Qwen-Audio](https://huggingface.co/Qwen/Qwen-Audio)| -|[Qwen/Qwen2-VL-2B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)| -|[Qwen/Qwen2-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)| -|[Qwen/Qwen2-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct)| -|[Qwen/Qwen2-VL-2B](https://modelscope.cn/models/Qwen/Qwen2-VL-2B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B](https://huggingface.co/Qwen/Qwen2-VL-2B)| -|[Qwen/Qwen2-VL-7B](https://modelscope.cn/models/Qwen/Qwen2-VL-7B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B)| -|[Qwen/Qwen2-VL-72B](https://modelscope.cn/models/Qwen/Qwen2-VL-72B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B](https://huggingface.co/Qwen/Qwen2-VL-72B)| +|[Qwen/Qwen2-VL-2B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-2B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct)| +|[Qwen/Qwen2-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)| +|[Qwen/Qwen2-VL-72B-Instruct](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-72B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct)| +|[Qwen/Qwen2-VL-2B](https://modelscope.cn/models/Qwen/Qwen2-VL-2B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-2B](https://huggingface.co/Qwen/Qwen2-VL-2B)| +|[Qwen/Qwen2-VL-7B](https://modelscope.cn/models/Qwen/Qwen2-VL-7B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-7B](https://huggingface.co/Qwen/Qwen2-VL-7B)| +|[Qwen/Qwen2-VL-72B](https://modelscope.cn/models/Qwen/Qwen2-VL-72B)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2-VL-72B](https://huggingface.co/Qwen/Qwen2-VL-72B)| |[Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4)| |[Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4)| |[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4)| @@ -667,12 +667,12 @@ The table below introduces the models integrated with ms-swift: |[Qwen/Qwen2-VL-2B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct-AWQ)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-2B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct-AWQ)| |[Qwen/Qwen2-VL-7B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2-VL-7B-Instruct-AWQ)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-7B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct-AWQ)| |[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://modelscope.cn/models/Qwen/Qwen2-VL-72B-Instruct-AWQ)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[Qwen/Qwen2-VL-72B-Instruct-AWQ](https://huggingface.co/Qwen/Qwen2-VL-72B-Instruct-AWQ)| -|[bytedance-research/UI-TARS-2B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-2B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-2B-SFT](https://huggingface.co/bytedance-research/UI-TARS-2B-SFT)| -|[bytedance-research/UI-TARS-7B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-7B-SFT](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)| -|[bytedance-research/UI-TARS-7B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-7B-DPO](https://huggingface.co/bytedance-research/UI-TARS-7B-DPO)| -|[bytedance-research/UI-TARS-72B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-SFT](https://huggingface.co/bytedance-research/UI-TARS-72B-SFT)| -|[bytedance-research/UI-TARS-72B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[bytedance-research/UI-TARS-72B-DPO](https://huggingface.co/bytedance-research/UI-TARS-72B-DPO)| -|[allenai/olmOCR-7B-0225-preview](https://modelscope.cn/models/allenai/olmOCR-7B-0225-preview)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✘|vision, video|[allenai/olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview)| +|[bytedance-research/UI-TARS-2B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-2B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-2B-SFT](https://huggingface.co/bytedance-research/UI-TARS-2B-SFT)| +|[bytedance-research/UI-TARS-7B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-7B-SFT](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)| +|[bytedance-research/UI-TARS-7B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-7B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-7B-DPO](https://huggingface.co/bytedance-research/UI-TARS-7B-DPO)| +|[bytedance-research/UI-TARS-72B-SFT](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-SFT)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-72B-SFT](https://huggingface.co/bytedance-research/UI-TARS-72B-SFT)| +|[bytedance-research/UI-TARS-72B-DPO](https://modelscope.cn/models/bytedance-research/UI-TARS-72B-DPO)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[bytedance-research/UI-TARS-72B-DPO](https://huggingface.co/bytedance-research/UI-TARS-72B-DPO)| +|[allenai/olmOCR-7B-0225-preview](https://modelscope.cn/models/allenai/olmOCR-7B-0225-preview)|qwen2_vl|qwen2_vl|transformers>=4.45, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[allenai/olmOCR-7B-0225-preview](https://huggingface.co/allenai/olmOCR-7B-0225-preview)| |[Qwen/Qwen2.5-VL-3B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-3B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct)| |[Qwen/Qwen2.5-VL-7B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)| |[Qwen/Qwen2.5-VL-32B-Instruct](https://modelscope.cn/models/Qwen/Qwen2.5-VL-32B-Instruct)|qwen2_5_vl|qwen2_5_vl|transformers>=4.49, qwen_vl_utils>=0.0.6, decord|✔|vision, video|[Qwen/Qwen2.5-VL-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-32B-Instruct)| diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 0e5ceca9d4..389ae8bea1 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -1,6 +1,6 @@ # Quick Start -ms-swift incorporates Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, context parallelism, and expert parallelism. It supports the pre-training and fine-tuning of models such as Qwen3, [Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/qwen3_moe.sh), Qwen2.5, Llama3, Deepseek-R1 and GLM4.5 series. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). We recommend using Megatron-SWIFT for MoE training; it can typically achieve a 10x speedup in training. +ms-swift incorporates Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, context parallelism, and expert parallelism. It supports CPT/SFT/DPO for models such as Qwen3, [Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/qwen3_moe.sh), Qwen2.5, Llama3, Deepseek-R1 and GLM4.5 series. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). We recommend using Megatron-SWIFT for MoE training; it can typically achieve a 10x speedup in training. ## Environment Setup diff --git a/examples/megatron/multimodal/dense.sh b/examples/megatron/multimodal/dense.sh index b40e2a1c1f..3590fad38d 100644 --- a/examples/megatron/multimodal/dense.sh +++ b/examples/megatron/multimodal/dense.sh @@ -5,7 +5,7 @@ MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0,1 \ megatron sft \ --load Qwen2.5-VL-7B-Instruct-mcore \ - --dataset 'AI-ModelScope/LaTeX_OCR' \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ --tensor_model_parallel_size 2 \ --sequence_parallel true \ --packing true \ diff --git a/examples/megatron/multimodal/lora.sh b/examples/megatron/multimodal/lora.sh index 23a82ab01a..69daa5fb9c 100644 --- a/examples/megatron/multimodal/lora.sh +++ b/examples/megatron/multimodal/lora.sh @@ -5,7 +5,7 @@ MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0,1 \ megatron sft \ --load Qwen2.5-VL-7B-Instruct-mcore \ - --dataset 'AI-ModelScope/LaTeX_OCR' \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ --train_type lora \ --lora_rank 8 \ --lora_alpha 32 \ diff --git a/swift/megatron/model/constant.py b/swift/megatron/model/constant.py index 82e9b38913..56e2ea6707 100644 --- a/swift/megatron/model/constant.py +++ b/swift/megatron/model/constant.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. class MegatronModelType: gpt = 'gpt' + qwen2_vl = 'qwen2_vl' qwen2_5_vl = 'qwen2_5_vl' diff --git a/swift/megatron/model/gpt/config.py b/swift/megatron/model/gpt/config.py index 112b0cde5d..7b6a1803a1 100644 --- a/swift/megatron/model/gpt/config.py +++ b/swift/megatron/model/gpt/config.py @@ -39,7 +39,7 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]: res['rotary_interleaved'] = True elif architectures == 'Glm4MoeForCausalLM': res['moe_router_score_function'] = 'sigmoid' - elif architectures == 'Qwen2_5_VLForConditionalGeneration': + elif architectures in {'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration'}: res['position_embedding_type'] = 'mrope' res['mrope_section'] = res['rope_scaling']['mrope_section'] if first_k_dense_replace is not None: diff --git a/swift/megatron/model/mm_gpt/qwen2_5_vl.py b/swift/megatron/model/mm_gpt/qwen2_5_vl.py index e1ea21ad22..ba23e2ee8e 100644 --- a/swift/megatron/model/mm_gpt/qwen2_5_vl.py +++ b/swift/megatron/model/mm_gpt/qwen2_5_vl.py @@ -45,17 +45,26 @@ def convert_mcore2hf_qwen2_5_vl(hf_model, mg_model): class Qwen2_5VL_Vit(HuggingFaceModule): vision_tower = ['model'] aligner = ['model.merger'] + version = 'v2_5' def __init__(self, config): - try: - from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel - except ImportError: - from transformers.models.qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel + if self.version == 'v2_5': + try: + from transformers.models.qwen2_5_vl import Qwen2_5_VLTextModel + except ImportError: + from transformers.models.qwen2_5_vl import Qwen2_5_VLModel as Qwen2_5_VLTextModel + context = patch_device_map_meta(Qwen2_5_VLTextModel) + elif self.version == 'v2': + try: + from transformers.models.qwen2_vl import Qwen2VLTextModel + except ImportError: + from transformers.models.qwen2_vl import Qwen2VLModel as Qwen2VLTextModel + context = patch_device_map_meta(Qwen2VLTextModel) super().__init__(config) args = get_args() model_dir = args.model_info.model_dir kwargs = {'attn_impl': 'flash_attn'} if args.attention_backend.name == 'flash' else {} - with patch_device_map_meta(Qwen2_5_VLTextModel): + with context: model, _ = get_model_tokenizer(model_dir, args.torch_dtype, return_dummy_model=True, **kwargs) self.model = model.visual.to('cuda') self.model_config = model.config @@ -116,6 +125,10 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds +class Qwen2VL_Vit(Qwen2_5VL_Vit): + version = 'v2' + + register_megatron_model( MMGPTMegatronModelMeta( MegatronModelType.qwen2_5_vl, [ @@ -124,3 +137,12 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): convert_hf2mcore=convert_hf2mcore_qwen2_5_vl, convert_mcore2hf=convert_mcore2hf_qwen2_5_vl, visual_cls=Qwen2_5VL_Vit)) + +register_megatron_model( + MMGPTMegatronModelMeta( + MegatronModelType.qwen2_vl, [ + ModelType.qwen2_vl, + ], + convert_hf2mcore=convert_hf2mcore_qwen2_5_vl, + convert_mcore2hf=convert_mcore2hf_qwen2_5_vl, + visual_cls=Qwen2VL_Vit)) diff --git a/swift/megatron/model/mm_gpt/utils.py b/swift/megatron/model/mm_gpt/utils.py index ec949f19c5..2b85c9f0f7 100644 --- a/swift/megatron/model/mm_gpt/utils.py +++ b/swift/megatron/model/mm_gpt/utils.py @@ -21,10 +21,20 @@ def __init__(self, *args, **kwargs): __origin_init__(self, *args, **kwargs) model_cls.__init__ = __init__ + + from transformers import PreTrainedModel + _origin_initialize_weight = PreTrainedModel._initialize_weights + + def _initialize_weight(self, *args, **kwargs): + return + + PreTrainedModel._initialize_weights = _initialize_weight + try: yield finally: model_cls.__init__ = __origin_init__ + PreTrainedModel._initialize_weights = _origin_initialize_weight @dataclass diff --git a/tests/megatron/test_align/test_llm.py b/tests/megatron/test_align/test_llm.py index b0006cd023..69e62f574f 100644 --- a/tests/megatron/test_align/test_llm.py +++ b/tests/megatron/test_align/test_llm.py @@ -132,6 +132,11 @@ def test_qwen2_5_vl(): _test_model('Qwen/Qwen2.5-VL-7B-Instruct') +def test_qwen2_vl(): + os.environ['MAX_PIXELS'] = str(1280 * 28 * 28) + _test_model('Qwen/Qwen2-VL-7B-Instruct') + + if __name__ == '__main__': # test_qwen2() # test_llama2() @@ -158,3 +163,4 @@ def test_qwen2_5_vl(): # test_ernie() # test_glm4_5() test_qwen2_5_vl() + # test_qwen2_vl() From 0156a0cb5dd6867a63d5ca6d6c9b9fd58748c3bb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 20:02:36 +0800 Subject: [PATCH 27/31] update --- ...41\346\200\201\346\250\241\345\236\213.md" | 21 ++- docs/source/index.rst | 1 + .../Megatron-SWIFT/Multimodal-Model.md | 158 ++++++++++++++++++ docs/source_en/index.rst | 1 + 4 files changed, 175 insertions(+), 6 deletions(-) create mode 100644 docs/source_en/Megatron-SWIFT/Multimodal-Model.md diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" index 6089befc34..3294c96f01 100644 --- "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -6,7 +6,7 @@ ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。 ## Dense模型 Full/LoRA -这里介绍使用2卡80GiB A100对Qwen2.5-VL-7B-Instruct模型进行Latex-OCR的微调,分别使用全参数和LoRA的方式,以下最佳实践可以在30分钟内完成。 +这里介绍使用2卡80GiB A100对Qwen2.5-VL-7B-Instruct模型进行Latex-OCR的微调,分别使用全参数和LoRA的方式,以下最佳实践可以在10分钟内完成。 首先,我们需要将HF格式的权重转为Megatron格式: ```shell @@ -30,7 +30,7 @@ MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0,1 \ megatron sft \ --load Qwen2.5-VL-7B-Instruct-mcore \ - --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#20000' \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ --tensor_model_parallel_size 2 \ --sequence_parallel true \ --packing true \ @@ -82,7 +82,7 @@ MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0,1 \ megatron sft \ --load Qwen2.5-VL-7B-Instruct-mcore \ - --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#20000' \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ --train_type lora \ --lora_rank 8 \ --lora_alpha 32 \ @@ -137,11 +137,20 @@ swift infer \ --stream true \ --load_data_args true \ --temperature 0 \ - --max_new_tokens 2048 + --max_new_tokens 512 ``` 推理结果如下: ``` -<<< who are you? -I am a language model developed by swift, you can call me swift-robot. How can I assist you? +[QUERY] Using LaTeX to perform OCR on the image. +[LABELS] \forall x \in X , ( \alpha f ) ( x ) = \alpha f ( x ) +[RESPONSE] \forall x \in X , ( \alpha f ) ( x ) = \alpha f ( x ) +-------------------------------------------------- +[QUERY] Using LaTeX to perform OCR on the image. +[LABELS] \pi \int _ { c } ^ { d } \{ g ( y ) \} ^ { 2 } d y +[RESPONSE] \pi \int _ { c } ^ { d } \{ g ( y ) \} ^ { 2 } d y +-------------------------------------------------- +[QUERY] Using LaTeX to perform OCR on the image. +[LABELS] [ \frac 2 3 x ^ { \frac 3 2 } ] _ { 0 } ^ { 1 } +[RESPONSE] [ \frac 2 3 x ^ { \frac 3 2 } ] _ { 0 } ^ { 1 } ``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 7e9dbba285..af6129e628 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -38,6 +38,7 @@ Swift DOCUMENTATION Megatron-SWIFT/快速开始.md Megatron-SWIFT/命令行参数.md Megatron-SWIFT/LoRA训练.md + Megatron-SWIFT/多模态模型.md .. toctree:: :maxdepth: 2 diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md new file mode 100644 index 0000000000..13a20b0b60 --- /dev/null +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -0,0 +1,158 @@ +# Multimodal Models + +ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports pretraining and fine-tuning for models such as Qwen2.5-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). + +For environment setup, please refer to the Megatron-SWIFT [Quick Start guide](./Quick-start.md). + +## Dense Model Full/LoRA Fine-tuning + +This section demonstrates fine-tuning the Qwen2.5-VL-7B-Instruct model on the LaTeX-OCR task using two 80GiB A100 GPUs, with both full-parameter fine-tuning and LoRA. The best practices described below can be completed within 10 minutes. + +First, we need to convert the model weights from Hugging Face format to Megatron format: +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --model Qwen/Qwen2.5-VL-7B-Instruct \ + --to_mcore true \ + --torch_dtype bfloat16 \ + --output_dir Qwen2.5-VL-7B-Instruct-mcore \ + --test_convert_precision true +``` + +### Full + +The full-parameter training script is as follows: +```shell +# 2 * 72GiB; 4.1s/it +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --load Qwen2.5-VL-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ + --tensor_model_parallel_size 2 \ + --sequence_parallel true \ + --packing true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ + --split_dataset_ratio 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --max_epochs 1 \ + --save megatron_output/Qwen2.5-VL-7B-Instruct \ + --save_interval 200 \ + --vit_gradient_checkpointing true \ + --max_length 2048 \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 8 +``` + +Convert Megatron-format weights saved with full parameters to Hugging Face format: + +- Note: `--mcore_model` should point to the parent directory of `iter_xxx`. By default, the checkpoint specified in `latest_checkpointed_iteration.txt` will be used. + +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --mcore_model megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx \ + --to_hf true \ + --torch_dtype bfloat16 \ + --output_dir megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx-hf \ + --test_convert_precision true +``` + +### LoRA + +The LoRA training script is as follows: +```shell +# 2 * 15GiB; 3.8s/it +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron sft \ + --load Qwen2.5-VL-7B-Instruct-mcore \ + --dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --tensor_model_parallel_size 2 \ + --sequence_parallel true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ + --packing true \ + --split_dataset_ratio 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --max_epochs 1 \ + --save megatron_output/Qwen2.5-VL-7B-Instruct \ + --save_interval 200 \ + --vit_gradient_checkpointing true \ + --max_length 2048 \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 8 +``` + +Merge the LoRA-saved incremental weights and convert them to Hugging Face format: +```shell +CUDA_VISIBLE_DEVICES=0 \ +swift export \ + --mcore_adapters megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx \ + --to_hf true \ + --torch_dtype bfloat16 \ + --output_dir megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx-hf \ + --test_convert_precision true +``` + + +Finally, we perform inference using the generated Hugging Face format weights: +```shell +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --model megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \ + --attn_impl flash_attn \ + --stream true \ + --load_data_args true \ + --temperature 0 \ + --max_new_tokens 512 +``` + +The inference results are as follows: +``` +[QUERY] Using LaTeX to perform OCR on the image. +[LABELS] \forall x \in X , ( \alpha f ) ( x ) = \alpha f ( x ) +[RESPONSE] \forall x \in X , ( \alpha f ) ( x ) = \alpha f ( x ) +-------------------------------------------------- +[QUERY] Using LaTeX to perform OCR on the image. +[LABELS] \pi \int _ { c } ^ { d } \{ g ( y ) \} ^ { 2 } d y +[RESPONSE] \pi \int _ { c } ^ { d } \{ g ( y ) \} ^ { 2 } d y +-------------------------------------------------- +[QUERY] Using LaTeX to perform OCR on the image. +[LABELS] [ \frac 2 3 x ^ { \frac 3 2 } ] _ { 0 } ^ { 1 } +[RESPONSE] [ \frac 2 3 x ^ { \frac 3 2 } ] _ { 0 } ^ { 1 } +``` diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index a7a3ac0811..c561735643 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -38,6 +38,7 @@ Swift DOCUMENTATION Megatron-SWIFT/Quick-start.md Megatron-SWIFT/Command-line-parameters.md Megatron-SWIFT/LoRA-Training.md + Megatron-SWIFT/Multimodal-Model.md .. toctree:: From 4c810cb18404538154963521ae91e74e2efd5bdd Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 20:17:57 +0800 Subject: [PATCH 28/31] fix --- README.md | 1 + README_CN.md | 1 + ...4\232\346\250\241\346\200\201\346\250\241\345\236\213.md" | 2 +- docs/source_en/Megatron-SWIFT/Multimodal-Model.md | 2 +- swift/megatron/model/mm_gpt_model.py | 2 +- swift/megatron/trainers/base.py | 5 ++++- swift/megatron/utils/utils.py | 2 ++ 7 files changed, 11 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 710d6a0b95..80215b90be 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group: ## 🎉 News +- 🎁 2025.09.02: Megatron-SWIFT now supports multimodal model training. Documentation can be found [here](./docs/source_en/Megatron-SWIFT/Multimodal-Model.md). - 🎁 2025.08.12: Support [Dynamic Fine-Tuning](https://arxiv.org/abs/2508.05629)(DFT) in SFT training, use parameter `--enable_dft_loss true`. Training scripts can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/dft.sh). - 🎁 2025.07.12: Deployment(pt/vLLM/SGLang) of Embedding models is supported, check [here](examples/deploy/embedding/client.py). - 🎁 2025.07.09: Megatron-SWIFT supports LoRA training. Compared to ms-swift, it achieves significant speedup on MoE models. Training scripts can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/lora). diff --git a/README_CN.md b/README_CN.md index d34ac9bd85..ef6f062a76 100644 --- a/README_CN.md +++ b/README_CN.md @@ -71,6 +71,7 @@ - **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。 ## 🎉 新闻 +- 🎁 2025.09.02: Megatron-SWIFT支持多模态模型训练。文档参考[这里](./docs/source/Megatron-SWIFT/多模态模型.md)。 - 🎁 2025.08.12: 支持在SFT训练中使用[Dynamic Fine-Tuning](https://arxiv.org/abs/2508.05629)(DFT),使用参数 `--enable_dft_loss true`。训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/dft.sh) - 🎁 2025.07.12: 支持部署Embedding模型的部署(pt/vLLM/SGLang), 查看[这里](examples/deploy/embedding/client.py). - 🎁 2025.07.09: Megatron-SWIFT支持LoRA训练。相比ms-swift,在MoE模型提速显著。训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/lora)。 diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" index 3294c96f01..4d93655a94 100644 --- "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -127,7 +127,7 @@ swift export \ ``` -最后,我们对生成的HF格式权重进行推理: +最后,我们使用生成的HF格式权重对验证集进行推理: ```shell MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0 \ diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index 13a20b0b60..3ca0068230 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -129,7 +129,7 @@ swift export \ ``` -Finally, we perform inference using the generated Hugging Face format weights: +Finally, we use the generated Hugging Face format weights to perform inference on the validation set: ```shell MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0 \ diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index deeece9b63..d4871b5b54 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -32,7 +32,7 @@ def __init__(self, self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights args = get_args() self.visual = None - if args.megatron_model_meta.visual_cls is not None: + if pre_process and args.megatron_model_meta.visual_cls is not None: self.visual = args.megatron_model_meta.visual_cls(config) @contextmanager diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 841a4ae229..ad17437a42 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -264,7 +264,10 @@ def new_model_provider_func(*args, **kwargs): return model, optimizer, opt_param_scheduler def _prepare_vit_gradient_checkpointing(self): - visual = self.unwrapped_model.visual.model + visual = self.unwrapped_model.visual + if visual is None: + return + visual = visual.model args = get_args() if args.vit_gradient_checkpointing: visual.gradient_checkpointing_enable(**(args.gradient_checkpointing_kwargs or {})) diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index 3578c7e140..e1611326c1 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -66,6 +66,8 @@ def get_multimodal_target_regex( rejected_modules.append(_aligner) sub_module = deep_getattr(model, module) + if sub_module is None: + continue target_modules = find_all_linears(sub_module) if not target_modules: continue From d90d2821fb13741d920f2f8074fc99674284fe39 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 2 Sep 2025 20:28:08 +0800 Subject: [PATCH 29/31] fix --- ...\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" | 2 +- docs/source_en/Megatron-SWIFT/Multimodal-Model.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" index 4d93655a94..9d951662e6 100644 --- "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -132,7 +132,7 @@ swift export \ MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0 \ swift infer \ - --model megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \ + --model megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx-hf \ --attn_impl flash_attn \ --stream true \ --load_data_args true \ diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index 3ca0068230..51e333b6ca 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -134,7 +134,7 @@ Finally, we use the generated Hugging Face format weights to perform inference o MAX_PIXELS=1003520 \ CUDA_VISIBLE_DEVICES=0 \ swift infer \ - --model megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \ + --model megatron_output/Qwen2.5-VL-7B-Instruct/vx-xxx-hf \ --attn_impl flash_attn \ --stream true \ --load_data_args true \ From f89a2bbb6ce7074ce117f49f0352921b19f5fcba Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 3 Sep 2025 00:26:56 +0800 Subject: [PATCH 30/31] update --- ...41\346\200\201\346\250\241\345\236\213.md" | 2 +- .../Megatron-SWIFT/Multimodal-Model.md | 2 +- examples/megatron/multimodal/dense/dpo.sh | 42 +++++++++++++++++++ .../multimodal/{dense.sh => dense/full.sh} | 0 .../megatron/multimodal/{ => dense}/lora.sh | 0 swift/megatron/train/sft.py | 2 - 6 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 examples/megatron/multimodal/dense/dpo.sh rename examples/megatron/multimodal/{dense.sh => dense/full.sh} (100%) rename examples/megatron/multimodal/{ => dense}/lora.sh (100%) diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" index 9d951662e6..4e4bb3bdd8 100644 --- "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -1,6 +1,6 @@ # 多模态模型 -ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen2.5-VL等模型的预训练和微调。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。 +ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen2.5-VL等模型的CPT/SFT/DPO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/支持的模型和数据集.md)。 环境准备请参考Megatron-SWIFT的[快速开始文档](./快速开始.md)。 diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index 51e333b6ca..c4cda73cec 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -1,6 +1,6 @@ # Multimodal Models -ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports pretraining and fine-tuning for models such as Qwen2.5-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). +ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO for models such as Qwen2.5-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). For environment setup, please refer to the Megatron-SWIFT [Quick Start guide](./Quick-start.md). diff --git a/examples/megatron/multimodal/dense/dpo.sh b/examples/megatron/multimodal/dense/dpo.sh new file mode 100644 index 0000000000..3aa969269e --- /dev/null +++ b/examples/megatron/multimodal/dense/dpo.sh @@ -0,0 +1,42 @@ +# 2 * 23GiB; 16s/it +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=2 \ +MAX_PIXELS=1003520 \ +CUDA_VISIBLE_DEVICES=0,1 \ +megatron rlhf \ + --rlhf_type dpo \ + --load Qwen2.5-VL-7B-Instruct-mcore \ + --dataset 'swift/RLAIF-V-Dataset#20000' \ + --train_type lora \ + --lora_rank 8 \ + --lora_alpha 32 \ + --target_modules all-linear \ + --tensor_model_parallel_size 2 \ + --sequence_parallel true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ + --packing true \ + --split_dataset_ratio 0.01 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-4 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-5 \ + --max_epochs 1 \ + --save megatron_output/Qwen2.5-VL-7B-Instruct \ + --save_interval 100 \ + --vit_gradient_checkpointing true \ + --max_length 8192 \ + --num_workers 4 \ + --no_save_optim true \ + --no_save_rng true \ + --dataset_num_proc 16 \ + --attention_backend flash \ + --beta 0.1 \ + --loss_type sigmoid diff --git a/examples/megatron/multimodal/dense.sh b/examples/megatron/multimodal/dense/full.sh similarity index 100% rename from examples/megatron/multimodal/dense.sh rename to examples/megatron/multimodal/dense/full.sh diff --git a/examples/megatron/multimodal/lora.sh b/examples/megatron/multimodal/dense/lora.sh similarity index 100% rename from examples/megatron/multimodal/lora.sh rename to examples/megatron/multimodal/dense/lora.sh diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index 3a2e3dde55..289529a5f9 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -63,8 +63,6 @@ def run(self): if val_dataset is not None: val_dataset = build_streaming_dataloader(args, val_dataset, data_collator) - logging_path = os.path.join(args.save, 'logging.jsonl') - logger.info(f'The logging file will be saved in: {logging_path}') try: self.trainer.train(train_dataset, val_dataset, data_collator) finally: From dc803a0351ab8660cb2b84093d43e23fa7735b99 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 3 Sep 2025 10:15:58 +0800 Subject: [PATCH 31/31] update --- ...41\346\200\201\346\250\241\345\236\213.md" | 4 ++-- .../Megatron-SWIFT/Multimodal-Model.md | 4 ++-- examples/megatron/multimodal/dense/dpo.sh | 19 ++++++++----------- examples/megatron/multimodal/dense/lora.sh | 4 ++-- swift/megatron/trainers/dpo_trainer.py | 4 +--- 5 files changed, 15 insertions(+), 20 deletions(-) diff --git "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" index 4e4bb3bdd8..e9186658c5 100644 --- "a/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" +++ "b/docs/source/Megatron-SWIFT/\345\244\232\346\250\241\346\200\201\346\250\241\345\236\213.md" @@ -75,7 +75,7 @@ swift export \ LoRA训练脚本如下: ```shell -# 2 * 15GiB; 3.8s/it +# 2 * 23GiB; 2.3s/it PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=2 \ MAX_PIXELS=1003520 \ @@ -87,7 +87,7 @@ megatron sft \ --lora_rank 8 \ --lora_alpha 32 \ --target_modules all-linear \ - --tensor_model_parallel_size 2 \ + --tensor_model_parallel_size 1 \ --sequence_parallel true \ --freeze_llm false \ --freeze_vit true \ diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index c4cda73cec..91becd4917 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -77,7 +77,7 @@ swift export \ The LoRA training script is as follows: ```shell -# 2 * 15GiB; 3.8s/it +# 2 * 23GiB; 2.3s/it PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=2 \ MAX_PIXELS=1003520 \ @@ -89,7 +89,7 @@ megatron sft \ --lora_rank 8 \ --lora_alpha 32 \ --target_modules all-linear \ - --tensor_model_parallel_size 2 \ + --tensor_model_parallel_size 1 \ --sequence_parallel true \ --freeze_llm false \ --freeze_vit true \ diff --git a/examples/megatron/multimodal/dense/dpo.sh b/examples/megatron/multimodal/dense/dpo.sh index 3aa969269e..edea6bdb35 100644 --- a/examples/megatron/multimodal/dense/dpo.sh +++ b/examples/megatron/multimodal/dense/dpo.sh @@ -1,17 +1,14 @@ -# 2 * 23GiB; 16s/it +# 4 * 60GiB 14s/it PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ -NPROC_PER_NODE=2 \ +NPROC_PER_NODE=4 \ MAX_PIXELS=1003520 \ -CUDA_VISIBLE_DEVICES=0,1 \ +CUDA_VISIBLE_DEVICES=0,1,2,3 \ megatron rlhf \ --rlhf_type dpo \ --load Qwen2.5-VL-7B-Instruct-mcore \ --dataset 'swift/RLAIF-V-Dataset#20000' \ - --train_type lora \ - --lora_rank 8 \ - --lora_alpha 32 \ - --target_modules all-linear \ - --tensor_model_parallel_size 2 \ + --train_type full \ + --tensor_model_parallel_size 4 \ --sequence_parallel true \ --freeze_llm false \ --freeze_vit true \ @@ -25,12 +22,12 @@ megatron rlhf \ --recompute_num_layers 1 \ --finetune true \ --cross_entropy_loss_fusion true \ - --lr 1e-4 \ + --lr 1e-5 \ --lr_warmup_fraction 0.05 \ - --min_lr 1e-5 \ + --min_lr 1e-6 \ --max_epochs 1 \ --save megatron_output/Qwen2.5-VL-7B-Instruct \ - --save_interval 100 \ + --save_interval 200 \ --vit_gradient_checkpointing true \ --max_length 8192 \ --num_workers 4 \ diff --git a/examples/megatron/multimodal/dense/lora.sh b/examples/megatron/multimodal/dense/lora.sh index 69daa5fb9c..1e232f5f07 100644 --- a/examples/megatron/multimodal/dense/lora.sh +++ b/examples/megatron/multimodal/dense/lora.sh @@ -1,4 +1,4 @@ -# 2 * 15GiB; 3.8s/it +# 2 * 23GiB; 2.3s/it PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=2 \ MAX_PIXELS=1003520 \ @@ -10,7 +10,7 @@ megatron sft \ --lora_rank 8 \ --lora_alpha 32 \ --target_modules all-linear \ - --tensor_model_parallel_size 2 \ + --tensor_model_parallel_size 1 \ --sequence_parallel true \ --freeze_llm false \ --freeze_vit true \ diff --git a/swift/megatron/trainers/dpo_trainer.py b/swift/megatron/trainers/dpo_trainer.py index 7798de2b08..aef42560c1 100644 --- a/swift/megatron/trainers/dpo_trainer.py +++ b/swift/megatron/trainers/dpo_trainer.py @@ -179,9 +179,7 @@ def _replace_data_iterator(self, data_iterator): return iter(res) def forward_step(self, data_iterator, model): - with torch.no_grad(): - data = next(data_iterator) - + data = next(data_iterator) ref_logps = data.pop('logps') with self.stimer: output_tensor = model(**data)