From 356469db2c99226effb2ae0f58031672d063ce18 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Thu, 15 May 2025 16:42:28 +0800 Subject: [PATCH 1/3] add UE5M3 simulation --- auto_round/data_type/nvfp.py | 94 +++++++++++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 2 deletions(-) diff --git a/auto_round/data_type/nvfp.py b/auto_round/data_type/nvfp.py index 2a93e62d5..d227f9a0a 100644 --- a/auto_round/data_type/nvfp.py +++ b/auto_round/data_type/nvfp.py @@ -63,8 +63,8 @@ def ref_nvfp4_quant(x, global_scale, block_size=16, v=0): scale = float8_e4m3fn_ste(scale).to(torch.float32) output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) - scaled_x = x.to(torch.float32) * output_scale - clipped_x = torch.clamp(scaled_x, -6.0, 6.0) + v + scaled_x = x.to(torch.float32) * output_scale + v + clipped_x = torch.clamp(scaled_x, -6.0, 6.0) return (cast_to_fp4_ste(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), output_scale @@ -77,3 +77,93 @@ def full_quant(tensor, bits=4, group_size=16, v=0, **kwargs): qdq_res, output_scale = ref_nvfp4_quant(tensor, global_scale, group_size, v) qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len) return qdq_res.to(orig_dtype), output_scale, None + + +FLOAT8_UE5M3_MAX = 114688 + +def float_to_e5m3_frexp(x: torch.Tensor) -> torch.Tensor: + x = torch.clamp(x, min=0.0) + e5m3 = torch.zeros_like(x, dtype=torch.uint8) + + mask = x > 0 + x_masked = x[mask] + + # 正常数:x >= 2^-14 + normal_mask = x_masked >= 2 ** -14 + x_normal = x_masked[normal_mask] + mantissa, exponent = torch.frexp(x_normal) + + m3 = torch.clamp(torch.round((mantissa - 0.5) * 16), 0, 7).to(torch.uint8) + e5 = torch.clamp(exponent + 14, 0, 31).to(torch.uint8) # 0 reserved for subnormal, 31 reserved for NaN + + e5m3_vals = ((e5 << 3) | m3).to(torch.uint8) + + # sumnorm:0 < x < 2^-14 + subnormal_mask = ~normal_mask + x_subnormal = x_masked[subnormal_mask] + m_sub = torch.clamp(torch.round(x_subnormal / (2 ** -14) * 8), 1, 7).to(torch.uint8) # exponent = 0 + e5m3_sub = m_sub # top 5 bits = 0 + + out_vals = torch.zeros_like(x_masked, dtype=torch.uint8) + out_vals[normal_mask] = e5m3_vals + out_vals[subnormal_mask] = e5m3_sub + + e5m3[mask] = out_vals + return e5m3 + + +def e5m3_to_float_tensor(e5m3: torch.Tensor) -> torch.Tensor: + assert e5m3.dtype == torch.uint8 + + x = torch.zeros_like(e5m3, dtype=torch.float32) + mask_nonzero = e5m3 != 0 + e = ((e5m3[mask_nonzero] >> 3) & 0x1F).to(torch.int32) + m = (e5m3[mask_nonzero] & 0x07).to(torch.int32) + + is_nan = (e == 31) & (m == 7) + is_subnormal = (e == 0) + is_normal = (e > 0) & (~is_nan) + + out = torch.zeros_like(e, dtype=torch.float32) + + # subnormal: exponent = -14, no implicit leading 1 + out[is_subnormal] = (m[is_subnormal].float() / 8.0) * (2 ** -14) + + # normal: exponent = e - 15, implicit leading 1 + mant = 1.0 + m[is_normal].float() / 8.0 + exp = e[is_normal] - 15 + out[is_normal] = torch.ldexp(mant, exp) + + out[is_nan] = float('nan') + x[mask_nonzero] = out + return x + + +def cast_to_ue5m3(tensor): + orig_dtype = tensor.dtype + encoded = float_to_e5m3_frexp(tensor) + res = e5m3_to_float_tensor(encoded) + res = res.to(orig_dtype) + return res + + +def cast_to_ue5m3_ste(x): + fp4 = (cast_to_ue5m3(x).to(x.dtype) - x).detach() + x + + return fp4 + + +if __name__ == "__main__": + test = torch.tensor( + [0.0, 2 ** (-17), (2 ** -14) * 0.875, 2 ** -14, 2 ** -13, 2 ** -6, 1e-6, 2.7657e-05, 0.1, 1.0, 3.14, 1000.0, + 114688, + 1e10], + dtype=torch.float32) + encoded = float_to_e5m3_frexp(test) + decoded = e5m3_to_float_tensor(encoded) + decoded_bf16 = decoded.to(torch.bfloat16) + print(decoded_bf16) + + for i in range(len(test)): + print( + f"{test[i].item():.6g} -> {encoded[i].item():3d} -> {decoded[i].item():.6g} (error={abs(test[i] - decoded[i]).item():.3g})") From 68ba72cceab8c0d29e940fdaccc8840a3332a75f Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Thu, 15 May 2025 17:11:29 +0800 Subject: [PATCH 2/3] fix line too long --- auto_round/data_type/nvfp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/auto_round/data_type/nvfp.py b/auto_round/data_type/nvfp.py index d227f9a0a..a60541c95 100644 --- a/auto_round/data_type/nvfp.py +++ b/auto_round/data_type/nvfp.py @@ -81,6 +81,7 @@ def full_quant(tensor, bits=4, group_size=16, v=0, **kwargs): FLOAT8_UE5M3_MAX = 114688 + def float_to_e5m3_frexp(x: torch.Tensor) -> torch.Tensor: x = torch.clamp(x, min=0.0) e5m3 = torch.zeros_like(x, dtype=torch.uint8) @@ -155,7 +156,8 @@ def cast_to_ue5m3_ste(x): if __name__ == "__main__": test = torch.tensor( - [0.0, 2 ** (-17), (2 ** -14) * 0.875, 2 ** -14, 2 ** -13, 2 ** -6, 1e-6, 2.7657e-05, 0.1, 1.0, 3.14, 1000.0, + [0.0, 2 ** (-17), (2 ** -14) * 0.875, 2 ** -14, 2 ** -13, 2 ** -6, + 1e-6, 2.7657e-05, 0.1, 1.0, 3.14, 1000.0, 114688, 1e10], dtype=torch.float32) @@ -166,4 +168,5 @@ def cast_to_ue5m3_ste(x): for i in range(len(test)): print( - f"{test[i].item():.6g} -> {encoded[i].item():3d} -> {decoded[i].item():.6g} (error={abs(test[i] - decoded[i]).item():.3g})") + f"{test[i].item():.6g} -> {encoded[i].item():3d} -> {decoded[i].item():.6g} " + f"(error={abs(test[i] - decoded[i]).item():.3g})") From 0f13dcad80f9c515442cd96bb14ad3d8e179b63d Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Thu, 15 May 2025 17:30:16 +0800 Subject: [PATCH 3/3] fix --- auto_round/data_type/__init__.py | 2 ++ auto_round/data_type/mxfp.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/auto_round/data_type/__init__.py b/auto_round/data_type/__init__.py index c800fb9af..44c2e9368 100644 --- a/auto_round/data_type/__init__.py +++ b/auto_round/data_type/__init__.py @@ -18,3 +18,5 @@ from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE import auto_round.data_type.w4fp8 from auto_round.data_type.utils import get_quant_func +import auto_round.data_type.nvfp + diff --git a/auto_round/data_type/mxfp.py b/auto_round/data_type/mxfp.py index 48add5a8c..2c51f16e6 100644 --- a/auto_round/data_type/mxfp.py +++ b/auto_round/data_type/mxfp.py @@ -103,7 +103,7 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0, orig_dtype = tensor.dtype shared_exp, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True) if isinstance(max_scale, torch.Tensor): - shared_exp *= (max_scale.unsqueeze(dim=-1)) + shared_exp *= (max_scale.unsqueeze(dim=-1)).to(tensor.device) else: shared_exp *= max_scale