Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,10 @@ def _get_weight(
if not isinstance(tensor, (list, tuple)):
tensor = [tensor]
if self._is_fp8_param(tensor[0]):
mg_scale_inv = [t._rowwise_scale_inv for t in tensor]
mg_scale_inv = [
t._rowwise_scale_inv[..., :math.ceil(t._rowwise_data.shape[-1] / self.fp8_block_size)]
for t in tensor
]
tensor = [t._rowwise_data for t in tensor]
del mg_weight
if tensor is not None:
Expand All @@ -397,7 +400,6 @@ def _get_weight(
mg_scale_inv = self._all_gather_tp(mg_scale_inv, tp_dim, is_expert)
mg_scale_inv = self._broadcast_ep_pp(mg_scale_inv, is_expert)
tensor = tensor.view(torch.float8_e4m3fn)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The removal of the slicing for mg_scale_inv at the end of _get_weight can lead to incorrect shapes for exported weights in RowParallel layers. When tp_dim is 1 (RowParallel), the column dimension is split across ranks. After all_gather, the concatenated scale tensor might have more blocks than required by the global weight if the local column count is not a multiple of fp8_block_size. For example, if total_cols=258, tp_size=2, and block_size=128, each rank has 129 columns, which requires 2 blocks locally. Gathering them results in 4 blocks, but the global weight only needs ceil(258/128) = 3 blocks. The slicing should be restored to ensure the exported scale tensor matches the expected format.

Suggested change
tensor = tensor.view(torch.float8_e4m3fn)
tensor = tensor.view(torch.float8_e4m3fn)
if mg_scale_inv is not None:
mg_scale_inv = mg_scale_inv[..., :math.ceil(tensor.shape[-1] / self.fp8_block_size)].contiguous()

mg_scale_inv = mg_scale_inv[..., :math.ceil(tensor.shape[-1] / self.fp8_block_size)].contiguous()
assert tensor is not None, f'mg_key: {mg_key}'
if offset:
assert mg_scale_inv is None, f'mg_key: {mg_key}'
Expand Down
Loading