diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index e1de0d6d27..f551b9d906 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -298,7 +298,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。该参数只对`vit_gradient_checkpointing`生效。 - 🔥packing: 是否使用序列packing提升计算效率(不同节点与进程更负载均衡,GPU利用率更高;但需要额外的预处理时间)并稳定显存占用,默认为False。当前支持CPT/SFT/DPO/KTO/RM。 - 注意:**同一batch的不同序列之间依旧是不可见的**,除了Qwen3-Next。 - - 注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 + - 注意:**packing会导致数据集样本数减少,请自行调节global_batch_size和学习率**。 - packing_length: packing的长度。默认为None,设置为max_length。 - packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效) - streaming: 流式读取并处理数据集,默认False。 diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 9161bdaf55..8c92e2b6b9 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -27,6 +27,7 @@ pip install --no-build-isolation transformer_engine[pytorch] # pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5#egg=transformer_engine[pytorch] # apex +# 提示:Megatron-SWIFT可以在不含apex的环境下运行,额外设置`--no_gradient_accumulation_fusion true`即可。 git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ @@ -65,7 +66,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | torch | >=2.0 | 2.7.1/2.8.0 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | | 0.14 | | +| megatron_core | >=0.12 | 0.14 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 2c5e465576..8e0ef3085a 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -315,7 +315,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled. - 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, DPO, KTO and RM. - Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next. - - Note: **Packing reduces the number of samples in the dataset; please adjust the gradient accumulation steps and learning rate accordingly**. + - Note: **Packing will reduce the number of dataset samples. Please adjust global_batch_size and learning rate accordingly**. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing) - streaming: Stream data loading and processing, default is False. diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 292922b0d6..ed46f0471f 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -26,6 +26,7 @@ pip install --no-build-isolation transformer_engine[pytorch] # pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5#egg=transformer_engine[pytorch] # apex +# Note: Megatron-SWIFT can run in environments without apex by setting `--no_gradient_accumulation_fusion true`. git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ @@ -65,7 +66,7 @@ Recommended Operating Environment: | torch | >=2.0 | 2.7.1/2.8.0 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | | 0.14 | | +| megatron_core | >=0.12 | 0.14 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | diff --git a/examples/models/qwen3_next/mcore.sh b/examples/models/qwen3_next/mcore.sh index 6b36795beb..f520429868 100644 --- a/examples/models/qwen3_next/mcore.sh +++ b/examples/models/qwen3_next/mcore.sh @@ -11,7 +11,10 @@ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=8 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ megatron sft \ - --load Qwen3-Next-80B-A3B-Instruct-mcore \ + --model Qwen/Qwen3-Next-80B-A3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --merge_lora false \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT#2000' \ 'swift/self-cognition#1000' \ --load_from_cache_file true \ @@ -23,7 +26,7 @@ megatron sft \ --moe_permute_fusion true \ --moe_grouped_gemm true \ --moe_shared_expert_overlap true \ - --moe_aux_loss_coeff 1e-3 \ + --moe_aux_loss_coeff 1e-6 \ --micro_batch_size 2 \ --global_batch_size 16 \ --recompute_granularity full \ @@ -47,3 +50,9 @@ megatron sft \ --attention_backend flash \ --model_author swift \ --model_name swift-robot + + +# CUDA_VISIBLE_DEVICES=0,1,2,3 \ +# swift infer \ +# --adapters megatron_output/Qwen3-Next-80B-A3B-Instruct/vx-xxx/checkpoint-xxx \ +# --stream true diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index d9f6ba9e22..556a631f36 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -454,6 +454,8 @@ def __post_init__(self): MegatronTunerMixin.__post_init__(self) os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' self._set_default() + if self.optimizer_cpu_offload: + require_version('megatron-core>=0.13') self.model_info, self.model_meta = get_model_info_meta( self.model, model_type=self.model_type, use_hf=self.use_hf, hub_token=self.hub_token) self.model_type = self.model_info.model_type diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 6a591f4429..fcf602ed00 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -66,7 +66,7 @@ def _patch_mla_attention(): gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region, ) - megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -112,7 +112,7 @@ def forward( # Adjust key, value for inference # =================================================== # rotary_pos_emb = None - if megatron_core_013: + if mcore_013: query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference( inference_context, query, key, value, rotary_pos_emb=None) else: @@ -430,7 +430,7 @@ def _patch_TransformerLayer(): from megatron.training import get_args from megatron.core.transformer import TransformerLayer _origin_forward = TransformerLayer.forward - megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') def forward(self, *_args, **kwargs): """ @@ -439,7 +439,7 @@ def forward(self, *_args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ - if not megatron_core_013: + if not mcore_013: return _origin_forward(self, *_args, **kwargs) hidden_states, context = self._forward_attention(*_args, **kwargs) args = get_args() @@ -551,11 +551,14 @@ def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): def _patch_mrope(): from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding from megatron.core import parallel_state + import megatron.core from megatron.core.models.common.embeddings.rope_utils import (get_pos_emb_on_this_cp_rank, _apply_rotary_pos_emb_bshd) from megatron.core.models.common.embeddings import rope_utils from megatron.training import get_args + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + # Code borrowed from huggingface/transformers def apply_interleaved_mrope(freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. @@ -638,13 +641,16 @@ def _apply_rotary_pos_emb_thd( Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ - use_batched_rope = False if cp_group is not None: cp_size = cp_group.size() - cu_seqlens_for_batched = cu_seqlens // cp_size - use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() + else: + args = get_args() + cp_size = args.context_parallel_size + cu_seqlens_for_batched = cu_seqlens // cp_size + use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() if not use_batched_rope: logger.warning_once('Using non-batched RoPE, which may affect performance.') + kwargs = {'cp_group': cp_group} if mcore_013 else {} return _origin_apply_rotary_pos_emb_thd( t, cu_seqlens, @@ -652,10 +658,8 @@ def _apply_rotary_pos_emb_thd( rotary_interleaved=rotary_interleaved, multi_latent_attention=multi_latent_attention, mscale=mscale, - cp_group=cp_group, + **kwargs, ) - if cp_group is None: - raise ValueError('cp_group must be provided for THD format RoPE') return _apply_rotary_pos_emb_bshd( t.unsqueeze(1), diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index ab95d2a4d0..7a1419c5c5 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -2,6 +2,7 @@ from copy import deepcopy from typing import Optional, Tuple, Union +import megatron.core import torch from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, _get_extra_te_kwargs from megatron.core.inference.contexts import BaseInferenceContext @@ -17,6 +18,7 @@ from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from megatron.core.utils import deprecate_inference_params, is_fa_min_version from megatron.training import get_args +from packaging import version from swift.llm import ModelType from swift.utils import get_logger @@ -24,6 +26,7 @@ from ..gpt_bridge import GPTBridge from ..register import MegatronModelMeta, register_megatron_model +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') try: from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import flash_attn_with_kvcache as flash_attn3_with_kvcache @@ -58,6 +61,7 @@ class Qwen3NextSelfAttention(SelfAttention): def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs): super(SelfAttention, self).__init__(config, submodules, *args, attention_type='self', **kwargs) + kwargs = {'tp_group': self.model_comm_pgs.tp} if mcore_013 else {} self.linear_qkv = build_module( submodules.linear_qkv, self.config.hidden_size, @@ -69,7 +73,7 @@ def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodule skip_bias_add=False, is_expert=False, tp_comm_buffer_name='qkv', - tp_group=self.model_comm_pgs.tp, + **kwargs, ) if submodules.q_layernorm is not None: @@ -130,12 +134,22 @@ def forward( (Tuple[Tensor, Tensor]) Attention output and bias. """ - from megatron.core.utils import nvtx_range_pop, nvtx_range_push + try: + from megatron.core.utils import nvtx_range_pop, nvtx_range_push + except ImportError: + + def nvtx_range_pop(*args, **kwargs): + return + + def nvtx_range_push(*args, **kwargs): + return + # Check if we need to skip RoPE # no_rope is 0-indexed array and self.layer_number is 1-indexed - no_rope = (self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False) - if no_rope: - rotary_pos_emb = None + if hasattr(self.config, 'no_rope_freq'): + no_rope = (self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False) + if no_rope: + rotary_pos_emb = None inference_context = deprecate_inference_params(inference_context, inference_params) @@ -194,17 +208,20 @@ def forward( if (in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching()): raise ValueError('CUDA graphs must use flash decode with static batching!') - query, key, value, rotary_pos_emb, attn_mask_type, block_table = ( - self._adjust_key_value_for_inference( - inference_context, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - )) + result = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + if mcore_013: + query, key, value, rotary_pos_emb, attn_mask_type, block_table = result + else: + query, key, value, rotary_pos_emb, attn_mask_type = result if packed_seq_params is not None: query = query.squeeze(1) @@ -215,6 +232,7 @@ def forward( # ================================================ # relative positional embedding (rotary embedding) # ================================================ + kwargs = {'cp_group': self.model_comm_pgs.cp} if mcore_013 else {} nvtx_range_push(suffix='rotary_pos_emb') if rotary_pos_emb is not None and not self.config.flash_decode: q_pos_emb, k_pos_emb = rotary_pos_emb @@ -239,18 +257,18 @@ def forward( q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, - cp_group=self.model_comm_pgs.cp, + **kwargs, ) else: query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q, - self.model_comm_pgs.cp) + **kwargs) if k_pos_emb is not None: key = apply_rotary_pos_emb( key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv, - cp_group=self.model_comm_pgs.cp, + **kwargs, ) # TODO, can apply positional embedding to value_layer so it has @@ -418,16 +436,17 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): def get_local_layer_specs(config, layer_specs, vp_stage=None): - from megatron.core.transformer.enums import LayerType - num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} + num_layers_to_build = get_num_layers_to_build(config, **kwargs) - if config.pipeline_model_parallel_layout is not None: + if getattr(config, 'pipeline_model_parallel_layout', None) is not None: + from megatron.core.transformer.enums import LayerType local_layer_specs = [ layer_specs[layer_id] for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( - layer_type=LayerType.decoder, vp_stage=vp_stage) + layer_type=LayerType.decoder, **kwargs) ] else: - offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + offset = get_transformer_layer_offset(config, **kwargs) local_layer_specs = layer_specs[offset:offset + num_layers_to_build] return local_layer_specs @@ -446,13 +465,14 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): config.linear_conv_kernel_dim = args.linear_conv_kernel_dim layer_norm_impl = TENorm + kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {} moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - use_kitchen=config.use_kitchen, + **kwargs, ) layer_specs = [] for layer_type in args.layer_types: diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 5d69a10df6..b86c16e188 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -20,6 +20,8 @@ logger = get_logger() +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + # Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225 class GPTBridge: @@ -43,7 +45,7 @@ def __init__(self, disable_tqmd: bool = False): self._init_meta_hf_model() self.hf_layers = deep_getattr(self.hf_model, self.hf_layers_prefix) self.module_mapping = {} - self.megatron_core_014 = version.parse(megatron.core.__version__) >= version.parse('0.14.0rc0') + self.mcore_014 = version.parse(megatron.core.__version__) >= version.parse('0.14.0rc0') megatron_model_meta = get_megatron_model_meta(self.args.hf_model_type) if self.args.is_multimodal and megatron_model_meta.visual_cls is not None: self.module_mapping = megatron_model_meta.visual_cls.module_mapping @@ -81,7 +83,7 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: } if self.args.task_type == 'causal_lm': dim0_keys.add('output_layer') - if not self.megatron_core_014: + if not self.mcore_014: # https://github.com/NVIDIA/Megatron-LM/commit/720c8b40d8e7e2de1dd303d792f29093101c5e72 dim0_keys.update({'linear_q_down_proj', 'linear_kv_down_proj'}) # RowLinear @@ -971,7 +973,13 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = {} mg_models = iter(mg_models) mg_model = next(mg_models) - if not to_mcore or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if mcore_013: + is_pp_first_stage = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage) + is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage) + else: + is_pp_first_stage = mpu.is_pipeline_first_stage() + is_pp_last_stage = mpu.is_pipeline_last_stage() + if not to_mcore or is_pp_first_stage: hf_state_dict.update(self._convert_pre_process(mg_model, hf_state_dict, '', to_mcore)) if to_mcore: yield @@ -1010,7 +1018,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd else: yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} - if not to_mcore or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if not to_mcore or is_pp_last_stage: hf_state_dict.update(self._convert_post_process(mg_model, hf_state_dict, '', to_mcore)) if to_mcore: yield diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 0aaa563277..b529a73337 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from typing import Any, Dict, Literal, Optional, Tuple +import megatron.core import torch from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -16,12 +17,15 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.training import get_args +from packaging import version from swift.utils import get_logger from .rope import dynamic_rope_update, get_rope_inv_freq logger = get_logger() +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + class OutputLayerLinear(TELinear): @@ -77,6 +81,12 @@ def __init__( config.mscale_all_dim = hf_rope_scaling['mscale_all_dim'] config.rotary_scaling_factor = hf_rope_scaling['factor'] self.hf_rope_scaling = hf_rope_scaling + if mcore_013: + kwargs = {'vp_stage': vp_stage} + else: + self.vp_stage = vp_stage + assert vp_stage is None, 'megatron-core==0.12 does not support vp_stage' + kwargs = {} super().__init__( config, transformer_layer_spec, @@ -95,7 +105,7 @@ def __init__( scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel, seq_len_interpolation_factor=seq_len_interpolation_factor, mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, + **kwargs, ) if config.multi_latent_attention: self.rotary_pos_emb = RotaryEmbedding( @@ -293,25 +303,53 @@ def forward( ) args = get_args() - return self._postprocess( - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels if args.task_type == 'causal_lm' else None, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - mtp_in_postprocess=self.mtp_process, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, - ) + labels = labels if args.task_type == 'causal_lm' else None + if mcore_013: + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=self.mtp_process, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + ) + else: + if not self.post_process: + return hidden_states + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) + if has_config_logger_enabled(self.config): + payload = OrderedDict({ + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'decoder_input': decoder_input, + 'logits': logits, + }) + log_config_to_disk(self.config, payload, prefix='input_and_logits') + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss def get_input_tensor(self): return self.decoder.input_tensor diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 8edb17c21f..997f53a231 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import TYPE_CHECKING, Optional, Union +import megatron.core import megatron.legacy import torch from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_local_spec, @@ -11,6 +12,9 @@ from megatron.training import get_args, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from packaging import version + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') if TYPE_CHECKING: from .gpt_model import GPTModel @@ -29,14 +33,17 @@ def _get_transformer_layer_spec(use_te, config): """ args = get_args() if use_te: + if mcore_013: + kwargs = {'qk_l2_norm': args.qk_l2_norm, 'use_kitchen': config.use_kitchen} + else: + kwargs = {} return get_gpt_layer_with_transformer_engine_spec( args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, - qk_l2_norm=args.qk_l2_norm, - use_kitchen=config.use_kitchen, + **kwargs, ) else: return get_gpt_layer_local_spec( @@ -110,13 +117,13 @@ def oom_observer(device, alloc, device_alloc, device_free): transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) else: if args.num_experts: + if mcore_013: + kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} + else: + kwargs = {} # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( - config, - use_transformer_engine=use_te, - normalization=args.normalization, - qk_l2_norm=args.qk_l2_norm, - vp_stage=vp_stage) + config, use_transformer_engine=use_te, normalization=args.normalization, **kwargs) elif args.heterogeneous_layers_config_path is not None: transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te) else: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index c4d0dc2aba..164fe0ee0a 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -14,7 +14,7 @@ from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import get_num_microbatches -from megatron.core.optimizer import _update_min_and_max_lr_in_param_groups, param_group_identifier_keys +from megatron.core.optimizer import _update_min_and_max_lr_in_param_groups from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine from megatron.core.transformer.module import MegatronModule @@ -40,6 +40,11 @@ from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, get_swift_datasets_provider) +try: + from megatron.core.optimizer import param_group_identifier_keys +except ImportError: + param_group_identifier_keys = None + logger = get_logger() @@ -64,7 +69,7 @@ def _get_mean_metric(): 'train': collections.defaultdict(_get_mean_metric), 'eval': collections.defaultdict(_get_mean_metric) } - self.megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @property def bridge(self): @@ -363,7 +368,8 @@ def _get_param_groups( } # Ensure param_group has required keys for matching when loading optimizer state # See MegatronOptimizer._filter_and_reorder_param_groups. - assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} + if param_group_identifier_keys is not None: + assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} param_groups.append(param_group) param_groups = _update_min_and_max_lr_in_param_groups( @@ -471,8 +477,7 @@ def _initialize_embedding(model): def _all_reduce_metric(self, metric: Dict[str, torch.Tensor], reduction=torch.distributed.ReduceOp.AVG) -> Dict[str, torch.Tensor]: - values = list(metric.values()) - reporting_metric = values[0].new_tensor(values) + reporting_metric = torch.stack(list(metric.values()), dim=0) torch.distributed.all_reduce(reporting_metric, reduction, group=mpu.get_data_parallel_group()) return {k: reporting_metric[i] for i, k in enumerate(metric.keys())} @@ -559,7 +564,7 @@ def evaluate( torch.cuda.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): - if self.megatron_core_013: + if self.mcore_013: for key in loss_dicts[0].keys(): if key not in total_loss_dict: total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index d0a385aa41..f201767d3e 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -143,7 +143,11 @@ def forward_step(self, data_iterator, model): unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, length, 0)) with self.stimer: output_tensor = model(**data) - dim = 1 if mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) else 0 + if self.mcore_013: + is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) + else: + is_pp_last_stage = mpu.is_pipeline_last_stage() + dim = 1 if is_pp_last_stage else 0 if self.args.calculate_KL: res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=dim) else: diff --git a/swift/megatron/trainers/reward_trainer.py b/swift/megatron/trainers/reward_trainer.py index 852f488ed2..08800826e7 100644 --- a/swift/megatron/trainers/reward_trainer.py +++ b/swift/megatron/trainers/reward_trainer.py @@ -16,6 +16,7 @@ class MegatronRewardTrainer(MegatronRLHFTrainer): def __init__(self, args, template): super().__init__(args, template) assert args.padding_free, 'Currently `rlhf_type="rm"` only supports padding_free.' + assert args.context_parallel_size == 1, 'Currently `rlhf_type="rm"` does not support context parallelism.' def loss_func(self, output_tensor, *, data): packed_seq_params = data.get('packed_seq_params') diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 98422b8c43..0fc193fd21 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -76,7 +76,7 @@ def loss_func(self, loss = torch.cat([torch.sum(losses * loss_mask).view(1), loss_mask.sum().view(1)]) - if args.context_parallel_size > 1 and not self.megatron_core_013: + if args.context_parallel_size > 1 and not self.mcore_013: loss = all_reduce(loss, group=mpu.get_context_parallel_group()) # Check individual rank losses are not NaN prior to DP all-reduce. @@ -114,7 +114,7 @@ def loss_func(self, # Reduce loss for logging. reporting_loss = loss.detach().clone() lm_loss = loss[0] - if not self.megatron_core_013: + if not self.mcore_013: # fix megatron-lm bug # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index abfcfbd0cc..6879fe23bf 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,15 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +import megatron.core import torch from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank from megatron.training import get_args +from packaging import version from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + def get_swift_datasets_provider(train_dataset, val_dataset): @@ -37,9 +41,15 @@ def get_batch_on_this_tp_rank(data, vp_stage=None): batch = to_device(data, 'cuda', non_blocking=True) if args.pipeline_model_parallel_size == 1: return batch - if not mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): + if mcore_013: + is_pp_first_stage = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage) + is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) + else: + is_pp_first_stage = mpu.is_pipeline_first_stage() + is_pp_last_stage = mpu.is_pipeline_last_stage() + if not is_pp_first_stage: batch['input_ids'] = None - if not mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): + if not is_pp_last_stage: batch['labels'] = None batch['loss_scale'] = None diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index 2222a465be..815fa63d5c 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -29,7 +29,7 @@ from swift.utils import get_current_device from ..utils import tuners_sharded_state_dict -megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') class LoraParallelLinear(MegatronModule, LoraLayer): @@ -99,7 +99,7 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w 'config': self.config, 'is_expert': self.is_expert, } - if megatron_core_013: + if mcore_013: kwargs['tp_group'] = self.base_layer.tp_group if isinstance(self.base_layer, TopKRouter): router_shape = self.base_layer.weight.shape