Skip to content

Commit

Permalink
Enable QConv2d with hardtanh post op
Browse files Browse the repository at this point in the history
ghstack-source-id: fb1c31d11febaa337bb899b18b792dfe3e42a8c8
Pull Request resolved: pytorch#114578
  • Loading branch information
leslie-fang-intel committed Nov 27, 2023
1 parent c9225a9 commit 9ef68ee
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 3 deletions.
22 changes: 19 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1630,7 +1630,23 @@ static at::Tensor _quantized_convolution_onednn(
dst.set_scale(accum_ideep_scale);
dst.set_zero_point(accum_ideep_zero_points);
} else {
op_attr = (has_unary_post_op && unary_attr.value()=="relu") ? ideep::attr_t::fuse_relu() : ideep::attr_t();
if (has_unary_post_op && unary_attr.value()=="relu") {
op_attr = ideep::attr_t::fuse_relu();
} else if (has_unary_post_op && unary_attr.value()=="hardtanh") {
TORCH_CHECK(
unary_scalars.size() == 2 &&
unary_scalars[0].get().toOptional<at::Scalar>().has_value() &&
unary_scalars[1].get().toOptional<at::Scalar>().has_value(),
"hardtanh is expected to have two scalar input: min_val and max_val");

auto lower_bound_value =
unary_scalars[0].get().toOptional<at::Scalar>().value().to<float>();
auto upper_bound_value =
unary_scalars[1].get().toOptional<at::Scalar>().value().to<float>();
op_attr = ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
} else {
op_attr = ideep::attr_t();
}
}

// Weight Reorder
Expand Down Expand Up @@ -1821,8 +1837,8 @@ class QConvoneDNN final {
} else {
// Conv2D post op check
TORCH_CHECK(
attr == "none" || attr == "relu",
"none post_op or post_op relu is supported for quantized pointwise conv2d. Got unary_post_op: ",
attr == "none" || attr == "relu" || attr == "hardtanh",
"none post_op or post_op relu/hardtanh is supported for quantized pointwise conv2d. Got unary_post_op: ",
attr,
".")
}
Expand Down
55 changes: 55 additions & 0 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6549,6 +6549,11 @@ def _test_qconv_impl_cpu_tensor(
assert not use_transpose, "Cannot fuse ReLU with ConvTranspose"
relu = torch.nn.ReLU()
result_ref = relu(result_ref)
elif post_op.unary_attr == "hardtanh":
assert not use_transpose, "Cannot fuse hardtanh with ConvTranspose"
assert len(post_op.scalars) == 2, "For post op hardtanh, expect 2 parameters passed in"
hardtanh = torch.nn.Hardtanh(min_val=post_op.scalars[0], max_val=post_op.scalars[1])
result_ref = hardtanh(result_ref)

# Quantize reference results for comparison
result_ref_q = torch.quantize_per_tensor(
Expand Down Expand Up @@ -6891,6 +6896,56 @@ def test_qconv2d_relu_pt2e(self):
qconv_output_dtype=output_dtype,
)

# Test qconv with post op hardtanh
@skipIfNoONEDNN
def test_qconv2d_hardtanh_pt2e(self):
input_channels_per_group = 2
output_channels_per_group = 2
groups_list = [1, 10]
input_feature_map_shape = (10, 10)
kernels = (3, 3)
strides = (2, 2)
pads = (1, 1)
dilations = (1, 1)
W_scale = [1.5]
W_zero_point = [0]
use_bias_list = [False, True]
use_channelwise_list = [False, True]
output_dtype_list = [None, torch.float32, torch.bfloat16]
options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
for groups, use_bias, use_channelwise, output_dtype in options:
qconv = torch.ops.onednn.qconv2d_pointwise
qconv_prepack = torch.ops.onednn.qconv_prepack
conv_op = torch.nn.Conv2d(
input_channels_per_group * groups,
output_channels_per_group * groups,
kernels,
strides,
pads,
dilations,
groups,
)
pointwise_post_op = PointwisePostOp(unary_attr="hardtanh", scalars=[0.0, 6.0])
self._test_qconv_impl_cpu_tensor(
qconv,
qconv_prepack,
conv_op,
input_channels_per_group=input_channels_per_group,
input_feature_map_shape=input_feature_map_shape,
output_channels_per_group=output_channels_per_group,
groups=groups,
kernels=kernels,
strides=strides,
pads=pads,
dilations=dilations,
W_scale=W_scale,
W_zero_point=W_zero_point,
use_bias=use_bias,
post_op=pointwise_post_op,
use_channelwise=use_channelwise,
qconv_output_dtype=output_dtype,
)

# Test qconv with post op add
@skipIfNoONEDNN
def test_qconv2d_add_pt2e(self):
Expand Down

0 comments on commit 9ef68ee

Please sign in to comment.