From d040a1429fb82b438d3fa4d31a112f868f84367b Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Mon, 27 Nov 2023 15:14:40 +0800 Subject: [PATCH] Add Hardtanh and ReLU6 into X86InductorQuantizer Conv2d Unary post op ghstack-source-id: d84ce1494ddac824adf30a6736dfb7cbbf936839 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114579 --- .../pt2e/test_x86inductor_quantizer.py | 39 +++++++++++++------ .../quantizer/x86_inductor_quantizer.py | 30 +++++++++++--- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index c1616f960b21c..ecccf4f6db963 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -42,11 +42,11 @@ def forward(self, x): x = self.bn(x) return x - class Conv2dReLUModule(torch.nn.Module): - def __init__(self, inplace_relu: bool = False, use_bias: bool = False, with_bn=False) -> None: + class Conv2dUnaryModule(torch.nn.Module): + def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None: super().__init__() self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias) - self.relu = nn.ReLU(inplace=inplace_relu) + self.post_op = post_op self.bn = torch.nn.BatchNorm2d(6) self.with_bn = with_bn @@ -54,7 +54,7 @@ def forward(self, x): x = self.conv(x) if self.with_bn: x = self.bn(x) - x = self.relu(x) + x = self.post_op(x) return x class Conv2dAddModule(torch.nn.Module): @@ -361,11 +361,18 @@ def test_conv2d_unary(self): Test pattern of conv2d with unary post ops (such as relu, sigmoid) with X86InductorQuantizer. Currently, only relu as unary post op is supported. """ - inplace_relu_list = [True, False] + unary_map = { + "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], + "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], + "hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default], + "hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default], + "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], + "relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default] + } use_bias_list = [True, False] with override_quantized_engine("x86"), torch.no_grad(): - for inplace_relu, use_bias in itertools.product(inplace_relu_list, use_bias_list): - m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, use_bias=use_bias).eval() + for unary_op, use_bias in itertools.product(unary_map.keys(), use_bias_list): + m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], use_bias=use_bias).eval() example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config() @@ -382,7 +389,7 @@ def test_conv2d_unary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, - torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default, + unary_map[unary_op][1], ] self._test_quantizer( m, @@ -993,10 +1000,18 @@ def test_qat_conv2d_unary(self): Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer. Currently, only relu as unary post op is supported. """ - inplace_relu_list = [True, False] + unary_map = { + "relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default], + "relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default], + "hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default], + "hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default], + "relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default], + "relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default] + } + with override_quantized_engine("x86"): - for inplace_relu in itertools.product(inplace_relu_list): - m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, with_bn=True) + for unary_op in unary_map.keys(): + m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], with_bn=True) example_inputs = (torch.randn(2, 3, 16, 16),) quantizer = X86InductorQuantizer().set_global( xiq.get_default_x86_inductor_quantization_config(is_qat=True) @@ -1015,7 +1030,7 @@ def test_qat_conv2d_unary(self): torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.aten.conv2d.default, - torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default, + unary_map[unary_op][1], torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, ] diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index cdbc85f0fb12d..cf0540f7878f1 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -548,9 +548,18 @@ def _annotate_qat_conv2d_bn_binary( def _annotate_qat_conv2d_bn_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - fused_partitions = find_sequential_partitions( - gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU] - ) + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + for fused_partition in fused_partitions: conv_partition, bn_partition, unary_partition = fused_partition ( @@ -717,9 +726,18 @@ def _annotate_conv2d_binary( def _annotate_conv2d_unary( self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig ) -> None: - fused_partitions = find_sequential_partitions( - gm, [torch.nn.Conv2d, torch.nn.ReLU] - ) + fused_partitions = [] + unary_patterns = [ + [torch.nn.Conv2d, torch.nn.ReLU], + [torch.nn.Conv2d, torch.nn.Hardtanh], + [torch.nn.Conv2d, torch.nn.ReLU6], + ] + for unary_pattern in unary_patterns: + partitions = find_sequential_partitions(gm, unary_pattern) + if partitions: + # Extend the fused_partitions if partitions is not empty + fused_partitions.extend(partitions) + for fused_partition in fused_partitions: conv_partition, unary_partition = fused_partition conv_node, unary_node = self._get_output_nodes_of_partitions(