From 160ea86b16c77ca8f107a4f8fa7859f392898cdf Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Fri, 10 Oct 2025 03:40:38 -0400 Subject: [PATCH] move ste from quant to round for nvfp4 Signed-off-by: He, Xin3 --- auto_round/data_type/nvfp.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/auto_round/data_type/nvfp.py b/auto_round/data_type/nvfp.py index bb56f365d..ec8f8c0c6 100644 --- a/auto_round/data_type/nvfp.py +++ b/auto_round/data_type/nvfp.py @@ -16,7 +16,7 @@ from auto_round.data_type.fp8 import float8_e4m3fn_ste from auto_round.data_type.register import register_dtype -from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad +from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste from auto_round.logger import logger @@ -26,9 +26,9 @@ def cast_to_fp4(x): sign = torch.sign(x) x = torch.abs(x) - step1 = torch.round(2.0 * x) / 2.0 - step2 = torch.round(x) - step3 = 2.0 * torch.round(x / 2.0) + step1 = round_ste(2.0 * x) / 2.0 + step2 = round_ste(x) + step3 = 2.0 * round_ste(x / 2.0) mask1 = x < 2.0 mask2 = x < 4.0 @@ -38,12 +38,6 @@ def cast_to_fp4(x): return x * sign -def cast_to_fp4_ste(x): - fp4 = (cast_to_fp4(x).to(x.dtype) - x).detach() + x - - return fp4 - - def get_reciprocal(x): if isinstance(x, torch.Tensor): return torch.where(x == 0, torch.zeros_like(x, dtype=x.dtype), 1.0 / x) @@ -81,7 +75,7 @@ def ref_nvfp4_quant(x, global_scale, block_size=16, v=0, scale_coeff=1.0): output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) scaled_x = x.to(torch.float32) * output_scale + v clipped_x = torch.clamp(scaled_x, -6.0, 6.0) - return (cast_to_fp4_ste(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale + return (cast_to_fp4(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale @register_dtype("nv_fp4") @@ -208,7 +202,7 @@ def ref_fp4_quant(x, global_scale, block_size=16, v=0, max_scale=1.0): output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) scaled_x = x.to(torch.float32) * output_scale + v clipped_x = torch.clamp(scaled_x, -6.0, 6.0) - return (cast_to_fp4_ste(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale + return (cast_to_fp4(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale @register_dtype("fp4_v2_with_global_scale")