Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2278,3 +2278,48 @@ Tensor THSTensor_unflatten_names(Tensor tensor, const char** names, const int64_

return nullptr;
}

Tensor THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type)
{
CATCH_TENSOR(torch::quantize_per_tensor(*tensor, scale, zero_point, at::ScalarType(scalar_type)));
}

Tensor THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type)
{
CATCH_TENSOR(torch::quantize_per_channel(*tensor, *scales, *zero_points, axis, at::ScalarType(scalar_type)));
}

Tensor THSTensor_dequantize(const Tensor tensor)
{
CATCH_TENSOR(tensor->dequantize());
}

double THSTensor_q_scale(const Tensor tensor)
{
CATCH_RETURN(double, 0.0, tensor->q_scale());
}

int64_t THSTensor_q_zero_point(const Tensor tensor)
{
CATCH_RETURN(int64_t, 0, tensor->q_zero_point());
}

Tensor THSTensor_int_repr(const Tensor tensor)
{
CATCH_TENSOR(tensor->int_repr());
}

Tensor THSTensor_q_per_channel_scales(const Tensor tensor)
{
CATCH_TENSOR(tensor->q_per_channel_scales());
}

Tensor THSTensor_q_per_channel_zero_points(const Tensor tensor)
{
CATCH_TENSOR(tensor->q_per_channel_zero_points());
}

int64_t THSTensor_q_per_channel_axis(const Tensor tensor)
{
CATCH_RETURN(int64_t, 0, tensor->q_per_channel_axis());
}
12 changes: 12 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1790,3 +1790,15 @@ EXPORT_API(Tensor) THSTensor_kaiser_window(const int64_t len, bool periodic, dou

EXPORT_API(Tensor) THSTensor_stft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool normalized, int64_t onesided, bool return_complex);
EXPORT_API(Tensor) THSTensor_istft(const Tensor x, int64_t n_fft, int64_t hop_length, int64_t win_length, const Tensor window, bool center, bool normalized, int64_t onesided, int64_t length, bool return_complex);

// Quantization Ops

