[bugfix] fix MTP & mcore 0.15 (NPU)#67
Conversation
There was a problem hiding this comment.
Code Review
This pull request modifies the forward method in src/mcore_bridge/patcher.py to conditionally pass the vp_stage parameter to get_mtp_layer_offset. The review feedback suggests using getattr for safer attribute access to prevent potential AttributeError and provides a more concise implementation for constructing the keyword arguments.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request centralizes the roll_tensor utility and introduces a custom implementation to support packed sequences for Megatron Core versions older than 0.16.0rc0. It also includes a fix for handling vp_stage in the MTP patcher and updates the .gitignore file. A critical bug was identified in the new roll_tensor wrapper, which lacks a return statement for the official implementation path, potentially leading to an AttributeError when packed_seq_params is None.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a custom roll_tensor implementation to support packed sequences and Context Parallelism (CP) for older Megatron-Core versions, alongside updates to the MTP patching logic and project-wide imports. Feedback focuses on critical issues in the new implementation, including a potential IndexError for short sequences and performance bottlenecks caused by non-vectorized communication. Suggestions were also provided to optimize the non-CP path through vectorization and to remove redundant memory operations.
| local_chunks = tensor_slice.chunk(2, dim=dims) | ||
| rolled_chunks = [torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks] | ||
|
|
||
| tensor_send_list = [] | ||
| tensor_recv_list = [] | ||
| for chunk in rolled_chunks: | ||
| # Skip empty chunks that can occur when the sequence slice is very small | ||
| if chunk.size(dims) == 0: | ||
| tensor_send_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) | ||
| tensor_recv_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) | ||
| continue | ||
| boundary = chunk.select(dims, shifts).contiguous().clone() | ||
| tensor_send_list.append(boundary) | ||
| tensor_recv_list.append(torch.empty_like(boundary)) | ||
|
|
||
| ops = [] | ||
| if local_rank != 0: | ||
| ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) | ||
| ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) | ||
| else: | ||
| tensor_recv_list[1].zero_() | ||
|
|
||
| if local_rank != cp_size - 1: | ||
| ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) | ||
| ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) | ||
| else: | ||
| tensor_recv_list[0].copy_(tensor_send_list[1]) |
There was a problem hiding this comment.
This logic is susceptible to an IndexError if a sequence is short. If local_seq_len is 1, tensor_slice.chunk(2, dim=dims) will return a tuple with only one chunk. Subsequent accesses to tensor_send_list[1] or tensor_recv_list[1] (e.g., at lines 177, 179, 183, 185) will fail. The implementation should be robust to cases where the sequence length per rank is less than 2, which can happen with packed sequences even if the total sequence length is a multiple of the CP size.
| for i in range(len(cu_seqlens) - 1): | ||
| start_idx = cu_seqlens[i] | ||
| end_idx = cu_seqlens[i + 1] | ||
|
|
||
| # the idx has been multiplied by cp_size, need to divide it by cp_size to get the local idx | ||
| local_start_idx = start_idx // cp_size | ||
| local_end_idx = end_idx // cp_size | ||
|
|
||
| # Skip empty sequences - this can happen when a sequence is very short and | ||
| # after dividing by cp_size, the local slice has zero length | ||
| local_seq_len = local_end_idx - local_start_idx | ||
| if local_seq_len == 0: | ||
| continue | ||
|
|
||
| tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() | ||
|
|
||
| # The following code is very similar as the code in roll_tensor function | ||
| local_chunks = tensor_slice.chunk(2, dim=dims) | ||
| rolled_chunks = [torch.roll(chunk, shifts=shifts, dims=dims) for chunk in local_chunks] | ||
|
|
||
| tensor_send_list = [] | ||
| tensor_recv_list = [] | ||
| for chunk in rolled_chunks: | ||
| # Skip empty chunks that can occur when the sequence slice is very small | ||
| if chunk.size(dims) == 0: | ||
| tensor_send_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) | ||
| tensor_recv_list.append(torch.empty(chunk.shape[:-1], dtype=chunk.dtype, device=chunk.device)) | ||
| continue | ||
| boundary = chunk.select(dims, shifts).contiguous().clone() | ||
| tensor_send_list.append(boundary) | ||
| tensor_recv_list.append(torch.empty_like(boundary)) | ||
|
|
||
| ops = [] | ||
| if local_rank != 0: | ||
| ops.append(torch.distributed.isend(tensor=tensor_send_list[0], dst=prev_rank)) | ||
| ops.append(torch.distributed.irecv(tensor=tensor_recv_list[1], src=prev_rank)) | ||
| else: | ||
| tensor_recv_list[1].zero_() | ||
|
|
||
| if local_rank != cp_size - 1: | ||
| ops.append(torch.distributed.irecv(tensor=tensor_recv_list[0], src=next_rank)) | ||
| ops.append(torch.distributed.isend(tensor=tensor_send_list[1], dst=next_rank)) | ||
| else: | ||
| tensor_recv_list[0].copy_(tensor_send_list[1]) | ||
|
|
||
| for op in ops: | ||
| op.wait() | ||
|
|
||
| index = [slice(None)] * rolled_chunks[0].dim() | ||
| index[dims] = shifts | ||
| for chunk, recv in zip(rolled_chunks, tensor_recv_list): | ||
| # Skip empty chunks | ||
| if chunk.size(dims) == 0: | ||
| continue | ||
| chunk[tuple(index)] = recv | ||
|
|
||
| seq_result = torch.cat(rolled_chunks, dim=dims) | ||
|
|
||
| # update the rolled tensor | ||
| rolled_tensor[..., local_start_idx:local_end_idx] = seq_result |
There was a problem hiding this comment.
Performing point-to-point communication (isend/irecv) inside a loop over sequences is extremely inefficient and will likely become a major performance bottleneck during training, especially for batches with many sequences. It is highly recommended to vectorize this by gathering all boundary tokens for all sequences into a single buffer and performing the communication once, as acknowledged in the comment at line 113.
| for i in range(len(cu_seqlens) - 1): | ||
| start_idx = cu_seqlens[i] | ||
| end_idx = cu_seqlens[i + 1] | ||
| seq_slice = tensor[..., start_idx:end_idx] | ||
| rolled_seq = torch.roll(seq_slice, shifts=shifts, dims=dims) | ||
| # Zero out the last position(s) that would cross sequence boundaries | ||
| rolled_seq[..., shifts:] = 0 | ||
| rolled_tensor[..., start_idx:end_idx] = rolled_seq |
There was a problem hiding this comment.
The loop over sequences for the non-CP case can be significantly optimized by vectorizing the operation. Instead of slicing and rolling each sequence individually, you can perform a single torch.roll on the entire tensor and then zero out the elements that wrapped around sequence boundaries using index_fill_. This avoids multiple expensive slicing and rolling operations.
rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
# Zero out the last position(s) that would cross sequence boundaries
# For shifts=-1, these are the elements at cu_seqlens[1:] - 1
indices = cu_seqlens[1:] - 1
rolled_tensor.index_fill_(dims, indices, 0)
return rolled_tensor, rolled_tensor.sum()| if local_seq_len == 0: | ||
| continue | ||
|
|
||
| tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() |
There was a problem hiding this comment.
The .clone() call here is redundant. rolled_tensor is already a clone of the input tensor (created at line 120), and the subsequent operations either create new tensors (like torch.roll) or update rolled_tensor in place. Removing this clone will save memory and compute.
| tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx].clone() | |
| tensor_slice = rolled_tensor[..., local_start_idx:local_end_idx] |
No description provided.