From efacc46543fecdf4b513d9f55ee73cdb0d602108 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 2 Nov 2023 16:15:44 +0800 Subject: [PATCH] Enable onednn.QConv FP32/BF16 output ghstack-source-id: af28fba9e9a2d6a03f9efe18496a9c173143ecd2 Pull Request resolved: https://github.com/pytorch/pytorch/pull/112010 --- aten/src/ATen/native/quantized/cpu/qconv.cpp | 61 +++++--- aten/src/ATen/native/quantized/library.cpp | 8 +- .../check_forward_backward_compatibility.py | 4 + test/quantization/core/test_quantized_op.py | 133 ++++++++++++++---- torch/_inductor/fx_passes/quantization.py | 23 +-- torch/_inductor/ir.py | 33 ++--- torch/_inductor/lowering.py | 8 +- torch/_meta_registrations.py | 5 +- 8 files changed, 191 insertions(+), 84 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 05c71eb4ef254..ce1cdd4d6169a 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1388,7 +1388,7 @@ static at::Tensor _quantized_convolution_onednn( c10::optional accum, // accum to fused with conv add double accum_scale, int64_t accum_zero_point, - bool fp32_output, + c10::optional output_dtype, c10::optional binary_attr, c10::optional binary_alpha, c10::optional unary_attr, @@ -1402,13 +1402,15 @@ static at::Tensor _quantized_convolution_onednn( // inv_scale = 1.0 / scale will be folded. // So, we can only get inv_scale from quant node which is used as // output_scale of this op. - if (fp32_output) { - // When fp32_output, oneDNN expects op_attr doesn't set_scales and set_zero_points. + bool fp32_output = output_dtype.has_value() && (output_dtype.value() == c10::kFloat); + bool bfloat16_output = output_dtype.has_value() && (output_dtype.value() == c10::kBFloat16); + if (fp32_output || bfloat16_output) { + // When fp32 or bf16 output, oneDNN expects op_attr doesn't set_scales and set_zero_points. // So, we will use default inv_output_scale as 1.0 and output_zero_point as 0, since // when inv_output_scale is 1.0, we will skip invoking of op_attr.set_scales in ideep; // when output_zero_point is 0, we will skip invoking of op_attr.set_zero_points in ideep. - TORCH_CHECK(inv_output_scale == 1.0, " (ONEDNN): fp32 output, inv_output_scale must be 1.0."); - TORCH_CHECK(output_zero_point == 0, " (ONEDNN): fp32 output, output_zero_point must be 0"); + TORCH_CHECK(inv_output_scale == 1.0, " (ONEDNN): fp32 or bf16 output, inv_output_scale must be 1.0."); + TORCH_CHECK(output_zero_point == 0, " (ONEDNN): fp32 or bf16 output, output_zero_point must be 0"); } int kSpatialDim = act.dim() - 2; @@ -1417,7 +1419,14 @@ static at::Tensor _quantized_convolution_onednn( bool has_binary_post_op = binary_attr.has_value() && binary_attr.value() != "none"; bool has_unary_post_op = unary_attr.has_value() && unary_attr.value() != "none"; // has_accum_postop_sum: extra input besides the conv to do conv add fusion with post op sum. - bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "add" && !fp32_output; + bool has_accum_postop_sum = has_binary_post_op && binary_attr.value() == "add"; + + if (has_accum_postop_sum && (fp32_output || bfloat16_output)) { + TORCH_CHECK(accum_scale == 1.0, " (ONEDNN): fp32 or bf16 output, accum_scale must be 1.0."); + TORCH_CHECK(accum_zero_point == 0, " (ONEDNN): fp32 or bf16 output, accum_zero_point must be 0"); + TORCH_CHECK((accum.value().scalar_type() == c10::kFloat) || (accum.value().scalar_type() == c10::kBFloat16), "The accum tensor should be KFloat or KBFloat."); + } + std::string func_name = "quantized::packed_weights_conv"; func_name += std::to_string(kSpatialDim) + "d"; if (has_binary_post_op) { @@ -1523,14 +1532,17 @@ static at::Tensor _quantized_convolution_onednn( ideep::tensor onednn_bias; const int output_channels = weight.size(0); bool with_bias = bias.has_value(); + + at::Tensor bias_val_float; if (with_bias) { - at::Tensor bias_val = bias.value(); - TORCH_CHECK(bias_val.dim() == 1, "bias should be a vector (1D Tensor)"); + // For int8-mixed-bf16, we will also use float32 bias + bias_val_float = bias.value().to(at::kFloat); + TORCH_CHECK(bias_val_float.dim() == 1, "bias should be a vector (1D Tensor)"); TORCH_CHECK( - bias_val.size(0) == output_channels, + bias_val_float.size(0) == output_channels, "bias should have K elements: " + std::to_string(output_channels)); - auto bias_desc = ideep::tensor::desc(bias.value().sizes().vec(), dnnl::memory::data_type::f32); - onednn_bias.init(bias_desc, bias.value().data_ptr()); + auto bias_desc = ideep::tensor::desc(bias_val_float.sizes().vec(), dnnl::memory::data_type::f32); + onednn_bias.init(bias_desc, bias_val_float.data_ptr()); } const auto& expected_bias = with_bias ? onednn_bias : ideep::tensor(); @@ -1556,11 +1568,11 @@ static at::Tensor _quantized_convolution_onednn( ideep::dims dst_dims = ideep::dims({output_sizes.cbegin(), output_sizes.cend()}); // Output is not a quantized tensor but data type is uint8 at::Tensor output; - if (fp32_output) { + if (fp32_output || bfloat16_output) { output = at::empty( dst_dims, device(c10::kCPU) - .dtype(c10::kFloat) + .dtype(fp32_output ? c10::kFloat : c10::kBFloat16) .memory_format(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d), @@ -1581,16 +1593,25 @@ static at::Tensor _quantized_convolution_onednn( ideep::tensor dst; at::Tensor accum_contig; if (has_accum_postop_sum) { - auto dst_desc = ideep::tensor::desc(dst_dims, src_data_type, + auto dst_desc = ideep::tensor::desc(dst_dims, fp32_output ? ideep::tensor::data_type::f32 : ( + bfloat16_output ? ideep::tensor::data_type::bf16 : src_data_type), kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc); accum_contig = accum.value().contiguous(kSpatialDim == 2 ? c10::MemoryFormat::ChannelsLast : c10::MemoryFormat::ChannelsLast3d); + if (fp32_output || bfloat16_output) { + TORCH_CHECK((output.scalar_type() == c10::kFloat) || (output.scalar_type() == c10::kBFloat16), "The output tensor should be KFloat or KBFloat."); + if (accum_contig.scalar_type() != output.scalar_type()) { + // accum_contig is KFloat32 and we expect a kBFloat16 output + // or accum_contig is kBFloat16 and we expect a KFloat32 output + accum_contig = accum_contig.to(output.scalar_type()); + } + } TORCH_CHECK(accum_contig.dtype() == output.dtype(), "The output tensor should have same dtype as the accum tensor."); // When fused with sum, the dst tensor will share the data ptr as the accum tensor. dst.init(dst_desc, accum_contig.data_ptr()); } else { - if (fp32_output) { + if (fp32_output || bfloat16_output) { // Conv without add: int8-in, fp32-output - dst = ideep::tensor({dst_dims, ideep::tensor::data_type::f32, {output.strides().cbegin(), output.strides().cend()}}, + dst = ideep::tensor({dst_dims, fp32_output ? ideep::tensor::data_type::f32 : ideep::tensor::data_type::bf16, {output.strides().cbegin(), output.strides().cend()}}, output.data_ptr()); } else { dst = ideep::tensor({dst_dims, ideep::tensor::data_type::u8, {output.strides().cbegin(), output.strides().cend()}}, @@ -1782,7 +1803,7 @@ class QConvoneDNN final { int64_t groups, double inv_output_scale, // inv_output_scale is the reciprocal of scale in fake quant int64_t output_zero_point, - bool fp32_output, + c10::optional output_dtype, c10::string_view attr, torch::List> scalars, c10::optional algorithm) { @@ -1810,7 +1831,7 @@ class QConvoneDNN final { bias, stride, padding, dilation, /*transposed*/false, groups, inv_output_scale, output_zero_point, /*accum*/c10::nullopt, /*accum_scale*/0.0, /*accum_zero_point*/0, - /*fp32_output*/fp32_output, /*binary_attr*/c10::nullopt, /*binary_alpha*/c10::nullopt, + /*output_dtype*/output_dtype, /*binary_attr*/c10::nullopt, /*binary_alpha*/c10::nullopt, /*unary_attr*/attr, /*unary_scalars*/scalars, /*unary_algorithm*/algorithm ); #else @@ -1834,7 +1855,7 @@ class QConvoneDNN final { int64_t groups, double inv_output_scale, // inv_output_scale is the reciprocal of scale in fake quant int64_t output_zero_point, - bool fp32_output, + c10::optional output_dtype, c10::string_view binary_attr, c10::optional alpha, c10::optional unary_attr, @@ -1862,7 +1883,7 @@ class QConvoneDNN final { bias, stride, padding, dilation, /*transposed*/false, groups, inv_output_scale, output_zero_point, accum, accum_scale, accum_zero_point, - /*fp32_output*/false, binary_attr, alpha, + /*output_dtype*/output_dtype, binary_attr, alpha, unary_attr, unary_scalars, unary_algorithm ); #else diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 3d59dae2007a6..0fc2f7eb66f7c 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -258,12 +258,12 @@ TORCH_LIBRARY(onednn, m) { m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_prepack(Tensor weight, Tensor w_scales, float x_scale, int x_zp, int[] stride, int[] padding, int[] dilation, int groups, int[]? x_shape=None) -> Tensor")); // Conv1D/2D/3D with unary postop - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv1d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv1d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv3d_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); // Conv2D with binary postop - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qaccum, float accum_scale, int accum_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, bool fp32_output, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qaccum, float accum_scale, int accum_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float inv_output_scale, int output_zero_point, ScalarType? output_dtype, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); // Linear prepack m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor")); diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 655e110656cdc..8c102d605e9fb 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -312,6 +312,10 @@ ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), + ("onednn::qconv1d_pointwise", datetime.date(2023, 12, 31)), + ("onednn::qconv2d_pointwise", datetime.date(2023, 12, 31)), + ("onednn::qconv3d_pointwise", datetime.date(2023, 12, 31)), + ("onednn::qconv2d_pointwise.binary", datetime.date(2023, 12, 31)), ("onednn::qlinear_pointwise", datetime.date(2023, 12, 31)), ] diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 7f0c48c36a6d5..b8d640126dc8e 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -6446,15 +6446,20 @@ def _test_qconv_impl_cpu_tensor( use_channelwise=True, X2_scale=1.2, X2_zero_point=0, - fp32_output=False, + qconv_output_dtype=None, # None, torch.float32, torch.bfloat16 weight_in_channel_last_format=False, + qconv_x2_dtype=None, ): # ONEDNN only supports symmetric quantization of weight if W_zero_point is not None: W_zero_point = len(W_zero_point) * [0] - if fp32_output: + fp32_output = True if qconv_output_dtype is torch.float32 else False + bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False + if fp32_output or bfloat16_output: Y_scale = 1.0 Y_zero_point = 0 + X2_scale = 1.0 + X2_zero_point = 0 batch_size = 3 o_pads = None device = torch.device("cpu") @@ -6557,7 +6562,9 @@ def _test_qconv_impl_cpu_tensor( X_q_cpu_tensor, X_scale, X_zero_point, - X2_q_cpu_tensor, + X2_q_cpu_tensor + if qconv_output_dtype is None + else X2_q.dequantize().to(qconv_x2_dtype), X2_scale, X2_zero_point, packed_weight, @@ -6570,7 +6577,7 @@ def _test_qconv_impl_cpu_tensor( groups, 1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant Y_zero_point, - fp32_output, + qconv_output_dtype, post_op.binary_attr, post_op.alpha, post_op.unary_attr, @@ -6592,16 +6599,18 @@ def _test_qconv_impl_cpu_tensor( groups, 1.0 / Y_scale, # Kernel expects pass in reciprocal of scale in fake quant Y_zero_point, - fp32_output, + qconv_output_dtype, post_op.unary_attr, post_op.scalars, post_op.algorithm, ) - if fp32_output: - self.assertTrue(Y_q_cpu_tensor.dtype == torch.float32) - Y_q_cpu_tensor = torch.quantize_per_tensor( - Y_q_cpu_tensor, scale=Y_scale, zero_point=Y_zero_point, dtype=output_dtype - ).int_repr() + if fp32_output or bfloat16_output: + self.assertTrue(Y_q_cpu_tensor.dtype == qconv_output_dtype) + Y_q_cpu_tensor = torch.quantize_per_tensor( + Y_q_cpu_tensor + if fp32_output + else Y_q_cpu_tensor.to(torch.float32), scale=Y_scale, zero_point=Y_zero_point, dtype=output_dtype + ).int_repr() # Make sure the results match # assert_array_almost_equal compares using the following formula: @@ -6642,10 +6651,10 @@ def test_qconv1d_pt2e(self): W_zero_point = [0] use_bias_list = [False, True] use_channelwise_list = [False, True] - fp32_output_list = [False, True] - options = itertools.product(groups_list, use_bias_list, use_channelwise_list, fp32_output_list) - for groups, use_bias, use_channelwise, fp32_output in options: - if fp32_output and not (use_bias and use_channelwise): + output_dtype_list = [None, torch.float32, torch.bfloat16] + options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) + for groups, use_bias, use_channelwise, output_dtype in options: + if output_dtype is not None and not (use_bias and use_channelwise): # Remove some test combination to reduce UT test time continue conv1d = torch.nn.Conv1d( @@ -6677,7 +6686,7 @@ def test_qconv1d_pt2e(self): use_bias=use_bias, post_op=pointwise_post_op, use_channelwise=use_channelwise, - fp32_output=fp32_output, + qconv_output_dtype=output_dtype, ) @skipIfNoONEDNN @@ -6695,16 +6704,16 @@ def test_qconv2d_pt2e(self): use_bias_list = [False, True] use_channelwise_list = [False, True] channel_last_weight_format_list = [False, True] - fp32_output_list = [False, True] + output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product( groups_list, use_bias_list, use_channelwise_list, channel_last_weight_format_list, - fp32_output_list, + output_dtype_list, ) - for groups, use_bias, use_channelwise, channel_last_weight_format, fp32_output in options: - if (fp32_output or channel_last_weight_format) and not (use_bias and use_channelwise): + for groups, use_bias, use_channelwise, channel_last_weight_format, output_dtype in options: + if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise): # Remove some test combination to reduce UT test time continue qconv = torch.ops.onednn.qconv2d_pointwise @@ -6736,7 +6745,7 @@ def test_qconv2d_pt2e(self): use_bias=use_bias, post_op=pointwise_post_op, use_channelwise=use_channelwise, - fp32_output=fp32_output, + qconv_output_dtype=output_dtype, weight_in_channel_last_format=channel_last_weight_format, ) @@ -6755,16 +6764,16 @@ def test_qconv3d_pt2e(self): use_bias_list = [False, True] use_channelwise_list = [False, True] channel_last_weight_format_list = [False, True] - fp32_output_list = [False, True] + output_dtype_list = [None, torch.float32, torch.bfloat16] options = itertools.product( groups_list, use_bias_list, use_channelwise_list, channel_last_weight_format_list, - fp32_output_list, + output_dtype_list, ) - for groups, use_bias, use_channelwise, channel_last_weight_format, fp32_output in options: - if (fp32_output or channel_last_weight_format) and not (use_bias and use_channelwise): + for groups, use_bias, use_channelwise, channel_last_weight_format, output_dtype in options: + if (output_dtype is not None or channel_last_weight_format) and not (use_bias and use_channelwise): # Remove some test combination to reduce UT test time continue qconv = torch.ops.onednn.qconv3d_pointwise @@ -6796,7 +6805,7 @@ def test_qconv3d_pt2e(self): use_bias=use_bias, post_op=pointwise_post_op, use_channelwise=use_channelwise, - fp32_output=fp32_output, + qconv_output_dtype=output_dtype, weight_in_channel_last_format=channel_last_weight_format, ) @@ -6815,8 +6824,9 @@ def test_qconv2d_relu_pt2e(self): W_zero_point = [0] use_bias_list = [False, True] use_channelwise_list = [False, True] - options = itertools.product(groups_list, use_bias_list, use_channelwise_list) - for groups, use_bias, use_channelwise in options: + output_dtype_list = [None, torch.float32, torch.bfloat16] + options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list) + for groups, use_bias, use_channelwise, output_dtype in options: qconv = torch.ops.onednn.qconv2d_pointwise qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( @@ -6846,6 +6856,7 @@ def test_qconv2d_relu_pt2e(self): use_bias=use_bias, post_op=pointwise_post_op, use_channelwise=use_channelwise, + qconv_output_dtype=output_dtype, ) # Test qconv with post op add @@ -6863,11 +6874,12 @@ def test_qconv2d_add_pt2e(self): W_zero_point = [-3] use_bias_list = [False, True] use_channelwise_list = [False, True] + output_dtype_list = [None, torch.float32, torch.bfloat16] X2_zero_point_list = [0, 1] options = itertools.product( - groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list + groups_list, use_bias_list, use_channelwise_list, X2_zero_point_list, output_dtype_list ) - for groups, use_bias, use_channelwise, X2_zero_point in options: + for groups, use_bias, use_channelwise, X2_zero_point, output_dtype in options: qconv = torch.ops.onednn.qconv2d_pointwise.binary qconv_prepack = torch.ops.onednn.qconv_prepack conv_op = torch.nn.Conv2d( @@ -6898,6 +6910,8 @@ def test_qconv2d_add_pt2e(self): post_op=pointwise_post_op, use_channelwise=use_channelwise, X2_zero_point=X2_zero_point, + qconv_output_dtype=output_dtype, + qconv_x2_dtype=output_dtype, ) # Test qconv with post op add relu @@ -6952,6 +6966,67 @@ def test_qconv2d_add_relu_pt2e(self): X2_zero_point=X2_zero_point, ) + # Test qconv with post op add + @skipIfNoONEDNN + def test_qconv2d_add_relu_float_output_pt2e(self): + groups = 1 + input_channels_per_group = 2 + output_channels_per_group = 2 + input_feature_map_shape = (10, 10) + kernels = (3, 3) + strides = (2, 2) + pads = (1, 1) + dilations = (1, 1) + W_scale = [1.5] + W_zero_point = [-3] + use_bias_list = [False, True] + use_channelwise = True + qconv_x2_dtype_list = [torch.float32, torch.bfloat16] + output_dtype_list = [torch.float32, torch.bfloat16] + X2_zero_point = 0 + use_relu_list = [True, False] + options = itertools.product( + use_bias_list, output_dtype_list, qconv_x2_dtype_list, use_relu_list + ) + for use_bias, output_dtype, qconv_x2_dtype, use_relu in options: + qconv = torch.ops.onednn.qconv2d_pointwise.binary + qconv_prepack = torch.ops.onednn.qconv_prepack + conv_op = torch.nn.Conv2d( + input_channels_per_group * groups, + output_channels_per_group * groups, + kernels, + strides, + pads, + dilations, + groups, + ) + pointwise_post_op = ( + PointwisePostOp(binary_attr="add", unary_attr="relu") + if use_relu + else PointwisePostOp(binary_attr="add") + ) + self._test_qconv_impl_cpu_tensor( + qconv, + qconv_prepack, + conv_op, + input_channels_per_group=input_channels_per_group, + input_feature_map_shape=input_feature_map_shape, + output_channels_per_group=output_channels_per_group, + groups=groups, + kernels=kernels, + strides=strides, + pads=pads, + dilations=dilations, + W_scale=W_scale, + W_zero_point=W_zero_point, + use_bias=use_bias, + post_op=pointwise_post_op, + use_channelwise=use_channelwise, + X2_zero_point=X2_zero_point, + qconv_output_dtype=output_dtype, + qconv_x2_dtype=qconv_x2_dtype, + ) + class TestPadding(TestCase): @given(batch_size=st.integers(1, 64), channels=st.integers(1, 64), diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index a46b8db6aa697..6fee5b5e6558d 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -69,7 +69,7 @@ KeywordArg("groups"), KeywordArg("inv_output_scale"), # inv_output_scale = 1.0 KeywordArg("output_zero_point"), # output_zero_point = 0 - KeywordArg("fp32_output"), # fp32_output = True + KeywordArg("output_dtype"), # output_dtype = None KeywordArg("attr"), # attr = "none" Arg(), # scalars Arg(), # algorithm @@ -164,7 +164,7 @@ def _register_quantized_conv_lowering( pattern, pass_number, computation_op, - fp32_output, + output_dtype, unary_attr, ): @register_lowering_pattern(pattern, pass_number=pass_number) @@ -195,7 +195,7 @@ def qconv(match: Match, *args, **kwargs): kwargs["o_zp"], ) assert ( - kwargs["fp32_output"] is True + kwargs["output_dtype"] is torch.float32 ) # Expected int8-in fp32-out qconv in weight prepack phase assert ( kwargs["attr"] == "none" @@ -214,7 +214,7 @@ def qconv(match: Match, *args, **kwargs): groups, o_inv_scale, o_zero_point, - fp32_output, + output_dtype, unary_attr.op_name, unary_attr.scalars_attr, unary_attr.algorithm_attr, @@ -285,7 +285,7 @@ def _register_quantized_conv_binary_lowering( pattern, pass_number, computation_op, - fp32_output, + output_dtype, binary_unary_attr, ): @register_lowering_pattern(pattern, pass_number=pass_number) @@ -330,7 +330,7 @@ def qconv_binary(match: Match, *args, **kwargs): groups, o_inv_scale, o_zero_point, - fp32_output, + output_dtype, binary_unary_attr.binary_op_name, binary_unary_attr.alpha, binary_unary_attr.unary_op_name, @@ -366,7 +366,7 @@ def __init__(self, op_name: str, scalars_attr=None, algorithm_attr=None): patterns, 1 if unary_attr.op_name != "none" else 2, # pass_number torch.ops.onednn.qconv2d_pointwise, # computation_op - False, # fp32_output + None, # output_dtype, None is the default value for int8 output unary_attr, # unary_attr ) @@ -431,7 +431,7 @@ def __init__( patterns, 0 if binary_unary_attr.unary_op_name != "none" else 1, # pass_number torch.ops.onednn.qconv2d_pointwise.binary, # computation_op - False, # fp32_output + None, # output_dtype binary_unary_attr, # binary_unary_attr ) @@ -813,7 +813,7 @@ def qconv_weight_prepack(match: Match, *args, **kwargs): groups, 1.0, # inv_output_scale 0, # output_zero_point - True, # fp32_output + torch.float32, # output_dtype "none", # attr [], # scalars "", # algorithm @@ -1015,13 +1015,18 @@ def _generate_qlinear_weight_prepack_patterns(): @functools.lru_cache(None) def _register_quantization_weight_pack_pass(): + # Step 1: Dequant promotion _register_dequant_promotion_pass( dequantize_per_tensor_activation_pattern, pass_number=0 ) # pass_number=0 to run before weight prepack + + # Step 2: QConv weight prepack weight_prepack_patterns = _generate_qconv_weight_prepack_patterns() for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. _register_qconv_weight_prepack_pass(weight_prepack_pattern, pass_number=1) + + # Step 3: QLinear weight prepack weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns() for weight_prepack_pattern in weight_prepack_patterns: # Register to pass_number 1, so we can do dequant promotion in pass_number 0. diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 522170486b961..f42484187246e 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -5431,7 +5431,7 @@ def __init__( int64_t groups, double inv_output_scale, int64_t output_zero_point, - bool fp32_output, + c10::optional output_dtype, c10::string_view attr, torch::List> scalars, c10::optional algorithm)""" @@ -5455,7 +5455,7 @@ def codegen(self, wrapper): x_zp, o_inv_scale, o_zp, - fp32_output, + output_dtype, unary_attr, unary_scalars, unary_algorithm, @@ -5475,7 +5475,7 @@ def codegen(self, wrapper): groups, o_inv_scale, o_zp, - fp32_output, + output_dtype, unary_attr, unary_scalars, unary_algorithm, @@ -5506,7 +5506,7 @@ def create( groups: int, o_inv_scale: float, output_zero_point: int, - fp32_output, + output_dtype, unary_attr, unary_scalars, unary_algorithm, @@ -5539,16 +5539,17 @@ def create( x_zp, o_inv_scale, output_zero_point, - fp32_output, + output_dtype, unary_attr, may_convert_to_optional(unary_scalars), unary_algorithm, ] - if fp32_output: + if output_dtype is not None: + assert output_dtype in [torch.float32, torch.bfloat16] # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set fp32_output, the output buf should be dtype float32 instead of uint8. - kernel_layout.dtype = torch.float32 + # if we set output_dtype is not None, the output buf should be output_dtype instead of uint8. + kernel_layout.dtype = output_dtype return QConvPointWisePT2E( layout=kernel_layout, @@ -5604,7 +5605,7 @@ def __init__( int64_t groups, double inv_output_scale, int64_t output_zero_point, - bool fp32_output, + c10::optional output_dtype, c10::string_view binary_attr, c10::optional alpha, c10::optional attr, @@ -5632,7 +5633,7 @@ def codegen(self, wrapper): accum_zp, o_inv_scale, o_zp, - fp32_output, + output_dtype, binary_attr, alpha, unary_attr, @@ -5656,7 +5657,7 @@ def codegen(self, wrapper): groups, o_inv_scale, o_zp, - fp32_output, + output_dtype, binary_attr, alpha, unary_attr, @@ -5693,7 +5694,7 @@ def create( groups: int, o_inv_scale: "TensorBox", output_zero_point: "TensorBox", - fp32_output, + output_dtype, binary_attr, alpha, unary_attr, @@ -5739,17 +5740,17 @@ def create( accum_zp, o_inv_scale, output_zero_point, - fp32_output, + output_dtype, binary_attr, alpha, unary_attr, may_convert_to_optional(unary_scalars), unary_algorithm, ] - if fp32_output: + if output_dtype is not None: # in _prepare_convolution_fusion_create, we use x.dtype (uint8) to create kernel_layout - # if we set fp32_output, the output buf should be dtype float32 instead of uint8. - kernel_layout.dtype = torch.float32 + # if output_dtype is not None, the output buf should be dtype output_dtype instead of uint8. + kernel_layout.dtype = output_dtype return QConvPointWiseBinaryPT2E( layout=kernel_layout, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 852482945d664..038a7d4c2e716 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1443,7 +1443,7 @@ def qconvolution_unary( groups, o_inv_scale, o_zero_point, - fp32_output, + output_dtype, attr, scalars, algorithm, @@ -1463,7 +1463,7 @@ def qconvolution_unary( groups, o_inv_scale, o_zero_point, - fp32_output, + output_dtype, attr, scalars, algorithm, @@ -1490,7 +1490,7 @@ def qconvolution_binary( groups, o_inv_scale, o_zero_point, - fp32_output, + output_dtype, binary_attr, alpha, unary_attr, @@ -1515,7 +1515,7 @@ def qconvolution_binary( groups, o_inv_scale, o_zero_point, - fp32_output, + output_dtype, binary_attr, alpha, unary_attr, diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 75d5e0607109a..c681fb79a627a 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2124,7 +2124,7 @@ def meta_qconv2d_pointwise( groups, output_scale, output_zero_point, - fp32_output, + output_dtype, attr, scalars, algorithm, @@ -2139,7 +2139,8 @@ def meta_qconv2d_pointwise( groups, None, ) - out = x.new_empty(shape_out, dtype=(torch.float32 if fp32_output else None)) + assert output_dtype in [torch.float32, torch.bfloat16] + out = x.new_empty(shape_out, dtype=output_dtype) out = out.to(memory_format=torch.channels_last) return out