Skip to content

Commit

Permalink
Add Hardtanh and ReLU6 into X86InductorQuantizer Conv2d Unary post op
Browse files Browse the repository at this point in the history
ghstack-source-id: d84ce1494ddac824adf30a6736dfb7cbbf936839
Pull Request resolved: pytorch#114579
  • Loading branch information
leslie-fang-intel committed Nov 27, 2023
1 parent 9ef68ee commit d040a14
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 18 deletions.
39 changes: 27 additions & 12 deletions test/quantization/pt2e/test_x86inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ def forward(self, x):
x = self.bn(x)
return x

class Conv2dReLUModule(torch.nn.Module):
def __init__(self, inplace_relu: bool = False, use_bias: bool = False, with_bn=False) -> None:
class Conv2dUnaryModule(torch.nn.Module):
def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias)
self.relu = nn.ReLU(inplace=inplace_relu)
self.post_op = post_op
self.bn = torch.nn.BatchNorm2d(6)
self.with_bn = with_bn

def forward(self, x):
x = self.conv(x)
if self.with_bn:
x = self.bn(x)
x = self.relu(x)
x = self.post_op(x)
return x

class Conv2dAddModule(torch.nn.Module):
Expand Down Expand Up @@ -361,11 +361,18 @@ def test_conv2d_unary(self):
Test pattern of conv2d with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
Currently, only relu as unary post op is supported.
"""
inplace_relu_list = [True, False]
unary_map = {
"relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default],
"relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default],
"hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default],
"hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default],
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default]
}
use_bias_list = [True, False]
with override_quantized_engine("x86"), torch.no_grad():
for inplace_relu, use_bias in itertools.product(inplace_relu_list, use_bias_list):
m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, use_bias=use_bias).eval()
for unary_op, use_bias in itertools.product(unary_map.keys(), use_bias_list):
m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], use_bias=use_bias).eval()
example_inputs = (torch.randn(2, 3, 16, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
Expand All @@ -382,7 +389,7 @@ def test_conv2d_unary(self):
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
unary_map[unary_op][1],
]
self._test_quantizer(
m,
Expand Down Expand Up @@ -993,10 +1000,18 @@ def test_qat_conv2d_unary(self):
Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
Currently, only relu as unary post op is supported.
"""
inplace_relu_list = [True, False]
unary_map = {
"relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default],
"relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default],
"hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default],
"hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default],
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default]
}

with override_quantized_engine("x86"):
for inplace_relu in itertools.product(inplace_relu_list):
m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, with_bn=True)
for unary_op in unary_map.keys():
m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], with_bn=True)
example_inputs = (torch.randn(2, 3, 16, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
Expand All @@ -1015,7 +1030,7 @@ def test_qat_conv2d_unary(self):
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
unary_map[unary_op][1],
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
Expand Down
30 changes: 24 additions & 6 deletions torch/ao/quantization/quantizer/x86_inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,18 @@ def _annotate_qat_conv2d_bn_binary(
def _annotate_qat_conv2d_bn_unary(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU]
)
fused_partitions = []
unary_patterns = [
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU],
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh],
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6],
]
for unary_pattern in unary_patterns:
partitions = find_sequential_partitions(gm, unary_pattern)
if partitions:
# Extend the fused_partitions if partitions is not empty
fused_partitions.extend(partitions)

for fused_partition in fused_partitions:
conv_partition, bn_partition, unary_partition = fused_partition
(
Expand Down Expand Up @@ -717,9 +726,18 @@ def _annotate_conv2d_binary(
def _annotate_conv2d_unary(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.ReLU]
)
fused_partitions = []
unary_patterns = [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, torch.nn.Hardtanh],
[torch.nn.Conv2d, torch.nn.ReLU6],
]
for unary_pattern in unary_patterns:
partitions = find_sequential_partitions(gm, unary_pattern)
if partitions:
# Extend the fused_partitions if partitions is not empty
fused_partitions.extend(partitions)

for fused_partition in fused_partitions:
conv_partition, unary_partition = fused_partition
conv_node, unary_node = self._get_output_nodes_of_partitions(
Expand Down

0 comments on commit d040a14

Please sign in to comment.