Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ __pycache__/
*.log
*.out
/megatron_output/
/.idea/
/.qoder/
4 changes: 2 additions & 2 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region,
gather_from_tensor_model_parallel_region)
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper, roll_tensor
from megatron.core.transformer.multi_token_prediction import MTPLossAutoScaler, MTPLossLoggingHelper
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from typing import Optional, Tuple

from mcore_bridge.config import ModelConfig
from mcore_bridge.utils import get_logger, split_cp_inputs
from mcore_bridge.utils import get_logger, roll_tensor, split_cp_inputs

from .rope import dynamic_rope_update, get_rope_inv_freq

Expand Down
4 changes: 2 additions & 2 deletions src/mcore_bridge/model/modules/mtp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from megatron.core.utils import make_viewless_tensor
from typing import Callable, Optional

from mcore_bridge.utils import roll_tensor

try:
from megatron.core.typed_torch import apply_module
except ImportError:
Expand Down Expand Up @@ -162,8 +164,6 @@ def _get_embeddings(
packed_seq_params: Optional[PackedSeqParams] = None,
decoder_input=None,
):
from megatron.core.transformer.multi_token_prediction import roll_tensor

# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids, _ = roll_tensor(
input_ids,
Expand Down
3 changes: 2 additions & 1 deletion src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,8 @@ def _patch_mtp():
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, **kwargs) -> torch.Tensor:
# get hidden states from previous mtp stages
offset = get_mtp_layer_offset(self.config, self.vp_stage)
get_offset_kwargs = {} if self.vp_stage is None else {'vp_stage': self.vp_stage}
offset = get_mtp_layer_offset(self.config, **get_offset_kwargs)
Comment thread
Jintao-Huang marked this conversation as resolved.
assert offset == 0, 'not support offset'
hidden_states_list = list(torch.chunk(hidden_states, 1 + offset, dim=0))
hidden_states = hidden_states_list[offset]
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .env import get_dist_setting, get_node_setting, is_dist, is_last_rank, is_local_master, is_master
from .import_utils import _LazyModule, is_flash_attn_3_available
from .logger import get_logger
from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model
from .megatron_utils import get_local_layer_specs, roll_tensor, set_random_seed, split_cp_inputs, unwrap_model
from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver
from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device
from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy
112 changes: 112 additions & 0 deletions src/mcore_bridge/utils/megatron_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
# code borrowed from modelscope/ms-swift
import megatron.core
import torch
from megatron.core import mpu, tensor_parallel
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.transformer.module import Float16Module
from megatron.core.transformer.multi_token_prediction import roll_tensor as mcore_roll_tensor
from megatron.core.transformer.transformer_block import get_num_layers_to_build
from megatron.core.transformer.transformer_layer import get_transformer_layer_offset
from packaging import version
from transformers import set_seed
from typing import Optional

from .logger import get_logger

mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0')

logger = get_logger()


Expand Down Expand Up @@ -97,3 +102,110 @@ def set_random_seed(
use_cudagraphable_rng)
else:
raise ValueError(f'Seed ({seed_}) should be a positive integer.')


# code borrowed from NVIDIA/Megatron-LM
def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None):
"""Roll tensor with packed sequence support.
This function handles rolling for packed sequences by respecting sequence boundaries
"""

# Notice: This is a naive implementation to test the correctness,
# a better solution will only sync the boundary tokens once.
assert (dims == -1 or dims == tensor.dim() - 1), 'Packed sequence roll only supports the last dimension.'
assert shifts == -1, 'Packed sequence roll only supports a single-token left shift.'
cu_seqlens = packed_seq_params.cu_seqlens_q
assert cu_seqlens is not None, 'Packed sequence parameters must provide cu_seqlens_q.'

rolled_tensor = tensor.clone()