EXPORT_API(Tensor) THSTensor_quantize_per_tensor(const Tensor tensor, double scale, int64_t zero_point, int8_t scalar_type);
EXPORT_API(Tensor) THSTensor_quantize_per_channel(const Tensor tensor, const Tensor scales, const Tensor zero_points, int64_t axis, int8_t scalar_type);
EXPORT_API(Tensor) THSTensor_dequantize(const Tensor tensor);
EXPORT_API(double) THSTensor_q_scale(const Tensor tensor);
EXPORT_API(int64_t) THSTensor_q_zero_point(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_int_repr(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_q_per_channel_scales(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_q_per_channel_zero_points(const Tensor tensor);
EXPORT_API(int64_t) THSTensor_q_per_channel_axis(const Tensor tensor);
27 changes: 27 additions & 0 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2176,6 +2176,33 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern IntPtr THSTensor_histogram_out_t(IntPtr input, IntPtr bins, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_histogram_out_i(IntPtr input, long bins, IntPtr range, int length, IntPtr weight, bool density, out IntPtr hist, out IntPtr bin_edges, out IntPtr r_bin_edges);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_quantize_per_tensor(IntPtr tensor, double scale, long zero_point, sbyte scalar_type);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_quantize_per_channel(IntPtr tensor, IntPtr scales, IntPtr zero_points, long axis, sbyte scalar_type);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_dequantize(IntPtr tensor);

[DllImport("LibTorchSharp")]
internal static extern double THSTensor_q_scale(IntPtr tensor);

[DllImport("LibTorchSharp")]
internal static extern long THSTensor_q_zero_point(IntPtr tensor);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_int_repr(IntPtr tensor);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_q_per_channel_scales(IntPtr tensor);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_q_per_channel_zero_points(IntPtr tensor);

[DllImport("LibTorchSharp")]
internal static extern long THSTensor_q_per_channel_axis(IntPtr tensor);
}
#pragma warning restore CA2101
}
111 changes: 108 additions & 3 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,95 @@ internal IntPtr MoveHandle()
/// </summary>
public bool is_complex() => torch.is_complex(dtype);

/// <summary>
/// Returns True if the data type of input is a quantized data type i.e., one of torch.qint8, torch.quint8, and torch.qint32.
/// </summary>
public bool is_quantized() => torch.is_quantized(dtype);

/// <summary>
/// Given a quantized Tensor, returns a dequantized (float) Tensor.
/// </summary>
public Tensor dequantize()
{
var res = NativeMethods.THSTensor_dequantize(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}

/// <summary>
/// Given a quantized Tensor, returns the scale of the quantization as a double.
/// </summary>
public double q_scale()
{
var res = NativeMethods.THSTensor_q_scale(Handle);
CheckForErrors();
return res;
}

/// <summary>
/// Given a quantized Tensor, returns the zero_point of the quantization as a long.
/// </summary>
public long q_zero_point()
{
var res = NativeMethods.THSTensor_q_zero_point(Handle);
CheckForErrors();
return res;
}

/// <summary>
/// Given a quantized Tensor, returns a Tensor of the underlying integer representation.
/// </summary>
public Tensor int_repr()
{
var res = NativeMethods.THSTensor_int_repr(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}

/// <summary>
/// Given a quantized Tensor quantized per channel, returns a Tensor of the scales of the quantization for each channel.
/// </summary>
public Tensor q_per_channel_scales()
{
var res = NativeMethods.THSTensor_q_per_channel_scales(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}

/// <summary>
/// Given a quantized Tensor quantized per channel, returns a Tensor of the zero points of the quantization for each channel.
/// </summary>
public Tensor q_per_channel_zero_points()
{
var res = NativeMethods.THSTensor_q_per_channel_zero_points(Handle);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}

/// <summary>
/// Given a quantized Tensor quantized per channel, returns the axis along which per channel quantization is applied.
/// </summary>
public long q_per_channel_axis()
{
var res = NativeMethods.THSTensor_q_per_channel_axis(Handle);
CheckForErrors();
return res;
}

internal Tensor _quantize_per_tensor(double scale, long zero_point, ScalarType dtype)
{
var res = NativeMethods.THSTensor_quantize_per_tensor(Handle, scale, zero_point, (sbyte)dtype);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}

internal Tensor _quantize_per_channel(Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
{
var res = NativeMethods.THSTensor_quantize_per_channel(Handle, scales.Handle, zero_points.Handle, axis, (sbyte)dtype);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}

/// <summary>
/// Returns True if the input is a single element tensor which is not equal to zero after type conversions,
/// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]).
Expand Down Expand Up @@ -7279,9 +7368,9 @@ public enum ScalarType : sbyte
ComplexFloat32 = 9,
ComplexFloat64 = 10,
Bool = 11,
//QInt8 = 12,
//QUInt8 = 13,
//QUInt32 = 14,
QInt8 = 12,
QUInt8 = 13,
QInt32 = 14,
Comment on lines +7371 to +7373
Copy link

Copilot AI Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScalarType now includes quantized dtypes (QInt8/QUInt8/QInt32), but ScalarType.ElementSize() (in TensorExtensionMethods.cs) doesn’t handle these and will currently throw NotImplementedException. Add element sizes for the new quantized types (qint8/quint8 = 1 byte, qint32 = 4 bytes) so existing utilities continue to work with quantized tensors.

Copilot uses AI. Check for mistakes.
BFloat16 = 15
}

Expand Down Expand Up @@ -7413,6 +7502,18 @@ public static bool is_complex(ScalarType type)
}
}

public static bool is_quantized(ScalarType type)
{
switch (type) {
case ScalarType.QInt8:
case ScalarType.QUInt8:
case ScalarType.QInt32:
return true;
default:
return false;
}
}

public static long max_int_value(ScalarType type)
{
switch (type) {
Expand Down Expand Up @@ -7463,6 +7564,10 @@ public static long max_int_value(ScalarType type)
public static ScalarType cfloat = ScalarType.ComplexFloat32;
public static ScalarType cdouble = ScalarType.ComplexFloat64;

public static ScalarType qint8 = ScalarType.QInt8;
public static ScalarType quint8 = ScalarType.QUInt8;
public static ScalarType qint32 = ScalarType.QInt32;

/// <summary>
/// Creates a new dispose scope for the current thread. Any tensor created within the dispose scope will
/// be automatically disposed once the dispose scope is disposed.
Expand Down
17 changes: 17 additions & 0 deletions src/TorchSharp/Tensor/TensorExtensionMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,23 @@ internal static bool IsComplex(this ScalarType type)
}
}

/// <summary>
/// Indicates whether a given element type is quantized.
/// </summary>
/// <param name="type">The input type.</param>
/// <returns></returns>
internal static bool IsQuantized(this ScalarType type)
{
switch (type) {
case ScalarType.QInt8:
case ScalarType.QUInt8:
case ScalarType.QInt32:
return true;
default:
return false;
}
}

/// <summary>
/// Save the tensor in a .NET-specific format.
/// </summary>
Expand Down
41 changes: 41 additions & 0 deletions src/TorchSharp/Tensor/torch.PointwiseOps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,47 @@ public static Tensor fake_quantize_per_channel_affine(Tensor input, Tensor scale
public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale, Tensor zero_point, long quant_min, long quant_max)
=> throw new NotImplementedException();

// https://pytorch.org/docs/stable/generated/torch.quantize_per_tensor
/// <summary>
/// Converts a float tensor to a quantized tensor with given scale and zero point.
/// </summary>
/// <param name="input">Float tensor to quantize</param>
/// <param name="scale">Scale to apply in quantization formula</param>
/// <param name="zero_point">Offset in integer value that maps to float zero</param>
/// <param name="dtype">The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).</param>
/// <returns>A newly quantized tensor</returns>
public static Tensor quantize_per_tensor(Tensor input, double scale, long zero_point, ScalarType dtype)
{
if (!is_quantized(dtype))
throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
return input._quantize_per_tensor(scale, zero_point, dtype);
}

// https://pytorch.org/docs/stable/generated/torch.quantize_per_channel
/// <summary>
/// Converts a float tensor to a per-channel quantized tensor with given scales and zero points.
/// </summary>
/// <param name="input">Float tensor to quantize</param>
/// <param name="scales">Float 1D tensor of scales to use, size should match input.size(axis)</param>
/// <param name="zero_points">Integer 1D tensor of offsets to use, size should match input.size(axis)</param>
/// <param name="axis">Dimension on which to apply per-channel quantization</param>
/// <param name="dtype">The desired data type of returned tensor. Must be a quantized type (torch.qint8, torch.quint8, or torch.qint32).</param>
/// <returns>A newly quantized tensor</returns>
public static Tensor quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, long axis, ScalarType dtype)
{
if (!is_quantized(dtype))
throw new ArgumentException("dtype must be a quantized type (QInt8, QUInt8, or QInt32)", nameof(dtype));
return input._quantize_per_channel(scales, zero_points, axis, dtype);
}

// https://pytorch.org/docs/stable/generated/torch.dequantize
/// <summary>
/// Returns an fp32 Tensor by dequantizing a quantized Tensor.
/// </summary>
/// <param name="input">A quantized tensor</param>
/// <returns>A dequantized (float) tensor</returns>
public static Tensor dequantize(Tensor input) => input.dequantize();

// https://pytorch.org/docs/stable/generated/torch.fix
/// <summary>
/// Returns a new tensor with the truncated integer values of the elements of input.
Expand Down
Loading
Loading