From d053f845565bea518b434f48bf2dd09c94621ff7 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 16 Apr 2026 12:18:31 +0800 Subject: [PATCH] fix fp8 --- src/mcore_bridge/bridge/gpt_bridge.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 8864739..0e226b9 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -375,7 +375,10 @@ def _get_weight( if not isinstance(tensor, (list, tuple)): tensor = [tensor] if self._is_fp8_param(tensor[0]): - mg_scale_inv = [t._rowwise_scale_inv for t in tensor] + mg_scale_inv = [ + t._rowwise_scale_inv[..., :math.ceil(t._rowwise_data.shape[-1] / self.fp8_block_size)] + for t in tensor + ] tensor = [t._rowwise_data for t in tensor] del mg_weight if tensor is not None: @@ -397,7 +400,6 @@ def _get_weight( mg_scale_inv = self._all_gather_tp(mg_scale_inv, tp_dim, is_expert) mg_scale_inv = self._broadcast_ep_pp(mg_scale_inv, is_expert) tensor = tensor.view(torch.float8_e4m3fn) - mg_scale_inv = mg_scale_inv[..., :math.ceil(tensor.shape[-1] / self.fp8_block_size)].contiguous() assert tensor is not None, f'mg_key: {mg_key}' if offset: assert mg_scale_inv is None, f'mg_key: {mg_key}'