cp_size = cp_group.size() if cp_group is not None else 1
if cp_size == 1:
# CP disabled: roll each packed sequence independently within its boundaries
for i in range(len(cu_seqlens) - 1):
start_idx = cu_seqlens[i]
end_idx = cu_seqlens[i + 1]
seq_slice = tensor[..., start_idx:end_idx]
rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims)
# Zero out the last position(s) that would cross sequence boundaries
rolled_seq[..., shifts:] = 0
rolled_tensor[..., start_idx:end_idx] = rolled_seq
Comment on lines +125 to +132
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop over sequences for the non-CP case can be significantly optimized by vectorizing the operation. Instead of slicing and rolling each sequence individually, you can perform a single torch.roll on the entire tensor and then zero out the elements that wrapped around sequence boundaries using index_fill_. This avoids multiple expensive slicing and rolling operations.

        rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
        # Zero out the last position(s) that would cross sequence boundaries
        # For shifts=-1, these are the elements at cu_seqlens[1:] - 1
        indices = cu_seqlens[1:] - 1
        rolled_tensor.index_fill_(dims, indices, 0)
        return rolled_tensor, rolled_tensor.sum()

return rolled_tensor, rolled_tensor.sum()

# CP enabled: each rank owns two chunks per sequence (front and mirrored tail).
local_rank = torch.distributed.get_rank(group=cp_group)
global_ranks = torch.distributed.get_process_group_ranks(group=cp_group)
next_rank = global_ranks[(local_rank + 1) % cp_size]
prev_rank = global_ranks[(local_rank - 1) % cp_size]

# Iterate over each sequence individually
for i in range(len(cu_seqlens) - 1):
start_idx = cu_seqlens[i]
end_idx = cu_seqlens[i + 1]

# the idx has been multiplied by cp_size, need to divide it by cp_size to get the local idx
local_start_idx = start_idx // cp_size
local_end_idx = end_idx // cp_size

# Skip empty sequences - this can happen when a sequence is very short and
# after dividing by cp_size, the local slice has zero length
local_seq_len = local_end_idx - local_start_idx
if local_seq_len == 0:
continue

tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The .clone() call here is redundant. rolled_tensor is already a clone of the input tensor (created at line 120), and the subsequent operations either create new tensors (like torch.roll) or update rolled_tensor in place. Removing this clone will save memory and compute.

Suggested change
tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone()
tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx]


# The following code is very similar as the code in roll_tensor function
local_chunks = tensor_slice.chunk(2, dim=dims)
rolled_chunks = [torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks]

tensor_send_list = []
tensor_recv_list = []
for chunk in rolled_chunks:
# Skip empty chunks that can occur when the sequence slice is very small
if chunk.size(dims) == 0:
tensor_send_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device))
tensor_recv_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device))
continue
boundary = chunk.select(dims, shifts).contiguous().clone()
tensor_send_list.append(boundary)
tensor_recv_list.append(torch.empty_like(boundary))

ops = []
if local_rank != 0:
ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank))
ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank))
else:
tensor_recv_list[1].zero_()

if local_rank != cp_size - 1:
ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank))
ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank))
else:
tensor_recv_list[0].copy_(tensor_send_list[1])
Comment on lines +159 to +185
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic is susceptible to an IndexError if a sequence is short. If local_seq_len is 1, tensor_slice.chunk(2, dim=dims) will return a tuple with only one chunk. Subsequent accesses to tensor_send_list[1] or tensor_recv_list[1] (e.g., at lines 177, 179, 183, 185) will fail. The implementation should be robust to cases where the sequence length per rank is less than 2, which can happen with packed sequences even if the total sequence length is a multiple of the CP size.


for op in ops:
op.wait()

index = [slice(None)] * rolled_chunks[0].dim()
index[dims] = shifts
for chunk, recv in zip(rolled_chunks, tensor_recv_list):
# Skip empty chunks
if chunk.size(dims) == 0:
continue
chunk[tuple(index)] = recv

seq_result = torch.cat(rolled_chunks, dim=dims)

# update the rolled tensor
rolled_tensor[..., local_start_idx:local_end_idx] = seq_result
Comment on lines +142 to +201
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performing point-to-point communication (isend/irecv) inside a loop over sequences is extremely inefficient and will likely become a major performance bottleneck during training, especially for batches with many sequences. It is highly recommended to vectorize this by gathering all boundary tokens for all sequences into a single buffer and performing the communication once, as acknowledged in the comment at line 113.


return rolled_tensor, rolled_tensor.sum()


def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None):
if mcore_016 or packed_seq_params is None:
kwargs = {'packed_seq_params': packed_seq_params} if mcore_016 else {}
return mcore_roll_tensor(tensor, shifts=shifts, dims=dims, cp_group=cp_group, **kwargs)
# mcore 0.15 & packed_seq_params
return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group)
Comment thread
Jintao-Huang marked this conversation as resolved.
Loading