From 54b5302f0f49e1c1715061e4a29e6fdcb599f733 Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Fri, 13 Feb 2026 15:11:10 +0100 Subject: [PATCH] Add QInt8, QUInt8, QInt32 quantized scalar types with full quantization support - Uncomment and fix ScalarType enum entries (QInt8=12, QUInt8=13, QInt32=14) - Fix incorrect QUInt32 name to QInt32 to match PyTorch - Add torch.qint8, torch.quint8, torch.qint32 dtype aliases - Add torch.is_quantized() and Tensor.is_quantized() methods - Add IsQuantized() extension method on ScalarType Native quantization bindings: - Add THSTensor_quantize_per_tensor, THSTensor_quantize_per_channel (C++ and P/Invoke) - Add THSTensor_dequantize (C++ and P/Invoke) - Add THSTensor_q_scale, THSTensor_q_zero_point (C++ and P/Invoke) - Add THSTensor_int_repr (C++ and P/Invoke) - Add THSTensor_q_per_channel_scales, THSTensor_q_per_channel_zero_points, THSTensor_q_per_channel_axis (C++ and P/Invoke) Managed API: - Add torch.quantize_per_tensor() and torch.quantize_per_channel() static methods - Add torch.dequantize() static method - Add Tensor.dequantize(), Tensor.q_scale(), Tensor.q_zero_point() instance methods - Add Tensor.int_repr() instance method - Add Tensor.q_per_channel_scales(), Tensor.q_per_channel_zero_points(), Tensor.q_per_channel_axis() instance methods Unit tests for all new functionality (13 tests) --- src/Native/LibTorchSharp/THSTensor.cpp | 45 ++++ src/Native/LibTorchSharp/THSTensor.h | 12 + .../PInvoke/LibTorchSharp.THSTensor.cs | 27 +++ src/TorchSharp/Tensor/Tensor.cs | 111 ++++++++- .../Tensor/TensorExtensionMethods.cs | 17 ++ src/TorchSharp/Tensor/torch.PointwiseOps.cs | 41 ++++ test/TorchSharpTest/TestTorchSharp.cs | 218 ++++++++++++++++++ 7 files changed, 468 insertions(+), 3 deletions(-) diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp index 13cd5787e..f313c56a1 100644 --- a/src/Native/LibTorchSharp/THSTensor.cpp +++ b/src/Native/LibTorchSharp/THSTensor.cpp @@ -2278,3 +2278,48 @@ Tensor THSTensor_unflatten_names(Tensor tensor, const char** names, const int64_ return nullptr; } + +Tensor THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type) +{ + CATCH_TENSOR(torch::quantize_per_tensor(*tensor, scale, zero_point, at::ScalarType(scalar_type))); +} + +Tensor THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type) +{ + CATCH_TENSOR(torch::quantize_per_channel(*tensor, *scales, *zero_points, axis, at::ScalarType(scalar_type))); +} + +Tensor THSTensor_dequantize(const Tensor tensor) +{ + CATCH_TENSOR(tensor->dequantize()); +} + +double THSTensor_q_scale(const Tensor tensor) +{ + CATCH_RETURN(double, 0.0, tensor->q_scale()); +} + +int64_t THSTensor_q_zero_point(const Tensor tensor) +{ + CATCH_RETURN(int64_t, 0, tensor->q_zero_point()); +} + +Tensor THSTensor_int_repr(const Tensor tensor) +{ + CATCH_TENSOR(tensor->int_repr()); +} + +Tensor THSTensor_q_per_channel_scales(const Tensor tensor) +{ + CATCH_TENSOR(tensor->q_per_channel_scales()); +} + +Tensor THSTensor_q_per_channel_zero_points(const Tensor tensor) +{ + CATCH_TENSOR(tensor->q_per_channel_zero_points()); +} + +int64_t THSTensor_q_per_channel_axis(const Tensor tensor) +{ + CATCH_RETURN(int64_t, 0, tensor->q_per_channel_axis()); +} diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h index 4ddfffb49..4437db84e 100644 --- a/src/Native/LibTorchSharp/THSTensor.h +++ b/src/Native/LibTorchSharp/THSTensor.h @@ -1790,3 +1790,15 @@ EXPORT_API(Tensor) THSTensor_kaiser_window(const int64_t len, bool periodic, dou EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, int64_t onesided, bool return_complex); EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, int64_t onesided, int64_t length, bool return_complex); + +// Quantization Ops + +EXPORT_API(Tensor) THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type); +EXPORT_API(Tensor) THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type); +EXPORT_API(Tensor) THSTensor_dequantize(const Tensor tensor); +EXPORT_API(double) THSTensor_q_scale(const Tensor tensor); +EXPORT_API(int64_t) THSTensor_q_zero_point(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_int_repr(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_q_per_channel_scales(const Tensor tensor); +EXPORT_API(Tensor) THSTensor_q_per_channel_zero_points(const Tensor tensor); +EXPORT_API(int64_t) THSTensor_q_per_channel_axis(const Tensor tensor); diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs index bb5d2dbd9..a0597698a 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs @@ -2176,6 +2176,33 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input, internal static extern IntPtr THSTensor_histogram_out_t(IntPtr input, IntPtr bins, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges); [DllImport("LibTorchSharp")] internal static extern IntPtr THSTensor_histogram_out_i(IntPtr input, long bins, IntPtr range, int length, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_quantize_per_tensor(IntPtr tensor, double scale, long zero_point, sbyte scalar_type); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_quantize_per_channel(IntPtr tensor, IntPtr scales, IntPtr zero_points, long axis, sbyte scalar_type); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_dequantize(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern double THSTensor_q_scale(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern long THSTensor_q_zero_point(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_int_repr(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_q_per_channel_scales(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern IntPtr THSTensor_q_per_channel_zero_points(IntPtr tensor); + + [DllImport("LibTorchSharp")] + internal static extern long THSTensor_q_per_channel_axis(IntPtr tensor); } #pragma warning restore CA2101 } diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs index 46a35db04..c6ccb88b7 100644 --- a/src/TorchSharp/Tensor/Tensor.cs +++ b/src/TorchSharp/Tensor/Tensor.cs @@ -271,6 +271,95 @@ internal IntPtr MoveHandle() /// public bool is_complex() => torch.is_complex(dtype); + /// + /// Returns True if the data type of input is a quantized data type i.e., one of torch.qint8, torch.quint8, and torch.qint32. + /// + public bool is_quantized() => torch.is_quantized(dtype); + + /// + /// Given a quantized Tensor, returns a dequantized (float) Tensor. + /// + public Tensor dequantize() + { + var res = NativeMethods.THSTensor_dequantize(Handle); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + /// + /// Given a quantized Tensor, returns the scale of the quantization as a double. + /// + public double q_scale() + { + var res = NativeMethods.THSTensor_q_scale(Handle); + CheckForErrors(); + return res; + } + + /// + /// Given a quantized Tensor, returns the zero_point of the quantization as a long. + /// + public long q_zero_point() + { + var res = NativeMethods.THSTensor_q_zero_point(Handle); + CheckForErrors(); + return res; + } + + /// + /// Given a quantized Tensor, returns a Tensor of the underlying integer representation. + /// + public Tensor int_repr() + { + var res = NativeMethods.THSTensor_int_repr(Handle); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + /// + /// Given a quantized Tensor quantized per channel, returns a Tensor of the scales of the quantization for each channel. + /// + public Tensor q_per_channel_scales() + { + var res = NativeMethods.THSTensor_q_per_channel_scales(Handle); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + /// + /// Given a quantized Tensor quantized per channel, returns a Tensor of the zero points of the quantization for each channel. + /// + public Tensor q_per_channel_zero_points() + { + var res = NativeMethods.THSTensor_q_per_channel_zero_points(Handle); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + /// + /// Given a quantized Tensor quantized per channel, returns the axis along which per channel quantization is applied. + /// + public long q_per_channel_axis() + { + var res = NativeMethods.THSTensor_q_per_channel_axis(Handle); + CheckForErrors(); + return res; + } + + internal Tensor _quantize_per_tensor(double scale, long zero_point, ScalarType dtype) + { + var res = NativeMethods.THSTensor_quantize_per_tensor(Handle, scale, zero_point, (sbyte)dtype); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + + internal Tensor _quantize_per_channel(Tensor scales, Tensor zero_points, long axis, ScalarType dtype) + { + var res = NativeMethods.THSTensor_quantize_per_channel(Handle, scales.Handle, zero_points.Handle, axis, (sbyte)dtype); + if (res == IntPtr.Zero) { CheckForErrors(); } + return new Tensor(res); + } + /// /// Returns True if the input is a single element tensor which is not equal to zero after type conversions, /// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]). @@ -7279,9 +7368,9 @@ public enum ScalarType : sbyte ComplexFloat32 = 9, ComplexFloat64 = 10, Bool = 11, - //QInt8 = 12, - //QUInt8 = 13, - //QUInt32 = 14, + QInt8 = 12, + QUInt8 = 13, + QInt32 = 14, BFloat16 = 15 } @@ -7413,6 +7502,18 @@ public static bool is_complex(ScalarType type) } } + public static bool is_quantized(ScalarType type) + { + switch (type) { + case ScalarType.QInt8: + case ScalarType.QUInt8: + case ScalarType.QInt32: + return true; + default: + return false; + } + } + public static long max_int_value(ScalarType type) { switch (type) { @@ -7463,6 +7564,10 @@ public static long max_int_value(ScalarType type) public static ScalarType cfloat = ScalarType.ComplexFloat32; public static ScalarType cdouble = ScalarType.ComplexFloat64; + public static ScalarType qint8 = ScalarType.QInt8; + public static ScalarType quint8 = ScalarType.QUInt8; + public static ScalarType qint32 = ScalarType.QInt32; + /// /// Creates a new dispose scope for the current thread. Any tensor created within the dispose scope will /// be automatically disposed once the dispose scope is disposed. diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs index 2f4fa81dc..91b777660 100644 --- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs +++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs @@ -368,6 +368,23 @@ internal static bool IsComplex(this ScalarType type) } } + /// + /// Indicates whether a given element type is quantized. + /// + /// The input type. + /// + internal static bool IsQuantized(this ScalarType type) + { + switch (type) { + case ScalarType.QInt8: + case ScalarType.QUInt8: + case ScalarType.QInt32: + return true; + default: + return false; + } + } + /// /// Save the tensor in a .NET-specific format. /// diff --git a/src/TorchSharp/Tensor/torch.PointwiseOps.cs b/src/TorchSharp/Tensor/torch.PointwiseOps.cs index 0fccbd8ce..cc7dffbd6 100644 --- a/src/TorchSharp/Tensor/torch.PointwiseOps.cs +++ b/src/TorchSharp/Tensor/torch.PointwiseOps.cs @@ -761,6 +761,47 @@ public static Tensor fake_quantize_per_channel_affine(Tensor input, Tensor scale public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale, Tensor zero_point, long quant_min, long quant_max) => throw new NotImplementedException(); + // https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor + /// + /// Converts a float tensor to a quantized tensor with given scale and zero point. + /// + /// Float tensor to quantize + /// Scale to apply in quantization formula + /// Offset in integer value that maps to float zero + /// The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32). + /// A newly quantized tensor + public static Tensor quantize_per_tensor(Tensor input, double scale, long zero_point, ScalarType dtype) + { + if (!is_quantized(dtype)) + throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype)); + return input._quantize_per_tensor(scale, zero_point, dtype); + } + + // https://pytorch.org/docs/stable/generated/torch.quantize_per_channel + /// + /// Converts a float tensor to a per-channel quantized tensor with given scales and zero points. + /// + /// Float tensor to quantize + /// Float 1D tensor of scales to use, size should match input.size(axis) + /// Integer 1D tensor of offsets to use, size should match input.size(axis) + /// Dimension on which to apply per-channel quantization + /// The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32). + /// A newly quantized tensor + public static Tensor quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, long axis, ScalarType dtype) + { + if (!is_quantized(dtype)) + throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype)); + return input._quantize_per_channel(scales, zero_points, axis, dtype); + } + + // https://pytorch.org/docs/stable/generated/torch.dequantize + /// + /// Returns an fp32 Tensor by dequantizing a quantized Tensor. + /// + /// A quantized tensor + /// A dequantized (float) tensor + public static Tensor dequantize(Tensor input) => input.dequantize(); + // https://pytorch.org/docs/stable/generated/torch.fix /// /// Returns a new tensor with the truncated integer values of the elements of input. diff --git a/test/TorchSharpTest/TestTorchSharp.cs b/test/TorchSharpTest/TestTorchSharp.cs index 549b8f131..b78e48b70 100644 --- a/test/TorchSharpTest/TestTorchSharp.cs +++ b/test/TorchSharpTest/TestTorchSharp.cs @@ -466,5 +466,223 @@ public void CheckVersionStrings() // Because some of the tests mess with global state, and are run in parallel, we need to // acquire a lock before testing setting the default RNG see. private static object _lock = new object(); + + [Fact] + [TestOf(nameof(ScalarType))] + public void QIntScalarTypeEnumValues() + { + // Verify the enum values match PyTorch's ScalarType ordinals + Assert.Equal(12, (int)ScalarType.QInt8); + Assert.Equal(13, (int)ScalarType.QUInt8); + Assert.Equal(14, (int)ScalarType.QInt32); + } + + [Fact] + [TestOf(nameof(torch.is_quantized))] + public void IsQuantizedScalarType() + { + // Quantized types should return true + Assert.True(torch.is_quantized(ScalarType.QInt8)); + Assert.True(torch.is_quantized(ScalarType.QUInt8)); + Assert.True(torch.is_quantized(ScalarType.QInt32)); + + // Non-quantized types should return false + Assert.False(torch.is_quantized(ScalarType.Float32)); + Assert.False(torch.is_quantized(ScalarType.Float64)); + Assert.False(torch.is_quantized(ScalarType.Int8)); + Assert.False(torch.is_quantized(ScalarType.Int32)); + Assert.False(torch.is_quantized(ScalarType.Bool)); + Assert.False(torch.is_quantized(ScalarType.Byte)); + Assert.False(torch.is_quantized(ScalarType.ComplexFloat32)); + Assert.False(torch.is_quantized(ScalarType.BFloat16)); + } + + [Fact] + [TestOf(nameof(torch.qint8))] + public void QIntDtypeAliases() + { + // Verify dtype aliases map to the correct ScalarType values + Assert.Equal(ScalarType.QInt8, torch.qint8); + Assert.Equal(ScalarType.QUInt8, torch.quint8); + Assert.Equal(ScalarType.QInt32, torch.qint32); + } + + [Fact] + [TestOf(nameof(torch.is_quantized))] + public void IsQuantizedNotIntegralOrFloating() + { + // Quantized types should not be classified as integral, floating, or complex + Assert.False(torch.is_integral(ScalarType.QInt8)); + Assert.False(torch.is_integral(ScalarType.QUInt8)); + Assert.False(torch.is_integral(ScalarType.QInt32)); + + Assert.False(torch.is_floating_point(ScalarType.QInt8)); + Assert.False(torch.is_floating_point(ScalarType.QUInt8)); + Assert.False(torch.is_floating_point(ScalarType.QInt32)); + + Assert.False(torch.is_complex(ScalarType.QInt8)); + Assert.False(torch.is_complex(ScalarType.QUInt8)); + Assert.False(torch.is_complex(ScalarType.QInt32)); + } + + [Fact] + [TestOf(nameof(torch.quantize_per_tensor))] + public void QuantizePerTensorQInt8() + { + var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); + + var qTensor = torch.quantize_per_tensor(floatTensor, 0.1, 0, ScalarType.QInt8); + Assert.True(qTensor.is_quantized()); + Assert.Equal(ScalarType.QInt8, qTensor.dtype); + Assert.False(qTensor.is_floating_point()); + Assert.False(qTensor.is_integral()); + Assert.False(qTensor.is_complex()); + + qTensor.Dispose(); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(torch.quantize_per_tensor))] + public void QuantizePerTensorQUInt8() + { + var floatTensor = torch.tensor(new float[] { 0.5f, 1.5f, 2.5f }); + + var qTensor = torch.quantize_per_tensor(floatTensor, 0.1, 128, ScalarType.QUInt8); + Assert.True(qTensor.is_quantized()); + Assert.Equal(ScalarType.QUInt8, qTensor.dtype); + + qTensor.Dispose(); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(torch.quantize_per_tensor))] + public void QuantizePerTensorQInt32() + { + var floatTensor = torch.tensor(new float[] { -1.0f, 0.0f, 1.0f, 2.0f }); + + var qTensor = torch.quantize_per_tensor(floatTensor, 0.01, 0, ScalarType.QInt32); + Assert.True(qTensor.is_quantized()); + Assert.Equal(ScalarType.QInt32, qTensor.dtype); + + qTensor.Dispose(); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(torch.quantize_per_tensor))] + public void QuantizePerTensorInvalidDtypeThrows() + { + var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f }); + Assert.Throws(() => torch.quantize_per_tensor(floatTensor, 0.1, 0, ScalarType.Float32)); + Assert.Throws(() => torch.quantize_per_tensor(floatTensor, 0.1, 0, ScalarType.Int32)); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(Tensor.dequantize))] + public void DequantizeRoundtrip() + { + var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f }); + + var qTensor = torch.quantize_per_tensor(floatTensor, 1.0, 0, ScalarType.QInt8); + Assert.True(qTensor.is_quantized()); + + var dequantized = qTensor.dequantize(); + Assert.False(dequantized.is_quantized()); + Assert.True(dequantized.is_floating_point()); + Assert.Equal(ScalarType.Float32, dequantized.dtype); + + // With scale=1.0 and zero_point=0, values should roundtrip exactly + Assert.Equal(1.0f, dequantized[0].ToSingle()); + Assert.Equal(2.0f, dequantized[1].ToSingle()); + Assert.Equal(3.0f, dequantized[2].ToSingle()); + + dequantized.Dispose(); + qTensor.Dispose(); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(torch.dequantize))] + public void DequantizeStaticMethod() + { + var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f }); + var qTensor = torch.quantize_per_tensor(floatTensor, 1.0, 0, ScalarType.QInt8); + + var dequantized = torch.dequantize(qTensor); + Assert.False(dequantized.is_quantized()); + Assert.Equal(ScalarType.Float32, dequantized.dtype); + + dequantized.Dispose(); + qTensor.Dispose(); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(Tensor.q_scale))] + public void QScaleAndZeroPoint() + { + var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f }); + double scale = 0.5; + long zeroPoint = 10; + + var qTensor = torch.quantize_per_tensor(floatTensor, scale, zeroPoint, ScalarType.QInt8); + Assert.Equal(scale, qTensor.q_scale()); + Assert.Equal(zeroPoint, qTensor.q_zero_point()); + + qTensor.Dispose(); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(Tensor.int_repr))] + public void IntReprReturnsUnderlyingIntegers() + { + var floatTensor = torch.tensor(new float[] { 0.0f, 1.0f, 2.0f }); + + // scale=1.0, zero_point=0: quantized values should be 0, 1, 2 + var qTensor = torch.quantize_per_tensor(floatTensor, 1.0, 0, ScalarType.QInt8); + var intRepr = qTensor.int_repr(); + + Assert.False(intRepr.is_quantized()); + Assert.Equal(ScalarType.Int8, intRepr.dtype); + Assert.Equal(0, intRepr[0].ToSByte()); + Assert.Equal(1, intRepr[1].ToSByte()); + Assert.Equal(2, intRepr[2].ToSByte()); + + intRepr.Dispose(); + qTensor.Dispose(); + floatTensor.Dispose(); + } + + [Fact] + [TestOf(nameof(torch.quantize_per_channel))] + public void QuantizePerChannel() + { + // Create a 2D tensor: 2 channels x 3 elements + var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }).reshape(2, 3); + var scales = torch.tensor(new double[] { 0.1, 0.2 }); + var zeroPoints = torch.tensor(new long[] { 0, 0 }); + + var qTensor = torch.quantize_per_channel(floatTensor, scales, zeroPoints, 0, ScalarType.QInt8); + Assert.True(qTensor.is_quantized()); + Assert.Equal(ScalarType.QInt8, qTensor.dtype); + + // Verify per-channel quantization parameters + var channelScales = qTensor.q_per_channel_scales(); + var channelZeroPoints = qTensor.q_per_channel_zero_points(); + Assert.Equal(0, qTensor.q_per_channel_axis()); + Assert.Equal(0.1, channelScales[0].ToDouble(), 5); + Assert.Equal(0.2, channelScales[1].ToDouble(), 5); + + channelScales.Dispose(); + channelZeroPoints.Dispose(); + qTensor.Dispose(); + scales.Dispose(); + zeroPoints.Dispose(); + floatTensor.Dispose(); + } } }