-
Notifications
You must be signed in to change notification settings - Fork 14
[bugfix] fix MTP & mcore 0.15 (NPU) #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,3 +8,5 @@ __pycache__/ | |
| *.log | ||
| *.out | ||
| /megatron_output/ | ||
| /.idea/ | ||
| /.qoder/ | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,16 +1,21 @@ | ||||||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||||||
| # code borrowed from modelscope/ms-swift | ||||||
| import megatron.core | ||||||
| import torch | ||||||
| from megatron.core import mpu, tensor_parallel | ||||||
| from megatron.core.distributed import DistributedDataParallel as DDP | ||||||
| from megatron.core.transformer.module import Float16Module | ||||||
| from megatron.core.transformer.multi_token_prediction import roll_tensor as mcore_roll_tensor | ||||||
| from megatron.core.transformer.transformer_block import get_num_layers_to_build | ||||||
| from megatron.core.transformer.transformer_layer import get_transformer_layer_offset | ||||||
| from packaging import version | ||||||
| from transformers import set_seed | ||||||
| from typing import Optional | ||||||
|
|
||||||
| from .logger import get_logger | ||||||
|
|
||||||
| mcore_016 = version.parse(megatron.core.__version__) >= version.parse('0.16.0rc0') | ||||||
|
|
||||||
| logger = get_logger() | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -97,3 +102,110 @@ def set_random_seed( | |||||
| use_cudagraphable_rng) | ||||||
| else: | ||||||
| raise ValueError(f'Seed ({seed_}) should be a positive integer.') | ||||||
|
|
||||||
|
|
||||||
| # code borrowed from NVIDIA/Megatron-LM | ||||||
| def _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group=None): | ||||||
| """Roll tensor with packed sequence support. | ||||||
| This function handles rolling for packed sequences by respecting sequence boundaries | ||||||
| """ | ||||||
|
|
||||||
| # Notice: This is a naive implementation to test the correctness, | ||||||
| # a better solution will only sync the boundary tokens once. | ||||||
| assert (dims == -1 or dims == tensor.dim() - 1), 'Packed sequence roll only supports the last dimension.' | ||||||
| assert shifts == -1, 'Packed sequence roll only supports a single-token left shift.' | ||||||
| cu_seqlens = packed_seq_params.cu_seqlens_q | ||||||
| assert cu_seqlens is not None, 'Packed sequence parameters must provide cu_seqlens_q.' | ||||||
|
|
||||||
| rolled_tensor = tensor.clone() | ||||||
|
|
||||||
| cp_size = cp_group.size() if cp_group is not None else 1 | ||||||
| if cp_size == 1: | ||||||
| # CP disabled: roll each packed sequence independently within its boundaries | ||||||
| for i in range(len(cu_seqlens) - 1): | ||||||
| start_idx = cu_seqlens[i] | ||||||
| end_idx = cu_seqlens[i + 1] | ||||||
| seq_slice = tensor[..., start_idx:end_idx] | ||||||
| rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) | ||||||
| # Zero out the last position(s) that would cross sequence boundaries | ||||||
| rolled_seq[..., shifts:] = 0 | ||||||
| rolled_tensor[..., start_idx:end_idx] = rolled_seq | ||||||
|
Comment on lines
+125
to
+132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loop over sequences for the non-CP case can be significantly optimized by vectorizing the operation. Instead of slicing and rolling each sequence individually, you can perform a single rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
# Zero out the last position(s) that would cross sequence boundaries
# For shifts=-1, these are the elements at cu_seqlens[1:] - 1
indices = cu_seqlens[1:] - 1
rolled_tensor.index_fill_(dims, indices, 0)
return rolled_tensor, rolled_tensor.sum() |
||||||
| return rolled_tensor, rolled_tensor.sum() | ||||||
|
|
||||||
| # CP enabled: each rank owns two chunks per sequence (front and mirrored tail). | ||||||
| local_rank = torch.distributed.get_rank(group=cp_group) | ||||||
| global_ranks = torch.distributed.get_process_group_ranks(group=cp_group) | ||||||
| next_rank = global_ranks[(local_rank + 1) % cp_size] | ||||||
| prev_rank = global_ranks[(local_rank - 1) % cp_size] | ||||||
|
|
||||||
| # Iterate over each sequence individually | ||||||
| for i in range(len(cu_seqlens) - 1): | ||||||
| start_idx = cu_seqlens[i] | ||||||
| end_idx = cu_seqlens[i + 1] | ||||||
|
|
||||||
| # the idx has been multiplied by cp_size, need to divide it by cp_size to get the local idx | ||||||
| local_start_idx = start_idx // cp_size | ||||||
| local_end_idx = end_idx // cp_size | ||||||
|
|
||||||
| # Skip empty sequences - this can happen when a sequence is very short and | ||||||
| # after dividing by cp_size, the local slice has zero length | ||||||
| local_seq_len = local_end_idx - local_start_idx | ||||||
| if local_seq_len == 0: | ||||||
| continue | ||||||
|
|
||||||
| tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
|
|
||||||
| # The following code is very similar as the code in roll_tensor function | ||||||
| local_chunks = tensor_slice.chunk(2, dim=dims) | ||||||
| rolled_chunks = [torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks] | ||||||
|
|
||||||
| tensor_send_list = [] | ||||||
| tensor_recv_list = [] | ||||||
| for chunk in rolled_chunks: | ||||||
| # Skip empty chunks that can occur when the sequence slice is very small | ||||||
| if chunk.size(dims) == 0: | ||||||
| tensor_send_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) | ||||||
| tensor_recv_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) | ||||||
| continue | ||||||
| boundary = chunk.select(dims, shifts).contiguous().clone() | ||||||
| tensor_send_list.append(boundary) | ||||||
| tensor_recv_list.append(torch.empty_like(boundary)) | ||||||
|
|
||||||
| ops = [] | ||||||
| if local_rank != 0: | ||||||
| ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) | ||||||
| ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) | ||||||
| else: | ||||||
| tensor_recv_list[1].zero_() | ||||||
|
|
||||||
| if local_rank != cp_size - 1: | ||||||
| ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) | ||||||
| ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) | ||||||
| else: | ||||||
| tensor_recv_list[0].copy_(tensor_send_list[1]) | ||||||
|
Comment on lines
+159
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic is susceptible to an |
||||||
|
|
||||||
| for op in ops: | ||||||
| op.wait() | ||||||
|
|
||||||
| index = [slice(None)] * rolled_chunks[0].dim() | ||||||
| index[dims] = shifts | ||||||
| for chunk, recv in zip(rolled_chunks, tensor_recv_list): | ||||||
| # Skip empty chunks | ||||||
| if chunk.size(dims) == 0: | ||||||
| continue | ||||||
| chunk[tuple(index)] = recv | ||||||
|
|
||||||
| seq_result = torch.cat(rolled_chunks, dim=dims) | ||||||
|
|
||||||
| # update the rolled tensor | ||||||
| rolled_tensor[..., local_start_idx:local_end_idx] = seq_result | ||||||
|
Comment on lines
+142
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Performing point-to-point communication ( |
||||||
|
|
||||||
| return rolled_tensor, rolled_tensor.sum() | ||||||
|
|
||||||
|
|
||||||
| def roll_tensor(tensor, shifts=-1, dims=-1, cp_group=None, packed_seq_params=None): | ||||||
| if mcore_016 or packed_seq_params is None: | ||||||
| kwargs = {'packed_seq_params': packed_seq_params} if mcore_016 else {} | ||||||
| return mcore_roll_tensor(tensor, shifts=shifts, dims=dims, cp_group=cp_group, **kwargs) | ||||||
| # mcore 0.15 & packed_seq_params | ||||||
| return _roll_tensor_packed_seq(tensor, shifts, dims, packed_seq_params, cp_group) | ||||||
|
Jintao-Huang marked this conversation as resolved.
|
||||||
Uh oh!
There was an error while loading. Please reload this page.