Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions auto_round/data_type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@
from auto_round.data_type.register import QUANT_FUNC_WITH_DTYPE
import auto_round.data_type.w4fp8
from auto_round.data_type.utils import get_quant_func
import auto_round.data_type.nvfp

2 changes: 1 addition & 1 deletion auto_round/data_type/mxfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def quant_mx(tensor, bits=4, group_size=-1, v=0, max_scale=1.0,
orig_dtype = tensor.dtype
shared_exp, _ = torch.max(torch.abs(tensor), dim=-1, keepdim=True)
if isinstance(max_scale, torch.Tensor):
shared_exp *= (max_scale.unsqueeze(dim=-1))
shared_exp *= (max_scale.unsqueeze(dim=-1)).to(tensor.device)
else:
shared_exp *= max_scale

Expand Down
97 changes: 95 additions & 2 deletions auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def ref_nvfp4_quant(x, global_scale, block_size=16, v=0):
scale = float8_e4m3fn_ste(scale).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))

scaled_x = x.to(torch.float32) * output_scale
clipped_x = torch.clamp(scaled_x, -6.0, 6.0) + v
scaled_x = x.to(torch.float32) * output_scale + v
clipped_x = torch.clamp(scaled_x, -6.0, 6.0)
return (cast_to_fp4_ste(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), output_scale


Expand All @@ -77,3 +77,96 @@ def full_quant(tensor, bits=4, group_size=16, v=0, **kwargs):
qdq_res, output_scale = ref_nvfp4_quant(tensor, global_scale, group_size, v)
qdq_res = revert_tensor_by_pad(qdq_res, orig_shape=orig_shape, pad_len=pad_len)
return qdq_res.to(orig_dtype), output_scale, None


FLOAT8_UE5M3_MAX = 114688


def float_to_e5m3_frexp(x: torch.Tensor) -> torch.Tensor:
x = torch.clamp(x, min=0.0)
e5m3 = torch.zeros_like(x, dtype=torch.uint8)

mask = x > 0
x_masked = x[mask]

# 正常数:x >= 2^-14
normal_mask = x_masked >= 2 ** -14
x_normal = x_masked[normal_mask]
mantissa, exponent = torch.frexp(x_normal)

m3 = torch.clamp(torch.round((mantissa - 0.5) * 16), 0, 7).to(torch.uint8)
e5 = torch.clamp(exponent + 14, 0, 31).to(torch.uint8) # 0 reserved for subnormal, 31 reserved for NaN

e5m3_vals = ((e5 << 3) | m3).to(torch.uint8)

# sumnorm:0 < x < 2^-14
subnormal_mask = ~normal_mask
x_subnormal = x_masked[subnormal_mask]
m_sub = torch.clamp(torch.round(x_subnormal / (2 ** -14) * 8), 1, 7).to(torch.uint8) # exponent = 0
e5m3_sub = m_sub # top 5 bits = 0

out_vals = torch.zeros_like(x_masked, dtype=torch.uint8)
out_vals[normal_mask] = e5m3_vals
out_vals[subnormal_mask] = e5m3_sub

e5m3[mask] = out_vals
return e5m3


def e5m3_to_float_tensor(e5m3: torch.Tensor) -> torch.Tensor:
assert e5m3.dtype == torch.uint8

x = torch.zeros_like(e5m3, dtype=torch.float32)
mask_nonzero = e5m3 != 0
e = ((e5m3[mask_nonzero] >> 3) & 0x1F).to(torch.int32)
m = (e5m3[mask_nonzero] & 0x07).to(torch.int32)

is_nan = (e == 31) & (m == 7)
is_subnormal = (e == 0)
is_normal = (e > 0) & (~is_nan)

out = torch.zeros_like(e, dtype=torch.float32)

# subnormal: exponent = -14, no implicit leading 1
out[is_subnormal] = (m[is_subnormal].float() / 8.0) * (2 ** -14)

# normal: exponent = e - 15, implicit leading 1
mant = 1.0 + m[is_normal].float() / 8.0
exp = e[is_normal] - 15
out[is_normal] = torch.ldexp(mant, exp)

out[is_nan] = float('nan')
x[mask_nonzero] = out
return x


def cast_to_ue5m3(tensor):
orig_dtype = tensor.dtype
encoded = float_to_e5m3_frexp(tensor)
res = e5m3_to_float_tensor(encoded)
res = res.to(orig_dtype)
return res


def cast_to_ue5m3_ste(x):
fp4 = (cast_to_ue5m3(x).to(x.dtype) - x).detach() + x

return fp4


if __name__ == "__main__":
test = torch.tensor(
[0.0, 2 ** (-17), (2 ** -14) * 0.875, 2 ** -14, 2 ** -13, 2 ** -6,
1e-6, 2.7657e-05, 0.1, 1.0, 3.14, 1000.0,
114688,
1e10],
dtype=torch.float32)
encoded = float_to_e5m3_frexp(test)
decoded = e5m3_to_float_tensor(encoded)
decoded_bf16 = decoded.to(torch.bfloat16)
print(decoded_bf16)

for i in range(len(test)):
print(
f"{test[i].item():.6g} -> {encoded[i].item():3d} -> {decoded[i].item():.6g} "
f"(error={abs(test[i] - decoded[i]).item():.3g})")