diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 768f51c53..09cfc7944 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -2,7 +2,7 @@ import torch from mmrazor.registry import MODELS -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler from .base import BaseQuantizer try: @@ -75,24 +75,25 @@ def gen_qconfig_mapping(self, qconfig_mapping): """tmp.""" conf = QConfigMapping() if GLOBAL_DICT_KEY in qconfig_mapping: - qconfig = QConfigHander(qconfig_mapping[GLOBAL_DICT_KEY]).convert() + qconfig = QConfigHandler( + qconfig_mapping[GLOBAL_DICT_KEY]).convert() conf.set_global(qconfig) for object_type, qconfig in qconfig_mapping.get( OBJECT_TYPE_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_object_type(object_type, qconfig) for module_name_regex, qconfig in qconfig_mapping.get( MODULE_NAME_REGEX_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_module_name_regex(module_name_regex, qconfig) for module_name, qconfig in qconfig_mapping.get( MODULE_NAME_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_module_name(module_name, qconfig) for module_name, object_type, index, qconfig in qconfig_mapping.get( MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_module_name_object_type_order(module_name, object_type, index, qconfig) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index d0534d361..2b75cf29c 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -33,7 +33,7 @@ del_fakequant_before_module, del_fakequant_before_op) from mmrazor.models.utils import str2class from mmrazor.registry import MODELS -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler from .base import BaseQuantizer if digit_version(torch.__version__) >= digit_version('1.13.0'): @@ -108,7 +108,7 @@ def __init__(self, extra_op_prev_wo_fakequant=tuple(), extra_op_next_wo_fakequant=tuple())): super().__init__(tracer) - self.qconfig = QConfigHander(global_qconfig) + self.qconfig = QConfigHandler(global_qconfig) if self.qconfig.w_qscheme.is_per_channel: w_mode = 'per_channel' else: diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 2d33e9875..a3cff1167 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import functools -from types import FunctionType, MethodType -from typing import Any, Callable, Dict, List, Optional, Type, Union +from types import FunctionType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -34,18 +34,24 @@ class UntracedMethodRegistry: - """A `Descriptor` class which records untraced methods.""" + """A `Descriptor` class which records untraced methods. Thus, when the + class is traced with CustomTracer, the decorated method will be as a leaf + node, not be nested traced. + + Example: + >>> # `imported_cls` is the owner of the untraced method; + >>> # `method_str` is the name of the untraced method. + >>> method_registry = UntracedMethodRegistry(method) + >>> method_registry.__set_name__(imported_cls, method_str) + + Args: + method (FunctionType): Function to be registered. + """ method_dict: Dict = dict() tracer = None - def __init__(self, method): - """_summary_ - - Args: - method (FunctionType): Function to be registered. - """ + def __init__(self, method: FunctionType): self.method = method - self.instances: Dict = dict() self.owner = None def __set_name__(self, owner, name): @@ -54,11 +60,6 @@ def __set_name__(self, owner, name): wrapped = self.method_wrapper() self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped) - def __get__(self, instance, owner): - if instance is None: - return self.method - return MethodType(self.method, instance) - def method_wrapper(self): @functools.wraps(self.method) @@ -73,33 +74,12 @@ def method(*args, **kwargs): return wrapped_method -def custom_symbolic_trace(root: Union[nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None): - """Modified `symbolic_trace` function. - - Args: - root (Union[nn.Module, Callable]): Module or function to be - traced and converted into a Graph representation. - concrete_args (Optional[Dict[str, any]]): Inputs to be partially - specialized. - - Returns: - _type_: _description_ - """ - tracer = CustomTracer() - graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance(root, - nn.Module) else root.__name__ - return GraphModule(tracer.root, graph, name) - - -def _prepare_module_dict(model: nn.Module, fx_graph): +def _prepare_module_dict(model: torch.nn.Module, fx_graph): """If there is a class method that can not be traced by the symbolic tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in ``CustomTracer``. - For example, - ``` + Example: >>> class Model: ... def __init__(self): ... self.head = ClsHead() @@ -123,7 +103,7 @@ def _prepare_module_dict(model: nn.Module, fx_graph): ... xxx ... losses = xxx ... return losses - ``` + As the ``_get_loss`` can not be traced by torch.fx, ``Toy._get_loss`` need to be added to ``skipped_methods`` in ``CustomTracer``. Hence the code above will product the following Graph:: @@ -140,8 +120,10 @@ def _prepare_module_dict(model: nn.Module, fx_graph): the original model. Args: - model (nn.Module): The original model. - fx_graph (Graph): The fx Graph traced by fx tracer. + model (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It + contains the nodes this GraphModule should use for code generation. """ def _get_attrs(target, attrs): @@ -170,7 +152,32 @@ def _get_attrs(target, attrs): return module_dict -def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'): +def build_graphmodule(model: torch.nn.Module, + fx_graph, + name: str = 'GraphModule'): + """To build GraphModule with the generated graph by CustomTracer. The + implement of skipping methods in CustomTracer will cause the confliction of + that a node is both a leaf node and non-leaf node, which will lead that the + modification to the ``graph`` also change the original ``forward``. + + Args: + model (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It + contains the nodes this GraphModule should use for code generation. + name (str): The name of generated GraphModule. + + Returns: + GraphModule: GraphModule is an nn.Module generated from an fx.Graph. + Graphmodule has a ``graph`` attribute, as well as ``code`` and + ``forward`` attributes generated from that ``graph``. + + .. warning:: + When ``graph`` is reassigned, ``code`` and ``forward`` will be + automatically regenerated. However, if you edit the contents of the + ``graph`` without reassigning the ``graph`` attribute itself, you must + call ``recompile()`` to update the generated code. + """ modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) modules.update(module_dict) @@ -179,6 +186,18 @@ def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'): @TASK_UTILS.register_module() class CustomTracer(QuantizationTracer): + """Custom tracer based on QuantizationTracer of pytorch. It can not only + skip some modules and classes while tracing, but also skip some methods + untraced by torch.fx.Tracer. + + Args: + skipped_methods (List[str], optional): Methods to be skipped while + tracing. Defaults to None. + skipped_module_names (List[str], optional): Modules to be skipped + while tracing. Defaults to None. + skipped_module_classes (List[Callable], optional): Class to be skipped + while tracing. Defaults to None. + """ def __init__(self, skipped_methods: List[str] = [], @@ -186,16 +205,6 @@ def __init__(self, skipped_module_classes: List[Callable] = [], *args, **kwargs): - """_summary_ - - Args: - skipped_methods (List[str], optional): Methods to be skipped while - tracing. Defaults to None. - skipped_module_names (List[str], optional): Modules to be skipped - while tracing. Defaults to None. - skipped_module_classes (List[str], optional): Class to be skipped - while tracing. Defaults to None. - """ super(CustomTracer, self).__init__(skipped_module_names, skipped_module_classes) UntracedMethodRegistry.tracer = self # type: ignore @@ -214,6 +223,7 @@ def _check_valid_source(source): 'source must have at least one `.`' def register_skipped_methods(self): + """Register skipped methods to UntracedMethodRegistry.method_dict.""" if not isinstance(self.skipped_methods, list): self.skipped_methods = [self.skipped_methods] for s_method in self.skipped_methods: @@ -239,7 +249,8 @@ def register_skipped_methods(self): method_registry = UntracedMethodRegistry(method) method_registry.__set_name__(imported_cls, method_str) - def call_method(self, m: nn.Module, name, method, args, kwargs): + def call_method(self, m: torch.nn.Module, name: str, method: Callable, + args: Tuple, kwargs: Dict): """Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -254,15 +265,13 @@ def call_method(self, m: nn.Module, name, method, args, kwargs): ``Module`` boundaries. Args: - - m (Module): The module for which a call is being emitted - forward (Callable): The forward() method of the ``Module`` to be - invoked + m (torch.nn.Module): The module for which a call is being emitted + name (str): The name of proxy to be created. + method (Callable): The method of the ``Module`` to be invoked args (Tuple): args of the module callsite kwargs (Dict): kwargs of the module callsite Return: - The return value from the Module call. In the case that a ``call_module`` node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever value was returned from the ``Module`` @@ -271,16 +280,37 @@ def call_method(self, m: nn.Module, name, method, args, kwargs): # module_qualified_name = self.path_of_module(m) if not self.is_skipped_method(m): return method(*args, **kwargs) - args = list(args) - args.insert(0, m) - args = tuple(args) + args_l = list(args) + args_l.insert(0, m) + args = tuple(args_l) return self.create_proxy('call_method', name, args, kwargs) - def trace(self, root, concrete_args=None): - if isinstance(root, nn.Module): + def trace(self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + """Trace ``root`` and return the corresponding FX ``Graph`` + representation. ``root`` can either be an ``nn.Module`` instance or a + Python callable. Note that after this call, ``self.root`` may be + different from the ``root`` passed in here. For example, when a free + function is passed to ``trace()``, we will create an ``nn.Module`` + instance to use as the root and add embedded constants to. + + Args: + root (Union[Module, Callable]): Either a ``Module`` or a function + to be traced through. Backwards-compatibility for this + parameter is guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that + should not be treated as Proxies. This parameter is + experimental and its backwards-compatibility is *NOT* + guaranteed. + + Returns: + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + if isinstance(root, torch.nn.Module): self.root = root fn = type(root).forward - self.submodule_paths = { + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = { mod: name for name, mod in root.named_modules() } @@ -364,13 +394,53 @@ def forward(*args, **kwargs): return self.graph - def is_skipped_method(self, m): + def is_skipped_method(self, m: torch.nn.Module): + """Judge if ``m`` is registered skipped method.""" mods = tuple(value['mod'] for value in UntracedMethodRegistry.method_dict.values()) custom = isinstance(m, mods) return custom - def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: - # return super().is_leaf_module(m, module_qualified_name) + def is_leaf_module(self, m: torch.nn.Module, + module_qualified_name: str) -> bool: + """A method to specify whether a given ``nn.Module`` is a "leaf" + module. Leaf modules are the atomic units that appear in the IR, + referenced by ``call_module`` calls. By default, Modules in the PyTorch + standard library namespace (torch.nn) are leaf modules. All other + modules are traced through and their constituent ops are recorded, + unless specified otherwise via this parameter. + + Args: + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. + For example, if you have a module hierarchy where submodule + ``foo`` contains submodule ``bar``, which contains submodule + ``baz``, that module will appear with the qualified name + ``foo.bar.baz`` here. + """ leaf = super().is_leaf_module(m, module_qualified_name) return leaf + + +def custom_symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: + """Modified `symbolic_trace` function in pytorch. Given an ``nn.Module`` or + function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + Args: + root (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially + specialized. + + Returns: + GraphModule: a Module created from the recorded operations from + ``root``. + """ + tracer = CustomTracer() + graph = tracer.trace(root, concrete_args) + name = root.__class__.__name__ if isinstance( + root, torch.nn.Module) else root.__name__ + return GraphModule(tracer.root, graph, name) diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py index e0fdf113d..2a502b8f7 100644 --- a/mmrazor/structures/quantization/qconfig.py +++ b/mmrazor/structures/quantization/qconfig.py @@ -18,7 +18,7 @@ ] -class QConfigHander(): +class QConfigHandler(): """Convert custom user-friendly qconfig format to torch's QConfig. Args: @@ -44,9 +44,9 @@ def __init__(self, qconfig: Union[Dict, Config]): w_is_per_channel = True if 'PerChannel' in a_observer.__name__: a_is_per_channel = True - self.w_qscheme = QSchemeHander( + self.w_qscheme = QSchemeHandler( is_per_channel=w_is_per_channel, **qconfig['w_qscheme']) - self.a_qscheme = QSchemeHander( + self.a_qscheme = QSchemeHandler( is_per_channel=a_is_per_channel, **qconfig['a_qscheme']) w_fake_quant = MODELS.get(qconfig['w_fake_quant']['type']) @@ -79,7 +79,7 @@ def convert(self): return torch_qconfig -class QSchemeHander(object): +class QSchemeHandler(object): """Convert the qscheme of custom user-friendly qconfig to args needed in observers. @@ -149,24 +149,3 @@ def __str__(self): return f'dtype: {self.dtype} / bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ is_per_channel: {self.is_per_channel} \ / extra_kwargs: {self.kwargs}' - - -if __name__ == '__main__': - from mmrazor.models.fake_quants import register_torch_fake_quants - from mmrazor.models.observers import register_torch_observers - register_torch_observers() - register_torch_fake_quants() - - qconfig = dict( - w_observer=dict(type='mmrazor.MovingAveragePerChannelMinMaxObserver'), - a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), - w_fake_quant=dict(type='mmrazor.FakeQuantize'), - a_fake_quant=dict(type='mmrazor.FakeQuantize'), - w_qscheme=dict( - qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), - a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), - ) - from mmengine.config import Config - qconfig = Config(qconfig) - torch_qconfig = QConfigHander(qconfig).convert() - print(torch_qconfig) diff --git a/tests/data/test_models/test_task_modules/mmcls_cfg.py b/tests/data/test_models/test_task_modules/mmcls_cfg.py new file mode 100644 index 000000000..117b9383e --- /dev/null +++ b/tests/data/test_models/test_task_modules/mmcls_cfg.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] \ No newline at end of file diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py new file mode 100644 index 000000000..207e9ccad --- /dev/null +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +from mmcls.models.backbones.resnet import ResLayer +from mmengine.config import Config +from mmengine.registry import MODELS + +try: + from torch.fx import GraphModule + from torch.fx._symbolic_trace import Graph +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + Graph = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import (CustomTracer, + UntracedMethodRegistry, + build_graphmodule, + custom_symbolic_trace) +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + _prepare_module_dict + + +class ToyModel(torch.nn.Module): + + def __init__(self): + super.__init__() + + def get_loss(self, x): + return x * 0.1 + + def extrac_feature(self, x): + return x * 2 + + def forward(self, x): + x = self.extrac_feature(x) + x = self.get_loss(x) + return x + + +class testUntracedMethodRgistry(TestCase): + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + method = ToyModel.get_loss + method_registry = UntracedMethodRegistry(method) + assert hasattr(method_registry, 'method') + assert hasattr(method_registry, 'method_dict') + assert len(method_registry.method_dict) == 0 + + def test_registry_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model = ToyModel + method = ToyModel.get_loss + method_registry = UntracedMethodRegistry(method) + method_registry.__set_name__(model, 'get_loss') + assert 'get_loss' in method_registry.method_dict.keys() + assert method_registry.method_dict['get_loss']['mod'] == model + + +class testCustomTracer(TestCase): + + def setUp(self): + self.cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + self.skipped_methods = [ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ] + self.skipped_module_names = ['backbone.layer4.0'] + self.skipped_module_classes = [ResLayer] + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # init without skipped_methods + tracer = CustomTracer() + assert hasattr(tracer, 'skipped_methods') + assert len(tracer.skipped_methods) == 0 + # init with skipped_methods(list) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods) + assert '_get_loss' in UntracedMethodRegistry.method_dict.keys() + assert '_get_predictions' in UntracedMethodRegistry.method_dict.keys() + # init with skipped_methods(str) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods[0]) + assert '_get_loss' in UntracedMethodRegistry.method_dict.keys() + # init with skipped_methods(int, error) + with self.assertRaises(TypeError): + CustomTracer(skipped_methods=123) + # init with skipped_methods(str, error) + with self.assertRaises(AssertionError): + CustomTracer(skipped_methods='_get_loss') + + def test_trace(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test trace with skipped_methods + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + graph_loss = tracer.trace(model, concrete_args={'mode': 'loss'}) + graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'}) + assert isinstance(graph_tensor, Graph) + assert isinstance(graph_loss, Graph) + skip_flag_loss = False + for node in graph_loss.nodes: + if node.op == 'call_method' and node.target == '_get_loss': + skip_flag_loss = True + assert isinstance(graph_predict, Graph) + skip_flag_predict = False + for node in graph_predict.nodes: + if node.op == 'call_method' and node.target == '_get_predictions': + skip_flag_predict = True + assert skip_flag_loss and skip_flag_predict + + # test trace with skipped_module_names + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_module_names=self.skipped_module_names) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + skip_flag = False + for node in graph_tensor.nodes: + skipped_module_name = self.skipped_module_names[0] + if node.op == 'call_module' and node.target == skipped_module_name: + skip_flag = True + assert skip_flag + + # test trace with skipped_module_classes + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer( + skipped_module_classes=self.skipped_module_classes) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + skip_flag = False + for node in graph_tensor.nodes: + if node.op == 'call_module' and node.target == 'backbone.layer1': + skip_flag = True + assert skip_flag + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_custom_symbolic_trace(): + cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + model = MODELS.build(cfg.model) + UntracedMethodRegistry.method_dict = dict() + graph_module = custom_symbolic_trace( + model, concrete_args={'mode': 'tensor'}) + assert isinstance(graph_module, GraphModule) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_build_graphmodule(): + skipped_methods = ['mmcls.models.heads.ClsHead._get_predictions'] + cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + model = MODELS.build(cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=skipped_methods) + graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'}) + graph_module = build_graphmodule(model, graph_predict) + assert isinstance(graph_module, GraphModule) + + # test _prepare_module_dict + modules = dict(model.named_modules()) + module_dict = _prepare_module_dict(model, graph_predict) + for k, v in module_dict.items(): + assert isinstance(v, torch.nn.Module) + assert not isinstance(v, modules[k].__class__) diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py index d8f53c03c..ea7f90565 100644 --- a/tests/test_models/test_task_modules/test_graph_utils.py +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -24,7 +24,7 @@ del_fakequant_after_module, del_fakequant_after_op, del_fakequant_before_function, del_fakequant_before_method, del_fakequant_before_module, del_fakequant_before_op) -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler def _get_attrs(target, attrs): @@ -119,7 +119,7 @@ def setUp(self): self.tracer = CustomTracer() self.backend_config = BackendConfigs['native'] - self.qconfig = QConfigHander(global_qconfig) + self.qconfig = QConfigHandler(global_qconfig) self.qconfig_mapping = QConfigMapping().set_global( self.qconfig.convert()) self.example_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py index 4730ab6cc..d4f98394a 100644 --- a/tests/test_structures/test_qconfig.py +++ b/tests/test_structures/test_qconfig.py @@ -14,32 +14,32 @@ from mmrazor import digit_version from mmrazor.models.fake_quants import register_torch_fake_quants from mmrazor.models.observers import register_torch_observers -from mmrazor.structures import QConfigHander, QSchemeHander +from mmrazor.structures import QConfigHandler, QSchemeHandler register_torch_observers() register_torch_fake_quants() -class TestQSchemeHander(TestCase): +class TestQSchemeHandler(TestCase): def test_init(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') # per_channel - qscheme = QSchemeHander(is_symmetry=True, is_per_channel=True) + qscheme = QSchemeHandler(is_symmetry=True, is_per_channel=True) assert qscheme.torch_qscheme is torch.per_channel_symmetric # per_tensor - qscheme = QSchemeHander(is_symmetry=True, is_per_channel=False) + qscheme = QSchemeHandler(is_symmetry=True, is_per_channel=False) assert qscheme.torch_qscheme is torch.per_tensor_symmetric # qdtype is incorrect - self.assertRaises(AssertionError, QSchemeHander, 'float') + self.assertRaises(AssertionError, QSchemeHandler, 'float') # is_symmetric_range kwargs = {'is_symmetric_range': True} - qscheme = QSchemeHander(**kwargs) + qscheme = QSchemeHandler(**kwargs) assert qscheme.is_symmetric_range is True def test_to_observer_params(self): @@ -47,32 +47,32 @@ def test_to_observer_params(self): self.skipTest('version of torch < 1.13.0') # qdtype = quint8 - ret_params = QSchemeHander(qdtype='quint8').to_observer_params() + ret_params = QSchemeHandler(qdtype='quint8').to_observer_params() assert ret_params['dtype'] == torch.quint8 assert ret_params['quant_min'] == 0 and ret_params['quant_max'] == 255 # qdtype = qint8, is_symmetric_range=False - ret_params = QSchemeHander(qdtype='qint8').to_observer_params() + ret_params = QSchemeHandler(qdtype='qint8').to_observer_params() assert ret_params['dtype'] == torch.qint8 assert ret_params['quant_min'] == -128 and ret_params[ 'quant_max'] == 127 # qdtype = qint8, is_symmetric_range=True - ret_params = QSchemeHander( + ret_params = QSchemeHandler( qdtype='qint8', is_symmetric_range=True).to_observer_params() assert ret_params['quant_min'] == -127 and ret_params[ 'quant_max'] == 127 # per_channel - ret_params = QSchemeHander(is_per_channel=True).to_observer_params() + ret_params = QSchemeHandler(is_per_channel=True).to_observer_params() assert ret_params['ch_axis'] == 0 # per_tensor - ret_params = QSchemeHander(is_per_channel=False).to_observer_params() + ret_params = QSchemeHandler(is_per_channel=False).to_observer_params() assert 'ch_axis' not in ret_params.keys() -class TestQConfigHander(TestCase): +class TestQConfigHandler(TestCase): def setUp(self): self.qconfig_dict = dict( @@ -93,26 +93,26 @@ def test_check_qconfig(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') - assert QConfigHander.check_qconfig(self.qconfig_dict) is True - assert QConfigHander.check_qconfig(self.qconfig) is True + assert QConfigHandler.check_qconfig(self.qconfig_dict) is True + assert QConfigHandler.check_qconfig(self.qconfig) is True qconfig_dict = copy.copy(self.qconfig_dict) print(qconfig_dict) qconfig_dict.pop('w_observer') - assert QConfigHander.check_qconfig(qconfig_dict) is False + assert QConfigHandler.check_qconfig(qconfig_dict) is False def test_init(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') # test dict init - qconfig = QConfigHander(self.qconfig_dict) + qconfig = QConfigHandler(self.qconfig_dict) assert hasattr(qconfig, 'w_qscheme') assert hasattr(qconfig, 'a_qscheme') assert hasattr(qconfig, 'w_fake_quant') assert hasattr(qconfig, 'a_fake_quant') # test mmengine's Config init - qconfig = QConfigHander(self.qconfig) + qconfig = QConfigHandler(self.qconfig) assert hasattr(qconfig, 'w_qscheme') assert hasattr(qconfig, 'a_qscheme') assert hasattr(qconfig, 'w_fake_quant') @@ -126,6 +126,6 @@ def test_convert(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') - qconfig = QConfigHander(self.qconfig) + qconfig = QConfigHandler(self.qconfig) torch_qconfig = qconfig.convert() assert isinstance(torch_qconfig, QConfig)