Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions auto_round/data_type/nvfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from auto_round.data_type.fp8 import float8_e4m3fn_ste
from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
from auto_round.logger import logger


Expand All @@ -26,9 +26,9 @@ def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)

step1 = torch.round(2.0 * x) / 2.0
step2 = torch.round(x)
step3 = 2.0 * torch.round(x / 2.0)
step1 = round_ste(2.0 * x) / 2.0
step2 = round_ste(x)
step3 = 2.0 * round_ste(x / 2.0)

mask1 = x < 2.0
mask2 = x < 4.0
Expand All @@ -38,12 +38,6 @@ def cast_to_fp4(x):
return x * sign


def cast_to_fp4_ste(x):
fp4 = (cast_to_fp4(x).to(x.dtype) - x).detach() + x

return fp4


def get_reciprocal(x):
if isinstance(x, torch.Tensor):
return torch.where(x == 0, torch.zeros_like(x, dtype=x.dtype), 1.0 / x)
Expand Down Expand Up @@ -81,7 +75,7 @@ def ref_nvfp4_quant(x, global_scale, block_size=16, v=0, scale_coeff=1.0):
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
scaled_x = x.to(torch.float32) * output_scale + v
clipped_x = torch.clamp(scaled_x, -6.0, 6.0)
return (cast_to_fp4_ste(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale
return (cast_to_fp4(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale


@register_dtype("nv_fp4")
Expand Down Expand Up @@ -208,7 +202,7 @@ def ref_fp4_quant(x, global_scale, block_size=16, v=0, max_scale=1.0):
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
scaled_x = x.to(torch.float32) * output_scale + v
clipped_x = torch.clamp(scaled_x, -6.0, 6.0)
return (cast_to_fp4_ste(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale
return (cast_to_fp4(clipped_x) * get_reciprocal(output_scale)).reshape(m, n), scale


@register_dtype("fp4_v2_with_global_scale")
Expand Down
Loading