From 2f4734e2d68edbb4885126c39aeca4ec4a931276 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 7 May 2026 23:04:10 +0800 Subject: [PATCH 1/6] fix mtp_layer_offset --- src/mcore_bridge/patcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 79527ee..77fff70 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -692,7 +692,8 @@ def _patch_mtp(): def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, hidden_states: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> torch.Tensor: # get hidden states from previous mtp stages - offset = get_mtp_layer_offset(self.config, self.vp_stage) + get_offset_kwargs = {} if self.vp_stage is None else {'vp_stage': self.vp_stage} + offset = get_mtp_layer_offset(self.config, **get_offset_kwargs) assert offset == 0, 'not support offset' hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0)) hidden_states = hidden_states_list[offset] From aa5674b5fc86b0adb0e6818fad28388bea24e362 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 7 May 2026 23:25:32 +0800 Subject: [PATCH 2/6] fix --- .gitignore | 2 + src/mcore_bridge/model/gpt_model.py | 4 +- src/mcore_bridge/model/modules/mtp_layer.py | 4 +- src/mcore_bridge/utils/__init__.py | 2 +- src/mcore_bridge/utils/megatron_utils.py | 112 ++++++++++++++++++++ 5 files changed, 119 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 9dc57fb..2d3acbd 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ __pycache__/ *.log *.out /megatron_output/ +/.idea/ +/.qoder/ diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 1903b7c..38a95cc 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -16,13 +16,13 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region) -from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor +from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.utils import WrappedTensor, deprecate_inference_params from typing import Optional, Tuple from mcore_bridge.config import ModelConfig -from mcore_bridge.utils import get_logger, split_cp_inputs +from mcore_bridge.utils import get_logger, roll_tensor, split_cp_inputs from .rope import dynamic_rope_update, get_rope_inv_freq diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index fa89bf4..537abf3 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -13,6 +13,8 @@ from megatron.core.utils import make_viewless_tensor from typing import Callable, Optional +from mcore_bridge.utils import roll_tensor + try: from megatron.core.typed_torch import apply_module except ImportError: @@ -162,8 +164,6 @@ def _get_embeddings( packed_seq_params: Optional[PackedSeqParams] = None, decoder_input=None, ): - from megatron.core.transformer.multi_token_prediction import roll_tensor - # Calc logits for the current Multi-Token Prediction (MTP) layers. input_ids, _ = roll_tensor( input_ids, diff --git a/src/mcore_bridge/utils/__init__.py b/src/mcore_bridge/utils/__init__.py index d4285be..5437718 100644 --- a/src/mcore_bridge/utils/__init__.py +++ b/src/mcore_bridge/utils/__init__.py @@ -3,7 +3,7 @@ from .env import get_dist_setting, get_node_setting, is_dist, is_last_rank, is_local_master, is_master from .import_utils import _LazyModule, is_flash_attn_3_available from .logger import get_logger -from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model +from .megatron_utils import get_local_layer_specs, roll_tensor, set_random_seed, split_cp_inputs, unwrap_model from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy diff --git a/src/mcore_bridge/utils/megatron_utils.py b/src/mcore_bridge/utils/megatron_utils.py index 9fb20a0..58cbd14 100644 --- a/src/mcore_bridge/utils/megatron_utils.py +++ b/src/mcore_bridge/utils/megatron_utils.py @@ -1,16 +1,21 @@ # Copyright (c) ModelScope Contributors. All rights reserved. # code borrowed from modelscope/ms-swift +import megatron.core import torch from megatron.core import mpu, tensor_parallel from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.transformer.module import Float16Module +from megatron.core.transformer.multi_token_prediction import roll_tensor as mcore_roll_tensor from megatron.core.transformer.transformer_block import get_num_layers_to_build from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from packaging import version from transformers import set_seed from typing import Optional from .logger import get_logger +mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0') + logger = get_logger() @@ -97,3 +102,110 @@ def set_random_seed( use_cudagraphable_rng) else: raise ValueError(f'Seed ({seed_}) should be a positive integer.') + + +# code borrowed from NVIDIA/Megatron-LM +def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): + """Roll tensor with packed sequence support. + This function handles rolling for packed sequences by respecting sequence boundaries + """ + + # Notice: This is a naive implementation to test the correctness, + # a better solution will only sync the boundary tokens once. + assert (dims == -1 or dims == tensor.dim() - 1), 'Packed sequence roll only supports the last dimension.' + assert shifts == -1, 'Packed sequence roll only supports a single-token left shift.' + cu_seqlens = packed_seq_params.cu_seqlens_q + assert cu_seqlens is not None, 'Packed sequence parameters must provide cu_seqlens_q.' + + rolled_tensor = tensor.clone() + + cp_size = cp_group.size() if cp_group is not None else 1 + if cp_size == 1: + # CP disabled: roll each packed sequence independently within its boundaries + for i in range(len(cu_seqlens) - 1): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + seq_slice = tensor[..., start_idx:end_idx] + rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) + # Zero out the last position(s) that would cross sequence boundaries + rolled_seq[..., shifts:] = 0 + rolled_tensor[..., start_idx:end_idx] = rolled_seq + return rolled_tensor, rolled_tensor.sum() + + # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). + local_rank = torch.distributed.get_rank(group=cp_group) + global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) + next_rank = global_ranks[(local_rank + 1) % cp_size] + prev_rank = global_ranks[(local_rank - 1) % cp_size] + + # Iterate over each sequence individually + for i in range(len(cu_seqlens) - 1): + start_idx = cu_seqlens[i] + end_idx = cu_seqlens[i + 1] + + # the idx has been multiplied by cp_size, need to divide it by cp_size to get the local idx + local_start_idx = start_idx // cp_size + local_end_idx = end_idx // cp_size + + # Skip empty sequences - this can happen when a sequence is very short and + # after dividing by cp_size, the local slice has zero length + local_seq_len = local_end_idx - local_start_idx + if local_seq_len == 0: + continue + + tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() + + # The following code is very similar as the code in roll_tensor function + local_chunks = tensor_slice.chunk(2, dim=dims) + rolled_chunks = [torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks] + + tensor_send_list = [] + tensor_recv_list = [] + for chunk in rolled_chunks: + # Skip empty chunks that can occur when the sequence slice is very small + if chunk.size(dims) == 0: + tensor_send_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) + tensor_recv_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) + continue + boundary = chunk.select(dims, shifts).contiguous().clone() + tensor_send_list.append(boundary) + tensor_recv_list.append(torch.empty_like(boundary)) + + ops = [] + if local_rank != 0: + ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) + ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) + else: + tensor_recv_list[1].zero_() + + if local_rank != cp_size - 1: + ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) + ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) + else: + tensor_recv_list[0].copy_(tensor_send_list[1]) + + for op in ops: + op.wait() + + index = [slice(None)] * rolled_chunks[0].dim() + index[dims] = shifts + for chunk, recv in zip(rolled_chunks, tensor_recv_list): + # Skip empty chunks + if chunk.size(dims) == 0: + continue + chunk[tuple(index)] = recv + + seq_result = torch.cat(rolled_chunks, dim=dims) + + # update the rolled tensor + rolled_tensor[..., local_start_idx:local_end_idx] = seq_result + + return rolled_tensor, rolled_tensor.sum() + + +def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): + if mcore_016 or packed_seq_params is None: + kwargs = {'packed_seq_params': packed_seq_params} if mcore_016 else {} + mcore_roll_tensor(tensor, shifts=shifts, dims=dims, cp_group=cp_group, **kwargs) + # mcore 0.15 & packed_seq_params + return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) From 1a93227d0a23932e11030ebfbfdfa12c7f4ff783 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 7 May 2026 23:25:38 +0800 Subject: [PATCH 3/6] fix --- src/mcore_bridge/model/modules/mtp_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 537abf3..95bfd6a 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -12,7 +12,7 @@ from megatron.core.transformer.spec_utils import build_module from megatron.core.utils import make_viewless_tensor from typing import Callable, Optional - +g from mcore_bridge.utils import roll_tensor try: From f04b91d282492190fd13a3a8d98a03ff6f5c9824 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 7 May 2026 23:27:47 +0800 Subject: [PATCH 4/6] fix --- src/mcore_bridge/model/modules/mtp_layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 95bfd6a..ce6a20e 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -12,7 +12,6 @@ from megatron.core.transformer.spec_utils import build_module from megatron.core.utils import make_viewless_tensor from typing import Callable, Optional -g from mcore_bridge.utils import roll_tensor try: From 6ddebcfe851d8c8e74e341725bff8011402eed6a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 7 May 2026 23:36:55 +0800 Subject: [PATCH 5/6] lint pass --- src/mcore_bridge/model/modules/mtp_layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index ce6a20e..537abf3 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -12,6 +12,7 @@ from megatron.core.transformer.spec_utils import build_module from megatron.core.utils import make_viewless_tensor from typing import Callable, Optional + from mcore_bridge.utils import roll_tensor try: From a3be0e84ad605f571184830500ab47afcd31e1d8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 7 May 2026 23:47:56 +0800 Subject: [PATCH 6/6] fix --- src/mcore_bridge/utils/megatron_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/utils/megatron_utils.py b/src/mcore_bridge/utils/megatron_utils.py index 58cbd14..cd912e7 100644 --- a/src/mcore_bridge/utils/megatron_utils.py +++ b/src/mcore_bridge/utils/megatron_utils.py @@ -206,6 +206,6 @@ def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=No def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): if mcore_016 or packed_seq_params is None: kwargs = {'packed_seq_params': packed_seq_params} if mcore_016 else {} - mcore_roll_tensor(tensor, shifts=shifts, dims=dims, cp_group=cp_group, **kwargs) + return mcore_roll_tensor(tensor, shifts=shifts, dims=dims, cp_group=cp_group, **kwargs) # mcore 0.15 & packed_seq_params return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group)