Add QInt8, QUInt8, QInt32 quantized scalar types with full quantizati…#1531
Add QInt8, QUInt8, QInt32 quantized scalar types with full quantizati…#1531alinpahontu2912 wants to merge 1 commit intodotnet:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This pull request adds comprehensive quantization support to TorchSharp, enabling the use of quantized scalar types (QInt8, QUInt8, QInt32) and associated quantization operations. The changes align TorchSharp's quantization API with PyTorch's native quantization capabilities.
Changes:
- Adds three quantized scalar types (QInt8, QUInt8, QInt32) to the ScalarType enum with correct ordinal values matching PyTorch
- Implements full quantization API including per-tensor and per-channel quantization/dequantization operations with corresponding native bindings
- Adds 13 comprehensive unit tests covering all new functionality including edge cases and error handling
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| test/TorchSharpTest/TestTorchSharp.cs | Adds 13 comprehensive unit tests for quantization operations covering enum values, dtype aliases, type checking, quantization/dequantization, and per-channel operations |
| src/TorchSharp/Tensor/torch.PointwiseOps.cs | Adds public API methods for quantize_per_tensor, quantize_per_channel, and dequantize with proper validation and documentation |
| src/TorchSharp/Tensor/TensorExtensionMethods.cs | Adds IsQuantized extension method for ScalarType to support internal type checking |
| src/TorchSharp/Tensor/Tensor.cs | Adds quantized scalar type enum values, dtype aliases, is_quantized checking, and Tensor instance methods for quantization operations (dequantize, q_scale, q_zero_point, int_repr, q_per_channel_*) |
| src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs | Adds P/Invoke declarations for 9 new native quantization functions with proper marshalling signatures |
| src/Native/LibTorchSharp/THSTensor.h | Adds C++ function declarations for quantization operations following existing naming conventions |
| src/Native/LibTorchSharp/THSTensor.cpp | Implements native quantization bindings using CATCH_TENSOR and CATCH_RETURN macros for consistent error handling |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…on support - Uncomment and fix ScalarType enum entries (QInt8=12, QUInt8=13, QInt32=14) - Fix incorrect QUInt32 name to QInt32 to match PyTorch - Add torch.qint8, torch.quint8, torch.qint32 dtype aliases - Add torch.is_quantized() and Tensor.is_quantized() methods - Add IsQuantized() extension method on ScalarType Native quantization bindings: - Add THSTensor_quantize_per_tensor, THSTensor_quantize_per_channel (C++ and P/Invoke) - Add THSTensor_dequantize (C++ and P/Invoke) - Add THSTensor_q_scale, THSTensor_q_zero_point (C++ and P/Invoke) - Add THSTensor_int_repr (C++ and P/Invoke) - Add THSTensor_q_per_channel_scales, THSTensor_q_per_channel_zero_points, THSTensor_q_per_channel_axis (C++ and P/Invoke) Managed API: - Add torch.quantize_per_tensor() and torch.quantize_per_channel() static methods - Add torch.dequantize() static method - Add Tensor.dequantize(), Tensor.q_scale(), Tensor.q_zero_point() instance methods - Add Tensor.int_repr() instance method - Add Tensor.q_per_channel_scales(), Tensor.q_per_channel_zero_points(), Tensor.q_per_channel_axis() instance methods Unit tests for all new functionality (13 tests)
a103b80 to
54b5302
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| QInt8 = 12, | ||
| QUInt8 = 13, | ||
| QInt32 = 14, |
There was a problem hiding this comment.
ScalarType now includes quantized dtypes (QInt8/QUInt8/QInt32), but ScalarType.ElementSize() (in TensorExtensionMethods.cs) doesn’t handle these and will currently throw NotImplementedException. Add element sizes for the new quantized types (qint8/quint8 = 1 byte, qint32 = 4 bytes) so existing utilities continue to work with quantized tensors.
| public void QuantizePerChannel() | ||
| { | ||
| // Create a 2D tensor: 2 channels x 3 elements | ||
| var floatTensor = torch.tensor(new float[] { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }).reshape(2, 3); |
There was a problem hiding this comment.
This test creates an extra intermediate tensor via torch.tensor(...).reshape(2, 3); the original 1D tensor returned by torch.tensor(...) is no longer referenced and won’t be deterministically disposed. Split the calls (or use a dispose scope / using var) so both the base tensor and the reshaped view are properly disposed.
Fixes #141
Native quantization bindings:
Managed API:
Unit tests for all new functionality (13 tests)