diff --git a/auto_round/data_type/nvfp.py b/auto_round/data_type/nvfp.py index 7b7d98cc8..a5b7bfebc 100644 --- a/auto_round/data_type/nvfp.py +++ b/auto_round/data_type/nvfp.py @@ -24,14 +24,16 @@ def cast_to_fp4(x): sign = torch.sign(x) x = torch.abs(x) - x[(x >= 0.0) & (x <= 0.25)] = 0.0 - x[(x > 0.25) & (x < 0.75)] = 0.5 - x[(x >= 0.75) & (x <= 1.25)] = 1.0 - x[(x > 1.25) & (x < 1.75)] = 1.5 - x[(x >= 1.75) & (x <= 2.5)] = 2.0 - x[(x > 2.5) & (x < 3.5)] = 3.0 - x[(x >= 3.5) & (x <= 5.0)] = 4.0 - x[x > 5.0] = 6.0 + + step1 = torch.round(2.0 * x) / 2.0 + step2 = torch.round(x) + step3 = 2.0 * torch.round(x / 2.0) + + mask1 = x < 2.0 + mask2 = x < 4.0 + x = step1 * mask1 + step2 * (~mask1) * mask2 + step3 * (~mask1) * (~mask2) + x = x.clamp(-6, 6) + return x * sign @@ -222,6 +224,14 @@ def fp4_v2(tensor, bits=4, group_size=32, v=0, max_scale=1.0, **kwargs): if __name__ == "__main__": + data = torch.tensor([0.0, 0.25, 0.4, 0.75, 1.25, 1.4, 1.75, 2.5, 2.9, 3.5, 5.0, 5.1, 6.0, 6.2, 8.9]) + data1 = cast_to_fp4(data) + gt = torch.tensor([0.0, 0.0, 0.5, 1.0, 1.0, 1.5, 2.0, 2.0, 3.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0]) + assert torch.sum(torch.abs(data1 - gt)) < 1e-6 + + data_neg = data * -1 + data2 = cast_to_fp4(data_neg) + assert torch.sum(torch.abs(data2 - gt * -1)) < 1e-6 test = torch.tensor( [