Skip to content

Commit

Permalink
Add Hardtanh and ReLU6 into X86InductorQuantizer Conv2d Unary post op
Browse files Browse the repository at this point in the history
ghstack-source-id: 3ca1fb17ffe7f807c742eb164afb69fb54d38289
Pull Request resolved: pytorch#114579
  • Loading branch information
leslie-fang-intel committed Nov 28, 2023
1 parent 31b5252 commit e1eb1b8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
42 changes: 28 additions & 14 deletions test/quantization/pt2e/test_x86inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ def forward(self, x):
x = self.bn(x)
return x

class Conv2dReLUModule(torch.nn.Module):
def __init__(self, inplace_relu: bool = False, use_bias: bool = False, with_bn=False) -> None:
class Conv2dUnaryModule(torch.nn.Module):
def __init__(self, post_op, use_bias: bool = False, with_bn=False) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 6, (2, 2), stride=(1, 1), padding=(1, 1), bias=use_bias)
self.relu = nn.ReLU(inplace=inplace_relu)
self.post_op = post_op
self.bn = torch.nn.BatchNorm2d(6)
self.with_bn = with_bn

def forward(self, x):
x = self.conv(x)
if self.with_bn:
x = self.bn(x)
x = self.relu(x)
x = self.post_op(x)
return x

class Conv2dAddModule(torch.nn.Module):
Expand Down Expand Up @@ -358,14 +358,20 @@ def test_conv2d(self):
@skipIfNoX86
def test_conv2d_unary(self):
"""
Test pattern of conv2d with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
Currently, only relu as unary post op is supported.
Test pattern of conv2d with unary post ops (such as relu, hardtanh, relu6) with X86InductorQuantizer.
"""
inplace_relu_list = [True, False]
unary_map = {
"relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default],
"relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default],
"hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default],
"hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default],
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default]
}
use_bias_list = [True, False]
with override_quantized_engine("x86"), torch.no_grad():
for inplace_relu, use_bias in itertools.product(inplace_relu_list, use_bias_list):
m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, use_bias=use_bias).eval()
for unary_op, use_bias in itertools.product(unary_map.keys(), use_bias_list):
m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], use_bias=use_bias).eval()
example_inputs = (torch.randn(2, 3, 16, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config()
Expand All @@ -382,7 +388,7 @@ def test_conv2d_unary(self):
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
unary_map[unary_op][1],
]
self._test_quantizer(
m,
Expand Down Expand Up @@ -1026,10 +1032,18 @@ def test_qat_conv2d_unary(self):
Test QAT pattern of conv2d_bn with unary post ops (such as relu, sigmoid) with X86InductorQuantizer.
Currently, only relu as unary post op is supported.
"""
inplace_relu_list = [True, False]
unary_map = {
"relu": [torch.nn.ReLU(inplace=False), torch.ops.aten.relu.default],
"relu_inplace": [torch.nn.ReLU(inplace=True), torch.ops.aten.relu_.default],
"hardtanh": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=False), torch.ops.aten.hardtanh.default],
"hardtanh_inplace": [torch.nn.Hardtanh(min_val=0.0, max_val=6.0, inplace=True), torch.ops.aten.hardtanh_.default],
"relu6": [torch.nn.ReLU6(inplace=False), torch.ops.aten.hardtanh.default],
"relu6_inplace": [torch.nn.ReLU6(inplace=True), torch.ops.aten.hardtanh_.default]
}

with override_quantized_engine("x86"):
for inplace_relu in itertools.product(inplace_relu_list):
m = TestHelperModules.Conv2dReLUModule(inplace_relu=inplace_relu, with_bn=True)
for unary_op in unary_map.keys():
m = TestHelperModules.Conv2dUnaryModule(unary_map[unary_op][0], with_bn=True)
example_inputs = (torch.randn(2, 3, 16, 16),)
quantizer = X86InductorQuantizer().set_global(
xiq.get_default_x86_inductor_quantization_config(is_qat=True)
Expand All @@ -1048,7 +1062,7 @@ def test_qat_conv2d_unary(self):
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.relu_.default if inplace_relu else torch.ops.aten.relu.default,
unary_map[unary_op][1],
torch.ops.quantized_decomposed.quantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
]
Expand Down
30 changes: 24 additions & 6 deletions torch/ao/quantization/quantizer/x86_inductor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,18 @@ def _annotate_qat_conv2d_bn_binary(
def _annotate_qat_conv2d_bn_unary(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU]
)
fused_partitions = []
unary_patterns = [
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU],
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.Hardtanh],
[torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU6],
]
for unary_pattern in unary_patterns:
partitions = find_sequential_partitions(gm, unary_pattern)
if partitions:
# Extend the fused_partitions if partitions is not empty
fused_partitions.extend(partitions)

for fused_partition in fused_partitions:
conv_partition, bn_partition, unary_partition = fused_partition
(
Expand Down Expand Up @@ -715,9 +724,18 @@ def _annotate_conv2d_binary(
def _annotate_conv2d_unary(
self, gm: torch.fx.GraphModule, quantization_config: QuantizationConfig
) -> None:
fused_partitions = find_sequential_partitions(
gm, [torch.nn.Conv2d, torch.nn.ReLU]
)
fused_partitions = []
unary_patterns = [
[torch.nn.Conv2d, torch.nn.ReLU],
[torch.nn.Conv2d, torch.nn.Hardtanh],
[torch.nn.Conv2d, torch.nn.ReLU6],
]
for unary_pattern in unary_patterns:
partitions = find_sequential_partitions(gm, unary_pattern)
if partitions:
# Extend the fused_partitions if partitions is not empty
fused_partitions.extend(partitions)

for fused_partition in fused_partitions:
conv_partition, unary_partition = fused_partition
conv_node, unary_node = self._get_output_nodes_of_partitions(
Expand Down

0 comments on commit e1eb1b8

Please sign in to comment.