Skip to content

Commit

Permalink
[quant][graphmode][fx] Allow user to specify qconfig for call_method (p…
Browse files Browse the repository at this point in the history
…ytorch#49621)

Summary:
Pull Request resolved: pytorch#49621

This adds support to configure qconfig for a call_method, e.g. x.chunk, this will help workaround
a problem in our internal model.

TODO: since call_method is also a string and we flatten the qconfig, might need to resolve namespace conflict between
call_method and module_name
TODO: Add scope support to set the qconfig for call_method correctly with original qconfig

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D25651828

fbshipit-source-id: 82d66b121d37c8274fd481b6a2e9f9b54c5ca73d
  • Loading branch information
jerryzh168 authored and hwangdeyu committed Dec 23, 2020
1 parent 32073ec commit cd8ef1a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 12 deletions.
36 changes: 36 additions & 0 deletions test/quantization/test_quantize_fx.py
Expand Up @@ -534,6 +534,42 @@ def forward(self, x):
m = convert_fx(m)
m(dict_input)

@override_qengines
def test_attention(self):
""" Make sure quantization runs for a corner case in attention module
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)

def forward(self, x):
x = self.conv(x)
q, k, v = x.chunk(3, dim=0)
q = q.contiguous().view(-1, 1).transpose(0, 1)
k = k.contiguous().view(-1, 1).transpose(0, 1)
v = v.contiguous().view(-1, 1).transpose(0, 1)
torch._assert(
k.size(1) == 1, "key size should be equal to 1"
)
r = torch.mm(k, v)
return q * k + r

tensor_input = torch.randn(3, 1, 1, 1)
m = M().eval()
qconfig_dict = {
"": None,
"object_type": [
(nn.Conv2d, default_qconfig),
("chunk", None)
]
}
# make sure it runs
m = prepare_fx(m, qconfig_dict)
m(tensor_input)
m = convert_fx(m)
m(tensor_input)

def test_standalone_module(self):
class StandaloneModule(torch.nn.Module):
def __init__(self):
Expand Down
24 changes: 17 additions & 7 deletions torch/quantization/fx/qconfig_utils.py
@@ -1,7 +1,13 @@
from .utils import _parent_name
import torch
from collections import OrderedDict
from typing import Union, Callable, Any
import re

from .utils import _parent_name

QConfigAny = Union[torch.quantization.QConfig,
torch.quantization.QConfigDynamic, None]

def get_flattened_qconfig_dict(qconfig_dict):
""" flatten the global, object_type and module_name qconfig
to the same qconfig_dict so that it can be used by
Expand Down Expand Up @@ -50,12 +56,16 @@ def _convert_to_ordered_dict(key, qconfig_dict):
_convert_to_ordered_dict('module_name_regex', qconfig_dict)
_convert_to_ordered_dict('module_name', qconfig_dict)

def get_module_type_qconfig(qconfig_dict, module_type, fallback_qconfig):
def get_object_type_qconfig(
qconfig_dict: Any,
object_type: Union[Callable, str],
fallback_qconfig: QConfigAny) -> QConfigAny:
# object_type can be
# 1. module type (call_module)
# 2. function (call_function)
# 3. string (call_method)
return qconfig_dict['object_type'].get(
module_type, fallback_qconfig)

def get_function_qconfig(qconfig_dict, function, fallback_qconfig):
return qconfig_dict['object_type'].get(function, fallback_qconfig)
object_type, fallback_qconfig)

def get_module_name_regex_qconfig(qconfig_dict, module_name, fallback_qconfig):
for regex_pattern, qconfig in \
Expand All @@ -80,7 +90,7 @@ def get_module_name_qconfig(qconfig_dict, module_name, fallback_qconfig):
# global_qconfig if necessary
def get_qconfig(modules, qconfig_dict, module_name, global_qconfig):
assert modules is not None
module_type_qconfig = get_module_type_qconfig(
module_type_qconfig = get_object_type_qconfig(
qconfig_dict, type(modules[module_name]), global_qconfig)
module_name_regex_qconfig = get_module_name_regex_qconfig(
qconfig_dict, module_name, module_type_qconfig)
Expand Down
8 changes: 3 additions & 5 deletions torch/quantization/fx/quantize.py
Expand Up @@ -61,12 +61,9 @@

import warnings

from typing import Optional, Dict, Any, List, Union, Tuple, Set, Callable
from typing import Optional, Dict, Any, List, Tuple, Set, Callable

# Define helper types

QConfigAny = Union[torch.quantization.QConfig,
torch.quantization.QConfigDynamic, None]
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
QConfigAny]

Expand Down Expand Up @@ -302,7 +299,7 @@ def _generate_qconfig_map(
# precedence: [TODO] module_name_qconfig (need scope support
# from fx)
# > function_qconfig > global_qconfig
function_qconfig = get_function_qconfig(
function_qconfig = get_object_type_qconfig(
qconfig_dict, node.target, global_qconfig)
self.qconfig_map[node.name] = function_qconfig
elif node.op == 'call_method':
Expand All @@ -318,6 +315,7 @@ def _generate_qconfig_map(
"qconfig for value {}".format(node.name))
qconfig = get_qconfig(
self.modules, qconfig_dict, '', global_qconfig)
qconfig = get_object_type_qconfig(qconfig_dict, node.target, qconfig)
self.qconfig_map[node.name] = qconfig
elif node.op == 'call_module':
module_qconfig = get_qconfig(
Expand Down

0 comments on commit cd8ef1a

Please sign in to comment.