diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 3bf8e367f1a7..98283e713747 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -534,6 +534,42 @@ def forward(self, x): m = convert_fx(m) m(dict_input) + @override_qengines + def test_attention(self): + """ Make sure quantization runs for a corner case in attention module + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv(x) + q, k, v = x.chunk(3, dim=0) + q = q.contiguous().view(-1, 1).transpose(0, 1) + k = k.contiguous().view(-1, 1).transpose(0, 1) + v = v.contiguous().view(-1, 1).transpose(0, 1) + torch._assert( + k.size(1) == 1, "key size should be equal to 1" + ) + r = torch.mm(k, v) + return q * k + r + + tensor_input = torch.randn(3, 1, 1, 1) + m = M().eval() + qconfig_dict = { + "": None, + "object_type": [ + (nn.Conv2d, default_qconfig), + ("chunk", None) + ] + } + # make sure it runs + m = prepare_fx(m, qconfig_dict) + m(tensor_input) + m = convert_fx(m) + m(tensor_input) + def test_standalone_module(self): class StandaloneModule(torch.nn.Module): def __init__(self): diff --git a/torch/quantization/fx/qconfig_utils.py b/torch/quantization/fx/qconfig_utils.py index 6326a2e0da59..3db370c4422d 100644 --- a/torch/quantization/fx/qconfig_utils.py +++ b/torch/quantization/fx/qconfig_utils.py @@ -1,7 +1,13 @@ -from .utils import _parent_name +import torch from collections import OrderedDict +from typing import Union, Callable, Any import re +from .utils import _parent_name + +QConfigAny = Union[torch.quantization.QConfig, + torch.quantization.QConfigDynamic, None] + def get_flattened_qconfig_dict(qconfig_dict): """ flatten the global, object_type and module_name qconfig to the same qconfig_dict so that it can be used by @@ -50,12 +56,16 @@ def _convert_to_ordered_dict(key, qconfig_dict): _convert_to_ordered_dict('module_name_regex', qconfig_dict) _convert_to_ordered_dict('module_name', qconfig_dict) -def get_module_type_qconfig(qconfig_dict, module_type, fallback_qconfig): +def get_object_type_qconfig( + qconfig_dict: Any, + object_type: Union[Callable, str], + fallback_qconfig: QConfigAny) -> QConfigAny: + # object_type can be + # 1. module type (call_module) + # 2. function (call_function) + # 3. string (call_method) return qconfig_dict['object_type'].get( - module_type, fallback_qconfig) - -def get_function_qconfig(qconfig_dict, function, fallback_qconfig): - return qconfig_dict['object_type'].get(function, fallback_qconfig) + object_type, fallback_qconfig) def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig): for regex_pattern, qconfig in \ @@ -80,7 +90,7 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig): # global_qconfig if necessary def get_qconfig(modules, qconfig_dict, module_name, global_qconfig): assert modules is not None - module_type_qconfig = get_module_type_qconfig( + module_type_qconfig = get_object_type_qconfig( qconfig_dict, type(modules[module_name]), global_qconfig) module_name_regex_qconfig = get_module_name_regex_qconfig( qconfig_dict, module_name, module_type_qconfig) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 99fdb841c28c..273cecfcad43 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -61,12 +61,9 @@ import warnings -from typing import Optional, Dict, Any, List, Union, Tuple, Set, Callable +from typing import Optional, Dict, Any, List, Tuple, Set, Callable # Define helper types - -QConfigAny = Union[torch.quantization.QConfig, - torch.quantization.QConfigDynamic, None] MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler, QConfigAny] @@ -302,7 +299,7 @@ def _generate_qconfig_map( # precedence: [TODO] module_name_qconfig (need scope support # from fx) # > function_qconfig > global_qconfig - function_qconfig = get_function_qconfig( + function_qconfig = get_object_type_qconfig( qconfig_dict, node.target, global_qconfig) self.qconfig_map[node.name] = function_qconfig elif node.op == 'call_method': @@ -318,6 +315,7 @@ def _generate_qconfig_map( "qconfig for value {}".format(node.name)) qconfig = get_qconfig( self.modules, qconfig_dict, '', global_qconfig) + qconfig = get_object_type_qconfig(qconfig_dict, node.target, qconfig) self.qconfig_map[node.name] = qconfig elif node.op == 'call_module': module_qconfig = get_qconfig(