diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index f3147bdf78aa..1d9f9d9d2a12 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -290,6 +290,27 @@ Tensor index(const Tensor & self, TensorList indices) { return iter.output(); } +Tensor quantized_index(const Tensor & self, TensorList indices) { + TORCH_INTERNAL_ASSERT( + self.qscheme() == c10::kPerTensorAffine || + self.qscheme() == c10::kPerTensorSymmetric, + "Indexing is only supported for per-Tensor quantized Tensors."); + + // For now, this is a naive implementation which does dq -> index -> q. + // TODO(future PR): improve performance by removing the copies. + const auto& self_dq = self.dequantize(); + + TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); + + auto info = make_info(self_dq, indices); + auto iter = make_index_iterator(info); + index_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides); + at::Tensor res = iter.output(); + + return at::quantize_per_tensor( + res, self.q_scale(), self.q_zero_point(), self.scalar_type()); +} + Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) { TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); at::assert_no_internal_overlap(result); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 09b7c5f7e762..715fdccc9691 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2200,6 +2200,7 @@ variants: function, method dispatch: CPU, CUDA: index + QuantizedCPU: quantized_index # NB: This function is special-cased in tools/autograd/gen_variable_type.py # NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp: # - Tensor Tensor::index(ArrayRef indices) diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 347df066cc90..34e17c37f774 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -691,6 +691,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) { return DeviceType::Vulkan; } else if (tid == DispatchKey::Metal) { return DeviceType::Metal; + } else if (tid == DispatchKey::QuantizedCPU) { + return DeviceType::CPU; } else { AT_ASSERTM(false, "Unknown DispatchKey: ", tid); } diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 1c66c8fb986f..f1e52fc38d32 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2274,6 +2274,45 @@ def test_empty_batch(self): result = torch.ops.quantized.linear_dynamic(X, w_packed) self.assertEqual(result.shape, (0, 2)) + def test_advanced_indexing(self): + """ + Verifies that the x[:, [0], :, :] syntax works for quantized tensors. + """ + for dtype in (torch.qint8, torch.quint8, torch.qint32): + scale = 0.1 + zp = 0 + x_q = torch.quantize_per_tensor( + torch.randn(1, 4, 4, 4), scale, zp, dtype) + # reference + x_fp32 = x_q.dequantize() + + # single dim, single index + x_q_s1 = x_q[:, [0], :, :] + x_fp32_s1 = x_fp32[:, [0], :, :] + x_fp32_s1_ref = \ + torch.quantize_per_tensor(x_fp32_s1, scale, zp, dtype) + self.assertEqual(x_q_s1, x_fp32_s1_ref) + + # multiple dim, single index + x_q_s2 = x_q[:, [0], [2], :] + x_fp32_s2 = x_fp32[:, [0], [2], :] + x_fp32_s2_ref = \ + torch.quantize_per_tensor(x_fp32_s2, scale, zp, dtype) + self.assertEqual(x_q_s2, x_fp32_s2_ref) + + # single dim, multiple indices + x_q_s3 = x_q[:, [2, 0, 1], :, :] + x_fp32_s3 = x_fp32[:, [2, 0, 1], :, :] + x_fp32_s3_ref = \ + torch.quantize_per_tensor(x_fp32_s3, scale, zp, dtype) + self.assertEqual(x_q_s3, x_fp32_s3_ref) + + # multiple dim, multiple indices + x_q_s4 = x_q[:, [2, 0, 1], :, [1]] + x_fp32_s4 = x_fp32[:, [2, 0, 1], :, [1]] + x_fp32_s4_ref = \ + torch.quantize_per_tensor(x_fp32_s4, scale, zp, dtype) + self.assertEqual(x_q_s4, x_fp32_s4_ref) class TestDynamicQuantizedLinear(TestCase):