Skip to content

Commit

Permalink
[Quant] [Inductor] Fix an issue in QConv Binary Pattern Match (pytorc…
Browse files Browse the repository at this point in the history
…h#114541)

**Summary**
Add the `extra_check` in `_register_quantized_conv_binary_lowering` to skip the pattern which matched unexpected. To match a Conv-Binary pattern, we should expect the extra input of binary node comes from a dequant pattern instead of a constant scalar.

**Test Plan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d_add_2
```

Pull Request resolved: pytorch#114541
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
ghstack dependencies: pytorch#114540
  • Loading branch information
leslie-fang-intel authored and pytorchmergebot committed Nov 28, 2023
1 parent 8556a09 commit 11f11e9
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
46 changes: 46 additions & 0 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,52 @@ def test_qconv2d_add_relu_cpu(self):
def test_qconv2d_add_relu_int8_mixed_bf16(self):
self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
def test_qconv2d_add_2(self):
r"""
This testcase prevents this pattern be matched as a conv_binary fusion by mistake.
Conv(X) 3
\ /
Add
We see this pattern in Mobilenet v3 large which add is decomposed from torch.nn.Hardswish or torch.nn.Hardsigmoid.
"""

class M(torch.nn.Module):
def __init__(
self,
post_op,
):
super().__init__()
self.conv = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1)
self.post_op = post_op

def forward(self, x):
return self.post_op(self.conv(x))

for post_op in [
torch.nn.Hardswish(inplace=True),
torch.nn.Hardsigmoid(inplace=True),
]:
mod = M(post_op).eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
)

def matcher_check_fn():
# Shouldn't hit conv binary fusion
self.assertEqual(
counters["inductor"]["qconv2d_binary_matcher_count"], 0
)

self._test_common(
mod,
(v,),
check_quantization=True,
matcher_check_fn=matcher_check_fn,
)

@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfRocm
Expand Down
29 changes: 29 additions & 0 deletions torch/_inductor/fx_passes/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,34 @@ def qlinear(match: Match, *args, **kwargs):
return qlinear


def _is_valid_quantized_conv_binary_optimization_pattern(output_dtype):
# Check if it's a valid Conv Binary Pattern:
# * qconv2d_pointwise should only has one users
# * Extra input of binary node comes from dequant pattern
def fn(match):
qconv2d_node_after_weight_prepack = filter_nodes(
match.nodes, torch.ops.onednn.qconv2d_pointwise
)[0]
if len(qconv2d_node_after_weight_prepack.users) != 1:
return False
if output_dtype is not None:
binary_node_inputs = list(qconv2d_node_after_weight_prepack.users)[0].args
assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
extra_input_node = None
for arg in binary_node_inputs:
if arg != qconv2d_node_after_weight_prepack:
extra_input_node = arg
break
assert extra_input_node is not None
if (not isinstance(extra_input_node, torch.fx.Node)) or (
extra_input_node.target != aten.mul.Tensor
):
return False
return True

return fn


def _register_quantized_conv_binary_lowering(
pattern,
pass_number,
Expand All @@ -398,6 +426,7 @@ def _register_quantized_conv_binary_lowering(
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_conv_binary_optimization_pattern(output_dtype),
pass_number=pass_number,
)
def qconv_binary(match: Match, *args, **kwargs):
Expand Down

0 comments on commit 11f11e9

Please sign in to comment.