From 93e8e7e7be25ec5f92418fd4705e56497af1587a Mon Sep 17 00:00:00 2001 From: ruro Date: Fri, 14 Nov 2025 22:02:07 +0300 Subject: [PATCH 1/7] add OpInfo for fake_quantize_per_channel_affine --- tests/function_libs/torch_lib/extra_opinfo.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 5d7deb1695..72fb873e47 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -779,6 +779,55 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): ) +def sample_inputs_fake_quantize_per_channel_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, 2D, 4D and empty tensors (scalar tensors not supported) + axes_and_shapes = [ + # 1D, 2D, 4D + (axis, (S,) * dims) + for dims in (1, 2, 4) + for axis in range(dims) + ] + [ + # empty + (0, (1, 0, 3)), + (2, (1, 0, 3)), + # empty channel axis causes an error due to + # an internal zero_point.min() calculation + # (1, (1, 0, 3)), + ] + + # tensor_qparams + scale_dtype = torch.float + zero_point_dtypes = [torch.int32, torch.float, torch.half] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(axes_and_shapes, zero_point_dtypes, quant_vals) + for (axis, shape), zero_point_dtype, (quant_min, quant_max) in cases: + scale = make_arg((shape[axis],), dtype=scale_dtype) + + zero_point = make_arg( + (shape[axis],), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + + args = (scale, zero_point, axis, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + def _index_variable_bool(shape, max_indices, device): if not isinstance(shape, tuple): shape = (shape,) @@ -2408,6 +2457,14 @@ def __init__(self): sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_channel_affine", + aten_name="fake_quantize_per_channel_affine", + op=torch.fake_quantize_per_channel_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_channel_affine, + supports_out=False, + ), opinfo_core.BinaryUfuncInfo( "ops.aten.floor_divide", aten_name="floor_divide", From a4028ed942fc2d0ccd0d2bc8f42678da47ca8494 Mon Sep 17 00:00:00 2001 From: ruro Date: Sat, 15 Nov 2025 02:23:57 +0300 Subject: [PATCH 2/7] add OpInfo for fake_quantize_per_tensor_affine --- tests/function_libs/torch_lib/extra_opinfo.py | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 72fb873e47..2ce015b363 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -779,6 +779,60 @@ def sample_inputs__fft_c2r(self, device, dtype, requires_grad=False, **_): ) +def sample_inputs_fake_quantize_per_tensor_affine( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info, kwargs # Unused + make_arg = functools.partial( + opinfo_core.make_tensor, + device=device, + requires_grad=requires_grad, + ) + + # Test 1D, empty and scalar tensors (like sample_inputs_elementwise_unary) + shapes = [ + (S,), + (1, 0, 3), + (), + ] + + scale_zero_point_dtypes = [ + # default (float, int) + (None, None) + ] + [ + # tensor_qparams (tensor, tensor) + (t1, t2) + for t1 in common_dtype.all_types_and() + for t2 in common_dtype.all_types_and() + ] + + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + quant_vals = [(0, 255), (-128, 127), (0, 127)] + + cases = itertools.product(shapes, scale_zero_point_dtypes, quant_vals) + for shape, (scale_dtype, zero_point_dtype), (quant_min, quant_max) in cases: + scale = make_arg( + (), + dtype=scale_dtype or torch.float64, + ) + if scale_dtype is None: + scale = scale.item() + + zero_point = make_arg( + (), + dtype=zero_point_dtype or torch.int64, + # zero_point must be between quant_min and quant_max + low=quant_min, + high=quant_max, + ) + if zero_point_dtype is None: + zero_point = zero_point.item() + + args = (scale, zero_point, quant_min, quant_max) + yield opinfo_core.SampleInput(make_arg(shape, dtype=dtype), args=args) + + def sample_inputs_fake_quantize_per_channel_affine( op_info, device, dtype, requires_grad, **kwargs ): @@ -2457,6 +2511,14 @@ def __init__(self): sample_inputs_func=sample_inputs__fft_r2c, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + aten_name="fake_quantize_per_tensor_affine", + op=torch.fake_quantize_per_tensor_affine, + dtypes=common_dtype.floating_types_and(torch.half, torch.bfloat16), + sample_inputs_func=sample_inputs_fake_quantize_per_tensor_affine, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.fake_quantize_per_channel_affine", aten_name="fake_quantize_per_channel_affine", From 142a47eae43fec44f8a7ef256722f3b0147f7a54 Mon Sep 17 00:00:00 2001 From: ruro Date: Sat, 15 Nov 2025 02:25:19 +0300 Subject: [PATCH 3/7] enable tests for fake_quantize_per_channel_affine --- tests/function_libs/torch_lib/ops_test_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4ef7550b6e..86115005bd 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -698,6 +698,10 @@ def _where_input_wrangler( TorchLibOpInfo("special.erfcx", special_ops.aten_special_erfcx).xfail( reason="fixme: The implementation is numerically unstable: https://github.com/microsoft/onnxscript/issues/1223" ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_channel_affine", + core_ops.aten_fake_quantize_per_channel_affine, + ), TorchLibOpInfo("fill", core_ops.aten_fill), TorchLibOpInfo("flip", core_ops.aten_flip).skip( reason="fixme: size 0 inputs are not handled yet", From cf62faf8e258ed07ae2cbde0fd37afadd37b0580 Mon Sep 17 00:00:00 2001 From: ruro Date: Sat, 15 Nov 2025 02:26:05 +0300 Subject: [PATCH 4/7] enable tests for fake_quantize_per_tensor_affine --- tests/function_libs/torch_lib/ops_test_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 86115005bd..bd5cbc6ca4 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -702,6 +702,10 @@ def _where_input_wrangler( "ops.aten.fake_quantize_per_channel_affine", core_ops.aten_fake_quantize_per_channel_affine, ), + TorchLibOpInfo( + "ops.aten.fake_quantize_per_tensor_affine", + core_ops.aten_fake_quantize_per_tensor_affine, + ), TorchLibOpInfo("fill", core_ops.aten_fill), TorchLibOpInfo("flip", core_ops.aten_flip).skip( reason="fixme: size 0 inputs are not handled yet", From 8bb8a3d27339db93bc469966e6659cca8d1e01ec Mon Sep 17 00:00:00 2001 From: ruro Date: Sat, 15 Nov 2025 02:26:37 +0300 Subject: [PATCH 5/7] implement onnx conversion for aten::fake_quantize_per_channel_affine --- .../function_libs/torch_lib/ops/core.py | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 326075b2fe..2bce7a1333 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -23,6 +23,7 @@ COMPLEX128, DOUBLE, FLOAT, + FLOAT16, INT8, INT16, INT32, @@ -3317,17 +3318,58 @@ def aten_eye(n: int) -> TensorType: raise NotImplementedError() +@torch_op("aten::fake_quantize_per_channel_affine", trace_only=True) def aten_fake_quantize_per_channel_affine( - self: TensorType, - scale: TensorType, - zero_point: TensorType, + self: TFloat, + scale: FLOAT, # float32 specifically! + zero_point: Union[INT32, FLOAT, FLOAT16], # int32, float32 or float16 only! axis: int, quant_min: int, quant_max: int, ) -> TensorType: """fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + if zero_point.type.dtype == ir.DataType.INT32: + zero_point = op.Cast(zero_point, to=int_dtype) + else: + raise NotImplementedError( + "ONNX only supports integer values for the zero_point parameter. " + f"Got {zero_point.type.dtype}", + ) + + quantized = op.QuantizeLinear(self, scale, zero_point, axis=axis) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point, axis=axis) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_channel_affine_cachemask( From ae0b9a33d0af4535db534b88512ee4f46549b536 Mon Sep 17 00:00:00 2001 From: ruro Date: Sat, 15 Nov 2025 02:26:57 +0300 Subject: [PATCH 6/7] add fake_quantize_per_channel_affine expected failure case for non-integer zero_point --- tests/function_libs/torch_lib/ops_test_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index bd5cbc6ca4..e87a0cc232 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -701,6 +701,9 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten.fake_quantize_per_channel_affine", core_ops.aten_fake_quantize_per_channel_affine, + ).xfail( + reason="fixme: ONNX (De)QuantizeLinear only supports integer zero_point values", + matcher=lambda sample: sample.args[1].dtype != torch.int32, ), TorchLibOpInfo( "ops.aten.fake_quantize_per_tensor_affine", From ae22c2ff1f9816b3559f65b7019cd9f9ad4203ce Mon Sep 17 00:00:00 2001 From: ruro Date: Sat, 15 Nov 2025 02:27:25 +0300 Subject: [PATCH 7/7] implement onnx conversion for aten::fake_quantize_per_tensor_affine --- .../function_libs/torch_lib/ops/core.py | 73 ++++++++++++++++++- 1 file changed, 70 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 2bce7a1333..0f9ee7366c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3393,12 +3393,79 @@ def aten_fake_quantize_per_channel_affine_cachemask_backward( raise NotImplementedError() +@torch_op("aten::fake_quantize_per_tensor_affine", trace_only=True) def aten_fake_quantize_per_tensor_affine( - self: TensorType, scale: float, zero_point: int, quant_min: int, quant_max: int -) -> TensorType: + self: TFloat, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, +) -> TFloat: """fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor""" - raise NotImplementedError() + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +@torch_op("aten::fake_quantize_per_tensor_affine.tensor_qparams", trace_only=True) +def aten_fake_quantize_per_tensor_affine_tensor_qparams( + self: TFloat, + scale: TReal, + zero_point: TReal, + quant_min: int, + quant_max: int, +) -> TFloat: + """fake_quantize_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor""" + + return _aten_fake_quantize_per_tensor_affine(self, scale, zero_point, quant_min, quant_max) + + +def _aten_fake_quantize_per_tensor_affine( + self: TFloat, + scale: Union[float, TReal], + zero_point: Union[int, TReal], + quant_min: int, + quant_max: int, +) -> TFloat: + # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127). + # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422 + if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]: + raise NotImplementedError( + "For (quant_min, quant_max), ONNX allows only " + "(0, 127), (0, 255) and (-128, 127). " + f"Got ({quant_min}, {quant_max})", + ) + + if quant_min == 0: + int_dtype = ir.DataType.UINT8 + else: + int_dtype = ir.DataType.INT8 + + # TODO: When opset >= 19, remove this cast + orig_dtype = self.type.dtype + if self.type.dtype not in {ir.DataType.FLOAT, ir.DataType.INT32}: + self = op.Cast(self, to=ir.DataType.FLOAT) + + # TODO: When opset >= 19, relex the condition for this cast + if isinstance(scale, float) or scale.type.dtype != ir.DataType.FLOAT: + scale = op.Cast(scale, to=ir.DataType.FLOAT) + + if isinstance(zero_point, int) or zero_point.type.dtype != int_dtype: + zero_point = op.Cast(zero_point, to=int_dtype) + + quantized = op.QuantizeLinear(self, scale, zero_point) + + # See comment about, PyTorch-specific (0, 127) handling + if (quant_min, quant_max) == (0, 127): + const_127 = op.Cast(127, to=int_dtype) + quantized = op.Clip(quantized, max=const_127) + + output = op.DequantizeLinear(quantized, scale, zero_point) + + # TODO: When opset >= 23, remove this cast and set output_dtype on DequantizeLinear + if orig_dtype != ir.DataType.FLOAT: + output = op.Cast(output, to=orig_dtype) + + return output def aten_fake_quantize_per_tensor_affine_cachemask(