From 85ed93ec5e1db67ee2aa6b6ea632c1a90f529f89 Mon Sep 17 00:00:00 2001 From: "P.Huang" <37200926+FreakieHuang@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:18:20 +0800 Subject: [PATCH 01/44] [FEATURE] add quant algo `Learned Step Size Quantization` (#346) * update * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * updated * retina loss & predict & tesnor DONE * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai Co-authored-by: jacky * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * for RFC * Customed FX initialize * add UT init * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix review comments * fix CI * fix UTs * update torch requirements Co-authored-by: huangpengsheng Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: humu789 --- .github/workflows/build.yml | 38 --- configs/quantization/ptq/adaround.py | 47 +++ configs/quantization/ptq/demo.py | 1 + configs/quantization/qat/demo.py | 1 + .../qat/lsq_resnet50_8xb16_cifar10.py | 37 +++ mmrazor/engine/runner/__init__.py | 4 +- mmrazor/engine/runner/quantization_loops.py | 298 ++++++++++++++++++ mmrazor/engine/runner/utils/state.py | 19 ++ mmrazor/engine/runner/utils/subgraph.py | 61 ++++ mmrazor/models/__init__.py | 4 + mmrazor/models/algorithms/__init__.py | 3 +- .../algorithms/quantization/__init__.py | 4 + .../models/algorithms/quantization/base.py | 116 +++++++ mmrazor/models/fake_quants/__init__.py | 10 + mmrazor/models/fake_quants/adaround.py | 98 ++++++ mmrazor/models/fake_quants/base.py | 124 ++++++++ mmrazor/models/fake_quants/lsq.py | 137 ++++++++ mmrazor/models/fake_quants/qdrop.py | 44 +++ mmrazor/models/losses/__init__.py | 1 + mmrazor/models/losses/adaround_loss.py | 87 +++++ .../mutable_channel/units/channel_unit.py | 5 + .../channel_mutator/channel_mutator.py | 18 +- .../slimmable_channel_mutator.py | 2 +- .../module_mutator/diff_module_mutator.py | 117 +++++++ .../mutators/module_mutator/module_mutator.py | 94 ++++++ mmrazor/models/observers/__init__.py | 5 + mmrazor/models/observers/base.py | 74 +++++ mmrazor/models/observers/lsq_observer.py | 59 ++++ mmrazor/models/observers/minmax.py | 97 ++++++ mmrazor/models/observers/mse.py | 156 +++++++++ mmrazor/models/quantizers/__init__.py | 5 + mmrazor/models/quantizers/base.py | 194 ++++++++++++ mmrazor/models/quantizers/trt_quantizer.py | 23 ++ .../models/task_modules/tracer/__init__.py | 4 +- .../models/task_modules/tracer/fx/__init__.py | 5 + .../task_modules/tracer/fx/custom_tracer.py | 281 +++++++++++++++++ mmrazor/models/utils/quantization_util.py | 217 +++++++++++++ mmrazor/structures/quantization/__init__.py | 5 + .../quantization/backend_default_qconfigs.py | 46 +++ mmrazor/structures/quantization/qscheme.py | 68 ++++ mmrazor/testing/__init__.py | 1 + mmrazor/testing/_fx_models.py | 42 +++ .../test_algorithms/test_general_quant.py | 34 ++ .../test_lsq_fake_quants.py | 23 ++ .../test_mutators/test_diff_mutator.py | 235 ++++++++++++++ .../test_observers/test_observer.py | 38 +++ .../test_quantizers/test_trt_quantizer.py | 34 ++ .../test_task_modules/test_custom_tracer.py | 35 ++ tools/ckpt_demo.py | 13 + tools/slurm_test.sh | 26 +- tools/tracer_demo.py | 93 ++++++ 51 files changed, 3116 insertions(+), 67 deletions(-) create mode 100644 configs/quantization/ptq/adaround.py create mode 100644 configs/quantization/ptq/demo.py create mode 100644 configs/quantization/qat/demo.py create mode 100644 configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py create mode 100644 mmrazor/engine/runner/quantization_loops.py create mode 100644 mmrazor/engine/runner/utils/state.py create mode 100644 mmrazor/engine/runner/utils/subgraph.py create mode 100644 mmrazor/models/algorithms/quantization/__init__.py create mode 100644 mmrazor/models/algorithms/quantization/base.py create mode 100644 mmrazor/models/fake_quants/__init__.py create mode 100644 mmrazor/models/fake_quants/adaround.py create mode 100644 mmrazor/models/fake_quants/base.py create mode 100644 mmrazor/models/fake_quants/lsq.py create mode 100644 mmrazor/models/fake_quants/qdrop.py create mode 100644 mmrazor/models/losses/adaround_loss.py create mode 100644 mmrazor/models/mutators/module_mutator/diff_module_mutator.py create mode 100644 mmrazor/models/mutators/module_mutator/module_mutator.py create mode 100644 mmrazor/models/observers/__init__.py create mode 100644 mmrazor/models/observers/base.py create mode 100644 mmrazor/models/observers/lsq_observer.py create mode 100644 mmrazor/models/observers/minmax.py create mode 100644 mmrazor/models/observers/mse.py create mode 100644 mmrazor/models/quantizers/__init__.py create mode 100644 mmrazor/models/quantizers/base.py create mode 100644 mmrazor/models/quantizers/trt_quantizer.py create mode 100644 mmrazor/models/task_modules/tracer/fx/__init__.py create mode 100644 mmrazor/models/task_modules/tracer/fx/custom_tracer.py create mode 100644 mmrazor/models/utils/quantization_util.py create mode 100644 mmrazor/structures/quantization/__init__.py create mode 100644 mmrazor/structures/quantization/backend_default_qconfigs.py create mode 100644 mmrazor/structures/quantization/qscheme.py create mode 100644 mmrazor/testing/_fx_models.py create mode 100644 tests/test_models/test_algorithms/test_general_quant.py create mode 100644 tests/test_models/test_fake_quantize/test_lsq_fake_quants.py create mode 100644 tests/test_models/test_mutators/test_diff_mutator.py create mode 100644 tests/test_models/test_observers/test_observer.py create mode 100644 tests/test_models/test_quantizers/test_trt_quantizer.py create mode 100644 tests/test_models/test_task_modules/test_custom_tracer.py create mode 100644 tools/ckpt_demo.py create mode 100644 tools/tracer_demo.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e00ed24c8..53a184a3d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,44 +31,6 @@ jobs: python-version: [3.7] torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] include: - - torch: 1.6.0 - torch_version: 1.6 - torchvision: 0.7.0 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - python-version: 3.8 - - torch: 1.8.0 - torch_version: 1.8 - torchvision: 0.9.0 - - torch: 1.8.0 - torch_version: 1.8 - torchvision: 0.9.0 - python-version: 3.8 - - torch: 1.9.0 - torch_version: 1.9 - torchvision: 0.10.0 - - torch: 1.9.0 - torch_version: 1.9 - torchvision: 0.10.0 - python-version: 3.8 - - torch: 1.10.0 - torch_version: 1.10 - torchvision: 0.11.0 - - torch: 1.10.0 - torch_version: 1.10 - torchvision: 0.11.0 - python-version: 3.8 - - torch: 1.11.0 - torch_version: 1.11 - torchvision: 0.12.0 - - torch: 1.11.0 - torch_version: 1.11 - torchvision: 0.12.0 - python-version: 3.8 - torch: 1.12.0 torch_version: 1.12 torchvision: 0.13.0 diff --git a/configs/quantization/ptq/adaround.py b/configs/quantization/ptq/adaround.py new file mode 100644 index 000000000..389575dc6 --- /dev/null +++ b/configs/quantization/ptq/adaround.py @@ -0,0 +1,47 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +test_cfg = dict( + _delete_=True, + type='mmrazor.PTQLoop', + dataloader=_base_.test_dataloader, + evaluator=_base_.test_evaluator, + calibrate_dataloader=_base_.train_dataloader, + batch_num=32, + # reconstruction_cfg=dict( + # pattern='layer', + # loss=dict( + # type='mmrazor.AdaRoundLoss', + # iters=20000 + # ) + # ) +) + +model = dict( + _delete_=True, + type='mmrazor.GeneralQuant', + architecture=_base_.model, + quantizer=dict( + type='mmrazor.CustomQuantizer', + is_qat=False, + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.MSEObserver'), + a_observer=dict(type='mmrazor.EMAMSEObserver'), + w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + bit=2, + is_symmetry=False, + is_per_channel=True, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=4, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) diff --git a/configs/quantization/ptq/demo.py b/configs/quantization/ptq/demo.py new file mode 100644 index 000000000..af6a0a5df --- /dev/null +++ b/configs/quantization/ptq/demo.py @@ -0,0 +1 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] diff --git a/configs/quantization/qat/demo.py b/configs/quantization/qat/demo.py new file mode 100644 index 000000000..be3ec6013 --- /dev/null +++ b/configs/quantization/qat/demo.py @@ -0,0 +1 @@ +_base_ = ['./lsq_resnet50_8xb16_cifar10.py'] diff --git a/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py new file mode 100644 index 000000000..a246bc265 --- /dev/null +++ b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py @@ -0,0 +1,37 @@ +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] + +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=_base_.train_cfg.max_epochs, +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GeneralQuant', + architecture={{_base_.model}}, + quantizer=dict( + type='TensorRTQuantizer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.MinMaxObserver'), + a_observer=dict(type='mmrazor.EMAMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + bit=2, + is_symmetry=False, + is_per_channel=True, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=4, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 10eb2b598..647d8b410 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -4,6 +4,7 @@ from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop from .iteprune_val_loop import ItePruneValLoop +from .quantization_loops import PTQLoop, QATEpochBasedLoop from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop from .subnet_val_loop import SubnetValLoop @@ -12,5 +13,6 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop', 'AutoSlimGreedySearchLoop' + 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'PTQLoop', + 'QATEpochBasedLoop' ] diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py new file mode 100644 index 000000000..2f15f5deb --- /dev/null +++ b/mmrazor/engine/runner/quantization_loops.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.evaluator import Evaluator +from mmengine.registry import MODELS +from mmengine.runner import EpochBasedTrainLoop, TestLoop +from torch.utils.data import DataLoader + +from mmrazor.models.task_modules import (ModuleInputsRecorder, + ModuleOutputsRecorder, + RecorderManager) +from mmrazor.registry import LOOPS +from .utils import extract_blocks, extract_layers, extract_subgraph + +_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) + + +@LOOPS.register_module() +class QATEpochBasedLoop(EpochBasedTrainLoop): + """`EpochBasedLoop` for `QuantizationAwareTraining` + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + max_epochs (int): Total training epochs. + calibrate_dataloader (Dataloader or dict, optional): A dataloader + object or a dict to build a dataloader for calibration. Defaults + to None. + val_begin (int): The epoch that begins validating. + Defaults to 1. + val_interval (int): Validation interval. Defaults to 1. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. + """ + + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + calibrate_dataloader: Union[DataLoader, Dict] = None, + val_begin: int = 1, + val_interval: int = 1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__(runner, dataloader, max_epochs, val_begin, + val_interval, dynamic_intervals) + if isinstance(calibrate_dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) + self.calibrate_dataloader = runner.build_dataloader( + calibrate_dataloader, + seed=runner.seed, + diff_rank_seed=diff_rank_seed) + else: + self.calibrate_dataloader = calibrate_dataloader + + self.is_calibrate = True if calibrate_dataloader is not None else False + + if self.runner.distributed: + self.model = runner.model.module + else: + self.model = runner.model + + def calibrate(self, calibrate_dataloader) -> None: + self.model.eval() + with torch.no_grad(): + for batch_data in calibrate_dataloader: + self.model(batch_data) + + def run(self) -> None: + """Launch training.""" + self.runner.call_hook('before_train') + + self.model.prepare() + + if self.is_calibrate: + self.model.state = (1, 0) + self.calibrate(self.calibrate_dataloader) + + self.model.state = (1, 1) + + while self._epoch < self._max_epochs: + self.run_epoch() + + self._decide_current_val_interval() + if (self.runner.val_loop is not None + and self._epoch >= self.val_begin + and self._epoch % self.val_interval == 0): + self.runner.val_loop.run() + + self.model.convert() + + # self.runner.val_loop.run() + + self.runner.call_hook('after_train') + + +@LOOPS.register_module() +class PTQLoop(TestLoop): + """`TestLoop` for Post Training Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + calibrate_dataloader (Dataloader or dict, optional): A dataloader + object or a dict to build a dataloader for calibration. Defaults + to None. + batch_num (Optional[int], optional): Total calibration batches. + Defaults to None. + reconstruction_cfg (Optional[Dict], optional): Model reconstruction + configuration. Defaults to None. + fp16 (bool, optional): Enable FP16 training mode. Defaults to False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + calibrate_dataloader: Optional[Union[DataLoader, + Dict]] = None, + batch_num: Optional[int] = None, + reconstruction_cfg: Optional[Dict] = None, + fp16: bool = False): + super().__init__(runner, dataloader, evaluator, fp16) + if isinstance(calibrate_dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) + self.calibrate_dataloader = runner.build_dataloader( + calibrate_dataloader, + seed=runner.seed, + diff_rank_seed=diff_rank_seed) + else: + self.calibrate_dataloader = calibrate_dataloader + + self.is_calibrate = True if calibrate_dataloader is not None else False + + if self.runner.distributed: + self.model = runner.model.module + else: + self.model = runner.model + + self.batch_num = batch_num + self.config = reconstruction_cfg + + def calibrate(self, calibrate_dataloader) -> None: + self.model.eval() + with torch.no_grad(): + for i, batch_data in enumerate(calibrate_dataloader): + if self.batch_num and i >= self.batch_num: + break + self.model.calib_step(batch_data) + + def _save_inter_result(self, + model, + dataloader, + slices, + store_input=True, + store_output=True): + recorders = {} + for s in slices: + node_l, node_r = s[:2] + if store_input: + recorders[node_l.target + '_input'] = ModuleInputsRecorder( + node_l.target) + if store_output: + recorders[node_r.target + '_output'] = ModuleOutputsRecorder( + node_r.target) + manager = RecorderManager(recorders) + manager.initialize(model) + + with torch.no_grad(): + with manager: + for i, batch_data in enumerate(dataloader): + if self.batch_num and i >= self.batch_num: + break + batch_data = self.model.data_preprocessor( + batch_data, False) + model(**batch_data) + return manager + + def sub_reconstruction(self, graphmodule, input_recorder, output_recorder, + config): + w_para = [] + for layer in graphmodule.modules(): + # import pdb + # pdb.set_trace() + if isinstance(layer, _ADAROUND_SUPPORT_TYPE): + weight_fake_quant = layer.weight_fake_quant + weight_fake_quant.init(layer.weight.data) + w_para += [weight_fake_quant.alpha] + + w_opt = torch.optim.Adam(w_para) + loss_func = MODELS.build(config.loss) + + for _ in range(config.loss.iters): + w_opt.zero_grad() + + data_size = len(input_recorder.data_buffer) + data_index = np.random.randint(0, data_size) + out_quant = graphmodule( + input_recorder.get_recorder_data(data_index)) + out_fp = output_recorder.get_recorder_data(data_index) + err = loss_func(graphmodule, out_quant, out_fp) + err.backward() + w_opt.step() + + for layer in graphmodule.modules(): + if isinstance(layer, _ADAROUND_SUPPORT_TYPE): + weight_fake_quant = layer.weight_fake_quant + layer.weight.data = weight_fake_quant.get_hard_value( + layer.weight.data) + weight_fake_quant.adaround = False + if isinstance(layer, torch.quantization.FakeQuantize) and hasattr( + layer, 'prob'): + # recover to promise that drop activation quantization only + # occurs at reconstruction phase + layer.prob = 1.0 + + def reconstruction(self, graphmodule, calibrate_dataloader, config): + assert isinstance(graphmodule, torch.fx.GraphModule) + graphmodule_fp = graphmodule + graphmodule_quant = copy.deepcopy(graphmodule) + + # get layers/blocks need to reconstructe + slices = [] + if config.pattern == 'layer': + slices = extract_layers( + graphmodule, layer_types=_ADAROUND_SUPPORT_TYPE) + elif config.pattern == 'block': + slices = extract_blocks(graphmodule) + else: + # TODO: add remind + raise NotImplementedError + + # save fp inputs and outputs of each layers + manager_fp = self._save_inter_result(graphmodule_fp, + self.calibrate_dataloader, slices) + + # extract subgraph_module + for s in slices: + sub_graphmodule = extract_subgraph(graphmodule_quant, s) + manager_quant = self._save_inter_result( + graphmodule_quant, + self.calibrate_dataloader, [s], + store_output=False) + input_index = s[0].target + '_input' + output_index = s[1].target + '_output' + input_recorder = manager_quant.get_recorder(input_index) + output_recorder = manager_fp.get_recorder(output_index) + self.sub_reconstruction(sub_graphmodule, input_recorder, + output_recorder, config) + + return graphmodule_quant + + def run(self) -> None: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + + self.model.eval() + self.model.prepare() + + if self.is_calibrate: + self.model.state = (1, 0) + self.calibrate(self.calibrate_dataloader) + + self.model.state = (1, 1) + + if self.config is not None: + self.model.architecture = self.reconstruction( + self.model.architecture, self.calibrate_dataloader, + self.config) + + self.model.convert() + + self.model.eval() + from torch.onnx import OperatorExportTypes + dummy_input = torch.randn([1, 3, 224, 224]) + onnx_path = os.path.join(self.runner.work_dir, 'quantizied.onnx') + torch.onnx.export( + self.model.architecture, + dummy_input, + onnx_path, + opset_version=11, + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) + + self.runner.call_hook('after_test') diff --git a/mmrazor/engine/runner/utils/state.py b/mmrazor/engine/runner/utils/state.py new file mode 100644 index 000000000..2f6d602a5 --- /dev/null +++ b/mmrazor/engine/runner/utils/state.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.logging import print_log +from torch.ao.quantization import FakeQuantize + + +# TODO: may be removed +def set_quant_state(model, enable_observer=True, enable_fake_quant=True): + for name, submodule in model.named_modules(): + if isinstance(submodule, FakeQuantize): + if enable_observer: + submodule.enable_observer() + else: + submodule.disable_observer() + if enable_fake_quant: + submodule.enable_fake_quant() + else: + submodule.disable_fake_quant() + print_log(f'Enable observer: {enable_observer}; \ + Enable fake quant: {enable_fake_quant}') diff --git a/mmrazor/engine/runner/utils/subgraph.py b/mmrazor/engine/runner/utils/subgraph.py new file mode 100644 index 000000000..ea0f8837f --- /dev/null +++ b/mmrazor/engine/runner/utils/subgraph.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.fx as fx + + +def extract_subgraph(graphmodule, block_slice): + subgraph = copy.deepcopy(graphmodule.graph) + block_start, block_end = block_slice[:2] + for node in subgraph.nodes: + if node.name == 'inputs': + input_node = node + if node.name == block_start.name: + node.replace_input_with(node.prev, input_node) + if node.name == block_end.name: + output_node = node + if node.op == 'output': + node.replace_input_with(node.prev, output_node) + subgraph.lint() + subgraph_module = fx.GraphModule(graphmodule, subgraph) + subgraph_module.graph.eliminate_dead_code() + subgraph_module.recompile() + return subgraph_module + + +def extract_blocks(graph, key_word='layer'): + block_slices = [] + block_slice = [] + pre_stage_index, pre_block_index = 0, 0 + cur_stage_index, cur_block_index = 0, 0 + for node in graph.nodes: + if key_word not in node.name: + continue + else: + items = node.name.split('_') + for i, item in enumerate(items): + if key_word in item: + cur_stage_index = int(item[5:]) + cur_block_index = int(items[i + 1]) + break + if (cur_block_index != pre_block_index) or (cur_stage_index != + pre_stage_index): + block_slice.append(node.prev) + if len(block_slice) == 2: + block_slices.append(block_slice) + block_slice = [] + block_slice.append(node) + + pre_stage_index, pre_block_index = cur_stage_index, cur_block_index + + return block_slices + + +def extract_layers(graphmodule, layer_types): + layer_slices = [] + for node in graphmodule.graph.nodes: + if node.op == 'call_module': + m = graphmodule.get_submodule(node.target) + if isinstance(m, layer_types): + layer_slices.append((node, node)) + return layer_slices diff --git a/mmrazor/models/__init__.py b/mmrazor/models/__init__.py index f5295aa9e..e5b9ec451 100644 --- a/mmrazor/models/__init__.py +++ b/mmrazor/models/__init__.py @@ -2,7 +2,11 @@ from .algorithms import * # noqa: F401,F403 from .architectures import * # noqa: F401,F403 from .distillers import * # noqa: F401,F403 +from .fake_quants import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .mutables import * # noqa: F401,F403 from .mutators import * # noqa: F401,F403 +from .observers import * # noqa: F401,F403 +from .quantizers import * # noqa: F401,F403 from .task_modules import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 3cef96dfe..89a11b899 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -7,6 +7,7 @@ BigNAS, BigNASDDP, Darts, DartsDDP) from .pruning import DCFF, DMCP, DMCPDDP, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm +from .quantization import GeneralQuant __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', @@ -14,5 +15,5 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP', 'DMCP', 'DMCPDDP' + 'BigNASDDP', 'DMCP', 'DMCPDDP', 'GeneralQuant' ] diff --git a/mmrazor/models/algorithms/quantization/__init__.py b/mmrazor/models/algorithms/quantization/__init__.py new file mode 100644 index 000000000..84c25bbc0 --- /dev/null +++ b/mmrazor/models/algorithms/quantization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import GeneralQuant + +__all__ = ['GeneralQuant'] diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/base.py new file mode 100644 index 000000000..718b08725 --- /dev/null +++ b/mmrazor/models/algorithms/quantization/base.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import torch +from mmengine.structures import BaseDataElement +from torch.fx import GraphModule + +from mmrazor.registry import MODELS +from ..base import BaseAlgorithm + +LossResults = Dict[str, torch.Tensor] +TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] +PredictResults = List[BaseDataElement] +ForwardResults = Union[LossResults, TensorResults, PredictResults] + + +@MODELS.register_module() +class GeneralQuant(BaseAlgorithm): + """General quantization. + + Args: + Args: + architecture (dict | :obj:`BaseModel`): The config of + :class:`BaseModel` or built model. + quantizer (dict | :obj:`BaseModel`): The config of + :class:`BaseQuantizer` or built model. + data_preprocessor (dict | torch.nn.Module | None): The pre-process + config of :class:`BaseDataPreprocessor`. Defaults to None. + init_cfg (dict): The weight initialized config for + :class:`BaseModule`. + """ + + def __init__(self, + architecture, + quantizer, + data_preprocessor=None, + init_cfg=None): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') + + super().__init__(architecture, data_preprocessor, init_cfg) + self.quantizer = MODELS.build(quantizer) + self.observers_enabled = True + self.fake_quants_enabled = True + self.gen_graphs(self.architecture) + + def gen_graphs(self, model): + self.quantizer._swap_ff_with_fxff(model) + tracer = self.quantizer.tracer + for mode in ['tensor', 'loss', 'predict']: + concrete_args = {'mode': mode} + if mode == 'tensor': + self.graph_tensor = GraphModule( + model, tracer.trace(model, concrete_args=concrete_args)) + if mode == 'loss': + self.graph_loss = GraphModule( + model, tracer.trace(model, concrete_args=concrete_args)) + if mode == 'predict': + self.graph_predict = GraphModule( + model, tracer.trace(model, concrete_args=concrete_args)) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor') -> ForwardResults: + + if mode == 'loss': + return self.graph_loss(inputs, data_samples, mode) + elif mode == 'tensor': + return self.graph_tensor(inputs, data_samples, mode) + elif mode == 'predict': + return self.graph_predict(inputs, data_samples, mode) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + def calib_step(self, data): + data = self.data_preprocessor(data, False) + return self._run_forward(data, mode='tensor') + + def prepare(self, mode='tensor'): + assert mode in ['tensor', 'loss', 'predict'] + if mode == 'tensor': + graph = self.graph_tensor + elif mode == 'loss': + graph = self.graph_loss + else: + graph = self.graph_predict + self.architecture = self.quantizer.prepare(self.architecture, graph) + + def convert(self): + self.architecture = self.quantizer.convert(self.architecture) + + @property + def state(self): + return (self.observers_enabled, self.fake_quants_enabled) + + @state.setter + def state(self, state): + observers_enabled, fake_quants_enabled = state + for name, submodule in self.architecture.named_modules(): + if isinstance(submodule, torch.quantization.FakeQuantize): + if observers_enabled: + submodule.enable_observer() + else: + submodule.disable_observer() + + if fake_quants_enabled: + submodule.enable_fake_quant() + else: + submodule.disable_fake_quant() + + self.observers_enabled = observers_enabled + self.fake_quants_enabled = fake_quants_enabled diff --git a/mmrazor/models/fake_quants/__init__.py b/mmrazor/models/fake_quants/__init__.py new file mode 100644 index 000000000..cea7708a2 --- /dev/null +++ b/mmrazor/models/fake_quants/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adaround import AdaRoundFakeQuantize +from .base import FakeQuantize +from .lsq import LearnableFakeQuantize +from .qdrop import QDropFakeQuantize + +__all__ = [ + 'FakeQuantize', 'AdaRoundFakeQuantize', 'QDropFakeQuantize', + 'LearnableFakeQuantize' +] diff --git a/mmrazor/models/fake_quants/adaround.py b/mmrazor/models/fake_quants/adaround.py new file mode 100644 index 000000000..9388f1aa4 --- /dev/null +++ b/mmrazor/models/fake_quants/adaround.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS +from .base import FakeQuantize + +_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 + + +@MODELS.register_module() +class AdaRoundFakeQuantize(FakeQuantize): + + def __init__(self, observer, **observer_kwargs): + super().__init__(observer, **observer_kwargs) + self.adaround = False + + def init(self, weight_tensor: torch.Tensor): + self.adaround = True + self.observer_enabled[0] = 0 + self.fake_quant_enabled[0] = 1 + + # self.soft_targets = False # delete this + self.gamma = -0.1 + self.zeta = 1.1 + self.init_alpha(x=weight_tensor.data.clone()) + + def init_alpha(self, x: torch.Tensor): + if self.ch_axis != -1: + new_shape = [1] * len(x.shape) + new_shape[self.ch_axis] = x.shape[self.ch_axis] + scale = self.scale.data.reshape(new_shape) + else: + scale = self.scale.data + x_floor = torch.floor(x / scale) + rest = (x / scale) - x_floor # rest of rounding [0, 1) + alpha = -torch.log((self.zeta - self.gamma) / + (rest - self.gamma) - 1) # => sigmoid(alpha) = rest + self.alpha = Parameter(alpha) + + def rectified_sigmoid(self): + """Function to generate rounding mask. + + Args: + x (torch.Tensor): + zeta (torch.Tensor): + gamma (torch.Tensor): + Returns: + torch.Tensor: + """ + return ((self.zeta - self.gamma) * torch.sigmoid(self.alpha) + + self.gamma).clamp(0, 1) + + def adaround_forward(self, x, hard_value=False): + if self.ch_axis != -1: + new_shape = [1] * len(x.shape) + new_shape[self.ch_axis] = x.shape[self.ch_axis] + scale = self.scale.reshape(new_shape) + zero_point = self.zero_point.reshape(new_shape) + x = torch.floor(x / scale) + if hard_value: + x += (self.alpha >= 0).float() + else: + x += self.rectified_sigmoid(self.alpha, self.zeta, self.gamma) + x += zero_point + x = torch.clamp(x, self.quant_min, self.quant_max) + x = (x - zero_point) * scale + return x + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if not self.adaround: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, + self.zero_point.long() + if _version_under_1100 else self.zero_point, + self.ch_axis, self.quant_min, self.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale.item(), int(self.zero_point.item()), + self.quant_min, self.quant_max) + else: + if not hasattr(self, 'alpha'): + raise NotImplementedError + X = self.adaround_forward(X) + return X diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py new file mode 100644 index 000000000..13f8a1e43 --- /dev/null +++ b/mmrazor/models/fake_quants/base.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import FakeQuantizeBase + +from mmrazor.models.utils import (_is_float_qparams, _is_per_channel, + _is_per_tensor, _is_symmetric_quant) +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class FakeQuantize(FakeQuantizeBase): + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__(self, observer, **observer_kwargs): + super().__init__() + self.activation_post_process = observer(**observer_kwargs) + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + if _is_float_qparams(self.activation_post_process.qscheme): + zero_point_dtype = torch.float + else: + zero_point_dtype = torch.int + self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) + self.register_buffer('zero_point', + torch.tensor([0], dtype=zero_point_dtype)) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + assert _is_per_channel(self.qscheme) or \ + _is_per_tensor(self.qscheme), \ + 'Only per channel and per tensor quantization are supported in ' \ + 'fake quantize' + ' got qscheme: ' + str(self.qscheme) + self.is_per_channel = _is_per_channel(self.qscheme) + + bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.is_pot_scale = self.activation_post_process.is_pot_scale + self.is_symmetric_quant = _is_symmetric_quant(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): + return self.activation_post_process.calculate_qparams() + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale, self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + return X + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ' \ + 'ch_axis={}, scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, self.dtype, + self.qscheme, self.ch_axis, self.scale, self.zero_point) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to + # manually specify serialization here. + super(FakeQuantize, self)._save_to_state_dict(destination, prefix, + keep_vars) + destination[prefix + 'scale'] = self.scale + destination[prefix + 'zero_point'] = self.zero_point + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + # Removing this function throws an error that the the size of the + # loaded tensor does not match the original size i.e., These buffers + # start out with numel 0 and become numel 1 once they have their + # first forward pass. + local_state = ['scale', 'zero_point'] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == 'scale': + self.scale.resize_(val.shape) + else: + assert name == 'zero_point' + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here + # since we do not call the `_load_from_state_dict` function + # defined module.py + if torch.jit.is_scripting(): + if name == 'scale': + self.scale.copy_(val) + else: + assert name == 'zero_point' + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super(FakeQuantize, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py new file mode 100644 index 000000000..10970a6a3 --- /dev/null +++ b/mmrazor/models/fake_quants/lsq.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS +from ..utils import PerChannelLoadHook, _is_symmetric_quant, is_tracing_state +from .base import FakeQuantize + + +@MODELS.register_module() +class LearnableFakeQuantize(FakeQuantize): + r""" This is an extension of the FakeQuantize module in fake_quantize.py, + which supports more generalized lower-bit quantization and support learning + of the scale and zero point parameters through backpropagation. For + literature references, please see the class + `_LearnableFakeQuantizePerTensorOp`. In addition to the attributes in the + original FakeQuantize module, the `_LearnableFakeQuantize` module also + includes the following attributes to support quantization parameter + learning. + """ + + def __init__(self, + observer, + scale=1., + zero_point=0., + use_grad_scaling=True, + **observer_kwargs): + super(LearnableFakeQuantize, self).__init__(observer, + **observer_kwargs) + self.use_grad_scaling = use_grad_scaling + self.scale = Parameter(torch.tensor([scale])) + self.zero_point = Parameter(torch.tensor([zero_point])) + self.register_buffer('eps', + torch.tensor([torch.finfo(torch.float32).eps])) + # Check whether the module will load a state dict; + # Initialize the shape of per-channel 'scale' and + # 'zero-point' before copying values + self.load_state_dict_hook = PerChannelLoadHook(self) + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={},'\ + 'scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis, + self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape), # noqa: E501 + self.zero_point if self.ch_axis == -1 else 'List') + + def forward(self, X): + # Learnable fake quantize have to zero_point.float() + # to make it learnable. + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = \ + self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + + if self.ch_axis != -1: + self.scale.data = torch.ones_like(_scale) + self.zero_point.data = torch.zeros_like(_zero_point.float()) + + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point.float()) + else: + self.scale.data.abs_() + self.scale.data.clamp_(min=self.eps.item()) + + if self.fake_quant_enabled[0] == 1: + if _is_symmetric_quant(self.qscheme): + self.zero_point.data.zero_() + else: + self.zero_point.data.clamp_(self.quant_min, + self.quant_max).float() + + if self.is_per_channel: + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * + self.quant_max)**0.5 + else: + grad_factor = 1.0 + if is_tracing_state(): + X = FakeQuantizeLearnablePerchannelAffine.apply( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + X = _fake_quantize_learnable_per_channel_affine_training( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 + else: + grad_factor = 1.0 + X = torch._fake_quantize_learnable_per_tensor_affine( + X, self.scale, self.zero_point, self.quant_min, + self.quant_max, grad_factor) + return X + + +def _fake_quantize_learnable_per_channel_affine_training( + x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): + zero_point = (zero_point.round() - zero_point).detach() + zero_point + new_shape = [1] * len(x.shape) + new_shape[ch_axis] = x.shape[ch_axis] + scale = grad_scale(scale, grad_factor).reshape(new_shape) + zero_point = grad_scale(zero_point, grad_factor).reshape(new_shape) + x = x / scale + zero_point + x = (x.round() - x).detach() + x + x = torch.clamp(x, quant_min, quant_max) + return (x - zero_point) * scale + + +def grad_scale(t, scale): + return (t - (t * scale)).detach() + (t * scale) + + +class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, + grad_factor): + return _fake_quantize_learnable_per_channel_affine_training( + x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor) + + @staticmethod + def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, + grad_factor): + return g.op( + '::FakeQuantizeLearnablePerchannelAffine', + x, + scale, + zero_point, + quant_min_i=quant_min, + quant_max_i=quant_max) diff --git a/mmrazor/models/fake_quants/qdrop.py b/mmrazor/models/fake_quants/qdrop.py new file mode 100644 index 000000000..e2e13bfc0 --- /dev/null +++ b/mmrazor/models/fake_quants/qdrop.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS +from .base import FakeQuantize + + +@MODELS.register_module() +class QDropFakeQuantize(FakeQuantize): + + def __init__(self, observer, **observer_kwargs): + super().__init__(observer, **observer_kwargs) + self.scale = Parameter(torch.tensor([1.0], dtype=torch.float)) + self.prob = 1.0 + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + x_orig = X + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale, self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + if self.prob < 1.0: + x_prob = torch.where(torch.rand_like(X) < self.prob, X, x_orig) + return x_prob + return X diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 65e2108fd..3509acd5c 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ab_loss import ABLoss +from .adaround_loss import AdaRoundLoss from .at_loss import ATLoss from .crd_loss import CRDLoss from .cross_entropy_loss import CrossEntropyLoss diff --git a/mmrazor/models/losses/adaround_loss.py b/mmrazor/models/losses/adaround_loss.py new file mode 100644 index 000000000..76c97977d --- /dev/null +++ b/mmrazor/models/losses/adaround_loss.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.logging import print_log + +from mmrazor.registry import MODELS + +_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) + + +@MODELS.register_module() +class AdaRoundLoss(nn.Module): + r'''loss function to calculate mse reconstruction loss and relaxation loss + use some tempdecay to balance the two losses. + ''' + + def __init__(self, + weight: float = 1., + iters: int = 10000, + beta_range: tuple = (20, 2), + warm_up: float = 0.0, + p: float = 2.): + self.weight = weight + self.loss_start = iters * warm_up + self.p = p + + self.temp_decay = LinearTempDecay( + iters, + warm_up=warm_up, + start_beta=beta_range[0], + end_beta=beta_range[1]) + self.count = 0 + + def forward(self, subgraph, pred, tgt): + """Compute the total loss for adaptive rounding: rec_loss is the + quadratic output reconstruction loss, round_loss is a regularization + term to optimize the rounding policy. + + :param pred: output from quantized model + :param tgt: output from FP model + :return: total loss function + """ + + def lp_loss(pred, tgt, p=2.0): + """loss function measured in L_p Norm.""" + return (pred - tgt).abs().pow(p).sum(1).mean() + + self.count += 1 + rec_loss = lp_loss(pred, tgt, p=self.p) + + beta = self.temp_decay(self.count) + if self.count < self.loss_start: + round_loss = 0 + else: + round_loss = 0 + for layer in subgraph.modules(): + if isinstance(layer, _ADAROUND_SUPPORT_TYPE): + round_vals = layer.weight_fake_quant.rectified_sigmoid() + round_loss += self.weight * (1 - ( + (round_vals - .5).abs() * 2).pow(beta)).sum() + + total_loss = rec_loss + round_loss + if self.count % 500 == 0: + print_log('Total loss:\t{:.3f} (rec_loss:{:.3f}, ' + 'round_loss:{:.3f})\tbeta={:.2f}\tcount={}'.format( + float(total_loss), float(rec_loss), + float(round_loss), beta, self.count)) + return total_loss + + +class LinearTempDecay: + + def __init__(self, t_max=10000, warm_up=0.2, start_beta=20, end_beta=2): + self.t_max = t_max + self.start_decay = warm_up * t_max + self.start_beta = start_beta + self.end_beta = end_beta + + def __call__(self, t): + if t < self.start_decay: + return self.start_beta + elif t > self.t_max: + return self.end_beta + else: + rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) + return self.end_beta + (self.start_beta - self.end_beta) * \ + max(0.0, (1 - rel_t)) diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py index e730245d4..ea8681511 100644 --- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -212,6 +212,11 @@ def alias(self) -> str: """str: alias of the unit""" return self.name + @property + def alias(self) -> str: + """str: alias of the unit""" + return self.name + def config_template(self, with_init_args=False, with_channels=False) -> Dict: diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 910992e1e..2d83d48f7 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -12,10 +12,11 @@ from mmrazor.models.task_modules.tracer.channel_analyzer import ChannelAnalyzer from mmrazor.registry import MODELS, TASK_UTILS from ..base_mutator import BaseMutator +from ..group_mixin import GroupMixin @MODELS.register_module() -class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): +class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin): """ChannelMutator manages the pruning structure of a model. Args: @@ -45,6 +46,10 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): demo_input=(1, 3, 224, 224), tracer_type='BackwardTracer') + custom_groups (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + init_cfg (dict, optional): initialization configuration dict for BaseModule. @@ -92,6 +97,10 @@ def __init__(self, self._parse_channel_unit_cfg( channel_unit_cfg) + if custom_groups is None: + custom_groups = [] + self._custom_groups = custom_groups + def prepare_from_supernet(self, supernet: Module) -> None: """Prepare from a model for pruning. @@ -229,10 +238,9 @@ def set_choices(self, choices: Dict[str, Any]) -> None: @property def current_choices(self) -> Dict: """Get current choices.""" - config = self.choice_template - for unit in self.mutable_units: - config[unit.name] = unit.current_choice - return config + current_choices = dict() + for group_id, modules in self.search_groups.items(): + current_choices[group_id] = modules[0].current_choice @property def choice_template(self) -> Dict: diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py index c3da419bf..b00e0ef22 100644 --- a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py @@ -29,7 +29,7 @@ def __init__(self, tracer_type='BackwardTracer'), init_cfg: Optional[Dict] = None) -> None: - super().__init__(channel_unit_cfg, parse_cfg, init_cfg) + super().__init__(channel_unit_cfg, parse_cfg, None, init_cfg) self.subnets = self._prepare_subnets(self.units_cfg) diff --git a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py new file mode 100644 index 000000000..1f639ed28 --- /dev/null +++ b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS +from ...mutables import DiffMutableModule +from .module_mutator import ModuleMutator + + +@MODELS.register_module() +class DiffModuleMutator(ModuleMutator): + """Differentiable mutable based mutator. + + `DiffModuleMutator` is responsible for mutating `DiffMutableModule`, + `DiffMutableOP`, `DiffChoiceRoute` and `GumbelChoiceRoute`. + The architecture parameters of the mutables are retained + in `DiffModuleMutator`. + + Args: + custom_group (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_groups: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(custom_groups=custom_groups, init_cfg=init_cfg) + + def build_arch_param(self, num_choices) -> nn.Parameter: + """Build learnable architecture parameters.""" + return nn.Parameter(torch.randn(num_choices) * 1e-3) + + def prepare_from_supernet(self, supernet: nn.Module) -> None: + """Inherit from ``BaseMutator``'s, generate `arch_params` in DARTS. + + Args: + supernet (:obj:`torch.nn.Module`): The architecture to be used + in your algorithm. + """ + + super().prepare_from_supernet(supernet) + self.arch_params = self.build_arch_params() + self.modify_supernet_forward(self.arch_params) + + def build_arch_params(self): + """This function will build many arch params, which are generally used + in differentiable search algorithms, such as Darts' series. Each + group_id corresponds to an arch param, so the Mutables with the same + group_id share the same arch param. + + Returns: + torch.nn.ParameterDict: the arch params are got by `search_groups`. + """ + + arch_params = nn.ParameterDict() + + for group_id, modules in self.search_groups.items(): + group_arch_param = self.build_arch_param(modules[0].num_choices) + arch_params[str(group_id)] = group_arch_param + + return arch_params + + def modify_supernet_forward(self, arch_params): + """Modify the DiffMutableModule's default arch_param in forward. + + In MMRazor, the `arch_param` is along to `DiffModuleMutator`, while the + `DiffMutableModule` needs `arch_param` in the forward. Here we use + partial function to assign the corresponding `arch_param` to each + `DiffMutableModule`. + """ + + for group_id, mutables in self.search_groups.items(): + for m in mutables: + m.set_forward_args(arch_param=arch_params[str(group_id)]) + + def sample_choices(self): + """Sampling by search groups. + + The sampling result of the first mutable of each group is the sampling + result of this group. + + Returns: + Dict[int, Any]: Random choices dict. + """ + + choices = dict() + for group_id, mutables in self.search_groups.items(): + arch_param = self.arch_params[str(group_id)] + choice = mutables[0].sample_choice(arch_param) + choices[group_id] = choice + return choices + + def set_choices(self, choices: Dict[int, Any]) -> None: + """Set mutables' current choice according to choices sample by + :func:`sample_choices`. + + Args: + choices (Dict[int, Any]): Choices dict. The key is group_id in + search groups, and the value is the sampling results + corresponding to this group. + """ + for group_id, mutables in self.search_groups.items(): + choice = choices[group_id] + for m in mutables: + m.current_choice = choice + + @property + def mutable_class_type(self): + """Differentiable mutable class type. + + Returns: + Type[DiffMutableModule]: Class type of differentiable mutable. + """ + return DiffMutableModule diff --git a/mmrazor/models/mutators/module_mutator/module_mutator.py b/mmrazor/models/mutators/module_mutator/module_mutator.py new file mode 100644 index 000000000..f30e933e0 --- /dev/null +++ b/mmrazor/models/mutators/module_mutator/module_mutator.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Dict, List, Optional, Type + +from torch.nn import Module + +from ..base_mutator import MUTABLE_TYPE, BaseMutator +from ..group_mixin import GroupMixin + + +class ModuleMutator(BaseMutator[MUTABLE_TYPE], GroupMixin): + """The base class for mutable based mutator. + + All subclass should implement the following APIS: + + - ``mutable_class_type`` + + Args: + custom_groups (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_groups: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(init_cfg) + + if custom_groups is None: + custom_groups = [] + self._custom_groups = custom_groups + self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None + + # TODO + # should be a class property + @property + @abstractmethod + def mutable_class_type(self) -> Type[MUTABLE_TYPE]: + """Corresponding mutable class type. + + Returns: + Type[MUTABLE_TYPE]: Mutable class type. + """ + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + Note: + For mutable based mutator, we need to build search group first. + + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + self._search_groups = self.build_search_groups(supernet, + self.mutable_class_type, + self._custom_groups) + + @property + def name2mutable(self) -> Dict[str, MUTABLE_TYPE]: + """Search space of supernet. + + Note: + To get the mapping: module name to mutable. + + Raises: + RuntimeError: Called before search space has been parsed. + + Returns: + Dict[str, MUTABLE_TYPE]: The name2mutable dict. + """ + if self._name2mutable is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access name2mutable!') + return self._name2mutable + + @property + def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]: + """Search group of supernet. + + Note: + For mutable based mutator, the search group is composed of + corresponding mutables. + + Raises: + RuntimeError: Called before search group has been built. + + Returns: + Dict[int, List[MUTABLE_TYPE]]: Search group. + """ + if self._search_groups is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access search group!') + return self._search_groups diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py new file mode 100644 index 000000000..22af9bae9 --- /dev/null +++ b/mmrazor/models/observers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .minmax import EMAMinMaxObserver, MinMaxObserver +from .mse import MSEObserver + +__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver'] diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py new file mode 100644 index 000000000..e10738664 --- /dev/null +++ b/mmrazor/models/observers/base.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch.ao.quantization.observer import UniformQuantizationObserverBase + +from mmrazor.models.utils import pot_quantization, sync_tensor + +# from mmengine.model import BaseModule + + +class BaseObserver(UniformQuantizationObserverBase): + """Modified torch quantization observer. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, it will follow + the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow + the 8-bit setup. + ch_axis (int, optional): Channel axis index. Defaults to -1. + is_pot_scale (bool, optional): Indicate whether scale is power of two. + Defaults to False. + eps: Epsilon value for float32. + Defaults to `torch.finfo(torch.float32).eps`. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + factory_kwargs, eps) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer('min_val', + torch.tensor(float('inf'), **factory_kwargs)) + self.register_buffer('max_val', + torch.tensor(float('-inf'), **factory_kwargs)) + self.ch_axis = ch_axis + self.is_pot_scale = is_pot_scale + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters.""" + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + scale.data = sync_tensor(scale).data + zero_point.data = sync_tensor(zero_point).data + if self.is_pot_scale: + scale = pot_quantization(scale) + return scale, zero_point + + @torch.jit.export + def extra_repr(self): + return 'min_val={}, max_val={} ch_axis={} is_pot_scale={}'.format( + self.min_val, self.max_val, self.ch_axis, self.is_pot_scale) + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float('inf'))) + self.max_val.copy_(torch.tensor(float('-inf'))) diff --git a/mmrazor/models/observers/lsq_observer.py b/mmrazor/models/observers/lsq_observer.py new file mode 100644 index 000000000..d9b96d7a8 --- /dev/null +++ b/mmrazor/models/observers/lsq_observer.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch + +from mmrazor.registry import MODELS +from ..utils import _is_symmetric_quant, pot_quantization, sync_tensor +from .base import BaseObserver + + +@MODELS.register_module() +class LSQObserver(BaseObserver): + """Observer for `LEARNED STEP SIZE QUANTIZATION`""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, is_pot_scale, factory_kwargs, eps) + + self.tensor_norm = None + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + self.tensor_norm = x.abs().mean() + self.min_val, self.max_val = torch._aminmax(x) + else: + # compute channel-wise mean + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + self.tensor_norm = y.abs().mean(1) + self.min_val, self.max_val = torch._aminmax(y, 1) + + return x + + def calculate_qparams(self): + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + zero_point = torch.zeros_like(self.tensor_norm) + sync_tensor(scale) + sync_tensor(zero_point) + if self.is_pot_scale: + scale = pot_quantization(scale) + if not _is_symmetric_quant(self.qscheme): + zero_point = self.quant_min - torch.round(self.min_val / scale) + return scale, zero_point diff --git a/mmrazor/models/observers/minmax.py b/mmrazor/models/observers/minmax.py new file mode 100644 index 000000000..099296536 --- /dev/null +++ b/mmrazor/models/observers/minmax.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import MODELS +from .base import BaseObserver + + +@MODELS.register_module() +class MinMaxObserver(BaseObserver): + """Min max observer.""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super(MinMaxObserver, self).__init__(dtype, qscheme, reduce_range, + quant_min, quant_max, ch_axis, + is_pot_scale, factory_kwargs, eps) + if (self.qscheme == torch.per_tensor_symmetric and self.reduce_range + and self.dtype == torch.quint8): + raise NotImplementedError('Cannot reduce range for symmetric \ + quantization for quint8') + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val = torch.min(self.min_val, min_val_cur) + max_val = torch.max(self.max_val, max_val_cur) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + return x + + +@MODELS.register_module() +class EMAMinMaxObserver(BaseObserver): + """Moving average min/max among batches.""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + ema_ratio=0.9, + factory_kwargs=None): + super(EMAMinMaxObserver, + self).__init__(dtype, qscheme, reduce_range, quant_min, + quant_max, ch_axis, is_pot_scale, factory_kwargs) + self.ema_ratio = ema_ratio + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + + if self.max_val.numel() <= 1 and self.max_val.isinf(): + self.min_val = min_val_cur + self.max_val = max_val_cur + else: + self.min_val = self.min_val * self.ema_ratio + min_val_cur * ( + 1.0 - self.ema_ratio) + self.max_val = self.max_val * self.ema_ratio + max_val_cur * ( + 1.0 - self.ema_ratio) + return x diff --git a/mmrazor/models/observers/mse.py b/mmrazor/models/observers/mse.py new file mode 100644 index 000000000..f85abd902 --- /dev/null +++ b/mmrazor/models/observers/mse.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import MODELS +from .base import BaseObserver + +_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 + + +@MODELS.register_module() +class MSEObserver(BaseObserver): + """MSE observer.""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + p=2.0, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, is_pot_scale, factory_kwargs, eps) + self.p = p + + def lp_loss(self, pred, tgt, dim=None): + """loss function measured in L_p Norm.""" + return (pred - tgt).abs().pow( + self.p).mean(dim) if dim else (pred - + tgt).abs().pow(self.p).mean() + + def mse(self, + x: torch.Tensor, + x_min: torch.Tensor, + x_max: torch.Tensor, + iter=80): + best_score = 1e+10 + best_min, best_max = torch.tensor( + [1.0], dtype=torch.float), torch.tensor([1.0], dtype=torch.float) + best_min.copy_(x_min) + best_max.copy_(x_max) + for i in range(iter): + new_min = x_min * (1.0 - (i * 0.01)) + new_max = x_max * (1.0 - (i * 0.01)) + scale, zero_point = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_tensor_affine( + x, scale.item(), int(zero_point.item()), self.quant_min, + self.quant_max) + score = self.lp_loss(x_q, x) + if score < best_score: + best_score = score + best_min, best_max = new_min, new_max + return best_min, best_max + + def mse_perchannel(self, + x: torch.Tensor, + x_min: torch.Tensor, + x_max: torch.Tensor, + iter=80, + ch_axis=0): + assert x_min.shape == x_max.shape + assert ch_axis >= 0, f'{ch_axis}' + best_score = 1e+10 * torch.ones_like(x_min) + best_min, best_max = x_min.clone(), x_max.clone() + reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis]) + for i in range(iter): + new_min = x_min * (1.0 - (i * 0.01)) + new_max = x_max * (1.0 - (i * 0.01)) + scale, zero_point = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_channel_affine( + x, scale, + zero_point.long() if _version_under_1100 else zero_point, + ch_axis, self.quant_min, self.quant_max) + score = self.lp_loss(x_q, x, reduce_dim) + update_idx = (score < best_score) + best_score[update_idx] = score[update_idx] + best_min[update_idx] = new_min[update_idx] + best_max[update_idx] = new_max[update_idx] + return best_min, best_max + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.clone().detach().to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = self.mse( + x, min_val_cur, max_val_cur, iter=95) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + x_channel = x.permute(new_axis_list) + y = torch.flatten(x_channel, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = self.mse_perchannel( + x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) + + self.min_val = torch.min(self.min_val, min_val_cur) + self.max_val = torch.max(self.max_val, max_val_cur) + return x + + +@MODELS.register_module() +class EMAMSEObserver(MSEObserver): + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + p=2.0, + ema_ratio=0.9, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, is_pot_scale, p, factory_kwargs, eps) + self.ema_ratio = ema_ratio + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.clone().detach().to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = self.mse( + x, min_val_cur, max_val_cur, iter=95) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + x_channel = x.permute(new_axis_list) + y = torch.flatten(x_channel, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = self.mse_perchannel( + x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) + + if self.max_val.numel() <= 1 and self.max_val.isinf(): + self.min_val = min_val_cur + self.max_val = max_val_cur + else: + self.min_val = self.min_val * self.ema_ratio + min_val_cur * ( + 1.0 - self.ema_ratio) + self.max_val = self.max_val * self.ema_ratio + max_val_cur * ( + 1.0 - self.ema_ratio) + return x diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py new file mode 100644 index 000000000..e56902eba --- /dev/null +++ b/mmrazor/models/quantizers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import CustomQuantizer +from .trt_quantizer import TensorRTQuantizer + +__all__ = ['CustomQuantizer', 'TensorRTQuantizer'] diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py new file mode 100644 index 000000000..ab4cf190a --- /dev/null +++ b/mmrazor/models/quantizers/base.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +from mmengine.model import BaseModule +from torch.ao.quantization import QConfig +from torch.ao.quantization.fx import prepare +from torch.ao.quantization.quantize_fx import _convert_fx, _fuse_fx + +from mmrazor.models.task_modules.tracer import CustomTracer +from mmrazor.models.utils import (check_is_valid_convert_custom_config_dict, + check_is_valid_prepare_custom_config_dict, + check_is_valid_qconfig_dict, + get_custom_module_class_keys) +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import (CheckArgs, DefalutQconfigs, + QuantizeScheme, SupportQtypes) + + +@MODELS.register_module() +class CustomQuantizer(BaseModule): + """Configurable quantizer, base class of quantizers. + + Args: + qconfig (Dict, optional): QConfig. Defaults to DefalutQconfigs['default']. # noqa: E501 + is_qat (bool, optional): Is QAT ro not. Defaults to True. + skipped_methods (List, optional): Skipped methods list for tracer. + Defaults to None. + prepare_custom_config_dict (Dict, optional): `PrepareCustomConfig` + from `torch.quantization.fx`. Defaults to None. + convert_custom_config_dict (Dict, optional): `ConvertCustomConfig` + from `torch.quantization.fx`. Defaults to None. + equalization_qconfig_dict (Dict, optional): Custom `QConfig` effects + on all modules. Defaults to None. + _remove_qconfig (Dict, optional): Remove qconfig at the end of + `_convert_fx`. Defaults to True. + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, + qconfig: Dict = DefalutQconfigs['default'], + is_qat: bool = True, + skipped_methods: List = None, + prepare_custom_config_dict: Dict = None, + convert_custom_config_dict: Dict = None, + equalization_qconfig_dict: Dict = None, + _remove_qconfig: bool = True, + init_cfg: Dict = None): + super().__init__(init_cfg) + if self.check_qconfig(qconfig): + qconfig = self.qconfig_convert(qconfig) + self.qconfig_dict = {'': qconfig} + else: + raise ValueError('qconfig is incorrect!') + + if prepare_custom_config_dict is None: + self.prepare_custom_config_dict = {} + else: + self.prepare_custom_config_dict = prepare_custom_config_dict + if convert_custom_config_dict is None: + self.convert_custom_config_dict = {} + else: + self.convert_custom_config_dict = convert_custom_config_dict + if equalization_qconfig_dict is None: + self.equalization_qconfig_dict = {} + else: + self.equalization_qconfig_dict = equalization_qconfig_dict + + check_is_valid_qconfig_dict(self.qconfig_dict) + check_is_valid_prepare_custom_config_dict( + self.prepare_custom_config_dict) + check_is_valid_convert_custom_config_dict( + self.convert_custom_config_dict) + check_is_valid_qconfig_dict(self.equalization_qconfig_dict) + + self.is_qat = is_qat + self.skipped_methods = skipped_methods + self._remove_qconfig = _remove_qconfig + self.tracer = self.build_tracer() + + def prepare(self, model, graph_module): + + preserved_attributes = self.prepare_custom_config_dict.get( + 'preserved_attributes', []) + for attr_name in preserved_attributes: + setattr(graph_module, attr_name, getattr(model, attr_name)) + + graph_module = self.fuse_model(graph_module) + + prepared = prepare( + graph_module, + self.qconfig_dict, + self.is_qat, + self.tracer.node_name_to_scope, + prepare_custom_config_dict=self.prepare_custom_config_dict, + equalization_qconfig_dict=self.equalization_qconfig_dict + ) # type: ignore[operator] + + for attr_name in preserved_attributes: + setattr(prepared, attr_name, getattr(model, attr_name)) + return prepared + + def convert(self, graph_module): + quantized = _convert_fx( + graph_module, + is_reference=False, + convert_custom_config_dict=self.convert_custom_config_dict, + _remove_qconfig=self._remove_qconfig, + qconfig_dict=self.qconfig_dict) + return quantized + + def check_qconfig(self, qconfig): + is_pass = True + for arg in CheckArgs: + if arg == 'qtype': + if qconfig[arg] in SupportQtypes and arg in qconfig.keys(): + continue + else: + is_pass = False + break + else: + if isinstance(qconfig[arg], dict) and arg in qconfig.keys(): + continue + else: + is_pass = False + break + return is_pass + + def qconfig_convert(self, qconfig): + self.w_qscheme = QuantizeScheme(**qconfig['w_qscheme']) + self.a_qscheme = QuantizeScheme(**qconfig['a_qscheme']) + w_observer = MODELS.get(qconfig['w_observer']['type']) + w_observer_kwargs = self.w_qscheme.to_observer_params() + a_observer = MODELS.get(qconfig['a_observer']['type']) + a_observer_kwargs = self.a_qscheme.to_observer_params() + self.w_observer = MODELS.get(qconfig['w_observer']['type']).with_args( + **self.w_qscheme.to_observer_params()) + self.a_observer = MODELS.get(qconfig['a_observer']['type']).with_args( + **self.a_qscheme.to_observer_params()) + self.w_fake_quant = MODELS.get( + qconfig['w_fake_quant']['type']).with_args( + observer=w_observer, **w_observer_kwargs) + self.a_fake_quant = MODELS.get( + qconfig['a_fake_quant']['type']).with_args( + observer=a_observer, **a_observer_kwargs) + + torch_qconfig = QConfig( + weight=self.w_fake_quant, activation=self.a_fake_quant) + return torch_qconfig + + def _swap_ff_with_fxff(self, model: torch.nn.Module) -> None: + r""" Swap FloatFunctional with FXFloatFunctional + """ + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self._swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.nn.quantized.FXFloatFunctional() + + def build_tracer(self): + skipped_module_names = self.prepare_custom_config_dict.get( + 'non_traceable_module_name', []) + skipped_module_classes = self.prepare_custom_config_dict.get( + 'non_traceable_module_class', []) + standalone_module_name_configs = self.prepare_custom_config_dict.get( + 'standalone_module_name', []) + skipped_module_names += [ + config[0] for config in standalone_module_name_configs + ] + + standalone_module_class_configs = self.prepare_custom_config_dict.get( + 'standalone_module_class', []) + skipped_module_classes += [ + config[0] for config in standalone_module_class_configs + ] + float_custom_module_classes = get_custom_module_class_keys( + self.prepare_custom_config_dict, + 'float_to_observed_custom_module_class') + skipped_module_classes += float_custom_module_classes + tracer = CustomTracer(self.skipped_methods, skipped_module_names, + skipped_module_classes) + # tracer = QuantizationTracer(skipped_module_names, + # skipped_module_classes) + return tracer + + def fuse_model(self, graph_module): + graph_module = _fuse_fx(graph_module, self.is_qat, + self.prepare_custom_config_dict) + return graph_module diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py new file mode 100644 index 000000000..cc8532a53 --- /dev/null +++ b/mmrazor/models/quantizers/trt_quantizer.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import DefalutQconfigs +from .base import CustomQuantizer + + +@MODELS.register_module() +class TensorRTQuantizer(CustomQuantizer): + """Quantizer for TensorRT backend.""" + + def __init__(self, + qconfig=DefalutQconfigs['tensorrt'], + is_qat=True, + skipped_methods=None, + prepare_custom_config_dict=None, + convert_custom_config_dict=None, + equalization_qconfig_dict=None, + _remove_qconfig=True, + init_cfg=None): + super().__init__(qconfig, is_qat, skipped_methods, + prepare_custom_config_dict, + convert_custom_config_dict, equalization_qconfig_dict, + _remove_qconfig, init_cfg) diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index c3ff8dd66..aab26fa40 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -2,6 +2,7 @@ from .backward_tracer import BackwardTracer from .channel_analyzer import ChannelAnalyzer # from .razor_tracer import RazorFxTracer +from .fx import CustomTracer, UntracedMethodRegistry, custom_symbolic_trace from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, @@ -10,5 +11,6 @@ __all__ = [ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', - 'ChannelAnalyzer' + 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', + 'custom_symbolic_trace' ] diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py new file mode 100644 index 000000000..29c93f83a --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .custom_tracer import (CustomTracer, UntracedMethodRegistry, + custom_symbolic_trace) + +__all__ = ['CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace'] diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py new file mode 100644 index 000000000..f69ec2269 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from types import FunctionType, MethodType +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import torch +from mmengine.utils import import_modules_from_strings +from torch._C import ScriptObject # type: ignore[attr-defined] +from torch.ao.quantization.quantize_fx import QuantizationTracer +from torch.fx import GraphModule, Tracer +from torch.fx._symbolic_trace import (Graph, _autowrap_check, + _patch_wrapped_functions, _Patcher) +from torch.fx.proxy import Proxy + +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ +# _orig_module_forward_train: Callable = models.BaseDenseHead.forward_train + + +class UntracedMethodRegistry: + """A `Descriptor` class which records untraced methods.""" + method_dict: Dict = dict() + tracer = None + + def __init__(self, method): + """_summary_ + + Args: + method (FunctionType): Function to be registered. + """ + self.method = method + self.instances: Dict = dict() + self.owner = None + + def __set_name__(self, owner, name): + self.owner = owner + self.name = name + wrapped = self.method_wrapper() + self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped) + + def __get__(self, instance, owner): + if instance is None: + return self.method + return MethodType(self.method, instance) + + def method_wrapper(self): + + @functools.wraps(self.method) + def wrapped_method(mod, *args, **kwargs): + + def method(*args, **kwargs): + return self.method(mod, *args, **kwargs) + + return self.tracer.call_method(mod, self.name, method, args, + kwargs) + + return wrapped_method + + +def custom_symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: + """Modified `symbolic_trace` function. + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be + traced and converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially + specialized. + + Returns: + _type_: _description_ + """ + tracer = CustomTracer() + graph = tracer.trace(root, concrete_args) + name = root.__class__.__name__ if isinstance( + root, torch.nn.Module) else root.__name__ + return GraphModule(tracer.root, graph, name) + + +class CustomTracer(QuantizationTracer): + + def __init__(self, + skipped_methods: List[str] = [], + skipped_module_names: List[str] = [], + skipped_module_classes: List[Callable] = [], + *args, + **kwargs): + """_summary_ + + Args: + skipped_methods (List[str], optional): Methods to be skipped while + tracing. Defaults to None. + skipped_module_names (List[str], optional): Modules to be skipped + while tracing. Defaults to None. + skipped_module_classes (List[str], optional): Class to be skipped + while tracing. Defaults to None. + """ + super(CustomTracer, self).__init__(skipped_module_names, + skipped_module_classes) + UntracedMethodRegistry.tracer = self # type: ignore + self.skipped_methods = skipped_methods + if self.skipped_methods: + self.register_skipped_methods() + + @staticmethod + def _check_valid_source(source): + """Check if the source's format is valid.""" + if not isinstance(source, str): + raise TypeError(f'source should be a str ' + f'instance, but got {type(source)}') + + assert len(source.split('.')) > 1, \ + 'source must have at least one `.`' + + def register_skipped_methods(self): + if not isinstance(self.skipped_methods, list): + self.skipped_methods = [self.skipped_methods] + for s_method in self.skipped_methods: + self._check_valid_source(s_method) + mod_str = '.'.join(s_method.split('.')[:-2]) + cls_str = s_method.split('.')[-2] + method_str = s_method.split('.')[-1] + + try: + mod = import_modules_from_strings(mod_str) + except ImportError: + raise ImportError(f'{mod_str} is not imported correctly.') + + imported_cls: type = getattr(mod, cls_str) + if not isinstance(imported_cls, type): + raise TypeError(f'{cls_str} should be a type ' + f'instance, but got {type(imported_cls)}') + assert hasattr(imported_cls, method_str), \ + f'{method_str} is not in {mod_str}.' + + method = getattr(imported_cls, method_str) + + method_registry = UntracedMethodRegistry(method) + method_registry.__set_name__(imported_cls, method_str) + + def call_method(self, m: torch.nn.Module, name, method, args, kwargs): + """Method that specifies the behavior of this ``Tracer`` when it + encounters a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf + module via ``is_leaf_module``. If it is, emit a ``call_module`` + node referring to ``m`` in the ``Graph``. Otherwise, call the + ``Module`` normally, tracing through the operations in its ``forward`` + function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be + invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a + ``call_module`` node was emitted, this is a ``Proxy`` value. + Otherwise, it is whatever value was returned from the ``Module`` + invocation. + """ + # module_qualified_name = self.path_of_module(m) + if not self.is_skipped_method(m): + return method(*args, **kwargs) + args = list(args) + args.insert(0, m) + args = tuple(args) + return self.create_proxy('call_method', name, args, kwargs) + + def trace(self, root, concrete_args=None): + if isinstance(root, torch.nn.Module): + self.root = root + fn = type(root).forward + self.submodule_paths = { + mod: name + for name, mod in root.named_modules() + } + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) + self.graph = Graph(tracer_cls=tracer_cls) + + # When we encounter a Tensor value that's not a parameter, we look if + # it is some other attribute on the model. Construct a dict mapping + # Tensor values to the qualified name here for efficiency. This is + # used downstream in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root(fn, + isinstance(root, torch.nn.Module), + concrete_args) + + # Reduce number of get_attr calls + parameter_proxy_cache: Dict[str, Proxy] = {} + + # Method dispatch on parameters is not recorded unless it's directly + # used. Thus, we need to insert a proxy when __getattr__ requests a + # parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self._module_getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, + getattr(getattr(mod, 'forward', mod), '__globals__', {}), + self._autowrap_function_ids) + return self.call_module(mod, forward, args, kwargs) + + with _Patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + '__getattr__', + module_getattr_wrapper, + deduplicate=False) + patcher.patch_method( + torch.nn.Module, + '__call__', + module_call_wrapper, + deduplicate=False) + + for name, value in UntracedMethodRegistry.method_dict.items(): + wrapped = value['wrapped'] + patcher.patch_method( + value['mod'], name, wrapped, deduplicate=False) + + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check(patcher, module.__dict__, + self._autowrap_function_ids) + self.create_node( + 'output', + 'output', (self.create_arg(fn(*args)), ), {}, + type_expr=fn.__annotations__.get('return', None)) + + self.submodule_paths = None + + return self.graph + + def is_skipped_method(self, m): + mods = tuple(value['mod'] + for value in UntracedMethodRegistry.method_dict.values()) + custom = isinstance(m, mods) + return custom + + def is_leaf_module(self, m: torch.nn.Module, + module_qualified_name: str) -> bool: + # return super().is_leaf_module(m, module_qualified_name) + leaf = super().is_leaf_module(m, module_qualified_name) + return leaf diff --git a/mmrazor/models/utils/quantization_util.py b/mmrazor/models/utils/quantization_util.py new file mode 100644 index 000000000..376096b67 --- /dev/null +++ b/mmrazor/models/utils/quantization_util.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Any, Dict, List, Optional, Set + +import torch + + +class PerChannelLoadHook: + + def __init__(self, module, hook_param=['scale', 'zero_point']): + self.hook = module._register_load_state_dict_pre_hook( + partial(self.hook_fn, module=module)) + self.hook_param = hook_param + + def hook_fn(self, state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs, module): + if module.ch_axis == -1: + # no per-channel parameters + return + for module_key, param in module._parameters.items(): + if module_key not in self.hook_param: + continue + candidate = prefix + module_key + if candidate in state_dict: + input_param = state_dict[candidate] + if param.shape != input_param.shape: + param.data = torch.ones_like( + input_param, dtype=param.dtype, device=param.device) + for module_key, param in module._buffers.items(): + if module_key not in self.hook_param: + continue + candidate = prefix + module_key + if candidate in state_dict: + input_param = state_dict[candidate] + if param.shape != input_param.shape: + param.data = torch.ones_like( + input_param, dtype=param.dtype, device=param.device) + + def close(self): + self.hook.remove() + + +USE_LINK = False +USE_DDP = False + +try: + import spring.linklink as link + assert link.is_initialized() + USE_LINK = True +except (ModuleNotFoundError, AssertionError): + import torch.distributed as dist + if torch.distributed.is_initialized(): + USE_DDP = True + + +def sync_tensor(tensor): + if USE_LINK: + if tensor.is_cuda is True: + tensor.data = tensor.data / link.get_world_size() + link.allreduce(tensor.data) + elif USE_DDP: + tensor.data = tensor.data / dist.get_world_size() + dist.all_reduce(tensor.data) + return tensor + + +def pot_quantization(tensor: torch.Tensor, mode='round'): + log2t = torch.log2(tensor) + if mode == 'round': + log2t = (torch.round(log2t) - log2t).detach() + log2t + else: + assert mode == 'floor' + log2t = (torch.floor(log2t) - log2t).detach() + log2t + return 2**log2t + + +def _is_per_channel(qscheme: 'torch.qscheme') -> bool: + return qscheme in [ + torch.per_channel_symmetric, torch.per_channel_affine, + torch.per_channel_affine_float_qparams + ] + + +def _is_per_tensor(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + + +def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + + +def is_tracing_state(): + return torch._C._get_tracing_state() + + +def _is_float_qparams(qscheme: 'torch.qscheme') -> bool: + return qscheme in [ + torch.per_channel_affine_float_qparams, + ] + + +def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], + dict_name: str) -> None: + r""" Checks if the given config_dict has the correct keys + Args: + `config_dict`: dictionary whose keys we want to check + """ + + for k in config_dict.keys(): + if k not in allowed_keys: + raise ValueError('Expected ' + dict_name + + ' to have the following keys: ' + + str(allowed_keys) + '. But found \'' + k + + '\' instead.') + + +def check_is_valid_qconfig_dict(qconfig_dict: Any) -> None: + r""" Checks if the given qconfig_dict has the correct keys + Args: + `qconfig_dict`: dictionary whose keys we want to check + """ + + qconfig_dict_allowed_keys = { + '', 'object_type', 'module_name_regex', 'module_name', + 'module_name_object_type_order' + } + check_is_valid_config_dict(qconfig_dict, qconfig_dict_allowed_keys, + 'qconfig_dict') + + +def check_is_valid_prepare_custom_config_dict( + prepare_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: + r""" Checks if the given prepare_custom_config_dict has the correct keys + Args: + `prepare_custom_config_dict`: customization configuration dictionary for + quantization tool + """ + if not prepare_custom_config_dict: + return + + prepare_custom_config_dict_allowed_keys = { + 'standalone_module_name', 'standalone_module_class', + 'float_to_observed_custom_module_class', 'non_traceable_module_name', + 'non_traceable_module_class', 'input_quantized_idxs', + 'output_quantized_idxs', 'preserved_attributes' + } + check_is_valid_config_dict(prepare_custom_config_dict, + prepare_custom_config_dict_allowed_keys, + 'prepare_custom_config_dict') + + +def check_is_valid_convert_custom_config_dict( + convert_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: + r""" Checks if the given convert_custom_config_dict has the correct keys + Args: + `convert_custom_config_dict`: dictionary for custom configurations for + convert function + """ + if not convert_custom_config_dict: + return + + convert_custom_config_dict_allowed_keys = { + 'observed_to_quantized_custom_module_class', 'preserved_attributes' + } + check_is_valid_config_dict(convert_custom_config_dict, + convert_custom_config_dict_allowed_keys, + 'convert_custom_config_dict') + + +def check_is_valid_fuse_custom_config_dict( + fuse_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: + r""" Checks if the given fuse_custom_config_dict has the correct keys + Args: + `fuse_custom_config_dict`: dictionary for custom configurations for + fuse_fx + """ + if not fuse_custom_config_dict: + return + + fuse_custom_config_dict_allowed_keys = {'preserved_attributes'} + check_is_valid_config_dict(fuse_custom_config_dict, + fuse_custom_config_dict_allowed_keys, + 'fuse_custom_config_dict') + + +def get_custom_module_class_keys(custom_config_dict, + custom_config_dict_key) -> List[Any]: + r""" Get all the unique custom module keys in the custom config dict + e.g. + Input: + custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule1: ObservedCustomModule + }, + "dynamic": { + CustomModule2: DynamicObservedCustomModule + }, + "weight_only": { + CustomModule3: WeightOnlyObservedCustomModule + }, + }, + } + Output: + # extract all the keys in "static", "dynamic" and "weight_only" dict + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes: Set[Any] = set() + custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) + for quant_mode in ['static', 'dynamic', 'weight_only']: + quant_mode_custom_module_config = custom_module_mapping.get( + quant_mode, {}) + quant_mode_custom_module_classes = set( + quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) diff --git a/mmrazor/structures/quantization/__init__.py b/mmrazor/structures/quantization/__init__.py new file mode 100644 index 000000000..fc2133bf2 --- /dev/null +++ b/mmrazor/structures/quantization/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backend_default_qconfigs import CheckArgs, DefalutQconfigs, SupportQtypes +from .qscheme import QuantizeScheme + +__all__ = ['QuantizeScheme', 'DefalutQconfigs', 'SupportQtypes', 'CheckArgs'] diff --git a/mmrazor/structures/quantization/backend_default_qconfigs.py b/mmrazor/structures/quantization/backend_default_qconfigs.py new file mode 100644 index 000000000..6a1fde183 --- /dev/null +++ b/mmrazor/structures/quantization/backend_default_qconfigs.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +SupportQtypes = ('affine') +CheckArgs = [ + 'qtype', 'w_qscheme', 'a_qscheme', 'w_fake_quant', 'a_fake_quant', + 'w_observer', 'a_observer' +] + +Default = dict( + qtype='affine', # noqa: E241 + w_qscheme=dict( + is_symmetry=True, + is_per_channel=True, + is_pot_scale=False, + bit=8, + symmetric_range=True), + a_qscheme=dict( + is_symmetry=True, + is_per_channel=False, + is_pot_scale=False, + bit=8, + symmetric_range=True), + w_fake_quant=dict(type='BaseFakeQuantize'), + a_fake_quant=dict(type='BaseFakeQuantize'), + w_observer=dict(type='MinMaxObserver'), + a_observer=dict(type='MinMaxObserver')) + +TensorRT = dict( + qtype='affine', # noqa: E241 + w_qscheme=dict( + is_symmetry=True, + is_per_channel=True, + is_pot_scale=False, + bit=8, + symmetric_range=True), + a_qscheme=dict( + is_symmetry=True, + is_per_channel=False, + is_pot_scale=False, + bit=8, + symmetric_range=True), + w_fake_quant=dict(type='LearnableFakeQuantize'), + a_fake_quant=dict(type='LearnableFakeQuantize'), + w_observer=dict(type='MinMaxObserver'), + a_observer=dict(type='EMAMinMaxObserver')) + +DefalutQconfigs = dict(default=Default, tensorrt=TensorRT) diff --git a/mmrazor/structures/quantization/qscheme.py b/mmrazor/structures/quantization/qscheme.py new file mode 100644 index 000000000..24c41832e --- /dev/null +++ b/mmrazor/structures/quantization/qscheme.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +class QuantizeScheme(object): + """Custom QScheme. Refer to: + https://github.com/pytorch/pytorch/blob/master/c10/core/QScheme.h. + + Args: + bit (int, optional): Bit number. Defaults to 8. + is_symmetry (bool, optional): Is symmetry quantization or not. Defaults + to True. + is_per_channel (bool, optional): Is per-channel quantization or not. + Defaults to False. + is_pot_scale (bool, optional): Indicate whether scale is power of two. + Defaults to False. + """ + + def __init__(self, + bit=8, + is_symmetry=True, + is_per_channel=False, + is_pot_scale=False, + **kwargs): + self.bit = bit + self.is_symmetry = is_symmetry + self.is_per_channel = is_per_channel + self.is_pot_scale = is_pot_scale + + if self.is_per_channel: + self.torch_qscheme = torch.per_channel_symmetric \ + if self.is_symmetry else torch.per_channel_affine + else: + self.torch_qscheme = torch.per_tensor_symmetric \ + if self.is_symmetry else torch.per_tensor_affine + if 'is_symmetric_range' in kwargs: + self.is_symmetric_range = kwargs['is_symmetric_range'] + del kwargs['is_symmetric_range'] + else: + self.is_symmetric_range = False + self.kwargs = kwargs + + def to_observer_params(self): + quant_min = 0 + quant_max = 2**self.bit - 1 + if self.is_symmetry: + quant_max = 2**(self.bit - 1) - 1 + if self.is_symmetric_range: + quant_min = -2**(self.bit - 1) + 1 + else: + quant_min = -2**(self.bit - 1) + + naive_para = { + 'quant_min': quant_min, + 'quant_max': quant_max, + 'dtype': torch.qint8 if self.is_symmetry else torch.quint8, + 'is_pot_scale': self.is_pot_scale, + 'qscheme': self.torch_qscheme, + 'reduce_range': False, + 'ch_axis': 0 if self.is_per_channel else -1 + } + naive_para.update(self.kwargs) + return naive_para + + def __str__(self): + return f'bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ + is_per_channel: {self.is_per_channel} / is_pot_scale: \ + {self.is_pot_scale} / extra_kwargs: {self.kwargs}' diff --git a/mmrazor/testing/__init__.py b/mmrazor/testing/__init__.py index 009dd844d..54dfd30ed 100644 --- a/mmrazor/testing/__init__.py +++ b/mmrazor/testing/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403 +from ._fx_models import * # noqa: F401, F403 diff --git a/mmrazor/testing/_fx_models.py b/mmrazor/testing/_fx_models.py new file mode 100644 index 000000000..969c4792d --- /dev/null +++ b/mmrazor/testing/_fx_models.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class ConvBNReLU(nn.Module): + + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: Union[int, Tuple[int, int]] = 1, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: Union[str, bool] = 'auto', + conv_cfg: Optional[Dict] = None, + norm_cfg: Optional[Dict] = None, + act_cfg: Dict = dict(type='ReLU'), + inplace: bool = True, + with_spectral_norm: bool = False, + padding_mode: str = 'zeros', + order: tuple = ('conv', 'norm', 'act'), + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__() + self.conv_module = ConvModule(in_channel, out_channel, kernel_size, + stride, padding, dilation, groups, bias, + conv_cfg, norm_cfg, act_cfg, inplace, + with_spectral_norm, padding_mode, order) + + def forward(self, x): + x = self.conv_module.conv(x) + x = self.conv_module.norm(x) + x = self.conv_module.activate(x) + return x diff --git a/tests/test_models/test_algorithms/test_general_quant.py b/tests/test_models/test_algorithms/test_general_quant.py new file mode 100644 index 000000000..94a2485bc --- /dev/null +++ b/tests/test_models/test_algorithms/test_general_quant.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + + +class ToyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # TODO + + +class TestGeneralQuant(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def test_init(self): + pass + + def test_prepare(self): + pass + + def test_convert(self): + pass + + def test_states(self): + pass + + def test_forward(self): + pass diff --git a/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py b/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py new file mode 100644 index 000000000..d6b670bb5 --- /dev/null +++ b/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + + +class TestLearnableFakeQuantize(TestCase): + + def test_init(self): + pass + + def test_repr(self): + pass + + def test_calculate_qparams(self): + pass + + def test_forward(self): + pass + + def test_load_state_dict(self): + pass + + def test_save_state_dict(self): + pass diff --git a/tests/test_models/test_mutators/test_diff_mutator.py b/tests/test_models/test_mutators/test_diff_mutator.py new file mode 100644 index 000000000..663637fc9 --- /dev/null +++ b/tests/test_models/test_mutators/test_diff_mutator.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch.nn as nn + +from mmrazor.models import * # noqa: F401,F403 +from mmrazor.models.mutables import DiffMutableModule +from mmrazor.models.mutators import DiffModuleMutator +from mmrazor.registry import MODELS + +MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) +MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) +MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) + + +class SearchableLayer(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + self.op1 = MODELS.build(mutable_cfg) + self.op2 = MODELS.build(mutable_cfg) + self.op3 = MODELS.build(mutable_cfg) + + def forward(self, x): + x = self.op1(x) + x = self.op2(x) + return self.op3(x) + + +class SearchableModel(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + self.slayer1 = SearchableLayer(mutable_cfg) + self.slayer2 = SearchableLayer(mutable_cfg) + self.slayer3 = SearchableLayer(mutable_cfg) + + def forward(self, x): + x = self.slayer1(x) + x = self.slayer2(x) + return self.slayer3(x) + + +class SearchableLayerAlias(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + mutable_cfg.update(alias='op1') + self.op1 = MODELS.build(mutable_cfg) + mutable_cfg.update(alias='op2') + self.op2 = MODELS.build(mutable_cfg) + mutable_cfg.update(alias='op3') + self.op3 = MODELS.build(mutable_cfg) + + def forward(self, x): + x = self.op1(x) + x = self.op2(x) + return self.op3(x) + + +class SearchableModelAlias(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + self.slayer1 = SearchableLayerAlias(mutable_cfg) + self.slayer2 = SearchableLayerAlias(mutable_cfg) + self.slayer3 = SearchableLayerAlias(mutable_cfg) + + def forward(self, x): + x = self.slayer1(x) + x = self.slayer2(x) + return self.slayer3(x) + + +class TestDiffModuleMutator(TestCase): + + def setUp(self): + self.MUTABLE_CFG = dict( + type='DiffMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + self.MUTATOR_CFG = dict( + type='DiffModuleMutator', + custom_groups=[['op1'], ['op2'], ['op3']]) + + def test_diff_mutator_diffop_layer(self) -> None: + model = SearchableLayer(self.MUTABLE_CFG) + mutator: DiffModuleMutator = MODELS.build(self.MUTATOR_CFG) + + mutator.prepare_from_supernet(model) + assert list(mutator.search_groups.keys()) == [0, 1, 2] + + def test_diff_mutator_diffop_model(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + assert list(mutator.search_groups.keys()) == [0, 1, 2] + + mutator.modify_supernet_forward(mutator.arch_params) + assert mutator.mutable_class_type == DiffMutableModule + + def test_diff_mutator_diffop_model_error(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3_error_key'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_diffop_alias(self) -> None: + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [['op1'], ['op2'], ['op3']] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + + assert list(mutator.search_groups.keys()) == [0, 1, 2] + + mutator.modify_supernet_forward(mutator.arch_params) + assert mutator.mutable_class_type == DiffMutableModule + + def test_diff_mutator_alias_module_name(self) -> None: + """Using both alias and module name for grouping.""" + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [['op1'], + [ + 'slayer1.op2', 'slayer2.op2', + 'slayer3.op2' + ], ['slayer1.op3', 'slayer2.op3']] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + + assert list(mutator.search_groups.keys()) == [0, 1, 2, 3] + + mutator.modify_supernet_forward(mutator.arch_params) + assert mutator.mutable_class_type == DiffMutableModule + + def test_diff_mutator_duplicate_keys(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer2.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_duplicate_key_alias(self) -> None: + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_illegal_key(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_sample_and_set_choices(self): + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + mutator.prepare_from_supernet(model) + choices = mutator.sample_choices() + mutator.set_choices(choices) + self.assertTrue(len(choices) == 3) + + +if __name__ == '__main__': + import unittest + unittest.main() diff --git a/tests/test_models/test_observers/test_observer.py b/tests/test_models/test_observers/test_observer.py new file mode 100644 index 000000000..ca39ecfbd --- /dev/null +++ b/tests/test_models/test_observers/test_observer.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + + +class ToyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # TODO + + +class TestMinMaxObserver(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def test_init(self): + pass + + def test_prepare(self): + pass + + def test_convert(self): + pass + + def test_states(self): + pass + + def test_forward(self): + pass + + +class TestLSQObserver(TestMinMaxObserver): + pass diff --git a/tests/test_models/test_quantizers/test_trt_quantizer.py b/tests/test_models/test_quantizers/test_trt_quantizer.py new file mode 100644 index 000000000..9f85d1ecd --- /dev/null +++ b/tests/test_models/test_quantizers/test_trt_quantizer.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + + +class ToyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # TODO + + +class TestTRTQuantizer(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def test_init(self): + pass + + def test_prepare(self): + pass + + def test_convert(self): + pass + + def test_states(self): + pass + + def test_forward(self): + pass diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py new file mode 100644 index 000000000..671922f69 --- /dev/null +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.task_modules import CustomTracer, UntracedMethodRegistry +from mmrazor.testing import ConvBNReLU + + +class testCustomTracer(TestCase): + + def test_init(self): + tracer = CustomTracer() + assert tracer.skipped_methods.__len__() == 0 + + def test_trace(self): + tracer = CustomTracer() + model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + graph = tracer.trace(model) # noqa: F841 + + def test_auto_skip_call_module(self): + pass + + def test_auto_skip_call_method(self): + pass + + def test_configurable_skipped_methods(self): + pass + + +class testUntracedMethodRgistry(TestCase): + + def test_init(self): + self.assertEqual(len(UntracedMethodRegistry.method_dict), 0) + + def test_add_method(self): + pass diff --git a/tools/ckpt_demo.py b/tools/ckpt_demo.py new file mode 100644 index 000000000..ee257390c --- /dev/null +++ b/tools/ckpt_demo.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +ckpt_path = '/mnt/lustre/humu/experiments/adaround/quantizied.pth' +# ckpt_path = +# '/mnt/petrelfs/humu/share/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' +# ckpt_path = '/tmp/humu/resnet18_uniform8/checkpoint.pth.tar' +# ckpt_path = '/tmp/humu/resnet18_uniform8/quantized_checkpoint.pth.tar' + +state_dict = torch.load(ckpt_path, map_location='cpu') + +for k, v in state_dict['state_dict'].items(): + print(k) diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh index 6dd67e574..3c74ec6ec 100644 --- a/tools/slurm_test.sh +++ b/tools/slurm_test.sh @@ -1,24 +1,10 @@ #!/usr/bin/env bash -set -x - -PARTITION=$1 -JOB_NAME=$2 -CONFIG=$3 -CHECKPOINT=$4 -GPUS=${GPUS:-8} -GPUS_PER_NODE=${GPUS_PER_NODE:-8} -CPUS_PER_TASK=${CPUS_PER_TASK:-5} -PY_ARGS=${@:5} -SRUN_ARGS=${SRUN_ARGS:-""} +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +PORT=${PORT:-29500} PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ -srun -p ${PARTITION} \ - --job-name=${JOB_NAME} \ - --gres=gpu:${GPUS_PER_NODE} \ - --ntasks=${GPUS} \ - --ntasks-per-node=${GPUS_PER_NODE} \ - --cpus-per-task=${CPUS_PER_TASK} \ - --kill-on-bad-exit=1 \ - ${SRUN_ARGS} \ - python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/tracer_demo.py b/tools/tracer_demo.py new file mode 100644 index 000000000..88334d6aa --- /dev/null +++ b/tools/tracer_demo.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.fx as fx +from mmengine.config import Config +from mmengine.registry import MODELS + +from mmrazor.models.task_modules.tracer import custom_symbolic_trace + +cfg_path = 'configs/quantization/ptq/demo.py' +_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) + + +def extract_subgraph(graphmodule, block_slice): + subgraph = copy.deepcopy(graphmodule.graph) + block_start, block_end = block_slice[:2] + for node in subgraph.nodes: + if node.name == 'inputs': + input_node = node + if node.name == block_start.name: + node.replace_input_with(node.prev, input_node) + if node.name == block_end.name: + output_node = node + if node.op == 'output': + node.replace_input_with(node.prev, output_node) + subgraph.lint() + subgraph_module = fx.GraphModule(graphmodule, subgraph) + subgraph_module.graph.eliminate_dead_code() + subgraph_module.recompile() + return subgraph_module + + +def extract_blocks(graphmodule, key_word='layer'): + block_slices = [] + block_slice = [] + pre_stage_index, pre_block_index = 0, 0 + cur_stage_index, cur_block_index = 0, 0 + for node in graphmodule.graph.nodes: + if key_word not in node.name: + continue + else: + items = node.name.split('_') + for i, item in enumerate(items): + if key_word in item: + cur_stage_index = int(item[5:]) + cur_block_index = int(items[i + 1]) + break + if (cur_block_index != pre_block_index) or (cur_stage_index != + pre_stage_index): + block_slice.append(node.prev) + if len(block_slice) == 2: + block_slices.append(block_slice) + block_slice = [] + block_slice.append(node) + + pre_stage_index, pre_block_index = cur_stage_index, cur_block_index + + return block_slices + + +def extract_layers(graphmodule, layer_types): + layer_slices = [] + for node in graphmodule.graph.nodes: + if node.op == 'call_module': + m = node.graph.owning_module.get_submodule(node.target) + if isinstance(m, _ADAROUND_SUPPORT_TYPE): + layer_slices.append((node, node)) + return layer_slices + + +def main(): + # load config + cfg = Config.fromfile(cfg_path) + model = MODELS.build(cfg.model) + symbolic_traced = custom_symbolic_trace( + model, concrete_args={'mode': 'tensor'}) + # block_slices = extract_blocks(symbolic_traced) + block_slices = extract_layers( + symbolic_traced, layer_types=_ADAROUND_SUPPORT_TYPE) + + for b in block_slices: + print(b[0].name, b[1].name) + + print('#' * 100) + subgraph = extract_subgraph(symbolic_traced, block_slices[0]) + print(subgraph.code) + for name, layer in subgraph.named_modules(): + print(name, layer) + + +if __name__ == '__main__': + main() From 09f943cb35632089eb14ac360922f547b659bad4 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 25 Nov 2022 15:45:37 +0800 Subject: [PATCH 02/44] [Features]Quantize pipeline (#350) * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * update * updated * retina loss & predict & tesnor DONE * for RFC * Customed FX initialize * add UT init * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * TDO: UTs * fix bugs * fix lsq * refactor quantize pipeline * fix quant * WIP: debug qat * fix lsq bugs * fix qat, docstring in progress * fixed DefaultQconfigs name * fix bugs * add comments and fix typos * delete useless codes * fix bugs and add comments * rename prepare_module_dict * update lsq config Co-authored-by: humu789 Co-authored-by: huangpengsheng Co-authored-by: FreakieHuang Co-authored-by: pppppM --- configs/quantization/ptq/adaround.py | 8 +- configs/quantization/qat/demo.py | 1 - .../qat/lsq_resnet18_8xb16_cifar10.py | 70 ++++ .../qat/lsq_resnet18_8xb32_in1k.py | 75 +++++ .../qat/lsq_resnet50_8xb16_cifar10.py | 37 -- mmrazor/engine/runner/quantization_loops.py | 315 +++++++++++++++--- .../models/algorithms/quantization/base.py | 154 ++++++--- mmrazor/models/fake_quants/lsq.py | 16 + .../units/mutable_channel_unit.py | 4 +- mmrazor/models/observers/__init__.py | 3 +- mmrazor/models/observers/base.py | 3 +- mmrazor/models/observers/minmax.py | 4 +- mmrazor/models/quantizers/base.py | 9 +- mmrazor/models/quantizers/trt_quantizer.py | 4 +- .../models/task_modules/tracer/__init__.py | 5 +- .../models/task_modules/tracer/fx/__init__.py | 7 +- .../task_modules/tracer/fx/custom_tracer.py | 88 ++++- mmrazor/registry/registry.py | 1 + mmrazor/structures/quantization/__init__.py | 4 +- .../quantization/backend_default_qconfigs.py | 6 +- tools/ptq_calibrate.py | 73 ++++ 21 files changed, 735 insertions(+), 152 deletions(-) delete mode 100644 configs/quantization/qat/demo.py create mode 100644 configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py create mode 100644 configs/quantization/qat/lsq_resnet18_8xb32_in1k.py delete mode 100644 configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py create mode 100644 tools/ptq_calibrate.py diff --git a/configs/quantization/ptq/adaround.py b/configs/quantization/ptq/adaround.py index 389575dc6..78157c61a 100644 --- a/configs/quantization/ptq/adaround.py +++ b/configs/quantization/ptq/adaround.py @@ -1,12 +1,8 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] test_cfg = dict( - _delete_=True, type='mmrazor.PTQLoop', - dataloader=_base_.test_dataloader, - evaluator=_base_.test_evaluator, - calibrate_dataloader=_base_.train_dataloader, - batch_num=32, + # reconstruction_cfg=dict( # pattern='layer', # loss=dict( diff --git a/configs/quantization/qat/demo.py b/configs/quantization/qat/demo.py deleted file mode 100644 index be3ec6013..000000000 --- a/configs/quantization/qat/demo.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['./lsq_resnet50_8xb16_cifar10.py'] diff --git a/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py b/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py new file mode 100644 index 000000000..412a6fd87 --- /dev/null +++ b/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py @@ -0,0 +1,70 @@ +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] + +resnet = _base_.model +pretrained_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GeneralQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=10, + # RGB format normalization parameters + mean=[125.307, 122.961, 113.8575], + std=[51.5865, 50.847, 51.255], + # loaded images are already RGB format + to_rgb=False), + architecture=resnet, + pretrained_ckpt=pretrained_ckpt, + quantizer=dict( + type='CustomQuantizer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.LSQObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.GeneralQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + by_epoch=True, + max_epochs=100, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..a0885a52a --- /dev/null +++ b/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py @@ -0,0 +1,75 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=_base_.train_cfg.max_epochs) + +resnet = _base_.model +ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 +resnet.init_cfg = dict(type='Pretrained', checkpoint=ckpt) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GeneralQuant', + # data_preprocessor = dict( + # num_classes=1000, + # # RGB format normalization parameters + # mean=[123.675, 116.28, 103.53], + # std=[58.395, 57.12, 57.375], + # # convert image from BGR to RGB + # to_rgb=True, + # ), + architecture=resnet, + quantizer=dict( + type='CustomQuantizer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.MinMaxObserver'), + a_observer=dict(type='mmrazor.EMAMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=8, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +default_hooks = dict( + checkpoint=dict( + type='CheckpointHook', + interval=5, + max_keep_ckpts=3, + out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/quant')) + +model_wrapper_cfg = dict( + type='mmrazor.GeneralQuantDDP', + broadcast_buffers=False, + find_unused_parameters=False) + +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py deleted file mode 100644 index a246bc265..000000000 --- a/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py +++ /dev/null @@ -1,37 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] - -train_cfg = dict( - _delete_=True, - type='mmrazor.QATEpochBasedLoop', - max_epochs=_base_.train_cfg.max_epochs, -) - -model = dict( - _delete_=True, - _scope_='mmrazor', - type='GeneralQuant', - architecture={{_base_.model}}, - quantizer=dict( - type='TensorRTQuantizer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ], - qconfig=dict( - qtype='affine', - w_observer=dict(type='mmrazor.MinMaxObserver'), - a_observer=dict(type='mmrazor.EMAMinMaxObserver'), - w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - w_qscheme=dict( - bit=2, - is_symmetry=False, - is_per_channel=True, - is_pot_scale=False, - ), - a_qscheme=dict( - bit=4, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False), - ))) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 2f15f5deb..a2d5d383b 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -1,13 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch from mmengine.evaluator import Evaluator from mmengine.registry import MODELS -from mmengine.runner import EpochBasedTrainLoop, TestLoop +from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop, autocast +from torch.ao.quantization import disable_observer +from torch.nn.intrinsic.qat import freeze_bn_stats from torch.utils.data import DataLoader from mmrazor.models.task_modules import (ModuleInputsRecorder, @@ -28,12 +30,13 @@ class QATEpochBasedLoop(EpochBasedTrainLoop): dataloader (Dataloader or dict): An iterator to generate one batch of dataset each iteration. max_epochs (int): Total training epochs. - calibrate_dataloader (Dataloader or dict, optional): A dataloader - object or a dict to build a dataloader for calibration. Defaults - to None. val_begin (int): The epoch that begins validating. Defaults to 1. val_interval (int): Validation interval. Defaults to 1. + disable_observer_begin (int): The number of total epochs to update + observers. + freeze_bn_begin (int): The number of total epochs to update batch norm + stats. dynamic_intervals (List[Tuple[int, int]], optional): The first element in the tuple is a milestone and the second element is a interval. The interval is used after the @@ -45,68 +48,296 @@ def __init__( runner, dataloader: Union[DataLoader, Dict], max_epochs: int, - calibrate_dataloader: Union[DataLoader, Dict] = None, val_begin: int = 1, val_interval: int = 1, + disable_observer_begin: int = 3, + freeze_bn_begin: int = 3, dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: super().__init__(runner, dataloader, max_epochs, val_begin, val_interval, dynamic_intervals) - if isinstance(calibrate_dataloader, dict): - # Determine whether or not different ranks use different seed. - diff_rank_seed = runner._randomness_cfg.get( - 'diff_rank_seed', False) - self.calibrate_dataloader = runner.build_dataloader( - calibrate_dataloader, - seed=runner.seed, - diff_rank_seed=diff_rank_seed) - else: - self.calibrate_dataloader = calibrate_dataloader - self.is_calibrate = True if calibrate_dataloader is not None else False + self.disable_observer_begin = disable_observer_begin + self.freeze_bn_begin = freeze_bn_begin - if self.runner.distributed: - self.model = runner.model.module - else: - self.model = runner.model - - def calibrate(self, calibrate_dataloader) -> None: - self.model.eval() - with torch.no_grad(): - for batch_data in calibrate_dataloader: - self.model(batch_data) - - def run(self) -> None: - """Launch training.""" - self.runner.call_hook('before_train') - - self.model.prepare() - - if self.is_calibrate: - self.model.state = (1, 0) - self.calibrate(self.calibrate_dataloader) - - self.model.state = (1, 1) + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[qat_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None) while self._epoch < self._max_epochs: + # state: observer_enabled, fakequant_enabled + self.runner.model.state = (True, True) self.run_epoch() self._decide_current_val_interval() if (self.runner.val_loop is not None and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): + # observer disabled during evaluation + self.runner.model.state = (False, True) + self.runner.model.sync_param() self.runner.val_loop.run() - self.model.convert() + self.runner.call_hook('after_train') - # self.runner.val_loop.run() + def run_epoch(self) -> None: + """Iterate one epoch.""" + self.runner.call_hook('before_train_epoch') + self.runner.model.train() - self.runner.call_hook('after_train') + # TODO freeze bn + if self._epoch >= self.disable_observer_begin: + self.runner.model.apply(disable_observer) + + if self._epoch >= self.freeze_bn_begin: + self.runner.model.apply(freeze_bn_stats) + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + self.runner.call_hook('after_train_epoch') + self._epoch += 1 + + +@LOOPS.register_module() +class QATValLoop(ValLoop): + """`ValLoop` for `QuantizationAwareTraining` + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool): Whether to enable fp16 validation. Defaults to + False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False) -> None: + super().__init__(runner, dataloader, evaluator, fp16) + if self.runner.distributed: + assert hasattr(self.runner.model.module, 'architecture') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.module.data_preprocessor + self.architecture = self.runner.model.module.architecture + self.architecture.data_preprocessor = data_preprocessor + + else: + assert hasattr(self.runner.model, 'architecture') + # TODO: remove hard code after mmcls add data_preprocessor + data_preprocessor = self.runner.model.data_preprocessor + self.architecture = self.runner.model.architecture + self.architecture.data_preprocessor = data_preprocessor + + def run(self) -> dict: + """Launch validation.""" + self.runner.call_hook('before_val') + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch, self.runner.model) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[qat_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=qat_metrics) + + self.runner.call_hook('before_val_epoch') + self.runner.model.eval() + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch, self.architecture) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + qat_metrics = dict() + for key, value in metrics.items(): + qat_key = 'qat.' + key + ori_key = 'original.' + key + qat_metrics[ori_key] = value + self.runner.message_hub.log_scalars.pop(f'val/{qat_key}', None) + + self.runner.call_hook('after_val_epoch', metrics=qat_metrics) + + self.runner.call_hook('after_val') + return qat_metrics + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict], model): + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data + from dataloader. + """ + self.runner.call_hook( + 'before_val_iter', batch_idx=idx, data_batch=data_batch) + # outputs should be sequence of BaseDataElement + with autocast(enabled=self.fp16): + outputs = model.val_step(data_batch) + self.evaluator.process(data_samples=outputs, data_batch=data_batch) + self.runner.call_hook( + 'after_val_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) @LOOPS.register_module() class PTQLoop(TestLoop): """`TestLoop` for Post Training Quantization. + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + fp16 (bool, optional): Enable FP16 training mode. Defaults to False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False): + super().__init__(runner, dataloader, evaluator, fp16) + + def run(self) -> dict: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + self.runner.model.state = (True, False) + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + + # todo: hard code to save checkpoint on disk + self.runner.save_checkpoint( + self.runner.work_dir, + 'checkpoint_after_ptq.pth', + file_client_args=None, + save_optimizer=False, + save_param_scheduler=False) + + return metrics + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + + outputs = self.runner.model.calibrate_step(data_batch) + + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + +# TODO refactor to supoort DDP +@LOOPS.register_module() +class AdaRoundLoop(TestLoop): + """`TestLoop` for Post Training Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + calibrate_dataloader (Dataloader or dict, optional): A dataloader + object or a dict to build a dataloader for calibration. Defaults + to None. + batch_num (Optional[int], optional): Total calibration batches. + Defaults to None. + reconstruction_cfg (Optional[Dict], optional): Model reconstruction + configuration. Defaults to None. + fp16 (bool, optional): Enable FP16 training mode. Defaults to False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + fp16: bool = False): + super().__init__(runner, dataloader, evaluator, fp16) + + def run(self) -> None: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + self.runner.model.eval() + self.runner.model.state = (1, 0) + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + # compute metrics + metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + + self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test') + + # todo: hard code to save checkpoint on disk + self.runner.save_checkpoint( + self.runner.work_dir, + 'checkpoint_after_ptq.pth', + file_client_args=None, + save_optimizer=False, + save_param_scheduler=False) + + return metrics + + @torch.no_grad() + def run_iter(self, idx, data_batch: Sequence[dict]) -> None: + """Iterate one mini-batch. + + Args: + data_batch (Sequence[dict]): Batch of data from dataloader. + """ + self.runner.call_hook( + 'before_test_iter', batch_idx=idx, data_batch=data_batch) + # predictions should be sequence of BaseDataElement + + outputs = self.runner.model.calibrate_step(data_batch) + + self.runner.call_hook( + 'after_test_iter', + batch_idx=idx, + data_batch=data_batch, + outputs=outputs) + + +# TODO refactor to supoort DDP +@LOOPS.register_module() +class AdaRoundLoop(TestLoop): + """`TestLoop` for Post Training Quantization. + Args: runner (Runner): A reference of runner dataloader (Dataloader or dict): An iterator to generate one batch of diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/base.py index 718b08725..c97d832ff 100644 --- a/mmrazor/models/algorithms/quantization/base.py +++ b/mmrazor/models/algorithms/quantization/base.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Dict, List, Optional, Tuple, Union import torch +from mmengine.model import MMDistributedDataParallel +from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement -from torch.fx import GraphModule +from torch import nn +from torch.ao.quantization import FakeQuantizeBase -from mmrazor.registry import MODELS +from mmrazor.models.task_modules import build_graphmodule +from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm LossResults = Dict[str, torch.Tensor] @@ -19,13 +24,17 @@ class GeneralQuant(BaseAlgorithm): """General quantization. Args: - Args: architecture (dict | :obj:`BaseModel`): The config of :class:`BaseModel` or built model. quantizer (dict | :obj:`BaseModel`): The config of :class:`BaseQuantizer` or built model. + export_mode (str): The mode of the model to be exported. Defaults to + predict. + qmodel_modes (list): The available mode of runner. data_preprocessor (dict | torch.nn.Module | None): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to None. + pretrained_ckpt (str, Optional): The path of pretrained checkpoint. + Defaults to None. init_cfg (dict): The weight initialized config for :class:`BaseModule`. """ @@ -33,74 +42,94 @@ class GeneralQuant(BaseAlgorithm): def __init__(self, architecture, quantizer, + export_mode: str = 'predict', + qmodel_modes: List[str] = ['tensor', 'predict', 'loss'], data_preprocessor=None, + pretrained_ckpt: Optional[str] = None, init_cfg=None): + if data_preprocessor is None: data_preprocessor = {} # The build process is in MMEngine, so we need to add scope here. data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') - super().__init__(architecture, data_preprocessor, init_cfg) + if pretrained_ckpt: + _ = load_checkpoint(self.architecture, pretrained_ckpt) + self.architecture._is_init = True self.quantizer = MODELS.build(quantizer) - self.observers_enabled = True - self.fake_quants_enabled = True - self.gen_graphs(self.architecture) + self._observers_enabled = True + self._fake_quants_enabled = True + self.export_mode = export_mode + self.qmodel_modes = qmodel_modes + self.qmodels = self._build_qmodels(self.architecture) + + def sync_param(self): + + def traverse(module, prefix): + for name, child in module._modules.items(): + if module is None: + continue + module_name = f'{prefix}{name}' + if isinstance(child, FakeQuantizeBase): + for name, param in child.named_parameters(): + param.data.copy_(self.qmodels['loss'].state_dict() + [f'{module_name}.{name}']) + for name, buffer in child.named_buffers(): + buffer.data.copy_(self.qmodels['loss'].state_dict() + [f'{module_name}.{name}']) + else: + traverse(child, f'{module_name}.') + + for mode in self.qmodel_modes: + if mode == 'loss': + continue + traverse(self.qmodels[mode], '') + + def _build_qmodels(self, model): + + qmodels = nn.ModuleDict() - def gen_graphs(self, model): self.quantizer._swap_ff_with_fxff(model) tracer = self.quantizer.tracer - for mode in ['tensor', 'loss', 'predict']: + + for mode in self.qmodel_modes: concrete_args = {'mode': mode} - if mode == 'tensor': - self.graph_tensor = GraphModule( - model, tracer.trace(model, concrete_args=concrete_args)) - if mode == 'loss': - self.graph_loss = GraphModule( - model, tracer.trace(model, concrete_args=concrete_args)) - if mode == 'predict': - self.graph_predict = GraphModule( - model, tracer.trace(model, concrete_args=concrete_args)) + traced_graph = tracer.trace(model, concrete_args=concrete_args) + + qmodel = build_graphmodule(model, traced_graph) + qmodels[mode] = self.quantizer.prepare(model, qmodel) + + return qmodels def forward(self, inputs: torch.Tensor, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor') -> ForwardResults: - if mode == 'loss': - return self.graph_loss(inputs, data_samples, mode) - elif mode == 'tensor': - return self.graph_tensor(inputs, data_samples, mode) - elif mode == 'predict': - return self.graph_predict(inputs, data_samples, mode) + if mode in self.qmodels: + qmodel = self.qmodels[mode] + return qmodel(inputs, data_samples, mode) else: - raise RuntimeError(f'Invalid mode "{mode}". ' - 'Only supports loss, predict and tensor mode') + return self.architecture(inputs, data_samples, mode) - def calib_step(self, data): + def calibrate_step(self, data): data = self.data_preprocessor(data, False) + self.state = (1, 0) return self._run_forward(data, mode='tensor') - def prepare(self, mode='tensor'): - assert mode in ['tensor', 'loss', 'predict'] - if mode == 'tensor': - graph = self.graph_tensor - elif mode == 'loss': - graph = self.graph_loss - else: - graph = self.graph_predict - self.architecture = self.quantizer.prepare(self.architecture, graph) - - def convert(self): - self.architecture = self.quantizer.convert(self.architecture) + def convert(self, mode='predict'): + qmodel = self.qmodels[self.export_mode] + self.qmodels[mode] = self.quantizer.convert(qmodel) @property def state(self): - return (self.observers_enabled, self.fake_quants_enabled) + return (self._observers_enabled, self._fake_quants_enabled) @state.setter - def state(self, state): + def state(self, state: Tuple[bool, bool]): observers_enabled, fake_quants_enabled = state - for name, submodule in self.architecture.named_modules(): + qmodel = self.qmodels[self.export_mode] + for submodule in qmodel.modules(): if isinstance(submodule, torch.quantization.FakeQuantize): if observers_enabled: submodule.enable_observer() @@ -112,5 +141,42 @@ def state(self, state): else: submodule.disable_fake_quant() - self.observers_enabled = observers_enabled - self.fake_quants_enabled = fake_quants_enabled + self._observers_enabled = observers_enabled + self._fake_quants_enabled = fake_quants_enabled + + +@MODEL_WRAPPERS.register_module() +class GeneralQuantDDP(MMDistributedDataParallel): + """DDPwapper for GeneralQuant.""" + + def __init__(self, + *, + device_ids: Optional[Union[List, int, torch.device]] = None, + **kwargs) -> None: + if device_ids is None: + if os.environ.get('LOCAL_RANK') is not None: + device_ids = [int(os.environ['LOCAL_RANK'])] + super().__init__(device_ids=device_ids, **kwargs) + # After moving all model parameters and buffers to the GPU + # (`model.cuda()`), the buffers in model are different. + self.module.qmodels = self.module._build_qmodels( + self.module.architecture) + + def calibrate_step(self, data): + return self.module.calibrate_step(data) + + @property + def state(self): + return (self.module._observers_enabled, + self.module._fake_quants_enabled) + + @state.setter + def state(self, state: Tuple[bool]): + self.module.state = state + + def convert(self, mode='predict'): + self.module.convert(mode) + self.module.qmodels[mode].cuda() + + def sync_param(self): + self.module.sync_param() diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py index 10970a6a3..3e26631eb 100644 --- a/mmrazor/models/fake_quants/lsq.py +++ b/mmrazor/models/fake_quants/lsq.py @@ -48,6 +48,22 @@ def extra_repr(self): self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape), # noqa: E501 self.zero_point if self.ch_axis == -1 else 'List') + @torch.jit.export + def calculate_qparams(self): + self.scale.data.clamp_(min=self.eps.item()) + scale = self.scale.detach() + zero_point = self.zero_point.detach().round().clamp( + self.quant_min, self.quant_max).long() + return scale, zero_point + + def _save_to_state_dict(self, destination, prefix, keep_vars): + super(FakeQuantize, self)._save_to_state_dict(destination, prefix, + keep_vars) + destination[prefix + 'scale'] = self.scale if keep_vars \ + else self.scale.detach() + destination[prefix + 'zero_point'] = self.zero_point if keep_vars \ + else self.zero_point.detach() + def forward(self, X): # Learnable fake quantize have to zero_point.float() # to make it learnable. diff --git a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py index d0a2deff0..251214f70 100644 --- a/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/mutable_channel_unit.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. """This module defines MutableChannelUnit.""" import abc -from collections import Set +# from collections import set from typing import Dict, List, Type, TypeVar import torch @@ -72,7 +72,7 @@ def process_container(container: MutableChannelContainer, if isinstance(derived_choices, torch.Tensor): derived_choices = derived_choices.sum().item() if isinstance(mutable, DerivedMutable): - source_mutables: Set = \ + source_mutables: set = \ mutable._trace_source_mutables() source_channel_mutables = [ mutable for mutable in source_mutables diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py index 22af9bae9..eac6371e2 100644 --- a/mmrazor/models/observers/__init__.py +++ b/mmrazor/models/observers/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .lsq_observer import LSQObserver from .minmax import EMAMinMaxObserver, MinMaxObserver from .mse import MSEObserver -__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver'] +__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver', 'LSQObserver'] diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py index e10738664..8d9c40afe 100644 --- a/mmrazor/models/observers/base.py +++ b/mmrazor/models/observers/base.py @@ -6,9 +6,8 @@ from mmrazor.models.utils import pot_quantization, sync_tensor -# from mmengine.model import BaseModule - +# todo: We only support per-tensor quantization currently. class BaseObserver(UniformQuantizationObserverBase): """Modified torch quantization observer. diff --git a/mmrazor/models/observers/minmax.py b/mmrazor/models/observers/minmax.py index 099296536..2ec620e60 100644 --- a/mmrazor/models/observers/minmax.py +++ b/mmrazor/models/observers/minmax.py @@ -45,8 +45,8 @@ def forward(self, x_orig): min_val_cur, max_val_cur = torch._aminmax(y, 1) min_val = torch.min(self.min_val, min_val_cur) max_val = torch.max(self.max_val, max_val_cur) - self.min_val.copy_(min_val) - self.max_val.copy_(max_val) + self.min_val = min_val + self.max_val = max_val return x diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index ab4cf190a..6f1fb4e31 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -13,7 +13,7 @@ check_is_valid_qconfig_dict, get_custom_module_class_keys) from mmrazor.registry import MODELS -from mmrazor.structures.quantization import (CheckArgs, DefalutQconfigs, +from mmrazor.structures.quantization import (CheckArgs, DefaultQconfigs, QuantizeScheme, SupportQtypes) @@ -22,7 +22,7 @@ class CustomQuantizer(BaseModule): """Configurable quantizer, base class of quantizers. Args: - qconfig (Dict, optional): QConfig. Defaults to DefalutQconfigs['default']. # noqa: E501 + qconfig (Dict, optional): QConfig. Defaults to DefaultQconfigs['default']. # noqa: E501 is_qat (bool, optional): Is QAT ro not. Defaults to True. skipped_methods (List, optional): Skipped methods list for tracer. Defaults to None. @@ -38,7 +38,7 @@ class CustomQuantizer(BaseModule): """ def __init__(self, - qconfig: Dict = DefalutQconfigs['default'], + qconfig: Dict = DefaultQconfigs['default'], is_qat: bool = True, skipped_methods: List = None, prepare_custom_config_dict: Dict = None, @@ -189,6 +189,9 @@ def build_tracer(self): return tracer def fuse_model(self, graph_module): + if not self.is_qat: + graph_module.eval() + graph_module = _fuse_fx(graph_module, self.is_qat, self.prepare_custom_config_dict) return graph_module diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py index cc8532a53..9dbe9f594 100644 --- a/mmrazor/models/quantizers/trt_quantizer.py +++ b/mmrazor/models/quantizers/trt_quantizer.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmrazor.registry import MODELS -from mmrazor.structures.quantization import DefalutQconfigs +from mmrazor.structures.quantization import DefaultQconfigs from .base import CustomQuantizer @@ -9,7 +9,7 @@ class TensorRTQuantizer(CustomQuantizer): """Quantizer for TensorRT backend.""" def __init__(self, - qconfig=DefalutQconfigs['tensorrt'], + qconfig=DefaultQconfigs['tensorrt'], is_qat=True, skipped_methods=None, prepare_custom_config_dict=None, diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index aab26fa40..838f20164 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -2,7 +2,8 @@ from .backward_tracer import BackwardTracer from .channel_analyzer import ChannelAnalyzer # from .razor_tracer import RazorFxTracer -from .fx import CustomTracer, UntracedMethodRegistry, custom_symbolic_trace +from .fx import (CustomTracer, UntracedMethodRegistry, custom_symbolic_trace, + prepare_graph_module) from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, @@ -12,5 +13,5 @@ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', - 'custom_symbolic_trace' + 'custom_symbolic_trace', 'prepare_graph_module' ] diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py index 29c93f83a..998e9ffe1 100644 --- a/mmrazor/models/task_modules/tracer/fx/__init__.py +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .custom_tracer import (CustomTracer, UntracedMethodRegistry, - custom_symbolic_trace) + custom_symbolic_trace, build_graphmodule) -__all__ = ['CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace'] +__all__ = [ + 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', + 'build_graphmodule' +] diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index f69ec2269..1d78d3007 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union import torch +import torch.nn as nn from mmengine.utils import import_modules_from_strings from torch._C import ScriptObject # type: ignore[attr-defined] from torch.ao.quantization.quantize_fx import QuantizationTracer @@ -14,7 +15,6 @@ _orig_module_call: Callable = torch.nn.Module.__call__ _orig_module_getattr: Callable = torch.nn.Module.__getattr__ -# _orig_module_forward_train: Callable = models.BaseDenseHead.forward_train class UntracedMethodRegistry: @@ -78,6 +78,92 @@ def custom_symbolic_trace( return GraphModule(tracer.root, graph, name) +def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): + """If there is a class method that can not be traced by the symbolic + tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in + ``CustomTracer``. + + For example, + ``` + >>> class Model: + ... def __init__(self): + ... self.head = ClsHead() + ... + >>> class ClsHead(nn.Module): + ... def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: + ... return feats[-1] + ... + ... def loss(self, feats: Tuple[torch.Tensor], + ... data_samples: List[ClsDataSample], **kwargs) -> dict: + ... cls_score = self(feats) + ... # The part can not be traced by torch.fx + ... losses = self._get_loss(cls_score, data_samples, **kwargs) + ... return losses + ... + ... def _get_loss(self, cls_score: torch.Tensor, + ... data_samples: List[ClsDataSample], **kwargs): + ... if 'score' in data_samples[0].gt_label: + ... xxx + ... else: + ... xxx + ... losses = xxx + ... return losses + ``` + As the ``_get_loss`` can not be traced by torch.fx, ``Toy._get_loss`` need + to be added to ``skipped_methods`` in ``CustomTracer``. Hence the code + above will product the following Graph:: + + .. code-block:: text + ... ... + %head : [#users=1] = get_attr[target=head] + %_get_loss : [#users=1] = call_method[target=_get_loss](args = (%head, %head_fc, %data_samples), kwargs = {}) # noqa: E501 + return _get_loss + + Hence, the head module in the ``GraphModule`` and that in the original + model are the same one (refer to https://github.com/pytorch/pytorch/blob/master/torch/fx/graph_module.py#L346). # noqa: E501 + So changes made to the graph module (in ``prepare()``) will also modify + the original model. + + Args: + model (nn.Module): The original model. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. + """ + + def _get_attrs(target, attrs): + attrs = attrs.split('.') + for att in attrs: + target = getattr(target, att) + return target + + module_dict = dict() + special_nodes = [] + + for node in fx_graph.nodes: + if node.op == 'get_attr': + attr = _get_attrs(model, node.target) + if isinstance(attr, nn.Module): + module_dict[node.target] = nn.Module() + special_nodes.append(node) + elif node.op == 'call_method': + for special_node in special_nodes: + if special_node in node.args or \ + special_node in node.kwargs.values(): + origin_module = getattr(model, special_node.target) + setattr(module_dict[special_node.target], node.target, + getattr(origin_module, node.target)) + + return module_dict + + +def build_graphmodule(model: nn.Module, + fx_graph: torch.fx.Graph, + name: str = 'GraphModule'): + modules = dict(model.named_modules()) + module_dict = _prepare_module_dict(model, fx_graph) + modules.update(module_dict) + return GraphModule(modules, fx_graph, name) + + class CustomTracer(QuantizationTracer): def __init__(self, diff --git a/mmrazor/registry/registry.py b/mmrazor/registry/registry.py index 7f915ee74..6aa3f192d 100644 --- a/mmrazor/registry/registry.py +++ b/mmrazor/registry/registry.py @@ -30,6 +30,7 @@ from mmengine.registry import \ WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS from mmengine.registry import Registry, build_from_cfg +from mmengine.runner import load_checkpoint def build_razor_model_from_cfg( diff --git a/mmrazor/structures/quantization/__init__.py b/mmrazor/structures/quantization/__init__.py index fc2133bf2..9447c2f0f 100644 --- a/mmrazor/structures/quantization/__init__.py +++ b/mmrazor/structures/quantization/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backend_default_qconfigs import CheckArgs, DefalutQconfigs, SupportQtypes +from .backend_default_qconfigs import CheckArgs, DefaultQconfigs, SupportQtypes from .qscheme import QuantizeScheme -__all__ = ['QuantizeScheme', 'DefalutQconfigs', 'SupportQtypes', 'CheckArgs'] +__all__ = ['QuantizeScheme', 'DefaultQconfigs', 'SupportQtypes', 'CheckArgs'] diff --git a/mmrazor/structures/quantization/backend_default_qconfigs.py b/mmrazor/structures/quantization/backend_default_qconfigs.py index 6a1fde183..590f3208a 100644 --- a/mmrazor/structures/quantization/backend_default_qconfigs.py +++ b/mmrazor/structures/quantization/backend_default_qconfigs.py @@ -19,8 +19,8 @@ is_pot_scale=False, bit=8, symmetric_range=True), - w_fake_quant=dict(type='BaseFakeQuantize'), - a_fake_quant=dict(type='BaseFakeQuantize'), + w_fake_quant=dict(type='FakeQuantize'), + a_fake_quant=dict(type='FakeQuantize'), w_observer=dict(type='MinMaxObserver'), a_observer=dict(type='MinMaxObserver')) @@ -43,4 +43,4 @@ w_observer=dict(type='MinMaxObserver'), a_observer=dict(type='EMAMinMaxObserver')) -DefalutQconfigs = dict(default=Default, tensorrt=TensorRT) +DefaultQconfigs = dict(default=Default, tensorrt=TensorRT) diff --git a/tools/ptq_calibrate.py b/tools/ptq_calibrate.py new file mode 100644 index 000000000..2c00c5b11 --- /dev/null +++ b/tools/ptq_calibrate.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp + +from mmengine.config import Config, DictAction +from mmengine.runner import Runner + +from mmrazor.utils import register_all_modules + + +# TODO: support fuse_conv_bn, visualization, and format_only +def parse_args(): + parser = argparse.ArgumentParser( + description='MMRazor test (and eval) a model') + parser.add_argument('config', help='test config file path') + # parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def main(): + register_all_modules(False) + args = parse_args() + + # load config + cfg = Config.fromfile(args.config) + cfg.launcher = args.launcher + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + + # cfg.load_from = args.checkpoint + + # build the runner from config + runner = Runner.from_cfg(cfg) + + # start testing + runner.test() + + +if __name__ == '__main__': + main() From c7e2c9ae4a0c68453802f7ec70b4385bf231193b Mon Sep 17 00:00:00 2001 From: pppppM <67539920+pppppM@users.noreply.github.com> Date: Thu, 1 Dec 2022 12:21:43 +0800 Subject: [PATCH 03/44] [Feature] Add `prepare_for_mmdeploy` interface (#365) * remove useless code * fix build graph module import bug * refactor general quant * rename GeneralQuant to MMArchitectureQuant * fix some dtype bugs * add prepare_for_mmdeploy interface * update prepare for mmdeploy args * fix some comments Co-authored-by: humu789 --- configs/quantization/ptq/demo.py | 1 - ...=> ptq_openvino_resnet18_8xb16_cifar10.py} | 20 +- ...=> lsq_openvino_resnet18_8xb16_cifar10.py} | 25 +- .../qat/lsq_resnet18_8xb32_in1k.py | 75 ---- mmrazor/engine/runner/quantization_loops.py | 322 ++---------------- mmrazor/models/algorithms/__init__.py | 5 +- .../algorithms/quantization/__init__.py | 4 +- .../models/algorithms/quantization/base.py | 113 +++--- mmrazor/models/fake_quants/lsq.py | 9 +- mmrazor/models/observers/lsq_observer.py | 2 +- mmrazor/models/observers/mse.py | 9 +- mmrazor/models/quantizers/__init__.py | 3 +- mmrazor/models/quantizers/base.py | 69 +++- .../models/quantizers/openvino_quantizer.py | 59 ++++ mmrazor/models/quantizers/trt_quantizer.py | 25 ++ .../models/task_modules/tracer/__init__.py | 6 +- .../models/task_modules/tracer/fx/__init__.py | 2 +- mmrazor/models/utils/quantization_util.py | 19 +- mmrazor/registry/registry.py | 1 - tools/debug.py | 162 +++++++++ 20 files changed, 409 insertions(+), 522 deletions(-) delete mode 100644 configs/quantization/ptq/demo.py rename configs/quantization/ptq/{adaround.py => ptq_openvino_resnet18_8xb16_cifar10.py} (68%) rename configs/quantization/qat/{lsq_resnet18_8xb16_cifar10.py => lsq_openvino_resnet18_8xb16_cifar10.py} (68%) delete mode 100644 configs/quantization/qat/lsq_resnet18_8xb32_in1k.py create mode 100644 mmrazor/models/quantizers/openvino_quantizer.py create mode 100644 tools/debug.py diff --git a/configs/quantization/ptq/demo.py b/configs/quantization/ptq/demo.py deleted file mode 100644 index af6a0a5df..000000000 --- a/configs/quantization/ptq/demo.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] diff --git a/configs/quantization/ptq/adaround.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py similarity index 68% rename from configs/quantization/ptq/adaround.py rename to configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py index 78157c61a..bb6dbc778 100644 --- a/configs/quantization/ptq/adaround.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py @@ -2,22 +2,14 @@ test_cfg = dict( type='mmrazor.PTQLoop', - - # reconstruction_cfg=dict( - # pattern='layer', - # loss=dict( - # type='mmrazor.AdaRoundLoss', - # iters=20000 - # ) - # ) ) model = dict( _delete_=True, - type='mmrazor.GeneralQuant', + type='mmrazor.MMArchitectureQuant', architecture=_base_.model, quantizer=dict( - type='mmrazor.CustomQuantizer', + type='mmrazor.OpenvinoQuantizer', is_qat=False, skipped_methods=[ 'mmcls.models.heads.ClsHead._get_loss', @@ -27,16 +19,16 @@ qtype='affine', w_observer=dict(type='mmrazor.MSEObserver'), a_observer=dict(type='mmrazor.EMAMSEObserver'), - w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), a_fake_quant=dict(type='mmrazor.FakeQuantize'), w_qscheme=dict( - bit=2, - is_symmetry=False, + bit=8, + is_symmetry=True, is_per_channel=True, is_pot_scale=False, ), a_qscheme=dict( - bit=4, + bit=8, is_symmetry=False, is_per_channel=False, is_pot_scale=False), diff --git a/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py similarity index 68% rename from configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py rename to configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py index 412a6fd87..8076769a9 100644 --- a/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py @@ -1,24 +1,16 @@ _base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] resnet = _base_.model -pretrained_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 +float_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 model = dict( _delete_=True, _scope_='mmrazor', - type='GeneralQuant', - data_preprocessor=dict( - type='mmcls.ClsDataPreprocessor', - num_classes=10, - # RGB format normalization parameters - mean=[125.307, 122.961, 113.8575], - std=[51.5865, 50.847, 51.255], - # loaded images are already RGB format - to_rgb=False), + type='MMArchitectureQuant', architecture=resnet, - pretrained_ckpt=pretrained_ckpt, + float_checkpoint=float_ckpt, quantizer=dict( - type='CustomQuantizer', + type='OpenvinoQuantizer', skipped_methods=[ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' @@ -31,8 +23,8 @@ a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), w_qscheme=dict( bit=8, - is_symmetry=False, - is_per_channel=False, + is_symmetry=True, + is_per_channel=True, is_pot_scale=False, ), a_qscheme=dict( @@ -55,7 +47,7 @@ end=100) model_wrapper_cfg = dict( - type='mmrazor.GeneralQuantDDP', + type='mmrazor.MMArchitectureQuantDDP', broadcast_buffers=False, find_unused_parameters=True) @@ -63,8 +55,7 @@ train_cfg = dict( _delete_=True, type='mmrazor.QATEpochBasedLoop', - by_epoch=True, max_epochs=100, val_interval=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -test_cfg = val_cfg +# test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py deleted file mode 100644 index a0885a52a..000000000 --- a/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py +++ /dev/null @@ -1,75 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] - -train_cfg = dict( - _delete_=True, - type='mmrazor.QATEpochBasedLoop', - max_epochs=_base_.train_cfg.max_epochs) - -resnet = _base_.model -ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 -resnet.init_cfg = dict(type='Pretrained', checkpoint=ckpt) - -model = dict( - _delete_=True, - _scope_='mmrazor', - type='GeneralQuant', - # data_preprocessor = dict( - # num_classes=1000, - # # RGB format normalization parameters - # mean=[123.675, 116.28, 103.53], - # std=[58.395, 57.12, 57.375], - # # convert image from BGR to RGB - # to_rgb=True, - # ), - architecture=resnet, - quantizer=dict( - type='CustomQuantizer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ], - qconfig=dict( - qtype='affine', - w_observer=dict(type='mmrazor.MinMaxObserver'), - a_observer=dict(type='mmrazor.EMAMinMaxObserver'), - w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - w_qscheme=dict( - bit=8, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False, - ), - a_qscheme=dict( - bit=8, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False), - ))) - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) - -# learning policy -param_scheduler = dict( - _delete_=True, - type='CosineAnnealingLR', - T_max=100, - by_epoch=True, - begin=0, - end=100) - -default_hooks = dict( - checkpoint=dict( - type='CheckpointHook', - interval=5, - max_keep_ckpts=3, - out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/quant')) - -model_wrapper_cfg = dict( - type='mmrazor.GeneralQuantDDP', - broadcast_buffers=False, - find_unused_parameters=False) - -val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -test_cfg = val_cfg diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index a2d5d383b..bca61a563 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -1,24 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy -import os from typing import Dict, List, Optional, Sequence, Tuple, Union -import numpy as np import torch from mmengine.evaluator import Evaluator -from mmengine.registry import MODELS -from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop, autocast -from torch.ao.quantization import disable_observer +from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) from torch.nn.intrinsic.qat import freeze_bn_stats from torch.utils.data import DataLoader -from mmrazor.models.task_modules import (ModuleInputsRecorder, - ModuleOutputsRecorder, - RecorderManager) from mmrazor.registry import LOOPS -from .utils import extract_blocks, extract_layers, extract_subgraph - -_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) @LOOPS.register_module() @@ -59,18 +50,14 @@ def __init__( self.disable_observer_begin = disable_observer_begin self.freeze_bn_begin = freeze_bn_begin - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - qat_metrics = dict() - for key, value in metrics.items(): - qat_key = 'qat.' + key - ori_key = 'original.' + key - qat_metrics[qat_key] = value - self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None) + def run(self) -> torch.nn.Module: + """Launch training.""" + self.runner.call_hook('before_train') while self._epoch < self._max_epochs: # state: observer_enabled, fakequant_enabled - self.runner.model.state = (True, True) + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) self.run_epoch() self._decide_current_val_interval() @@ -78,8 +65,8 @@ def __init__( and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): # observer disabled during evaluation - self.runner.model.state = (False, True) - self.runner.model.sync_param() + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -185,8 +172,8 @@ def run_iter(self, idx, data_batch: Sequence[dict], model): self.runner.call_hook( 'before_val_iter', batch_idx=idx, data_batch=data_batch) # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = model.val_step(data_batch) + + outputs = model.val_step(data_batch) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_val_iter', @@ -219,15 +206,13 @@ def run(self) -> dict: self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') self.runner.model.eval() - self.runner.model.state = (True, False) + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test_epoch', metrics=None) self.runner.call_hook('after_test') # todo: hard code to save checkpoint on disk @@ -238,80 +223,10 @@ def run(self) -> dict: save_optimizer=False, save_param_scheduler=False) - return metrics - - @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[dict]) -> None: - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # predictions should be sequence of BaseDataElement - - outputs = self.runner.model.calibrate_step(data_batch) - - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - -# TODO refactor to supoort DDP -@LOOPS.register_module() -class AdaRoundLoop(TestLoop): - """`TestLoop` for Post Training Quantization. + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) - Args: - runner (Runner): A reference of runner - dataloader (Dataloader or dict): An iterator to generate one batch of - dataset each iteration. - evaluator (Evaluator or dict or list): Used for computing metrics. - calibrate_dataloader (Dataloader or dict, optional): A dataloader - object or a dict to build a dataloader for calibration. Defaults - to None. - batch_num (Optional[int], optional): Total calibration batches. - Defaults to None. - reconstruction_cfg (Optional[Dict], optional): Model reconstruction - configuration. Defaults to None. - fp16 (bool, optional): Enable FP16 training mode. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False): - super().__init__(runner, dataloader, evaluator, fp16) - - def run(self) -> None: - """Launch test.""" - self.runner.call_hook('before_test') - self.runner.call_hook('before_test_epoch') - self.runner.model.eval() - self.runner.model.state = (1, 0) - - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - self.runner.call_hook('after_test_epoch', metrics=metrics) - self.runner.call_hook('after_test') - - # todo: hard code to save checkpoint on disk - self.runner.save_checkpoint( - self.runner.work_dir, - 'checkpoint_after_ptq.pth', - file_client_args=None, - save_optimizer=False, - save_param_scheduler=False) - - return metrics + return self.runner.val_loop.run() @torch.no_grad() def run_iter(self, idx, data_batch: Sequence[dict]) -> None: @@ -322,208 +237,11 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: """ self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # predictions should be sequence of BaseDataElement - outputs = self.runner.model.calibrate_step(data_batch) + _ = self.runner.model.calibrate_step(data_batch) self.runner.call_hook( 'after_test_iter', batch_idx=idx, data_batch=data_batch, - outputs=outputs) - - -# TODO refactor to supoort DDP -@LOOPS.register_module() -class AdaRoundLoop(TestLoop): - """`TestLoop` for Post Training Quantization. - - Args: - runner (Runner): A reference of runner - dataloader (Dataloader or dict): An iterator to generate one batch of - dataset each iteration. - evaluator (Evaluator or dict or list): Used for computing metrics. - calibrate_dataloader (Dataloader or dict, optional): A dataloader - object or a dict to build a dataloader for calibration. Defaults - to None. - batch_num (Optional[int], optional): Total calibration batches. - Defaults to None. - reconstruction_cfg (Optional[Dict], optional): Model reconstruction - configuration. Defaults to None. - fp16 (bool, optional): Enable FP16 training mode. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - calibrate_dataloader: Optional[Union[DataLoader, - Dict]] = None, - batch_num: Optional[int] = None, - reconstruction_cfg: Optional[Dict] = None, - fp16: bool = False): - super().__init__(runner, dataloader, evaluator, fp16) - if isinstance(calibrate_dataloader, dict): - # Determine whether or not different ranks use different seed. - diff_rank_seed = runner._randomness_cfg.get( - 'diff_rank_seed', False) - self.calibrate_dataloader = runner.build_dataloader( - calibrate_dataloader, - seed=runner.seed, - diff_rank_seed=diff_rank_seed) - else: - self.calibrate_dataloader = calibrate_dataloader - - self.is_calibrate = True if calibrate_dataloader is not None else False - - if self.runner.distributed: - self.model = runner.model.module - else: - self.model = runner.model - - self.batch_num = batch_num - self.config = reconstruction_cfg - - def calibrate(self, calibrate_dataloader) -> None: - self.model.eval() - with torch.no_grad(): - for i, batch_data in enumerate(calibrate_dataloader): - if self.batch_num and i >= self.batch_num: - break - self.model.calib_step(batch_data) - - def _save_inter_result(self, - model, - dataloader, - slices, - store_input=True, - store_output=True): - recorders = {} - for s in slices: - node_l, node_r = s[:2] - if store_input: - recorders[node_l.target + '_input'] = ModuleInputsRecorder( - node_l.target) - if store_output: - recorders[node_r.target + '_output'] = ModuleOutputsRecorder( - node_r.target) - manager = RecorderManager(recorders) - manager.initialize(model) - - with torch.no_grad(): - with manager: - for i, batch_data in enumerate(dataloader): - if self.batch_num and i >= self.batch_num: - break - batch_data = self.model.data_preprocessor( - batch_data, False) - model(**batch_data) - return manager - - def sub_reconstruction(self, graphmodule, input_recorder, output_recorder, - config): - w_para = [] - for layer in graphmodule.modules(): - # import pdb - # pdb.set_trace() - if isinstance(layer, _ADAROUND_SUPPORT_TYPE): - weight_fake_quant = layer.weight_fake_quant - weight_fake_quant.init(layer.weight.data) - w_para += [weight_fake_quant.alpha] - - w_opt = torch.optim.Adam(w_para) - loss_func = MODELS.build(config.loss) - - for _ in range(config.loss.iters): - w_opt.zero_grad() - - data_size = len(input_recorder.data_buffer) - data_index = np.random.randint(0, data_size) - out_quant = graphmodule( - input_recorder.get_recorder_data(data_index)) - out_fp = output_recorder.get_recorder_data(data_index) - err = loss_func(graphmodule, out_quant, out_fp) - err.backward() - w_opt.step() - - for layer in graphmodule.modules(): - if isinstance(layer, _ADAROUND_SUPPORT_TYPE): - weight_fake_quant = layer.weight_fake_quant - layer.weight.data = weight_fake_quant.get_hard_value( - layer.weight.data) - weight_fake_quant.adaround = False - if isinstance(layer, torch.quantization.FakeQuantize) and hasattr( - layer, 'prob'): - # recover to promise that drop activation quantization only - # occurs at reconstruction phase - layer.prob = 1.0 - - def reconstruction(self, graphmodule, calibrate_dataloader, config): - assert isinstance(graphmodule, torch.fx.GraphModule) - graphmodule_fp = graphmodule - graphmodule_quant = copy.deepcopy(graphmodule) - - # get layers/blocks need to reconstructe - slices = [] - if config.pattern == 'layer': - slices = extract_layers( - graphmodule, layer_types=_ADAROUND_SUPPORT_TYPE) - elif config.pattern == 'block': - slices = extract_blocks(graphmodule) - else: - # TODO: add remind - raise NotImplementedError - - # save fp inputs and outputs of each layers - manager_fp = self._save_inter_result(graphmodule_fp, - self.calibrate_dataloader, slices) - - # extract subgraph_module - for s in slices: - sub_graphmodule = extract_subgraph(graphmodule_quant, s) - manager_quant = self._save_inter_result( - graphmodule_quant, - self.calibrate_dataloader, [s], - store_output=False) - input_index = s[0].target + '_input' - output_index = s[1].target + '_output' - input_recorder = manager_quant.get_recorder(input_index) - output_recorder = manager_fp.get_recorder(output_index) - self.sub_reconstruction(sub_graphmodule, input_recorder, - output_recorder, config) - - return graphmodule_quant - - def run(self) -> None: - """Launch test.""" - self.runner.call_hook('before_test') - self.runner.call_hook('before_test_epoch') - - self.model.eval() - self.model.prepare() - - if self.is_calibrate: - self.model.state = (1, 0) - self.calibrate(self.calibrate_dataloader) - - self.model.state = (1, 1) - - if self.config is not None: - self.model.architecture = self.reconstruction( - self.model.architecture, self.calibrate_dataloader, - self.config) - - self.model.convert() - - self.model.eval() - from torch.onnx import OperatorExportTypes - dummy_input = torch.randn([1, 3, 224, 224]) - onnx_path = os.path.join(self.runner.work_dir, 'quantizied.onnx') - torch.onnx.export( - self.model.architecture, - dummy_input, - onnx_path, - opset_version=11, - operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) - - self.runner.call_hook('after_test') + outputs=None) diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 89a11b899..26305b226 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -7,7 +7,7 @@ BigNAS, BigNASDDP, Darts, DartsDDP) from .pruning import DCFF, DMCP, DMCPDDP, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm -from .quantization import GeneralQuant +from .quantization import MMArchitectureQuant, MMArchitectureQuantDDP __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', @@ -15,5 +15,6 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP', 'DMCP', 'DMCPDDP', 'GeneralQuant' + 'BigNASDDP', 'DMCP', 'DMCPDDP', 'GeneralQuant', 'MMArchitectureQuant', + 'MMArchitectureQuantDDP' ] diff --git a/mmrazor/models/algorithms/quantization/__init__.py b/mmrazor/models/algorithms/quantization/__init__.py index 84c25bbc0..337717c01 100644 --- a/mmrazor/models/algorithms/quantization/__init__.py +++ b/mmrazor/models/algorithms/quantization/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import GeneralQuant +from .base import MMArchitectureQuant, MMArchitectureQuantDDP -__all__ = ['GeneralQuant'] +__all__ = ['MMArchitectureQuant', 'MMArchitectureQuantDDP'] diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/base.py index c97d832ff..ad3c16340 100644 --- a/mmrazor/models/algorithms/quantization/base.py +++ b/mmrazor/models/algorithms/quantization/base.py @@ -20,7 +20,7 @@ @MODELS.register_module() -class GeneralQuant(BaseAlgorithm): +class MMArchitectureQuant(BaseAlgorithm): """General quantization. Args: @@ -39,13 +39,15 @@ class GeneralQuant(BaseAlgorithm): :class:`BaseModule`. """ + + def __init__(self, architecture, quantizer, - export_mode: str = 'predict', - qmodel_modes: List[str] = ['tensor', 'predict', 'loss'], data_preprocessor=None, - pretrained_ckpt: Optional[str] = None, + forward_modes = ('tensor', 'predict', 'loss'), + float_checkpoint: Optional[str] = None, + input_shapes=(1, 3, 224, 224), init_cfg=None): if data_preprocessor is None: @@ -53,35 +55,50 @@ def __init__(self, # The build process is in MMEngine, so we need to add scope here. data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') super().__init__(architecture, data_preprocessor, init_cfg) - if pretrained_ckpt: - _ = load_checkpoint(self.architecture, pretrained_ckpt) + if float_checkpoint: + _ = load_checkpoint(self.architecture, float_checkpoint) self.architecture._is_init = True self.quantizer = MODELS.build(quantizer) - self._observers_enabled = True - self._fake_quants_enabled = True - self.export_mode = export_mode - self.qmodel_modes = qmodel_modes + self.input_shapes = input_shapes + self.forward_modes = forward_modes + self.qmodels = self._build_qmodels(self.architecture) - def sync_param(self): + self.sync_param('tensor') + + def sync_param(self, src_mode): def traverse(module, prefix): for name, child in module._modules.items(): if module is None: continue - module_name = f'{prefix}{name}' + child_name = f'{prefix}{name}' if isinstance(child, FakeQuantizeBase): for name, param in child.named_parameters(): - param.data.copy_(self.qmodels['loss'].state_dict() - [f'{module_name}.{name}']) + param_name = f'{child_name}.{name}' + src_param = src_state_dict[param_name] + if src_param.shape == param.shape: + param.data.copy_(src_param) + else: + requirs_grad = param.requires_grad + param.requires_grad = False + param.resize_(src_param.shape) + param.requires_grad = requirs_grad + param.data.copy_(src_param) for name, buffer in child.named_buffers(): - buffer.data.copy_(self.qmodels['loss'].state_dict() - [f'{module_name}.{name}']) + buffer_name = f'{child_name}.{name}' + src_buffer = src_state_dict[buffer_name] + if src_buffer.shape == buffer.shape: + buffer.data.copy_(src_buffer) + else: + buffer.resize_(src_buffer.shape) + buffer.data.copy_(src_buffer) else: - traverse(child, f'{module_name}.') + traverse(child, f'{child_name}.') - for mode in self.qmodel_modes: - if mode == 'loss': + src_state_dict = self.qmodels[src_mode].state_dict() + for mode in self.forward_modes: + if mode == src_mode: continue traverse(self.qmodels[mode], '') @@ -92,12 +109,17 @@ def _build_qmodels(self, model): self.quantizer._swap_ff_with_fxff(model) tracer = self.quantizer.tracer - for mode in self.qmodel_modes: + for mode in self.forward_modes: concrete_args = {'mode': mode} traced_graph = tracer.trace(model, concrete_args=concrete_args) - qmodel = build_graphmodule(model, traced_graph) - qmodels[mode] = self.quantizer.prepare(model, qmodel) + graph_mopdule = build_graphmodule(model, traced_graph) + observed_module = self.quantizer.prepare(model, graph_mopdule) + + qmodels[mode] = observed_module + + dummy_input = torch.randn(self.input_shapes) + qmodels['predict'](dummy_input, None, 'predict') return qmodels @@ -114,39 +136,11 @@ def forward(self, def calibrate_step(self, data): data = self.data_preprocessor(data, False) - self.state = (1, 0) return self._run_forward(data, mode='tensor') - def convert(self, mode='predict'): - qmodel = self.qmodels[self.export_mode] - self.qmodels[mode] = self.quantizer.convert(qmodel) - - @property - def state(self): - return (self._observers_enabled, self._fake_quants_enabled) - - @state.setter - def state(self, state: Tuple[bool, bool]): - observers_enabled, fake_quants_enabled = state - qmodel = self.qmodels[self.export_mode] - for submodule in qmodel.modules(): - if isinstance(submodule, torch.quantization.FakeQuantize): - if observers_enabled: - submodule.enable_observer() - else: - submodule.disable_observer() - - if fake_quants_enabled: - submodule.enable_fake_quant() - else: - submodule.disable_fake_quant() - - self._observers_enabled = observers_enabled - self._fake_quants_enabled = fake_quants_enabled - @MODEL_WRAPPERS.register_module() -class GeneralQuantDDP(MMDistributedDataParallel): +class MMArchitectureQuantDDP(MMDistributedDataParallel): """DDPwapper for GeneralQuant.""" def __init__(self, @@ -165,18 +159,5 @@ def __init__(self, def calibrate_step(self, data): return self.module.calibrate_step(data) - @property - def state(self): - return (self.module._observers_enabled, - self.module._fake_quants_enabled) - - @state.setter - def state(self, state: Tuple[bool]): - self.module.state = state - - def convert(self, mode='predict'): - self.module.convert(mode) - self.module.qmodels[mode].cuda() - - def sync_param(self): - self.module.sync_param() + def sync_param(self, src): + self.module.sync_param(src) diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py index 3e26631eb..a24898442 100644 --- a/mmrazor/models/fake_quants/lsq.py +++ b/mmrazor/models/fake_quants/lsq.py @@ -144,10 +144,5 @@ def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, @staticmethod def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): - return g.op( - '::FakeQuantizeLearnablePerchannelAffine', - x, - scale, - zero_point, - quant_min_i=quant_min, - quant_max_i=quant_max) + return g.op('::FakeQuantizeLearnablePerchannelAffine', x, scale, + zero_point, ch_axis, quant_min, quant_max) diff --git a/mmrazor/models/observers/lsq_observer.py b/mmrazor/models/observers/lsq_observer.py index d9b96d7a8..b543efe17 100644 --- a/mmrazor/models/observers/lsq_observer.py +++ b/mmrazor/models/observers/lsq_observer.py @@ -33,7 +33,7 @@ def forward(self, x_orig): x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: self.tensor_norm = x.abs().mean() - self.min_val, self.max_val = torch._aminmax(x) + self.min_val, self.max_val = torch.aminmax(x) else: # compute channel-wise mean x_dim = x.size() diff --git a/mmrazor/models/observers/mse.py b/mmrazor/models/observers/mse.py index f85abd902..a2b65a3a6 100644 --- a/mmrazor/models/observers/mse.py +++ b/mmrazor/models/observers/mse.py @@ -71,9 +71,8 @@ def mse_perchannel(self, new_max = x_max * (1.0 - (i * 0.01)) scale, zero_point = self._calculate_qparams(new_min, new_max) x_q = torch.fake_quantize_per_channel_affine( - x, scale, - zero_point.long() if _version_under_1100 else zero_point, - ch_axis, self.quant_min, self.quant_max) + x, scale, zero_point.int(), ch_axis, self.quant_min, + self.quant_max) score = self.lp_loss(x_q, x, reduce_dim) update_idx = (score < best_score) best_score[update_idx] = score[update_idx] @@ -87,7 +86,7 @@ def forward(self, x_orig): return x_orig x = x_orig.clone().detach().to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) min_val_cur, max_val_cur = self.mse( x, min_val_cur, max_val_cur, iter=95) else: @@ -131,7 +130,7 @@ def forward(self, x_orig): return x_orig x = x_orig.clone().detach().to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) min_val_cur, max_val_cur = self.mse( x, min_val_cur, max_val_cur, iter=95) else: diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py index e56902eba..2741e2fd1 100644 --- a/mmrazor/models/quantizers/__init__.py +++ b/mmrazor/models/quantizers/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import CustomQuantizer +from .openvino_quantizer import OpenvinoQuantizer from .trt_quantizer import TensorRTQuantizer -__all__ = ['CustomQuantizer', 'TensorRTQuantizer'] +__all__ = ['CustomQuantizer', 'TensorRTQuantizer', 'OpenvinoQuantizer'] diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 6f1fb4e31..2dd3930fc 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -3,9 +3,11 @@ import torch from mmengine.model import BaseModule -from torch.ao.quantization import QConfig +from torch.ao.quantization import QConfig, enable_fake_quant from torch.ao.quantization.fx import prepare from torch.ao.quantization.quantize_fx import _convert_fx, _fuse_fx +from torch.nn.intrinsic.qat import modules as qat_fused_modules +from torch.nn.qat import modules as qat_modules from mmrazor.models.task_modules.tracer import CustomTracer from mmrazor.models.utils import (check_is_valid_convert_custom_config_dict, @@ -16,6 +18,25 @@ from mmrazor.structures.quantization import (CheckArgs, DefaultQconfigs, QuantizeScheme, SupportQtypes) +SUPPORT_QAT_MODULES = ( + qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, + qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, + qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, + qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, + qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, + qat_modules.Conv3d, qat_modules.Linear) + +MERGE_BN_MAPPINGS = { + qat_fused_modules.ConvBn1d: qat_modules.Conv1d, + qat_fused_modules.ConvBn2d: qat_modules.Conv2d, + qat_fused_modules.ConvBn3d: qat_modules.Conv3d, + qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, + qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, + qat_fused_modules.LinearBn1d: qat_modules.Linear +} + @MODELS.register_module() class CustomQuantizer(BaseModule): @@ -23,7 +44,6 @@ class CustomQuantizer(BaseModule): Args: qconfig (Dict, optional): QConfig. Defaults to DefaultQconfigs['default']. # noqa: E501 - is_qat (bool, optional): Is QAT ro not. Defaults to True. skipped_methods (List, optional): Skipped methods list for tracer. Defaults to None. prepare_custom_config_dict (Dict, optional): `PrepareCustomConfig` @@ -39,7 +59,6 @@ class CustomQuantizer(BaseModule): def __init__(self, qconfig: Dict = DefaultQconfigs['default'], - is_qat: bool = True, skipped_methods: List = None, prepare_custom_config_dict: Dict = None, convert_custom_config_dict: Dict = None, @@ -73,7 +92,6 @@ def __init__(self, self.convert_custom_config_dict) check_is_valid_qconfig_dict(self.equalization_qconfig_dict) - self.is_qat = is_qat self.skipped_methods = skipped_methods self._remove_qconfig = _remove_qconfig self.tracer = self.build_tracer() @@ -90,9 +108,8 @@ def prepare(self, model, graph_module): prepared = prepare( graph_module, self.qconfig_dict, - self.is_qat, + True, self.tracer.node_name_to_scope, - prepare_custom_config_dict=self.prepare_custom_config_dict, equalization_qconfig_dict=self.equalization_qconfig_dict ) # type: ignore[operator] @@ -189,9 +206,41 @@ def build_tracer(self): return tracer def fuse_model(self, graph_module): - if not self.is_qat: - graph_module.eval() - - graph_module = _fuse_fx(graph_module, self.is_qat, + graph_module = _fuse_fx(graph_module, True, self.prepare_custom_config_dict) return graph_module + + def post_process_weight_fakequant(self, + observed_module, + keep_fake_quant=False): + + def traverse(module): + + for name, child in module.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + weight_fakequant = child.weight_fake_quant + child.weight.data = weight_fakequant(child.weight.data) + + float_child = child.to_float() + + if keep_fake_quant: + for m in float_child.modules(): + setattr(m, 'qconfig', self.qconfig_dict['']) + + if type(child) in MERGE_BN_MAPPINGS: + cls = MERGE_BN_MAPPINGS[type(child)] + new_child = cls.from_float(float_child) + else: + new_child = child.from_float(float_child) + + new_child.weight_fake_quant(new_child.weight) + else: + new_child = float_child + setattr(module, name, new_child) + else: + traverse(child) + observed_module.apply(enable_fake_quant) + traverse(observed_module) + + def prepare_for_mmdeploy(self, model): + raise NotImplementedError diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py new file mode 100644 index 000000000..7bf067593 --- /dev/null +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import disable_observer +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import DefaultQconfigs +from mmrazor.models.task_modules.tracer.fx.custom_tracer import build_graphmodule +from .base import CustomQuantizer + + +@MODELS.register_module() +class OpenvinoQuantizer(CustomQuantizer): + """Quantizer for Openvino backend.""" + + support_bits = [8] + support_w_mode = ['per_channel'] + support_a_mode = ['per_tensor'] + + + def __init__(self, + qconfig, + is_qat=True, + skipped_methods=None, + prepare_custom_config_dict=None, + convert_custom_config_dict=None, + equalization_qconfig_dict=None, + _remove_qconfig=True, + init_cfg=None): + super().__init__(qconfig, is_qat, skipped_methods, + prepare_custom_config_dict, + convert_custom_config_dict, equalization_qconfig_dict, + _remove_qconfig, init_cfg) + + + + + def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None): + + + self._swap_ff_with_fxff(model) + graph = self.tracer.trace(model) + graph_module = build_graphmodule(model, graph) + observed_model = self.prepare(model, graph_module) + + + self.post_process_weight_fakequant(observed_model, keep_fake_quant=True) + + if dummy_input is not None: + observed_model(dummy_input) + + observed_model.apply(disable_observer) + + + return observed_model + + + + + + diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py index 9dbe9f594..f2c46844a 100644 --- a/mmrazor/models/quantizers/trt_quantizer.py +++ b/mmrazor/models/quantizers/trt_quantizer.py @@ -1,4 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import disable_observer + +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + build_graphmodule from mmrazor.registry import MODELS from mmrazor.structures.quantization import DefaultQconfigs from .base import CustomQuantizer @@ -8,6 +13,10 @@ class TensorRTQuantizer(CustomQuantizer): """Quantizer for TensorRT backend.""" + support_bits = [8] + support_w_mode = ['per_channel'] + support_a_mode = ['per_tensor'] + def __init__(self, qconfig=DefaultQconfigs['tensorrt'], is_qat=True, @@ -21,3 +30,19 @@ def __init__(self, prepare_custom_config_dict, convert_custom_config_dict, equalization_qconfig_dict, _remove_qconfig, init_cfg) + + def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None) + + graph = self.tracer.trace(model) + graph_module = build_graphmodule(model, graph) + observed_model = self.prepare(model, graph_module) + + observed_model(torch.randn(1, 3, 224, 224)) + + self.post_process_weight_fakequant(observed_model) + if dummy_input is not None: + observed_model(dummy_input) + + observed_model.apply(disable_observer) + + return observed_model diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index 838f20164..5ba623f5e 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -2,8 +2,8 @@ from .backward_tracer import BackwardTracer from .channel_analyzer import ChannelAnalyzer # from .razor_tracer import RazorFxTracer -from .fx import (CustomTracer, UntracedMethodRegistry, custom_symbolic_trace, - prepare_graph_module) +from .fx import (CustomTracer, UntracedMethodRegistry, build_graphmodule, + custom_symbolic_trace) from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, @@ -13,5 +13,5 @@ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', - 'custom_symbolic_trace', 'prepare_graph_module' + 'custom_symbolic_trace', 'build_graphmodule' ] diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py index 998e9ffe1..1190f945c 100644 --- a/mmrazor/models/task_modules/tracer/fx/__init__.py +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .custom_tracer import (CustomTracer, UntracedMethodRegistry, - custom_symbolic_trace, build_graphmodule) + build_graphmodule, custom_symbolic_trace) __all__ = [ 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', diff --git a/mmrazor/models/utils/quantization_util.py b/mmrazor/models/utils/quantization_util.py index 376096b67..5593572ce 100644 --- a/mmrazor/models/utils/quantization_util.py +++ b/mmrazor/models/utils/quantization_util.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Set import torch +import torch.distributed as dist class PerChannelLoadHook: @@ -40,25 +41,15 @@ def close(self): self.hook.remove() -USE_LINK = False USE_DDP = False -try: - import spring.linklink as link - assert link.is_initialized() - USE_LINK = True -except (ModuleNotFoundError, AssertionError): - import torch.distributed as dist - if torch.distributed.is_initialized(): - USE_DDP = True +if torch.distributed.is_initialized(): + USE_DDP = True def sync_tensor(tensor): - if USE_LINK: - if tensor.is_cuda is True: - tensor.data = tensor.data / link.get_world_size() - link.allreduce(tensor.data) - elif USE_DDP: + + if USE_DDP: tensor.data = tensor.data / dist.get_world_size() dist.all_reduce(tensor.data) return tensor diff --git a/mmrazor/registry/registry.py b/mmrazor/registry/registry.py index 6aa3f192d..7f915ee74 100644 --- a/mmrazor/registry/registry.py +++ b/mmrazor/registry/registry.py @@ -30,7 +30,6 @@ from mmengine.registry import \ WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS from mmengine.registry import Registry, build_from_cfg -from mmengine.runner import load_checkpoint def build_razor_model_from_cfg( diff --git a/tools/debug.py b/tools/debug.py new file mode 100644 index 000000000..5d594cff8 --- /dev/null +++ b/tools/debug.py @@ -0,0 +1,162 @@ +import os +import sys +import time +import numpy as np +from tqdm import tqdm + +import torch +from torch.ao.quantization import get_default_qconfig +from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx +import torch.nn as nn +from torch.utils.data import DataLoader + +import torchvision +from torchvision import datasets +from torchvision.models import resnet18, ResNet18_Weights +import torchvision.transforms as transforms + +# Set up warnings +import warnings +warnings.filterwarnings( + action='ignore', + category=DeprecationWarning, + module=r'.*' +) +warnings.filterwarnings( + action='default', + module=r'torch.ao.quantization' +) + +# Specify random seed for repeatable results +_ = torch.manual_seed(191009) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def evaluate(model, criterion, data_loader, use_cuda=False): + model.eval() + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + cnt = 0 + with torch.no_grad(): + for image, target in data_loader: + if use_cuda: + image = image.cuda() + target = target.cuda() + output = model(image) + loss = criterion(output, target) + cnt += 1 + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], image.size(0)) + top5.update(acc5[0], image.size(0)) + print('') + + return top1, top5 + +def print_size_of_model(model): + if isinstance(model, torch.jit.RecursiveScriptModule): + torch.jit.save(model, "temp.p") + else: + torch.jit.save(torch.jit.script(model), "temp.p") + print("Size (MB):", os.path.getsize("temp.p")/1e6) + os.remove("temp.p") + +def prepare_data_loaders(data_path, batch_size=4): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = torchvision.datasets.CIFAR10(data_path, train=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + testset = torchvision.datasets.CIFAR10(data_path, train=False, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + return trainloader, testloader + +data_path = '/nvme/dataset/cifar10/' +saved_model_dir = '../experiments/valid_convert' +os.makedirs(saved_model_dir, exist_ok=True) + +_, data_loader_test = prepare_data_loaders(data_path) +criterion = nn.CrossEntropyLoss() +weights = ResNet18_Weights.DEFAULT +float_model = resnet18(weights=weights).cuda() +float_model.eval() + + +# deepcopy the model since we need to keep the original model around +import copy +model_to_quantize = copy.deepcopy(float_model) +model_to_quantize.eval() +qconfig = get_default_qconfig("fbgemm") +qconfig_dict = {"": qconfig} +prepared_model = prepare_fx(model_to_quantize, qconfig_dict) +# print(prepared_model.graph) + +def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in tqdm(data_loader): + model(image.cuda()) +calibrate(prepared_model, data_loader_test) + +quantized_model = convert_fx(prepared_model) +# print(quantized_model) + +print("Size of model before quantization") +print_size_of_model(float_model) +print("Size of model after quantization") +print_size_of_model(quantized_model) +top1, top5 = evaluate(float_model, criterion, data_loader_test, use_cuda=True) +print("[float model] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) +top1, top5 = evaluate(quantized_model, criterion, data_loader_test) +print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) + +fx_graph_mode_model_file_path = os.path.join(saved_model_dir, "resnet18_fx_graph_mode_quantized.pth") + +# save with state_dict +torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path) +import copy +model_to_quantize = copy.deepcopy(float_model) +prepared_model = prepare_fx(model_to_quantize, {"": qconfig}) +loaded_quantized_model = convert_fx(prepared_model) +loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path)) +top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test) +print("[after serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) +print('pipeline finish') \ No newline at end of file From 486e3152bc8d1994bd0ef40fdde4ed241c49cf24 Mon Sep 17 00:00:00 2001 From: L-Icarus <30308843+L-Icarus@users.noreply.github.com> Date: Thu, 22 Dec 2022 21:28:35 +0800 Subject: [PATCH 04/44] CodeCamp #132 add MinMaxFloorObserver (#376) * add minmaxfloor_observer.py * add MinMaxFloorObserver and normative docstring * add test for MinMaxFloorObserver --- mmrazor/models/observers/__init__.py | 6 +- .../models/observers/minmaxfloor_observer.py | 88 +++++++++++++++++++ mmrazor/models/quantizers/trt_quantizer.py | 2 +- .../test_observers/test_observer.py | 83 +++++++++++++++++ 4 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 mmrazor/models/observers/minmaxfloor_observer.py diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py index eac6371e2..004345625 100644 --- a/mmrazor/models/observers/__init__.py +++ b/mmrazor/models/observers/__init__.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .lsq_observer import LSQObserver from .minmax import EMAMinMaxObserver, MinMaxObserver +from .minmaxfloor_observer import MinMaxFloorObserver from .mse import MSEObserver -__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver', 'LSQObserver'] +__all__ = [ + 'MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver', 'LSQObserver', + 'MinMaxFloorObserver' +] diff --git a/mmrazor/models/observers/minmaxfloor_observer.py b/mmrazor/models/observers/minmaxfloor_observer.py new file mode 100644 index 000000000..231149ad4 --- /dev/null +++ b/mmrazor/models/observers/minmaxfloor_observer.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch.ao.quantization.observer import UniformQuantizationObserverBase + +from mmrazor.registry import MODELS +from ..utils import _is_float_qparams, _is_symmetric_quant, sync_tensor + + +@MODELS.register_module() +class MinMaxFloorObserver(UniformQuantizationObserverBase): + """Calculate minmax of whole calibration dataset with floor but round. + + Args: + dtype: Quantized data type. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, + it will follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, + it will follow the 8-bit setup. + eps: Epsilon value for float32, Defaults to + `torch.finfo(torch.float32).eps`. + """ + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + factory_kwargs, eps) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer('min_val', + torch.tensor(float('inf'), **factory_kwargs)) + self.register_buffer('max_val', + torch.tensor(float('-inf'), **factory_kwargs)) + if (self.qscheme == torch.per_tensor_symmetric and self.reduce_range + and self.dtype == torch.quint8): + raise NotImplementedError('Cannot reduce range for symmetric \ + quantization for quint8') + + def forward(self, x_orig: torch.Tensor) -> torch.Tensor: + """Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculates the quantization parameters.""" + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + if not _is_symmetric_quant(self.qscheme) and not _is_float_qparams( + self.qscheme): + scale = (self.max_val - self.min_val) / float(self.quant_max - + self.quant_min) + scale = torch.max(scale, self.eps) + zero_point = self.quant_min - torch.floor(self.min_val / scale).to( + torch.int) + zero_point = torch.clamp(zero_point, self.quant_min, + self.quant_max) + sync_tensor(scale) + sync_tensor(zero_point) + return scale, zero_point + + @torch.jit.export + def extra_repr(self) -> str: + return 'min_val={}, max_val={}'.format(self.min_val, self.max_val) + + @torch.jit.export + def reset_min_max_vals(self) -> None: + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float('inf'))) + self.max_val.copy_(torch.tensor(float('-inf'))) diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py index f2c46844a..26c44b665 100644 --- a/mmrazor/models/quantizers/trt_quantizer.py +++ b/mmrazor/models/quantizers/trt_quantizer.py @@ -31,7 +31,7 @@ def __init__(self, convert_custom_config_dict, equalization_qconfig_dict, _remove_qconfig, init_cfg) - def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None) + def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None): graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) diff --git a/tests/test_models/test_observers/test_observer.py b/tests/test_models/test_observers/test_observer.py index ca39ecfbd..5b99cb0fc 100644 --- a/tests/test_models/test_observers/test_observer.py +++ b/tests/test_models/test_observers/test_observer.py @@ -1,7 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy +import unittest from unittest import TestCase +import torch import torch.nn as nn +from torch.ao.quantization import QConfig +from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx +from torchvision.models import resnet18 + +from mmrazor.models.observers import MinMaxFloorObserver class ToyModel(nn.Module): @@ -36,3 +44,78 @@ def test_forward(self): class TestLSQObserver(TestMinMaxObserver): pass + + +class TestMinMaxFloorObserver(TestMinMaxObserver): + + def setUp(self) -> None: + self.model_fp = resnet18() + self.w_qscheme = dict( + dtype=torch.qint8, qscheme=torch.per_tensor_affine) + self.a_qscheme = dict( + dtype=torch.quint8, qscheme=torch.per_tensor_affine) + + def test_init(self) -> None: + with self.assertRaises(NotImplementedError): + _ = MinMaxFloorObserver( + dtype=torch.quint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=True) + + def test_prepare(self) -> None: + flag = False + model_to_quantize = copy.deepcopy(self.model_fp) + model_to_quantize.eval() + qconfig_dict = { + '': + QConfig( + activation=MinMaxFloorObserver.with_args(**self.a_qscheme), + weight=MinMaxFloorObserver.with_args(**self.w_qscheme)) + } + prepared_model = prepare_fx(model_to_quantize, qconfig_dict) + for m in prepared_model.modules(): + if isinstance(m, MinMaxFloorObserver): + flag = True + break + self.assertTrue(flag) + + def test_convert(self) -> None: + flag = True + model_to_quantize = copy.deepcopy(self.model_fp) + model_to_quantize.eval() + qconfig_dict = { + '': + QConfig( + activation=MinMaxFloorObserver.with_args(**self.a_qscheme), + weight=MinMaxFloorObserver.with_args(**self.w_qscheme)) + } + prepared_model = prepare_fx(model_to_quantize, qconfig_dict) + prepared_model(torch.randn(1, 3, 224, 224)) + quantized_model = convert_fx(prepared_model) + for m in quantized_model.modules(): + if isinstance(m, MinMaxFloorObserver): + flag = False + break + self.assertTrue(flag) + + def test_states(self) -> None: + test_input = torch.Tensor([6., -8.]) + observer = MinMaxFloorObserver(**self.w_qscheme) + self.assertEqual( + [observer.min_val, observer.max_val], + [torch.tensor(float('inf')), + torch.tensor(float('-inf'))]) + observer.forward(test_input) + # per_tensor_affine + scale, zero_point = observer.calculate_qparams() + self.assertEqual(zero_point.item(), 18) + + def test_forward(self) -> None: + test_input = torch.Tensor([1., -1.]) + observer = MinMaxFloorObserver(**self.w_qscheme) + test_output = observer.forward(test_input) + self.assertIs(test_input, test_output) + + +if __name__ == '__main__': + unittest.main() From b57a11a9ce347b372ef63280526ed8aff13d7426 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Thu, 29 Dec 2022 14:30:37 +0800 Subject: [PATCH 05/44] Quant go (#409) * add torch observer * add torch fakequant * refactor base quantizer * add QConfigHander and QSchemeHander & finish quantizer_refactor_beta * passed ptq_pipeline * tmp-commit * fix loop and algorithm * delete fakequant * refactor code structure * remove lsq * valid ptq pipeline * wip * fix del functions * fix * fix lint and pytest Co-authored-by: HIT-cwh <2892770585@qq.com> --- ...tq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 34 + .../ptq_openvino_resnet18_8xb16_cifar10.py | 35 - ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 35 + ...penvino_resnet50_8xb32_in1k_calib32xb32.py | 35 + mmrazor/__init__.py | 2 +- mmrazor/engine/runner/quantization_loops.py | 24 +- mmrazor/engine/runner/utils/state.py | 19 - mmrazor/engine/runner/utils/subgraph.py | 61 -- .../algorithms/quantization/__init__.py | 2 +- .../{base.py => mm_architecture.py} | 26 +- mmrazor/models/fake_quants/__init__.py | 11 +- mmrazor/models/fake_quants/adaround.py | 98 --- mmrazor/models/fake_quants/base.py | 124 +--- mmrazor/models/fake_quants/lsq.py | 148 ----- mmrazor/models/fake_quants/qdrop.py | 44 -- .../models/fake_quants/torch_fake_quants.py | 36 ++ mmrazor/models/observers/__init__.py | 11 +- mmrazor/models/observers/base.py | 71 +- mmrazor/models/observers/lsq_observer.py | 59 -- mmrazor/models/observers/minmax.py | 97 --- .../models/observers/minmaxfloor_observer.py | 88 --- mmrazor/models/observers/mse.py | 155 ----- mmrazor/models/observers/torch_observers.py | 44 ++ mmrazor/models/quantizers/__init__.py | 13 +- .../models/quantizers/academic_quantizer.py | 106 +++ mmrazor/models/quantizers/base.py | 240 +------ mmrazor/models/quantizers/native_quantizer.py | 130 ++++ .../models/quantizers/openvino_quantizer.py | 122 ++-- .../models/quantizers/tensorrt_quantizer.py | 56 ++ mmrazor/models/quantizers/trt_quantizer.py | 48 -- .../models/task_modules/tracer/fx/__init__.py | 8 +- .../task_modules/tracer/fx/custom_tracer.py | 9 +- .../task_modules/tracer/fx/graph_utils.py | 138 ++++ mmrazor/models/utils/__init__.py | 5 +- mmrazor/models/utils/quantization_util.py | 233 +------ mmrazor/structures/__init__.py | 1 + mmrazor/structures/quantization/__init__.py | 6 +- .../quantization/backend_config/__init__.py | 21 + .../quantization/backend_config/academic.py | 45 ++ .../common_operator_config_utils.py | 607 ++++++++++++++++++ .../quantization/backend_config/mapping.py | 12 + .../quantization/backend_config/native.py | 137 ++++ .../quantization/backend_config/openvino.py | 75 +++ .../quantization/backend_config/tensorrt.py | 75 +++ .../quantization/backend_default_qconfigs.py | 46 -- mmrazor/structures/quantization/qconfig.py | 167 +++++ mmrazor/structures/quantization/qscheme.py | 68 -- .../test_observers/test_observer.py | 121 ---- tests/test_registry/test_registry.py | 22 +- tests/test_structures/test_qconfig.py | 110 ++++ tools/ckpt_demo.py | 13 - tools/debug.py | 162 ----- tools/model_converters/convert_quant_ckpt.py | 53 ++ tools/{ptq_calibrate.py => ptq.py} | 0 tools/tracer_demo.py | 93 --- 55 files changed, 2123 insertions(+), 2078 deletions(-) create mode 100644 configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py delete mode 100644 configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py create mode 100644 configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py create mode 100644 configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py delete mode 100644 mmrazor/engine/runner/utils/state.py delete mode 100644 mmrazor/engine/runner/utils/subgraph.py rename mmrazor/models/algorithms/quantization/{base.py => mm_architecture.py} (92%) delete mode 100644 mmrazor/models/fake_quants/adaround.py delete mode 100644 mmrazor/models/fake_quants/lsq.py delete mode 100644 mmrazor/models/fake_quants/qdrop.py create mode 100644 mmrazor/models/fake_quants/torch_fake_quants.py delete mode 100644 mmrazor/models/observers/lsq_observer.py delete mode 100644 mmrazor/models/observers/minmax.py delete mode 100644 mmrazor/models/observers/minmaxfloor_observer.py delete mode 100644 mmrazor/models/observers/mse.py create mode 100644 mmrazor/models/observers/torch_observers.py create mode 100644 mmrazor/models/quantizers/academic_quantizer.py create mode 100644 mmrazor/models/quantizers/native_quantizer.py create mode 100644 mmrazor/models/quantizers/tensorrt_quantizer.py delete mode 100644 mmrazor/models/quantizers/trt_quantizer.py create mode 100644 mmrazor/models/task_modules/tracer/fx/graph_utils.py create mode 100644 mmrazor/structures/quantization/backend_config/__init__.py create mode 100644 mmrazor/structures/quantization/backend_config/academic.py create mode 100644 mmrazor/structures/quantization/backend_config/common_operator_config_utils.py create mode 100644 mmrazor/structures/quantization/backend_config/mapping.py create mode 100644 mmrazor/structures/quantization/backend_config/native.py create mode 100644 mmrazor/structures/quantization/backend_config/openvino.py create mode 100644 mmrazor/structures/quantization/backend_config/tensorrt.py delete mode 100644 mmrazor/structures/quantization/backend_default_qconfigs.py create mode 100644 mmrazor/structures/quantization/qconfig.py delete mode 100644 mmrazor/structures/quantization/qscheme.py delete mode 100644 tests/test_models/test_observers/test_observer.py create mode 100644 tests/test_structures/test_qconfig.py delete mode 100644 tools/ckpt_demo.py delete mode 100644 tools/debug.py create mode 100644 tools/model_converters/convert_quant_ckpt.py rename tools/{ptq_calibrate.py => ptq.py} (100%) delete mode 100644 tools/tracer_demo.py diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..df667c141 --- /dev/null +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -0,0 +1,34 @@ +_base_ = ['mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py'] + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=_base_.train_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + architecture=_base_.model, + float_checkpoint='/tmp/humu/mobilenet_v2_batch256_imagenet' + + '_20200708-3b2dc3af.pth', + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py deleted file mode 100644 index bb6dbc778..000000000 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py +++ /dev/null @@ -1,35 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] - -test_cfg = dict( - type='mmrazor.PTQLoop', -) - -model = dict( - _delete_=True, - type='mmrazor.MMArchitectureQuant', - architecture=_base_.model, - quantizer=dict( - type='mmrazor.OpenvinoQuantizer', - is_qat=False, - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ], - qconfig=dict( - qtype='affine', - w_observer=dict(type='mmrazor.MSEObserver'), - a_observer=dict(type='mmrazor.EMAMSEObserver'), - w_fake_quant=dict(type='mmrazor.FakeQuantize'), - a_fake_quant=dict(type='mmrazor.FakeQuantize'), - w_qscheme=dict( - bit=8, - is_symmetry=True, - is_per_channel=True, - is_pot_scale=False, - ), - a_qscheme=dict( - bit=8, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False), - ))) diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..56da13de9 --- /dev/null +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -0,0 +1,35 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +train_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=train_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + architecture=_base_.model, + float_checkpoint='/tmp/humu/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..09e103bfc --- /dev/null +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -0,0 +1,35 @@ +_base_ = ['mmcls::resnet/resnet50_8xb32_in1k.py'] + +train_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=train_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + architecture=_base_.model, + float_checkpoint='/tmp/humu/resnet50_8xb32_in1k_20210831-ea4938fc.pth', + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) diff --git a/mmrazor/__init__.py b/mmrazor/__init__.py index 74d91b8fa..8282de52d 100644 --- a/mmrazor/__init__.py +++ b/mmrazor/__init__.py @@ -48,7 +48,7 @@ def digit_version(version_str: str, length: int = 4): return tuple(release) -mmcv_minimum_version = '2.0.0rc1' +mmcv_minimum_version = '2.0.0rc3' mmcv_maximum_version = '2.0.0' mmcv_version = digit_version(mmcv.__version__) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index bca61a563..2a0aa812f 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -198,31 +198,45 @@ def __init__(self, runner, dataloader: Union[DataLoader, Dict], evaluator: Union[Evaluator, Dict, List], + calibrate_dataloader: Union[DataLoader, Dict], + calibrate_steps=32, fp16: bool = False): super().__init__(runner, dataloader, evaluator, fp16) + if isinstance(calibrate_dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) + self.dataloader = runner.build_dataloader( + dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) + else: + self.dataloader = dataloader + + self.calibrate_steps = calibrate_steps def run(self) -> dict: """Launch test.""" self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') + self.runner.model.eval() self.runner.model.apply(enable_fake_quant) self.runner.model.apply(enable_observer) for idx, data_batch in enumerate(self.dataloader): + if idx == self.calibrate_steps: + break self.run_iter(idx, data_batch) - self.runner.call_hook('after_test_epoch', metrics=None) - self.runner.call_hook('after_test') - - # todo: hard code to save checkpoint on disk self.runner.save_checkpoint( self.runner.work_dir, - 'checkpoint_after_ptq.pth', + 'model_ptq.pth', file_client_args=None, save_optimizer=False, save_param_scheduler=False) + self.runner.call_hook('after_test_epoch', metrics=None) + self.runner.call_hook('after_test') + self.runner.model.apply(enable_fake_quant) self.runner.model.apply(disable_observer) diff --git a/mmrazor/engine/runner/utils/state.py b/mmrazor/engine/runner/utils/state.py deleted file mode 100644 index 2f6d602a5..000000000 --- a/mmrazor/engine/runner/utils/state.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmengine.logging import print_log -from torch.ao.quantization import FakeQuantize - - -# TODO: may be removed -def set_quant_state(model, enable_observer=True, enable_fake_quant=True): - for name, submodule in model.named_modules(): - if isinstance(submodule, FakeQuantize): - if enable_observer: - submodule.enable_observer() - else: - submodule.disable_observer() - if enable_fake_quant: - submodule.enable_fake_quant() - else: - submodule.disable_fake_quant() - print_log(f'Enable observer: {enable_observer}; \ - Enable fake quant: {enable_fake_quant}') diff --git a/mmrazor/engine/runner/utils/subgraph.py b/mmrazor/engine/runner/utils/subgraph.py deleted file mode 100644 index ea0f8837f..000000000 --- a/mmrazor/engine/runner/utils/subgraph.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy - -import torch.fx as fx - - -def extract_subgraph(graphmodule, block_slice): - subgraph = copy.deepcopy(graphmodule.graph) - block_start, block_end = block_slice[:2] - for node in subgraph.nodes: - if node.name == 'inputs': - input_node = node - if node.name == block_start.name: - node.replace_input_with(node.prev, input_node) - if node.name == block_end.name: - output_node = node - if node.op == 'output': - node.replace_input_with(node.prev, output_node) - subgraph.lint() - subgraph_module = fx.GraphModule(graphmodule, subgraph) - subgraph_module.graph.eliminate_dead_code() - subgraph_module.recompile() - return subgraph_module - - -def extract_blocks(graph, key_word='layer'): - block_slices = [] - block_slice = [] - pre_stage_index, pre_block_index = 0, 0 - cur_stage_index, cur_block_index = 0, 0 - for node in graph.nodes: - if key_word not in node.name: - continue - else: - items = node.name.split('_') - for i, item in enumerate(items): - if key_word in item: - cur_stage_index = int(item[5:]) - cur_block_index = int(items[i + 1]) - break - if (cur_block_index != pre_block_index) or (cur_stage_index != - pre_stage_index): - block_slice.append(node.prev) - if len(block_slice) == 2: - block_slices.append(block_slice) - block_slice = [] - block_slice.append(node) - - pre_stage_index, pre_block_index = cur_stage_index, cur_block_index - - return block_slices - - -def extract_layers(graphmodule, layer_types): - layer_slices = [] - for node in graphmodule.graph.nodes: - if node.op == 'call_module': - m = graphmodule.get_submodule(node.target) - if isinstance(m, layer_types): - layer_slices.append((node, node)) - return layer_slices diff --git a/mmrazor/models/algorithms/quantization/__init__.py b/mmrazor/models/algorithms/quantization/__init__.py index 337717c01..03a9538e2 100644 --- a/mmrazor/models/algorithms/quantization/__init__.py +++ b/mmrazor/models/algorithms/quantization/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import MMArchitectureQuant, MMArchitectureQuantDDP +from .mm_architecture import MMArchitectureQuant, MMArchitectureQuantDDP __all__ = ['MMArchitectureQuant', 'MMArchitectureQuantDDP'] diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/mm_architecture.py similarity index 92% rename from mmrazor/models/algorithms/quantization/base.py rename to mmrazor/models/algorithms/quantization/mm_architecture.py index ad3c16340..c14aae08c 100644 --- a/mmrazor/models/algorithms/quantization/base.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -39,13 +39,11 @@ class MMArchitectureQuant(BaseAlgorithm): :class:`BaseModule`. """ - - def __init__(self, architecture, quantizer, data_preprocessor=None, - forward_modes = ('tensor', 'predict', 'loss'), + forward_modes=('tensor', 'predict', 'loss'), float_checkpoint: Optional[str] = None, input_shapes=(1, 3, 224, 224), init_cfg=None): @@ -58,15 +56,16 @@ def __init__(self, if float_checkpoint: _ = load_checkpoint(self.architecture, float_checkpoint) self.architecture._is_init = True + self.quantizer = MODELS.build(quantizer) self.input_shapes = input_shapes self.forward_modes = forward_modes self.qmodels = self._build_qmodels(self.architecture) - self.sync_param('tensor') + self.sync_qparams('predict') - def sync_param(self, src_mode): + def sync_qparams(self, src_mode): def traverse(module, prefix): for name, child in module._modules.items(): @@ -106,20 +105,19 @@ def _build_qmodels(self, model): qmodels = nn.ModuleDict() - self.quantizer._swap_ff_with_fxff(model) + self.quantizer.swap_ff_with_fxff(model) tracer = self.quantizer.tracer for mode in self.forward_modes: concrete_args = {'mode': mode} traced_graph = tracer.trace(model, concrete_args=concrete_args) - graph_mopdule = build_graphmodule(model, traced_graph) observed_module = self.quantizer.prepare(model, graph_mopdule) - qmodels[mode] = observed_module - - dummy_input = torch.randn(self.input_shapes) - qmodels['predict'](dummy_input, None, 'predict') + # import pdb + # pdb.set_trace() + # dummy_input = torch.randn(self.input_shapes) + # qmodels['predict'](dummy_input, None, 'predict') return qmodels @@ -136,7 +134,7 @@ def forward(self, def calibrate_step(self, data): data = self.data_preprocessor(data, False) - return self._run_forward(data, mode='tensor') + return self._run_forward(data, mode='predict') @MODEL_WRAPPERS.register_module() @@ -159,5 +157,5 @@ def __init__(self, def calibrate_step(self, data): return self.module.calibrate_step(data) - def sync_param(self, src): - self.module.sync_param(src) + def sync_qparams(self, src): + self.module.sync_qparams(src) diff --git a/mmrazor/models/fake_quants/__init__.py b/mmrazor/models/fake_quants/__init__.py index cea7708a2..9030660f6 100644 --- a/mmrazor/models/fake_quants/__init__.py +++ b/mmrazor/models/fake_quants/__init__.py @@ -1,10 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .adaround import AdaRoundFakeQuantize -from .base import FakeQuantize -from .lsq import LearnableFakeQuantize -from .qdrop import QDropFakeQuantize +from .base import BaseFakeQuantize +from .torch_fake_quants import register_torch_fake_quants -__all__ = [ - 'FakeQuantize', 'AdaRoundFakeQuantize', 'QDropFakeQuantize', - 'LearnableFakeQuantize' -] +__all__ = ['BaseFakeQuantize', 'register_torch_fake_quants'] diff --git a/mmrazor/models/fake_quants/adaround.py b/mmrazor/models/fake_quants/adaround.py deleted file mode 100644 index 9388f1aa4..000000000 --- a/mmrazor/models/fake_quants/adaround.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch.nn.parameter import Parameter - -from mmrazor.registry import MODELS -from .base import FakeQuantize - -_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 - - -@MODELS.register_module() -class AdaRoundFakeQuantize(FakeQuantize): - - def __init__(self, observer, **observer_kwargs): - super().__init__(observer, **observer_kwargs) - self.adaround = False - - def init(self, weight_tensor: torch.Tensor): - self.adaround = True - self.observer_enabled[0] = 0 - self.fake_quant_enabled[0] = 1 - - # self.soft_targets = False # delete this - self.gamma = -0.1 - self.zeta = 1.1 - self.init_alpha(x=weight_tensor.data.clone()) - - def init_alpha(self, x: torch.Tensor): - if self.ch_axis != -1: - new_shape = [1] * len(x.shape) - new_shape[self.ch_axis] = x.shape[self.ch_axis] - scale = self.scale.data.reshape(new_shape) - else: - scale = self.scale.data - x_floor = torch.floor(x / scale) - rest = (x / scale) - x_floor # rest of rounding [0, 1) - alpha = -torch.log((self.zeta - self.gamma) / - (rest - self.gamma) - 1) # => sigmoid(alpha) = rest - self.alpha = Parameter(alpha) - - def rectified_sigmoid(self): - """Function to generate rounding mask. - - Args: - x (torch.Tensor): - zeta (torch.Tensor): - gamma (torch.Tensor): - Returns: - torch.Tensor: - """ - return ((self.zeta - self.gamma) * torch.sigmoid(self.alpha) + - self.gamma).clamp(0, 1) - - def adaround_forward(self, x, hard_value=False): - if self.ch_axis != -1: - new_shape = [1] * len(x.shape) - new_shape[self.ch_axis] = x.shape[self.ch_axis] - scale = self.scale.reshape(new_shape) - zero_point = self.zero_point.reshape(new_shape) - x = torch.floor(x / scale) - if hard_value: - x += (self.alpha >= 0).float() - else: - x += self.rectified_sigmoid(self.alpha, self.zeta, self.gamma) - x += zero_point - x = torch.clamp(x, self.quant_min, self.quant_max) - x = (x - zero_point) * scale - return x - - def forward(self, X): - if self.observer_enabled[0] == 1: - self.activation_post_process(X.detach()) - _scale, _zero_point = self.calculate_qparams() - _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( - self.zero_point.device) - if self.scale.shape != _scale.shape: - self.scale.resize_(_scale.shape) - self.zero_point.resize_(_zero_point.shape) - self.scale.copy_(_scale) - self.zero_point.copy_(_zero_point) - - if self.fake_quant_enabled[0] == 1: - if not self.adaround: - if self.is_per_channel: - X = torch.fake_quantize_per_channel_affine( - X, self.scale, - self.zero_point.long() - if _version_under_1100 else self.zero_point, - self.ch_axis, self.quant_min, self.quant_max) - else: - X = torch.fake_quantize_per_tensor_affine( - X, self.scale.item(), int(self.zero_point.item()), - self.quant_min, self.quant_max) - else: - if not hasattr(self, 'alpha'): - raise NotImplementedError - X = self.adaround_forward(X) - return X diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py index 13f8a1e43..1d4c6dfe0 100644 --- a/mmrazor/models/fake_quants/base.py +++ b/mmrazor/models/fake_quants/base.py @@ -1,124 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch.ao.quantization import FakeQuantizeBase +from torch.ao.quantization import FakeQuantize -from mmrazor.models.utils import (_is_float_qparams, _is_per_channel, - _is_per_tensor, _is_symmetric_quant) -from mmrazor.registry import MODELS - - -@MODELS.register_module() -class FakeQuantize(FakeQuantizeBase): - - scale: torch.Tensor - zero_point: torch.Tensor - - def __init__(self, observer, **observer_kwargs): - super().__init__() - self.activation_post_process = observer(**observer_kwargs) - self.quant_min = self.activation_post_process.quant_min - self.quant_max = self.activation_post_process.quant_max - if _is_float_qparams(self.activation_post_process.qscheme): - zero_point_dtype = torch.float - else: - zero_point_dtype = torch.int - self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) - self.register_buffer('zero_point', - torch.tensor([0], dtype=zero_point_dtype)) - self.dtype = self.activation_post_process.dtype - self.qscheme = self.activation_post_process.qscheme - self.ch_axis = self.activation_post_process.ch_axis \ - if hasattr(self.activation_post_process, 'ch_axis') else -1 - assert _is_per_channel(self.qscheme) or \ - _is_per_tensor(self.qscheme), \ - 'Only per channel and per tensor quantization are supported in ' \ - 'fake quantize' + ' got qscheme: ' + str(self.qscheme) - self.is_per_channel = _is_per_channel(self.qscheme) - - bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double() - self.bitwidth = int(torch.log2(bitrange).item()) - self.is_pot_scale = self.activation_post_process.is_pot_scale - self.is_symmetric_quant = _is_symmetric_quant(self.qscheme) - - @torch.jit.export - def calculate_qparams(self): - return self.activation_post_process.calculate_qparams() - - def forward(self, X): - if self.observer_enabled[0] == 1: - self.activation_post_process(X.detach()) - _scale, _zero_point = self.calculate_qparams() - _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( - self.zero_point.device) - if self.scale.shape != _scale.shape: - self.scale.resize_(_scale.shape) - self.zero_point.resize_(_zero_point.shape) - self.scale.copy_(_scale) - self.zero_point.copy_(_zero_point) - - if self.fake_quant_enabled[0] == 1: - if self.is_per_channel: - X = torch.fake_quantize_per_channel_affine( - X, self.scale, self.zero_point, self.ch_axis, - self.activation_post_process.quant_min, - self.activation_post_process.quant_max) - else: - X = torch.fake_quantize_per_tensor_affine( - X, self.scale, self.zero_point, - self.activation_post_process.quant_min, - self.activation_post_process.quant_max) - return X - - @torch.jit.export - def extra_repr(self): - return 'fake_quant_enabled={}, observer_enabled={}, ' \ - 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ' \ - 'ch_axis={}, scale={}, zero_point={}'.format( - self.fake_quant_enabled, self.observer_enabled, - self.activation_post_process.quant_min, - self.activation_post_process.quant_max, self.dtype, - self.qscheme, self.ch_axis, self.scale, self.zero_point) - - def _save_to_state_dict(self, destination, prefix, keep_vars): - # We cannot currently register scalar values as buffers, so need to - # manually specify serialization here. - super(FakeQuantize, self)._save_to_state_dict(destination, prefix, - keep_vars) - destination[prefix + 'scale'] = self.scale - destination[prefix + 'zero_point'] = self.zero_point - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - # Removing this function throws an error that the the size of the - # loaded tensor does not match the original size i.e., These buffers - # start out with numel 0 and become numel 1 once they have their - # first forward pass. - local_state = ['scale', 'zero_point'] - for name in local_state: - key = prefix + name - if key in state_dict: - val = state_dict[key] - # Custom handling to allow loading scale and zero_point - # of size N into uninitialized buffers of size 0. The - # buffers are resized here, and the values are copied in - # the default state_dict loading code of the parent. - if name == 'scale': - self.scale.resize_(val.shape) - else: - assert name == 'zero_point' - self.zero_point.resize_(val.shape) - # For torchscript module we need to update the attributes here - # since we do not call the `_load_from_state_dict` function - # defined module.py - if torch.jit.is_scripting(): - if name == 'scale': - self.scale.copy_(val) - else: - assert name == 'zero_point' - self.zero_point.copy_(val) - elif strict: - missing_keys.append(key) - super(FakeQuantize, - self)._load_from_state_dict(state_dict, prefix, local_metadata, - strict, missing_keys, - unexpected_keys, error_msgs) +BaseFakeQuantize = FakeQuantize diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py deleted file mode 100644 index a24898442..000000000 --- a/mmrazor/models/fake_quants/lsq.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch.nn.parameter import Parameter - -from mmrazor.registry import MODELS -from ..utils import PerChannelLoadHook, _is_symmetric_quant, is_tracing_state -from .base import FakeQuantize - - -@MODELS.register_module() -class LearnableFakeQuantize(FakeQuantize): - r""" This is an extension of the FakeQuantize module in fake_quantize.py, - which supports more generalized lower-bit quantization and support learning - of the scale and zero point parameters through backpropagation. For - literature references, please see the class - `_LearnableFakeQuantizePerTensorOp`. In addition to the attributes in the - original FakeQuantize module, the `_LearnableFakeQuantize` module also - includes the following attributes to support quantization parameter - learning. - """ - - def __init__(self, - observer, - scale=1., - zero_point=0., - use_grad_scaling=True, - **observer_kwargs): - super(LearnableFakeQuantize, self).__init__(observer, - **observer_kwargs) - self.use_grad_scaling = use_grad_scaling - self.scale = Parameter(torch.tensor([scale])) - self.zero_point = Parameter(torch.tensor([zero_point])) - self.register_buffer('eps', - torch.tensor([torch.finfo(torch.float32).eps])) - # Check whether the module will load a state dict; - # Initialize the shape of per-channel 'scale' and - # 'zero-point' before copying values - self.load_state_dict_hook = PerChannelLoadHook(self) - - @torch.jit.export - def extra_repr(self): - return 'fake_quant_enabled={}, observer_enabled={}, ' \ - 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={},'\ - 'scale={}, zero_point={}'.format( - self.fake_quant_enabled, self.observer_enabled, - self.quant_min, self.quant_max, - self.dtype, self.qscheme, self.ch_axis, - self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape), # noqa: E501 - self.zero_point if self.ch_axis == -1 else 'List') - - @torch.jit.export - def calculate_qparams(self): - self.scale.data.clamp_(min=self.eps.item()) - scale = self.scale.detach() - zero_point = self.zero_point.detach().round().clamp( - self.quant_min, self.quant_max).long() - return scale, zero_point - - def _save_to_state_dict(self, destination, prefix, keep_vars): - super(FakeQuantize, self)._save_to_state_dict(destination, prefix, - keep_vars) - destination[prefix + 'scale'] = self.scale if keep_vars \ - else self.scale.detach() - destination[prefix + 'zero_point'] = self.zero_point if keep_vars \ - else self.zero_point.detach() - - def forward(self, X): - # Learnable fake quantize have to zero_point.float() - # to make it learnable. - if self.observer_enabled[0] == 1: - self.activation_post_process(X.detach()) - _scale, _zero_point = \ - self.activation_post_process.calculate_qparams() - _scale = _scale.to(self.scale.device) - _zero_point = _zero_point.to(self.zero_point.device) - - if self.ch_axis != -1: - self.scale.data = torch.ones_like(_scale) - self.zero_point.data = torch.zeros_like(_zero_point.float()) - - self.scale.data.copy_(_scale) - self.zero_point.data.copy_(_zero_point.float()) - else: - self.scale.data.abs_() - self.scale.data.clamp_(min=self.eps.item()) - - if self.fake_quant_enabled[0] == 1: - if _is_symmetric_quant(self.qscheme): - self.zero_point.data.zero_() - else: - self.zero_point.data.clamp_(self.quant_min, - self.quant_max).float() - - if self.is_per_channel: - if self.use_grad_scaling: - grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * - self.quant_max)**0.5 - else: - grad_factor = 1.0 - if is_tracing_state(): - X = FakeQuantizeLearnablePerchannelAffine.apply( - X, self.scale, self.zero_point, self.ch_axis, - self.quant_min, self.quant_max, grad_factor) - else: - X = _fake_quantize_learnable_per_channel_affine_training( - X, self.scale, self.zero_point, self.ch_axis, - self.quant_min, self.quant_max, grad_factor) - else: - if self.use_grad_scaling: - grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 - else: - grad_factor = 1.0 - X = torch._fake_quantize_learnable_per_tensor_affine( - X, self.scale, self.zero_point, self.quant_min, - self.quant_max, grad_factor) - return X - - -def _fake_quantize_learnable_per_channel_affine_training( - x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): - zero_point = (zero_point.round() - zero_point).detach() + zero_point - new_shape = [1] * len(x.shape) - new_shape[ch_axis] = x.shape[ch_axis] - scale = grad_scale(scale, grad_factor).reshape(new_shape) - zero_point = grad_scale(zero_point, grad_factor).reshape(new_shape) - x = x / scale + zero_point - x = (x.round() - x).detach() + x - x = torch.clamp(x, quant_min, quant_max) - return (x - zero_point) * scale - - -def grad_scale(t, scale): - return (t - (t * scale)).detach() + (t * scale) - - -class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function): - - @staticmethod - def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, - grad_factor): - return _fake_quantize_learnable_per_channel_affine_training( - x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor) - - @staticmethod - def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, - grad_factor): - return g.op('::FakeQuantizeLearnablePerchannelAffine', x, scale, - zero_point, ch_axis, quant_min, quant_max) diff --git a/mmrazor/models/fake_quants/qdrop.py b/mmrazor/models/fake_quants/qdrop.py deleted file mode 100644 index e2e13bfc0..000000000 --- a/mmrazor/models/fake_quants/qdrop.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch.nn.parameter import Parameter - -from mmrazor.registry import MODELS -from .base import FakeQuantize - - -@MODELS.register_module() -class QDropFakeQuantize(FakeQuantize): - - def __init__(self, observer, **observer_kwargs): - super().__init__(observer, **observer_kwargs) - self.scale = Parameter(torch.tensor([1.0], dtype=torch.float)) - self.prob = 1.0 - - def forward(self, X): - if self.observer_enabled[0] == 1: - self.activation_post_process(X.detach()) - _scale, _zero_point = self.calculate_qparams() - _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( - self.zero_point.device) - if self.scale.shape != _scale.shape: - self.scale.resize_(_scale.shape) - self.zero_point.resize_(_zero_point.shape) - self.scale.copy_(_scale) - self.zero_point.copy_(_zero_point) - - if self.fake_quant_enabled[0] == 1: - x_orig = X - if self.is_per_channel: - X = torch.fake_quantize_per_channel_affine( - X, self.scale, self.zero_point, self.ch_axis, - self.activation_post_process.quant_min, - self.activation_post_process.quant_max) - else: - X = torch.fake_quantize_per_tensor_affine( - X, self.scale, self.zero_point, - self.activation_post_process.quant_min, - self.activation_post_process.quant_max) - if self.prob < 1.0: - x_prob = torch.where(torch.rand_like(X) < self.prob, X, x_orig) - return x_prob - return X diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py new file mode 100644 index 000000000..ad1a0d966 --- /dev/null +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from typing import List + +import torch.ao.quantization.fake_quantize as torch_fake_quant_src + +from mmrazor.registry import MODELS + + +def register_torch_fake_quants() -> List[str]: + """Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the + ``MODELS`` registry. + + Returns: + List[str]: A list of registered fake_quants' name. + """ + torch_fake_quants = [] + for module_name in dir(torch_fake_quant_src): + if module_name.startswith('__') or module_name.startswith('_') or \ + module_name.startswith('default'): + continue + _fake_quant = getattr(torch_fake_quant_src, module_name) + if inspect.isclass(_fake_quant) and issubclass( + _fake_quant, torch_fake_quant_src.FakeQuantizeBase): + if MODELS.get(module_name) is None: + MODELS.register_module(module=_fake_quant) + torch_fake_quants.append(module_name) + return torch_fake_quants + + +TORCH_fake_quants = register_torch_fake_quants() +# TORCH_fake_quants including: +# FakeQuantize +# FakeQuantizeBase +# FixedQParamsFakeQuantize +# FusedMovingAvgObsFakeQuantize diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py index 004345625..c82f902f5 100644 --- a/mmrazor/models/observers/__init__.py +++ b/mmrazor/models/observers/__init__.py @@ -1,10 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .lsq_observer import LSQObserver -from .minmax import EMAMinMaxObserver, MinMaxObserver -from .minmaxfloor_observer import MinMaxFloorObserver -from .mse import MSEObserver +from .base import BaseObserver +from .torch_observers import register_torch_observers -__all__ = [ - 'MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver', 'LSQObserver', - 'MinMaxFloorObserver' -] +__all__ = ['BaseObserver', 'register_torch_observers'] diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py index 8d9c40afe..a68410eb0 100644 --- a/mmrazor/models/observers/base.py +++ b/mmrazor/models/observers/base.py @@ -1,73 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple - -import torch from torch.ao.quantization.observer import UniformQuantizationObserverBase -from mmrazor.models.utils import pot_quantization, sync_tensor - - -# todo: We only support per-tensor quantization currently. -class BaseObserver(UniformQuantizationObserverBase): - """Modified torch quantization observer. - - Args: - dtype: dtype argument to the `quantize` node needed to implement the - reference model spec. - qscheme: Quantization scheme to be used. - reduce_range: Reduces the range of the quantized data type by 1 bit. - This is sometimes required to avoid instruction overflow. - quant_min: Minimum quantization value. If unspecified, it will follow - the 8-bit setup. - quant_max: Maximum quantization value. If unspecified, it will follow - the 8-bit setup. - ch_axis (int, optional): Channel axis index. Defaults to -1. - is_pot_scale (bool, optional): Indicate whether scale is power of two. - Defaults to False. - eps: Epsilon value for float32. - Defaults to `torch.finfo(torch.float32).eps`. - """ - - min_val: torch.Tensor - max_val: torch.Tensor - - def __init__(self, - dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=False, - quant_min=None, - quant_max=None, - ch_axis=-1, - is_pot_scale=False, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps) -> None: - super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, - factory_kwargs, eps) - factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) - self.register_buffer('min_val', - torch.tensor(float('inf'), **factory_kwargs)) - self.register_buffer('max_val', - torch.tensor(float('-inf'), **factory_kwargs)) - self.ch_axis = ch_axis - self.is_pot_scale = is_pot_scale - - @torch.jit.export - def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: - r"""Calculates the quantization parameters.""" - scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) - scale.data = sync_tensor(scale).data - zero_point.data = sync_tensor(zero_point).data - if self.is_pot_scale: - scale = pot_quantization(scale) - return scale, zero_point - - @torch.jit.export - def extra_repr(self): - return 'min_val={}, max_val={} ch_axis={} is_pot_scale={}'.format( - self.min_val, self.max_val, self.ch_axis, self.is_pot_scale) - - @torch.jit.export - def reset_min_max_vals(self): - """Resets the min/max values.""" - self.min_val.copy_(torch.tensor(float('inf'))) - self.max_val.copy_(torch.tensor(float('-inf'))) +BaseObserver = UniformQuantizationObserverBase diff --git a/mmrazor/models/observers/lsq_observer.py b/mmrazor/models/observers/lsq_observer.py deleted file mode 100644 index b543efe17..000000000 --- a/mmrazor/models/observers/lsq_observer.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import math - -import torch - -from mmrazor.registry import MODELS -from ..utils import _is_symmetric_quant, pot_quantization, sync_tensor -from .base import BaseObserver - - -@MODELS.register_module() -class LSQObserver(BaseObserver): - """Observer for `LEARNED STEP SIZE QUANTIZATION`""" - - def __init__(self, - dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=False, - quant_min=None, - quant_max=None, - ch_axis=-1, - is_pot_scale=False, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps) -> None: - super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, - ch_axis, is_pot_scale, factory_kwargs, eps) - - self.tensor_norm = None - - def forward(self, x_orig): - if x_orig.numel() == 0: - return x_orig - x = x_orig.to(self.min_val.dtype) - if self.ch_axis == -1: - self.tensor_norm = x.abs().mean() - self.min_val, self.max_val = torch.aminmax(x) - else: - # compute channel-wise mean - x_dim = x.size() - new_axis_list = [i for i in range(len(x_dim))] - new_axis_list[self.ch_axis] = 0 - new_axis_list[0] = self.ch_axis - y = x.permute(new_axis_list) - y = torch.flatten(y, start_dim=1) - self.tensor_norm = y.abs().mean(1) - self.min_val, self.max_val = torch._aminmax(y, 1) - - return x - - def calculate_qparams(self): - scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) - zero_point = torch.zeros_like(self.tensor_norm) - sync_tensor(scale) - sync_tensor(zero_point) - if self.is_pot_scale: - scale = pot_quantization(scale) - if not _is_symmetric_quant(self.qscheme): - zero_point = self.quant_min - torch.round(self.min_val / scale) - return scale, zero_point diff --git a/mmrazor/models/observers/minmax.py b/mmrazor/models/observers/minmax.py deleted file mode 100644 index 2ec620e60..000000000 --- a/mmrazor/models/observers/minmax.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmrazor.registry import MODELS -from .base import BaseObserver - - -@MODELS.register_module() -class MinMaxObserver(BaseObserver): - """Min max observer.""" - - def __init__(self, - dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=False, - quant_min=None, - quant_max=None, - ch_axis=-1, - is_pot_scale=False, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps) -> None: - super(MinMaxObserver, self).__init__(dtype, qscheme, reduce_range, - quant_min, quant_max, ch_axis, - is_pot_scale, factory_kwargs, eps) - if (self.qscheme == torch.per_tensor_symmetric and self.reduce_range - and self.dtype == torch.quint8): - raise NotImplementedError('Cannot reduce range for symmetric \ - quantization for quint8') - - def forward(self, x_orig): - r"""Records the running minimum and maximum of ``x``.""" - if x_orig.numel() == 0: - return x_orig - x = x_orig.detach() # avoid keeping autograd tape - x = x.to(self.min_val.dtype) - if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) - else: - x_dim = x.size() - new_axis_list = [i for i in range(len(x_dim))] - new_axis_list[self.ch_axis] = 0 - new_axis_list[0] = self.ch_axis - y = x.permute(new_axis_list) - y = torch.flatten(y, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) - min_val = torch.min(self.min_val, min_val_cur) - max_val = torch.max(self.max_val, max_val_cur) - self.min_val = min_val - self.max_val = max_val - - return x - - -@MODELS.register_module() -class EMAMinMaxObserver(BaseObserver): - """Moving average min/max among batches.""" - - def __init__(self, - dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=False, - quant_min=None, - quant_max=None, - ch_axis=-1, - is_pot_scale=False, - ema_ratio=0.9, - factory_kwargs=None): - super(EMAMinMaxObserver, - self).__init__(dtype, qscheme, reduce_range, quant_min, - quant_max, ch_axis, is_pot_scale, factory_kwargs) - self.ema_ratio = ema_ratio - - def forward(self, x_orig): - r"""Records the running minimum and maximum of ``x``.""" - if x_orig.numel() == 0: - return x_orig - x = x_orig.to(self.min_val.dtype) - if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) - else: - x_dim = x.size() - new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 - new_axis_list[self.ch_axis] = 0 - new_axis_list[0] = self.ch_axis - y = x.permute(new_axis_list) - y = torch.flatten(y, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) - - if self.max_val.numel() <= 1 and self.max_val.isinf(): - self.min_val = min_val_cur - self.max_val = max_val_cur - else: - self.min_val = self.min_val * self.ema_ratio + min_val_cur * ( - 1.0 - self.ema_ratio) - self.max_val = self.max_val * self.ema_ratio + max_val_cur * ( - 1.0 - self.ema_ratio) - return x diff --git a/mmrazor/models/observers/minmaxfloor_observer.py b/mmrazor/models/observers/minmaxfloor_observer.py deleted file mode 100644 index 231149ad4..000000000 --- a/mmrazor/models/observers/minmaxfloor_observer.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple - -import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase - -from mmrazor.registry import MODELS -from ..utils import _is_float_qparams, _is_symmetric_quant, sync_tensor - - -@MODELS.register_module() -class MinMaxFloorObserver(UniformQuantizationObserverBase): - """Calculate minmax of whole calibration dataset with floor but round. - - Args: - dtype: Quantized data type. - qscheme: Quantization scheme to be used. - reduce_range: Reduces the range of the quantized data type by 1 bit. - This is sometimes required to avoid instruction overflow. - quant_min: Minimum quantization value. If unspecified, - it will follow the 8-bit setup. - quant_max: Maximum quantization value. If unspecified, - it will follow the 8-bit setup. - eps: Epsilon value for float32, Defaults to - `torch.finfo(torch.float32).eps`. - """ - min_val: torch.Tensor - max_val: torch.Tensor - - def __init__(self, - dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=False, - quant_min=None, - quant_max=None, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps) -> None: - super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, - factory_kwargs, eps) - factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) - self.register_buffer('min_val', - torch.tensor(float('inf'), **factory_kwargs)) - self.register_buffer('max_val', - torch.tensor(float('-inf'), **factory_kwargs)) - if (self.qscheme == torch.per_tensor_symmetric and self.reduce_range - and self.dtype == torch.quint8): - raise NotImplementedError('Cannot reduce range for symmetric \ - quantization for quint8') - - def forward(self, x_orig: torch.Tensor) -> torch.Tensor: - """Records the running minimum and maximum of ``x``.""" - if x_orig.numel() == 0: - return x_orig - x = x_orig.detach() # avoid keeping autograd tape - x = x.to(self.min_val.dtype) - min_val_cur, max_val_cur = torch.aminmax(x) - min_val = torch.min(min_val_cur, self.min_val) - max_val = torch.max(max_val_cur, self.max_val) - self.min_val.copy_(min_val) - self.max_val.copy_(max_val) - return x_orig - - @torch.jit.export - def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculates the quantization parameters.""" - scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) - if not _is_symmetric_quant(self.qscheme) and not _is_float_qparams( - self.qscheme): - scale = (self.max_val - self.min_val) / float(self.quant_max - - self.quant_min) - scale = torch.max(scale, self.eps) - zero_point = self.quant_min - torch.floor(self.min_val / scale).to( - torch.int) - zero_point = torch.clamp(zero_point, self.quant_min, - self.quant_max) - sync_tensor(scale) - sync_tensor(zero_point) - return scale, zero_point - - @torch.jit.export - def extra_repr(self) -> str: - return 'min_val={}, max_val={}'.format(self.min_val, self.max_val) - - @torch.jit.export - def reset_min_max_vals(self) -> None: - """Resets the min/max values.""" - self.min_val.copy_(torch.tensor(float('inf'))) - self.max_val.copy_(torch.tensor(float('-inf'))) diff --git a/mmrazor/models/observers/mse.py b/mmrazor/models/observers/mse.py deleted file mode 100644 index a2b65a3a6..000000000 --- a/mmrazor/models/observers/mse.py +++ /dev/null @@ -1,155 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmrazor.registry import MODELS -from .base import BaseObserver - -_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 - - -@MODELS.register_module() -class MSEObserver(BaseObserver): - """MSE observer.""" - - def __init__(self, - dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=False, - quant_min=None, - quant_max=None, - ch_axis=-1, - is_pot_scale=False, - p=2.0, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps) -> None: - super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, - ch_axis, is_pot_scale, factory_kwargs, eps) - self.p = p - - def lp_loss(self, pred, tgt, dim=None): - """loss function measured in L_p Norm.""" - return (pred - tgt).abs().pow( - self.p).mean(dim) if dim else (pred - - tgt).abs().pow(self.p).mean() - - def mse(self, - x: torch.Tensor, - x_min: torch.Tensor, - x_max: torch.Tensor, - iter=80): - best_score = 1e+10 - best_min, best_max = torch.tensor( - [1.0], dtype=torch.float), torch.tensor([1.0], dtype=torch.float) - best_min.copy_(x_min) - best_max.copy_(x_max) - for i in range(iter): - new_min = x_min * (1.0 - (i * 0.01)) - new_max = x_max * (1.0 - (i * 0.01)) - scale, zero_point = self._calculate_qparams(new_min, new_max) - x_q = torch.fake_quantize_per_tensor_affine( - x, scale.item(), int(zero_point.item()), self.quant_min, - self.quant_max) - score = self.lp_loss(x_q, x) - if score < best_score: - best_score = score - best_min, best_max = new_min, new_max - return best_min, best_max - - def mse_perchannel(self, - x: torch.Tensor, - x_min: torch.Tensor, - x_max: torch.Tensor, - iter=80, - ch_axis=0): - assert x_min.shape == x_max.shape - assert ch_axis >= 0, f'{ch_axis}' - best_score = 1e+10 * torch.ones_like(x_min) - best_min, best_max = x_min.clone(), x_max.clone() - reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis]) - for i in range(iter): - new_min = x_min * (1.0 - (i * 0.01)) - new_max = x_max * (1.0 - (i * 0.01)) - scale, zero_point = self._calculate_qparams(new_min, new_max) - x_q = torch.fake_quantize_per_channel_affine( - x, scale, zero_point.int(), ch_axis, self.quant_min, - self.quant_max) - score = self.lp_loss(x_q, x, reduce_dim) - update_idx = (score < best_score) - best_score[update_idx] = score[update_idx] - best_min[update_idx] = new_min[update_idx] - best_max[update_idx] = new_max[update_idx] - return best_min, best_max - - def forward(self, x_orig): - r"""Records the running minimum and maximum of ``x``.""" - if x_orig.numel() == 0: - return x_orig - x = x_orig.clone().detach().to(self.min_val.dtype) - if self.ch_axis == -1: - min_val_cur, max_val_cur = torch.aminmax(x) - min_val_cur, max_val_cur = self.mse( - x, min_val_cur, max_val_cur, iter=95) - else: - x_dim = x.size() - new_axis_list = [i for i in range(len(x_dim))] - new_axis_list[self.ch_axis] = 0 - new_axis_list[0] = self.ch_axis - x_channel = x.permute(new_axis_list) - y = torch.flatten(x_channel, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) - min_val_cur, max_val_cur = self.mse_perchannel( - x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) - - self.min_val = torch.min(self.min_val, min_val_cur) - self.max_val = torch.max(self.max_val, max_val_cur) - return x - - -@MODELS.register_module() -class EMAMSEObserver(MSEObserver): - - def __init__(self, - dtype=torch.quint8, - qscheme=torch.per_tensor_affine, - reduce_range=False, - quant_min=None, - quant_max=None, - ch_axis=-1, - is_pot_scale=False, - p=2.0, - ema_ratio=0.9, - factory_kwargs=None, - eps=torch.finfo(torch.float32).eps) -> None: - super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, - ch_axis, is_pot_scale, p, factory_kwargs, eps) - self.ema_ratio = ema_ratio - - def forward(self, x_orig): - r"""Records the running minimum and maximum of ``x``.""" - if x_orig.numel() == 0: - return x_orig - x = x_orig.clone().detach().to(self.min_val.dtype) - if self.ch_axis == -1: - min_val_cur, max_val_cur = torch.aminmax(x) - min_val_cur, max_val_cur = self.mse( - x, min_val_cur, max_val_cur, iter=95) - else: - x_dim = x.size() - new_axis_list = [i for i in range(len(x_dim))] - new_axis_list[self.ch_axis] = 0 - new_axis_list[0] = self.ch_axis - x_channel = x.permute(new_axis_list) - y = torch.flatten(x_channel, start_dim=1) - min_val_cur, max_val_cur = torch._aminmax(y, 1) - min_val_cur, max_val_cur = self.mse_perchannel( - x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) - - if self.max_val.numel() <= 1 and self.max_val.isinf(): - self.min_val = min_val_cur - self.max_val = max_val_cur - else: - self.min_val = self.min_val * self.ema_ratio + min_val_cur * ( - 1.0 - self.ema_ratio) - self.max_val = self.max_val * self.ema_ratio + max_val_cur * ( - 1.0 - self.ema_ratio) - return x diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py new file mode 100644 index 000000000..8e0e81d58 --- /dev/null +++ b/mmrazor/models/observers/torch_observers.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import inspect +from typing import List + +import torch.ao.quantization.observer as torch_observer_src + +from mmrazor.registry import MODELS + + +def register_torch_observers() -> List[str]: + """Register observers in ``torch.ao.quantization.observer`` to the + ``MODELS`` registry. + + Returns: + List[str]: A list of registered observers' name. + """ + torch_observers = [] + for module_name in dir(torch_observer_src): + if module_name.startswith('__') or module_name.startswith('_') or \ + module_name.startswith('default'): + continue + _observer = getattr(torch_observer_src, module_name) + if inspect.isclass(_observer) and issubclass( + _observer, torch_observer_src.ObserverBase): + if MODELS.get(module_name) is None: + MODELS.register_module(module=_observer) + torch_observers.append(module_name) + return torch_observers + + +TORCH_observers = register_torch_observers() +# TORCH_observers including: +# FixedQParamsObserver +# HistogramObserver +# MinMaxObserver +# MovingAverageMinMaxObserver +# MovingAveragePerChannelMinMaxObserver +# NoopObserver +# ObserverBase +# PerChannelMinMaxObserver +# PlaceholderObserver +# RecordingObserver +# ReuseInputObserver +# UniformQuantizationObserverBase diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py index 2741e2fd1..9d4fa1a28 100644 --- a/mmrazor/models/quantizers/__init__.py +++ b/mmrazor/models/quantizers/__init__.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import CustomQuantizer -from .openvino_quantizer import OpenvinoQuantizer -from .trt_quantizer import TensorRTQuantizer +from .academic_quantizer import AcademicQuantizer +from .base import BaseQuantizer +from .native_quantizer import NativeQuantizer +from .openvino_quantizer import OpenVINOQuantizer +from .tensorrt_quantizer import TensorRTQuantizer -__all__ = ['CustomQuantizer', 'TensorRTQuantizer', 'OpenvinoQuantizer'] +__all__ = [ + 'BaseQuantizer', 'AcademicQuantizer', 'NativeQuantizer', + 'TensorRTQuantizer', 'OpenVINOQuantizer' +] diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py new file mode 100644 index 000000000..6a6500791 --- /dev/null +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization.fx import prepare +from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, + PrepareCustomConfig) +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quant_type import _quant_type_from_str +from torch.ao.quantization.quantize_fx import _fuse_fx + +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from .base import BaseQuantizer + +GLOBAL_DICT_KEY = '_global_' +OBJECT_TYPE_DICT_KEY = 'object_type' +MODULE_NAME_REGEX_DICT_KEY = 'module_name_regex' +MODULE_NAME_DICT_KEY = 'module_name' +MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = 'module_name_object_type_order' + +FLOAT_TO_OBSERVED_DICT_KEY = 'float_to_observed_custom_module_class' +PRESERVED_ATTRIBUTES_DICT_KEY = 'preserved_attributes' + + +@MODELS.register_module() +class AcademicQuantizer(BaseQuantizer): + + def __init__(self, + qconfig_mapping, + tracer=dict(type='mmrazor.CustomTracer'), + prepare_custom_config=None, + backend_config=BackendConfigs['academic']): + super().__init__(tracer) + self.qconfig_mapping = self.gen_qconfig_mapping(qconfig_mapping) + self.prepare_custom_config = self.gen_prepare_custom_config( + prepare_custom_config) + self.backend_config = backend_config + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + def prepare(self, model, graph_module): + preserved_attributes = self.prepare_custom_config.preserved_attributes + for attr_name in preserved_attributes: + setattr(graph_module, attr_name, getattr(model, attr_name)) + fuse_custom_config = FuseCustomConfig().set_preserved_attributes( + preserved_attributes) + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + fuse_custom_config=fuse_custom_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + prepare_custom_config=self.prepare_custom_config, + backend_config=self.backend_config) + for attr_name in preserved_attributes: + setattr(prepared, attr_name, getattr(model, attr_name)) + + return prepared + + def gen_qconfig_mapping(self, qconfig_mapping): + conf = QConfigMapping() + if GLOBAL_DICT_KEY in qconfig_mapping: + qconfig = QConfigHander(qconfig_mapping[GLOBAL_DICT_KEY]).convert() + conf.set_global(qconfig) + for object_type, qconfig in qconfig_mapping.get( + OBJECT_TYPE_DICT_KEY, []): + qconfig = QConfigHander(qconfig).convert() + conf.set_object_type(object_type, qconfig) + + for module_name_regex, qconfig in qconfig_mapping.get( + MODULE_NAME_REGEX_DICT_KEY, []): + qconfig = QConfigHander(qconfig).convert() + conf.set_module_name_regex(module_name_regex, qconfig) + for module_name, qconfig in qconfig_mapping.get( + MODULE_NAME_DICT_KEY, []): + qconfig = QConfigHander(qconfig).convert() + conf.set_module_name(module_name, qconfig) + for module_name, object_type, index, qconfig in qconfig_mapping.get( + MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): + qconfig = QConfigHander(qconfig).convert() + conf.set_module_name_object_type_order(module_name, object_type, + index, qconfig) + + return conf + + def gen_prepare_custom_config(self, prepare_custom_config): + conf = PrepareCustomConfig() + if prepare_custom_config is None: + return conf + else: + for quant_type_name, custom_module_mapping in \ + prepare_custom_config.get( + FLOAT_TO_OBSERVED_DICT_KEY, {}).items(): + quant_type = _quant_type_from_str(quant_type_name) + mapping_items = custom_module_mapping.items() + for float_class_str, observed_class_str in mapping_items: + float_class = MODELS.get(float_class_str) + observed_class = MODELS.get(observed_class_str) + conf.set_float_to_observed_mapping(float_class, + observed_class, + quant_type) + conf.set_preserved_attributes( + prepare_custom_config.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])) + return conf diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 2dd3930fc..4d1adceda 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -1,246 +1,32 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List +from abc import abstractmethod import torch from mmengine.model import BaseModule -from torch.ao.quantization import QConfig, enable_fake_quant -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.quantize_fx import _convert_fx, _fuse_fx -from torch.nn.intrinsic.qat import modules as qat_fused_modules -from torch.nn.qat import modules as qat_modules -from mmrazor.models.task_modules.tracer import CustomTracer -from mmrazor.models.utils import (check_is_valid_convert_custom_config_dict, - check_is_valid_prepare_custom_config_dict, - check_is_valid_qconfig_dict, - get_custom_module_class_keys) -from mmrazor.registry import MODELS -from mmrazor.structures.quantization import (CheckArgs, DefaultQconfigs, - QuantizeScheme, SupportQtypes) +from mmrazor.registry import TASK_UTILS -SUPPORT_QAT_MODULES = ( - qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, - qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, - qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, - qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, - qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, - qat_modules.Conv3d, qat_modules.Linear) -MERGE_BN_MAPPINGS = { - qat_fused_modules.ConvBn1d: qat_modules.Conv1d, - qat_fused_modules.ConvBn2d: qat_modules.Conv2d, - qat_fused_modules.ConvBn3d: qat_modules.Conv3d, - qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, - qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, - qat_fused_modules.LinearBn1d: qat_modules.Linear -} +class BaseQuantizer(BaseModule): + def __init__(self, tracer): + super().__init__() + self.tracer = TASK_UTILS.build(tracer) -@MODELS.register_module() -class CustomQuantizer(BaseModule): - """Configurable quantizer, base class of quantizers. + @abstractmethod + def prepare(self): + pass - Args: - qconfig (Dict, optional): QConfig. Defaults to DefaultQconfigs['default']. # noqa: E501 - skipped_methods (List, optional): Skipped methods list for tracer. - Defaults to None. - prepare_custom_config_dict (Dict, optional): `PrepareCustomConfig` - from `torch.quantization.fx`. Defaults to None. - convert_custom_config_dict (Dict, optional): `ConvertCustomConfig` - from `torch.quantization.fx`. Defaults to None. - equalization_qconfig_dict (Dict, optional): Custom `QConfig` effects - on all modules. Defaults to None. - _remove_qconfig (Dict, optional): Remove qconfig at the end of - `_convert_fx`. Defaults to True. - init_cfg (dict, optional): Initialization config dict. - """ - - def __init__(self, - qconfig: Dict = DefaultQconfigs['default'], - skipped_methods: List = None, - prepare_custom_config_dict: Dict = None, - convert_custom_config_dict: Dict = None, - equalization_qconfig_dict: Dict = None, - _remove_qconfig: bool = True, - init_cfg: Dict = None): - super().__init__(init_cfg) - if self.check_qconfig(qconfig): - qconfig = self.qconfig_convert(qconfig) - self.qconfig_dict = {'': qconfig} - else: - raise ValueError('qconfig is incorrect!') - - if prepare_custom_config_dict is None: - self.prepare_custom_config_dict = {} - else: - self.prepare_custom_config_dict = prepare_custom_config_dict - if convert_custom_config_dict is None: - self.convert_custom_config_dict = {} - else: - self.convert_custom_config_dict = convert_custom_config_dict - if equalization_qconfig_dict is None: - self.equalization_qconfig_dict = {} - else: - self.equalization_qconfig_dict = equalization_qconfig_dict - - check_is_valid_qconfig_dict(self.qconfig_dict) - check_is_valid_prepare_custom_config_dict( - self.prepare_custom_config_dict) - check_is_valid_convert_custom_config_dict( - self.convert_custom_config_dict) - check_is_valid_qconfig_dict(self.equalization_qconfig_dict) - - self.skipped_methods = skipped_methods - self._remove_qconfig = _remove_qconfig - self.tracer = self.build_tracer() - - def prepare(self, model, graph_module): - - preserved_attributes = self.prepare_custom_config_dict.get( - 'preserved_attributes', []) - for attr_name in preserved_attributes: - setattr(graph_module, attr_name, getattr(model, attr_name)) - - graph_module = self.fuse_model(graph_module) - - prepared = prepare( - graph_module, - self.qconfig_dict, - True, - self.tracer.node_name_to_scope, - equalization_qconfig_dict=self.equalization_qconfig_dict - ) # type: ignore[operator] - - for attr_name in preserved_attributes: - setattr(prepared, attr_name, getattr(model, attr_name)) - return prepared - - def convert(self, graph_module): - quantized = _convert_fx( - graph_module, - is_reference=False, - convert_custom_config_dict=self.convert_custom_config_dict, - _remove_qconfig=self._remove_qconfig, - qconfig_dict=self.qconfig_dict) - return quantized - - def check_qconfig(self, qconfig): - is_pass = True - for arg in CheckArgs: - if arg == 'qtype': - if qconfig[arg] in SupportQtypes and arg in qconfig.keys(): - continue - else: - is_pass = False - break - else: - if isinstance(qconfig[arg], dict) and arg in qconfig.keys(): - continue - else: - is_pass = False - break - return is_pass - - def qconfig_convert(self, qconfig): - self.w_qscheme = QuantizeScheme(**qconfig['w_qscheme']) - self.a_qscheme = QuantizeScheme(**qconfig['a_qscheme']) - w_observer = MODELS.get(qconfig['w_observer']['type']) - w_observer_kwargs = self.w_qscheme.to_observer_params() - a_observer = MODELS.get(qconfig['a_observer']['type']) - a_observer_kwargs = self.a_qscheme.to_observer_params() - self.w_observer = MODELS.get(qconfig['w_observer']['type']).with_args( - **self.w_qscheme.to_observer_params()) - self.a_observer = MODELS.get(qconfig['a_observer']['type']).with_args( - **self.a_qscheme.to_observer_params()) - self.w_fake_quant = MODELS.get( - qconfig['w_fake_quant']['type']).with_args( - observer=w_observer, **w_observer_kwargs) - self.a_fake_quant = MODELS.get( - qconfig['a_fake_quant']['type']).with_args( - observer=a_observer, **a_observer_kwargs) - - torch_qconfig = QConfig( - weight=self.w_fake_quant, activation=self.a_fake_quant) - return torch_qconfig - - def _swap_ff_with_fxff(self, model: torch.nn.Module) -> None: + def swap_ff_with_fxff(self, model): r""" Swap FloatFunctional with FXFloatFunctional """ modules_to_swap = [] for name, module in model.named_children(): - if isinstance(module, torch.nn.quantized.FloatFunctional): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): modules_to_swap.append(name) else: - self._swap_ff_with_fxff(module) + self.swap_ff_with_fxff(module) for name in modules_to_swap: del model._modules[name] - model._modules[name] = torch.nn.quantized.FXFloatFunctional() - - def build_tracer(self): - skipped_module_names = self.prepare_custom_config_dict.get( - 'non_traceable_module_name', []) - skipped_module_classes = self.prepare_custom_config_dict.get( - 'non_traceable_module_class', []) - standalone_module_name_configs = self.prepare_custom_config_dict.get( - 'standalone_module_name', []) - skipped_module_names += [ - config[0] for config in standalone_module_name_configs - ] - - standalone_module_class_configs = self.prepare_custom_config_dict.get( - 'standalone_module_class', []) - skipped_module_classes += [ - config[0] for config in standalone_module_class_configs - ] - float_custom_module_classes = get_custom_module_class_keys( - self.prepare_custom_config_dict, - 'float_to_observed_custom_module_class') - skipped_module_classes += float_custom_module_classes - tracer = CustomTracer(self.skipped_methods, skipped_module_names, - skipped_module_classes) - # tracer = QuantizationTracer(skipped_module_names, - # skipped_module_classes) - return tracer - - def fuse_model(self, graph_module): - graph_module = _fuse_fx(graph_module, True, - self.prepare_custom_config_dict) - return graph_module - - def post_process_weight_fakequant(self, - observed_module, - keep_fake_quant=False): - - def traverse(module): - - for name, child in module.named_children(): - if isinstance(child, SUPPORT_QAT_MODULES): - weight_fakequant = child.weight_fake_quant - child.weight.data = weight_fakequant(child.weight.data) - - float_child = child.to_float() - - if keep_fake_quant: - for m in float_child.modules(): - setattr(m, 'qconfig', self.qconfig_dict['']) - - if type(child) in MERGE_BN_MAPPINGS: - cls = MERGE_BN_MAPPINGS[type(child)] - new_child = cls.from_float(float_child) - else: - new_child = child.from_float(float_child) - - new_child.weight_fake_quant(new_child.weight) - else: - new_child = float_child - setattr(module, name, new_child) - else: - traverse(child) - observed_module.apply(enable_fake_quant) - traverse(observed_module) - - def prepare_for_mmdeploy(self, model): - raise NotImplementedError + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py new file mode 100644 index 000000000..8d1cd0b34 --- /dev/null +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import enable_fake_quant +from torch.ao.quantization.fx import prepare +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize_fx import _fuse_fx +from torch.nn.intrinsic.qat import modules as qat_fused_modules +from torch.nn.qat import modules as qat_modules + +from mmrazor.models.utils import str2class +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from .base import BaseQuantizer + +SUPPORT_QAT_MODULES = ( + qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, + qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, + qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, + qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, + qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, + qat_modules.Conv3d, qat_modules.Linear) + +MERGE_BN_MAPPINGS = { + qat_fused_modules.ConvBn1d: qat_modules.Conv1d, + qat_fused_modules.ConvBn2d: qat_modules.Conv2d, + qat_fused_modules.ConvBn3d: qat_modules.Conv3d, + qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, + qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, + qat_fused_modules.LinearBn1d: qat_modules.Linear +} + + +@MODELS.register_module() +class NativeQuantizer(BaseQuantizer): + + # backend: 'native' + # support_w_modes = ['per_tensor', 'per_channel'] + # support_a_modes = ['per_tensor'] + + def __init__(self, + global_qconfig, + no_observer_modules=None, + tracer=dict(type='CustomTracer')): + super().__init__(tracer) + self.qconfig = QConfigHander(global_qconfig) + if self.qconfig.w_qscheme.is_per_channel: + w_mode = 'per_channel' + else: + w_mode = 'per_tensor' + if self.qconfig.a_qscheme.is_per_channel: + a_mode = 'per_channel' + else: + a_mode = 'per_tensor' + assert w_mode in self.support_w_modes + assert a_mode in self.support_a_modes + + self.qconfig_mapping = QConfigMapping().set_global( + self.qconfig.convert()) + if no_observer_modules: + self.no_observer_modules = str2class(no_observer_modules) + for mod in self.no_observer_modules: + self.qconfig_mapping.set_object_type(mod, None) + else: + self.no_observer_modules = no_observer_modules + self.backend_config = BackendConfigs[self.backend] + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + @property + def backend(self): + return 'native' + + @property + def support_w_modes(self): + return ['per_tensor', 'per_channel'] + + @property + def support_a_modes(self): + return ['per_tensor'] + + def prepare(self, model, graph_module): + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + return prepared + + def post_process_weight_fakequant(self, + observed_module, + keep_fake_quant=False): + + def traverse(module): + for name, child in module.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + weight_fakequant = child.weight_fake_quant + child.weight.data = weight_fakequant(child.weight.data) + + float_child = child.to_float() + + if keep_fake_quant: + for m in float_child.modules(): + setattr(m, 'qconfig', self.qconfig.convert()) + + if type(child) in MERGE_BN_MAPPINGS: + cls = MERGE_BN_MAPPINGS[type(child)] + new_child = cls.from_float(float_child) + else: + new_child = type(child).from_float(float_child) + + new_child.weight_fake_quant(new_child.weight) + else: + new_child = float_child + setattr(module, name, new_child) + else: + traverse(child) + + observed_module.apply(enable_fake_quant) + traverse(observed_module) + + def prepare_for_mmdeploy(self, model, dummy_input, checkpoint): + raise NotImplementedError diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 7bf067593..bac432baa 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,59 +1,101 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch from torch.ao.quantization import disable_observer +from torch.ao.quantization.fx import prepare +from torch.ao.quantization.quantize_fx import _fuse_fx + +from mmrazor.models.task_modules.tracer.fx import (build_graphmodule, + del_fakequant_after_module, + del_fakequant_after_target, + del_fakequant_before_module, + del_fakequant_before_target) +from mmrazor.models.utils import str2class from mmrazor.registry import MODELS -from mmrazor.structures.quantization import DefaultQconfigs -from mmrazor.models.task_modules.tracer.fx.custom_tracer import build_graphmodule -from .base import CustomQuantizer +from .native_quantizer import NativeQuantizer @MODELS.register_module() -class OpenvinoQuantizer(CustomQuantizer): +class OpenVINOQuantizer(NativeQuantizer): """Quantizer for Openvino backend.""" - support_bits = [8] - support_w_mode = ['per_channel'] - support_a_mode = ['per_tensor'] - + # backend: 'openvino' + # support_w_mode = ['per_tensor', 'per_channel'] + # support_a_mode = ['per_tensor'] def __init__(self, - qconfig, - is_qat=True, - skipped_methods=None, - prepare_custom_config_dict=None, - convert_custom_config_dict=None, - equalization_qconfig_dict=None, - _remove_qconfig=True, - init_cfg=None): - super().__init__(qconfig, is_qat, skipped_methods, - prepare_custom_config_dict, - convert_custom_config_dict, equalization_qconfig_dict, - _remove_qconfig, init_cfg) - - - - - def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None): - - - self._swap_ff_with_fxff(model) + global_qconfig, + no_observer_modules=None, + tracer=dict(type='CustomTracer'), + remove_fakequants=dict( + module_prev=('torch.nn.ReLU6', 'torch.nn.Identity'), + module_next=('torch.nn.MaxPool2d', ), + target_prev=('output', ), + target_next=('flatten', ))): + super().__init__(global_qconfig, no_observer_modules, tracer) + self.remove_fakequants = remove_fakequants + + @property + def backend(self): + return 'openvino' + + @property + def support_w_modes(self): + return ['per_tensor', 'per_channel'] + + @property + def support_a_modes(self): + return ['per_tensor'] + + def prepare(self, model, graph_module): + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + module_prev = self.remove_fakequants.get('module_prev') + module_next = self.remove_fakequants.get('module_next') + target_prev = self.remove_fakequants.get('target_prev') + target_next = self.remove_fakequants.get('target_next') + + if module_prev: + prepared = del_fakequant_before_module( + prepared, str2class(module_prev), inplace=True) + if module_next: + prepared = del_fakequant_after_module( + prepared, str2class(module_next), inplace=True) + if target_prev: + prepared = del_fakequant_before_target( + prepared, target_prev, inplace=True) + if target_next: + prepared = del_fakequant_after_target( + prepared, target_next, inplace=True) + print(prepared) + + return prepared + + def prepare_for_mmdeploy(self, + model, + dummy_input=(1, 3, 224, 224), + checkpoint=None): + + self.swap_ff_with_fxff(model) graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) observed_model = self.prepare(model, graph_module) - - - self.post_process_weight_fakequant(observed_model, keep_fake_quant=True) - if dummy_input is not None: - observed_model(dummy_input) + observed_model(torch.randn(dummy_input)) + if checkpoint is not None: + observed_model.load_state_dict( + torch.load(checkpoint)['state_dict']) + self.post_process_weight_fakequant( + observed_model, keep_fake_quant=True) observed_model.apply(disable_observer) - return observed_model - - - - - - diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py new file mode 100644 index 000000000..4d9868c4f --- /dev/null +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -0,0 +1,56 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import disable_observer + +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + build_graphmodule +from mmrazor.registry import MODELS +from .native_quantizer import NativeQuantizer + + +@MODELS.register_module() +class TensorRTQuantizer(NativeQuantizer): + """Quantizer for TensorRT backend.""" + + # backend: 'tensorrt' + # support_w_mode = ['per_tensor', 'per_channel'] + # support_a_mode = ['per_tensor'] + + def __init__(self, + global_qconfig, + no_observer_modules=None, + tracer=dict(type='CustomTracer')): + super().__init__(global_qconfig, no_observer_modules, tracer) + + @property + def backend(self): + return 'tensorrt' + + @property + def support_w_modes(self): + return ['per_tensor', 'per_channel'] + + @property + def support_a_modes(self): + return ['per_tensor'] + + def prepare_for_mmdeploy(self, + model, + dummy_input=(1, 3, 224, 224), + checkpoint=None): + + self.swap_ff_with_fxff(model) + graph = self.tracer.trace(model) + graph_module = build_graphmodule(model, graph) + observed_model = self.prepare(model, graph_module) + if dummy_input is not None: + observed_model(torch.randn(dummy_input)) + if checkpoint is not None: + observed_model.load_state_dict( + torch.load(checkpoint)['state_dict']) + self.post_process_weight_fakequant( + observed_model, keep_fake_quant=True) + + observed_model.apply(disable_observer) + + return observed_model diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py deleted file mode 100644 index 26c44b665..000000000 --- a/mmrazor/models/quantizers/trt_quantizer.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from torch.ao.quantization import disable_observer - -from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ - build_graphmodule -from mmrazor.registry import MODELS -from mmrazor.structures.quantization import DefaultQconfigs -from .base import CustomQuantizer - - -@MODELS.register_module() -class TensorRTQuantizer(CustomQuantizer): - """Quantizer for TensorRT backend.""" - - support_bits = [8] - support_w_mode = ['per_channel'] - support_a_mode = ['per_tensor'] - - def __init__(self, - qconfig=DefaultQconfigs['tensorrt'], - is_qat=True, - skipped_methods=None, - prepare_custom_config_dict=None, - convert_custom_config_dict=None, - equalization_qconfig_dict=None, - _remove_qconfig=True, - init_cfg=None): - super().__init__(qconfig, is_qat, skipped_methods, - prepare_custom_config_dict, - convert_custom_config_dict, equalization_qconfig_dict, - _remove_qconfig, init_cfg) - - def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None): - - graph = self.tracer.trace(model) - graph_module = build_graphmodule(model, graph) - observed_model = self.prepare(model, graph_module) - - observed_model(torch.randn(1, 3, 224, 224)) - - self.post_process_weight_fakequant(observed_model) - if dummy_input is not None: - observed_model(dummy_input) - - observed_model.apply(disable_observer) - - return observed_model diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py index 1190f945c..1a4d00d78 100644 --- a/mmrazor/models/task_modules/tracer/fx/__init__.py +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -1,8 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. from .custom_tracer import (CustomTracer, UntracedMethodRegistry, build_graphmodule, custom_symbolic_trace) +from .graph_utils import (del_fakequant_after_module, + del_fakequant_after_target, + del_fakequant_before_module, + del_fakequant_before_target) __all__ = [ 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', - 'build_graphmodule' + 'build_graphmodule', 'del_fakequant_before_module', + 'del_fakequant_after_module', 'del_fakequant_before_target', + 'del_fakequant_after_target' ] diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 1d78d3007..0e118290e 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -13,6 +13,8 @@ _patch_wrapped_functions, _Patcher) from torch.fx.proxy import Proxy +from mmrazor.registry import TASK_UTILS + _orig_module_call: Callable = torch.nn.Module.__call__ _orig_module_getattr: Callable = torch.nn.Module.__getattr__ @@ -155,8 +157,8 @@ def _get_attrs(target, attrs): return module_dict -def build_graphmodule(model: nn.Module, - fx_graph: torch.fx.Graph, +def build_graphmodule(model: nn.Module, + fx_graph: torch.fx.Graph, name: str = 'GraphModule'): modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) @@ -164,6 +166,7 @@ def build_graphmodule(model: nn.Module, return GraphModule(modules, fx_graph, name) +@TASK_UTILS.register_module() class CustomTracer(QuantizationTracer): def __init__(self, @@ -308,7 +311,7 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): @functools.wraps(_orig_module_getattr) def module_getattr_wrapper(mod, attr): attr_val = _orig_module_getattr(mod, attr) - return self._module_getattr(attr, attr_val, parameter_proxy_cache) + return self.getattr(attr, attr_val, parameter_proxy_cache) @functools.wraps(_orig_module_call) def module_call_wrapper(mod, *args, **kwargs): diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py new file mode 100644 index 000000000..952f31b4b --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -0,0 +1,138 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +from torch.ao.quantization.fake_quantize import FakeQuantizeBase + + +def _get_attrs(target, attrs): + + attrs = attrs.split('.') + + for att in attrs: + target = getattr(target, att, None) + return target + + +def del_fakequant_before_target(prepared_model, target_patterns, inplace=True): + + def recursive_find_erased_nodes(node): + """Find FakeQuant before target node recursively. + + Examples: + head_fc = self.head.fc(activation_post_process_87); \ + activation_post_process_87 = None + activation_post_process_88 = \ + self.activation_post_process_88(head_fc); head_fc = None + head = self.head + _get_loss = head._get_loss(activation_post_process_88, + data_samples); \ + head = activation_post_process_88 = data_samples = None + return _get_loss + + node | node.args + -------------------- + output | (_get_loss, ) + _get_loss | (head, activation_post_process_88, + data_samples) + head | () + activation_post_process_88 | (head_fc, ) + data_samples | (None, ) + """ + if node is None: + return + if isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + nodes_to_erase.append(node) + return + for prev_node in node.args: + recursive_find_erased_nodes(prev_node) + for prev_node in node.kwargs.values(): + recursive_find_erased_nodes(prev_node) + return + + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.target in target_patterns: + nodes_to_erase = [] + recursive_find_erased_nodes(node) + for to_erase in nodes_to_erase: + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_target(prepared_model, target_patterns, inplace=True): + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + + target_nodes = [] + for node in new_graph.nodes: + if node.target in target_patterns: + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_before_module(prepared_model, module_patterns, inplace=True): + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), module_patterns): + to_erase = node.args[0] + if not isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase): + continue + if len(to_erase.users) > 1: + continue + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_module(prepared_model, module_patterns, inplace=True): + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + target_nodes = [] + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), module_patterns): + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model diff --git a/mmrazor/models/utils/__init__.py b/mmrazor/models/utils/__init__.py index 1b3eea2a1..8ff8abdae 100644 --- a/mmrazor/models/utils/__init__.py +++ b/mmrazor/models/utils/__init__.py @@ -4,8 +4,9 @@ from .optim_wrapper import reinitialize_optim_wrapper_count_status from .parse_values import parse_values from .utils import get_module_device, set_requires_grad +from .quantization_util import str2class __all__ = [ - 'add_prefix', 'reinitialize_optim_wrapper_count_status', 'make_divisible', - 'get_module_device', 'set_requires_grad', 'parse_values' + 'make_divisible', 'add_prefix', 'reinitialize_optim_wrapper_count_status', + 'str2class', 'get_module_device', 'set_requires_grad', 'parse_values' ] diff --git a/mmrazor/models/utils/quantization_util.py b/mmrazor/models/utils/quantization_util.py index 5593572ce..c5eaf890b 100644 --- a/mmrazor/models/utils/quantization_util.py +++ b/mmrazor/models/utils/quantization_util.py @@ -1,208 +1,39 @@ # Copyright (c) OpenMMLab. All rights reserved. -from functools import partial -from typing import Any, Dict, List, Optional, Set +from mmengine.utils import import_modules_from_strings -import torch -import torch.distributed as dist +def _check_valid_source(source): + """Check if the source's format is valid.""" + if not isinstance(source, str): + raise TypeError(f'source should be a str ' + f'instance, but got {type(source)}') -class PerChannelLoadHook: + assert len(source.split('.')) > 1, \ + 'source must have at least one `.`' - def __init__(self, module, hook_param=['scale', 'zero_point']): - self.hook = module._register_load_state_dict_pre_hook( - partial(self.hook_fn, module=module)) - self.hook_param = hook_param - def hook_fn(self, state_dict, prefix, local_metadata, strict, missing_keys, - unexpected_keys, error_msgs, module): - if module.ch_axis == -1: - # no per-channel parameters - return - for module_key, param in module._parameters.items(): - if module_key not in self.hook_param: - continue - candidate = prefix + module_key - if candidate in state_dict: - input_param = state_dict[candidate] - if param.shape != input_param.shape: - param.data = torch.ones_like( - input_param, dtype=param.dtype, device=param.device) - for module_key, param in module._buffers.items(): - if module_key not in self.hook_param: - continue - candidate = prefix + module_key - if candidate in state_dict: - input_param = state_dict[candidate] - if param.shape != input_param.shape: - param.data = torch.ones_like( - input_param, dtype=param.dtype, device=param.device) - - def close(self): - self.hook.remove() - - -USE_DDP = False - -if torch.distributed.is_initialized(): - USE_DDP = True - - -def sync_tensor(tensor): - - if USE_DDP: - tensor.data = tensor.data / dist.get_world_size() - dist.all_reduce(tensor.data) - return tensor - - -def pot_quantization(tensor: torch.Tensor, mode='round'): - log2t = torch.log2(tensor) - if mode == 'round': - log2t = (torch.round(log2t) - log2t).detach() + log2t +def str2class(str_inputs): + clss = [] + if not isinstance(str_inputs, tuple) and not isinstance(str_inputs, list): + str_inputs_list = [str_inputs] else: - assert mode == 'floor' - log2t = (torch.floor(log2t) - log2t).detach() + log2t - return 2**log2t - - -def _is_per_channel(qscheme: 'torch.qscheme') -> bool: - return qscheme in [ - torch.per_channel_symmetric, torch.per_channel_affine, - torch.per_channel_affine_float_qparams - ] - - -def _is_per_tensor(qscheme: 'torch.qscheme') -> bool: - return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] - - -def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool: - return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] - - -def is_tracing_state(): - return torch._C._get_tracing_state() - - -def _is_float_qparams(qscheme: 'torch.qscheme') -> bool: - return qscheme in [ - torch.per_channel_affine_float_qparams, - ] - - -def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], - dict_name: str) -> None: - r""" Checks if the given config_dict has the correct keys - Args: - `config_dict`: dictionary whose keys we want to check - """ - - for k in config_dict.keys(): - if k not in allowed_keys: - raise ValueError('Expected ' + dict_name + - ' to have the following keys: ' + - str(allowed_keys) + '. But found \'' + k + - '\' instead.') - - -def check_is_valid_qconfig_dict(qconfig_dict: Any) -> None: - r""" Checks if the given qconfig_dict has the correct keys - Args: - `qconfig_dict`: dictionary whose keys we want to check - """ - - qconfig_dict_allowed_keys = { - '', 'object_type', 'module_name_regex', 'module_name', - 'module_name_object_type_order' - } - check_is_valid_config_dict(qconfig_dict, qconfig_dict_allowed_keys, - 'qconfig_dict') - - -def check_is_valid_prepare_custom_config_dict( - prepare_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: - r""" Checks if the given prepare_custom_config_dict has the correct keys - Args: - `prepare_custom_config_dict`: customization configuration dictionary for - quantization tool - """ - if not prepare_custom_config_dict: - return - - prepare_custom_config_dict_allowed_keys = { - 'standalone_module_name', 'standalone_module_class', - 'float_to_observed_custom_module_class', 'non_traceable_module_name', - 'non_traceable_module_class', 'input_quantized_idxs', - 'output_quantized_idxs', 'preserved_attributes' - } - check_is_valid_config_dict(prepare_custom_config_dict, - prepare_custom_config_dict_allowed_keys, - 'prepare_custom_config_dict') - - -def check_is_valid_convert_custom_config_dict( - convert_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: - r""" Checks if the given convert_custom_config_dict has the correct keys - Args: - `convert_custom_config_dict`: dictionary for custom configurations for - convert function - """ - if not convert_custom_config_dict: - return - - convert_custom_config_dict_allowed_keys = { - 'observed_to_quantized_custom_module_class', 'preserved_attributes' - } - check_is_valid_config_dict(convert_custom_config_dict, - convert_custom_config_dict_allowed_keys, - 'convert_custom_config_dict') - - -def check_is_valid_fuse_custom_config_dict( - fuse_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: - r""" Checks if the given fuse_custom_config_dict has the correct keys - Args: - `fuse_custom_config_dict`: dictionary for custom configurations for - fuse_fx - """ - if not fuse_custom_config_dict: - return - - fuse_custom_config_dict_allowed_keys = {'preserved_attributes'} - check_is_valid_config_dict(fuse_custom_config_dict, - fuse_custom_config_dict_allowed_keys, - 'fuse_custom_config_dict') - - -def get_custom_module_class_keys(custom_config_dict, - custom_config_dict_key) -> List[Any]: - r""" Get all the unique custom module keys in the custom config dict - e.g. - Input: - custom_config_dict = { - "float_to_observed_custom_module_class": { - "static": { - CustomModule1: ObservedCustomModule - }, - "dynamic": { - CustomModule2: DynamicObservedCustomModule - }, - "weight_only": { - CustomModule3: WeightOnlyObservedCustomModule - }, - }, - } - Output: - # extract all the keys in "static", "dynamic" and "weight_only" dict - [CustomModule1, CustomModule2, CustomModule3] - """ - # using set to dedup - float_custom_module_classes: Set[Any] = set() - custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) - for quant_mode in ['static', 'dynamic', 'weight_only']: - quant_mode_custom_module_config = custom_module_mapping.get( - quant_mode, {}) - quant_mode_custom_module_classes = set( - quant_mode_custom_module_config.keys()) - float_custom_module_classes |= quant_mode_custom_module_classes - return list(float_custom_module_classes) + str_inputs_list = str_inputs + for s_class in str_inputs_list: + _check_valid_source(s_class) + mod_str = '.'.join(s_class.split('.')[:-1]) + cls_str = s_class.split('.')[-1] + try: + mod = import_modules_from_strings(mod_str) + except ImportError: + raise ImportError(f'{mod_str} is not imported correctly.') + imported_cls: type = getattr(mod, cls_str) + if not isinstance(imported_cls, type): + raise TypeError(f'{cls_str} should be a type ' + f'instance, but got {type(imported_cls)}') + clss.append(imported_cls) + if isinstance(str_inputs, list): + return clss + elif isinstance(str_inputs, tuple): + return tuple(clss) + else: + return clss[0] diff --git a/mmrazor/structures/__init__.py b/mmrazor/structures/__init__.py index 6dfcfbdc8..7f15c5d45 100644 --- a/mmrazor/structures/__init__.py +++ b/mmrazor/structures/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .quantization import * # noqa: F401,F403 from .subnet import * # noqa: F401,F403 diff --git a/mmrazor/structures/quantization/__init__.py b/mmrazor/structures/quantization/__init__.py index 9447c2f0f..cbf28034f 100644 --- a/mmrazor/structures/quantization/__init__.py +++ b/mmrazor/structures/quantization/__init__.py @@ -1,5 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backend_default_qconfigs import CheckArgs, DefaultQconfigs, SupportQtypes -from .qscheme import QuantizeScheme - -__all__ = ['QuantizeScheme', 'DefaultQconfigs', 'SupportQtypes', 'CheckArgs'] +from .backend_config import * # noqa: F401,F403 +from .qconfig import * # noqa: F401,F403 diff --git a/mmrazor/structures/quantization/backend_config/__init__.py b/mmrazor/structures/quantization/backend_config/__init__.py new file mode 100644 index 000000000..151968f8d --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .academic import (get_academic_backend_config, + get_academic_backend_config_dict) +from .mapping import BackendConfigs +from .native import get_native_backend_config, get_native_backend_config_dict +from .openvino import (get_openvino_backend_config, + get_openvino_backend_config_dict) +from .tensorrt import (get_tensorrt_backend_config, + get_tensorrt_backend_config_dict) + +__all__ = [ + 'BackendConfigs', + 'get_native_backend_config', + 'get_native_backend_config_dict', + 'get_academic_backend_config', + 'get_academic_backend_config_dict', + 'get_openvino_backend_config', + 'get_openvino_backend_config_dict', + 'get_tensorrt_backend_config', + 'get_tensorrt_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py new file mode 100644 index 000000000..5983c3996 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +from .common_operator_config_utils import (_get_conv_configs, + _get_linear_configs) + +# =================== +# | DTYPE CONFIGS | +# =================== + +# weighted op int8 dtype config +# this is config for ops that has quantized weights, like linear, conv +weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_academic_backend_config() -> BackendConfig: + """Return the `BackendConfig` for academic reseaching.""" + conv_dtype_configs = [weighted_op_int8_dtype_config] + linear_dtype_configs = [weighted_op_int8_dtype_config] + + return BackendConfig('native') \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) + + +def get_academic_backend_config_dict(): + """Return the `BackendConfig` for academic reseaching in dictionary + form.""" + return get_academic_backend_config().to_dict() + + +__all__ = [ + 'get_academic_backend_config', + 'get_academic_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py new file mode 100644 index 000000000..2a855e687 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py @@ -0,0 +1,607 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import operator +from collections import namedtuple +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.intrinsic as nni +import torch.nn.intrinsic.qat as nniqat +import torch.nn.qat as nnqat +import torch.nn.quantized._reference as nnqr +from torch.ao.quantization.backend_config import (BackendPatternConfig, + DTypeConfig, ObservationType) +from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.ao.quantization.fuser_method_mappings import ( + fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, + reverse2, reverse3, reverse_sequential_wrapper2) +from torch.ao.quantization.qconfig_mapping import _FIXED_QPARAMS_OP_TO_OBSERVER + +_ConvMetadata = namedtuple('_ConvMetadata', [ + 'root', 'transpose', 'bn', 'reference', 'transpose_reference', + 'fused_conv_relu', 'fused_conv_bn', 'fused_conv_bn_relu', 'qat', + 'relu_qat', 'bn_qat', 'bn_relu_qat', 'func' +]) +_Conv1dMetadata = _ConvMetadata(nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, + nnqr.Conv1d, nnqr.ConvTranspose1d, + nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, + nnqat.Conv1d, nniqat.ConvReLU1d, + nniqat.ConvBn1d, nniqat.ConvBnReLU1d, F.conv1d) +_Conv2dMetadata = _ConvMetadata(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, + nnqr.Conv2d, nnqr.ConvTranspose2d, + nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, + nnqat.Conv2d, nniqat.ConvReLU2d, + nniqat.ConvBn2d, nniqat.ConvBnReLU2d, F.conv2d) +_Conv3dMetadata = _ConvMetadata(nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, + nnqr.Conv3d, nnqr.ConvTranspose3d, + nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, + nnqat.Conv3d, nniqat.ConvReLU3d, + nniqat.ConvBn3d, nniqat.ConvBnReLU3d, F.conv3d) + + +def _get_binary_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + binary_op_configs: List[BackendPatternConfig] = [] + num_tensor_args_to_observation_type_mapping = { + # TODO: this is not used right now since we have extra check in prepare + # will need to change this to NO_OBSERVER later after we implemented + # Tensor dtype inference properly + 0: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + 1: ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT, + 2: ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT, + } + for op_with_quantized_bop_scalar_variant in [ + operator.add, torch.add, operator.mul, torch.mul + ]: + bop_patterns = [(torch.nn.ReLU, op_with_quantized_bop_scalar_variant), + (torch.nn.functional.relu, + op_with_quantized_bop_scalar_variant), + (torch.relu, op_with_quantized_bop_scalar_variant), + op_with_quantized_bop_scalar_variant] + for bop_pattern in bop_patterns: + binary_op_configs.append( + BackendPatternConfig(bop_pattern).set_dtype_configs( + dtype_configs) # noqa: E131 + ._set_num_tensor_args_to_observation_type( + num_tensor_args_to_observation_type_mapping)) + # matmul + binary_op_configs.append( + BackendPatternConfig(torch.matmul).set_dtype_configs( + dtype_configs) # noqa: E131 + ) + return binary_op_configs + + +def _get_linear_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + """Return all configs related to linear modules and ops.""" + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + linear_configs: List[BackendPatternConfig] = [] + + # (1) Single linear modules/functions + # ------------------------------------- + # linear module + linear_configs.append( + BackendPatternConfig(torch.nn.Linear).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module( + nnqr.Linear).set_qat_module(nnqat.Linear)) + # linear qat module + linear_configs.append( + BackendPatternConfig(nnqat.Linear).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module(nnqr.Linear)) + # functional linear + linear_configs.append( + BackendPatternConfig(torch.nn.functional.linear).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 1, + 'bias': 2 + })) + + # (2) Linear + relu + # ------------------- + # 2.1 linear module + relu fusion config + # linear relu, linear module + relu module + linear_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + torch.nn.Linear)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + nni.LinearReLU)).set_fused_module(nni.LinearReLU)) + # linear relu, linear module + functional relu + linear_configs.append( + BackendPatternConfig( + (torch.nn.functional.relu, + torch.nn.Linear)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + nni.LinearReLU)).set_fused_module(nni.LinearReLU)) + + # 2.2 linear module + relu, fused module configs + # linear relu, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearReLU).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module( + nnqr.Linear).set_qat_module(nniqat.LinearReLU)) + # linear relu, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearReLU).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module(nnqr.Linear)) + # 2.3 functional linear + relu configs + # linear relu, functional linear + relu module + linear_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + F.linear)).set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + # linear relu, functional linear + functional relu + linear_configs.append( + BackendPatternConfig( + (F.relu, + F.linear)).set_observation_type(observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # (3) Linear + batchnorm + # ------------------------ + # 3.1 linear bn fusion + linear_configs.append( + BackendPatternConfig( + (nn.BatchNorm1d, + nn.Linear)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse2(fuse_linear_bn)).set_fused_module( + nni.LinearBn1d)) + + # 3.2 linear bn fused + # linear bn, fused module + linear_configs.append( + BackendPatternConfig(nni.LinearBn1d).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module( + nnqr.Linear).set_qat_module(nniqat.LinearBn1d)) + # linear bn, qat fused module + linear_configs.append( + BackendPatternConfig(nniqat.LinearBn1d).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + torch.nn.Linear).set_reference_quantized_module(nnqr.Linear)) + return linear_configs + + +def _get_conv_configs(dtype_configs): + """Return all configs related to conv modules and ops.""" + conv_configs = [] + observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + for convs in [_Conv1dMetadata, _Conv2dMetadata, _Conv3dMetadata]: + + # (1) Single conv modules/functions + # ----------------------------------- + # conv module + conv_configs.append( + BackendPatternConfig(convs.root).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module( + convs.reference).set_qat_module(convs.qat)) + # conv qat module + conv_configs.append( + BackendPatternConfig(convs.qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + # functional conv + conv_configs.append( + BackendPatternConfig(convs.func).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': + 1, + 'bias': + 2 + })) + + # (2) Conv + relu + # ----------------- + # 2.1 conv module + relu fusion configs + # conv relu fusion, conv module + relu module + conv_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + convs.root)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method( + reverse_sequential_wrapper2( + convs.fused_conv_relu)).set_fused_module( + convs.fused_conv_relu)) + # conv relu fusion, conv module + functional relu + conv_configs.append( + BackendPatternConfig( + (F.relu, + convs.root)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method( + reverse_sequential_wrapper2( + convs.fused_conv_relu)).set_fused_module( + convs.fused_conv_relu)) + # 2.2 conv module + relu fused module configs + # conv relu, fused module + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module( + convs.reference).set_qat_module(convs.relu_qat)) + # conv relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.relu_qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + # 2.3 functional conv + relu configs + # conv relu, functional conv + relu module + conv_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, convs.func)).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + # conv relu, functional conv + functional relu + conv_configs.append( + BackendPatternConfig((F.relu, convs.func)).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # fused conv relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_relu).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_qat_module(convs.relu_qat)) + + conv_configs.append( + BackendPatternConfig(convs.relu_qat).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_root_module(convs.root).set_reference_quantized_module( + convs.reference)) + + # (3) Conv + batchnorm (+ relu) + # ------------------------------- + # 3.1 conv bn fusion configs + # conv + bn fusion + conv_configs.append( + BackendPatternConfig( + (convs.bn, + convs.root)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse2(fuse_conv_bn)).set_fused_module( + convs.fused_conv_bn)) + # conv + bn + relu module fusion + conv_configs.append( + BackendPatternConfig( + (nn.ReLU, + (convs.bn, + convs.root))).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse3(fuse_conv_bn_relu)).set_fused_module( + convs.fused_conv_bn_relu)) + # conv + bn + relu functional fusion + conv_configs.append( + BackendPatternConfig( + (F.relu, + (convs.bn, + convs.root))).set_dtype_configs(dtype_configs) # noqa: E131 + .set_root_module(convs.root).set_fuser_method( + reverse3(fuse_conv_bn_relu)).set_fused_module( + convs.fused_conv_bn_relu)) + # TODO: we can add fusion for torch.relu as well + + # 3.2 conv + bn (+ relu) fused module configs + # fused conv bn + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_qat)) + + # fused conv bn relu + conv_configs.append( + BackendPatternConfig(convs.fused_conv_bn_relu).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_qat_module(convs.bn_relu_qat)) + + # conv bn, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + # conv bn relu, qat fused module + conv_configs.append( + BackendPatternConfig(convs.bn_relu_qat).set_observation_type( + observation_type) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + convs.root).set_reference_quantized_module(convs.reference)) + + # (4) conv transpose and its fusion + # 4.1 conv transpose config + conv_configs.append( + BackendPatternConfig(convs.transpose).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_root_module(convs.transpose).set_reference_quantized_module( + convs.transpose_reference)) + + # 4.2 conv transpose + bn fusion + conv_configs.append( + BackendPatternConfig( + (convs.bn, convs.transpose)).set_dtype_configs( + dtype_configs) # noqa: E131 + .set_fuser_method(reverse2(fuse_convtranspose_bn)).set_root_module( + convs.transpose).set_reference_quantized_module( + convs.transpose_reference)) + + return conv_configs + + +def _get_cat_config(dtype_configs: List[DTypeConfig]) -> BackendPatternConfig: + return BackendPatternConfig(torch.cat) \ + .set_observation_type( + ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .set_dtype_configs(dtype_configs) + + +def _get_ln_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + ln_configs = [] + ln_configs.append( + BackendPatternConfig(torch.nn.LayerNorm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + ln_configs.append( + BackendPatternConfig( + torch.nn.functional.layer_norm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 2, + 'bias': 3 + })) + return ln_configs + + +def _get_default_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + configs = [] + default_ops = [ + torch.nn.ELU, + torch.nn.LeakyReLU, + torch.nn.Hardswish, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.Dropout, + torch.nn.PReLU, + torch.nn.functional.elu, + torch.nn.functional.hardswish, + torch.nn.functional.leaky_relu, + torch.nn.functional.dropout, + ] + for op in default_ops: + configs.append( + BackendPatternConfig(op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + configs.append( + BackendPatternConfig( + torch.nn.functional.group_norm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 2, + 'bias': 3 + })) + + configs.append( + BackendPatternConfig( + torch.nn.functional.instance_norm).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)._set_input_type_to_index({ + 'weight': 3, + 'bias': 4 + })) + return configs + + +def _get_fixed_qparams_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + fixed_qparams_op_configs = [] + op_to_obs = _FIXED_QPARAMS_OP_TO_OBSERVER.items() + for fixed_qparam_op, output_observer in op_to_obs: + fixed_qparams_op_configs.append( + # TODO: The _overwrite_output keys are temporary, since we don't + # want to put observer in the configs we expect that it's provided + # by user What we want to put here is the requirement on observers, + # in this case dtype, quant_min, quant_max etc., but we need to + # first move all configs to backend_config_dict to do that, we'll + # remove these keys after we fully migrated everything to use + # backend_config_dict + BackendPatternConfig(fixed_qparam_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs). + _set_overwrite_output_fake_quantize( + FixedQParamsFakeQuantize.with_args(observer=output_observer) + )._set_overwrite_output_observer(output_observer)) + return fixed_qparams_op_configs + + +def _get_share_qparams_op_configs(dtype_configs): + """Get the operator config for the operators that works for both float and + quantized input if input is quantized, the output Tensor shares the same + quantization parameter with input. Example operator: avgpool2d, reshape, + transpose, maxpool2d Example observed operator: + + observer_0 - avgpool2d - observer_0 (same observer instance as input) + """ + + def _get_share_qprams_op_backend_config(op): + return BackendPatternConfig(op) \ + .set_observation_type( + ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .set_dtype_configs(dtype_configs) + + share_qparams_ops = [ + torch.nn.AdaptiveAvgPool1d, + torch.nn.AdaptiveAvgPool2d, + torch.nn.AdaptiveAvgPool3d, + torch.nn.AvgPool1d, + torch.nn.AvgPool2d, + torch.nn.AvgPool3d, + torch.nn.Hardtanh, + torch.nn.Identity, + torch.nn.MaxPool1d, + torch.nn.MaxPool2d, + torch.nn.MaxPool3d, + torch.nn.ReLU, + torch.adaptive_avg_pool1d, + torch.nn.functional.adaptive_avg_pool2d, + torch.nn.functional.adaptive_avg_pool3d, + torch.nn.functional.hardtanh, + torch.nn.functional.hardtanh_, + torch.nn.functional.interpolate, + torch.nn.functional.max_pool1d, + torch.nn.functional.max_pool2d, + torch.nn.functional.max_pool3d, + torch.nn.functional.relu, + torch.nn.functional.relu6, + torch.avg_pool1d, + torch._C._nn.avg_pool2d, + torch._C._nn.avg_pool3d, + torch.clamp, + torch.flatten, + torch.mean, + torch.repeat_interleave, + torch.transpose, + torch.squeeze, + torch.stack, + torch.unsqueeze, + operator.floordiv, + 'contiguous', + 'clamp', + 'detach', + 'detach_', + 'mean', + 'permute', + 'repeat', + 'repeat_interleave', + 'reshape', + 'resize_', + 'relu', + 'relu_', + 'shape', + 'size', + 'squeeze', + 'squeeze_', + 'transpose', + 'unsqueeze', + 'unsqueeze_', + 'view', + ] + return [ + _get_share_qprams_op_backend_config(op) for op in share_qparams_ops + ] + + +def _get_bn_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + """Get configs related to batchnorm.""" + bn_configs = [] + bn_to_fused_bn = { + torch.nn.BatchNorm2d: nni.BNReLU2d, + torch.nn.BatchNorm3d: nni.BNReLU3d, + } + for bn in bn_to_fused_bn.keys(): + fused_bn = bn_to_fused_bn[bn] + # bn module + relu module fusion config + bn_configs.append( + BackendPatternConfig( + (torch.nn.ReLU, + bn)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + fused_bn)).set_fused_module(fused_bn)) + # bn module + F.relu fusion config + bn_configs.append( + BackendPatternConfig( + (torch.nn.functional.relu, + bn)).set_dtype_configs(dtype_configs) # noqa: E131 + .set_fuser_method(reverse_sequential_wrapper2( + bn_to_fused_bn[bn])).set_fused_module(fused_bn)) + bn_configs.append( + BackendPatternConfig(bn).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + + # fused bn configs + for fused_bn in bn_to_fused_bn.values(): + bn_configs.append( + BackendPatternConfig(fused_bn).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs)) + return bn_configs + + +def _get_rnn_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + rnn_op_configs = [] + for rnn_op, ref_rnn_op in [(nn.GRUCell, nnqr.GRUCell), + (nn.LSTMCell, nnqr.LSTMCell), + (nn.RNNCell, nnqr.RNNCell), + (nn.LSTM, nnqr.LSTM)]: + rnn_op_configs.append( + BackendPatternConfig(rnn_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + rnn_op).set_reference_quantized_module(ref_rnn_op)) + return rnn_op_configs + + +def _get_embedding_op_configs( + dtype_configs: List[DTypeConfig]) -> List[BackendPatternConfig]: + embedding_op_configs = [] + for embedding_op, qat_embedding_op, ref_embedding_op in [ + (nn.Embedding, nnqat.Embedding, nnqr.Embedding), + (nn.EmbeddingBag, nnqat.EmbeddingBag, nnqr.EmbeddingBag), + ]: + embedding_op_configs.append( + BackendPatternConfig(embedding_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs).set_qat_module(qat_embedding_op). + set_root_module(embedding_op).set_reference_quantized_module( + ref_embedding_op)._set_input_output_observed( + False)) # This is temporary, and will be removed soon + # config for qat op + embedding_op_configs.append( + BackendPatternConfig(qat_embedding_op).set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT + ) # noqa: E131 + .set_dtype_configs(dtype_configs).set_root_module( + embedding_op).set_reference_quantized_module( + ref_embedding_op)._set_input_output_observed( + False)) # This is temporary, and will be removed soon + return embedding_op_configs + + +__all__ = [ + '_get_binary_op_configs', + '_get_linear_configs', + '_get_conv_configs', + '_get_share_qparams_op_configs', +] diff --git a/mmrazor/structures/quantization/backend_config/mapping.py b/mmrazor/structures/quantization/backend_config/mapping.py new file mode 100644 index 000000000..4c87a73b9 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/mapping.py @@ -0,0 +1,12 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .academic import get_academic_backend_config +from .native import get_native_backend_config +from .openvino import get_openvino_backend_config +from .tensorrt import get_tensorrt_backend_config + +BackendConfigs = { + 'academic': get_academic_backend_config(), + 'native': get_native_backend_config(), + 'tensorrt': get_tensorrt_backend_config(), + 'openvino': get_openvino_backend_config() +} diff --git a/mmrazor/structures/quantization/backend_config/native.py b/mmrazor/structures/quantization/backend_config/native.py new file mode 100644 index 000000000..d771b6012 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/native.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +from .common_operator_config_utils import ( # noqa: F401,F403 + _get_binary_op_configs, _get_bn_configs, _get_cat_config, + _get_conv_configs, _get_default_op_configs, _get_embedding_op_configs, + _get_fixed_qparams_op_configs, _get_linear_configs, _get_ln_configs, + _get_rnn_op_configs, _get_share_qparams_op_configs) + +# =================== +# | DTYPE CONFIGS | +# =================== + +# weighted op int8 dtype config +# this is config for ops that has quantized weights, like linear, conv +weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, +) + +default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, +) + +default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, + # we will enable it a bit later after we moved everything to + # backend_config_dict + is_dynamic=True, +) + +default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, we will enable it a bit + # later after we moved everything to backend_config_dict + is_dynamic=True, +) + +# Needed for LayerNorm and f.layer_norm, since currently the kernel only +# supports float weights +input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, +) + +weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, +) + +weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, +) + +# ===================== +# | BACKEND CONFIGS | +# ===================== + + +def get_native_backend_config() -> BackendConfig: + """Return the `BackendConfig` for PyTorch Native backend + (fbgemm/qnnpack).""" + # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK + # BackendConfigs + conv_dtype_configs = [weighted_op_int8_dtype_config] + linear_dtype_configs = [ + weighted_op_int8_dtype_config, + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + binary_op_dtype_configs = [weighted_op_int8_dtype_config] + default_op_dtype_configs = [default_op_quint8_dtype_config] + fixed_qparams_op_dtype_configs = [weighted_op_int8_dtype_config] + share_qparams_op_dtype_configs = [default_op_quint8_dtype_config] + rnn_op_dtype_configs = [ + default_dynamic_int8_dtype_config, + default_dynamic_float16_dtype_config, + ] + embedding_op_dtype_configs = [ + weight_only_quint8_dtype_config, + weight_only_quint4x2_dtype_config, + ] + layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config] + + return BackendConfig('native') \ + .set_backend_pattern_configs( + _get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_configs( + _get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs( + _get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_config( + _get_cat_config(default_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_default_op_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_bn_configs(default_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_ln_configs(layer_norm_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_rnn_op_configs(rnn_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_embedding_op_configs(embedding_op_dtype_configs)) + + +def get_native_backend_config_dict(): + """Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) + in dictionary form.""" + return get_native_backend_config().to_dict() + + +__all__ = [ + 'get_native_backend_config', + 'get_native_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/openvino.py b/mmrazor/structures/quantization/backend_config/openvino.py new file mode 100644 index 000000000..fd24eed17 --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/openvino.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, ObservationType) + +from .common_operator_config_utils import (_get_binary_op_configs, + _get_conv_configs, + _get_linear_configs, + _get_share_qparams_op_configs) + + +def get_openvino_backend_config() -> BackendConfig: + """Return the `BackendConfig` for the OpenVINO backend.""" + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + non_weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + ) + + addmm_config = BackendPatternConfig(torch.addmm) \ + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_op_qint8_dtype_config) \ + ._set_input_type_to_index({ + 'bias': 0, + 'input': 1, + 'weight': 2, + }) + cat_config = BackendPatternConfig(torch.cat) \ + .set_observation_type( + ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .add_dtype_config(non_weighted_op_qint8_dtype_config) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + binary_op_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return BackendConfig('openvino') \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_config(addmm_config) \ + .set_backend_pattern_config(cat_config) \ + .set_backend_pattern_configs( + _get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs( + _get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs)) + + +def get_openvino_backend_config_dict(): + """Return the `BackendConfig` for the OpenVINO backend in dictionary + form.""" + return get_openvino_backend_config().to_dict() + + +__all__ = [ + 'get_openvino_backend_config', + 'get_openvino_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py new file mode 100644 index 000000000..abb585c6a --- /dev/null +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, ObservationType) + +from .common_operator_config_utils import (_get_binary_op_configs, + _get_conv_configs, + _get_linear_configs, + _get_share_qparams_op_configs) + + +def get_tensorrt_backend_config() -> BackendConfig: + """Return the `BackendConfig` for the TensorRT backend.""" + # dtype configs + weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + non_weighted_op_qint8_dtype_config = DTypeConfig( + input_dtype=torch.qint8, + output_dtype=torch.qint8, + ) + + addmm_config = BackendPatternConfig(torch.addmm) \ + .set_observation_type( + ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ + .add_dtype_config(weighted_op_qint8_dtype_config) \ + ._set_input_type_to_index({ + 'bias': 0, + 'input': 1, + 'weight': 2, + }) + cat_config = BackendPatternConfig(torch.cat) \ + .set_observation_type( + ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ + .add_dtype_config(non_weighted_op_qint8_dtype_config) + conv_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + linear_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + binary_op_dtype_configs = [ + weighted_op_qint8_dtype_config, + ] + share_qparams_op_dtype_configs = [ + non_weighted_op_qint8_dtype_config, + ] + # there might be things not supported in fx2trt, but it will error out + # during fx2trt conversion and can support them after that + return BackendConfig('tensorrt') \ + .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ + .set_backend_pattern_config(addmm_config) \ + .set_backend_pattern_config(cat_config) \ + .set_backend_pattern_configs( + _get_linear_configs(linear_dtype_configs)) \ + .set_backend_pattern_configs( + _get_binary_op_configs(binary_op_dtype_configs)) \ + .set_backend_pattern_configs( + _get_share_qparams_op_configs(share_qparams_op_dtype_configs)) + + +def get_tensorrt_backend_config_dict(): + """Return the `BackendConfig` for the TensorRT backend in dictionary + form.""" + return get_tensorrt_backend_config().to_dict() + + +__all__ = [ + 'get_tensorrt_backend_config', + 'get_tensorrt_backend_config_dict', +] diff --git a/mmrazor/structures/quantization/backend_default_qconfigs.py b/mmrazor/structures/quantization/backend_default_qconfigs.py deleted file mode 100644 index 590f3208a..000000000 --- a/mmrazor/structures/quantization/backend_default_qconfigs.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -SupportQtypes = ('affine') -CheckArgs = [ - 'qtype', 'w_qscheme', 'a_qscheme', 'w_fake_quant', 'a_fake_quant', - 'w_observer', 'a_observer' -] - -Default = dict( - qtype='affine', # noqa: E241 - w_qscheme=dict( - is_symmetry=True, - is_per_channel=True, - is_pot_scale=False, - bit=8, - symmetric_range=True), - a_qscheme=dict( - is_symmetry=True, - is_per_channel=False, - is_pot_scale=False, - bit=8, - symmetric_range=True), - w_fake_quant=dict(type='FakeQuantize'), - a_fake_quant=dict(type='FakeQuantize'), - w_observer=dict(type='MinMaxObserver'), - a_observer=dict(type='MinMaxObserver')) - -TensorRT = dict( - qtype='affine', # noqa: E241 - w_qscheme=dict( - is_symmetry=True, - is_per_channel=True, - is_pot_scale=False, - bit=8, - symmetric_range=True), - a_qscheme=dict( - is_symmetry=True, - is_per_channel=False, - is_pot_scale=False, - bit=8, - symmetric_range=True), - w_fake_quant=dict(type='LearnableFakeQuantize'), - a_fake_quant=dict(type='LearnableFakeQuantize'), - w_observer=dict(type='MinMaxObserver'), - a_observer=dict(type='EMAMinMaxObserver')) - -DefaultQconfigs = dict(default=Default, tensorrt=TensorRT) diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py new file mode 100644 index 000000000..3dca49730 --- /dev/null +++ b/mmrazor/structures/quantization/qconfig.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Union + +import torch +from mmengine.config import Config +from torch.ao.quantization import QConfig + +from mmrazor.registry import MODELS + +RequiredArgs = [ + 'w_qscheme', 'a_qscheme', 'w_fake_quant', 'a_fake_quant', 'w_observer', + 'a_observer' +] + + +class QConfigHander(): + """Convert custom user-friendly qconfig format to torch's QConfig. + + Args: + qconfig (Dict | Config): custom user-friendly qconfig format, + including setting observers, fakequants and quantization schemes + for weights and activations. + Note: + whether quantization scheme is per-channel or not depends on + used observer, if observer support per-channel quantization, its name + should contain 'PerChannel'. + """ + + def __init__(self, qconfig: Union[Dict, Config]): + if not self.check_qconfig(qconfig): + raise ValueError('The format of qconfig is incorrect.') + else: + w_observer = MODELS.get(qconfig['w_observer']['type']) + a_observer = MODELS.get(qconfig['a_observer']['type']) + w_is_per_channel = False + a_is_per_channel = False + # import pdb;pdb.set_trace() + if 'PerChannel' in w_observer.__name__: + w_is_per_channel = True + if 'PerChannel' in a_observer.__name__: + a_is_per_channel = True + self.w_qscheme = QSchemeHander( + is_per_channel=w_is_per_channel, **qconfig['w_qscheme']) + self.a_qscheme = QSchemeHander( + is_per_channel=a_is_per_channel, **qconfig['a_qscheme']) + + w_fake_quant = MODELS.get(qconfig['w_fake_quant']['type']) + w_observer_kwargs = self.w_qscheme.to_observer_params() + a_fake_quant = MODELS.get(qconfig['a_fake_quant']['type']) + a_observer_kwargs = self.a_qscheme.to_observer_params() + + self.w_fake_quant = w_fake_quant.with_args( + observer=w_observer, **w_observer_kwargs) + self.a_fake_quant = a_fake_quant.with_args( + observer=a_observer, **a_observer_kwargs) + + @staticmethod + def check_qconfig(qconfig: Union[Dict, Config]): + """Check whether the passed qconfig's format meets requirement.""" + is_pass = True + for arg in RequiredArgs: + val = qconfig.get(arg, None) + if isinstance(val, dict) and arg in qconfig.keys(): + continue + else: + is_pass = False + break + return is_pass + + def convert(self): + """Generate torch's QConfig with built fake_quants.""" + torch_qconfig = QConfig( + weight=self.w_fake_quant, activation=self.a_fake_quant) + return torch_qconfig + + +class QSchemeHander(object): + """Convert the qscheme of custom user-friendly qconfig to args needed in + observers. + + Args: + qdtype (str): Quantization dtype. It should is 'quint8' or 'qint8', + and should be supported by the deploy backend. Defaults to 'quint8' + bit (int): Quantization bit number. Defaults to 8. + is_symmetry (bool): Is symmetry quantization or not. Defaults to True. + is_per_channel (bool): Is per-channel quantization or not. + Defaults to False. + """ + + def __init__(self, + qdtype: str = 'quint8', + bit: int = 8, + is_symmetry: bool = True, + is_per_channel: bool = False, + **kwargs): + assert qdtype in ('quint8', 'qint8'), \ + 'qdtype is incorrect, it should be quint8 or qint8.' + self.qdtype = qdtype + self.bit = bit + self.is_symmetry = is_symmetry + self.is_per_channel = is_per_channel + + if self.is_per_channel: + self.torch_qscheme = torch.per_channel_symmetric \ + if self.is_symmetry else torch.per_channel_affine + else: + self.torch_qscheme = torch.per_tensor_symmetric \ + if self.is_symmetry else torch.per_tensor_affine + if 'is_symmetric_range' in kwargs: + self.is_symmetric_range = kwargs['is_symmetric_range'] + del kwargs['is_symmetric_range'] + else: + self.is_symmetric_range = False + self.kwargs = kwargs + + def to_observer_params(self): + """Generate the args needed in observers.""" + if self.qdtype == 'quint8': + quant_min = 0 + quant_max = 2**self.bit - 1 + else: + quant_max = 2**(self.bit - 1) - 1 + if self.is_symmetric_range: + quant_min = -2**(self.bit - 1) + 1 + else: + quant_min = -2**(self.bit - 1) + + # `dtype` will be same as BackenConfig's + naive_para = { + 'dtype': torch.quint8 if self.qdtype == 'quint8' else torch.qint8, + 'quant_min': quant_min, + 'quant_max': quant_max, + 'qscheme': self.torch_qscheme, + 'reduce_range': False + } + if self.is_per_channel: + naive_para['ch_axis'] = 0 + all_para = self.kwargs.copy() + all_para.update(naive_para) + return all_para + + def __str__(self): + """Print generated args for observers.""" + return f'dtype: {self.dtype} / bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ + is_per_channel: {self.is_per_channel} \ + / extra_kwargs: {self.kwargs}' + + +if __name__ == '__main__': + from mmrazor.models.fake_quants import register_torch_fake_quants + from mmrazor.models.observers import register_torch_observers + register_torch_observers() + register_torch_fake_quants() + + qconfig = dict( + w_observer=dict(type='mmrazor.MovingAveragePerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + from mmengine.config import Config + qconfig = Config(qconfig) + torch_qconfig = QConfigHander(qconfig).convert() + print(torch_qconfig) diff --git a/mmrazor/structures/quantization/qscheme.py b/mmrazor/structures/quantization/qscheme.py deleted file mode 100644 index 24c41832e..000000000 --- a/mmrazor/structures/quantization/qscheme.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - - -class QuantizeScheme(object): - """Custom QScheme. Refer to: - https://github.com/pytorch/pytorch/blob/master/c10/core/QScheme.h. - - Args: - bit (int, optional): Bit number. Defaults to 8. - is_symmetry (bool, optional): Is symmetry quantization or not. Defaults - to True. - is_per_channel (bool, optional): Is per-channel quantization or not. - Defaults to False. - is_pot_scale (bool, optional): Indicate whether scale is power of two. - Defaults to False. - """ - - def __init__(self, - bit=8, - is_symmetry=True, - is_per_channel=False, - is_pot_scale=False, - **kwargs): - self.bit = bit - self.is_symmetry = is_symmetry - self.is_per_channel = is_per_channel - self.is_pot_scale = is_pot_scale - - if self.is_per_channel: - self.torch_qscheme = torch.per_channel_symmetric \ - if self.is_symmetry else torch.per_channel_affine - else: - self.torch_qscheme = torch.per_tensor_symmetric \ - if self.is_symmetry else torch.per_tensor_affine - if 'is_symmetric_range' in kwargs: - self.is_symmetric_range = kwargs['is_symmetric_range'] - del kwargs['is_symmetric_range'] - else: - self.is_symmetric_range = False - self.kwargs = kwargs - - def to_observer_params(self): - quant_min = 0 - quant_max = 2**self.bit - 1 - if self.is_symmetry: - quant_max = 2**(self.bit - 1) - 1 - if self.is_symmetric_range: - quant_min = -2**(self.bit - 1) + 1 - else: - quant_min = -2**(self.bit - 1) - - naive_para = { - 'quant_min': quant_min, - 'quant_max': quant_max, - 'dtype': torch.qint8 if self.is_symmetry else torch.quint8, - 'is_pot_scale': self.is_pot_scale, - 'qscheme': self.torch_qscheme, - 'reduce_range': False, - 'ch_axis': 0 if self.is_per_channel else -1 - } - naive_para.update(self.kwargs) - return naive_para - - def __str__(self): - return f'bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ - is_per_channel: {self.is_per_channel} / is_pot_scale: \ - {self.is_pot_scale} / extra_kwargs: {self.kwargs}' diff --git a/tests/test_models/test_observers/test_observer.py b/tests/test_models/test_observers/test_observer.py deleted file mode 100644 index 5b99cb0fc..000000000 --- a/tests/test_models/test_observers/test_observer.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import unittest -from unittest import TestCase - -import torch -import torch.nn as nn -from torch.ao.quantization import QConfig -from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx -from torchvision.models import resnet18 - -from mmrazor.models.observers import MinMaxFloorObserver - - -class ToyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - # TODO - - -class TestMinMaxObserver(TestCase): - """TODO. - - Args: - TestCase (_type_): _description_ - """ - - def test_init(self): - pass - - def test_prepare(self): - pass - - def test_convert(self): - pass - - def test_states(self): - pass - - def test_forward(self): - pass - - -class TestLSQObserver(TestMinMaxObserver): - pass - - -class TestMinMaxFloorObserver(TestMinMaxObserver): - - def setUp(self) -> None: - self.model_fp = resnet18() - self.w_qscheme = dict( - dtype=torch.qint8, qscheme=torch.per_tensor_affine) - self.a_qscheme = dict( - dtype=torch.quint8, qscheme=torch.per_tensor_affine) - - def test_init(self) -> None: - with self.assertRaises(NotImplementedError): - _ = MinMaxFloorObserver( - dtype=torch.quint8, - qscheme=torch.per_tensor_symmetric, - reduce_range=True) - - def test_prepare(self) -> None: - flag = False - model_to_quantize = copy.deepcopy(self.model_fp) - model_to_quantize.eval() - qconfig_dict = { - '': - QConfig( - activation=MinMaxFloorObserver.with_args(**self.a_qscheme), - weight=MinMaxFloorObserver.with_args(**self.w_qscheme)) - } - prepared_model = prepare_fx(model_to_quantize, qconfig_dict) - for m in prepared_model.modules(): - if isinstance(m, MinMaxFloorObserver): - flag = True - break - self.assertTrue(flag) - - def test_convert(self) -> None: - flag = True - model_to_quantize = copy.deepcopy(self.model_fp) - model_to_quantize.eval() - qconfig_dict = { - '': - QConfig( - activation=MinMaxFloorObserver.with_args(**self.a_qscheme), - weight=MinMaxFloorObserver.with_args(**self.w_qscheme)) - } - prepared_model = prepare_fx(model_to_quantize, qconfig_dict) - prepared_model(torch.randn(1, 3, 224, 224)) - quantized_model = convert_fx(prepared_model) - for m in quantized_model.modules(): - if isinstance(m, MinMaxFloorObserver): - flag = False - break - self.assertTrue(flag) - - def test_states(self) -> None: - test_input = torch.Tensor([6., -8.]) - observer = MinMaxFloorObserver(**self.w_qscheme) - self.assertEqual( - [observer.min_val, observer.max_val], - [torch.tensor(float('inf')), - torch.tensor(float('-inf'))]) - observer.forward(test_input) - # per_tensor_affine - scale, zero_point = observer.calculate_qparams() - self.assertEqual(zero_point.item(), 18) - - def test_forward(self) -> None: - test_input = torch.Tensor([1., -1.]) - observer = MinMaxFloorObserver(**self.w_qscheme) - test_output = observer.forward(test_input) - self.assertIs(test_input, test_output) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 6652cb943..009640684 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -12,8 +12,6 @@ from mmrazor.models.algorithms.base import BaseAlgorithm from mmrazor.models.mutables import OneShotMutableOP from mmrazor.registry import MODELS -from mmrazor.structures import load_fix_subnet -from mmrazor.utils import ValidFixMutable @MODELS.register_module() @@ -46,15 +44,13 @@ class MockAlgorithm(BaseAlgorithm): def __init__(self, architecture: Union[BaseModel, Dict], - fix_subnet: Optional[ValidFixMutable] = None): + _return_architecture_: Optional[bool] = None): super().__init__(architecture) - if fix_subnet is not None: - # According to fix_subnet, delete the unchosen part of supernet - load_fix_subnet(self, fix_subnet, prefix='architecture.') - self.is_supernet = False + if _return_architecture_ is True: + self.return_model = self.architecture else: - self.is_supernet = True + self.return_model = self class TestRegistry(TestCase): @@ -72,16 +68,18 @@ def test_build_razor_from_cfg(self): # model = MODELS.build(self.arch_cfg_path) # self.assertIsNotNone(model) - # test fix subnet + # test return architecture cfg = Config.fromfile( - 'tests/data/test_registry/registry_subnet_config.py') + 'tests/data/test_registry/registry_architecture_config.py') model = MODELS.build(cfg.model) + self.assertTrue(isinstance(model.return_model, MockModel)) - # test return architecture + # test return model cfg = Config.fromfile( 'tests/data/test_registry/registry_architecture_config.py') + cfg.model.pop('_return_architecture_') model = MODELS.build(cfg.model) - self.assertTrue(isinstance(model, BaseModel)) + self.assertTrue(isinstance(model.return_model, MockAlgorithm)) def test_build_subnet_prune_from_cfg_by_mutator(self): mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py new file mode 100644 index 000000000..045b02c83 --- /dev/null +++ b/tests/test_structures/test_qconfig.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import torch +from mmengine.config import Config +from torch.ao.quantization import QConfig + +from mmrazor.models.fake_quants import register_torch_fake_quants +from mmrazor.models.observers import register_torch_observers +from mmrazor.structures import QConfigHander, QSchemeHander + +register_torch_observers() +register_torch_fake_quants() + + +class TestQSchemeHander(TestCase): + + def test_init(self): + # per_channel + qscheme = QSchemeHander(is_symmetry=True, is_per_channel=True) + assert qscheme.torch_qscheme is torch.per_channel_symmetric + + # per_tensor + qscheme = QSchemeHander(is_symmetry=True, is_per_channel=False) + assert qscheme.torch_qscheme is torch.per_tensor_symmetric + + # qdtype is incorrect + self.assertRaises(AssertionError, QSchemeHander, 'float') + + # is_symmetric_range + kwargs = {'is_symmetric_range': True} + qscheme = QSchemeHander(**kwargs) + assert qscheme.is_symmetric_range is True + + def test_to_observer_params(self): + # qdtype = quint8 + ret_params = QSchemeHander(qdtype='quint8').to_observer_params() + assert ret_params['dtype'] == torch.quint8 + assert ret_params['quant_min'] == 0 and ret_params['quant_max'] == 255 + + # qdtype = qint8, is_symmetric_range=False + ret_params = QSchemeHander(qdtype='qint8').to_observer_params() + assert ret_params['dtype'] == torch.qint8 + assert ret_params['quant_min'] == -128 and ret_params[ + 'quant_max'] == 127 + + # qdtype = qint8, is_symmetric_range=True + ret_params = QSchemeHander( + qdtype='qint8', is_symmetric_range=True).to_observer_params() + assert ret_params['quant_min'] == -127 and ret_params[ + 'quant_max'] == 127 + + # per_channel + ret_params = QSchemeHander(is_per_channel=True).to_observer_params() + assert ret_params['ch_axis'] == 0 + + # per_tensor + ret_params = QSchemeHander(is_per_channel=False).to_observer_params() + assert 'ch_axis' not in ret_params.keys() + + +class TestQConfigHander(TestCase): + + def setUp(self): + self.qconfig_dict = dict( + w_observer=dict(type='MovingAveragePerChannelMinMaxObserver'), + a_observer=dict(type='MovingAveragePerChannelMinMaxObserver'), + w_fake_quant=dict(type='FakeQuantize'), + a_fake_quant=dict(type='FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', + bit=8, + is_symmetry=True, + is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.qconfig = Config(self.qconfig_dict) + + def test_check_qconfig(self): + assert QConfigHander.check_qconfig(self.qconfig_dict) is True + assert QConfigHander.check_qconfig(self.qconfig) is True + qconfig_dict = copy.copy(self.qconfig_dict) + print(qconfig_dict) + qconfig_dict.pop('w_observer') + assert QConfigHander.check_qconfig(qconfig_dict) is False + + def test_init(self): + # test dict init + qconfig = QConfigHander(self.qconfig_dict) + assert hasattr(qconfig, 'w_qscheme') + assert hasattr(qconfig, 'a_qscheme') + assert hasattr(qconfig, 'w_fake_quant') + assert hasattr(qconfig, 'a_fake_quant') + + # test mmengine's Config init + qconfig = QConfigHander(self.qconfig) + assert hasattr(qconfig, 'w_qscheme') + assert hasattr(qconfig, 'a_qscheme') + assert hasattr(qconfig, 'w_fake_quant') + assert hasattr(qconfig, 'a_fake_quant') + + # per_channel + assert qconfig.w_qscheme.is_per_channel is True + assert qconfig.a_qscheme.is_per_channel is True + + def test_convert(self): + qconfig = QConfigHander(self.qconfig) + torch_qconfig = qconfig.convert() + assert isinstance(torch_qconfig, QConfig) diff --git a/tools/ckpt_demo.py b/tools/ckpt_demo.py deleted file mode 100644 index ee257390c..000000000 --- a/tools/ckpt_demo.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -ckpt_path = '/mnt/lustre/humu/experiments/adaround/quantizied.pth' -# ckpt_path = -# '/mnt/petrelfs/humu/share/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' -# ckpt_path = '/tmp/humu/resnet18_uniform8/checkpoint.pth.tar' -# ckpt_path = '/tmp/humu/resnet18_uniform8/quantized_checkpoint.pth.tar' - -state_dict = torch.load(ckpt_path, map_location='cpu') - -for k, v in state_dict['state_dict'].items(): - print(k) diff --git a/tools/debug.py b/tools/debug.py deleted file mode 100644 index 5d594cff8..000000000 --- a/tools/debug.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -import sys -import time -import numpy as np -from tqdm import tqdm - -import torch -from torch.ao.quantization import get_default_qconfig -from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx -import torch.nn as nn -from torch.utils.data import DataLoader - -import torchvision -from torchvision import datasets -from torchvision.models import resnet18, ResNet18_Weights -import torchvision.transforms as transforms - -# Set up warnings -import warnings -warnings.filterwarnings( - action='ignore', - category=DeprecationWarning, - module=r'.*' -) -warnings.filterwarnings( - action='default', - module=r'torch.ao.quantization' -) - -# Specify random seed for repeatable results -_ = torch.manual_seed(191009) - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f'): - self.name = name - self.fmt = fmt - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' - return fmtstr.format(**self.__dict__) - - -def accuracy(output, target, topk=(1,)): - """Computes the accuracy over the k top predictions for the specified values of k""" - with torch.no_grad(): - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -def evaluate(model, criterion, data_loader, use_cuda=False): - model.eval() - top1 = AverageMeter('Acc@1', ':6.2f') - top5 = AverageMeter('Acc@5', ':6.2f') - cnt = 0 - with torch.no_grad(): - for image, target in data_loader: - if use_cuda: - image = image.cuda() - target = target.cuda() - output = model(image) - loss = criterion(output, target) - cnt += 1 - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - top1.update(acc1[0], image.size(0)) - top5.update(acc5[0], image.size(0)) - print('') - - return top1, top5 - -def print_size_of_model(model): - if isinstance(model, torch.jit.RecursiveScriptModule): - torch.jit.save(model, "temp.p") - else: - torch.jit.save(torch.jit.script(model), "temp.p") - print("Size (MB):", os.path.getsize("temp.p")/1e6) - os.remove("temp.p") - -def prepare_data_loaders(data_path, batch_size=4): - transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) - trainset = torchvision.datasets.CIFAR10(data_path, train=True, transform=transform) - trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) - testset = torchvision.datasets.CIFAR10(data_path, train=False, transform=transform) - testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) - - return trainloader, testloader - -data_path = '/nvme/dataset/cifar10/' -saved_model_dir = '../experiments/valid_convert' -os.makedirs(saved_model_dir, exist_ok=True) - -_, data_loader_test = prepare_data_loaders(data_path) -criterion = nn.CrossEntropyLoss() -weights = ResNet18_Weights.DEFAULT -float_model = resnet18(weights=weights).cuda() -float_model.eval() - - -# deepcopy the model since we need to keep the original model around -import copy -model_to_quantize = copy.deepcopy(float_model) -model_to_quantize.eval() -qconfig = get_default_qconfig("fbgemm") -qconfig_dict = {"": qconfig} -prepared_model = prepare_fx(model_to_quantize, qconfig_dict) -# print(prepared_model.graph) - -def calibrate(model, data_loader): - model.eval() - with torch.no_grad(): - for image, target in tqdm(data_loader): - model(image.cuda()) -calibrate(prepared_model, data_loader_test) - -quantized_model = convert_fx(prepared_model) -# print(quantized_model) - -print("Size of model before quantization") -print_size_of_model(float_model) -print("Size of model after quantization") -print_size_of_model(quantized_model) -top1, top5 = evaluate(float_model, criterion, data_loader_test, use_cuda=True) -print("[float model] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) -top1, top5 = evaluate(quantized_model, criterion, data_loader_test) -print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) - -fx_graph_mode_model_file_path = os.path.join(saved_model_dir, "resnet18_fx_graph_mode_quantized.pth") - -# save with state_dict -torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path) -import copy -model_to_quantize = copy.deepcopy(float_model) -prepared_model = prepare_fx(model_to_quantize, {"": qconfig}) -loaded_quantized_model = convert_fx(prepared_model) -loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path)) -top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test) -print("[after serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) -print('pipeline finish') \ No newline at end of file diff --git a/tools/model_converters/convert_quant_ckpt.py b/tools/model_converters/convert_quant_ckpt.py new file mode 100644 index 000000000..9fbb06125 --- /dev/null +++ b/tools/model_converters/convert_quant_ckpt.py @@ -0,0 +1,53 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from pathlib import Path + +import torch + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Convert quantized checkpoint to deploy') + parser.add_argument('checkpoint', help='input checkpoint filename') + parser.add_argument('--out-path', help='save checkpoint path') + parser.add_argument( + '--inplace', action='store_true', help='replace origin ckpt') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + checkpoint = torch.load(args.checkpoint, map_location='cpu') + new_state_dict = dict() + new_meta = checkpoint['meta'] + + for key, value in checkpoint['state_dict'].items(): + if key.startswith('qmodels.predict.'): + new_key = key.replace('qmodels.predict.', '') + if '_val' in new_key and 'weight_fake_quant' in new_key: + new_key = new_key.replace('_val', '_vals') + new_state_dict[new_key] = value + # if key.startswith('architecture.'): + # new_key = key.replace('architecture.', '') + # new_state_dict[new_key] = value + + checkpoint = dict() + checkpoint['meta'] = new_meta + checkpoint['state_dict'] = new_state_dict + + if args.inplace: + torch.save(checkpoint, args.checkpoint) + else: + ckpt_path = Path(args.checkpoint) + ckpt_name = ckpt_path.stem + if args.out_path: + ckpt_dir = Path(args.out_path) + else: + ckpt_dir = ckpt_path.parent + new_ckpt_path = ckpt_dir / f'{ckpt_name}_deploy.pth' + torch.save(checkpoint, new_ckpt_path) + + +if __name__ == '__main__': + main() diff --git a/tools/ptq_calibrate.py b/tools/ptq.py similarity index 100% rename from tools/ptq_calibrate.py rename to tools/ptq.py diff --git a/tools/tracer_demo.py b/tools/tracer_demo.py deleted file mode 100644 index 88334d6aa..000000000 --- a/tools/tracer_demo.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy - -import torch -import torch.fx as fx -from mmengine.config import Config -from mmengine.registry import MODELS - -from mmrazor.models.task_modules.tracer import custom_symbolic_trace - -cfg_path = 'configs/quantization/ptq/demo.py' -_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) - - -def extract_subgraph(graphmodule, block_slice): - subgraph = copy.deepcopy(graphmodule.graph) - block_start, block_end = block_slice[:2] - for node in subgraph.nodes: - if node.name == 'inputs': - input_node = node - if node.name == block_start.name: - node.replace_input_with(node.prev, input_node) - if node.name == block_end.name: - output_node = node - if node.op == 'output': - node.replace_input_with(node.prev, output_node) - subgraph.lint() - subgraph_module = fx.GraphModule(graphmodule, subgraph) - subgraph_module.graph.eliminate_dead_code() - subgraph_module.recompile() - return subgraph_module - - -def extract_blocks(graphmodule, key_word='layer'): - block_slices = [] - block_slice = [] - pre_stage_index, pre_block_index = 0, 0 - cur_stage_index, cur_block_index = 0, 0 - for node in graphmodule.graph.nodes: - if key_word not in node.name: - continue - else: - items = node.name.split('_') - for i, item in enumerate(items): - if key_word in item: - cur_stage_index = int(item[5:]) - cur_block_index = int(items[i + 1]) - break - if (cur_block_index != pre_block_index) or (cur_stage_index != - pre_stage_index): - block_slice.append(node.prev) - if len(block_slice) == 2: - block_slices.append(block_slice) - block_slice = [] - block_slice.append(node) - - pre_stage_index, pre_block_index = cur_stage_index, cur_block_index - - return block_slices - - -def extract_layers(graphmodule, layer_types): - layer_slices = [] - for node in graphmodule.graph.nodes: - if node.op == 'call_module': - m = node.graph.owning_module.get_submodule(node.target) - if isinstance(m, _ADAROUND_SUPPORT_TYPE): - layer_slices.append((node, node)) - return layer_slices - - -def main(): - # load config - cfg = Config.fromfile(cfg_path) - model = MODELS.build(cfg.model) - symbolic_traced = custom_symbolic_trace( - model, concrete_args={'mode': 'tensor'}) - # block_slices = extract_blocks(symbolic_traced) - block_slices = extract_layers( - symbolic_traced, layer_types=_ADAROUND_SUPPORT_TYPE) - - for b in block_slices: - print(b[0].name, b[1].name) - - print('#' * 100) - subgraph = extract_subgraph(symbolic_traced, block_slices[0]) - print(subgraph.code) - for name, layer in subgraph.named_modules(): - print(name, layer) - - -if __name__ == '__main__': - main() From 0f209c2302206bfa1ff58e95a93a201532769d25 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Mon, 9 Jan 2023 14:39:32 +0800 Subject: [PATCH 06/44] [Refactor & Doc] Refactor graph_utils and add docstring and pytest (#420) * refactor graph_utils and add docstring and pytest * fix del fakequant * delete useless codes --- mmrazor/models/quantizers/base.py | 2 +- mmrazor/models/quantizers/native_quantizer.py | 86 ++- .../models/quantizers/openvino_quantizer.py | 74 +-- .../models/task_modules/tracer/fx/__init__.py | 16 +- .../task_modules/tracer/fx/graph_utils.py | 339 ++++++++++-- .../test_task_modules/test_graph_utils.py | 499 ++++++++++++++++++ 6 files changed, 906 insertions(+), 110 deletions(-) create mode 100644 tests/test_models/test_task_modules/test_graph_utils.py diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 4d1adceda..d98fbd786 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -14,7 +14,7 @@ def __init__(self, tracer): self.tracer = TASK_UTILS.build(tracer) @abstractmethod - def prepare(self): + def prepare(self, model, graph_module): pass def swap_ff_with_fxff(self, model): diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 8d1cd0b34..84be1edfb 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -7,6 +7,12 @@ from torch.nn.intrinsic.qat import modules as qat_fused_modules from torch.nn.qat import modules as qat_modules +from mmrazor.models.task_modules.tracer.fx import ( + del_fakequant_after_function, del_fakequant_after_method, + del_fakequant_after_module, del_fakequant_after_op, + del_fakequant_before_function, del_fakequant_before_method, + del_fakequant_before_module, del_fakequant_before_op) + from mmrazor.models.utils import str2class from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHander @@ -42,7 +48,16 @@ class NativeQuantizer(BaseQuantizer): def __init__(self, global_qconfig, no_observer_modules=None, - tracer=dict(type='CustomTracer')): + tracer=dict(type='CustomTracer'), + extra_redundant_fakequants=dict( + extra_module_prev_wo_fakequant=tuple(), + extra_module_next_wo_fakequant=tuple(), + extra_function_prev_wo_fakequant = tuple(), + extra_function_next_wo_fakequant = tuple(), + extra_method_prev_wo_fakequant = tuple(), + extra_method_next_wo_fakequant = tuple(), + extra_op_prev_wo_fakequant = tuple(), + extra_op_next_wo_fakequant = tuple())): super().__init__(tracer) self.qconfig = QConfigHander(global_qconfig) if self.qconfig.w_qscheme.is_per_channel: @@ -67,6 +82,8 @@ def __init__(self, self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) + self.extra_redundant_fakequants = extra_redundant_fakequants + @property def backend(self): return 'native' @@ -91,6 +108,7 @@ def prepare(self, model, graph_module): node_name_to_scope=self.tracer.node_name_to_scope, example_inputs=self.example_inputs, backend_config=self.backend_config) + prepared = self.del_redundant_fakequant(prepared) return prepared @@ -128,3 +146,69 @@ def traverse(module): def prepare_for_mmdeploy(self, model, dummy_input, checkpoint): raise NotImplementedError + + def del_redundant_fakequant(self, prepared): + extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_module( + prepared, self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, inplace=True) + + extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_module( + prepared, self.module_next_wo_fakequant + extra_module_next_wo_fakequant, inplace=True) + + extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_method( + prepared, self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, inplace=True) + + extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_method( + prepared, self.function_next_wo_fakequant + extra_function_next_wo_fakequant, inplace=True) + + extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_function( + prepared, self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, inplace=True) + + extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_function( + prepared, self.method_next_wo_fakequant + extra_method_next_wo_fakequant, inplace=True) + + extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_prev_wo_fakequant', tuple()) + prepared = del_fakequant_before_op( + prepared, self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, inplace=True) + + extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_next_wo_fakequant', tuple()) + prepared = del_fakequant_after_op( + prepared, self.op_next_wo_fakequant + extra_op_next_wo_fakequant, inplace=True) + return prepared + + @property + def module_prev_wo_fakequant(self): + return tuple() + + @property + def module_next_wo_fakequant(self): + return tuple() + + @property + def function_prev_wo_fakequant(self): + return tuple() + + @property + def function_next_wo_fakequant(self): + return tuple() + + @property + def method_prev_wo_fakequant(self): + return tuple() + + @property + def method_next_wo_fakequant(self): + return tuple() + + @property + def op_prev_wo_fakequant(self): + return tuple() + + @property + def op_next_wo_fakequant(self): + return tuple() diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index bac432baa..0b13b23f9 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,15 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + import torch from torch.ao.quantization import disable_observer -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.quantize_fx import _fuse_fx - -from mmrazor.models.task_modules.tracer.fx import (build_graphmodule, - del_fakequant_after_module, - del_fakequant_after_target, - del_fakequant_before_module, - del_fakequant_before_target) -from mmrazor.models.utils import str2class + +from mmrazor.models.task_modules.tracer.fx import build_graphmodule from mmrazor.registry import MODELS from .native_quantizer import NativeQuantizer @@ -22,18 +17,6 @@ class OpenVINOQuantizer(NativeQuantizer): # support_w_mode = ['per_tensor', 'per_channel'] # support_a_mode = ['per_tensor'] - def __init__(self, - global_qconfig, - no_observer_modules=None, - tracer=dict(type='CustomTracer'), - remove_fakequants=dict( - module_prev=('torch.nn.ReLU6', 'torch.nn.Identity'), - module_next=('torch.nn.MaxPool2d', ), - target_prev=('output', ), - target_next=('flatten', ))): - super().__init__(global_qconfig, no_observer_modules, tracer) - self.remove_fakequants = remove_fakequants - @property def backend(self): return 'openvino' @@ -46,39 +29,6 @@ def support_w_modes(self): def support_a_modes(self): return ['per_tensor'] - def prepare(self, model, graph_module): - graph_module = _fuse_fx( - graph_module=graph_module, - is_qat=True, - backend_config=self.backend_config) - prepared = prepare( - model=graph_module, - qconfig_mapping=self.qconfig_mapping, - is_qat=True, - node_name_to_scope=self.tracer.node_name_to_scope, - example_inputs=self.example_inputs, - backend_config=self.backend_config) - module_prev = self.remove_fakequants.get('module_prev') - module_next = self.remove_fakequants.get('module_next') - target_prev = self.remove_fakequants.get('target_prev') - target_next = self.remove_fakequants.get('target_next') - - if module_prev: - prepared = del_fakequant_before_module( - prepared, str2class(module_prev), inplace=True) - if module_next: - prepared = del_fakequant_after_module( - prepared, str2class(module_next), inplace=True) - if target_prev: - prepared = del_fakequant_before_target( - prepared, target_prev, inplace=True) - if target_next: - prepared = del_fakequant_after_target( - prepared, target_next, inplace=True) - print(prepared) - - return prepared - def prepare_for_mmdeploy(self, model, dummy_input=(1, 3, 224, 224), @@ -99,3 +49,19 @@ def prepare_for_mmdeploy(self, observed_model.apply(disable_observer) return observed_model + + @property + def module_prev_wo_fakequant(self): + return (torch.nn.ReLU6, torch.nn.Identity) + + @property + def module_next_wo_fakequant(self): + return (torch.nn.MaxPool2d, ) + + @property + def method_next_wo_fakequant(self): + return ('flatten', ) + + @property + def op_prev_wo_fakequant(self): + return ('output', ) diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py index 1a4d00d78..82f723f10 100644 --- a/mmrazor/models/task_modules/tracer/fx/__init__.py +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -1,14 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. from .custom_tracer import (CustomTracer, UntracedMethodRegistry, build_graphmodule, custom_symbolic_trace) -from .graph_utils import (del_fakequant_after_module, - del_fakequant_after_target, - del_fakequant_before_module, - del_fakequant_before_target) +from .graph_utils import (del_fakequant_after_function, + del_fakequant_after_method, + del_fakequant_after_module, del_fakequant_after_op, + del_fakequant_before_function, + del_fakequant_before_method, + del_fakequant_before_module, del_fakequant_before_op) __all__ = [ 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', 'build_graphmodule', 'del_fakequant_before_module', - 'del_fakequant_after_module', 'del_fakequant_before_target', - 'del_fakequant_after_target' + 'del_fakequant_after_module', 'del_fakequant_after_function', + 'del_fakequant_before_function', 'del_fakequant_after_op', + 'del_fakequant_before_op', 'del_fakequant_before_method', + 'del_fakequant_after_method' ] diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py index 952f31b4b..fe8d620c2 100644 --- a/mmrazor/models/task_modules/tracer/fx/graph_utils.py +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -1,79 +1,291 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +from typing import Any, List, Tuple +import torch.fx from torch.ao.quantization.fake_quantize import FakeQuantizeBase -def _get_attrs(target, attrs): +def _get_attrs(target: torch.nn.Module, attr: str) -> Any: + """Get the attribute from target. - attrs = attrs.split('.') + Args: + target (torch.nn.Module): Get the attribute from target module. + attr (str): The target attribute. + + Returns: + Any: The target attribute. + """ + + attrs: List[str] = attr.split('.') for att in attrs: target = getattr(target, att, None) return target -def del_fakequant_before_target(prepared_model, target_patterns, inplace=True): - - def recursive_find_erased_nodes(node): - """Find FakeQuant before target node recursively. - - Examples: - head_fc = self.head.fc(activation_post_process_87); \ - activation_post_process_87 = None - activation_post_process_88 = \ - self.activation_post_process_88(head_fc); head_fc = None - head = self.head - _get_loss = head._get_loss(activation_post_process_88, - data_samples); \ - head = activation_post_process_88 = data_samples = None - return _get_loss - - node | node.args - -------------------- - output | (_get_loss, ) - _get_loss | (head, activation_post_process_88, - data_samples) - head | () - activation_post_process_88 | (head_fc, ) - data_samples | (None, ) - """ - if node is None: - return - if isinstance( +def recursive_find_erased_nodes(node, prepared_model): + """Find FakeQuant before target node recursively. + + Examples: + head_fc = self.head.fc(activation_post_process_87); \ + activation_post_process_87 = None + activation_post_process_88 = \ + self.activation_post_process_88(head_fc); head_fc = None + head = self.head + _get_loss = head._get_loss(activation_post_process_88, + data_samples); \ + head = activation_post_process_88 = data_samples = None + return _get_loss + + node | node.args + -------------------- + output | (_get_loss, ) + _get_loss | (head, activation_post_process_88, + data_samples) + head | () + activation_post_process_88 | (head_fc, ) + data_samples | (None, ) + """ + if node is None: + return [] + + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + return [node] + + nodes_to_erase = [] + for prev_node in node.args: + if isinstance(prev_node, torch.fx.Node): + nodes_to_erase.extend( + recursive_find_erased_nodes(prev_node, prepared_model)) + for prev_node in node.kwargs.values(): + if isinstance(prev_node, torch.fx.Node): + nodes_to_erase.extend( + recursive_find_erased_nodes(prev_node, prepared_model)) + + return nodes_to_erase + + +def del_fakequant_before_op(prepared_model: torch.fx.GraphModule, + target_ops: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant before nodes whose ``op`` attribute (node.op) + is in `target_ops`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_ops (tuple): Fakequants before nodes whose op attribute + (node.op) is in `target_ops` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.op in target_ops: + nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + node, prepared_model) + for to_erase in nodes_to_erase: + assert to_erase.op == 'call_module' and isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase) and len(to_erase.args) == 1 + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_op(prepared_model: torch.fx.GraphModule, + target_ops: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant after nodes whose ``op`` attribute (node.op) is + in `target_ops`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_ops (tuple): Fakequants after nodes whose op attribute + (node.op) is in `target_ops` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + + target_nodes = [] + for node in new_graph.nodes: + if node.op in target_ops: + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared_model, node.target), FakeQuantizeBase): + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_before_method(prepared_model: torch.fx.GraphModule, + method_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant before nodes whose op attribute (node.op) is + `call_method` and target attribute (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before nodes whose op attribute + (node.op) is `call_method` and target attribute (node.target) is + in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + for node in new_graph.nodes: + if node.op == 'call_method' and node.target in method_patterns: + nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + node, prepared_model) + for to_erase in nodes_to_erase: + assert to_erase.op == 'call_module' and isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase) and len(to_erase.args) == 1 + to_erase.replace_all_uses_with(to_erase.args[0]) + new_graph.erase_node(to_erase) + delattr(prepared_model, to_erase.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_after_method(prepared_model: torch.fx.GraphModule, + method_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant after nodes whose op attribute (node.op) is + `call_method` and target attribute (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants after nodes whose op attribute + (node.op) is `call_method` and target attribute (node.target) + is in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ + if not inplace: + prepared_model = copy.deepcopy(prepared_model) + new_graph = copy.deepcopy(prepared_model.graph) + + target_nodes = [] + for node in new_graph.nodes: + if node.op == 'call_method' and node.target in method_patterns: + target_nodes.append(node) + + for node in new_graph.nodes: + if node.op == 'call_module' and isinstance( _get_attrs(prepared_model, node.target), FakeQuantizeBase): - nodes_to_erase.append(node) - return - for prev_node in node.args: - recursive_find_erased_nodes(prev_node) - for prev_node in node.kwargs.values(): - recursive_find_erased_nodes(prev_node) - return + assert len(node.args) == 1 + prev_node = node.args[0] + if prev_node not in target_nodes: + continue + node.replace_all_uses_with(prev_node) + new_graph.erase_node(node) + delattr(prepared_model, node.target) + + new_graph.lint() + prepared_model.graph = new_graph + return prepared_model + + +def del_fakequant_before_function( + prepared_model: torch.fx.GraphModule, + function_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant before nodes whose op attribute (node.op) is + `call_function` and target attribute (node.target) is in `target_patterns`. + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before nodes whose op attribute + (node.op) is `call_function` and target attribute (node.target) is + in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ if not inplace: prepared_model = copy.deepcopy(prepared_model) new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: - if node.target in target_patterns: - nodes_to_erase = [] - recursive_find_erased_nodes(node) + if node.op == 'call_function' and node.target in function_patterns: + nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + node, prepared_model) for to_erase in nodes_to_erase: + assert to_erase.op == 'call_module' and isinstance( + _get_attrs(prepared_model, to_erase.target), + FakeQuantizeBase) and len(to_erase.args) == 1 to_erase.replace_all_uses_with(to_erase.args[0]) new_graph.erase_node(to_erase) delattr(prepared_model, to_erase.target) + new_graph.lint() prepared_model.graph = new_graph return prepared_model -def del_fakequant_after_target(prepared_model, target_patterns, inplace=True): +def del_fakequant_after_function(prepared_model: torch.fx.GraphModule, + function_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant after nodes whose op attribute (node.op) is + `call_function` and target attribute (node.target) is in `target_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + function_patterns (tuple): Fakequants after nodes whose op attribute + (node.op) is `call_function` and target attribute (node.target) is + in `target_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. Defaults to + True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ if not inplace: prepared_model = copy.deepcopy(prepared_model) new_graph = copy.deepcopy(prepared_model.graph) target_nodes = [] for node in new_graph.nodes: - if node.target in target_patterns: + if node.op == 'call_function' and node.target in function_patterns: target_nodes.append(node) for node in new_graph.nodes: @@ -86,12 +298,28 @@ def del_fakequant_after_target(prepared_model, target_patterns, inplace=True): node.replace_all_uses_with(prev_node) new_graph.erase_node(node) delattr(prepared_model, node.target) + new_graph.lint() prepared_model.graph = new_graph return prepared_model -def del_fakequant_before_module(prepared_model, module_patterns, inplace=True): +def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, + module_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant before modules whose type are in + `module_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants before modules whose type is in + `module_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ if not inplace: prepared_model = copy.deepcopy(prepared_model) new_graph = copy.deepcopy(prepared_model.graph) @@ -99,21 +327,35 @@ def del_fakequant_before_module(prepared_model, module_patterns, inplace=True): if node.op == 'call_module' and isinstance( _get_attrs(prepared_model, node.target), module_patterns): to_erase = node.args[0] - if not isinstance( + if not (to_erase.op == 'call_module' and isinstance( _get_attrs(prepared_model, to_erase.target), - FakeQuantizeBase): - continue - if len(to_erase.users) > 1: + FakeQuantizeBase)): continue to_erase.replace_all_uses_with(to_erase.args[0]) new_graph.erase_node(to_erase) delattr(prepared_model, to_erase.target) + new_graph.lint() prepared_model.graph = new_graph return prepared_model -def del_fakequant_after_module(prepared_model, module_patterns, inplace=True): +def del_fakequant_after_module(prepared_model: torch.fx.GraphModule, + module_patterns: Tuple, + inplace: bool = True) -> torch.fx.GraphModule: + """Delete useless fakequant after modules whose type are in + `module_patterns`. + + Args: + prepared_model (GraphModule): Prepared standalone module. + target_patterns (tuple): Fakequants after modules whose type is in + `module_patterns` will be deleted. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + + Returns: + GraphModule: Prepared standalone module after deletion. + """ if not inplace: prepared_model = copy.deepcopy(prepared_model) new_graph = copy.deepcopy(prepared_model.graph) @@ -133,6 +375,7 @@ def del_fakequant_after_module(prepared_model, module_patterns, inplace=True): node.replace_all_uses_with(prev_node) new_graph.erase_node(node) delattr(prepared_model, node.target) + new_graph.lint() prepared_model.graph = new_graph return prepared_model diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py new file mode 100644 index 000000000..7250bee95 --- /dev/null +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -0,0 +1,499 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import operator +from unittest import TestCase + +import torch +import torch.nn as nn +from torch.ao.quantization import QConfigMapping +from torch.ao.quantization.fake_quantize import FakeQuantizeBase +from torch.ao.quantization.fx import prepare +from torch.ao.quantization.quantize_fx import _fuse_fx + +from mmrazor.models.task_modules import build_graphmodule +from mmrazor.models.task_modules.tracer import CustomTracer +from mmrazor.models.task_modules.tracer.fx import ( + del_fakequant_after_function, del_fakequant_after_method, + del_fakequant_after_module, del_fakequant_after_op, + del_fakequant_before_function, del_fakequant_before_method, + del_fakequant_before_module, del_fakequant_before_op) +from mmrazor.structures.quantization import BackendConfigs, QConfigHander + + +def _get_attrs(target, attrs): + attrs = attrs.split('.') + + for att in attrs: + target = getattr(target, att, None) + return target + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyModel(nn.Module): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + + +class TestGraphUtils(TestCase): + + def setUp(self): + self.tracer = CustomTracer() + self.backend_config = BackendConfigs['native'] + self.qconfig = QConfigHander(global_qconfig) + self.qconfig_mapping = QConfigMapping().set_global( + self.qconfig.convert()) + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + + def swap_ff_with_fxff(self, model): + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + def test_del_fakequant_before_op(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + op_del_prev_fakequant = ('output', ) + + prepared_after_del = del_fakequant_before_op( + prepared, op_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op in op_del_prev_fakequant: + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op in op_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_op( + prepared, op_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op in op_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + def test_del_fakequant_after_op(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + op_del_next_fakequant = ('placeholder', ) + + prepared_after_del = del_fakequant_after_op( + prepared, op_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op in op_del_next_fakequant: + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op in op_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_op( + prepared, op_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op in op_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + def test_del_fakequant_before_method(self): + + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + method_del_prev_fakequant = ('flatten', ) + + prepared_after_del = del_fakequant_before_method( + prepared, method_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_prev_fakequant: + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_method( + prepared, method_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_prev_fakequant: + args = node.args + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + def test_del_fakequant_after_method(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + method_del_next_fakequant = ('flatten', ) + + prepared_after_del = del_fakequant_after_method( + prepared, method_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_next_fakequant: + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_method( + prepared, method_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_method' and \ + node.target in method_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + def test_del_fakequant_before_function(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + function_del_prev_fakequant = (operator.add, ) + + prepared_after_del = del_fakequant_before_function( + prepared, function_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_prev_fakequant: + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_prev_fakequant: + args = node.args + self.assertEqual(len(args), 2) + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + self.assertNotIsInstance( + _get_attrs(prepared, args[1].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_function( + prepared, function_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_prev_fakequant: + args = node.args + self.assertEqual(len(args), 2) + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + self.assertNotIsInstance( + _get_attrs(prepared, args[1].target), FakeQuantizeBase) + + def test_del_fakequant_after_function(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + function_del_next_fakequant = (operator.add, ) + + prepared_after_del = del_fakequant_after_function( + prepared, function_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_next_fakequant: + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_function( + prepared, function_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_function' and \ + node.target in function_del_next_fakequant: + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + def test_del_fakequant_before_module(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + module_del_prev_fakequant = (torch.nn.ReLU6, torch.nn.Identity) + + prepared_after_del = del_fakequant_before_module( + prepared, module_del_prev_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_prev_fakequant): + args = node.args + self.assertIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_prev_fakequant): + args = node.args + if args[0].op == 'call_module': + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_before_module( + prepared, module_del_prev_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_prev_fakequant): + args = node.args + if args[0].op == 'call_module': + self.assertNotIsInstance( + _get_attrs(prepared, args[0].target), FakeQuantizeBase) + + def test_del_fakequant_after_module(self): + model_to_quantize = ToyModel() + model_to_quantize.eval() + + self.swap_ff_with_fxff(model_to_quantize) + traced_graph = self.tracer.trace(model_to_quantize) + graph_module = build_graphmodule(model_to_quantize, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + + module_del_next_fakequant = (torch.nn.MaxPool2d, ) + + prepared_after_del = del_fakequant_after_module( + prepared, module_del_next_fakequant, inplace=False) + for node in prepared.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_next_fakequant): + self.assertIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_next_fakequant): + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) + + prepared_after_del = del_fakequant_after_module( + prepared, module_del_next_fakequant, inplace=True) + for node in prepared_after_del.graph.nodes: + if node.op == 'call_module' and isinstance( + _get_attrs(prepared, node.target), + module_del_next_fakequant): + self.assertNotIsInstance( + _get_attrs(prepared, node.next.target), FakeQuantizeBase) From 82056fc258016c6d37490e405afe1b6a379fdab7 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:23:03 +0800 Subject: [PATCH 07/44] Merge dev-1.x into quantize (#430) * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai Co-authored-by: jacky * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * [Feature] Add Autoformer algorithm (#315) * update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut * [Feature] Add performance predictor (#306) * add predictor with 4 handlers * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * update metric_predictor: 1. update MetricPredictor; 2. add predictor config for searching; 3. add predictor in evolution_search_loop. * add UT for predictor * add MLPHandler * patch optional.txt for predictors * patch test_evolution_search_loop * refactor apis of predictor and handlers * fix ut and remove predictor_cfg in predictor * adapt new mutable & mutator design * fix ut * remove unness assert after rebase * move predictor-build in __init__ & simplify estimator-build Co-authored-by: Yue Sun * [Feature] Add DCFF (#295) * add ChannelGroup (#250) * rebase new dev-1.x * modification for adding config_template * add docstring to channel_group.py * add docstring to mutable_channel_group.py * rm channel_group_cfg from Graph2ChannelGroups * change choice type of SequentialChannelGroup from float to int * add a warning about group-wise conv * restore __init__ of dynamic op * in_channel_mutable -> mutable_in_channel * rm abstractproperty * add a comment about VT * rm registry for ChannelGroup * MUTABLECHANNELGROUP -> ChannelGroupType * refine docstring of IndexDict * update docstring * update docstring * is_prunable -> is_mutable * update docstring * fix error in pre-commit * update unittest * add return type * unify init_xxx apit * add unitest about init of MutableChannelGroup * update according to reviews * sequential_channel_group -> sequential_mutable_channel_group Co-authored-by: liukai * Add BaseChannelMutator and refactor Autoslim (#289) * add BaseChannelMutator * add autoslim * tmp * make SequentialMutableChannelGroup accpeted both of num and ratio as choice. and supports divisior * update OneShotMutableChannelGroup * pass supernet training of autoslim * refine autoslim * fix bug in OneShotMutableChannelGroup * refactor make_divisible * fix spell error: channl -> channel * init_using_backward_tracer -> init_from_backward_tracer init_from_fx_tracer -> init_from_fx_tracer * refine SequentialMutableChannelGroup * let mutator support models with dynamicop * support define search space in model * tracer_cfg -> parse_cfg * refine * using -> from * update docstring * update docstring Co-authored-by: liukai * tmpsave * migrate ut * tmpsave2 * add loss collector * refactor slimmable and add l1-norm (#291) * refactor slimmable and add l1-norm * make l1-norm support convnd * update get_channel_groups * add l1-norm_resnet34_8xb32_in1k.py * add pretrained to resnet34-l1 * remove old channel mutator * BaseChannelMutator -> ChannelMutator * update according to reviews * add readme to l1-norm * MBV2_slimmable -> MBV2_slimmable_config Co-authored-by: liukai * update config * fix md & pytorch support <1.9.0 in batchnorm init * Clean old codes. (#296) * remove old dynamic ops * move dynamic ops * clean old mutable_channels * rm OneShotMutableChannel * rm MutableChannel * refine * refine * use SquentialMutableChannel to replace OneshotMutableChannel * refactor dynamicops folder * let SquentialMutableChannel support float Co-authored-by: liukai * fix ci * ci fix py3.6.x & add mmpose * ci fix py3.6.9 in utils/index_dict.py * fix mmpose * minimum_version_cpu=3.7 * fix ci 3.7.13 * fix pruning &meta ci * support python3.6.9 * fix py3.6 import caused by circular import patch in py3.7 * fix py3.6.9 * Add channel-flow (#301) * base_channel_mutator -> channel_mutator * init * update docstring * allow omitting redundant configs for channel * add register_mutable_channel_to_a_module to MutableChannelContainer * update according to reviews 1 * update according to reviews 2 * update according to reviews 3 * remove old docstring * fix error * using->from * update according to reviews * support self-define input channel number * update docstring * chanenl -> channel_elem Co-authored-by: liukai Co-authored-by: jacky * support >=3.7 * support py3.6.9 * Rename: ChannelGroup -> ChannelUnit (#302) * refine repr of MutableChannelGroup * rename folder name * ChannelGroup -> ChannelUnit * filename in units folder * channel_group -> channel_unit * groups -> units * group -> unit * update * get_mutable_channel_groups -> get_mutable_channel_units * fix bug * refine docstring * fix ci * fix bug in tracer Co-authored-by: liukai * update new channel config format * update pruning refactor * update merged pruning * update commit * fix dynamic_conv_mixin * update comments: readme&dynamic_conv_mixins.py * update readme * move kl softmax channel pooling to op by comments * fix comments: fix redundant & split README.md * dcff in ItePruneAlgorithm * partial dynamic params for fuseconv * add step_freq & prune_time check * update comments * update comments * update comments * fix ut * fix gpu ut & revise step_freq in ItePruneAlgorithm * update readme * revise ItePruneAlgorithm * fix docs * fix dynamic_conv attr * fix ci Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: zengyi.vendor Co-authored-by: jacky * [Fix] Fix optional requirements (#357) * fix optional requirements * fix dcff ut * fix import with get_placeholder * supplement the previous commit * [Fix] Fix configs of wrn models and ofd. (#361) * 1.revise the configs of wrn22, wrn24, and wrn40. 2.revise the data_preprocessor of ofd_backbone_resnet50_resnet18_8xb16_cifar10 * 1.Add README for vanilla-wrm. * 1.Revise readme of wrn Co-authored-by: zhangzhongyu * [Fix] Fix bug on mmrazor visualization, mismatch argument in define and use. (#356) fix bug on mmrazor visualization, mismatch argument in define and use. Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> * fix bug in benchmark_test (#364) fix bug in configs Co-authored-by: Your Name * [FIX] Fix wrn configs (#368) * fix wrn configs * fix wrn configs * update online wrn model weight * [Fix] fix bug on pkd config. Wrong import filename. (#373) * [CI] Update ci to torch1.13 (#380) update ci to torch1.13 * [Feature] Add BigNAS algorithm (#219) * add calibrate-bn-statistics * add test calibrate-bn-statistics * fix mixins * fix mixins * fix mixin tests * remove slimmable channel mutable and refactor dynamic op * refact dynamic batch norm * add progressive dynamic conv2d * add center crop dynamic conv2d * refactor dynamic directory * refactor dynamic sequential * rename length to depth in dynamic sequential * add test for derived mutable * refactor dynamic op * refactor api of dynamic op * add derive mutable mixin * addbignas algorithm * refactor bignas structure * add input resizer * add input resizer to bignas * move input resizer from algorithm into classifier * remove compnents * add attentive mobilenet * delete json file * nearly(less 0.2) align inference accuracy with gml * move mutate seperated in bignas mobilenet backbone * add zero_init_residual * add set_dropout * set dropout in bignas algorithm * fix registry * add subnet yaml and nearly align inference accuracy with gml * add rsb config for bignas * remove base in config * add gml bignas config * convert to iter based * bignas forward and backward fly * fix merge conflict * fix dynamicseq bug * fix bug and refactor bignas * arrange configs of bignas * fix typo * refactor attentive_mobilenet * fix channel mismatch due to registion of DerivedMutable * update bignas & fix se channel mismatch * add AutoAugmentV2 & remove unness configs * fix lint * recover channel assertion in channel unit * fix a group bug * fix comments * add docstring * add norm in dynamic_embed * fix search loop & other minor changes * fix se expansion * minor change * add ut for bignas & attentive_mobilenet * fix ut * update bignas readme * rm unness ut & supplement get_placeholder * fix lint * fix ut * add subnet deployment in downstream tasks. * minor change * update ofa backbone * minor fix * Continued improvements of searchable backbone * minor change * drop ratio in backbone * fix comments * fix ci test * fix test * add dynamic shortcut UT * modify strategy to fit bignas * fix test * fix bug in neck * fix error * fix error * fix yaml * save subnet ckpt * merge autoslim_val/test_loop into subnet_val_loop * move calibrate_bn_mixin to utils * fix bugs and add docstring * clean code * fix register bug * clean code * update Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny Co-authored-by: sunyue1 * [Bug] Fix ckpt (#372) fix ckpt * [Feature] Add tools to convert distill ckpt to student-only ckpt. (#381) * [Feature] Add tools to convert distill ckpt to student-only ckpt. * fix bug. * add --model-only to only save model. * Make changes accroding to PR review. * Enhance the Abilities of the Tracer for Pruning. (#371) * tmp * add new mmdet models * add docstring * pass test and pre-commit * rm razor tracer * update fx tracer, now it can automatically wrap methods and functions. * update tracer passed models * add warning for torch <1.12.0 fix bug for python3.6 update placeholder to support placeholder.XXX * fix bug * update docs * fix lint * fix parse_cfg in configs * restore mutablechannel * test ite prune algorithm when using dist * add get_model_from_path to MMModelLibrrary * add mm models to DefaultModelLibrary * add uts * fix bug * fix bug * add uts * add uts * add uts * add uts * fix bug * restore ite_prune_algorithm * update doc * PruneTracer -> ChannelAnalyzer * prune_tracer -> channel_analyzer * add test for fxtracer * fix bug * fix bug * PruneTracer -> ChannelAnalyzer refine * CustomFxTracer -> MMFxTracer * fix bug when test with torch<1.12 * update print log * fix lint * rm unuseful code Co-authored-by: liukai Co-authored-by: jacky Co-authored-by: Your Name Co-authored-by: liukai * fix bug in placer holder (#395) * fix bug in placer holder * remove redundent comment Co-authored-by: liukai * Add get_prune_config and a demo config_pruning (#389) * update tools and test * add demo * disable test doc * add switch for test tools and test_doc * fix bug * update doc * update tools name * mv get_channel_units Co-authored-by: liukai * [Improvement] Adapt OFA series with SearchableMobileNetV3 (#385) * fix mutable bug in AttentiveMobileNetV3 * remove unness code * update ATTENTIVE_SUBNET_A0-A6.yaml with optimized names * unify the sampling usage in sandwich_rule-based NAS * use alias to export subnet * update OFA configs * fix attr bug * fix comments * update convert_supernet2subnet.py * correct the way to dump DerivedMutable * fix convert index bug * update OFA configs & models * fix dynamic2static * generalize convert_ofa_ckpt.py * update input_resizer * update README.md * fix ut * update export_fix_subnet * update _dynamic_to_static * update fix_subnet UT & minor fix bugs * fix ut * add new autoaug compared to attentivenas * clean * fix act * fix act_cfg * update fix_subnet * fix lint * add docstring Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: aptsunny * [Fix]Dcff Deploy Revision (#383) * dcff deploy revision * tempsave * update fix_subnet * update mutator load * export/load_fix_subnet revision for mutator * update fix_subnet with dev-1.x * update comments * update docs * update registry * [Fix] Fix commands in README to adapt branch 1.x (#400) * update commands in README for 1.x * fix commands Co-authored-by: gaoyang07 <1546308416@qq.com> * Set requires_grad to False if the teacher is not trainable (#398) * add choice and mask of units to checkpoint (#397) * add choice and mask of units to checkpoint * update * fix bug * remove device operation * fix bug * fix circle ci error * fix error in numpy for circle ci * fix bug in requirements * restore * add a note * a new solution * save mutable_channel.mask as float for dist training * refine * mv meta file test Co-authored-by: liukai Co-authored-by: jacky * [Bug]Fix fpn teacher distill (#388) fix fpn distill * [CodeCamp #122] Support KD algorithm MGD for detection. (#377) * [Feature] Support KD algorithm MGD for detection. * use connector to beauty mgd. * fix typo, add unitest. * fix mgd loss unitest. * fix mgd connector unitest. * add model pth and log file. * add mAP. * update l1 config (#405) * add l1 config * update l1 config Co-authored-by: jacky * [Feature] Add greedy search for AutoSlim (#336) * WIP: add greedysearch * fix greedy search and add bn_training_mode to autoslim * fix cfg files * fix autoslim configs * fix bugs when converting dynamic bn to static bn * change to test loop * refactor greedy search * rebase and fix greedysearch * fix lint * fix and delete useless codes * fix pytest * fix pytest and add bn_training_mode * fix lint * add reference to AutoSlimGreedySearchLoop's docstring * sort candidate_choices * fix save subnet * delete useless codes in channel container * change files' name: convert greedy_search_loop to autoslim_greedy_search_loop * [Fix] Fix metafile (#422) * fix ckpt path in metafile and readme * fix darts file path * fix docstring in ConfigurableDistiller * fix darts * fix error * add darts of mmrazor version * delete py36 Co-authored-by: liukai * update bignas cfg (#412) * check attentivenas training * update ckpt link * update supernet log Co-authored-by: aptsunny * Bump version to 1.0.0rc2 (#423) bump version to 1.0.0rc2 Co-authored-by: liukai * fix lint * fix ci * add tmp docstring for passed ci * add tmp docstring for passed ci * fix ci * add get_placeholder for quant * add skip for unittest * fix package placeholder bug * add version judgement in __init__ * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit * update prev commit Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: Yue Sun Co-authored-by: zengyi <31244134+spynccat@users.noreply.github.com> Co-authored-by: zengyi.vendor Co-authored-by: zhongyu zhang <43191879+wilxy@users.noreply.github.com> Co-authored-by: zhangzhongyu Co-authored-by: Xianpan Zhou <32625100+TinyTigerPan@users.noreply.github.com> Co-authored-by: Xianpan Zhou <32625100+PanDaMeow@users.noreply.github.com> Co-authored-by: Your Name Co-authored-by: P.Huang <37200926+FreakieHuang@users.noreply.github.com> Co-authored-by: qiufeng <44188071+wutongshenqiu@users.noreply.github.com> Co-authored-by: wangshiguang Co-authored-by: gaoyang07 <1546308416@qq.com> Co-authored-by: sunyue1 Co-authored-by: liukai Co-authored-by: Ming-Hsuan-Tu Co-authored-by: Yivona <120088893+yivona08@users.noreply.github.com> Co-authored-by: Yue Sun --- .github/workflows/build.yml | 38 +++++ configs/pruning/mmpose/dcff/fix_subnet.json | 4 + ...pact_pointrend_resnet50_8xb2_cityscapes.py | 4 + mmrazor/engine/__init__.py | 9 +- mmrazor/engine/runner/__init__.py | 4 +- mmrazor/engine/runner/iteprune_val_loop.py | 1 - mmrazor/engine/runner/quantization_loops.py | 15 +- mmrazor/models/algorithms/nas/autoslim.py | 2 + .../algorithms/pruning/ite_prune_algorithm.py | 4 + .../quantization/mm_architecture.py | 9 +- mmrazor/models/fake_quants/base.py | 6 +- .../models/fake_quants/torch_fake_quants.py | 8 +- mmrazor/models/losses/__init__.py | 1 - .../one_shot_channel_mutator.py | 4 +- mmrazor/models/mutators/group_mixin.py | 68 ++++++++ .../models/mutators/value_mutator/__init__.py | 5 + .../value_mutator/dynamic_value_mutator.py | 14 ++ .../mutators/value_mutator/value_mutator.py | 73 +++++++++ mmrazor/models/observers/base.py | 6 +- mmrazor/models/observers/torch_observers.py | 8 +- .../models/quantizers/academic_quantizer.py | 26 ++- mmrazor/models/quantizers/base.py | 5 +- mmrazor/models/quantizers/native_quantizer.py | 150 ++++++++++++------ .../models/quantizers/openvino_quantizer.py | 17 +- .../models/quantizers/tensorrt_quantizer.py | 12 +- .../task_modules/tracer/fx/custom_tracer.py | 70 ++++---- .../task_modules/tracer/fx/graph_utils.py | 44 ++--- .../quantization/backend_config/academic.py | 34 ++-- .../common_operator_config_utils.py | 86 ++++++---- .../quantization/backend_config/mapping.py | 23 ++- .../quantization/backend_config/native.py | 132 +++++++-------- .../quantization/backend_config/openvino.py | 15 +- .../quantization/backend_config/tensorrt.py | 15 +- mmrazor/structures/quantization/qconfig.py | 7 +- tests/data/models.py | 3 - tests/test_data.py | 8 + .../test_mutators/test_value_mutator.py | 66 ++++++++ .../test_task_modules/test_custom_tracer.py | 35 ---- .../test_task_modules/test_graph_utils.py | 49 +++++- tests/test_registry/test_registry.py | 40 +++-- tests/test_structures/test_qconfig.py | 23 ++- 41 files changed, 838 insertions(+), 305 deletions(-) create mode 100644 mmrazor/models/mutators/value_mutator/__init__.py create mode 100644 mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py create mode 100644 mmrazor/models/mutators/value_mutator/value_mutator.py create mode 100644 tests/test_models/test_mutators/test_value_mutator.py delete mode 100644 tests/test_models/test_task_modules/test_custom_tracer.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 53a184a3d..e00ed24c8 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,44 @@ jobs: python-version: [3.7] torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] include: + - torch: 1.6.0 + torch_version: 1.6 + torchvision: 0.7.0 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + - torch: 1.7.0 + torch_version: 1.7 + torchvision: 0.8.1 + python-version: 3.8 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + - torch: 1.8.0 + torch_version: 1.8 + torchvision: 0.9.0 + python-version: 3.8 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + - torch: 1.9.0 + torch_version: 1.9 + torchvision: 0.10.0 + python-version: 3.8 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + - torch: 1.10.0 + torch_version: 1.10 + torchvision: 0.11.0 + python-version: 3.8 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + - torch: 1.11.0 + torch_version: 1.11 + torchvision: 0.12.0 + python-version: 3.8 - torch: 1.12.0 torch_version: 1.12 torchvision: 0.13.0 diff --git a/configs/pruning/mmpose/dcff/fix_subnet.json b/configs/pruning/mmpose/dcff/fix_subnet.json index dfdcea758..f7b40f41d 100644 --- a/configs/pruning/mmpose/dcff/fix_subnet.json +++ b/configs/pruning/mmpose/dcff/fix_subnet.json @@ -54,7 +54,11 @@ "min_value":1, "min_ratio":0.9 }, +<<<<<<< HEAD "choice":0.59375 +======= + "choice":0.59374 +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) }, "backbone.layer2.1.conv1_(0, 128)_128":{ "init_args":{ diff --git a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py index e6c1eb031..a0d0d044a 100644 --- a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py +++ b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py @@ -1,7 +1,11 @@ _base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py'] # model settings +<<<<<<< HEAD _base_.model = dict( +======= +model_cfg = dict( +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) _scope_='mmrazor', type='sub_model', cfg=_base_.architecture, diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index da6cec34d..603aa3d77 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -4,15 +4,14 @@ from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, SelfDistillValLoop, - SingleTeacherDistillValLoop, SlimmableValLoop, - SubnetValLoop) + GreedySamplerTrainLoop, PTQLoop, QATEpochBasedLoop, + SelfDistillValLoop, SingleTeacherDistillValLoop, + SlimmableValLoop, SubnetValLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'StopDistillHook', - 'DMCPSubnetHook' + 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', 'QATEpochBasedLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 647d8b410..2ca6c0dbb 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -13,6 +13,6 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'PTQLoop', - 'QATEpochBasedLoop' + 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', + 'PTQLoop' ] diff --git a/mmrazor/engine/runner/iteprune_val_loop.py b/mmrazor/engine/runner/iteprune_val_loop.py index bbca5d53a..2a627f398 100644 --- a/mmrazor/engine/runner/iteprune_val_loop.py +++ b/mmrazor/engine/runner/iteprune_val_loop.py @@ -52,7 +52,6 @@ def _save_fix_subnet(self): file.write(fix_subnet) torch.save({'state_dict': static_model.state_dict()}, osp.join(self.runner.work_dir, weight_name)) - self.runner.logger.info( 'export finished and ' f'{subnet_name}, ' diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 2a0aa812f..e90715910 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -4,9 +4,18 @@ import torch from mmengine.evaluator import Evaluator from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop -from torch.ao.quantization import (disable_observer, enable_fake_quant, - enable_observer) -from torch.nn.intrinsic.qat import freeze_bn_stats + +try: + from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) + from torch.nn.intrinsic.qat import freeze_bn_stats +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') + enable_fake_quant = get_placeholder('torch>=1.13') + enable_observer = get_placeholder('torch>=1.13') + freeze_bn_stats = get_placeholder('torch>=1.13') + from torch.utils.data import DataLoader from mmrazor.registry import LOOPS diff --git a/mmrazor/models/algorithms/nas/autoslim.py b/mmrazor/models/algorithms/nas/autoslim.py index dc8d54c0e..77bb6cacc 100644 --- a/mmrazor/models/algorithms/nas/autoslim.py +++ b/mmrazor/models/algorithms/nas/autoslim.py @@ -75,6 +75,8 @@ def __init__(self, self._optim_wrapper_count_status_reinitialized = False self.norm_training = norm_training + self.bn_training_mode = bn_training_mode + def _build_mutator(self, mutator: VALID_MUTATOR_TYPE = None) -> ChannelMutator: """Build mutator.""" diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index 937aaa156..f510acd76 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -10,6 +10,7 @@ from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutators import ChannelMutator from mmrazor.registry import MODELS +from mmrazor.utils import ValidFixMutable from ..base import BaseAlgorithm LossResults = Dict[str, torch.Tensor] @@ -97,6 +98,8 @@ class ItePruneAlgorithm(BaseAlgorithm): mutator_cfg (Union[Dict, ChannelMutator], optional): The config of a mutator. Defaults to dict( type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')). + fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or + loaded dict or built :obj:`FixSubnet`. Defaults to None. data_preprocessor (Optional[Union[Dict, nn.Module]], optional): Defaults to None. target_pruning_ratio (dict, optional): The prune-target. The template @@ -118,6 +121,7 @@ def __init__(self, type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')), + fix_subnet: Optional[ValidFixMutable] = None, data_preprocessor: Optional[Union[Dict, nn.Module]] = None, target_pruning_ratio: Optional[Dict[str, float]] = None, step_freq=1, diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index c14aae08c..f5cf30f10 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -7,12 +7,17 @@ from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement from torch import nn -from torch.ao.quantization import FakeQuantizeBase -from mmrazor.models.task_modules import build_graphmodule +from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm +try: + from torch.ao.quantization import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') + LossResults = Dict[str, torch.Tensor] TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] PredictResults = List[BaseDataElement] diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py index 1d4c6dfe0..45aed7421 100644 --- a/mmrazor/models/fake_quants/base.py +++ b/mmrazor/models/fake_quants/base.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch.ao.quantization import FakeQuantize +try: + from torch.ao.quantization import FakeQuantize +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantize = get_placeholder('torch>=1.13') BaseFakeQuantize = FakeQuantize diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py index ad1a0d966..b477929ad 100644 --- a/mmrazor/models/fake_quants/torch_fake_quants.py +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -2,10 +2,14 @@ import inspect from typing import List -import torch.ao.quantization.fake_quantize as torch_fake_quant_src - from mmrazor.registry import MODELS +try: + import torch.ao.quantization.fake_quantize as torch_fake_quant_src +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_fake_quant_src = get_package_placeholder('torch>=1.13') + def register_torch_fake_quants() -> List[str]: """Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 3509acd5c..65e2108fd 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ab_loss import ABLoss -from .adaround_loss import AdaRoundLoss from .at_loss import ATLoss from .crd_loss import CRDLoss from .cross_entropy_loss import CrossEntropyLoss diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py index cc008b0b8..3aca98c95 100644 --- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -4,11 +4,13 @@ from mmrazor.models.mutables import OneShotMutableChannelUnit from mmrazor.registry import MODELS +from ..group_mixin import DynamicSampleMixin from .channel_mutator import ChannelMutator, ChannelUnitType @MODELS.register_module() -class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit]): +class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit], + DynamicSampleMixin): """OneShotChannelMutator based on ChannelMutator. It use OneShotMutableChannelUnit by default. diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py index 569f01ebc..3ecd44b74 100644 --- a/mmrazor/models/mutators/group_mixin.py +++ b/mmrazor/models/mutators/group_mixin.py @@ -8,6 +8,11 @@ from mmrazor.models.mutables.mutable_module import MutableModule from .base_mutator import MUTABLE_TYPE +if sys.version_info < (3, 8): + from typing_extensions import Protocol +else: + from typing import Protocol + class GroupMixin(): """A mixin for :class:`BaseMutator`, which can group mutables by @@ -259,3 +264,66 @@ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], f'When a mutable is set alias attribute :{alias_key},' f'the corresponding module name {mutable_name} should ' f'not be used in `custom_group` {custom_group}.') + + +class MutatorProtocol(Protocol): # pragma: no cover + + @property + def mutable_class_type(self) -> Type[BaseMutable]: + ... + + @property + def search_groups(self) -> Dict: + ... + + +class OneShotSampleMixin: + """Sample mixin for one-shot mutators.""" + + def sample_choices(self: MutatorProtocol) -> Dict: + """Sample choices for each group in search_groups.""" + random_choices = dict() + for group_id, modules in self.search_groups.items(): + random_choices[group_id] = modules[0].sample_choice() + + return random_choices + + def set_choices(self: MutatorProtocol, choices: Dict) -> None: + """Set choices for each group in search_groups.""" + for group_id, modules in self.search_groups.items(): + choice = choices[group_id] + for module in modules: + module.current_choice = choice + + +class DynamicSampleMixin(OneShotSampleMixin): + + def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict: + """Sample choices for each group in search_groups.""" + random_choices = dict() + for group_id, modules in self.search_groups.items(): + if kind == 'max': + random_choices[group_id] = modules[0].max_choice + elif kind == 'min': + random_choices[group_id] = modules[0].min_choice + else: + random_choices[group_id] = modules[0].sample_choice() + return random_choices + + @property + def max_choice(self: MutatorProtocol) -> Dict: + """Get max choices for each group in search_groups.""" + max_choice = dict() + for group_id, modules in self.search_groups.items(): + max_choice[group_id] = modules[0].max_choice + + return max_choice + + @property + def min_choice(self: MutatorProtocol) -> Dict: + """Get min choices for each group in search_groups.""" + min_choice = dict() + for group_id, modules in self.search_groups.items(): + min_choice[group_id] = modules[0].min_choice + + return min_choice diff --git a/mmrazor/models/mutators/value_mutator/__init__.py b/mmrazor/models/mutators/value_mutator/__init__.py new file mode 100644 index 000000000..a29577bb1 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dynamic_value_mutator import DynamicValueMutator +from .value_mutator import ValueMutator + +__all__ = ['ValueMutator', 'DynamicValueMutator'] diff --git a/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py new file mode 100644 index 000000000..d8d081343 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py @@ -0,0 +1,14 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.models.mutables import OneShotMutableValue +from mmrazor.registry import MODELS +from ..group_mixin import DynamicSampleMixin +from .value_mutator import ValueMutator + + +@MODELS.register_module() +class DynamicValueMutator(ValueMutator, DynamicSampleMixin): + """Dynamic value mutator with type as `OneShotMutableValue`.""" + + @property + def mutable_class_type(self): + return OneShotMutableValue diff --git a/mmrazor/models/mutators/value_mutator/value_mutator.py b/mmrazor/models/mutators/value_mutator/value_mutator.py new file mode 100644 index 000000000..5127cbe37 --- /dev/null +++ b/mmrazor/models/mutators/value_mutator/value_mutator.py @@ -0,0 +1,73 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Type + +from torch.nn import Module + +from mmrazor.models.mutables import MutableValue +from mmrazor.registry import MODELS +from ..base_mutator import BaseMutator +from ..group_mixin import GroupMixin + + +@MODELS.register_module() +class ValueMutator(BaseMutator[MutableValue], GroupMixin): + """The base class for mutable based mutator. All subclass should implement + the following APIS: + + - ``mutable_class_type`` + Args: + custom_group (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_group: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(init_cfg) + + if custom_group is None: + custom_group = [] + self._custom_group = custom_group + self._search_groups: Optional[Dict[int, List[MutableValue]]] = None + + # TODO + # should be a class property + @property + def mutable_class_type(self) -> Type[MutableValue]: + """Corresponding mutable class type. + + Returns: + Type[MUTABLE_TYPE]: Mutable class type. + """ + return MutableValue + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + Note: + For mutable based mutator, we need to build search group first. + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + self._search_groups = self.build_search_groups(supernet, + self.mutable_class_type, + self._custom_group) + + @property + def search_groups(self) -> Dict[int, List[MutableValue]]: + """Search group of supernet. + + Note: + For mutable based mutator, the search group is composed of + corresponding mutables. + Raises: + RuntimeError: Called before search group has been built. + Returns: + Dict[int, List[MUTABLE_TYPE]]: Search group. + """ + if self._search_groups is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access search group!') + return self._search_groups diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py index a68410eb0..ce226cb48 100644 --- a/mmrazor/models/observers/base.py +++ b/mmrazor/models/observers/base.py @@ -1,4 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch.ao.quantization.observer import UniformQuantizationObserverBase +try: + from torch.ao.quantization.observer import UniformQuantizationObserverBase +except ImportError: + from mmrazor.utils import get_placeholder + UniformQuantizationObserverBase = get_placeholder('torch>=1.13') BaseObserver = UniformQuantizationObserverBase diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 8e0e81d58..5dc24609f 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -2,10 +2,14 @@ import inspect from typing import List -import torch.ao.quantization.observer as torch_observer_src - from mmrazor.registry import MODELS +try: + import torch.ao.quantization.observer as torch_observer_src +except ImportError: + from mmrazor.utils import get_package_placeholder + torch_observer_src = get_package_placeholder('torch>=1.13') + def register_torch_observers() -> List[str]: """Register observers in ``torch.ao.quantization.observer`` to the diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 6a6500791..768f51c53 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -1,16 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, - PrepareCustomConfig) -from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quant_type import _quant_type_from_str -from torch.ao.quantization.quantize_fx import _fuse_fx from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHander from .base import BaseQuantizer +try: + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, + PrepareCustomConfig) + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quant_type import _quant_type_from_str + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + prepare = get_placeholder('torch>=1.13') + FuseCustomConfig = get_placeholder('torch>=1.13') + PrepareCustomConfig = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _quant_type_from_str = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + GLOBAL_DICT_KEY = '_global_' OBJECT_TYPE_DICT_KEY = 'object_type' MODULE_NAME_REGEX_DICT_KEY = 'module_name_regex' @@ -23,6 +33,7 @@ @MODELS.register_module() class AcademicQuantizer(BaseQuantizer): + """tmp.""" def __init__(self, qconfig_mapping, @@ -37,6 +48,7 @@ def __init__(self, self.example_inputs = (torch.randn(1, 3, 224, 224), ) def prepare(self, model, graph_module): + """tmp.""" preserved_attributes = self.prepare_custom_config.preserved_attributes for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) @@ -60,6 +72,7 @@ def prepare(self, model, graph_module): return prepared def gen_qconfig_mapping(self, qconfig_mapping): + """tmp.""" conf = QConfigMapping() if GLOBAL_DICT_KEY in qconfig_mapping: qconfig = QConfigHander(qconfig_mapping[GLOBAL_DICT_KEY]).convert() @@ -86,6 +99,7 @@ def gen_qconfig_mapping(self, qconfig_mapping): return conf def gen_prepare_custom_config(self, prepare_custom_config): + """tmp.""" conf = PrepareCustomConfig() if prepare_custom_config is None: return conf diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index d98fbd786..0f14917ac 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -8,6 +8,7 @@ class BaseQuantizer(BaseModule): + """tmp.""" def __init__(self, tracer): super().__init__() @@ -15,11 +16,11 @@ def __init__(self, tracer): @abstractmethod def prepare(self, model, graph_module): + """tmp.""" pass def swap_ff_with_fxff(self, model): - r""" Swap FloatFunctional with FXFloatFunctional - """ + """Swap FloatFunctional with FXFloatFunctional.""" modules_to_swap = [] for name, module in model.named_children(): if isinstance(module, torch.ao.nn.quantized.FloatFunctional): diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 84be1edfb..b3f2002e5 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -1,45 +1,62 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + import torch -from torch.ao.quantization import enable_fake_quant -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantize_fx import _fuse_fx -from torch.nn.intrinsic.qat import modules as qat_fused_modules -from torch.nn.qat import modules as qat_modules +try: + from torch.ao.quantization import enable_fake_quant + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.nn.intrinsic.qat import modules as qat_fused_modules + from torch.nn.qat import modules as qat_modules +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + enable_fake_quant = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + qat_fused_modules = get_package_placeholder('torch>=1.13') + qat_modules = get_package_placeholder('torch>=1.13') + +from mmrazor import digit_version from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, del_fakequant_before_function, del_fakequant_before_method, del_fakequant_before_module, del_fakequant_before_op) - from mmrazor.models.utils import str2class from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHander from .base import BaseQuantizer -SUPPORT_QAT_MODULES = ( - qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, - qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, - qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, - qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, - qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, - qat_modules.Conv3d, qat_modules.Linear) - -MERGE_BN_MAPPINGS = { - qat_fused_modules.ConvBn1d: qat_modules.Conv1d, - qat_fused_modules.ConvBn2d: qat_modules.Conv2d, - qat_fused_modules.ConvBn3d: qat_modules.Conv3d, - qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, - qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, - qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, - qat_fused_modules.LinearBn1d: qat_modules.Linear -} +if digit_version(torch.__version__) >= digit_version('1.13.0'): + SUPPORT_QAT_MODULES: Tuple = ( + qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, + qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, + qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, + qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, + qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, + qat_modules.Conv3d, qat_modules.Linear) + + MERGE_BN_MAPPINGS: Dict = { + qat_fused_modules.ConvBn1d: qat_modules.Conv1d, + qat_fused_modules.ConvBn2d: qat_modules.Conv2d, + qat_fused_modules.ConvBn3d: qat_modules.Conv3d, + qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, + qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, + qat_fused_modules.LinearBn1d: qat_modules.Linear + } +else: + SUPPORT_QAT_MODULES = () + MERGE_BN_MAPPINGS = {} @MODELS.register_module() class NativeQuantizer(BaseQuantizer): + """tmp.""" # backend: 'native' # support_w_modes = ['per_tensor', 'per_channel'] @@ -52,12 +69,12 @@ def __init__(self, extra_redundant_fakequants=dict( extra_module_prev_wo_fakequant=tuple(), extra_module_next_wo_fakequant=tuple(), - extra_function_prev_wo_fakequant = tuple(), - extra_function_next_wo_fakequant = tuple(), - extra_method_prev_wo_fakequant = tuple(), - extra_method_next_wo_fakequant = tuple(), - extra_op_prev_wo_fakequant = tuple(), - extra_op_next_wo_fakequant = tuple())): + extra_function_prev_wo_fakequant=tuple(), + extra_function_next_wo_fakequant=tuple(), + extra_method_prev_wo_fakequant=tuple(), + extra_method_next_wo_fakequant=tuple(), + extra_op_prev_wo_fakequant=tuple(), + extra_op_next_wo_fakequant=tuple())): super().__init__(tracer) self.qconfig = QConfigHander(global_qconfig) if self.qconfig.w_qscheme.is_per_channel: @@ -86,17 +103,21 @@ def __init__(self, @property def backend(self): + """tmp.""" return 'native' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare(self, model, graph_module): + """tmp.""" graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, @@ -115,6 +136,7 @@ def prepare(self, model, graph_module): def post_process_weight_fakequant(self, observed_module, keep_fake_quant=False): + """tmp.""" def traverse(module): for name, child in module.named_children(): @@ -145,70 +167,104 @@ def traverse(module): traverse(observed_module) def prepare_for_mmdeploy(self, model, dummy_input, checkpoint): + """tmp.""" raise NotImplementedError def del_redundant_fakequant(self, prepared): - extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_prev_wo_fakequant', tuple()) + """tmp.""" + extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_module( - prepared, self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, inplace=True) + prepared, + self.module_prev_wo_fakequant + extra_module_prev_wo_fakequant, + inplace=True) - extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_module_next_wo_fakequant', tuple()) + extra_module_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_module_next_wo_fakequant', tuple()) prepared = del_fakequant_after_module( - prepared, self.module_next_wo_fakequant + extra_module_next_wo_fakequant, inplace=True) + prepared, + self.module_next_wo_fakequant + extra_module_next_wo_fakequant, + inplace=True) - extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_prev_wo_fakequant', tuple()) + extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_method( - prepared, self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, inplace=True) + prepared, + self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, + inplace=True) - extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_function_next_wo_fakequant', tuple()) + extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_function_next_wo_fakequant', tuple()) prepared = del_fakequant_after_method( - prepared, self.function_next_wo_fakequant + extra_function_next_wo_fakequant, inplace=True) + prepared, + self.function_next_wo_fakequant + extra_function_next_wo_fakequant, + inplace=True) - extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_prev_wo_fakequant', tuple()) + extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_function( - prepared, self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, inplace=True) + prepared, + self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, + inplace=True) - extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_method_next_wo_fakequant', tuple()) + extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_method_next_wo_fakequant', tuple()) prepared = del_fakequant_after_function( - prepared, self.method_next_wo_fakequant + extra_method_next_wo_fakequant, inplace=True) + prepared, + self.method_next_wo_fakequant + extra_method_next_wo_fakequant, + inplace=True) - extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_prev_wo_fakequant', tuple()) + extra_op_prev_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_op( - prepared, self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, inplace=True) + prepared, + self.op_prev_wo_fakequant + extra_op_prev_wo_fakequant, + inplace=True) - extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get('extra_op_next_wo_fakequant', tuple()) + extra_op_next_wo_fakequant = self.extra_redundant_fakequants.get( + 'extra_op_next_wo_fakequant', tuple()) prepared = del_fakequant_after_op( - prepared, self.op_next_wo_fakequant + extra_op_next_wo_fakequant, inplace=True) + prepared, + self.op_next_wo_fakequant + extra_op_next_wo_fakequant, + inplace=True) return prepared @property def module_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def module_next_wo_fakequant(self): + """tmp.""" return tuple() @property def function_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def function_next_wo_fakequant(self): + """tmp.""" return tuple() @property def method_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def method_next_wo_fakequant(self): + """tmp.""" return tuple() @property def op_prev_wo_fakequant(self): + """tmp.""" return tuple() @property def op_next_wo_fakequant(self): + """tmp.""" return tuple() diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 0b13b23f9..23abf40da 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,8 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple import torch -from torch.ao.quantization import disable_observer + +try: + from torch.ao.quantization import disable_observer +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') from mmrazor.models.task_modules.tracer.fx import build_graphmodule from mmrazor.registry import MODELS @@ -19,21 +23,24 @@ class OpenVINOQuantizer(NativeQuantizer): @property def backend(self): + """tmp.""" return 'openvino' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare_for_mmdeploy(self, model, dummy_input=(1, 3, 224, 224), checkpoint=None): - + """tmp.""" self.swap_ff_with_fxff(model) graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) @@ -52,16 +59,20 @@ def prepare_for_mmdeploy(self, @property def module_prev_wo_fakequant(self): + """tmp.""" return (torch.nn.ReLU6, torch.nn.Identity) @property def module_next_wo_fakequant(self): + """tmp.""" return (torch.nn.MaxPool2d, ) @property def method_next_wo_fakequant(self): + """tmp.""" return ('flatten', ) @property def op_prev_wo_fakequant(self): + """tmp.""" return ('output', ) diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py index 4d9868c4f..36e3f2be7 100644 --- a/mmrazor/models/quantizers/tensorrt_quantizer.py +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization import disable_observer + +try: + from torch.ao.quantization import disable_observer +except ImportError: + from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ build_graphmodule @@ -24,21 +29,24 @@ def __init__(self, @property def backend(self): + """tmp.""" return 'tensorrt' @property def support_w_modes(self): + """tmp.""" return ['per_tensor', 'per_channel'] @property def support_a_modes(self): + """tmp.""" return ['per_tensor'] def prepare_for_mmdeploy(self, model, dummy_input=(1, 3, 224, 224), checkpoint=None): - + """tmp.""" self.swap_ff_with_fxff(model) graph = self.tracer.trace(model) graph_module = build_graphmodule(model, graph) diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 0e118290e..2d33e9875 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -5,18 +5,32 @@ import torch import torch.nn as nn + +try: + from torch._C import ScriptObject # type: ignore[attr-defined] + from torch.ao.quantization.quantize_fx import QuantizationTracer + from torch.fx import Graph, GraphModule, Tracer + from torch.fx._symbolic_trace import (_autowrap_check, + _patch_wrapped_functions, _Patcher) + from torch.fx.proxy import Proxy +except ImportError: + from mmrazor.utils import get_placeholder + ScriptObject = get_placeholder('torch>=1.13') + QuantizationTracer = get_placeholder('torch>=1.13') + GraphModule = get_placeholder('torch>=1.13') + Tracer = get_placeholder('torch>=1.13') + Graph = get_placeholder('torch>=1.13') + _autowrap_check = get_placeholder('torch>=1.13') + _patch_wrapped_functions = get_placeholder('torch>=1.13') + _Patcher = get_placeholder('torch>=1.13') + Proxy = get_placeholder('torch>=1.13') + from mmengine.utils import import_modules_from_strings -from torch._C import ScriptObject # type: ignore[attr-defined] -from torch.ao.quantization.quantize_fx import QuantizationTracer -from torch.fx import GraphModule, Tracer -from torch.fx._symbolic_trace import (Graph, _autowrap_check, - _patch_wrapped_functions, _Patcher) -from torch.fx.proxy import Proxy from mmrazor.registry import TASK_UTILS -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ +_orig_module_call: Callable = nn.Module.__call__ +_orig_module_getattr: Callable = nn.Module.__getattr__ class UntracedMethodRegistry: @@ -59,13 +73,12 @@ def method(*args, **kwargs): return wrapped_method -def custom_symbolic_trace( - root: Union[torch.nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: +def custom_symbolic_trace(root: Union[nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None): """Modified `symbolic_trace` function. Args: - root (Union[torch.nn.Module, Callable]): Module or function to be + root (Union[nn.Module, Callable]): Module or function to be traced and converted into a Graph representation. concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized. @@ -75,12 +88,12 @@ def custom_symbolic_trace( """ tracer = CustomTracer() graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance( - root, torch.nn.Module) else root.__name__ + name = root.__class__.__name__ if isinstance(root, + nn.Module) else root.__name__ return GraphModule(tracer.root, graph, name) -def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): +def _prepare_module_dict(model: nn.Module, fx_graph): """If there is a class method that can not be traced by the symbolic tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in ``CustomTracer``. @@ -128,7 +141,7 @@ def _prepare_module_dict(model: nn.Module, fx_graph: torch.fx.Graph): Args: model (nn.Module): The original model. - fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. + fx_graph (Graph): The fx Graph traced by fx tracer. """ def _get_attrs(target, attrs): @@ -157,9 +170,7 @@ def _get_attrs(target, attrs): return module_dict -def build_graphmodule(model: nn.Module, - fx_graph: torch.fx.Graph, - name: str = 'GraphModule'): +def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'): modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) modules.update(module_dict) @@ -228,7 +239,7 @@ def register_skipped_methods(self): method_registry = UntracedMethodRegistry(method) method_registry.__set_name__(imported_cls, method_str) - def call_method(self, m: torch.nn.Module, name, method, args, kwargs): + def call_method(self, m: nn.Module, name, method, args, kwargs): """Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -266,7 +277,7 @@ def call_method(self, m: torch.nn.Module, name, method, args, kwargs): return self.create_proxy('call_method', name, args, kwargs) def trace(self, root, concrete_args=None): - if isinstance(root, torch.nn.Module): + if isinstance(root, nn.Module): self.root = root fn = type(root).forward self.submodule_paths = { @@ -274,7 +285,7 @@ def trace(self, root, concrete_args=None): for name, mod in root.named_modules() } else: - self.root = torch.nn.Module() + self.root = nn.Module() fn = root tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) @@ -286,7 +297,7 @@ def trace(self, root, concrete_args=None): # used downstream in create_arg self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + def collect_tensor_attrs(m: nn.Module, prefix_atoms: List[str]): for k, v in m.__dict__.items(): if isinstance(v, (torch.Tensor, ScriptObject)): self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) @@ -298,8 +309,7 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): assert isinstance(fn, FunctionType) fn_globals = fn.__globals__ # run before it gets patched - fn, args = self.create_args_for_root(fn, - isinstance(root, torch.nn.Module), + fn, args = self.create_args_for_root(fn, isinstance(root, nn.Module), concrete_args) # Reduce number of get_attr calls @@ -328,15 +338,12 @@ def forward(*args, **kwargs): with _Patcher() as patcher: # allow duplicate patches to support the case of nested calls patcher.patch_method( - torch.nn.Module, + nn.Module, '__getattr__', module_getattr_wrapper, deduplicate=False) patcher.patch_method( - torch.nn.Module, - '__call__', - module_call_wrapper, - deduplicate=False) + nn.Module, '__call__', module_call_wrapper, deduplicate=False) for name, value in UntracedMethodRegistry.method_dict.items(): wrapped = value['wrapped'] @@ -363,8 +370,7 @@ def is_skipped_method(self, m): custom = isinstance(m, mods) return custom - def is_leaf_module(self, m: torch.nn.Module, - module_qualified_name: str) -> bool: + def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: # return super().is_leaf_module(m, module_qualified_name) leaf = super().is_leaf_module(m, module_qualified_name) return leaf diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py index fe8d620c2..5e3ddc2f4 100644 --- a/mmrazor/models/task_modules/tracer/fx/graph_utils.py +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -2,8 +2,13 @@ import copy from typing import Any, List, Tuple -import torch.fx -from torch.ao.quantization.fake_quantize import FakeQuantizeBase +import torch + +try: + from torch.ao.quantization.fake_quantize import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') def _get_attrs(target: torch.nn.Module, attr: str) -> Any: @@ -67,9 +72,9 @@ def recursive_find_erased_nodes(node, prepared_model): return nodes_to_erase -def del_fakequant_before_op(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_op(prepared_model, target_ops: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before nodes whose ``op`` attribute (node.op) is in `target_ops`. @@ -104,9 +109,9 @@ def del_fakequant_before_op(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_op(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_op(prepared_model, target_ops: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose ``op`` attribute (node.op) is in `target_ops`. @@ -145,9 +150,9 @@ def del_fakequant_after_op(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_method(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_method(prepared_model, method_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before nodes whose op attribute (node.op) is `call_method` and target attribute (node.target) is in `target_patterns`. @@ -182,9 +187,9 @@ def del_fakequant_before_method(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_method(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_method(prepared_model, method_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose op attribute (node.op) is `call_method` and target attribute (node.target) is in `target_patterns`. @@ -224,10 +229,9 @@ def del_fakequant_after_method(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_function( - prepared_model: torch.fx.GraphModule, - function_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: +def del_fakequant_before_function(prepared_model, + function_patterns: Tuple, + inplace: bool = True): """Delete useless fakequant before nodes whose op attribute (node.op) is `call_function` and target attribute (node.target) is in `target_patterns`. @@ -262,9 +266,9 @@ def del_fakequant_before_function( return prepared_model -def del_fakequant_after_function(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_function(prepared_model, function_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after nodes whose op attribute (node.op) is `call_function` and target attribute (node.target) is in `target_patterns`. @@ -304,9 +308,9 @@ def del_fakequant_after_function(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, +def del_fakequant_before_module(prepared_model, module_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant before modules whose type are in `module_patterns`. @@ -340,9 +344,9 @@ def del_fakequant_before_module(prepared_model: torch.fx.GraphModule, return prepared_model -def del_fakequant_after_module(prepared_model: torch.fx.GraphModule, +def del_fakequant_after_module(prepared_model, module_patterns: Tuple, - inplace: bool = True) -> torch.fx.GraphModule: + inplace: bool = True): """Delete useless fakequant after modules whose type are in `module_patterns`. diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py index 5983c3996..4348e7179 100644 --- a/mmrazor/structures/quantization/backend_config/academic.py +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -1,23 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_conv_configs, _get_linear_configs) -# =================== -# | DTYPE CONFIGS | -# =================== - -# weighted op int8 dtype config -# this is config for ops that has quantized weights, like linear, conv -weighted_op_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.qint8, - bias_dtype=torch.float, -) - # ===================== # | BACKEND CONFIGS | # ===================== @@ -25,6 +18,19 @@ def get_academic_backend_config() -> BackendConfig: """Return the `BackendConfig` for academic reseaching.""" + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [weighted_op_int8_dtype_config] diff --git a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py index 2a855e687..0a381d5d0 100644 --- a/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py +++ b/mmrazor/structures/quantization/backend_config/common_operator_config_utils.py @@ -5,39 +5,71 @@ import torch import torch.nn as nn -import torch.nn.functional as F -import torch.nn.intrinsic as nni -import torch.nn.intrinsic.qat as nniqat -import torch.nn.qat as nnqat -import torch.nn.quantized._reference as nnqr -from torch.ao.quantization.backend_config import (BackendPatternConfig, - DTypeConfig, ObservationType) -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.fuser_method_mappings import ( - fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, - reverse2, reverse3, reverse_sequential_wrapper2) -from torch.ao.quantization.qconfig_mapping import _FIXED_QPARAMS_OP_TO_OBSERVER + +from mmrazor import digit_version + +try: + import torch.nn.functional as F + import torch.nn.intrinsic as nni + import torch.nn.intrinsic.qat as nniqat + import torch.nn.qat as nnqat + import torch.nn.quantized._reference as nnqr + from torch.ao.quantization.backend_config import (BackendPatternConfig, + DTypeConfig, + ObservationType) + from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize + from torch.ao.quantization.fuser_method_mappings import ( + fuse_conv_bn, fuse_conv_bn_relu, fuse_convtranspose_bn, fuse_linear_bn, + reverse2, reverse3, reverse_sequential_wrapper2) + from torch.ao.quantization.qconfig_mapping import \ + _FIXED_QPARAMS_OP_TO_OBSERVER +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + F = get_package_placeholder('torch>=1.13') + nni = get_package_placeholder('torch>=1.13') + nniqat = get_package_placeholder('torch>=1.13') + nnqat = get_package_placeholder('torch>=1.13') + nnqr = get_package_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') + FixedQParamsFakeQuantize = get_placeholder('torch>=1.13') + fuse_conv_bn = get_placeholder('torch>=1.13') + fuse_conv_bn_relu = get_placeholder('torch>=1.13') + fuse_convtranspose_bn = get_placeholder('torch>=1.13') + fuse_linear_bn = get_placeholder('torch>=1.13') + reverse2 = get_placeholder('torch>=1.13') + reverse3 = get_placeholder('torch>=1.13') + reverse_sequential_wrapper2 = get_placeholder('torch>=1.13') + _FIXED_QPARAMS_OP_TO_OBSERVER = get_placeholder('torch>=1.13') _ConvMetadata = namedtuple('_ConvMetadata', [ 'root', 'transpose', 'bn', 'reference', 'transpose_reference', 'fused_conv_relu', 'fused_conv_bn', 'fused_conv_bn_relu', 'qat', 'relu_qat', 'bn_qat', 'bn_relu_qat', 'func' ]) -_Conv1dMetadata = _ConvMetadata(nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, - nnqr.Conv1d, nnqr.ConvTranspose1d, - nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, - nnqat.Conv1d, nniqat.ConvReLU1d, - nniqat.ConvBn1d, nniqat.ConvBnReLU1d, F.conv1d) -_Conv2dMetadata = _ConvMetadata(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, - nnqr.Conv2d, nnqr.ConvTranspose2d, - nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, - nnqat.Conv2d, nniqat.ConvReLU2d, - nniqat.ConvBn2d, nniqat.ConvBnReLU2d, F.conv2d) -_Conv3dMetadata = _ConvMetadata(nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, - nnqr.Conv3d, nnqr.ConvTranspose3d, - nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, - nnqat.Conv3d, nniqat.ConvReLU3d, - nniqat.ConvBn3d, nniqat.ConvBnReLU3d, F.conv3d) + +if digit_version(torch.__version__) >= digit_version('1.13.0'): + _Conv1dMetadata = _ConvMetadata( + nn.Conv1d, nn.ConvTranspose1d, nn.BatchNorm1d, nnqr.Conv1d, + nnqr.ConvTranspose1d, nni.ConvReLU1d, nni.ConvBn1d, nni.ConvBnReLU1d, + nnqat.Conv1d, nniqat.ConvReLU1d, nniqat.ConvBn1d, nniqat.ConvBnReLU1d, + F.conv1d) + _Conv2dMetadata = _ConvMetadata( + nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nnqr.Conv2d, + nnqr.ConvTranspose2d, nni.ConvReLU2d, nni.ConvBn2d, nni.ConvBnReLU2d, + nnqat.Conv2d, nniqat.ConvReLU2d, nniqat.ConvBn2d, nniqat.ConvBnReLU2d, + F.conv2d) + _Conv3dMetadata = _ConvMetadata( + nn.Conv3d, nn.ConvTranspose3d, nn.BatchNorm3d, nnqr.Conv3d, + nnqr.ConvTranspose3d, nni.ConvReLU3d, nni.ConvBn3d, nni.ConvBnReLU3d, + nnqat.Conv3d, nniqat.ConvReLU3d, nniqat.ConvBn3d, nniqat.ConvBnReLU3d, + F.conv3d) +else: + toy_val = _ConvMetadata(*[i for i in range(13)]) + _Conv1dMetadata = toy_val + _Conv2dMetadata = toy_val + _Conv3dMetadata = toy_val def _get_binary_op_configs( diff --git a/mmrazor/structures/quantization/backend_config/mapping.py b/mmrazor/structures/quantization/backend_config/mapping.py index 4c87a73b9..b9cc5372b 100644 --- a/mmrazor/structures/quantization/backend_config/mapping.py +++ b/mmrazor/structures/quantization/backend_config/mapping.py @@ -1,12 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor import digit_version from .academic import get_academic_backend_config from .native import get_native_backend_config from .openvino import get_openvino_backend_config from .tensorrt import get_tensorrt_backend_config -BackendConfigs = { - 'academic': get_academic_backend_config(), - 'native': get_native_backend_config(), - 'tensorrt': get_tensorrt_backend_config(), - 'openvino': get_openvino_backend_config() -} +if digit_version(torch.__version__) >= digit_version('1.13.0'): + BackendConfigs = { + 'academic': get_academic_backend_config(), + 'native': get_native_backend_config(), + 'tensorrt': get_tensorrt_backend_config(), + 'openvino': get_openvino_backend_config() + } +else: + BackendConfigs = { + 'academic': None, + 'native': None, + 'tensorrt': None, + 'openvino': None + } diff --git a/mmrazor/structures/quantization/backend_config/native.py b/mmrazor/structures/quantization/backend_config/native.py index d771b6012..94c35d535 100644 --- a/mmrazor/structures/quantization/backend_config/native.py +++ b/mmrazor/structures/quantization/backend_config/native.py @@ -1,6 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig + +try: + from torch.ao.quantization.backend_config import BackendConfig, DTypeConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') from .common_operator_config_utils import ( # noqa: F401,F403 _get_binary_op_configs, _get_bn_configs, _get_cat_config, @@ -8,68 +14,6 @@ _get_fixed_qparams_op_configs, _get_linear_configs, _get_ln_configs, _get_rnn_op_configs, _get_share_qparams_op_configs) -# =================== -# | DTYPE CONFIGS | -# =================== - -# weighted op int8 dtype config -# this is config for ops that has quantized weights, like linear, conv -weighted_op_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.qint8, - bias_dtype=torch.float, -) - -default_op_quint8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, -) - -default_dynamic_int8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.float, - weight_dtype=torch.qint8, - bias_dtype=torch.float, - # currently the dtype check is not yet enabled, so we provided the - # dtype_configs but it is not really used yet, - # we will enable it a bit later after we moved everything to - # backend_config_dict - is_dynamic=True, -) - -default_dynamic_float16_dtype_config = DTypeConfig( - input_dtype=torch.float16, - output_dtype=torch.float, - weight_dtype=torch.float16, - bias_dtype=torch.float, - # currently the dtype check is not yet enabled, so we provided the - # dtype_configs but it is not really used yet, we will enable it a bit - # later after we moved everything to backend_config_dict - is_dynamic=True, -) - -# Needed for LayerNorm and f.layer_norm, since currently the kernel only -# supports float weights -input_output_only_quint8_dtype_config = DTypeConfig( - input_dtype=torch.quint8, - output_dtype=torch.quint8, - weight_dtype=torch.float, - bias_dtype=torch.float, -) - -weight_only_quint8_dtype_config = DTypeConfig( - input_dtype=torch.float, - output_dtype=torch.float, - weight_dtype=torch.quint8, -) - -weight_only_quint4x2_dtype_config = DTypeConfig( - input_dtype=torch.float, - output_dtype=torch.float, - weight_dtype=torch.quint4x2, -) - # ===================== # | BACKEND CONFIGS | # ===================== @@ -80,6 +24,68 @@ def get_native_backend_config() -> BackendConfig: (fbgemm/qnnpack).""" # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK # BackendConfigs + + # =================== + # | DTYPE CONFIGS | + # =================== + # weighted op int8 dtype config + # this is config for ops that has quantized weights, like linear, conv + weighted_op_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + ) + + default_op_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + ) + + default_dynamic_int8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.float, + weight_dtype=torch.qint8, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, + # we will enable it a bit later after we moved everything to + # backend_config_dict + is_dynamic=True, + ) + + default_dynamic_float16_dtype_config = DTypeConfig( + input_dtype=torch.float16, + output_dtype=torch.float, + weight_dtype=torch.float16, + bias_dtype=torch.float, + # currently the dtype check is not yet enabled, so we provided the + # dtype_configs but it is not really used yet, we will enable it a bit + # later after we moved everything to backend_config_dict + is_dynamic=True, + ) + + # Needed for LayerNorm and f.layer_norm, since currently the kernel only + # supports float weights + input_output_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.quint8, + output_dtype=torch.quint8, + weight_dtype=torch.float, + bias_dtype=torch.float, + ) + + weight_only_quint8_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint8, + ) + + weight_only_quint4x2_dtype_config = DTypeConfig( + input_dtype=torch.float, + output_dtype=torch.float, + weight_dtype=torch.quint4x2, + ) + conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [ weighted_op_int8_dtype_config, diff --git a/mmrazor/structures/quantization/backend_config/openvino.py b/mmrazor/structures/quantization/backend_config/openvino.py index fd24eed17..d990d4ef9 100644 --- a/mmrazor/structures/quantization/backend_config/openvino.py +++ b/mmrazor/structures/quantization/backend_config/openvino.py @@ -1,8 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import (BackendConfig, - BackendPatternConfig, - DTypeConfig, ObservationType) + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_binary_op_configs, _get_conv_configs, diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py index abb585c6a..53305f650 100644 --- a/mmrazor/structures/quantization/backend_config/tensorrt.py +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -1,8 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from torch.ao.quantization.backend_config import (BackendConfig, - BackendPatternConfig, - DTypeConfig, ObservationType) + +try: + from torch.ao.quantization.backend_config import (BackendConfig, + BackendPatternConfig, + DTypeConfig, + ObservationType) +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + BackendPatternConfig = get_placeholder('torch>=1.13') + DTypeConfig = get_placeholder('torch>=1.13') + ObservationType = get_placeholder('torch>=1.13') from .common_operator_config_utils import (_get_binary_op_configs, _get_conv_configs, diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py index 3dca49730..e0fdf113d 100644 --- a/mmrazor/structures/quantization/qconfig.py +++ b/mmrazor/structures/quantization/qconfig.py @@ -3,7 +3,12 @@ import torch from mmengine.config import Config -from torch.ao.quantization import QConfig + +try: + from torch.ao.quantization import QConfig +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') from mmrazor.registry import MODELS diff --git a/tests/data/models.py b/tests/data/models.py index 33fb0c624..0347b9147 100644 --- a/tests/data/models.py +++ b/tests/data/models.py @@ -78,7 +78,6 @@ def untracable_method(self, x): x = x * -2 return x - @MODELS.register_module() class UntracableBackBone(nn.Module): @@ -123,7 +122,6 @@ def forward(self, x): x_last = self.conv2(x_attn) return self.head(x_last) - @MODELS.register_module() class LinearHeadForTest(Module): @@ -704,7 +702,6 @@ def current_choice(self): def current_choice(self, choice): super().current_choice(choice) - class DynamicLinearModel(nn.Module): """ x diff --git a/tests/test_data.py b/tests/test_data.py index df3e07f69..d56a2950b 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -6,8 +6,13 @@ from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary, MMDetModelLibrary, MMModelLibrary, +<<<<<<< HEAD MMPoseModelLibrary, MMSegModelLibrary, ModelGenerator, TorchModelLibrary) +======= + MMSegModelLibrary, ModelGenerator, + TorchModelLibrary) +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) from .data.models import SingleLineModel from .data.tracer_passed_models import (BackwardPassedModelManager, FxPassedModelManager) @@ -45,6 +50,7 @@ def test_mmseg(self): if not TEST_DATA: self.skipTest('not test data to save time.') library = MMSegModelLibrary() +<<<<<<< HEAD print(library.short_names()) self.assertTrue(library.is_default_includes_cover_all_models()) @@ -55,6 +61,8 @@ def test_mmpose(self): self.skipTest('not test data to save time.') library = MMPoseModelLibrary() print(library.short_names()) +======= +>>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) self.assertTrue(library.is_default_includes_cover_all_models()) def test_get_model_by_config(self): diff --git a/tests/test_models/test_mutators/test_value_mutator.py b/tests/test_models/test_mutators/test_value_mutator.py new file mode 100644 index 000000000..a76257a9e --- /dev/null +++ b/tests/test_models/test_mutators/test_value_mutator.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import torch + +from mmrazor.models.mutables import MutableValue +from mmrazor.models.mutators import DynamicValueMutator +from tests.data.models import DynamicAttention, DynamicMMBlock + + +class TestValueMutator(unittest.TestCase): + + def test_models_with_predefined_dynamic_op(self): + for Model in [ + DynamicAttention, + ]: + with self.subTest(model=Model): + model = Model() + value_mutator = DynamicValueMutator() + value_mutator.prepare_from_supernet(model) + value_choices = value_mutator.sample_choices() + value_mutator.set_choices(value_choices) + + mutable_value_space = [] + for mutable_value, module in model.named_modules(): + if isinstance(module, MutableValue): + mutable_value_space.append(mutable_value) + elif hasattr(module, 'source_mutables'): + for each_mutables in module.source_mutables: + if isinstance(each_mutables, MutableValue): + mutable_value_space.append(each_mutables) + assert len( + value_mutator.search_groups) == len(mutable_value_space) + + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y.shape), [2, 624]) + + def test_models_with_multiple_value(self): + for Model in [ + DynamicMMBlock, + ]: + with self.subTest(model=Model): + model = Model() + value_mutator = DynamicValueMutator() + value_mutator.prepare_from_supernet(model) + value_choices = value_mutator.sample_choices() + value_mutator.set_choices(value_choices) + + # TODO check DynamicMMBlock + mutable_value_space = [] + for mutable_value, module in model.named_modules(): + if isinstance(module, MutableValue): + mutable_value_space.append(mutable_value) + elif hasattr(module, 'source_mutables'): + for each_mutables in module.source_mutables: + if isinstance(each_mutables, MutableValue): + mutable_value_space.append(each_mutables) + count = 0 + for values in value_mutator.search_groups.values(): + count += len(values) + assert count == len(mutable_value_space) + + x = torch.rand([2, 3, 224, 224]) + y = model(x) + self.assertEqual(list(y[-1].shape), [2, 1984, 1, 1]) diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py deleted file mode 100644 index 671922f69..000000000 --- a/tests/test_models/test_task_modules/test_custom_tracer.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -from mmrazor.models.task_modules import CustomTracer, UntracedMethodRegistry -from mmrazor.testing import ConvBNReLU - - -class testCustomTracer(TestCase): - - def test_init(self): - tracer = CustomTracer() - assert tracer.skipped_methods.__len__() == 0 - - def test_trace(self): - tracer = CustomTracer() - model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) - graph = tracer.trace(model) # noqa: F841 - - def test_auto_skip_call_module(self): - pass - - def test_auto_skip_call_method(self): - pass - - def test_configurable_skipped_methods(self): - pass - - -class testUntracedMethodRgistry(TestCase): - - def test_init(self): - self.assertEqual(len(UntracedMethodRegistry.method_dict), 0) - - def test_add_method(self): - pass diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py index 7250bee95..d8f53c03c 100644 --- a/tests/test_models/test_task_modules/test_graph_utils.py +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -4,13 +4,21 @@ import torch import torch.nn as nn -from torch.ao.quantization import QConfigMapping -from torch.ao.quantization.fake_quantize import FakeQuantizeBase -from torch.ao.quantization.fx import prepare -from torch.ao.quantization.quantize_fx import _fuse_fx -from mmrazor.models.task_modules import build_graphmodule -from mmrazor.models.task_modules.tracer import CustomTracer +try: + from torch.ao.quantization import QConfigMapping + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + QConfigMapping = get_placeholder('torch>=1.13') + FakeQuantizeBase = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import CustomTracer, build_graphmodule from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, @@ -106,6 +114,9 @@ def forward(self, x): class TestGraphUtils(TestCase): def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.tracer = CustomTracer() self.backend_config = BackendConfigs['native'] self.qconfig = QConfigHander(global_qconfig) @@ -114,6 +125,9 @@ def setUp(self): self.example_inputs = (torch.randn(1, 3, 224, 224), ) def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + modules_to_swap = [] for name, module in model.named_children(): if isinstance(module, torch.ao.nn.quantized.FloatFunctional): @@ -126,6 +140,9 @@ def swap_ff_with_fxff(self, model): model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() def test_del_fakequant_before_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -170,6 +187,9 @@ def test_del_fakequant_before_op(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_op(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -211,6 +231,8 @@ def test_del_fakequant_after_op(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') model_to_quantize = ToyModel() model_to_quantize.eval() @@ -259,6 +281,9 @@ def test_del_fakequant_before_method(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -303,6 +328,9 @@ def test_del_fakequant_after_method(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -356,6 +384,9 @@ def test_del_fakequant_before_function(self): _get_attrs(prepared, args[1].target), FakeQuantizeBase) def test_del_fakequant_after_function(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -400,6 +431,9 @@ def test_del_fakequant_after_function(self): _get_attrs(prepared, node.next.target), FakeQuantizeBase) def test_del_fakequant_before_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() @@ -452,6 +486,9 @@ def test_del_fakequant_before_module(self): _get_attrs(prepared, args[0].target), FakeQuantizeBase) def test_del_fakequant_after_module(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + model_to_quantize = ToyModel() model_to_quantize.eval() diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 009640684..c8340f352 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -12,6 +12,8 @@ from mmrazor.models.algorithms.base import BaseAlgorithm from mmrazor.models.mutables import OneShotMutableOP from mmrazor.registry import MODELS +from mmrazor.structures import load_fix_subnet +from mmrazor.utils import ValidFixMutable @MODELS.register_module() @@ -44,13 +46,15 @@ class MockAlgorithm(BaseAlgorithm): def __init__(self, architecture: Union[BaseModel, Dict], - _return_architecture_: Optional[bool] = None): + fix_subnet: Optional[ValidFixMutable] = None): super().__init__(architecture) - if _return_architecture_ is True: - self.return_model = self.architecture + if fix_subnet is not None: + # According to fix_subnet, delete the unchosen part of supernet + load_fix_subnet(self, fix_subnet, prefix='architecture.') + self.is_supernet = False else: - self.return_model = self + self.is_supernet = True class TestRegistry(TestCase): @@ -68,18 +72,34 @@ def test_build_razor_from_cfg(self): # model = MODELS.build(self.arch_cfg_path) # self.assertIsNotNone(model) - # test return architecture + # test fix subnet cfg = Config.fromfile( - 'tests/data/test_registry/registry_architecture_config.py') + 'tests/data/test_registry/registry_subnet_config.py') model = MODELS.build(cfg.model) - self.assertTrue(isinstance(model.return_model, MockModel)) - # test return model + # test return architecture cfg = Config.fromfile( 'tests/data/test_registry/registry_architecture_config.py') - cfg.model.pop('_return_architecture_') model = MODELS.build(cfg.model) - self.assertTrue(isinstance(model.return_model, MockAlgorithm)) + self.assertTrue(isinstance(model, BaseModel)) + + def test_build_subnet_prune_from_cfg(self): + mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') + init_cfg = dict( + type='Pretrained', + checkpoint='tests/data/test_registry/subnet_weight.pth') + # test fix subnet + model_cfg = dict( + # use mmrazor's build_func + type='mmrazor.sub_model', + cfg=dict( + cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', + pretrained=False), + fix_subnet=mutator_cfg, + mode='mutator', + init_cfg=init_cfg) + model = MODELS.build(model_cfg) + self.assertTrue(isinstance(model, BaseModel)) def test_build_subnet_prune_from_cfg_by_mutator(self): mutator_cfg = fileio.load('tests/data/test_registry/subnet.json') diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py index 045b02c83..4730ab6cc 100644 --- a/tests/test_structures/test_qconfig.py +++ b/tests/test_structures/test_qconfig.py @@ -4,8 +4,14 @@ import torch from mmengine.config import Config -from torch.ao.quantization import QConfig +try: + from torch.ao.quantization import QConfig +except ImportError: + from mmrazor.utils import get_placeholder + QConfig = get_placeholder('torch>=1.13') + +from mmrazor import digit_version from mmrazor.models.fake_quants import register_torch_fake_quants from mmrazor.models.observers import register_torch_observers from mmrazor.structures import QConfigHander, QSchemeHander @@ -17,6 +23,9 @@ class TestQSchemeHander(TestCase): def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # per_channel qscheme = QSchemeHander(is_symmetry=True, is_per_channel=True) assert qscheme.torch_qscheme is torch.per_channel_symmetric @@ -34,6 +43,9 @@ def test_init(self): assert qscheme.is_symmetric_range is True def test_to_observer_params(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # qdtype = quint8 ret_params = QSchemeHander(qdtype='quint8').to_observer_params() assert ret_params['dtype'] == torch.quint8 @@ -78,6 +90,9 @@ def setUp(self): self.qconfig = Config(self.qconfig_dict) def test_check_qconfig(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + assert QConfigHander.check_qconfig(self.qconfig_dict) is True assert QConfigHander.check_qconfig(self.qconfig) is True qconfig_dict = copy.copy(self.qconfig_dict) @@ -86,6 +101,9 @@ def test_check_qconfig(self): assert QConfigHander.check_qconfig(qconfig_dict) is False def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + # test dict init qconfig = QConfigHander(self.qconfig_dict) assert hasattr(qconfig, 'w_qscheme') @@ -105,6 +123,9 @@ def test_init(self): assert qconfig.a_qscheme.is_per_channel is True def test_convert(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + qconfig = QConfigHander(self.qconfig) torch_qconfig = qconfig.convert() assert isinstance(torch_qconfig, QConfig) From c0d16ce8b00bbc1326f023896fcf7cafba958268 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Mon, 16 Jan 2023 18:36:32 +0800 Subject: [PATCH 08/44] [Docs] Add docstring and unittest about backendconfig & observer & fakequant (#428) * add ut about backendconfig * add ut about observers and fakequants in torch * fix torch1.13 ci --- .../quantization/backend_config/academic.py | 9 ++- .../quantization/backend_config/native.py | 8 ++- .../quantization/backend_config/openvino.py | 7 ++- .../quantization/backend_config/tensorrt.py | 7 ++- .../test_lsq_fake_quants.py | 0 .../test_torch_fake_quants.py | 18 ++++++ .../test_observers/test_torch_observers.py | 18 ++++++ tests/test_structures/test_backendconfig.py | 62 +++++++++++++++++++ 8 files changed, 123 insertions(+), 6 deletions(-) rename tests/test_models/{test_fake_quantize => test_fake_quants}/test_lsq_fake_quants.py (100%) create mode 100644 tests/test_models/test_fake_quants/test_torch_fake_quants.py create mode 100644 tests/test_models/test_observers/test_torch_observers.py create mode 100644 tests/test_structures/test_backendconfig.py diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py index 4348e7179..6b4f0d598 100644 --- a/mmrazor/structures/quantization/backend_config/academic.py +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -17,7 +17,12 @@ def get_academic_backend_config() -> BackendConfig: - """Return the `BackendConfig` for academic reseaching.""" + """Return the `BackendConfig` for academic reseaching. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # =================== # | DTYPE CONFIGS | @@ -34,7 +39,7 @@ def get_academic_backend_config() -> BackendConfig: conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [weighted_op_int8_dtype_config] - return BackendConfig('native') \ + return BackendConfig('academic') \ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) diff --git a/mmrazor/structures/quantization/backend_config/native.py b/mmrazor/structures/quantization/backend_config/native.py index 94c35d535..59085a56a 100644 --- a/mmrazor/structures/quantization/backend_config/native.py +++ b/mmrazor/structures/quantization/backend_config/native.py @@ -20,8 +20,12 @@ def get_native_backend_config() -> BackendConfig: - """Return the `BackendConfig` for PyTorch Native backend - (fbgemm/qnnpack).""" + """Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK # BackendConfigs diff --git a/mmrazor/structures/quantization/backend_config/openvino.py b/mmrazor/structures/quantization/backend_config/openvino.py index d990d4ef9..5e3051f75 100644 --- a/mmrazor/structures/quantization/backend_config/openvino.py +++ b/mmrazor/structures/quantization/backend_config/openvino.py @@ -20,7 +20,12 @@ def get_openvino_backend_config() -> BackendConfig: - """Return the `BackendConfig` for the OpenVINO backend.""" + """Return the `BackendConfig` for the OpenVINO backend. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # dtype configs weighted_op_qint8_dtype_config = DTypeConfig( input_dtype=torch.quint8, diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py index 53305f650..791463233 100644 --- a/mmrazor/structures/quantization/backend_config/tensorrt.py +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -20,7 +20,12 @@ def get_tensorrt_backend_config() -> BackendConfig: - """Return the `BackendConfig` for the TensorRT backend.""" + """Return the `BackendConfig` for the TensorRT backend. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # dtype configs weighted_op_qint8_dtype_config = DTypeConfig( input_dtype=torch.qint8, diff --git a/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py similarity index 100% rename from tests/test_models/test_fake_quantize/test_lsq_fake_quants.py rename to tests/test_models/test_fake_quants/test_lsq_fake_quants.py diff --git a/tests/test_models/test_fake_quants/test_torch_fake_quants.py b/tests/test_models/test_fake_quants/test_torch_fake_quants.py new file mode 100644 index 000000000..485113e90 --- /dev/null +++ b/tests/test_models/test_fake_quants/test_torch_fake_quants.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.models.fake_quants import register_torch_fake_quants +from mmrazor.registry import MODELS + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_register_torch_fake_quants(): + + TORCH_fake_quants = register_torch_fake_quants() + assert isinstance(TORCH_fake_quants, list) + for fake_quant in TORCH_fake_quants: + assert MODELS.get(fake_quant) diff --git a/tests/test_models/test_observers/test_torch_observers.py b/tests/test_models/test_observers/test_torch_observers.py new file mode 100644 index 000000000..cc32e69d8 --- /dev/null +++ b/tests/test_models/test_observers/test_torch_observers.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.models.observers import register_torch_observers +from mmrazor.registry import MODELS + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_register_torch_observers(): + + TORCH_observers = register_torch_observers() + assert isinstance(TORCH_observers, list) + for observer in TORCH_observers: + assert MODELS.get(observer) diff --git a/tests/test_structures/test_backendconfig.py b/tests/test_structures/test_backendconfig.py new file mode 100644 index 000000000..24295e391 --- /dev/null +++ b/tests/test_structures/test_backendconfig.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + from torch.ao.quantization.backend_config import BackendConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.structures.quantization.backend_config import ( + BackendConfigs, get_academic_backend_config, + get_academic_backend_config_dict, get_native_backend_config, + get_native_backend_config_dict, get_openvino_backend_config, + get_openvino_backend_config_dict, get_tensorrt_backend_config, + get_tensorrt_backend_config_dict) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_get_backend_config(): + + # test get_native_backend_config + native_backend_config = get_native_backend_config() + assert isinstance(native_backend_config, BackendConfig) + assert native_backend_config.name == 'native' + native_backend_config_dict = get_native_backend_config_dict() + assert isinstance(native_backend_config_dict, dict) + + # test get_academic_backend_config + academic_backend_config = get_academic_backend_config() + assert isinstance(academic_backend_config, BackendConfig) + assert academic_backend_config.name == 'academic' + academic_backend_config_dict = get_academic_backend_config_dict() + assert isinstance(academic_backend_config_dict, dict) + + # test get_openvino_backend_config + openvino_backend_config = get_openvino_backend_config() + assert isinstance(openvino_backend_config, BackendConfig) + assert openvino_backend_config.name == 'openvino' + openvino_backend_config_dict = get_openvino_backend_config_dict() + assert isinstance(openvino_backend_config_dict, dict) + + # test get_tensorrt_backend_config + tensorrt_backend_config = get_tensorrt_backend_config() + assert isinstance(tensorrt_backend_config, BackendConfig) + assert tensorrt_backend_config.name == 'tensorrt' + tensorrt_backend_config_dict = get_tensorrt_backend_config_dict() + assert isinstance(tensorrt_backend_config_dict, dict) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_backendconfigs_mapping(): + + mapping = BackendConfigs + assert isinstance(mapping, dict) + assert 'academic' in mapping.keys() + assert isinstance(mapping['academic'], BackendConfig) From 4dcf3f130e76cd6c5933e140a04c9a6687e5245b Mon Sep 17 00:00:00 2001 From: Ivan Zhang <51170394+415905716@users.noreply.github.com> Date: Tue, 17 Jan 2023 11:19:22 +0800 Subject: [PATCH 09/44] [Docs] Add docstring for `MMArchitectureQuant` & `NativeQuantizer` (#425) * add docstring on mm_architecture& native_quantizer * add naive openvino r18 qat config & dist_ptq.sh * Added a more accurate description * unitest&doc * checkpoint url * unitest * passed_pre_commit * unitest on native_quantizer& fix bugs * remove dist_ptq * add get_placeholder&skipTest * complete arg descriptions * fix import bugs * fix pre-commit * add get_placeholder * add typehint and doctring * update docstring&typehint * update docstring * pre-commit * fix some problems * fix bug --- ...tq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 10 +- ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 6 +- ...penvino_resnet50_8xb32_in1k_calib32xb32.py | 5 +- .../minmax_openvino_resnet18_8xb32_in1k.py | 65 +++++ .../quantization/mm_architecture.py | 129 +++++++--- mmrazor/models/quantizers/native_quantizer.py | 116 +++++++-- .../test_algorithms/test_mm_architecture.py | 166 +++++++++++++ .../test_quantizers/test_native_quantizer.py | 228 ++++++++++++++++++ 8 files changed, 679 insertions(+), 46 deletions(-) create mode 100644 configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py create mode 100644 tests/test_models/test_algorithms/test_mm_architecture.py create mode 100644 tests/test_models/test_quantizers/test_native_quantizer.py diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index df667c141..d7c9cdf47 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -17,12 +17,13 @@ qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), ) +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa: E501 + model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', architecture=_base_.model, - float_checkpoint='/tmp/humu/mobilenet_v2_batch256_imagenet' + - '_20200708-3b2dc3af.pth', + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, @@ -32,3 +33,8 @@ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 56da13de9..5ba1eec85 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -19,11 +19,13 @@ qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), ) +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', architecture=_base_.model, - float_checkpoint='/tmp/humu/resnet18_8xb32_in1k_20210831-fbbb1da6.pth', + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, @@ -33,3 +35,5 @@ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index 09e103bfc..bd734ee40 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -19,11 +19,13 @@ qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), ) +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 + model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', architecture=_base_.model, - float_checkpoint='/tmp/humu/resnet50_8xb32_in1k_20210831-ea4938fc.pth', + float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, @@ -33,3 +35,4 @@ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' ]))) +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..8aa11d6b3 --- /dev/null +++ b/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py @@ -0,0 +1,65 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +train_dataloader = dict(batch_size=1024) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', + bit=8, + is_symmetry=True, + is_symmetric_range=True, + ), + a_qscheme=dict( + qdtype='quint8', + bit=8, + is_symmetry=True, + averaging_constant=0.1, + ), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + architecture=_base_.model, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=100, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +# test_cfg = val_cfg diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index f5cf30f10..9feb3fb53 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -10,7 +10,7 @@ from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS -from ..base import BaseAlgorithm +from ..base import BaseAlgorithm, BaseModel try: from torch.ao.quantization import FakeQuantizeBase @@ -29,35 +29,43 @@ class MMArchitectureQuant(BaseAlgorithm): """General quantization. Args: - architecture (dict | :obj:`BaseModel`): The config of - :class:`BaseModel` or built model. - quantizer (dict | :obj:`BaseModel`): The config of - :class:`BaseQuantizer` or built model. - export_mode (str): The mode of the model to be exported. Defaults to - predict. - qmodel_modes (list): The available mode of runner. - data_preprocessor (dict | torch.nn.Module | None): The pre-process + architecture (Union[Dict, BaseModel]): The config of model to be + quantized. + quantizer (Union[Dict, BaseModel]): The quantizer to support different + backend type. + qmodel_modes (List): The available mode of runner. + data_preprocessor (Optional[Dict]): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to None. - pretrained_ckpt (str, Optional): The path of pretrained checkpoint. - Defaults to None. - init_cfg (dict): The weight initialized config for - :class:`BaseModule`. + forward_modes (Tuple): The modes in forward method in OpenMMLab + architecture could be tensor, predict, or loss. It can generate + different graph of quantized model. + float_checkpoint (Optional[str]): The path of pretrained FP checkpoint. + Quantization is different from or task, we recommend to use + `float_checkpoint` as pretrain model. Defaults to None. + init_cfg (Optional[Dict]): The weight initialized config for: + class:`BaseModule`. + + Note: + forward_modes (Tuple): In OpenMMLab architecture, differenet modes + will trace a different graph of quantized model. """ def __init__(self, - architecture, - quantizer, - data_preprocessor=None, - forward_modes=('tensor', 'predict', 'loss'), + architecture: Union[Dict, BaseModel], + quantizer: Union[Dict, BaseModel], + data_preprocessor: Optional[Dict] = None, + forward_modes: Tuple = ('tensor', 'predict', 'loss'), float_checkpoint: Optional[str] = None, - input_shapes=(1, 3, 224, 224), - init_cfg=None): + input_shapes: Tuple = (1, 3, 224, 224), + init_cfg: Optional[Dict] = None): if data_preprocessor is None: data_preprocessor = {} # The build process is in MMEngine, so we need to add scope here. + # Default to mmcls.ClsDataPreprocessor. data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') super().__init__(architecture, data_preprocessor, init_cfg) + # If we have a float_checkpoint, we load it as pretrain. if float_checkpoint: _ = load_checkpoint(self.architecture, float_checkpoint) self.architecture._is_init = True @@ -70,7 +78,22 @@ def __init__(self, self.sync_qparams('predict') - def sync_qparams(self, src_mode): + def sync_qparams(self, src_mode: str): + """Sync all quantize parameters in different `forward_modes`. We could + have more than one forward mode to generate graphs, each mode will + generate one graph. But in training, only one graph will be update, so + we need to sync qparams in the other graphs. + + Args: + src_mode (str): The modes of forward method. + + Note: + `traverse()` function recursively traverses all module to sync + quantized graph generated from different `forward_modes`. + This is because We have different mode ('tensor', 'predict', + 'loss') in OpenMMLab architecture which have different graph + in some subtle ways, so we need to sync them here. + """ def traverse(module, prefix): for name, child in module._modules.items(): @@ -84,10 +107,10 @@ def traverse(module, prefix): if src_param.shape == param.shape: param.data.copy_(src_param) else: - requirs_grad = param.requires_grad - param.requires_grad = False + # requirs_grad = param.requires_grad + # param.requires_grad = False param.resize_(src_param.shape) - param.requires_grad = requirs_grad + # param.requires_grad = requirs_grad param.data.copy_(src_param) for name, buffer in child.named_buffers(): buffer_name = f'{child_name}.{name}' @@ -106,7 +129,31 @@ def traverse(module, prefix): continue traverse(self.qmodels[mode], '') - def _build_qmodels(self, model): + def _build_qmodels(self, model: BaseModel): + """Build quantized models from the given model. + + Args: + model (BaseModel): the given fp model. + + Example: + The main body of the graph is all the same, but the last one or two + op will have difference, as shown below. + + self.qmodels['tensor'].graph.print_tabular() + opcode target args + call_module head.fc (activation_post_process_38,) + output output (head_fc,) + + self.qmodels['loss'].graph.print_tabular() + opcode target args + call_method _get_loss (head, head_fc, data_samples) + output output (_get_loss,) + + self.qmodels['predict'].graph.print_tabular() + opcode target args + call_method _get_predictions (head, head_fc, data_samples) + output output (_get_predictions,) + """ qmodels = nn.ModuleDict() @@ -137,19 +184,27 @@ def forward(self, else: return self.architecture(inputs, data_samples, mode) - def calibrate_step(self, data): + def calibrate_step(self, data: Union[Dict, Tuple, List]): + """PTQ method need calibrate by cali data.""" + data = self.data_preprocessor(data, False) return self._run_forward(data, mode='predict') @MODEL_WRAPPERS.register_module() class MMArchitectureQuantDDP(MMDistributedDataParallel): - """DDPwapper for GeneralQuant.""" + """DDPwapper for GeneralQuant. + + Args: + device_ids (Optional[Union[List, int, torch.device]]): devices to run + ddp. + """ def __init__(self, *, device_ids: Optional[Union[List, int, torch.device]] = None, **kwargs) -> None: + if device_ids is None: if os.environ.get('LOCAL_RANK') is not None: device_ids = [int(os.environ['LOCAL_RANK'])] @@ -159,8 +214,26 @@ def __init__(self, self.module.qmodels = self.module._build_qmodels( self.module.architecture) - def calibrate_step(self, data): + def calibrate_step(self, data: Union[Dict, Tuple, List]): + """PTQ method need calibrate by cali data.""" + return self.module.calibrate_step(data) - def sync_qparams(self, src): + def sync_qparams(self, src: str): + """Same as in 'MMArchitectureQuant'. Sync all quantize parameters in + different `forward_modes`. We could have several modes to generate + graphs, but in training, only one graph will be update, so we need to + sync qparams on the other graphs. + + Args: + src (str): The src modes of forward method. + + Note: + `traverse()` function recursively traverses all module to sync + quantized graph generated from different `forward_modes`. + This is because We have different mode ('tensor', 'predict', + 'loss') in OpenMMLab architecture which have different graph + in some subtle ways, so we need to sync them here. + """ + self.module.sync_qparams(src) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index b3f2002e5..d0534d361 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -1,17 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn +from mmengine.config import Config try: from torch.ao.quantization import enable_fake_quant from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.graph_module import ObservedGraphModule from torch.ao.quantization.qconfig_mapping import QConfigMapping from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.fx.graph_module import GraphModule from torch.nn.intrinsic.qat import modules as qat_fused_modules from torch.nn.qat import modules as qat_modules except ImportError: from mmrazor.utils import get_package_placeholder, get_placeholder + GraphModule = get_placeholder('torch>=1.13') + ObservedGraphModule = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') prepare = get_placeholder('torch>=1.13') QConfigMapping = get_placeholder('torch>=1.13') @@ -56,17 +62,43 @@ @MODELS.register_module() class NativeQuantizer(BaseQuantizer): - """tmp.""" + """Native class for quantizer. + + Args: + global_qconfig (Union[Dict, Config]): Config for quantization details + of weight and activation include observer, quantizer, and qscheme. + no_observer_modules (Optional[List]): Modules don't need observer. + To fit different backend, we need qconfig to determine the modules + which don't need observer. + tracer (Dict): Config for tracer to trace modules for torch fx . + + Raises: + NotImplementedError: _description_ + + Examples: + >>> global_qconfig = dict( + ... w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + ... a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + ... w_fake_quant=dict(type='mmrazor.FakeQuantize'), + ... a_fake_quant=dict(type='mmrazor.FakeQuantize'), + ... w_qscheme=dict( + ... qdtype='qint8', bit=8, is_symmetry=True, + ... is_symmetric_range=True), + ... a_qscheme=dict( + ... qdtype='quint8', bit=8, is_symmetry=True, + ... averaging_constant=0.1), +) + """ # backend: 'native' # support_w_modes = ['per_tensor', 'per_channel'] # support_a_modes = ['per_tensor'] def __init__(self, - global_qconfig, - no_observer_modules=None, - tracer=dict(type='CustomTracer'), - extra_redundant_fakequants=dict( + global_qconfig: Union[Dict, Config], + no_observer_modules: Optional[List] = None, + tracer: Dict = dict(type='CustomTracer'), + extra_redundant_fakequants: Dict = dict( extra_module_prev_wo_fakequant=tuple(), extra_module_next_wo_fakequant=tuple(), extra_function_prev_wo_fakequant=tuple(), @@ -117,7 +149,28 @@ def support_a_modes(self): return ['per_tensor'] def prepare(self, model, graph_module): - """tmp.""" + """prepare graph to ObservedGraphModule. + + Args: + graph_module (_type_): GraphModules before fuse. + + Returns: + ObservedGraphModule: GraphModules after fuse and observer. + + Notes: + 'graph_module' after '_fuse_fx()' function will fuse conv, BN, ReLU + into modules in SUPPORT_QAT_MODULES. + 'graph_module' after 'prepare()' function will become observed. + + Notes: + Keep `is_qat` is True is because in Pytorch when `is_qat` is false, + the `_fuse_fx()` function only fuse module into `nn.Squential`. + In mmrazor, we aim to add more ptq algorithm into our pipeline such + as Adaround, these kind of ptq method have some additional + fake_quant operations that we need it to be fused into our + `SUPPORT_QAT_MODULES` type, which is a tricky way to deal with it. + """ + graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, @@ -134,18 +187,41 @@ def prepare(self, model, graph_module): return prepared def post_process_weight_fakequant(self, - observed_module, - keep_fake_quant=False): - """tmp.""" + observed_module: ObservedGraphModule, + keep_fake_quant: bool = False): + """weight fake-quant for supported QAT modules. + + Args: + observed_module (ObservedGraphModule): Modules after fused and + observed. + keep_fake_quant (bool, optional): Bool to determine whether to keep + fake-quant op, depending on the backend. Defaults to False. + + Note: + `post_process_weight_fakequant()` function is necessary that the + `SUPPORT_QAT_MODULES` will be convert to normal modules, and + BN will be really integrated into conv layers. + """ def traverse(module): for name, child in module.named_children(): + # Trace `SUPPORT_QAT_MODULES` recursively. if isinstance(child, SUPPORT_QAT_MODULES): + # We add w_fakequant once in case some ptq methods have + # specific operations such as Adaround. So we do Quantize + # to perform these operations and do dequantize to + # introduce quantization loss in advance. weight_fakequant = child.weight_fake_quant child.weight.data = weight_fakequant(child.weight.data) + # `to_float()` function fuse BN into conv or conv_relu, and + # also convert a qat module to a normal module. + # source url: https://github.com/pytorch/pytorch/blob/master/torch/nn/intrinsic/qat/modules/conv_fused.py # noqa: E501 float_child = child.to_float() + # This is decided by backend type, some backend need + # explicitly keep the fake quant structure, others don't. + # TODO add deploy doc link if keep_fake_quant: for m in float_child.modules(): setattr(m, 'qconfig', self.qconfig.convert()) @@ -166,12 +242,24 @@ def traverse(module): observed_module.apply(enable_fake_quant) traverse(observed_module) - def prepare_for_mmdeploy(self, model, dummy_input, checkpoint): - """tmp.""" + def prepare_for_mmdeploy(self, model: nn.Module, dummy_input: Tuple, + checkpoint: Optional[str]): + """Prepare model to Observed_model.""" raise NotImplementedError - def del_redundant_fakequant(self, prepared): - """tmp.""" + def del_redundant_fakequant(self, prepared: GraphModule): + """delete redundant fakequant op in prepared model. + + Returns: + prepared (GraphModule): prepared model after delete redundant + fakequant op. + + Notes: + We can configure different ways to delete redundant nodes: + @property + def module_prev_wo_fakequant(self): + return (torch.nn.ReLU6, torch.nn.Identity) + """ extra_module_prev_wo_fakequant = self.extra_redundant_fakequants.get( 'extra_module_prev_wo_fakequant', tuple()) prepared = del_fakequant_before_module( diff --git a/tests/test_models/test_algorithms/test_mm_architecture.py b/tests/test_models/test_algorithms/test_mm_architecture.py new file mode 100644 index 000000000..4862bff91 --- /dev/null +++ b/tests/test_models/test_algorithms/test_mm_architecture.py @@ -0,0 +1,166 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import tempfile +from unittest import TestCase + +import torch +import torch.nn as nn + +try: + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + +from mmengine.model import BaseModel + +from mmrazor import digit_version +from mmrazor.models.algorithms import MMArchitectureQuant +from mmrazor.registry import MODELS + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@MODELS.register_module() +class ToyQuantModel(BaseModel): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +class TestMMArchitectureQuant(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + filename = 'fp_model.pth' + filename = os.path.join(self.temp_dir, filename) + # import pdb; pdb.set_trace() + toymodel = ToyQuantModel() + torch.save(toymodel.state_dict(), filename) + + global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', + bit=8, + is_symmetry=True, + is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', + bit=8, + is_symmetry=True, + averaging_constant=0.1), + ) + alg_kwargs = dict( + type='mmrazor.MMArchitectureQuant', + architecture=dict(type='ToyQuantModel'), + float_checkpoint=filename, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + self.alg_kwargs = alg_kwargs + self.toy_model = MODELS.build(self.alg_kwargs) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + shutil.rmtree(self.temp_dir) + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + assert isinstance(self.toy_model, MMArchitectureQuant) + assert hasattr(self.toy_model, 'quantizer') + + def test_sync_qparams(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + mode = self.toy_model.forward_modes[0] + self.toy_model.sync_qparams(mode) + w_loss = self.toy_model.qmodels['loss'].block.conv1.state_dict( + )['weight'] + w_tensor = self.toy_model.qmodels['tensor'].block.conv1.state_dict( + )['weight'] + w_pred = self.toy_model.qmodels['predict'].block.conv1.state_dict( + )['weight'] + assert w_loss.equal(w_pred) + assert w_loss.equal(w_tensor) + + def test_build_qmodels(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + for forward_modes in self.toy_model.forward_modes: + qmodels = self.toy_model.qmodels[forward_modes] + assert isinstance(qmodels, GraphModule) + + def test_calibrate_step(self): + # TODO + pass diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py new file mode 100644 index 000000000..afd6011ed --- /dev/null +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -0,0 +1,228 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +from unittest import TestCase + +import torch +import torch.nn as nn + +from mmrazor import digit_version +from mmrazor.models.quantizers.native_quantizer import SUPPORT_QAT_MODULES +from mmrazor.models.task_modules.tracer import CustomTracer +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + build_graphmodule +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import BackendConfigs, QConfigHander + +try: + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.fx.graph_module import ObservedGraphModule + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quantize_fx import _fuse_fx + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + ObservedGraphModule = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyQuantModel(nn.Module): + + def __init__(self): + super().__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1)) + +no_observer_modules = [ + 'torch.nn.Conv2d', +] + +q_kwargs = dict( + type='mmrazor.NativeQuantizer', + global_qconfig=global_qconfig, + no_observer_modules=no_observer_modules, + tracer=dict(type='CustomTracer'), +) + + +class TestNativeQuantizer(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.q_kwargs = q_kwargs + self.tracer = CustomTracer() + self.backend_config = BackendConfigs['native'] + self.qconfig = QConfigHander(global_qconfig) + self.qconfig_mapping = QConfigMapping().set_global( + self.qconfig.convert()) + self.example_inputs = (torch.randn(1, 3, 224, 224), ) + self.native_quantizer = MODELS.build(self.q_kwargs) + + def tearDown(self): + pass + + def swap_ff_with_fxff(self, model): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self.swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + native_quantizer = MODELS.build(self.q_kwargs) + no_ob_dict = collections.OrderedDict() + no_ob_dict = no_ob_dict.fromkeys(native_quantizer.no_observer_modules, + None) + assert native_quantizer.qconfig_mapping.object_type_qconfigs == \ + no_ob_dict + + def test_prepare(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + toy_model = ToyQuantModel() + toy_model.eval() + + self.swap_ff_with_fxff(toy_model) + traced_graph = self.tracer.trace(toy_model) + graph_module = build_graphmodule(toy_model, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + assert isinstance(graph_module, GraphModule) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + assert isinstance(prepared, ObservedGraphModule) + + prepared = self.native_quantizer.del_redundant_fakequant(prepared) + assert isinstance(prepared, GraphModule) + + def test_post_process_weight_fakequant(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + toy_model = ToyQuantModel() + toy_model.eval() + + self.swap_ff_with_fxff(toy_model) + traced_graph = self.tracer.trace(toy_model) + graph_module = build_graphmodule(toy_model, traced_graph) + + graph_module = _fuse_fx( + graph_module=graph_module, + is_qat=True, + backend_config=self.backend_config) + assert isinstance(graph_module, GraphModule) + prepared = prepare( + model=graph_module, + qconfig_mapping=self.qconfig_mapping, + is_qat=True, + node_name_to_scope=self.tracer.node_name_to_scope, + example_inputs=self.example_inputs, + backend_config=self.backend_config) + assert isinstance(prepared, ObservedGraphModule) + + prepared = self.native_quantizer.del_redundant_fakequant(prepared) + assert isinstance(prepared, GraphModule) + + prepared_no_fq = prepared + + self.native_quantizer.post_process_weight_fakequant(prepared) + for name, child in prepared.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + raise ValueError + self.native_quantizer.post_process_weight_fakequant( + prepared_no_fq, True) + for name, child in prepared_no_fq.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + raise ValueError From 854f2698b97a654daae3f8a4ef315f59da46c6fe Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Tue, 17 Jan 2023 11:20:13 +0800 Subject: [PATCH 10/44] [Docs] Add docstring and unitest about custom tracer (#427) * rename QConfigHandler and QSchemeHandler * add docstring about custom tracer * add ut about custom tracer * fix torch1.13 ci * fix lint * fix ci * fix ci --- .../models/quantizers/academic_quantizer.py | 13 +- mmrazor/models/quantizers/native_quantizer.py | 4 +- .../task_modules/tracer/fx/custom_tracer.py | 204 ++++++++++++------ mmrazor/structures/quantization/qconfig.py | 29 +-- .../test_task_modules/mmcls_cfg.py | 2 + .../test_task_modules/test_custom_tracer.py | 185 ++++++++++++++++ .../test_task_modules/test_graph_utils.py | 4 +- tests/test_structures/test_qconfig.py | 36 ++-- 8 files changed, 357 insertions(+), 120 deletions(-) create mode 100644 tests/data/test_models/test_task_modules/mmcls_cfg.py create mode 100644 tests/test_models/test_task_modules/test_custom_tracer.py diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 768f51c53..09cfc7944 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -2,7 +2,7 @@ import torch from mmrazor.registry import MODELS -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler from .base import BaseQuantizer try: @@ -75,24 +75,25 @@ def gen_qconfig_mapping(self, qconfig_mapping): """tmp.""" conf = QConfigMapping() if GLOBAL_DICT_KEY in qconfig_mapping: - qconfig = QConfigHander(qconfig_mapping[GLOBAL_DICT_KEY]).convert() + qconfig = QConfigHandler( + qconfig_mapping[GLOBAL_DICT_KEY]).convert() conf.set_global(qconfig) for object_type, qconfig in qconfig_mapping.get( OBJECT_TYPE_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_object_type(object_type, qconfig) for module_name_regex, qconfig in qconfig_mapping.get( MODULE_NAME_REGEX_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_module_name_regex(module_name_regex, qconfig) for module_name, qconfig in qconfig_mapping.get( MODULE_NAME_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_module_name(module_name, qconfig) for module_name, object_type, index, qconfig in qconfig_mapping.get( MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): - qconfig = QConfigHander(qconfig).convert() + qconfig = QConfigHandler(qconfig).convert() conf.set_module_name_object_type_order(module_name, object_type, index, qconfig) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index d0534d361..2b75cf29c 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -33,7 +33,7 @@ del_fakequant_before_module, del_fakequant_before_op) from mmrazor.models.utils import str2class from mmrazor.registry import MODELS -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler from .base import BaseQuantizer if digit_version(torch.__version__) >= digit_version('1.13.0'): @@ -108,7 +108,7 @@ def __init__(self, extra_op_prev_wo_fakequant=tuple(), extra_op_next_wo_fakequant=tuple())): super().__init__(tracer) - self.qconfig = QConfigHander(global_qconfig) + self.qconfig = QConfigHandler(global_qconfig) if self.qconfig.w_qscheme.is_per_channel: w_mode = 'per_channel' else: diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index 2d33e9875..a3cff1167 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import functools -from types import FunctionType, MethodType -from typing import Any, Callable, Dict, List, Optional, Type, Union +from types import FunctionType +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch import torch.nn as nn @@ -34,18 +34,24 @@ class UntracedMethodRegistry: - """A `Descriptor` class which records untraced methods.""" + """A `Descriptor` class which records untraced methods. Thus, when the + class is traced with CustomTracer, the decorated method will be as a leaf + node, not be nested traced. + + Example: + >>> # `imported_cls` is the owner of the untraced method; + >>> # `method_str` is the name of the untraced method. + >>> method_registry = UntracedMethodRegistry(method) + >>> method_registry.__set_name__(imported_cls, method_str) + + Args: + method (FunctionType): Function to be registered. + """ method_dict: Dict = dict() tracer = None - def __init__(self, method): - """_summary_ - - Args: - method (FunctionType): Function to be registered. - """ + def __init__(self, method: FunctionType): self.method = method - self.instances: Dict = dict() self.owner = None def __set_name__(self, owner, name): @@ -54,11 +60,6 @@ def __set_name__(self, owner, name): wrapped = self.method_wrapper() self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped) - def __get__(self, instance, owner): - if instance is None: - return self.method - return MethodType(self.method, instance) - def method_wrapper(self): @functools.wraps(self.method) @@ -73,33 +74,12 @@ def method(*args, **kwargs): return wrapped_method -def custom_symbolic_trace(root: Union[nn.Module, Callable[..., Any]], - concrete_args: Optional[Dict[str, Any]] = None): - """Modified `symbolic_trace` function. - - Args: - root (Union[nn.Module, Callable]): Module or function to be - traced and converted into a Graph representation. - concrete_args (Optional[Dict[str, any]]): Inputs to be partially - specialized. - - Returns: - _type_: _description_ - """ - tracer = CustomTracer() - graph = tracer.trace(root, concrete_args) - name = root.__class__.__name__ if isinstance(root, - nn.Module) else root.__name__ - return GraphModule(tracer.root, graph, name) - - -def _prepare_module_dict(model: nn.Module, fx_graph): +def _prepare_module_dict(model: torch.nn.Module, fx_graph): """If there is a class method that can not be traced by the symbolic tracer, a ``call_method`` ``Node`` will be inserted into the ``Graph`` in ``CustomTracer``. - For example, - ``` + Example: >>> class Model: ... def __init__(self): ... self.head = ClsHead() @@ -123,7 +103,7 @@ def _prepare_module_dict(model: nn.Module, fx_graph): ... xxx ... losses = xxx ... return losses - ``` + As the ``_get_loss`` can not be traced by torch.fx, ``Toy._get_loss`` need to be added to ``skipped_methods`` in ``CustomTracer``. Hence the code above will product the following Graph:: @@ -140,8 +120,10 @@ def _prepare_module_dict(model: nn.Module, fx_graph): the original model. Args: - model (nn.Module): The original model. - fx_graph (Graph): The fx Graph traced by fx tracer. + model (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It + contains the nodes this GraphModule should use for code generation. """ def _get_attrs(target, attrs): @@ -170,7 +152,32 @@ def _get_attrs(target, attrs): return module_dict -def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'): +def build_graphmodule(model: torch.nn.Module, + fx_graph, + name: str = 'GraphModule'): + """To build GraphModule with the generated graph by CustomTracer. The + implement of skipping methods in CustomTracer will cause the confliction of + that a node is both a leaf node and non-leaf node, which will lead that the + modification to the ``graph`` also change the original ``forward``. + + Args: + model (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + fx_graph (torch.fx.Graph): The fx Graph traced by fx tracer. It + contains the nodes this GraphModule should use for code generation. + name (str): The name of generated GraphModule. + + Returns: + GraphModule: GraphModule is an nn.Module generated from an fx.Graph. + Graphmodule has a ``graph`` attribute, as well as ``code`` and + ``forward`` attributes generated from that ``graph``. + + .. warning:: + When ``graph`` is reassigned, ``code`` and ``forward`` will be + automatically regenerated. However, if you edit the contents of the + ``graph`` without reassigning the ``graph`` attribute itself, you must + call ``recompile()`` to update the generated code. + """ modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) modules.update(module_dict) @@ -179,6 +186,18 @@ def build_graphmodule(model: nn.Module, fx_graph, name: str = 'GraphModule'): @TASK_UTILS.register_module() class CustomTracer(QuantizationTracer): + """Custom tracer based on QuantizationTracer of pytorch. It can not only + skip some modules and classes while tracing, but also skip some methods + untraced by torch.fx.Tracer. + + Args: + skipped_methods (List[str], optional): Methods to be skipped while + tracing. Defaults to None. + skipped_module_names (List[str], optional): Modules to be skipped + while tracing. Defaults to None. + skipped_module_classes (List[Callable], optional): Class to be skipped + while tracing. Defaults to None. + """ def __init__(self, skipped_methods: List[str] = [], @@ -186,16 +205,6 @@ def __init__(self, skipped_module_classes: List[Callable] = [], *args, **kwargs): - """_summary_ - - Args: - skipped_methods (List[str], optional): Methods to be skipped while - tracing. Defaults to None. - skipped_module_names (List[str], optional): Modules to be skipped - while tracing. Defaults to None. - skipped_module_classes (List[str], optional): Class to be skipped - while tracing. Defaults to None. - """ super(CustomTracer, self).__init__(skipped_module_names, skipped_module_classes) UntracedMethodRegistry.tracer = self # type: ignore @@ -214,6 +223,7 @@ def _check_valid_source(source): 'source must have at least one `.`' def register_skipped_methods(self): + """Register skipped methods to UntracedMethodRegistry.method_dict.""" if not isinstance(self.skipped_methods, list): self.skipped_methods = [self.skipped_methods] for s_method in self.skipped_methods: @@ -239,7 +249,8 @@ def register_skipped_methods(self): method_registry = UntracedMethodRegistry(method) method_registry.__set_name__(imported_cls, method_str) - def call_method(self, m: nn.Module, name, method, args, kwargs): + def call_method(self, m: torch.nn.Module, name: str, method: Callable, + args: Tuple, kwargs: Dict): """Method that specifies the behavior of this ``Tracer`` when it encounters a call to an ``nn.Module`` instance. @@ -254,15 +265,13 @@ def call_method(self, m: nn.Module, name, method, args, kwargs): ``Module`` boundaries. Args: - - m (Module): The module for which a call is being emitted - forward (Callable): The forward() method of the ``Module`` to be - invoked + m (torch.nn.Module): The module for which a call is being emitted + name (str): The name of proxy to be created. + method (Callable): The method of the ``Module`` to be invoked args (Tuple): args of the module callsite kwargs (Dict): kwargs of the module callsite Return: - The return value from the Module call. In the case that a ``call_module`` node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever value was returned from the ``Module`` @@ -271,16 +280,37 @@ def call_method(self, m: nn.Module, name, method, args, kwargs): # module_qualified_name = self.path_of_module(m) if not self.is_skipped_method(m): return method(*args, **kwargs) - args = list(args) - args.insert(0, m) - args = tuple(args) + args_l = list(args) + args_l.insert(0, m) + args = tuple(args_l) return self.create_proxy('call_method', name, args, kwargs) - def trace(self, root, concrete_args=None): - if isinstance(root, nn.Module): + def trace(self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + """Trace ``root`` and return the corresponding FX ``Graph`` + representation. ``root`` can either be an ``nn.Module`` instance or a + Python callable. Note that after this call, ``self.root`` may be + different from the ``root`` passed in here. For example, when a free + function is passed to ``trace()``, we will create an ``nn.Module`` + instance to use as the root and add embedded constants to. + + Args: + root (Union[Module, Callable]): Either a ``Module`` or a function + to be traced through. Backwards-compatibility for this + parameter is guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that + should not be treated as Proxies. This parameter is + experimental and its backwards-compatibility is *NOT* + guaranteed. + + Returns: + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + if isinstance(root, torch.nn.Module): self.root = root fn = type(root).forward - self.submodule_paths = { + self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = { mod: name for name, mod in root.named_modules() } @@ -364,13 +394,53 @@ def forward(*args, **kwargs): return self.graph - def is_skipped_method(self, m): + def is_skipped_method(self, m: torch.nn.Module): + """Judge if ``m`` is registered skipped method.""" mods = tuple(value['mod'] for value in UntracedMethodRegistry.method_dict.values()) custom = isinstance(m, mods) return custom - def is_leaf_module(self, m: nn.Module, module_qualified_name: str) -> bool: - # return super().is_leaf_module(m, module_qualified_name) + def is_leaf_module(self, m: torch.nn.Module, + module_qualified_name: str) -> bool: + """A method to specify whether a given ``nn.Module`` is a "leaf" + module. Leaf modules are the atomic units that appear in the IR, + referenced by ``call_module`` calls. By default, Modules in the PyTorch + standard library namespace (torch.nn) are leaf modules. All other + modules are traced through and their constituent ops are recorded, + unless specified otherwise via this parameter. + + Args: + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. + For example, if you have a module hierarchy where submodule + ``foo`` contains submodule ``bar``, which contains submodule + ``baz``, that module will appear with the qualified name + ``foo.bar.baz`` here. + """ leaf = super().is_leaf_module(m, module_qualified_name) return leaf + + +def custom_symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: + """Modified `symbolic_trace` function in pytorch. Given an ``nn.Module`` or + function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + Args: + root (torch.nn.Module): Module or function to be + traced and converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially + specialized. + + Returns: + GraphModule: a Module created from the recorded operations from + ``root``. + """ + tracer = CustomTracer() + graph = tracer.trace(root, concrete_args) + name = root.__class__.__name__ if isinstance( + root, torch.nn.Module) else root.__name__ + return GraphModule(tracer.root, graph, name) diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py index e0fdf113d..2a502b8f7 100644 --- a/mmrazor/structures/quantization/qconfig.py +++ b/mmrazor/structures/quantization/qconfig.py @@ -18,7 +18,7 @@ ] -class QConfigHander(): +class QConfigHandler(): """Convert custom user-friendly qconfig format to torch's QConfig. Args: @@ -44,9 +44,9 @@ def __init__(self, qconfig: Union[Dict, Config]): w_is_per_channel = True if 'PerChannel' in a_observer.__name__: a_is_per_channel = True - self.w_qscheme = QSchemeHander( + self.w_qscheme = QSchemeHandler( is_per_channel=w_is_per_channel, **qconfig['w_qscheme']) - self.a_qscheme = QSchemeHander( + self.a_qscheme = QSchemeHandler( is_per_channel=a_is_per_channel, **qconfig['a_qscheme']) w_fake_quant = MODELS.get(qconfig['w_fake_quant']['type']) @@ -79,7 +79,7 @@ def convert(self): return torch_qconfig -class QSchemeHander(object): +class QSchemeHandler(object): """Convert the qscheme of custom user-friendly qconfig to args needed in observers. @@ -149,24 +149,3 @@ def __str__(self): return f'dtype: {self.dtype} / bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ is_per_channel: {self.is_per_channel} \ / extra_kwargs: {self.kwargs}' - - -if __name__ == '__main__': - from mmrazor.models.fake_quants import register_torch_fake_quants - from mmrazor.models.observers import register_torch_observers - register_torch_observers() - register_torch_fake_quants() - - qconfig = dict( - w_observer=dict(type='mmrazor.MovingAveragePerChannelMinMaxObserver'), - a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), - w_fake_quant=dict(type='mmrazor.FakeQuantize'), - a_fake_quant=dict(type='mmrazor.FakeQuantize'), - w_qscheme=dict( - qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), - a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), - ) - from mmengine.config import Config - qconfig = Config(qconfig) - torch_qconfig = QConfigHander(qconfig).convert() - print(torch_qconfig) diff --git a/tests/data/test_models/test_task_modules/mmcls_cfg.py b/tests/data/test_models/test_task_modules/mmcls_cfg.py new file mode 100644 index 000000000..117b9383e --- /dev/null +++ b/tests/data/test_models/test_task_modules/mmcls_cfg.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] \ No newline at end of file diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py new file mode 100644 index 000000000..207e9ccad --- /dev/null +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +from mmcls.models.backbones.resnet import ResLayer +from mmengine.config import Config +from mmengine.registry import MODELS + +try: + from torch.fx import GraphModule + from torch.fx._symbolic_trace import Graph +except ImportError: + from mmrazor.utils import get_placeholder + GraphModule = get_placeholder('torch>=1.13') + Graph = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import (CustomTracer, + UntracedMethodRegistry, + build_graphmodule, + custom_symbolic_trace) +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + _prepare_module_dict + + +class ToyModel(torch.nn.Module): + + def __init__(self): + super.__init__() + + def get_loss(self, x): + return x * 0.1 + + def extrac_feature(self, x): + return x * 2 + + def forward(self, x): + x = self.extrac_feature(x) + x = self.get_loss(x) + return x + + +class testUntracedMethodRgistry(TestCase): + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + method = ToyModel.get_loss + method_registry = UntracedMethodRegistry(method) + assert hasattr(method_registry, 'method') + assert hasattr(method_registry, 'method_dict') + assert len(method_registry.method_dict) == 0 + + def test_registry_method(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + model = ToyModel + method = ToyModel.get_loss + method_registry = UntracedMethodRegistry(method) + method_registry.__set_name__(model, 'get_loss') + assert 'get_loss' in method_registry.method_dict.keys() + assert method_registry.method_dict['get_loss']['mod'] == model + + +class testCustomTracer(TestCase): + + def setUp(self): + self.cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + self.skipped_methods = [ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ] + self.skipped_module_names = ['backbone.layer4.0'] + self.skipped_module_classes = [ResLayer] + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # init without skipped_methods + tracer = CustomTracer() + assert hasattr(tracer, 'skipped_methods') + assert len(tracer.skipped_methods) == 0 + # init with skipped_methods(list) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods) + assert '_get_loss' in UntracedMethodRegistry.method_dict.keys() + assert '_get_predictions' in UntracedMethodRegistry.method_dict.keys() + # init with skipped_methods(str) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods[0]) + assert '_get_loss' in UntracedMethodRegistry.method_dict.keys() + # init with skipped_methods(int, error) + with self.assertRaises(TypeError): + CustomTracer(skipped_methods=123) + # init with skipped_methods(str, error) + with self.assertRaises(AssertionError): + CustomTracer(skipped_methods='_get_loss') + + def test_trace(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test trace with skipped_methods + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=self.skipped_methods) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + graph_loss = tracer.trace(model, concrete_args={'mode': 'loss'}) + graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'}) + assert isinstance(graph_tensor, Graph) + assert isinstance(graph_loss, Graph) + skip_flag_loss = False + for node in graph_loss.nodes: + if node.op == 'call_method' and node.target == '_get_loss': + skip_flag_loss = True + assert isinstance(graph_predict, Graph) + skip_flag_predict = False + for node in graph_predict.nodes: + if node.op == 'call_method' and node.target == '_get_predictions': + skip_flag_predict = True + assert skip_flag_loss and skip_flag_predict + + # test trace with skipped_module_names + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_module_names=self.skipped_module_names) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + skip_flag = False + for node in graph_tensor.nodes: + skipped_module_name = self.skipped_module_names[0] + if node.op == 'call_module' and node.target == skipped_module_name: + skip_flag = True + assert skip_flag + + # test trace with skipped_module_classes + model = MODELS.build(self.cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer( + skipped_module_classes=self.skipped_module_classes) + graph_tensor = tracer.trace(model, concrete_args={'mode': 'tensor'}) + skip_flag = False + for node in graph_tensor.nodes: + if node.op == 'call_module' and node.target == 'backbone.layer1': + skip_flag = True + assert skip_flag + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_custom_symbolic_trace(): + cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + model = MODELS.build(cfg.model) + UntracedMethodRegistry.method_dict = dict() + graph_module = custom_symbolic_trace( + model, concrete_args={'mode': 'tensor'}) + assert isinstance(graph_module, GraphModule) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_build_graphmodule(): + skipped_methods = ['mmcls.models.heads.ClsHead._get_predictions'] + cfg = Config.fromfile( + 'tests/data/test_models/test_task_modules/mmcls_cfg.py') + model = MODELS.build(cfg.model) + UntracedMethodRegistry.method_dict = dict() + tracer = CustomTracer(skipped_methods=skipped_methods) + graph_predict = tracer.trace(model, concrete_args={'mode': 'predict'}) + graph_module = build_graphmodule(model, graph_predict) + assert isinstance(graph_module, GraphModule) + + # test _prepare_module_dict + modules = dict(model.named_modules()) + module_dict = _prepare_module_dict(model, graph_predict) + for k, v in module_dict.items(): + assert isinstance(v, torch.nn.Module) + assert not isinstance(v, modules[k].__class__) diff --git a/tests/test_models/test_task_modules/test_graph_utils.py b/tests/test_models/test_task_modules/test_graph_utils.py index d8f53c03c..ea7f90565 100644 --- a/tests/test_models/test_task_modules/test_graph_utils.py +++ b/tests/test_models/test_task_modules/test_graph_utils.py @@ -24,7 +24,7 @@ del_fakequant_after_module, del_fakequant_after_op, del_fakequant_before_function, del_fakequant_before_method, del_fakequant_before_module, del_fakequant_before_op) -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler def _get_attrs(target, attrs): @@ -119,7 +119,7 @@ def setUp(self): self.tracer = CustomTracer() self.backend_config = BackendConfigs['native'] - self.qconfig = QConfigHander(global_qconfig) + self.qconfig = QConfigHandler(global_qconfig) self.qconfig_mapping = QConfigMapping().set_global( self.qconfig.convert()) self.example_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py index 4730ab6cc..d4f98394a 100644 --- a/tests/test_structures/test_qconfig.py +++ b/tests/test_structures/test_qconfig.py @@ -14,32 +14,32 @@ from mmrazor import digit_version from mmrazor.models.fake_quants import register_torch_fake_quants from mmrazor.models.observers import register_torch_observers -from mmrazor.structures import QConfigHander, QSchemeHander +from mmrazor.structures import QConfigHandler, QSchemeHandler register_torch_observers() register_torch_fake_quants() -class TestQSchemeHander(TestCase): +class TestQSchemeHandler(TestCase): def test_init(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') # per_channel - qscheme = QSchemeHander(is_symmetry=True, is_per_channel=True) + qscheme = QSchemeHandler(is_symmetry=True, is_per_channel=True) assert qscheme.torch_qscheme is torch.per_channel_symmetric # per_tensor - qscheme = QSchemeHander(is_symmetry=True, is_per_channel=False) + qscheme = QSchemeHandler(is_symmetry=True, is_per_channel=False) assert qscheme.torch_qscheme is torch.per_tensor_symmetric # qdtype is incorrect - self.assertRaises(AssertionError, QSchemeHander, 'float') + self.assertRaises(AssertionError, QSchemeHandler, 'float') # is_symmetric_range kwargs = {'is_symmetric_range': True} - qscheme = QSchemeHander(**kwargs) + qscheme = QSchemeHandler(**kwargs) assert qscheme.is_symmetric_range is True def test_to_observer_params(self): @@ -47,32 +47,32 @@ def test_to_observer_params(self): self.skipTest('version of torch < 1.13.0') # qdtype = quint8 - ret_params = QSchemeHander(qdtype='quint8').to_observer_params() + ret_params = QSchemeHandler(qdtype='quint8').to_observer_params() assert ret_params['dtype'] == torch.quint8 assert ret_params['quant_min'] == 0 and ret_params['quant_max'] == 255 # qdtype = qint8, is_symmetric_range=False - ret_params = QSchemeHander(qdtype='qint8').to_observer_params() + ret_params = QSchemeHandler(qdtype='qint8').to_observer_params() assert ret_params['dtype'] == torch.qint8 assert ret_params['quant_min'] == -128 and ret_params[ 'quant_max'] == 127 # qdtype = qint8, is_symmetric_range=True - ret_params = QSchemeHander( + ret_params = QSchemeHandler( qdtype='qint8', is_symmetric_range=True).to_observer_params() assert ret_params['quant_min'] == -127 and ret_params[ 'quant_max'] == 127 # per_channel - ret_params = QSchemeHander(is_per_channel=True).to_observer_params() + ret_params = QSchemeHandler(is_per_channel=True).to_observer_params() assert ret_params['ch_axis'] == 0 # per_tensor - ret_params = QSchemeHander(is_per_channel=False).to_observer_params() + ret_params = QSchemeHandler(is_per_channel=False).to_observer_params() assert 'ch_axis' not in ret_params.keys() -class TestQConfigHander(TestCase): +class TestQConfigHandler(TestCase): def setUp(self): self.qconfig_dict = dict( @@ -93,26 +93,26 @@ def test_check_qconfig(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') - assert QConfigHander.check_qconfig(self.qconfig_dict) is True - assert QConfigHander.check_qconfig(self.qconfig) is True + assert QConfigHandler.check_qconfig(self.qconfig_dict) is True + assert QConfigHandler.check_qconfig(self.qconfig) is True qconfig_dict = copy.copy(self.qconfig_dict) print(qconfig_dict) qconfig_dict.pop('w_observer') - assert QConfigHander.check_qconfig(qconfig_dict) is False + assert QConfigHandler.check_qconfig(qconfig_dict) is False def test_init(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') # test dict init - qconfig = QConfigHander(self.qconfig_dict) + qconfig = QConfigHandler(self.qconfig_dict) assert hasattr(qconfig, 'w_qscheme') assert hasattr(qconfig, 'a_qscheme') assert hasattr(qconfig, 'w_fake_quant') assert hasattr(qconfig, 'a_fake_quant') # test mmengine's Config init - qconfig = QConfigHander(self.qconfig) + qconfig = QConfigHandler(self.qconfig) assert hasattr(qconfig, 'w_qscheme') assert hasattr(qconfig, 'a_qscheme') assert hasattr(qconfig, 'w_fake_quant') @@ -126,6 +126,6 @@ def test_convert(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') - qconfig = QConfigHander(self.qconfig) + qconfig = QConfigHandler(self.qconfig) torch_qconfig = qconfig.convert() assert isinstance(torch_qconfig, QConfig) From 2eea077bb0db4a31f44a922238ba826b321a8f63 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Tue, 17 Jan 2023 18:34:32 +0800 Subject: [PATCH 11/44] [Docs & Refactor] Add docstring and UT of other quantizers (#439) * add quantizer docstring and refactor the interface of AcademicQuantizer * add AcademicQuantizer unittest * add TensorRTQuantizer and OpenVINOQuantizer unittest & refactor prepare interface * adapt torch113 ci * fix import * fix lint * update some docstring * fix ci --- .../quantization/mm_architecture.py | 14 +- .../models/quantizers/academic_quantizer.py | 118 +++++++++---- mmrazor/models/quantizers/base.py | 32 +++- mmrazor/models/quantizers/native_quantizer.py | 50 +++--- .../models/quantizers/openvino_quantizer.py | 63 ++++--- .../models/quantizers/tensorrt_quantizer.py | 55 +++--- mmrazor/testing/_fx_models.py | 2 + .../test_academic_quantizer.py | 167 ++++++++++++++++++ .../test_quantizers/test_native_quantizer.py | 4 +- .../test_openvino_quantizer.py | 78 ++++++++ .../test_tensorrt_quantizer.py | 74 ++++++++ .../test_quantizers/test_trt_quantizer.py | 34 ---- .../test_task_modules/test_custom_tracer.py | 1 - 13 files changed, 532 insertions(+), 160 deletions(-) create mode 100644 tests/test_models/test_quantizers/test_academic_quantizer.py create mode 100644 tests/test_models/test_quantizers/test_openvino_quantizer.py create mode 100644 tests/test_models/test_quantizers/test_tensorrt_quantizer.py delete mode 100644 tests/test_models/test_quantizers/test_trt_quantizer.py diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 9feb3fb53..afdd7799c 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -8,7 +8,6 @@ from mmengine.structures import BaseDataElement from torch import nn -from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.registry import MODEL_WRAPPERS, MODELS from ..base import BaseAlgorithm, BaseModel @@ -156,20 +155,10 @@ def _build_qmodels(self, model: BaseModel): """ qmodels = nn.ModuleDict() - - self.quantizer.swap_ff_with_fxff(model) - tracer = self.quantizer.tracer - for mode in self.forward_modes: concrete_args = {'mode': mode} - traced_graph = tracer.trace(model, concrete_args=concrete_args) - graph_mopdule = build_graphmodule(model, traced_graph) - observed_module = self.quantizer.prepare(model, graph_mopdule) + observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module - # import pdb - # pdb.set_trace() - # dummy_input = torch.randn(self.input_shapes) - # qmodels['predict'](dummy_input, None, 'predict') return qmodels @@ -177,6 +166,7 @@ def forward(self, inputs: torch.Tensor, data_samples: Optional[List[BaseDataElement]] = None, mode: str = 'tensor') -> ForwardResults: + """Forward with qmodels in quantization.""" if mode in self.qmodels: qmodel = self.qmodels[mode] diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index 09cfc7944..a6cfc257c 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -1,6 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + import torch +from mmrazor.models.task_modules.tracer import build_graphmodule +from mmrazor.models.utils import str2class from mmrazor.registry import MODELS from mmrazor.structures.quantization import BackendConfigs, QConfigHandler from .base import BaseQuantizer @@ -10,7 +14,6 @@ from torch.ao.quantization.fx.custom_config import (FuseCustomConfig, PrepareCustomConfig) from torch.ao.quantization.qconfig_mapping import QConfigMapping - from torch.ao.quantization.quant_type import _quant_type_from_str from torch.ao.quantization.quantize_fx import _fuse_fx except ImportError: from mmrazor.utils import get_placeholder @@ -18,37 +21,83 @@ FuseCustomConfig = get_placeholder('torch>=1.13') PrepareCustomConfig = get_placeholder('torch>=1.13') QConfigMapping = get_placeholder('torch>=1.13') - _quant_type_from_str = get_placeholder('torch>=1.13') _fuse_fx = get_placeholder('torch>=1.13') GLOBAL_DICT_KEY = '_global_' OBJECT_TYPE_DICT_KEY = 'object_type' -MODULE_NAME_REGEX_DICT_KEY = 'module_name_regex' MODULE_NAME_DICT_KEY = 'module_name' -MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = 'module_name_object_type_order' +# keys can be used in `prepare_custom_config` of `AcademicQuantizer`. FLOAT_TO_OBSERVED_DICT_KEY = 'float_to_observed_custom_module_class' PRESERVED_ATTRIBUTES_DICT_KEY = 'preserved_attributes' @MODELS.register_module() class AcademicQuantizer(BaseQuantizer): - """tmp.""" + """Quantizer for academic researching. Different from some quantizers for + deploying, `AcademicQuantizer` is without the interfaces for deployment, + but it has more flexible functions for quantizing your model. With its + help, you can custom configuration qconfig for differenet OP by + `qconfig_mapping` to implement customized experiments, including using + custom fakquant, trying mixed precision quantization, comparing different + quantization scheme and so on. + + Args: + qconfig_mapping (Dict): Mapping from model ops to qconfig to configure + how a model is quantized. You can specify qconfigs using the + following keys (in increasing match priority): + ``_global_`` : sets the global (default) qconfig + ``object_type`` : sets the qconfig for a given module type, + function, or method name + ``module_name`` : sets the qconfig for modules matching the + given module name + tracer (Dict): It can be used to trace the float model to generate the + corresponding graph, which contributes to prepare for quantizing + the float model with code-free. Default to + `dict(type='mmrazor.CustomTracer')`. + prepare_custom_config (Optional[Dict]): Custom configuration for + :func:`~torch.ao.quantization.fx.prepare`. You can specify the + follow: + ``float_to_observed_custom_module_class`` : a list of dict that + mapping from float module classes to observed module + classes, e.g. + `[('FloatCustomModule', 'ObservedCustomModule')]` + ``preserved_attributes``: a list of attributes that persist + even if they are not used in ``forward``, e.g. + `['attr1', 'attr2']` + """ def __init__(self, - qconfig_mapping, - tracer=dict(type='mmrazor.CustomTracer'), - prepare_custom_config=None, - backend_config=BackendConfigs['academic']): + qconfig_mapping: Dict, + tracer: Dict = dict(type='mmrazor.CustomTracer'), + prepare_custom_config: Optional[Dict] = None): super().__init__(tracer) self.qconfig_mapping = self.gen_qconfig_mapping(qconfig_mapping) self.prepare_custom_config = self.gen_prepare_custom_config( prepare_custom_config) - self.backend_config = backend_config + self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) - def prepare(self, model, graph_module): - """tmp.""" + @property + def backend(self): + """The key of the corresponding backend config.""" + return 'academic' + + def prepare(self, model, concrete_args=None): + """Prepare for quantizing model, which includes as follows: + + 1. Swap floatfunctional with FXFloatFunctional; + 2. Trace model to generate `GraphModule`; + 2. Fuse some OPs combination, such as conv + bn, conv + relu and so on; + 3. Swap some conv or linear module with QAT Modules which contain + weight fakequant nodes; + 4. Insert required fakequant nodes for activation. + step 3 and step 4 are implemented in + :func:`~torch.ao.quantization.fx.prepare` + """ + self.swap_ff_with_fxff(model) + traced_graph = self.tracer.trace(model, concrete_args=concrete_args) + graph_module = build_graphmodule(model, traced_graph) preserved_attributes = self.prepare_custom_config.preserved_attributes for attr_name in preserved_attributes: setattr(graph_module, attr_name, getattr(model, attr_name)) @@ -71,51 +120,46 @@ def prepare(self, model, graph_module): return prepared - def gen_qconfig_mapping(self, qconfig_mapping): - """tmp.""" + def gen_qconfig_mapping(self, qconfig_mapping: Dict): + """Convert qconfig_mapping in config file to `QConfigMapping`. + + `QConfigMapping` is a custom class for mapping from model ops to + :class:`torch.ao.quantization.QConfig` s. + """ conf = QConfigMapping() if GLOBAL_DICT_KEY in qconfig_mapping: qconfig = QConfigHandler( qconfig_mapping[GLOBAL_DICT_KEY]).convert() conf.set_global(qconfig) + for object_type, qconfig in qconfig_mapping.get( OBJECT_TYPE_DICT_KEY, []): qconfig = QConfigHandler(qconfig).convert() - conf.set_object_type(object_type, qconfig) + conf.set_object_type(str2class(object_type), qconfig) - for module_name_regex, qconfig in qconfig_mapping.get( - MODULE_NAME_REGEX_DICT_KEY, []): - qconfig = QConfigHandler(qconfig).convert() - conf.set_module_name_regex(module_name_regex, qconfig) for module_name, qconfig in qconfig_mapping.get( MODULE_NAME_DICT_KEY, []): qconfig = QConfigHandler(qconfig).convert() conf.set_module_name(module_name, qconfig) - for module_name, object_type, index, qconfig in qconfig_mapping.get( - MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY, []): - qconfig = QConfigHandler(qconfig).convert() - conf.set_module_name_object_type_order(module_name, object_type, - index, qconfig) return conf - def gen_prepare_custom_config(self, prepare_custom_config): - """tmp.""" + def gen_prepare_custom_config(self, prepare_custom_config: Optional[Dict]): + """Convert prepare_custom_config in config file to + `PrepareCustomConfig`. + + `PrepareCustomConfig` is a custom class for custom configurating + :func:`~torch.ao.quantization.fx.prepare`. + """ conf = PrepareCustomConfig() if prepare_custom_config is None: return conf else: - for quant_type_name, custom_module_mapping in \ - prepare_custom_config.get( - FLOAT_TO_OBSERVED_DICT_KEY, {}).items(): - quant_type = _quant_type_from_str(quant_type_name) - mapping_items = custom_module_mapping.items() - for float_class_str, observed_class_str in mapping_items: - float_class = MODELS.get(float_class_str) - observed_class = MODELS.get(observed_class_str) - conf.set_float_to_observed_mapping(float_class, - observed_class, - quant_type) + for float_class_str, observed_class_str in prepare_custom_config.get( # noqa: E501 + FLOAT_TO_OBSERVED_DICT_KEY, []): + float_class = MODELS.get(float_class_str) + observed_class = MODELS.get(observed_class_str) + conf.set_float_to_observed_mapping(float_class, observed_class) conf.set_preserved_attributes( prepare_custom_config.get(PRESERVED_ATTRIBUTES_DICT_KEY, [])) return conf diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 0f14917ac..866199735 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod +from typing import Dict import torch from mmengine.model import BaseModule @@ -8,18 +9,37 @@ class BaseQuantizer(BaseModule): - """tmp.""" - - def __init__(self, tracer): + """Base class for quantizers. Its role for several subclass is as follows: + 1. Provide tracer for tracing model for all subclass. + 2. Define some common abstract methods, such as `prepare`. + 3. Provide some common functional interfaces, such as `swap_ff_with_fxff`. + + Args: + tracer (Dict): It can be used to trace the float model to generate the + corresponding graph, which contributes to prepare for quantizing + the float model with code-free. + """ + + def __init__(self, tracer: Dict): super().__init__() self.tracer = TASK_UTILS.build(tracer) @abstractmethod - def prepare(self, model, graph_module): - """tmp.""" + def prepare(self, model): + """Prepare for quantizing model, which usually includes as follows: + + 1. Swap floatfunctional with FXFloatFunctional; + 2. Trace model to generate `GraphModule`; + 2. Fuse some OPs combination, such as conv + bn, conv + relu and so on; + 3. Swap some conv or linear module with QAT Modules which contain + weight fakequant nodes; + 4. Insert required fakequant nodes for activation. + 5. (Optional) Delete some redundant fakequant nodes according to the + special requirement of the backend for deployment. + """ pass - def swap_ff_with_fxff(self, model): + def swap_ff_with_fxff(self, model: torch.nn.Module): """Swap FloatFunctional with FXFloatFunctional.""" modules_to_swap = [] for name, module in model.named_children(): diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 2b75cf29c..b5de1c028 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -26,6 +26,7 @@ qat_modules = get_package_placeholder('torch>=1.13') from mmrazor import digit_version +from mmrazor.models.task_modules.tracer import build_graphmodule from mmrazor.models.task_modules.tracer.fx import ( del_fakequant_after_function, del_fakequant_after_method, del_fakequant_after_module, del_fakequant_after_op, @@ -90,10 +91,6 @@ class NativeQuantizer(BaseQuantizer): ) """ - # backend: 'native' - # support_w_modes = ['per_tensor', 'per_channel'] - # support_a_modes = ['per_tensor'] - def __init__(self, global_qconfig: Union[Dict, Config], no_observer_modules: Optional[List] = None, @@ -135,25 +132,24 @@ def __init__(self, @property def backend(self): - """tmp.""" + """The key of the corresponding backend config.""" return 'native' @property def support_w_modes(self): - """tmp.""" - return ['per_tensor', 'per_channel'] + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') @property def support_a_modes(self): - """tmp.""" - return ['per_tensor'] + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') - def prepare(self, model, graph_module): + def prepare(self, model, concrete_args=None): """prepare graph to ObservedGraphModule. - Args: - graph_module (_type_): GraphModules before fuse. - Returns: ObservedGraphModule: GraphModules after fuse and observer. @@ -170,7 +166,9 @@ def prepare(self, model, graph_module): fake_quant operations that we need it to be fused into our `SUPPORT_QAT_MODULES` type, which is a tricky way to deal with it. """ - + self.swap_ff_with_fxff(model) + traced_graph = self.tracer.trace(model, concrete_args=concrete_args) + graph_module = build_graphmodule(model, traced_graph) graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, @@ -319,40 +317,48 @@ def module_prev_wo_fakequant(self): @property def module_prev_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their previous nodes are redundant + fakequants.""" return tuple() @property def module_next_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their next nodes are redundant + fakequants.""" return tuple() @property def function_prev_wo_fakequant(self): - """tmp.""" + """Configurate the functions that their previous nodes are redundant + fakequants.""" return tuple() @property def function_next_wo_fakequant(self): - """tmp.""" + """Configurate the functions that their next nodes are redundant + fakequants.""" return tuple() @property def method_prev_wo_fakequant(self): - """tmp.""" + """Configurate the methods that their previous nodes are redundant + fakequants.""" return tuple() @property def method_next_wo_fakequant(self): - """tmp.""" + """Configurate the methods that their next nodes are redundant + fakequants.""" return tuple() @property def op_prev_wo_fakequant(self): - """tmp.""" + """Configurate the OPs that their previous nodes are redundant + fakequants.""" return tuple() @property def op_next_wo_fakequant(self): - """tmp.""" + """Configurate the OPs that their next nodes are redundant + fakequants.""" return tuple() diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 23abf40da..cb7d3084b 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple import torch @@ -8,43 +9,57 @@ from mmrazor.utils import get_placeholder disable_observer = get_placeholder('torch>=1.13') -from mmrazor.models.task_modules.tracer.fx import build_graphmodule from mmrazor.registry import MODELS from .native_quantizer import NativeQuantizer @MODELS.register_module() class OpenVINOQuantizer(NativeQuantizer): - """Quantizer for Openvino backend.""" + """Quantizer for quantizing and deploying to Openvino backend. - # backend: 'openvino' - # support_w_mode = ['per_tensor', 'per_channel'] - # support_a_mode = ['per_tensor'] + Each backend has its own features, for reducing the gap of quantized + performance between before and after deployment as possible, we should + match the backend's features in quantization. + + Openvino's some important features about quantization is as follows: + * support_w_mode = ('per_tensor', 'per_channel') + * support_a_mode = ('per_tensor') + * weight range should be symmetric, such as int 8 is [-127, 127] rather + than [-128, 127] + """ @property def backend(self): - """tmp.""" + """The backend to deploy, also the key of the corresponding backend + config.""" return 'openvino' @property def support_w_modes(self): - """tmp.""" - return ['per_tensor', 'per_channel'] + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') @property def support_a_modes(self): - """tmp.""" - return ['per_tensor'] + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') def prepare_for_mmdeploy(self, - model, - dummy_input=(1, 3, 224, 224), - checkpoint=None): - """tmp.""" - self.swap_ff_with_fxff(model) - graph = self.tracer.trace(model) - graph_module = build_graphmodule(model, graph) - observed_model = self.prepare(model, graph_module) + model: torch.nn.Module, + dummy_input: Tuple = (1, 3, 224, 224), + checkpoint: Optional[str] = None): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + observed_model = self.prepare(model) if dummy_input is not None: observed_model(torch.randn(dummy_input)) if checkpoint is not None: @@ -59,20 +74,24 @@ def prepare_for_mmdeploy(self, @property def module_prev_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their previous nodes are redundant + fakequants.""" return (torch.nn.ReLU6, torch.nn.Identity) @property def module_next_wo_fakequant(self): - """tmp.""" + """Configurate the modules that their next nodes are redundant + fakequants.""" return (torch.nn.MaxPool2d, ) @property def method_next_wo_fakequant(self): - """tmp.""" + """Configurate the methods that their next nodes are redundant + fakequants.""" return ('flatten', ) @property def op_prev_wo_fakequant(self): - """tmp.""" + """Configurate the OPs that their previous nodes are redundant + fakequants.""" return ('output', ) diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py index 36e3f2be7..028c96a8c 100644 --- a/mmrazor/models/quantizers/tensorrt_quantizer.py +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple + import torch try: @@ -7,50 +9,55 @@ from mmrazor.utils import get_placeholder disable_observer = get_placeholder('torch>=1.13') -from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ - build_graphmodule from mmrazor.registry import MODELS from .native_quantizer import NativeQuantizer @MODELS.register_module() class TensorRTQuantizer(NativeQuantizer): - """Quantizer for TensorRT backend.""" + """Quantizer for quantizing and deploying to TensorRT backend. - # backend: 'tensorrt' - # support_w_mode = ['per_tensor', 'per_channel'] - # support_a_mode = ['per_tensor'] + Each backend has its own features, for reducing the gap of quantized + performance between before and after deployment as possible, we should + match the backend's features in quantization. - def __init__(self, - global_qconfig, - no_observer_modules=None, - tracer=dict(type='CustomTracer')): - super().__init__(global_qconfig, no_observer_modules, tracer) + TensorRT's some important features about quantization is as follows: + * support_w_mode = ('per_tensor', 'per_channel') + * support_a_mode = ('per_tensor') + """ @property def backend(self): - """tmp.""" + """The backend to deploy, also the key of the corresponding backend + config.""" return 'tensorrt' @property def support_w_modes(self): - """tmp.""" - return ['per_tensor', 'per_channel'] + """Supported quantization modes for weight about per_tensor or + per_channel.""" + return ('per_tensor', 'per_channel') @property def support_a_modes(self): - """tmp.""" - return ['per_tensor'] + """Supported quantization modes for activation about per_tensor or + per_channel.""" + return ('per_tensor') def prepare_for_mmdeploy(self, - model, - dummy_input=(1, 3, 224, 224), - checkpoint=None): - """tmp.""" - self.swap_ff_with_fxff(model) - graph = self.tracer.trace(model) - graph_module = build_graphmodule(model, graph) - observed_model = self.prepare(model, graph_module) + model: torch.nn.Module, + dummy_input: Tuple = (1, 3, 224, 224), + checkpoint: Optional[str] = None): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + observed_model = self.prepare(model) if dummy_input is not None: observed_model(torch.randn(dummy_input)) if checkpoint is not None: diff --git a/mmrazor/testing/_fx_models.py b/mmrazor/testing/_fx_models.py index 969c4792d..6bf42e16a 100644 --- a/mmrazor/testing/_fx_models.py +++ b/mmrazor/testing/_fx_models.py @@ -34,6 +34,8 @@ def __init__( stride, padding, dilation, groups, bias, conv_cfg, norm_cfg, act_cfg, inplace, with_spectral_norm, padding_mode, order) + self.toy_attr1 = 1 + self.toy_attr2 = 2 def forward(self, x): x = self.conv_module.conv(x) diff --git a/tests/test_models/test_quantizers/test_academic_quantizer.py b/tests/test_models/test_quantizers/test_academic_quantizer.py new file mode 100644 index 000000000..c95060a00 --- /dev/null +++ b/tests/test_models/test_quantizers/test_academic_quantizer.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import copy +from unittest import TestCase + +import torch +from mmengine.model import BaseModule + +try: + from torch.ao.nn.intrinsic import ConvBnReLU2d + from torch.ao.quantization.backend_config import BackendConfig + from torch.ao.quantization.fx.custom_config import PrepareCustomConfig + from torch.ao.quantization.fx.graph_module import ObservedGraphModule + from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.quant_type import QuantType +except ImportError: + from mmrazor.utils import get_placeholder + ConvBnReLU2d = get_placeholder('torch>=1.13') + BackendConfig = get_placeholder('torch>=1.13') + PrepareCustomConfig = get_placeholder('torch>=1.13') + ConObservedGraphModuleBnReLU2d = get_placeholder('torch>=1.13') + QConfigMapping = get_placeholder('torch>=1.13') + QuantType = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import AcademicQuantizer +from mmrazor.models.quantizers.academic_quantizer import ( + FLOAT_TO_OBSERVED_DICT_KEY, GLOBAL_DICT_KEY, MODULE_NAME_DICT_KEY, + OBJECT_TYPE_DICT_KEY, PRESERVED_ATTRIBUTES_DICT_KEY) +from mmrazor.registry import MODELS +from mmrazor.testing import ConvBNReLU + + +@MODELS.register_module() +class ToyFloatModel(BaseModule): + + def __init__(self) -> None: + super().__init__() + + +@MODELS.register_module() +class ToyObservedModel(BaseModule): + + def __init__(self) -> None: + super().__init__() + + +class TestAcademicQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=4, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=4, is_symmetry=True), + ) + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def test_gen_qconfig_mapping(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test set GLOBAL_DICT_KEY by QConfigMapping + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.global_qconfig + + # test set OBJECT_TYPE_DICT_KEY by QConfigMapping + qconfig = copy(self.qconfig) + qconfig_mapping = { + OBJECT_TYPE_DICT_KEY: + [('torch.ao.nn.intrinsic.ConvBnReLU2d', qconfig)] + } + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.object_type_qconfigs.get(ConvBnReLU2d) + + # test set MODULE_NAME_DICT_KEY by QConfigMapping + qconfig = copy(self.qconfig) + qconfig_mapping = { + MODULE_NAME_DICT_KEY: [('conv_module.conv', qconfig)] + } + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'qconfig_mapping') + assert isinstance(quantizer.qconfig_mapping, QConfigMapping) + assert quantizer.qconfig_mapping.module_name_qconfigs.get( + 'conv_module.conv') + + def test_gen_prepare_custom_config(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # test prepare_custom_config is None + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'prepare_custom_config') + assert isinstance(quantizer.prepare_custom_config, PrepareCustomConfig) + + # test set FLOAT_TO_OBSERVED_DICT_KEY and PRESERVED_ATTRIBUTES_DICT_KEY + # by PrepareCustomConfig + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + flop_to_observed_list = [('ToyFloatModel', 'ToyObservedModel')] + preserved_attributes_list = ['toy_attr1', 'toy_attr2'] + prepare_custom_config = { + FLOAT_TO_OBSERVED_DICT_KEY: flop_to_observed_list, + PRESERVED_ATTRIBUTES_DICT_KEY: preserved_attributes_list + } + quantizer = AcademicQuantizer( + qconfig_mapping=qconfig_mapping, + prepare_custom_config=prepare_custom_config) + + assert hasattr(quantizer, 'prepare_custom_config') + assert isinstance(quantizer.prepare_custom_config, PrepareCustomConfig) + mapping = quantizer.prepare_custom_config.float_to_observed_mapping[ + QuantType.STATIC] + assert mapping.get(ToyFloatModel) + assert mapping[ToyFloatModel] == ToyObservedModel + + attributes = quantizer.prepare_custom_config.preserved_attributes + assert attributes == preserved_attributes_list + + def test_init(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + quantizer = AcademicQuantizer(qconfig_mapping=qconfig_mapping) + assert hasattr(quantizer, 'backend_config') + assert isinstance(quantizer.backend_config, BackendConfig) + + def test_prepare(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + qconfig_mapping = {GLOBAL_DICT_KEY: global_qconfig} + preserved_attributes_list = ['toy_attr1', 'toy_attr2'] + prepare_custom_config = { + PRESERVED_ATTRIBUTES_DICT_KEY: preserved_attributes_list + } + quantizer = AcademicQuantizer( + qconfig_mapping=qconfig_mapping, + prepare_custom_config=prepare_custom_config) + model = copy(self.model) + prepared = quantizer.prepare(model) + assert isinstance(prepared, ObservedGraphModule) + assert hasattr(prepared, 'toy_attr1') + assert hasattr(prepared, 'toy_attr2') diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py index afd6011ed..62052f66f 100644 --- a/tests/test_models/test_quantizers/test_native_quantizer.py +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -11,7 +11,7 @@ from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ build_graphmodule from mmrazor.registry import MODELS -from mmrazor.structures.quantization import BackendConfigs, QConfigHander +from mmrazor.structures.quantization import BackendConfigs, QConfigHandler try: from torch.ao.quantization.fx import prepare @@ -127,7 +127,7 @@ def setUp(self): self.q_kwargs = q_kwargs self.tracer = CustomTracer() self.backend_config = BackendConfigs['native'] - self.qconfig = QConfigHander(global_qconfig) + self.qconfig = QConfigHandler(global_qconfig) self.qconfig_mapping = QConfigMapping().set_global( self.qconfig.convert()) self.example_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/tests/test_models/test_quantizers/test_openvino_quantizer.py b/tests/test_models/test_quantizers/test_openvino_quantizer.py new file mode 100644 index 000000000..24fc81ca4 --- /dev/null +++ b/tests/test_models/test_quantizers/test_openvino_quantizer.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import tempfile +from copy import copy +from unittest import TestCase + +import torch + +try: + from torch.ao.quantization.fx.graph_module import ObservedGraphModule +except ImportError: + from mmrazor.utils import get_placeholder + ObservedGraphModule = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import OpenVINOQuantizer +from mmrazor.testing import ConvBNReLU + + +class TestOpenVINOQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.temp_dir = tempfile.mkdtemp() + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + shutil.rmtree(self.temp_dir) + + def test_property(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = OpenVINOQuantizer(global_qconfig=global_qconfig) + assert quantizer.backend == 'openvino' + assert quantizer.support_w_modes == ('per_tensor', 'per_channel') + assert quantizer.support_a_modes == ('per_tensor') + assert quantizer.module_prev_wo_fakequant + assert quantizer.module_next_wo_fakequant + assert quantizer.method_next_wo_fakequant + assert quantizer.op_prev_wo_fakequant + + def test_prepare_for_mmdeploy(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = OpenVINOQuantizer(global_qconfig=global_qconfig) + model = copy(self.model) + + # test checkpoint is None + prepared_deploy = quantizer.prepare_for_mmdeploy(model=model) + assert isinstance(prepared_deploy, ObservedGraphModule) + + # test checkpoint is not None + ckpt_path = os.path.join(self.temp_dir, + 'test_prepare_for_mmdeploy.pth') + model = copy(self.model) + prepared = quantizer.prepare(model) + torch.save({'state_dict': prepared.state_dict()}, ckpt_path) + prepared_deploy = quantizer.prepare_for_mmdeploy( + model=model, checkpoint=ckpt_path) + assert isinstance(prepared_deploy, ObservedGraphModule) diff --git a/tests/test_models/test_quantizers/test_tensorrt_quantizer.py b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py new file mode 100644 index 000000000..aeae311f3 --- /dev/null +++ b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import shutil +import tempfile +from copy import copy +from unittest import TestCase + +import torch + +try: + from torch.ao.quantization.fx.graph_module import ObservedGraphModule +except ImportError: + from mmrazor.utils import get_placeholder + ObservedGraphModule = get_placeholder('torch>=1.13') + +from mmrazor import digit_version +from mmrazor.models.quantizers import TensorRTQuantizer +from mmrazor.testing import ConvBNReLU + + +class TestTensorRTQuantizer(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + self.global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), + ) + self.temp_dir = tempfile.mkdtemp() + self.model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + + def tearDown(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + shutil.rmtree(self.temp_dir) + + def test_property(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = TensorRTQuantizer(global_qconfig=global_qconfig) + assert quantizer.backend == 'tensorrt' + assert quantizer.support_w_modes == ('per_tensor', 'per_channel') + assert quantizer.support_a_modes == ('per_tensor') + + def test_prepare_for_mmdeploy(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + global_qconfig = copy(self.global_qconfig) + quantizer = TensorRTQuantizer(global_qconfig=global_qconfig) + model = copy(self.model) + + # test checkpoint is None + prepared_deploy = quantizer.prepare_for_mmdeploy(model=model) + assert isinstance(prepared_deploy, ObservedGraphModule) + + # test checkpoint is not None + ckpt_path = os.path.join(self.temp_dir, + 'test_prepare_for_mmdeploy.pth') + model = copy(self.model) + prepared = quantizer.prepare(model) + torch.save({'state_dict': prepared.state_dict()}, ckpt_path) + prepared_deploy = quantizer.prepare_for_mmdeploy( + model=model, checkpoint=ckpt_path) + assert isinstance(prepared_deploy, ObservedGraphModule) diff --git a/tests/test_models/test_quantizers/test_trt_quantizer.py b/tests/test_models/test_quantizers/test_trt_quantizer.py deleted file mode 100644 index 9f85d1ecd..000000000 --- a/tests/test_models/test_quantizers/test_trt_quantizer.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import torch.nn as nn - - -class ToyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - # TODO - - -class TestTRTQuantizer(TestCase): - """TODO. - - Args: - TestCase (_type_): _description_ - """ - - def test_init(self): - pass - - def test_prepare(self): - pass - - def test_convert(self): - pass - - def test_states(self): - pass - - def test_forward(self): - pass diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py index 207e9ccad..fcb02f381 100644 --- a/tests/test_models/test_task_modules/test_custom_tracer.py +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -51,7 +51,6 @@ def test_init(self): method_registry = UntracedMethodRegistry(method) assert hasattr(method_registry, 'method') assert hasattr(method_registry, 'method_dict') - assert len(method_registry.method_dict) == 0 def test_registry_method(self): if digit_version(torch.__version__) < digit_version('1.13.0'): From e78ac797d8e91af49a13c16e31786be00601056b Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Wed, 18 Jan 2023 14:42:27 +0800 Subject: [PATCH 12/44] [Feature&Doc]Modify ptq pipeline and support lsq (#435) * modify ptq pipeline and support lsq * use placeholder * fix lsq && quantloop * add lsq pytest * add quant loop pytest * test lsq observer * fix bug under pt13 * fix reset_min_max_vals * fix bugs under pt13 * fix configs * add get_qconfig_mapping * delete is_qat, add doc and fix pytest * delete useless codes in custom_tracer * skip pytest under pt13 * add todo: check freezebn * fix pytest bugs * fix pytest * fix pytest * fix pytest --- ...tq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 8 + ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 8 + ...penvino_resnet50_8xb32_in1k_calib32xb32.py | 8 + ...openvino_retina_r50_1x_coco_calib32xb32.py | 49 +++ .../lsq_openvino_resnet18_8xb16_cifar10.py | 61 --- .../qat/lsq_openvino_resnet18_8xb32_in1k.py | 64 +++ mmrazor/engine/__init__.py | 10 +- mmrazor/engine/runner/__init__.py | 5 +- mmrazor/engine/runner/quantization_loops.py | 116 ++++- .../quantization/mm_architecture.py | 65 ++- mmrazor/models/fake_quants/__init__.py | 7 +- mmrazor/models/fake_quants/lsq.py | 273 ++++++++++++ mmrazor/models/observers/__init__.py | 6 +- mmrazor/models/observers/lsq.py | 129 ++++++ mmrazor/models/observers/torch_observers.py | 20 + .../models/quantizers/academic_quantizer.py | 5 + mmrazor/models/quantizers/base.py | 34 ++ mmrazor/models/quantizers/native_quantizer.py | 57 ++- .../models/quantizers/openvino_quantizer.py | 1 + .../task_modules/tracer/fx/custom_tracer.py | 31 ++ .../task_modules/tracer/fx/graph_utils.py | 12 +- .../test_algorithms/test_mm_architecture.py | 46 +- .../test_fake_quants/test_lsq_fake_quants.py | 181 +++++++- .../test_observers/test_lsq_observer.py | 77 ++++ .../test_quantizers/test_native_quantizer.py | 8 +- tests/test_runners/test_quantization_loop.py | 413 ++++++++++++++++++ 26 files changed, 1550 insertions(+), 144 deletions(-) create mode 100644 configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py delete mode 100644 configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py create mode 100644 configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py create mode 100644 mmrazor/models/fake_quants/lsq.py create mode 100644 mmrazor/models/observers/lsq.py create mode 100644 tests/test_models/test_observers/test_lsq_observer.py create mode 100644 tests/test_runners/test_quantization_loop.py diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index d7c9cdf47..7c919c0fd 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -22,6 +22,14 @@ model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), architecture=_base_.model, float_checkpoint=float_checkpoint, quantizer=dict( diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 5ba1eec85..125f46367 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -24,6 +24,14 @@ model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), architecture=_base_.model, float_checkpoint=float_checkpoint, quantizer=dict( diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index bd734ee40..f629337ed 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -24,6 +24,14 @@ model = dict( _delete_=True, type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), architecture=_base_.model, float_checkpoint=float_checkpoint, quantizer=dict( diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py new file mode 100644 index 000000000..578f5fe84 --- /dev/null +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -0,0 +1,49 @@ +_base_ = ['mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'] + +train_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=train_dataloader, + calibrate_steps=32, +) + +retina = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + architecture=retina, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.anchor_head.AnchorHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py deleted file mode 100644 index 8076769a9..000000000 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py +++ /dev/null @@ -1,61 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] - -resnet = _base_.model -float_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 - -model = dict( - _delete_=True, - _scope_='mmrazor', - type='MMArchitectureQuant', - architecture=resnet, - float_checkpoint=float_ckpt, - quantizer=dict( - type='OpenvinoQuantizer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ], - qconfig=dict( - qtype='affine', - w_observer=dict(type='mmrazor.LSQObserver'), - a_observer=dict(type='mmrazor.LSQObserver'), - w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - w_qscheme=dict( - bit=8, - is_symmetry=True, - is_per_channel=True, - is_pot_scale=False, - ), - a_qscheme=dict( - bit=8, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False), - ))) - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) - -# learning policy -param_scheduler = dict( - _delete_=True, - type='CosineAnnealingLR', - T_max=100, - by_epoch=True, - begin=0, - end=100) - -model_wrapper_cfg = dict( - type='mmrazor.MMArchitectureQuantDDP', - broadcast_buffers=False, - find_unused_parameters=True) - -# train, val, test setting -train_cfg = dict( - _delete_=True, - type='mmrazor.QATEpochBasedLoop', - max_epochs=100, - val_interval=1) -val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -# test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..0b79232f8 --- /dev/null +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py @@ -0,0 +1,64 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=100, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') +test_cfg = val_cfg diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index 603aa3d77..d3d8e6981 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -4,14 +4,16 @@ from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimGreedySearchLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, - GreedySamplerTrainLoop, PTQLoop, QATEpochBasedLoop, - SelfDistillValLoop, SingleTeacherDistillValLoop, - SlimmableValLoop, SubnetValLoop) + GreedySamplerTrainLoop, LSQEpochBasedLoop, PTQLoop, + QATEpochBasedLoop, QATValLoop, SelfDistillValLoop, + SingleTeacherDistillValLoop, SlimmableValLoop, + SubnetValLoop) __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', 'QATEpochBasedLoop' + 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', + 'QATEpochBasedLoop', 'LSQEpochBasedLoop', 'QATValLoop' ] diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 2ca6c0dbb..5fe2fd524 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -4,7 +4,8 @@ from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop from .iteprune_val_loop import ItePruneValLoop -from .quantization_loops import PTQLoop, QATEpochBasedLoop +from .quantization_loops import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, + QATValLoop) from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop from .subnet_val_loop import SubnetValLoop @@ -14,5 +15,5 @@ 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'QATEpochBasedLoop', - 'PTQLoop' + 'PTQLoop', 'LSQEpochBasedLoop', 'QATValLoop' ] diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index e90715910..df0f4f76d 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -18,6 +18,8 @@ from torch.utils.data import DataLoader +from mmrazor.models.fake_quants import (enable_param_learning, + enable_static_estimate, enable_val) from mmrazor.registry import LOOPS @@ -30,13 +32,13 @@ class QATEpochBasedLoop(EpochBasedTrainLoop): dataloader (Dataloader or dict): An iterator to generate one batch of dataset each iteration. max_epochs (int): Total training epochs. - val_begin (int): The epoch that begins validating. - Defaults to 1. + val_begin (int): The epoch that begins validating. Defaults to 1. val_interval (int): Validation interval. Defaults to 1. disable_observer_begin (int): The number of total epochs to update - observers. + observers. Defaults to -1, which means observers are enabled + all the time. freeze_bn_begin (int): The number of total epochs to update batch norm - stats. + stats. Defaults to -1, which means no need to freeze bn. dynamic_intervals (List[Tuple[int, int]], optional): The first element in the tuple is a milestone and the second element is a interval. The interval is used after the @@ -50,8 +52,8 @@ def __init__( max_epochs: int, val_begin: int = 1, val_interval: int = 1, - disable_observer_begin: int = 3, - freeze_bn_begin: int = 3, + disable_observer_begin: int = -1, + freeze_bn_begin: int = -1, dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: super().__init__(runner, dataloader, max_epochs, val_begin, val_interval, dynamic_intervals) @@ -59,14 +61,24 @@ def __init__( self.disable_observer_begin = disable_observer_begin self.freeze_bn_begin = freeze_bn_begin - def run(self) -> torch.nn.Module: + def prepare_for_run_epoch(self): + """Toggle the state of the observers and fake quantizers before qat + training.""" + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) + + def prepare_for_val(self): + """Toggle the state of the observers and fake quantizers before + validation.""" + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) + + def run(self): """Launch training.""" self.runner.call_hook('before_train') while self._epoch < self._max_epochs: - # state: observer_enabled, fakequant_enabled - self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(enable_observer) + self.prepare_for_run_epoch() self.run_epoch() self._decide_current_val_interval() @@ -74,8 +86,8 @@ def run(self) -> torch.nn.Module: and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): # observer disabled during evaluation - self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(disable_observer) + self.prepare_for_val() + self.runner.model.sync_qparams(src_mode='loss') self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -85,14 +97,88 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # TODO freeze bn - if self._epoch >= self.disable_observer_begin: + # The initialized _epoch equals to 0 so _epoch + 1 + # equal to the current epoch + if self._epoch + 1 >= self.disable_observer_begin: self.runner.model.apply(disable_observer) - if self._epoch >= self.freeze_bn_begin: + if self._epoch + 1 >= self.freeze_bn_begin: + self.runner.model.apply(freeze_bn_stats) + + for idx, data_batch in enumerate(self.dataloader): + self.run_iter(idx, data_batch) + + self.runner.call_hook('after_train_epoch') + self._epoch += 1 + + +@LOOPS.register_module() +class LSQEpochBasedLoop(QATEpochBasedLoop): + """`EpochBasedLoop` for `LEARNED STEP SIZE QUANTIZATION` + + Paper: Learned Step Size Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + max_epochs (int): Total training epochs. + val_begin (int): The epoch that begins validating. Defaults to 1. + val_interval (int): Validation interval. Defaults to 1. + freeze_bn_begin (int): The number of total epochs to update batch norm + stats. Defaults to -1, which means no need to freeze bn. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. + """ + + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + val_begin: int = 1, + val_interval: int = 1, + freeze_bn_begin: int = -1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__( + runner, + dataloader, + max_epochs, + val_begin, + val_interval, + freeze_bn_begin=freeze_bn_begin, + dynamic_intervals=dynamic_intervals) + + self.is_first_batch = True + + def prepare_for_run_epoch(self): + """Toggle the state of the observers and fake quantizers before qat + training.""" + pass + + def prepare_for_val(self): + """Toggle the state of the observers and fake quantizers before + validation.""" + self.runner.model.apply(enable_val) + + def run_epoch(self) -> None: + """Iterate one epoch.""" + self.runner.call_hook('before_train_epoch') + self.runner.model.train() + + # TODO freeze bn + if self._epoch + 1 >= self.freeze_bn_begin: self.runner.model.apply(freeze_bn_stats) for idx, data_batch in enumerate(self.dataloader): + if self.is_first_batch: + # lsq init + self.is_first_batch = False + self.runner.model.apply(enable_static_estimate) + else: + self.runner.model.apply(enable_param_learning) self.run_iter(idx, data_batch) self.runner.call_hook('after_train_epoch') diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index afdd7799c..d3b0be089 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -12,10 +12,13 @@ from ..base import BaseAlgorithm, BaseModel try: - from torch.ao.quantization import FakeQuantizeBase + from torch.ao.quantization import (FakeQuantizeBase, MinMaxObserver, + PerChannelMinMaxObserver) except ImportError: from mmrazor.utils import get_placeholder FakeQuantizeBase = get_placeholder('torch>=1.13') + MinMaxObserver = get_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_placeholder('torch>=1.13') LossResults = Dict[str, torch.Tensor] TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] @@ -58,24 +61,36 @@ def __init__(self, input_shapes: Tuple = (1, 3, 224, 224), init_cfg: Optional[Dict] = None): - if data_preprocessor is None: - data_preprocessor = {} - # The build process is in MMEngine, so we need to add scope here. - # Default to mmcls.ClsDataPreprocessor. - data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') super().__init__(architecture, data_preprocessor, init_cfg) - # If we have a float_checkpoint, we load it as pretrain. - if float_checkpoint: - _ = load_checkpoint(self.architecture, float_checkpoint) - self.architecture._is_init = True self.quantizer = MODELS.build(quantizer) self.input_shapes = input_shapes self.forward_modes = forward_modes + # Replace syncbn and _BatchNormXd (in mmengine) with batchnorm2d + self.quantizer.convert_batchnorm2d(self.architecture) + + # If we have a float_checkpoint, we load it as pretrain. + if float_checkpoint: + _ = load_checkpoint(self.architecture, float_checkpoint) + self.architecture._is_init = True + self.qmodels = self._build_qmodels(self.architecture) + self.sync_qparams('tensor') + self.reset_observer_and_fakequant_statistics(self) - self.sync_qparams('predict') + def reset_observer_and_fakequant_statistics(self, model): + """Reset the statistics in observers and fake quantizers. + + The forward computation in `_build_qmodels` can modify the original + statistics in observers and fake quantizers. + """ + for module in model.modules(): + if isinstance(module, (MinMaxObserver, PerChannelMinMaxObserver)): + module.reset_min_max_vals() + elif isinstance(module, FakeQuantizeBase): + module.scale.data = torch.ones_like(module.scale) + module.zero_point.data = torch.zeros_like(module.zero_point) def sync_qparams(self, src_mode: str): """Sync all quantize parameters in different `forward_modes`. We could @@ -106,10 +121,10 @@ def traverse(module, prefix): if src_param.shape == param.shape: param.data.copy_(src_param) else: - # requirs_grad = param.requires_grad - # param.requires_grad = False + requirs_grad = param.requires_grad + param.requires_grad = False param.resize_(src_param.shape) - # param.requires_grad = requirs_grad + param.requires_grad = requirs_grad param.data.copy_(src_param) for name, buffer in child.named_buffers(): buffer_name = f'{child_name}.{name}' @@ -160,6 +175,24 @@ def _build_qmodels(self, model: BaseModel): observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module + # data_samples can not be None in detectors during prediction. + # But we need to make the dummy prediction in _build_qmodels. + # It is more convenient to use `tensor` mode. + is_training = qmodels['tensor'].training + # Avoid random input changing bn's statistics + qmodels['tensor'].eval() + # Originally, the steps to train a qat model is as follows: + # 1. build qmodels 2. convert the model to ddpmodel 3. forward backward + # The shape of `scale` and `zero_point` can be modified during forward. + # We initialize these parameters with per-tensor mode by default for + # convenience. Their shape will be modified during forward if + # per-channel mode is used. It's hacky. Hence we need to input a + # dummy input to make sure the shape has been modified. + device = next(qmodels.parameters()).device + dummy_input = torch.randn(self.input_shapes).to(device) + qmodels['tensor'](dummy_input, None, 'tensor') + qmodels['tensor'].train(mode=is_training) + return qmodels def forward(self, @@ -183,7 +216,7 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]): @MODEL_WRAPPERS.register_module() class MMArchitectureQuantDDP(MMDistributedDataParallel): - """DDPwapper for GeneralQuant. + """DDPwapper for MMArchitectureQuant. Args: device_ids (Optional[Union[List, int, torch.device]]): devices to run @@ -203,6 +236,8 @@ def __init__(self, # (`model.cuda()`), the buffers in model are different. self.module.qmodels = self.module._build_qmodels( self.module.architecture) + self.module.sync_qparams('tensor') + self.module.reset_observer_and_fakequant_statistics(self) def calibrate_step(self, data: Union[Dict, Tuple, List]): """PTQ method need calibrate by cali data.""" diff --git a/mmrazor/models/fake_quants/__init__.py b/mmrazor/models/fake_quants/__init__.py index 9030660f6..950821210 100644 --- a/mmrazor/models/fake_quants/__init__.py +++ b/mmrazor/models/fake_quants/__init__.py @@ -1,5 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseFakeQuantize +from .lsq import (LearnableFakeQuantize, enable_param_learning, + enable_static_estimate, enable_val) from .torch_fake_quants import register_torch_fake_quants -__all__ = ['BaseFakeQuantize', 'register_torch_fake_quants'] +__all__ = [ + 'BaseFakeQuantize', 'register_torch_fake_quants', 'LearnableFakeQuantize', + 'enable_val', 'enable_param_learning', 'enable_static_estimate' +] diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py new file mode 100644 index 000000000..270140b85 --- /dev/null +++ b/mmrazor/models/fake_quants/lsq.py @@ -0,0 +1,273 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS + +try: + from torch.ao.quantization import FakeQuantizeBase +except ImportError: + from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') + + +def enable_param_learning(mod): + """Enables learning of quantization parameters, if applicable. Example + usage:: + + # model is any PyTorch model model.apply(enable_param_learning) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_param_learning() + + +def enable_static_estimate(mod): + """Enables static observer estimates, if applicable. Example usage:: + + # model is any PyTorch model model.apply(enable_static_estimate) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_static_estimate() + + +def enable_val(mod): + """Enable validation, if applicable. Example usage:: + + # model is any PyTorch model model.apply(enable_val) + """ + if isinstance(mod, LearnableFakeQuantize): + mod.enable_val() + + +@MODELS.register_module() +class LearnableFakeQuantize(FakeQuantizeBase): + """This is an extension of the FakeQuantize module in fake_quantize.py, + which supports learning of the scale and zero point parameters through + backpropagation. + + In addition to the attributes in the original FakeQuantize module, the + LearnableFakeQuantize module also includes the following attributes to + support quantization parameter learning. + + * :attr:`fake_quant_enabled` defines the flag for enabling fake + quantization on the output. + + * :attr:`static_enabled` defines the flag for using observer's static + estimation for scale and zero point. + + * :attr:`learning_enabled` defines the flag for enabling backpropagation + for scale and zero point. + + Args: + observer (module): Module for observing statistics on input tensors and + calculating scale and zero-point. + quant_min (int): Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max (int): Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + scale (float): The initial value of the floating-point scale factor. + Defaults to 1. + zero_point (float): The initial value of the floating-point zero-point. + Defaults to 0. + use_grad_scaling (bool): Whether the gradients for scale and zero point + are normalized by the constant, which is proportional to the square + root of the number of elements in the tensor. The related + literature justifying the use of this particular constant can be + found here: https://openreview.net/pdf?id=rkgO66VKDS. Defaults to + True. + zero_point_trainable (bool): Whether the zero_point is trainable. + Defaults to False. + observer_kwargs (dict | optional): Arguments for the observer module. + """ + + def __init__(self, + observer, + quant_min=0, + quant_max=255, + scale=1., + zero_point=0., + use_grad_scaling=True, + zero_point_trainable=False, + **observer_kwargs): + super(LearnableFakeQuantize, self).__init__() + assert quant_min < quant_max, \ + 'quant_min must be strictly less than quant_max.' + self.quant_min = quant_min + self.quant_max = quant_max + # also pass quant_min and quant_max to observer + observer_kwargs['quant_min'] = quant_min + observer_kwargs['quant_max'] = quant_max + self.use_grad_scaling = use_grad_scaling + + self.scale = Parameter(torch.tensor([scale])) + self.zero_point_trainable = zero_point_trainable + if zero_point_trainable: + self.zero_point = Parameter(torch.tensor([zero_point])) + else: + self.register_buffer('zero_point', torch.tensor([zero_point])) + + self.activation_post_process = observer(**observer_kwargs) + assert \ + torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \ + 'quant_min out of bound' + assert \ + quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \ + 'quant_max out of bound' + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + self.register_buffer('fake_quant_enabled', + torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('static_enabled', + torch.tensor([1], dtype=torch.uint8)) + self.register_buffer('learning_enabled', + torch.tensor([0], dtype=torch.uint8)) + + bitrange = torch.tensor(quant_max - quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.register_buffer('eps', + torch.tensor([torch.finfo(torch.float32).eps])) + + @torch.jit.export + def enable_param_learning(self): + """Enables learning of quantization parameters and disables static + observer estimates. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=True) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=False) + return self + + @torch.jit.export + def enable_static_estimate(self): + """Enables static observer estimates and disables learning of + quantization parameters. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def enable_val(self): + """Disables static observer accumulating data from input and doesn't + update the quantization parameters. + + Forward path returns fake quantized X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=True) \ + .toggle_observer_update(enabled=False) + + @torch.jit.export + def enable_static_observation(self): + """Enables static observer accumulating data from input but doesn't + update the quantization parameters. + + Forward path returns the original X. + """ + self.toggle_qparam_learning(enabled=False) \ + .toggle_fake_quant(enabled=False) \ + .toggle_observer_update(enabled=True) + + @torch.jit.export + def toggle_observer_update(self, enabled=True): + """Toggles whether static observer accumulates data from input.""" + self.static_enabled[0] = int(enabled) + return self + + @torch.jit.export + def enable_observer(self, enabled=True): + """Enables static observer accumulating data from input.""" + self.toggle_observer_update(enabled) + + @torch.jit.export + def toggle_qparam_learning(self, enabled=True): + """Toggles whether the quantization parameters are learnable.""" + self.learning_enabled[0] = int(enabled) + self.scale.requires_grad = enabled + if self.zero_point_trainable: + self.zero_point.requires_grad = enabled + return self + + @torch.jit.export + def toggle_fake_quant(self, enabled=True): + """Toggles whether the fake quantization is enabled.""" + self.fake_quant_enabled[0] = int(enabled) + return self + + @torch.jit.export + def observe_quant_params(self): + """Shows the quantization parameters.""" + print('LearnableFakeQuantize Scale: {}'.format(self.scale.detach())) + print('LearnableFakeQuantize Zero Point: {}'.format( + self.zero_point.detach())) + + @torch.jit.export + def calculate_qparams(self): + """Calculate the quantization parameters.""" + self.scale.data.clamp_(min=self.eps.item()) + scale = self.scale.detach() + zero_point = self.zero_point.detach().round().clamp( + self.quant_min, self.quant_max).long() + return scale, zero_point + + def forward(self, X): + """Forward computation. + + Forward path returns fake quantized X. + """ + if self.static_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = \ + self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + + if self.qscheme in (torch.per_channel_symmetric, + torch.per_channel_affine): + self.scale.data = torch.ones_like(_scale) + self.zero_point.data = torch.zeros_like(_zero_point.float()) + + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point) + else: + self.scale.data.clamp_(min=self.eps.item()) + + if self.fake_quant_enabled[0] == 1: + + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 + else: + grad_factor = 1.0 + if self.qscheme in (torch.per_channel_symmetric, + torch.per_channel_affine): + X = torch._fake_quantize_learnable_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + if not (self.quant_min <= self.zero_point <= self.quant_max): + print(self.quant_min, self.zero_point, self.quant_max) + X = torch._fake_quantize_learnable_per_tensor_affine( + X, self.scale, self.zero_point, self.quant_min, + self.quant_max, grad_factor) + + return X + + @torch.jit.export + def extra_repr(self): + """The printable representational string.""" + repr_str = f'static_enabled={self.static_enabled}, ' + repr_str += f'fake_quant_enabled={self.fake_quant_enabled}, ' + repr_str += f'quant_min={self.activation_post_process.quant_min}, ' + repr_str += f'quant_max={self.activation_post_process.quant_max}, ' + repr_str += f'dtype={self.dtype}, ' + repr_str += f'qscheme={self.qscheme}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'zero_point={self.zero_point}, ' + repr_str += f'zero_point_trainable={self.zero_point_trainable}' + return repr_str diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py index c82f902f5..84d1677dd 100644 --- a/mmrazor/models/observers/__init__.py +++ b/mmrazor/models/observers/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import BaseObserver +from .lsq import LSQObserver, LSQPerChannelObserver from .torch_observers import register_torch_observers -__all__ = ['BaseObserver', 'register_torch_observers'] +__all__ = [ + 'BaseObserver', 'register_torch_observers', 'LSQObserver', + 'LSQPerChannelObserver' +] diff --git a/mmrazor/models/observers/lsq.py b/mmrazor/models/observers/lsq.py new file mode 100644 index 000000000..ccab3b0e6 --- /dev/null +++ b/mmrazor/models/observers/lsq.py @@ -0,0 +1,129 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch +import torch.distributed as dist + +from mmrazor.registry import MODELS + +try: + from torch.ao.quantization.observer import (MinMaxObserver, + PerChannelMinMaxObserver) +except ImportError: + from mmrazor.utils import get_placeholder + MinMaxObserver = get_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_placeholder('torch>=1.13') + + +def sync_tensor(tensor): + """Synchronize the target tensor during distributed training.""" + if torch.distributed.is_initialized() and tensor.is_cuda: + tensor.data = tensor.data / dist.get_world_size() + dist.all_reduce(tensor.data) + return tensor + + +class LSQObserverMixIn: + """A mixin class for LSQObserver which can provide the initialized + floating-point scale factor.""" + + def __init__(self): + self.tensor_norm = None + + @torch.jit.export + def _calculate_scale(self): + """Calculate the initialized floating-point scale factor. + + Each layer of weights and each layer of activations has a distinct step + size, represented as a fp32 value, initialized to 2<|v|> / sqrt(Q_p), + computed on either the initial weights values or the first batch of + activations, respectively. + """ + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + sync_tensor(scale) + return scale + + +@MODELS.register_module() +class LSQObserver(MinMaxObserver, LSQObserverMixIn): + """LSQ observer. + + Paper: Learned Step Size Quantization. + """ + + def __init__(self, *args, **kwargs): + MinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the running minimum, maximum and tensor_norm of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + self.tensor_norm = x.abs().mean() + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = MinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point + + +@MODELS.register_module() +class LSQPerChannelObserver(PerChannelMinMaxObserver, LSQObserverMixIn): + """LSQ per-channel observer. + + Paper: Learned Step Size Quantization. + """ + + def __init__(self, *args, **kwargs): + PerChannelMinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the per-channel running minimum, maximum and tensor_norm of + ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + min_val = self.min_val + max_val = self.max_val + x_dim = x.size() + + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + # Need to match dtype of min/max because the updates to buffers + # are done in place and types need to match for comparisons + y = y.to(self.min_val.dtype) + y = torch.flatten(y, start_dim=1) + + self.tensor_norm = y.abs().mean(1) + + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch.aminmax(y, dim=1) + else: + min_val_cur, max_val_cur = torch.aminmax(y, dim=1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = PerChannelMinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 5dc24609f..996314d27 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -2,13 +2,33 @@ import inspect from typing import List +import torch + from mmrazor.registry import MODELS try: import torch.ao.quantization.observer as torch_observer_src + from torch.ao.quantization.observer import PerChannelMinMaxObserver except ImportError: from mmrazor.utils import get_package_placeholder torch_observer_src = get_package_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_package_placeholder('torch>=1.13') + + +@torch.jit.export +def reset_min_max_vals(self): + """Resets the min/max values. + + `min_val` and `max_val` are always be on cpu in the pytorch version of this + method. + """ + min_val = torch.rand(0, ) + max_val = torch.rand(0, ) + self.min_val.resize_(min_val.shape).copy_(min_val) + self.max_val.resize_(max_val.shape).copy_(max_val) + + +PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals def register_torch_observers() -> List[str]: diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index a6cfc257c..0dbe6dcdd 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -103,6 +103,11 @@ def prepare(self, model, concrete_args=None): setattr(graph_module, attr_name, getattr(model, attr_name)) fuse_custom_config = FuseCustomConfig().set_preserved_attributes( preserved_attributes) + + # set the training modes of all modules to True to `_fuse_fx` correctly + # todo: check freezebn + self.sync_module_training_mode(graph_module, mode=True) + graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 866199735..78c8163c7 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -3,7 +3,9 @@ from typing import Dict import torch +import torch.nn as nn from mmengine.model import BaseModule +from mmengine.model.utils import _BatchNormXd from mmrazor.registry import TASK_UTILS @@ -24,6 +26,38 @@ def __init__(self, tracer: Dict): super().__init__() self.tracer = TASK_UTILS.build(tracer) + def sync_module_training_mode(self, model, mode=True): + """Synchronize the training modes. + + Note that modes of conv and bn must be the same during ``_fuse_fx``. + """ + for module in model.modules(): + module.training = mode + return + + @staticmethod + def convert_batchnorm2d(model): + """Helper function to convert all :attr:`_BatchNormXd` layers and + :class:`torch.nn.SyncBatchNorm` layers in the model to + :class:`torch.nn.BatchNorm2d` layers. + """ + # todo: Convert all `_BatchNormXd` and `SyncBatchNorm` + # layers to `BatchNorm2d` layers but they may be :attr:`BatchNorm*D` + # layers + module_checklist = [nn.modules.batchnorm.SyncBatchNorm, _BatchNormXd] + + def traverse(module: nn.Module): + for child_name, child in module.named_children(): + if isinstance(child, tuple(module_checklist)): + bn = nn.BatchNorm2d(child.num_features, child.eps, + child.momentum, child.affine, + child.track_running_stats) + setattr(module, child_name, bn) + else: + traverse(child) + + traverse(model) + @abstractmethod def prepare(self, model): """Prepare for quantizing model, which usually includes as follows: diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index b5de1c028..1d566b45f 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -9,7 +9,9 @@ from torch.ao.quantization import enable_fake_quant from torch.ao.quantization.fx import prepare from torch.ao.quantization.fx.graph_module import ObservedGraphModule - from torch.ao.quantization.qconfig_mapping import QConfigMapping + from torch.ao.quantization.qconfig_mapping import ( + _FIXED_QPARAMS_OP_TO_OBSERVER, FixedQParamsFakeQuantize, QConfig, + QConfigMapping, default_weight_fake_quant) from torch.ao.quantization.quantize_fx import _fuse_fx from torch.fx.graph_module import GraphModule from torch.nn.intrinsic.qat import modules as qat_fused_modules @@ -24,6 +26,10 @@ _fuse_fx = get_placeholder('torch>=1.13') qat_fused_modules = get_package_placeholder('torch>=1.13') qat_modules = get_package_placeholder('torch>=1.13') + _FIXED_QPARAMS_OP_TO_OBSERVER = get_package_placeholder('torch>=1.13') + FixedQParamsFakeQuantize = get_package_placeholder('torch>=1.13') + QConfig = get_package_placeholder('torch>=1.13') + default_weight_fake_quant = get_package_placeholder('torch>=1.13') from mmrazor import digit_version from mmrazor.models.task_modules.tracer import build_graphmodule @@ -117,19 +123,47 @@ def __init__(self, assert w_mode in self.support_w_modes assert a_mode in self.support_a_modes - self.qconfig_mapping = QConfigMapping().set_global( - self.qconfig.convert()) - if no_observer_modules: - self.no_observer_modules = str2class(no_observer_modules) - for mod in self.no_observer_modules: - self.qconfig_mapping.set_object_type(mod, None) - else: - self.no_observer_modules = no_observer_modules + self.qconfig_mapping = self.gen_qconfig_mapping( + self.qconfig, no_observer_modules) + self.no_observer_modules = no_observer_modules + self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) self.extra_redundant_fakequants = extra_redundant_fakequants + def gen_qconfig_mapping(self, qconfig, no_observer_modules): + """Convert qconfig in config file to `QConfigMapping`. + + `QConfigMapping` is a custom class for mapping from model ops to + :class:`torch.ao.quantization.QConfig` s. + """ + qconfig_mapping = QConfigMapping().set_global(qconfig.convert()) + + if no_observer_modules is not None: + no_observer_modules = str2class(no_observer_modules) + for mod in no_observer_modules: + qconfig_mapping.set_object_type(mod, None) + + fixed_qparams_observer_to_qconfig = {} + for fixed_qparams_op, observer in _FIXED_QPARAMS_OP_TO_OBSERVER.items( + ): + if observer in fixed_qparams_observer_to_qconfig: + fixed_qparams_qconfig = fixed_qparams_observer_to_qconfig[ + observer] + else: + activation = FixedQParamsFakeQuantize.with_args( + observer=observer) + + fixed_qparams_qconfig = QConfig( + activation=activation, weight=default_weight_fake_quant) + fixed_qparams_observer_to_qconfig[ + observer] = fixed_qparams_qconfig + qconfig_mapping.set_object_type(fixed_qparams_op, + fixed_qparams_qconfig) + + return qconfig_mapping + @property def backend(self): """The key of the corresponding backend config.""" @@ -169,6 +203,11 @@ def prepare(self, model, concrete_args=None): self.swap_ff_with_fxff(model) traced_graph = self.tracer.trace(model, concrete_args=concrete_args) graph_module = build_graphmodule(model, traced_graph) + + # set the training modes of all modules to True to `_fuse_fx` correctly + # todo: check freezebn + self.sync_module_training_mode(graph_module, mode=True) + graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index cb7d3084b..f8a25bd56 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -59,6 +59,7 @@ def prepare_for_mmdeploy(self, 3. post process weight fakequant for exporting .onnx that meet the backend's requirement. """ + self.convert_batchnorm2d(model) observed_model = self.prepare(model) if dummy_input is not None: observed_model(torch.randn(dummy_input)) diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py index a3cff1167..68d5f0809 100644 --- a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import functools +from copy import deepcopy from types import FunctionType from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union @@ -152,6 +153,33 @@ def _get_attrs(target, attrs): return module_dict +def duplicate_reused_nodes(graph: Graph, modules: Dict[str, Any] = {}): + """Deepcopy the shared modules (e.g. shared detection head in RetinaNet) to + make sure modules can be fused correctly. + + Modified from https://github.com/ModelTC/MQBench/blob/main/mqbench/prepare_by_platform.py # noqa: E501 + """ + _dup_prefix = '_dup' + target_dict = dict() + dup_modules = dict() + for node in graph.nodes: + if node.op == 'call_module': + if node.target not in target_dict: + target_dict[node.target] = [node] + else: + target_dict[node.target].append(node) + for key in target_dict: + if len(target_dict[key]) > 1: + for idx, node in enumerate(target_dict[key]): + if idx == 0: + continue + module = deepcopy(modules[node.target]) + node.target += _dup_prefix + str(idx) + dup_modules[node.target] = module + graph.lint() + return graph, dup_modules + + def build_graphmodule(model: torch.nn.Module, fx_graph, name: str = 'GraphModule'): @@ -180,7 +208,9 @@ def build_graphmodule(model: torch.nn.Module, """ modules = dict(model.named_modules()) module_dict = _prepare_module_dict(model, fx_graph) + fx_graph, duplicated_modules = duplicate_reused_nodes(fx_graph, modules) modules.update(module_dict) + modules.update(duplicated_modules) return GraphModule(modules, fx_graph, name) @@ -272,6 +302,7 @@ def call_method(self, m: torch.nn.Module, name: str, method: Callable, kwargs (Dict): kwargs of the module callsite Return: + The return value from the Module call. In the case that a ``call_module`` node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever value was returned from the ``Module`` diff --git a/mmrazor/models/task_modules/tracer/fx/graph_utils.py b/mmrazor/models/task_modules/tracer/fx/graph_utils.py index 5e3ddc2f4..ca1291711 100644 --- a/mmrazor/models/task_modules/tracer/fx/graph_utils.py +++ b/mmrazor/models/task_modules/tracer/fx/graph_utils.py @@ -6,9 +6,11 @@ try: from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.fx import Node except ImportError: from mmrazor.utils import get_placeholder FakeQuantizeBase = get_placeholder('torch>=1.13') + Node = get_placeholder('torch>=1.13') def _get_attrs(target: torch.nn.Module, attr: str) -> Any: @@ -61,11 +63,11 @@ def recursive_find_erased_nodes(node, prepared_model): nodes_to_erase = [] for prev_node in node.args: - if isinstance(prev_node, torch.fx.Node): + if isinstance(prev_node, Node): nodes_to_erase.extend( recursive_find_erased_nodes(prev_node, prepared_model)) for prev_node in node.kwargs.values(): - if isinstance(prev_node, torch.fx.Node): + if isinstance(prev_node, Node): nodes_to_erase.extend( recursive_find_erased_nodes(prev_node, prepared_model)) @@ -94,7 +96,7 @@ def del_fakequant_before_op(prepared_model, new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: if node.op in target_ops: - nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + nodes_to_erase: List[Node] = recursive_find_erased_nodes( node, prepared_model) for to_erase in nodes_to_erase: assert to_erase.op == 'call_module' and isinstance( @@ -172,7 +174,7 @@ def del_fakequant_before_method(prepared_model, new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: if node.op == 'call_method' and node.target in method_patterns: - nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + nodes_to_erase: List[Node] = recursive_find_erased_nodes( node, prepared_model) for to_erase in nodes_to_erase: assert to_erase.op == 'call_module' and isinstance( @@ -251,7 +253,7 @@ def del_fakequant_before_function(prepared_model, new_graph = copy.deepcopy(prepared_model.graph) for node in new_graph.nodes: if node.op == 'call_function' and node.target in function_patterns: - nodes_to_erase: List[torch.fx.Node] = recursive_find_erased_nodes( + nodes_to_erase: List[Node] = recursive_find_erased_nodes( node, prepared_model) for to_erase in nodes_to_erase: assert to_erase.op == 'call_module' and isinstance( diff --git a/tests/test_models/test_algorithms/test_mm_architecture.py b/tests/test_models/test_algorithms/test_mm_architecture.py index 4862bff91..639e0f492 100644 --- a/tests/test_models/test_algorithms/test_mm_architecture.py +++ b/tests/test_models/test_algorithms/test_mm_architecture.py @@ -61,11 +61,10 @@ def _inner_forward(x): return out -@MODELS.register_module() -class ToyQuantModel(BaseModel): +class ToyModel(nn.Module): def __init__(self): - super().__init__() + super(ToyModel, self).__init__() self.stem_layer = nn.Sequential( nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -85,15 +84,34 @@ def forward(self, x): return x +class ToyQuantModel(BaseModel): + + def __init__(self): + super().__init__() + self.architecture = ToyModel() + + def loss(self, outputs, data_samples): + return dict(loss=outputs.sum() - data_samples.sum()) + + def forward(self, inputs, data_samples, mode: str = 'tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + outputs = self.architecture(inputs) + + return outputs + + class TestMMArchitectureQuant(TestCase): def setUp(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') + + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() filename = 'fp_model.pth' filename = os.path.join(self.temp_dir, filename) - # import pdb; pdb.set_trace() toymodel = ToyQuantModel() torch.save(toymodel.state_dict(), filename) @@ -120,18 +138,14 @@ def setUp(self): quantizer=dict( type='mmrazor.OpenVINOQuantizer', global_qconfig=global_qconfig, - tracer=dict( - type='mmrazor.CustomTracer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ]))) + tracer=dict(type='mmrazor.CustomTracer'))) self.alg_kwargs = alg_kwargs self.toy_model = MODELS.build(self.alg_kwargs) def tearDown(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') + MODELS.module_dict.pop('ToyQuantModel') shutil.rmtree(self.temp_dir) def test_init(self): @@ -145,12 +159,12 @@ def test_sync_qparams(self): self.skipTest('version of torch < 1.13.0') mode = self.toy_model.forward_modes[0] self.toy_model.sync_qparams(mode) - w_loss = self.toy_model.qmodels['loss'].block.conv1.state_dict( - )['weight'] - w_tensor = self.toy_model.qmodels['tensor'].block.conv1.state_dict( - )['weight'] - w_pred = self.toy_model.qmodels['predict'].block.conv1.state_dict( - )['weight'] + w_loss = self.toy_model.qmodels[ + 'loss'].architecture.block.conv1.state_dict()['weight'] + w_tensor = self.toy_model.qmodels[ + 'tensor'].architecture.block.conv1.state_dict()['weight'] + w_pred = self.toy_model.qmodels[ + 'predict'].architecture.block.conv1.state_dict()['weight'] assert w_loss.equal(w_pred) assert w_loss.equal(w_tensor) diff --git a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py index d6b670bb5..bd8fcbd50 100644 --- a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py +++ b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py @@ -1,23 +1,186 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase +import torch +from torch.nn.parameter import Parameter + +from mmrazor import digit_version +from mmrazor.models import LearnableFakeQuantize + +try: + from torch.ao.quantization import MovingAverageMinMaxObserver +except ImportError: + from mmrazor.utils import get_placeholder + MovingAverageMinMaxObserver = get_placeholder('torch>=1.13') + class TestLearnableFakeQuantize(TestCase): - def test_init(self): - pass + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.zero_point_trainable_fakequant = LearnableFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, + zero_point_trainable=True) + + self.zero_point_untrainable_fakequant = \ + LearnableFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=True, + zero_point_trainable=False) def test_repr(self): - pass + fq_module = self.zero_point_untrainable_fakequant() + repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += f'fake_quant_enabled=' \ + f'{torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += 'quant_min=0, ' + repr_str += 'quant_max=127, ' + repr_str += f'dtype={torch.quint8}, ' + repr_str += f'qscheme={torch.per_tensor_affine}, ' + repr_str += f'scale={Parameter(torch.tensor([1.0]))}, ' + repr_str += f'zero_point={torch.tensor([0.])}, ' + repr_str += 'zero_point_trainable=False' + self.assertEqual(fq_module.extra_repr(), repr_str) + + fq_module = self.zero_point_trainable_fakequant() + repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += f'fake_quant_enabled=' \ + f'{torch.tensor([1], dtype=torch.uint8)}, ' + repr_str += 'quant_min=0, ' + repr_str += 'quant_max=127, ' + repr_str += f'dtype={torch.quint8}, ' + repr_str += f'qscheme={torch.per_tensor_affine}, ' + repr_str += f'scale={Parameter(torch.tensor([1.0]))}, ' + repr_str += f'zero_point={Parameter(torch.tensor([0.]))}, ' + repr_str += 'zero_point_trainable=True' + self.assertEqual(fq_module.extra_repr(), repr_str) def test_calculate_qparams(self): - pass + fq_module = self.zero_point_untrainable_fakequant() + scale, zero_point = fq_module.calculate_qparams() + self.assertEqual(scale, 1.) + self.assertEqual(zero_point, 0.) + + fq_module = self.zero_point_trainable_fakequant() + scale, zero_point = fq_module.calculate_qparams() + self.assertEqual(scale, 1.) + self.assertEqual(zero_point, 0.) def test_forward(self): - pass + fq_module = self.zero_point_untrainable_fakequant() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + # Output of fake quant is not identical to input + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # self.assertNotEqual(Y, X) + fq_module.toggle_fake_quant(False) + X = torch.rand(20, 10, dtype=torch.float32) + Y = fq_module(X) + # Fake quant is disabled,output is identical to input + self.assertTrue(torch.equal(Y, X)) + + # Explicit copy at this point in time, because FakeQuant keeps internal + # state in mutable buffers. + scale = fq_module.scale.clone().detach() + zero_point = fq_module.zero_point.clone().detach() + + fq_module.toggle_observer_update(False) + fq_module.toggle_fake_quant(True) + X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is disabled, scale and zero-point do not change + self.assertEqual(fq_module.scale, scale) + self.assertEqual(fq_module.zero_point, zero_point) + + fq_module.toggle_observer_update(True) + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is enabled, scale and zero-point are different + self.assertNotEqual(fq_module.scale, scale) + self.assertNotEqual(fq_module.zero_point, zero_point) + + fq_module = self.zero_point_trainable_fakequant() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + # Output of fake quant is not identical to input + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # self.assertNotEqual(Y, X) + fq_module.toggle_fake_quant(False) + X = torch.rand(20, 10, dtype=torch.float32) + Y = fq_module(X) + # Fake quant is disabled,output is identical to input + self.assertTrue(torch.equal(Y, X)) + + # Explicit copy at this point in time, because FakeQuant keeps internal + # state in mutable buffers. + scale = fq_module.scale.clone().detach() + zero_point = fq_module.zero_point.clone().detach() + + fq_module.toggle_observer_update(False) + fq_module.toggle_fake_quant(True) + X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0 + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is disabled, scale and zero-point do not change + self.assertEqual(fq_module.scale, scale) + self.assertEqual(fq_module.zero_point, zero_point) + + fq_module.toggle_observer_update(True) + Y = fq_module(X) + self.assertFalse(torch.equal(Y, X)) + # Observer is enabled, scale and zero-point are different + self.assertNotEqual(fq_module.scale, scale) + self.assertNotEqual(fq_module.zero_point, zero_point) + + def test_state(self): + fq_module = self.zero_point_untrainable_fakequant() + + fq_module.enable_param_learning() + self.assertEqual(fq_module.learning_enabled[0], 1) + self.assertEqual(fq_module.scale.requires_grad, 1) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) + + fq_module.enable_static_estimate() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 1) + + fq_module.enable_val() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) + + fq_module.enable_static_observation() + self.assertEqual(fq_module.learning_enabled[0], 0) + self.assertEqual(fq_module.scale.requires_grad, 0) + self.assertEqual(fq_module.zero_point.requires_grad, 0) + self.assertEqual(fq_module.fake_quant_enabled[0], 0) + self.assertEqual(fq_module.static_enabled[0], 1) - def test_load_state_dict(self): - pass + fq_module = self.zero_point_trainable_fakequant() - def test_save_state_dict(self): - pass + fq_module.enable_param_learning() + self.assertEqual(fq_module.learning_enabled[0], 1) + self.assertEqual(fq_module.scale.requires_grad, 1) + self.assertEqual(fq_module.zero_point.requires_grad, 1) + self.assertEqual(fq_module.fake_quant_enabled[0], 1) + self.assertEqual(fq_module.static_enabled[0], 0) diff --git a/tests/test_models/test_observers/test_lsq_observer.py b/tests/test_models/test_observers/test_lsq_observer.py new file mode 100644 index 000000000..a61f95d7f --- /dev/null +++ b/tests/test_models/test_observers/test_lsq_observer.py @@ -0,0 +1,77 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch + +from mmrazor import digit_version +from mmrazor.models import LSQObserver, LSQPerChannelObserver + + +class TestLSQObserver(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.lsq = LSQObserver.with_args( + dtype=torch.quint8, + qscheme=torch.per_tensor_symmetric, + reduce_range=False, + quant_min=0, + quant_max=255) + + def test_forward(self): + lsq_observer = self.lsq() + torch.manual_seed(42) + X = torch.rand(20, 10, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + X = torch.rand(0, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + def test_calculate_qparams(self): + lsq_observer = self.lsq() + X = torch.ones(10, dtype=torch.float32) + _ = lsq_observer(X) + scale, zero_point = lsq_observer.calculate_qparams() + # tensor_norm = 1, quant_max = 255 + self.assertEqual(scale, 2 * torch.tensor([1.]) / (255**0.5)) + self.assertEqual(zero_point, 127) + + +class TestLSQPerChannelObserver(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.lsq = LSQPerChannelObserver.with_args( + dtype=torch.qint8, + qscheme=torch.per_channel_symmetric, + reduce_range=False, + quant_min=-127, + quant_max=127) + + def test_forward(self): + lsq_observer = self.lsq() + torch.manual_seed(42) + X = torch.rand(2, 10, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + X = torch.rand(0, dtype=torch.float32) + Y = lsq_observer(X) + # Output of observer is identical to input + self.assertTrue(torch.equal(Y, X)) + + def test_calculate_qparams(self): + lsq_observer = self.lsq() + X = torch.ones(2, 10, dtype=torch.float32) + X[0] -= 1 + _ = lsq_observer(X) + scale, zero_point = lsq_observer.calculate_qparams() + self.assertEqual(scale[0], 2 * torch.tensor([0.]) / (127**0.5)) + self.assertEqual(scale[1], 2 * torch.tensor([1.]) / (127**0.5)) diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py index 62052f66f..06a12c206 100644 --- a/tests/test_models/test_quantizers/test_native_quantizer.py +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -import collections from unittest import TestCase import torch import torch.nn as nn from mmrazor import digit_version +from mmrazor.models.quantizers import NativeQuantizer from mmrazor.models.quantizers.native_quantizer import SUPPORT_QAT_MODULES from mmrazor.models.task_modules.tracer import CustomTracer from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ @@ -155,11 +155,7 @@ def test_init(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') native_quantizer = MODELS.build(self.q_kwargs) - no_ob_dict = collections.OrderedDict() - no_ob_dict = no_ob_dict.fromkeys(native_quantizer.no_observer_modules, - None) - assert native_quantizer.qconfig_mapping.object_type_qconfigs == \ - no_ob_dict + self.assertIsInstance(native_quantizer, NativeQuantizer) def test_prepare(self): if digit_version(torch.__version__) < digit_version('1.13.0'): diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py new file mode 100644 index 000000000..6a300fb91 --- /dev/null +++ b/tests/test_runners/test_quantization_loop.py @@ -0,0 +1,413 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging +import shutil +import tempfile +from unittest import TestCase + +import torch +import torch.nn as nn +from mmengine.config import Config, ConfigDict +from mmengine.evaluator import BaseMetric +from mmengine.hooks import Hook +from mmengine.logging import MMLogger +from mmengine.model import BaseModel +from mmengine.optim import OptimWrapper +from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS, OPTIM_WRAPPERS +from mmengine.runner import Runner +from torch.nn.intrinsic.qat import ConvBnReLU2d +from torch.utils.data import Dataset + +from mmrazor import digit_version +from mmrazor.engine import (LSQEpochBasedLoop, PTQLoop, QATEpochBasedLoop, + QATValLoop) + +try: + from torch.ao.nn.quantized import FloatFunctional, FXFloatFunctional + from torch.ao.quantization import QConfigMapping + from torch.ao.quantization.fake_quantize import FakeQuantizeBase + from torch.ao.quantization.fx import prepare + from torch.ao.quantization.qconfig_mapping import \ + get_default_qconfig_mapping + from torch.ao.quantization.quantize_fx import _fuse_fx +except ImportError: + from mmrazor.utils import get_placeholder + QConfigMapping = get_placeholder('torch>=1.13') + FakeQuantizeBase = get_placeholder('torch>=1.13') + prepare = get_placeholder('torch>=1.13') + _fuse_fx = get_placeholder('torch>=1.13') + get_default_qconfig_mapping = get_placeholder('torch>=1.13') + FloatFunctional = get_placeholder('torch>=1.13') + FXFloatFunctional = get_placeholder('torch>=1.13') + + +class ToyDataset(Dataset): + METAINFO = dict() # type: ignore + data = torch.randn(12, 3, 4, 4) + label = torch.ones(12) + + @property + def metainfo(self): + return self.METAINFO + + def __len__(self): + return self.data.size(0) + + def __getitem__(self, index): + return dict(inputs=self.data[index], data_sample=self.label[index]) + + +class MMArchitectureQuant(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + self.architecture = ToyModel() + + def calibrate_step(self, data): + data = self.data_preprocessor(data, False) + return self.architecture(**data) + + def sync_qparams(self, src_mode): + pass + + def forward(self, inputs, data_sample, mode='tensor'): + return self.architecture(inputs, data_sample, mode) + + +class ToyModel(BaseModel): + + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) + qconfig = get_default_qconfig_mapping().to_dict()[''] + self.architecture = nn.Sequential( + ConvBnReLU2d(3, 3, 1, qconfig=qconfig)) + + def forward(self, inputs, data_sample, mode='tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + if isinstance(data_sample, list): + data_sample = torch.stack(data_sample) + outputs = self.architecture(inputs) + + if mode == 'tensor': + return outputs + elif mode == 'loss': + loss = data_sample.sum() - outputs.sum() + outputs = dict(loss=loss) + return outputs + elif mode == 'predict': + return outputs + + +class ToyOptimWrapper(OptimWrapper): + ... + + +class ToyMetric1(BaseMetric): + + def __init__(self, collect_device='cpu', dummy_metrics=None): + super().__init__(collect_device=collect_device) + self.dummy_metrics = dummy_metrics + + def process(self, data_batch, predictions): + result = {'acc': 1} + self.results.append(result) + + def compute_metrics(self, results): + return dict(acc=1) + + +DEFAULT_CFG = ConfigDict( + model=dict(type='MMArchitectureQuant'), + train_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=True), + batch_size=3, + num_workers=0), + val_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + test_dataloader=dict( + dataset=dict(type='ToyDataset'), + sampler=dict(type='DefaultSampler', shuffle=False), + batch_size=3, + num_workers=0), + optim_wrapper=dict( + type='OptimWrapper', optimizer=dict(type='SGD', lr=0.01)), + val_evaluator=dict(type='ToyMetric1'), + test_evaluator=dict(type='ToyMetric1'), + train_cfg=dict(), + val_cfg=dict(), + test_cfg=dict(), + custom_hooks=[], + data_preprocessor=None, + launcher='none', + env_cfg=dict(dist_cfg=dict(backend='nccl')), +) + + +class TestQATEpochBasedLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.train_cfg = ConfigDict( + type='mmrazor.QATEpochBasedLoop', + max_epochs=4, + val_begin=1, + val_interval=1, + disable_observer_begin=-1, + freeze_bn_begin=-1, + dynamic_intervals=None) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_qat_train_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.train_loop, QATEpochBasedLoop) + + def test_run_epoch(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_train' + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestFreezeBNHook(Hook): + + def __init__(self, freeze_bn_begin): + self.freeze_bn_begin = freeze_bn_begin + + def after_train_epoch(self, runner): + + def check_bn_stats(mod): + if isinstance(mod, ConvBnReLU2d): + assert mod.freeze_bn + assert not mod.bn.training + + if runner.train_loop._epoch + 1 >= self.freeze_bn_begin: + runner.model.apply(check_bn_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_freeze_bn' + cfg.custom_hooks = [ + dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1) + ] + cfg.train_cfg.freeze_bn_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestDisableObserverHook(Hook): + + def __init__(self, disable_observer_begin): + self.disable_observer_begin = disable_observer_begin + + def after_train_epoch(self, runner): + + def check_observer_stats(mod): + if isinstance(mod, FakeQuantizeBase): + assert mod.fake_quant_enabled[0] == 0 + + if runner.train_loop._epoch + 1 >= self.disable_observer_begin: + runner.model.apply(check_observer_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_disable_observer' + cfg.custom_hooks = [ + dict( + type='TestDisableObserverHook', + priority=50, + disable_observer_begin=1) + ] + cfg.train_cfg.disable_observer_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + +class TestLSQEpochBasedLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.train_cfg = ConfigDict( + type='mmrazor.LSQEpochBasedLoop', + max_epochs=4, + val_begin=1, + val_interval=1, + freeze_bn_begin=-1, + dynamic_intervals=None) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_lsq_train_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.train_loop, LSQEpochBasedLoop) + + def test_run_epoch(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_train' + runner = Runner.from_cfg(cfg) + runner.train() + + @HOOKS.register_module(force=True) + class TestFreezeBNHook(Hook): + + def __init__(self, freeze_bn_begin): + self.freeze_bn_begin = freeze_bn_begin + + def after_train_epoch(self, runner): + + def check_bn_stats(mod): + if isinstance(mod, ConvBnReLU2d): + assert mod.freeze_bn + assert not mod.bn.training + + if runner.train_loop._epoch + 1 >= self.freeze_bn_begin: + runner.model.apply(check_bn_stats) + + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_freeze_bn' + cfg.custom_hooks = [ + dict(type='TestFreezeBNHook', priority=50, freeze_bn_begin=1) + ] + cfg.train_cfg.freeze_bn_begin = 1 + runner = Runner.from_cfg(cfg) + runner.train() + + +class TestQATValLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + default_cfg.val_cfg = ConfigDict(type='mmrazor.QATValLoop') + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_qat_val_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.val_loop, QATValLoop) + + def test_run(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_qat_val' + cfg.pop('train_dataloader') + cfg.pop('train_cfg') + cfg.pop('optim_wrapper') + cfg.pop('test_dataloader') + cfg.pop('test_cfg') + cfg.pop('test_evaluator') + runner = Runner.from_cfg(cfg) + runner.val() + + +class TestPTQLoop(TestCase): + + def setUp(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + self.temp_dir = tempfile.mkdtemp() + MODELS.register_module(module=MMArchitectureQuant, force=True) + DATASETS.register_module(module=ToyDataset, force=True) + METRICS.register_module(module=ToyMetric1, force=True) + OPTIM_WRAPPERS.register_module(module=ToyOptimWrapper, force=True) + + default_cfg = copy.deepcopy(DEFAULT_CFG) + default_cfg = Config(default_cfg) + default_cfg.work_dir = self.temp_dir + # save_checkpoint in PTQLoop need train_dataloader + default_cfg.train_cfg = ConfigDict(by_epoch=True, max_epochs=3) + default_cfg.test_cfg = ConfigDict( + type='mmrazor.PTQLoop', + calibrate_dataloader=default_cfg.train_dataloader, + calibrate_steps=32) + self.default_cfg = default_cfg + + def tearDown(self): + MODELS.module_dict.pop('MMArchitectureQuant') + DATASETS.module_dict.pop('ToyDataset') + METRICS.module_dict.pop('ToyMetric1') + OPTIM_WRAPPERS.module_dict.pop('ToyOptimWrapper') + + logging.shutdown() + MMLogger._instance_dict.clear() + shutil.rmtree(self.temp_dir) + + def test_init(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_init_ptq_loop' + runner = Runner(**cfg) + self.assertIsInstance(runner, Runner) + self.assertIsInstance(runner.test_loop, PTQLoop) + + def test_run(self): + cfg = copy.deepcopy(self.default_cfg) + cfg.experiment_name = 'test_ptq_run' + runner = Runner.from_cfg(cfg) + runner.test() From 2a61cd0fb985d166e4cbb587dec5e60440f8cdd1 Mon Sep 17 00:00:00 2001 From: Ivan Zhang <51170394+415905716@users.noreply.github.com> Date: Wed, 18 Jan 2023 17:55:28 +0800 Subject: [PATCH 13/44] [Docs] Add customize_quantization_tutorial (#440) --- .../customize_quantization_algorithms.md | 283 ++++++++++++++++++ docs/en/advanced_guides/index.rst | 1 + 2 files changed, 284 insertions(+) create mode 100644 docs/en/advanced_guides/customize_quantization_algorithms.md diff --git a/docs/en/advanced_guides/customize_quantization_algorithms.md b/docs/en/advanced_guides/customize_quantization_algorithms.md new file mode 100644 index 000000000..b9a4d05bb --- /dev/null +++ b/docs/en/advanced_guides/customize_quantization_algorithms.md @@ -0,0 +1,283 @@ +# Customize Quantization algorithms + +Here we show how to develop new QAT algorithms with an example of LSQ on OpenVINO backend. + +This document is mainly aimed at QAT because the ptq process is relatively fixed and the components we provide can meet most of the needs. We will first give an overview of the overall required development components, and then introduce the specific implementation step by step. + +## Overall + +In the mmrazor quantization pipeline, in order to better support the openmmlab environment, we have configured most of the code modules for users. You can configure all the components directly in the config file. How to configure them can be found in our [file](https://github.com/open-mmlab/mmrazor/blob/quantize/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py). + +```Python +global_qconfig = dict( + w_observer=dict(), + a_observer=dict(), + w_fake_quant=dict(), + a_fake_quant=dict(), + w_qscheme=dict(), + a_qscheme=dict(), +) +model = dict( + type='mmrazor.MMArchitectureQuant', + architecture=resnet, + quantizer=dict( + type='mmrazor.OpenvinoQuantizer', + global_qconfig=global_qconfig, + tracer=dict())) +train_cfg = dict(type='mmrazor.LSQEpochBasedLoop') +``` + +For `algorithm` and `tracer`, we recommend that you use the default configurations `MMArchitectureQuant` and `CustomTracer` provided by us. These two module operators are specially built for the openmmlab environment, while other modules can refer to the following steps and choose or develop new operators according to your needs. + +To adapt to different backends, you need to select a different `quantizer`. + +To develop new quantization algorithms, you need to define new `observer` and `fakequant`. + +If the existing `loop` does not meet your needs, you may need to make some changes to the existing `loop` based on your algorithm. + +## Detailed steps + +1. Select a quantization algorithm + +We recommend that you directly use the`MMArchitectureQuant` in `mmrazor/models/algorithms/quantization/mm_architecture.py`.The class `MMArchitectureQuant` inherits from class `BaseAlgorithm`. + +This structure is built for the model in openmmlab. If you have other requirements, you can also refer to this [document](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/customize_architectures.html#develop-common-model-components) to design the overall framework. + +2. Select quantizer + +At present, the quantizers we support are `NativeQuantizer`, `OpenVINOQuantizer`, `TensorRTQuantizer` and `AcademicQuantizer` in `mmrazor/models/quantizers/`. `AcademicQuantizer` and `NativeQuantizer` inherit from class `BaseQuantizer` in `mmrazor/models/quantizers/base.py`: + +```Python +class BaseQuantizer(BaseModule): + def __init__(self, tracer): + super().__init__() + self.tracer = TASK_UTILS.build(tracer) + @abstractmethod + def prepare(self, model, graph_module): + """tmp.""" + pass + def swap_ff_with_fxff(self, model): + pass +``` + +`NativeQuantizer` is the operator we developed to adapt to the environment of mmrazor according to pytorch's official quantization logic. `AcademicQuantizer` is an operator designed for academic research to give users more space to operate. + +The class `OpenVINOQuantizer` and `TensorRTQuantizer` inherits from class `NativeQuantize`. They adapted `OpenVINO` and `TensorRT`backend respectively. You can also try to develop a quantizer based on other backends according to your own needs. + +3. Select tracer + +Tracer we use `CustomTracer` in `mmrazor/models/task_modules/tracer/fx/custom_tracer.py`. You can inherit this class and customize your own tracer. + +4. Develop new fakequant method(optional) + +You can use fakequants provided by pytorch in `mmrazor/models/fake_quants/torch_fake_quants.py` as core functions provider. If you want to use the fakequant methods from other papers, you can also define them yourself. Let's take lsq as an example as follows: + +a.Create a new file `mmrazor/models/fake_quants/lsq.py`, class `LearnableFakeQuantize` inherits from class `FakeQuantizeBase`. + +b. Finish the functions you need, eg: `observe_quant_params`, `calculate_qparams` and so on. + +```Python +from mmrazor.registry import MODELS +from torch.ao.quantization import FakeQuantizeBase + +@MODELS.register_module() +class LearnableFakeQuantize(FakeQuantizeBase): + def __init__(self, + observer, + quant_min=0, + quant_max=255, + scale=1., + zero_point=0., + use_grad_scaling=True, + zero_point_trainable=False, + **observer_kwargs): + super(LearnableFakeQuantize, self).__init__() + pass + + def observe_quant_params(self): + pass + + def calculate_qparams(self): + pass + + def forward(self, X): + pass +``` + +c.Import the module in `mmrazor/models/fake_quants/__init__.py`. + +```Python +from .lsq import LearnableFakeQuantize + +__all__ = ['LearnableFakeQuantize'] +``` + +5. Develop new observer(optional) + +You can directly use observers provided by pytorch in `mmrazor/models/observers/torch_observers.py` or use observers customized by yourself. Let's take `LSQObserver` as follows: + +a.Create a new observer file `mmrazor/models/observers/lsq.py`, class `LSQObserver` inherits from class `MinMaxObserver` and `LSQObserverMixIn`. These two observers can calculate `zero_point` and `scale`, respectively. + +b.Finish the functions you need, eg: `calculate_qparams` and so on. + +```Python +from mmrazor.registry import MODELS +from torch.ao.quantization.observer import MinMaxObserver + +class LSQObserverMixIn: + def __init__(self): + self.tensor_norm = None + + @torch.jit.export + def _calculate_scale(self): + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + sync_tensor(scale) + return scale + +@MODELS.register_module() +class LSQObserver(MinMaxObserver, LSQObserverMixIn): + """LSQ observer. + Paper: Learned Step Size Quantization. + """ + def __init__(self, *args, **kwargs): + MinMaxObserver.__init__(self, *args, **kwargs) + LSQObserverMixIn.__init__(self) + + def forward(self, x_orig): + """Records the running minimum, maximum and tensor_norm of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + self.tensor_norm = x.abs().mean() + min_val_cur, max_val_cur = torch.aminmax(x) + min_val = torch.min(min_val_cur, self.min_val) + max_val = torch.max(max_val_cur, self.max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + return x_orig + + @torch.jit.export + def calculate_qparams(self): + """Calculates the quantization parameters.""" + _, zero_point = MinMaxObserver.calculate_qparams(self) + scale = LSQObserverMixIn._calculate_scale(self) + return scale, zero_point +``` + +c.Import the module in `mmrazor/models/observers/__init__.py` + +```Python +from .lsq import LSQObserver + +__all__ = ['LSQObserver'] +``` + +6. Select loop or develop new loop + +At present, the QAT loops we support are `PTQLoop` and `QATEpochBasedLoop`, in `mmrazor/engine/runner/quantization_loops.py`. We can develop a new `LSQEpochBasedLoop` inherits from class `QATEpochBasedLoop` and finish the functions we need in LSQ method. + +```Python +from mmengine.runner import EpochBasedTrainLoop + +@LOOPS.register_module() +class LSQEpochBasedLoop(QATEpochBasedLoop): + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + val_begin: int = 1, + val_interval: int = 1, + freeze_bn_begin: int = -1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__( + runner, + dataloader, + max_epochs, + val_begin, + val_interval, + freeze_bn_begin=freeze_bn_begin, + dynamic_intervals=dynamic_intervals) + + self.is_first_batch = True + + def prepare_for_run_epoch(self): + pass + + def prepare_for_val(self): + pass + + def run_epoch(self) -> None: + pass +``` + +And then Import the module in `mmrazor/engine/runner/__init__.py` + +```Python +from .quantization_loops import LSQEpochBasedLoop + +__all__ = ['LSQEpochBasedLoop'] +``` + +7. Use the algorithm in your config file + +After completing the above steps, we have all the components of the qat algorithm, and now we can combine them in the config file. + +a.First, `_base_` stores the location of the model that needs to be quantized. + +b.Second, configure observer,fakequant and qscheme in `global_qconfig` in detail. +You can configure the required quantization bit width and quantization methods in `qscheme`, such as symmetric quantization or asymmetric quantization. + +c.Third, build the whole mmrazor model in `model`. + +d.Finally, complete all the remaining required configuration files. + +```Python +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_ckpt, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + is_qat=True, + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +# learning policy +optim_wrapper = dict() +param_scheduler = dict() +model_wrapper_cfg = dict() + +# train, val, test setting +train_cfg = dict(type='mmrazor.LSQEpochBasedLoop') +val_cfg = dict() +test_cfg = val_cfg +``` diff --git a/docs/en/advanced_guides/index.rst b/docs/en/advanced_guides/index.rst index 7d46576ef..349dc5902 100644 --- a/docs/en/advanced_guides/index.rst +++ b/docs/en/advanced_guides/index.rst @@ -20,5 +20,6 @@ Development tutorials customize_nas_algorithms.md customize_pruning_algorithms.md customize_kd_algorithms.md + customize_quantization_algorithms.md customize_mixed_algorithms.md apply_existing_algorithms_to_new_tasks.md From 6cdb394e6fbc160d110f543097875f9570bf1b5d Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Mon, 30 Jan 2023 15:37:22 +0800 Subject: [PATCH 14/44] [Docs] Add quantization user guide (#441) * add quantization user guide * fix layout * fix layout * update README --- README.md | 14 +- docs/en/user_guides/index.rst | 11 +- .../en/user_guides/quantization_user_guide.md | 223 ++++++++++++++++++ 3 files changed, 240 insertions(+), 8 deletions(-) create mode 100644 docs/en/user_guides/quantization_user_guide.md diff --git a/README.md b/README.md index 26a5fe1a9..77e12092b 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ [![PyPI](https://img.shields.io/pypi/v/mmrazor)](https://pypi.org/project/mmrazor) -[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmrazor.readthedocs.io/en/dev-1.x/) +[![docs](https://img.shields.io/badge/docs-latest-blue)](https://mmrazor.readthedocs.io/en/quantize/) [![badge](https://github.com/open-mmlab/mmrazor/workflows/build/badge.svg)](https://github.com/open-mmlab/mmrazor/actions) [![codecov](https://codecov.io/gh/open-mmlab/mmrazor/branch/master/graph/badge.svg)](https://codecov.io/gh/open-mmlab/mmrazor) [![license](https://img.shields.io/github/license/open-mmlab/mmrazor.svg)](https://github.com/open-mmlab/mmrazor/blob/master/LICENSE) @@ -32,9 +32,9 @@ -[📘Documentation](https://mmrazor.readthedocs.io/en/dev-1.x/) | -[🛠️Installation](https://mmrazor.readthedocs.io/en/dev-1.x/get_started/installation.html) | -[👀Model Zoo](https://mmrazor.readthedocs.io/en/dev-1.x/get_started/model_zoo.html) | +[📘Documentation](https://mmrazor.readthedocs.io/en/quantize/) | +[🛠️Installation](https://mmrazor.readthedocs.io/en/quantize/get_started/installation.html) | +[👀Model Zoo](https://mmrazor.readthedocs.io/en/quantize/get_started/model_zoo.html) | [🤔Reporting Issues](https://github.com/open-mmlab/mmrazor/issues/new/choose) @@ -68,7 +68,7 @@ MMRazor is a model compression toolkit for model slimming and AutoML, which incl - Neural Architecture Search (NAS) - Pruning - Knowledge Distillation (KD) -- Quantization (come soon) +- Quantization It is a part of the [OpenMMLab](https://openmmlab.com/) project. @@ -86,7 +86,7 @@ Major features: With better modular design, developers can implement new model compression algorithms with only a few codes, or even by simply modifying config files. -Below is an overview of MMRazor's design and implementation, please refer to [tutorials](https://mmrazor.readthedocs.io/en/dev-1.x/get_started/overview.html) for more details. +Below is an overview of MMRazor's design and implementation, please refer to [tutorials](https://mmrazor.readthedocs.io/en/quantize/get_started/overview.html) for more details.
@@ -164,7 +164,7 @@ Please refer to [installation.md](/docs/en/get_started/installation.md) for more ## Getting Started -Please refer to [user guides](https://mmrazor.readthedocs.io/en/dev-1.x/user_guides/index.html) for the basic usage of MMRazor. There are also [advanced guides](https://mmrazor.readthedocs.io/en/dev-1.x/advanced_guides/index.html): +Please refer to [user guides](https://mmrazor.readthedocs.io/en/quantize/user_guides/index.html) for the basic usage of MMRazor. There are also [advanced guides](https://mmrazor.readthedocs.io/en/quantize/advanced_guides/index.html): ## Contributing diff --git a/docs/en/user_guides/index.rst b/docs/en/user_guides/index.rst index 622987867..96ebc0a6e 100644 --- a/docs/en/user_guides/index.rst +++ b/docs/en/user_guides/index.rst @@ -10,6 +10,15 @@ Train & Test 3_train_with_different_devices.md 4_test_a_model.md +Quantization +************ + +.. toctree:: + :maxdepth: 1 + + quantization_user_guide.md + Useful Tools ************ - please refer to upstream applied repositories' docs + +please refer to upstream applied repositories' docs diff --git a/docs/en/user_guides/quantization_user_guide.md b/docs/en/user_guides/quantization_user_guide.md new file mode 100644 index 000000000..d645d8451 --- /dev/null +++ b/docs/en/user_guides/quantization_user_guide.md @@ -0,0 +1,223 @@ +# Quantization + +## Introduction + +MMRazor's quantization is OpenMMLab's quantization toolkit, which has got through task models and model deployment. With its help, we can quantize and deploy pre-trained models in OpenMMLab to specified backend quickly. Of course, it can also contribute to implementing some custom quantization algorithms easier. + +### Major features + +- **Ease of use**. Benefited from PyTorch fx, we can quantize our model without modifying the original model, but with user-friendly config. +- **Multiple backends deployment support**. Because of the specificity of each backend, a gap in performance usually exists between before and after deployment. We provided some common backend deployment support to reduce the gap as much. +- **Multiple task repos support.** Benefited from OpenMMLab 2.0, our quantization can support all task repos of OpenMMLab without extra code. +- **Be compatible with PyTorch's core module in quantization**. Some core modules in PyTorch can be used directly in mmrazor, such as `Observer`, `FakeQuantize`, `BackendConfig` and so on. + +## Quick run + +```{note} +MMRazor's quantization is based on `torch==1.13`. Other requirements are the same as MMRazor's +``` + +Model quantization is in mmrazor, but quantized model deployment is in mmdeploy. So we need to use two branches as follows: + +mmrazor: https://github.com/open-mmlab/mmrazor/tree/quantize + +mmdeploy: https://github.com/humu789/mmdeploy/tree/adapt_razor_quantize + +1. Quantize the float model in mmrazor. + +```Shell +# For QAT (Quantization Aware Training) +python tools/train.py ${CONFIG_FILE} [optional arguments] + +# For PTQ (Post-training quantization) +python tools/ptq.py ${CONFIG_FILE} [optional arguments] +``` + +2. Convert quantized model checkpoint in mmrazor. (required by model deployment) + +```Shell +python tools/model_converters/convert_quant_ckpt.py ${CKPT_PATH} +``` + +3. Export quantized model to a specific backend in mmdeploy. (required by model deployment) + +```Shell +python ./tools/deploy.py \ + ${DEPLOY_CFG_PATH} \ + ${MODEL_CFG_PATH} \ + ${MODEL_CHECKPOINT_PATH} \ + ${INPUT_IMG} \ + [optional arguments] +``` + +This step is the same as how to export an OpenMMLab model to a specific backend. For more details, please refer to [How to convert model](https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/02-how-to-run/convert_model.md) + +4. Evaluate the exported model. (optional) + +```Shell +python tools/test.py \ + ${DEPLOY_CFG} \ + ${MODEL_CFG} \ + --model ${BACKEND_MODEL_FILES} \ + [optional arguments] +``` + +This step is the same as evaluating backend models. For more details, please refer to [How to evaluate model](https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/02-how-to-run/profile_model.md) + +## How to quantize your own model quickly + +If you want to try quantize your own model quickly, you just need to learn about how to change our provided config. + +**Case 1: If the model you want to quantize is in our provided configs.** + +You can refer to the previous chapter Quick Run. + +**Case 2: If the model you want to quantize is not in our provided configs.** + +Let us take `resnet50` as an example to show how to handle case 2. + +```Python +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +train_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=train_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) +``` + +This is a config that quantize `resnet18` with OpenVINO backend. You just need to modify two args: `_base_` and `float_checkpoint`. + +```Python +# before +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' + +# after +_base_ = ['mmcls::resnet/resnet50_8xb32_in1k.py'] +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' +``` + +- `_base_` will be called from mmcls by mmengine, so you can just use mmcls provided configs directly. Other repos are similar. +- `float_checkpoint ` is a pre-trained float checkpoint by OpenMMLab. You can find it in the corresponding repo. + +After modifying required config, we can use it the same as case 1. + +## How to improve your quantization performance + +If you can not be satisfied with quantization performance by applying our provided configs to your own model, you can try to improve it with our provided various quantization schemes by modifying `global_qconfig`. + +```Python +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) +``` + +As shown above, `global_qconfig` contains server common core args as follows: + +- Observes + +In `forward`, they will update the statistics of the observed Tensor. And they should provide a `calculate_qparams` function that computes the quantization parameters given the collected statistics. + +```{note} +Whether it is per channel quantization depends on whether `PerChannel` is in the observer name. +``` + +Because mmrazor's quantization has been compatible with PyTorch's observers, we can use observers in PyTorch and our custom observers. + +Supported observers list in Pytorch. + +```Python +FixedQParamsObserver +HistogramObserver +MinMaxObserver +MovingAverageMinMaxObserver +MovingAveragePerChannelMinMaxObserver +NoopObserver +ObserverBase +PerChannelMinMaxObserver +PlaceholderObserver +RecordingObserver +ReuseInputObserver +UniformQuantizationObserverBase +``` + +- Fake quants + +In `forward`, they will update the statistics of the observed Tensor and fake quantize the input. They should also provide a `calculate_qparams` function that computes the quantization parameters given the collected statistics. + +Because mmrazor's quantization has been compatible with PyTorch's fakequants, we can use fakequants in PyTorch and our custom fakequants. + +Supported fakequants list in Pytorch. + +```Python +FakeQuantize +FakeQuantizeBase +FixedQParamsFakeQuantize +FusedMovingAvgObsFakeQuantize +``` + +- Qschemes + +Include some basic quantization configurations. + +`qdtype`: to specify whether quantized data type is sign or unsign. It can be chosen from \[ 'qint8', 'quint8' \] + +`bit`: to specify the quantized data bit. It can be chosen from \[1 ~ 16\]. + +`is_symmetry`: to specify whether to use symmetry quantization. It can be chosen from \[ True, False \] + +The specified qscheme is actually implemented by observers, so how to configurate other args needs to be based on the given observers, such as `is_symmetric_range` and `averaging_constant`. + +## How to customize your quantization algorithm + +If you try to customize your quantization algorithm, you can refer to the following link for more details. + +[Customize Quantization algorithms](https://github.com/open-mmlab/mmrazor/blob/quantize/docs/en/advanced_guides/customize_quantization_algorithms.md) From 0cd361d20bfaf9c7fcbe93304a255a639e70aa45 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 7 Feb 2023 18:08:44 +0800 Subject: [PATCH 15/44] [Bug] Fix del redundant fakequant (#447) fix del redundant fakequant --- mmrazor/models/quantizers/native_quantizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 1d566b45f..f24e8d538 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -313,28 +313,28 @@ def module_prev_wo_fakequant(self): extra_function_prev_wo_fakequant = self.extra_redundant_fakequants.get( 'extra_function_prev_wo_fakequant', tuple()) - prepared = del_fakequant_before_method( + prepared = del_fakequant_before_function( prepared, self.function_prev_wo_fakequant + extra_function_prev_wo_fakequant, inplace=True) extra_function_next_wo_fakequant = self.extra_redundant_fakequants.get( 'extra_function_next_wo_fakequant', tuple()) - prepared = del_fakequant_after_method( + prepared = del_fakequant_after_function( prepared, self.function_next_wo_fakequant + extra_function_next_wo_fakequant, inplace=True) extra_method_prev_wo_fakequant = self.extra_redundant_fakequants.get( 'extra_method_prev_wo_fakequant', tuple()) - prepared = del_fakequant_before_function( + prepared = del_fakequant_before_method( prepared, self.method_prev_wo_fakequant + extra_method_prev_wo_fakequant, inplace=True) extra_method_next_wo_fakequant = self.extra_redundant_fakequants.get( 'extra_method_next_wo_fakequant', tuple()) - prepared = del_fakequant_after_function( + prepared = del_fakequant_after_method( prepared, self.method_next_wo_fakequant + extra_method_next_wo_fakequant, inplace=True) From d6a7ea5fdd6b61aaa730ec0a31ec935dfc24facd Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Wed, 15 Mar 2023 14:09:40 +0800 Subject: [PATCH 16/44] [Feature] Add onnx exporters (#475) * fix del redundant fakequant * add onnx exporters * fix onnx exporters and add docstring * fix comments * delete useless codes * fix export_onnx in native quantizer --------- Co-authored-by: pppppM --- .../quantization/mm_architecture.py | 32 +- .../models/quantizers/exporters/__init__.py | 5 + .../exporters/base_quantize_exporter.py | 164 ++++++++++ .../exporters/openvino_quantize_exporter.py | 152 ++++++++++ .../quantizers/exporters/optim_utils.py | 281 ++++++++++++++++++ .../exporters/tensorrt_quantize_exporter.py | 44 +++ mmrazor/models/quantizers/native_quantizer.py | 52 +++- .../models/quantizers/openvino_quantizer.py | 48 ++- 8 files changed, 735 insertions(+), 43 deletions(-) create mode 100644 mmrazor/models/quantizers/exporters/__init__.py create mode 100644 mmrazor/models/quantizers/exporters/base_quantize_exporter.py create mode 100644 mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py create mode 100644 mmrazor/models/quantizers/exporters/optim_utils.py create mode 100644 mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index d3b0be089..ff99592d6 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -13,12 +13,14 @@ try: from torch.ao.quantization import (FakeQuantizeBase, MinMaxObserver, - PerChannelMinMaxObserver) + PerChannelMinMaxObserver, + disable_observer) except ImportError: from mmrazor.utils import get_placeholder FakeQuantizeBase = get_placeholder('torch>=1.13') MinMaxObserver = get_placeholder('torch>=1.13') PerChannelMinMaxObserver = get_placeholder('torch>=1.13') + disable_observer = get_placeholder('torch>=1.13') LossResults = Dict[str, torch.Tensor] TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] @@ -213,6 +215,34 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]): data = self.data_preprocessor(data, False) return self._run_forward(data, mode='predict') + def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + + quantized_state_dict = self.qmodels['tensor'].state_dict() + fp32_model = self.architecture + self.quantizer.convert_batchnorm2d(fp32_model) + observed_model = self.quantizer.prepare(fp32_model, {'mode': 'tensor'}) + + if dummy_input is not None: + observed_model(torch.randn(dummy_input)) + + observed_model.load_state_dict(quantized_state_dict) + + self.quantizer.post_process_for_deploy( + observed_model, keep_w_fake_quant=True) + + observed_model.apply(disable_observer) + + return observed_model + @MODEL_WRAPPERS.register_module() class MMArchitectureQuantDDP(MMDistributedDataParallel): diff --git a/mmrazor/models/quantizers/exporters/__init__.py b/mmrazor/models/quantizers/exporters/__init__.py new file mode 100644 index 000000000..b8153289d --- /dev/null +++ b/mmrazor/models/quantizers/exporters/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .openvino_quantize_exporter import OpenVinoQuantizeExportor +from .tensorrt_quantize_exporter import TensorRTExplicitExporter + +__all__ = ['OpenVinoQuantizeExportor', 'TensorRTExplicitExporter'] diff --git a/mmrazor/models/quantizers/exporters/base_quantize_exporter.py b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py new file mode 100644 index 000000000..7e1e1f375 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py @@ -0,0 +1,164 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import onnx +from mmengine import print_log +from onnx import numpy_helper + +from .optim_utils import ONNXOptimUtils + +SUPPORT_QWEIGHT_NODE = ['Gemm', 'Conv', 'ConvTranspose'] + +PERCHANNEL_FAKEQUANTIZER = [ + 'FakeQuantizeLearnablePerchannelAffine', 'FixedPerChannelAffine' +] +PERTENSOR_FAKEQUANTIZER = ['LearnablePerTensorAffine', 'FixedPerTensorAffine'] + +ALL_FAKEQUANTIZER = PERCHANNEL_FAKEQUANTIZER + PERTENSOR_FAKEQUANTIZER + + +def _parse_attrs(node_attrs): + attrs = {} + for attr in node_attrs: + if attr.type == onnx.AttributeProto.AttributeType.INTS: + attrs[attr.name] = tuple(attr.ints) + elif attr.type == onnx.AttributeProto.AttributeType.INT: + attrs[attr.name] = attr.i + elif attr.type == onnx.AttributeProto.AttributeType.FLOATS: + attrs[attr.name] = tuple(attr.floats) + elif attr.type == onnx.AttributeProto.AttributeType.FLOAT: + attrs[attr.name] = attr.f + elif attr.type == onnx.AttributeProto.AttributeType.TENSOR: + attrs[attr.name] = numpy_helper.to_array(attr.t) + elif attr.type == onnx.AttributeProto.AttributeType.STRING: + attrs[attr.name] = str(attr.s) + elif attr.type == onnx.AttributeProto.AttributeType.STRINGS: + attrs[attr.name] = tuple([str(x) for x in attr.strings]) + else: + raise Exception('ATTR Type [{}] Not Supported!'.format(attr.type)) + return attrs + + +class BaseQuantizeExportor(): + + optimizer = ONNXOptimUtils + + def __init__(self, onnx_model, export_path) -> None: + + if isinstance(onnx_model, str): + self.onnx_model = onnx.load(onnx_model) + elif isinstance(onnx_model, onnx.ModelProto): + self.onnx_model = onnx_model + else: + raise TypeError + + self.export_path = export_path + self._init_mappings_from_onnx(self.onnx_model) + + self.optimizer.remove_fake_pad_op(self.onnx_model, self.name2data, + self.input2node, self.output2node) + + self._remap_input_and_node() + self._remap_output_and_node() + + @property + def graph(self): + """The onnx model's graph.""" + return self.onnx_model.graph + + def _init_mappings_from_onnx(self, onnx_model): + """Build necessary mappings in a onnx model.""" + + self.input2node = self.optimizer.map_input_and_node(onnx_model) + self.output2node = self.optimizer.map_output_and_node(onnx_model) + self.name2data = self.optimizer.map_name_and_data(onnx_model) + + # todo: maybe useless + # self.name2init = self.optimizer.map_name_and_initializer(onnx_model) + + def _remap_input_and_node(self): + """Rebuild the mapping from input name to a (node, input index) + tuple.""" + self.input2node = self.optimizer.map_input_and_node(self.onnx_model) + + def _remap_output_and_node(self): + """Rebuild the mapping from a node's output name to this node.""" + self.output2node = self.optimizer.map_output_and_node(self.onnx_model) + + def parse_qparams(self, node: onnx.NodeProto): + """Parse the quantize-related parameters based on a node.""" + tensor_name, scale, zero_point = node.input[:3] + + scale, zero_point = self.name2data[scale], self.name2data[zero_point] + if len(node.input) > 3: + qmin, qmax = node.input[-2:] + qmin, qmax = self.name2data[qmin], self.name2data[qmax] + elif len(node.attribute) > 0: + qparams = _parse_attrs(node.attribute) + qmin = qparams['quant_min'] + qmax = qparams['quant_max'] + else: + print_log(f'qmin and qmax are not found for <{node.name}>!') + qmax = qmin = None + return tensor_name, scale, zero_point, qmin, qmax + + def collect_symbolic_nodes(self, onnx_model: onnx.ModelProto): + """Collect all the fakequant nodes from a onnx model.""" + symbolic_nodes = list() + for node in onnx_model.graph.node: + if node.op_type in ALL_FAKEQUANTIZER: + symbolic_nodes.append(node) + return symbolic_nodes + + def _get_constant_inputs(self, node: onnx.NodeProto): + """Get the constant input node for the current node.""" + constant_nodes = list() + output2node = self.output2node + for inp in node.input: + if inp in output2node and output2node[inp].op_type == 'Constant': + cnode = output2node[inp] + + constant_nodes.append(cnode) + return constant_nodes + + def _collect_symbolic_constant_inputs(self, symbolic_nodes: List): + """Collect these constant nodes which is the input of all the symbolic + node.""" + + collected_constant_names = set() + constant_inputs = list() + for node in symbolic_nodes: + constant_inputs = self._get_constant_inputs(node) + for constant in constant_inputs: + if constant.name in collected_constant_names: + continue + constant_inputs.append(constant) + collected_constant_names.add(constant.name) + return constant_inputs + + def _remove_symbolic_related_from_onnx(self, symbolic_nodes: List, + symbolic_constant_inputs: List): + """Remove these out of date fakequant nodes and theirs constant input + nodes.""" + for node in symbolic_nodes: + self.onnx_model.graph.node.remove(node) + + # Remove symbolic related constant nodes. The constant node which is + # only used by those symbolic nodes can be removed. + + def _is_standalone_constant_node(constant): + for node in self.onnx_model.graph.node: + for input_name in node.input: + # A constant node always has one output. + if input_name == constant.output[0]: + return False + return True + + for constant in symbolic_constant_inputs: + if _is_standalone_constant_node(constant): + self.onnx_model.graph.node.remove(constant) + + def export(self): + """Export end to end onnx model.""" + # todo: is it a abstract method? + raise NotImplementedError diff --git a/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py new file mode 100644 index 000000000..6d0df5d36 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import List + +import numpy as np +import onnx +from google.protobuf.internal.containers import RepeatedScalarFieldContainer +from onnx import helper, numpy_helper + +from .base_quantize_exporter import BaseQuantizeExportor + + +class OpenVinoQuantizeExportor(BaseQuantizeExportor): + + def __init__(self, onnx_model, export_path) -> None: + super().__init__(onnx_model, export_path) + + def _build_backend_node_from_symbolic(self, node: onnx.NodeProto, + tensor_name: str, qmin: np.ndarray, + qmax: np.ndarray): + """Build new onnx nodes which can be deployed to the specific backend. + + These nodes will be used to replace those symbolic nodes in the + original onnx model. + """ + qmax = int(qmax) + qmin = int(qmin) + levels = qmax - qmin + 1 + # adjust weight levels + # if levels == 128: + # levels = 256 + # qmax = qmax * 2 + 1 + # qmin = qmin * 2 + output_name = node.output[0] + # Create a node (FakeQuantize) + keys = ['input_min', 'input_max', 'output_min', 'output_max'] + input_names = [f'{tensor_name}_{key}' for key in keys] + backend_node = helper.make_node( + 'FakeQuantize', # node name + [tensor_name, *input_names], # inputs + [output_name], # outputs + levels=levels, # Attributes + domain='org.openvinotoolkit', + name=node.name) + return backend_node + + def _build_backend_initializer(self, + names: RepeatedScalarFieldContainer[str], + scale: np.ndarray, zero_point: np.ndarray, + qmin: np.ndarray, qmax: np.ndarray, + shape: List[int]): + """Build onnx initializers which can be deployed to specific + backend.""" + + scale = np.abs(np.asarray(scale, dtype=np.float64).reshape(-1)) + zero_point = np.clip( + np.asarray(np.round(zero_point), dtype=np.int32).reshape(-1), + a_min=qmin, + a_max=qmax) + + qrange = float(qmax - qmin) + input_range = scale * qrange + input_high = (qmax - zero_point).astype( + np.float64) * input_range / qrange + input_low = input_high - input_range + input_low_size = input_low.size + + if input_low_size != 1: + input_low = input_low.reshape(*shape) + input_high = input_high.reshape(*shape) + + input_low = input_low.astype(np.float32) + input_high = input_high.astype(np.float32) + + initializers = list() + for init_name, value_tensor in zip( + names, [input_low, input_high, input_low, input_high]): + init = numpy_helper.from_array(value_tensor) + init.name = init_name + initializers.append(init) + return initializers + + def build_backend_nodes_and_initializers(self, symbolic_nodes: List): + """Build new onnx nodes and initializers which can be deployed to + specific backend.""" + backend_nodes = list() + backend_initializers = list() + for node in symbolic_nodes: + tensor_name, scale, zero_point, qmin, qmax = self.parse_qparams( + node) + new_node = self._build_backend_node_from_symbolic( + node, tensor_name, qmin, qmax) + backend_nodes.append(new_node) + + try: + # If the successor node (such as a conv node) has weight, + # we need get the length of the weight's shape. And ensure + # the length of the weight's shape and the new node's + # input shape (such as input_low and input_high) is the same. + next_node = self.input2node[node.output[0]][0][0] + # node for save weights + fake_node = self.output2node[next_node.input[1]] + tensor = self.name2data[fake_node.input[0]] + shape_length = len(tensor.shape) + new_shape = [-1] + [1] * (shape_length - 1) + except Exception: + new_shape = [-1] + + # The first element of new_node.input is the tensor name. + new_init_names = new_node.input[1:] + new_initializers = self._build_backend_initializer( + new_init_names, scale, zero_point, qmin, qmax, new_shape) + backend_initializers.extend(new_initializers) + return backend_nodes, backend_initializers + + def _insert_initializers_to_onnx(self, initializers: List): + """Insert onnx initializers to the onnx graph.""" + inserted_init_names = set() + for init in initializers: + if init.name in inserted_init_names: + continue + + self.onnx_model.graph.initializer.append(init) + inserted_init_names.add(init.name) + + def _replace_symbolic_related(self): + """Replacing symbolic related nodes and initializers in the original + onnx model with new nodes and initializers that can be deployed to the + specific backend.""" + + symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model) + + collect_func = self._collect_symbolic_constant_inputs + # Usually different activation fakequants share the same constant + # input, and different weight fakequants share the same constant input. + symbolic_constant_inputs = collect_func(symbolic_nodes) + + build_func = self.build_backend_nodes_and_initializers + new_nodes, new_initializers = build_func(symbolic_nodes) + + self._insert_initializers_to_onnx(new_initializers) + + self._remove_symbolic_related_from_onnx(symbolic_nodes, + symbolic_constant_inputs) + + self.onnx_model.graph.node.extend(new_nodes) + self.optimizer.optimize(self.onnx_model) + + def export(self): + """Export end to end onnx model.""" + self._replace_symbolic_related() + onnx.save(self.onnx_model, self.export_path) diff --git a/mmrazor/models/quantizers/exporters/optim_utils.py b/mmrazor/models/quantizers/exporters/optim_utils.py new file mode 100644 index 000000000..62b348d1c --- /dev/null +++ b/mmrazor/models/quantizers/exporters/optim_utils.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Dict, List, Optional + +import onnx +from mmengine import print_log +from onnx import numpy_helper + + +class ONNXOptimUtils(): + + @classmethod + def map_name_and_data(cls, onnx_model: onnx.ModelProto): + """Build the mapping from a data's name to the data itself.""" + params = {} + for init in onnx_model.graph.initializer: + params[init.name] = numpy_helper.to_array(init) + for node in onnx_model.graph.node: + # If two zero_points are identity, one is a reference to the other + # after optimized by onnx. + if node.op_type == 'Identity' and len(node.input) == 1 and \ + node.input[0] in params: + params[node.output[0]] = copy.deepcopy(params[node.input[0]]) + if node.op_type == 'Constant': + for attr in node.attribute: + if attr.name == 'value': + params[node.output[0]] = numpy_helper.to_array(attr.t) + return params + + @classmethod + def map_name_and_initializer(cls, + onnx_model: onnx.ModelProto, + allow_redundant=True): + """Build the mapping from a initializer's output name to this + initializer.""" + + initializers = dict() + + for idx, init in enumerate(onnx_model.graph.initializer): + initializers[init.name] = (init, idx) + + return initializers + + @classmethod + def map_output_and_node(cls, onnx_model: onnx.ModelProto): + """Build the mapping from a node's output name to this node.""" + output2node = dict() + for node in onnx_model.graph.node: + for output_name in node.output: + output2node[output_name] = node + return output2node + + @classmethod + def map_input_and_node(cls, onnx_model: onnx.ModelProto): + """Build the mapping from input name to a (node, input index) tuple.""" + + input2node: Dict[str, List] = dict() + for node in onnx_model.graph.node: + for idx, input_name in enumerate(node.input): + if input_name not in input2node: + input2node[input_name] = [] + input2node[input_name].append([node, idx]) + return input2node + + @classmethod + def get_constant(cls, name, onnx_model): + for node in onnx_model.graph.node: + if node.op_type == 'Constant': + if node.output[0] == name: + return numpy_helper.to_array(node.attribute[0].t).tolist() + + @classmethod + def get_initializer(cls, initializer_name, onnx_model): + return numpy_helper.to_array( + onnx_model.initializer[initializer_name][0]) + + @classmethod + def get_tensor_producer(cls, output_name, output2node): + if output_name not in output2node: + return 'INPUT_TOKEN' + return output2node[output_name] + + @classmethod + def get_tensor_consumer(self, input_name, input2node): + if input_name not in input2node: + return ['OUTPUT_TOKEN'] + return input2node[input_name] + + @classmethod + def remove_node_from_onnx(cls, node: onnx.NodeProto, + onnx_model: onnx.ModelProto): + """Removes a node from node list.""" + onnx_model.graph.node.remove(node) + + @classmethod + def remove_initializer_from_onnx(cls, initializer: onnx.TensorProto, + onnx_model: onnx.ModelProto): + """Inserts the initializer at the specified position.""" + onnx_model.graph.initializer.remove(initializer) + + @classmethod + def remove_fake_pad_op(cls, onnx_model, name2data, inp2node, out2node): + nodes_to_be_removed = [] + for idx, node in enumerate(onnx_model.graph.node): + if node.op_type == 'Pad': + pads = name2data[node.input[1]] + if all([x == 0 for x in pads]): + print_log(f'Remove pad op: <{node.name}>.') + next_nodes = inp2node[node.output[0]] + for next_node, idx in next_nodes: + next_node.input[idx] = node.input[0] + nodes_to_be_removed.append(node) + + for node in nodes_to_be_removed: + onnx_model.graph.node.remove(node) + + @classmethod + def insert_node_to_onnx(cls, + node: onnx.NodeProto, + onnx_model: onnx.ModelProto, + idx: int = 0): + """Inserts the node at the specified position.""" + onnx_model.graph.node.insert(idx, node) + + @classmethod + def find_standalone_nodes(cls, + onnx_model: onnx.ModelProto, + input2node: Optional[Dict] = None, + output2node: Optional[Dict] = None): + """Find unused nodes.""" + + if input2node is None: + input2node = cls.map_input_and_node(onnx_model) + if output2node is None: + output2node = cls.map_output_and_node(onnx_model) + + def _is_standalone_node(node, input2node, output2node): + for input_name in node.input: + if input_name in output2node: + return False + + for out_node in node.output: + if out_node in input2node: + return False + + return True + + standalone_nodes = list() + for node in onnx_model.graph.node: + + if _is_standalone_node(node, input2node, output2node): + standalone_nodes.append(node) + return standalone_nodes + + @classmethod + def find_redundant_initializers(cls, + onnx_model: onnx.ModelProto, + input2node: Optional[Dict] = None): + """Find unused initializers.""" + if input2node is None: + input2node = cls.map_input_and_node(onnx_model) + + initializers = cls.map_name_and_initializer(onnx_model) + redundant_initializers = list() + redundant_set = set() + for name, init_and_idx in initializers.items(): + if name not in input2node and name not in redundant_set: + # init_and_idx[0] is onnx.onnx_ml_pb2.TensorProto + # init_and_idx[1] is a integer index + redundant_initializers.append(init_and_idx[0]) + redundant_set.add(name) + return redundant_initializers + + @classmethod + def topo_sort(cls, + onnx_model: onnx.ModelProto, + initializers: Optional[Dict] = None, + inplace: bool = True): + """Topologically sort the nodes in a directed acyclic graph. + + Note that nodes in a directed acyclic graph may be out of order + after replacing symbolic related nodes with new nodes. + + Args: + onnx_model (onnx.ModelProto): The onnx model to be sorted + topologically. + initializers (Dict | Optional): The mapping from name to + initializers. Default to None. + inplace (bool): Can optionally do the operation in-place. + Defaults to True. + """ + + if inplace: + _onnx_model = onnx_model + else: + _onnx_model = copy.deepcopy(onnx_model) + + if initializers is None: + initializers = cls.map_name_and_initializer( + _onnx_model, allow_redundant=True) + + # A node may have multiple outputs. The first output name of a node + # named `/conv/Conv` is `/conv/Conv_output_0` + output_name2node = {} + for node in _onnx_model.graph.node: + for output_name in node.output: + output_name2node[output_name] = node + for node in _onnx_model.graph.input: + output_name2node[node.name] = node + + name2node = {node.name: node for node in _onnx_model.graph.node} + + graph: Dict[str, + List] = {node.name: [] + for node in _onnx_model.graph.node} + for node in _onnx_model.graph.input: + graph[node.name] = [] + + indegree = {node.name: 0 for node in _onnx_model.graph.node} + + # Build graph + for i, node in enumerate(_onnx_model.graph.node): + for input_name in node.input: + if input_name not in initializers: + indegree[node.name] += 1 + prev_node = output_name2node[input_name] + graph[prev_node.name].append(node) + + graph_input = [node.name for node in _onnx_model.graph.input] + root = graph_input.copy() + sorted_nodes = [] + + # There are some nodes whose input are all initializers. + for node_name, in_degree in indegree.items(): + if in_degree == 0: + root.append(node_name) + + while root: + node_name = root.pop() + # There is no intersection between graph_input and + # _onnx_model.graph.node + if node_name not in graph_input: + node = name2node[node_name] + sorted_nodes.append(node) + for next_node in graph[node_name]: + indegree[next_node.name] -= 1 + if indegree[next_node.name] == 0: + root.append(next_node.name) + + num_nodes = len(_onnx_model.graph.node) + if len(sorted_nodes) != num_nodes: + raise RuntimeError('The graph is not a DAG.') + + for _ in range(num_nodes): + _onnx_model.graph.node.pop() + for node in sorted_nodes: + _onnx_model.graph.node.append(node) + + return _onnx_model + + @classmethod + def optimize(cls, onnx_model): + + input2node = cls.map_input_and_node(onnx_model) + output2node = cls.map_output_and_node(onnx_model) + + standalone_nodes = cls.find_standalone_nodes(onnx_model, input2node, + output2node) + for node in standalone_nodes: + cls.remove_node_from_onnx(node, onnx_model) + print_log(f'Remove node {node.name}') + + redundant_inits = cls.find_redundant_initializers( + onnx_model, input2node) + for init in redundant_inits: + cls.remove_initializer_from_onnx(init, onnx_model) + print_log(f'Remove initializer {init.name}') + + sorted_onnx_model = cls.topo_sort(onnx_model) + + return sorted_onnx_model diff --git a/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py new file mode 100644 index 000000000..7d05847c1 --- /dev/null +++ b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import onnx + +from .base_quantize_exporter import BaseQuantizeExportor + + +class TensorRTExplicitExporter(BaseQuantizeExportor): + + def __init__(self, onnx_model, export_path) -> None: + super().__init__(onnx_model, export_path) + + def _build_backend_node_from_symbolic(self, node): + quantize_linear_node = onnx.helper.make_node( + 'QuantizeLinear', node.input[:3], [node.name + '_quantized_out'], + node.name + '_quantized') + dequantize_linear_node = onnx.helper.make_node( + 'DequantizeLinear', + [node.name + '_quantized_out'] + quantize_linear_node.input[1:3], + node.output, node.name + '_dequantized') + return [quantize_linear_node, dequantize_linear_node] + + def build_backend_nodes(self, symbolic_nodes): + backend_nodes = list() + for node in symbolic_nodes: + _, _, zero_point, qmin, qmax = self.parse_qparams(node) + assert qmax - qmin in ( + 2**8 - 1, 2**8 - + 2), 'Only 8 bit quantization support deployment to ONNX.' + assert not np.any(zero_point != 0), \ + 'This pass is only supposed to be used with TensorRT ' \ + 'Backend which does not support asymmetric quantization.' + new_nodes = self._build_backend_node_from_symbolic(node) + backend_nodes.extend(new_nodes) + return backend_nodes + + def export(self): + symbolic_nodes = self.collect_symbolic_nodes(self.onnx_model) + new_nodes = self.build_backend_nodes(symbolic_nodes) + for node in symbolic_nodes: + self.onnx_model.graph.node.remove(node) + self.onnx_model.graph.node.extend(new_nodes) + self.optimizer.optimize(self.onnx_model) + onnx.save(self.onnx_model, self.export_path) diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index f24e8d538..c1edb6fe4 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch -import torch.nn as nn from mmengine.config import Config try: - from torch.ao.quantization import enable_fake_quant + from torch.ao.quantization import disable_observer, enable_fake_quant from torch.ao.quantization.fx import prepare from torch.ao.quantization.fx.graph_module import ObservedGraphModule from torch.ao.quantization.qconfig_mapping import ( @@ -16,11 +15,13 @@ from torch.fx.graph_module import GraphModule from torch.nn.intrinsic.qat import modules as qat_fused_modules from torch.nn.qat import modules as qat_modules + from torch.onnx import register_custom_op_symbolic except ImportError: from mmrazor.utils import get_package_placeholder, get_placeholder GraphModule = get_placeholder('torch>=1.13') ObservedGraphModule = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') + disable_observer = get_placeholder('torch>=1.13') prepare = get_placeholder('torch>=1.13') QConfigMapping = get_placeholder('torch>=1.13') _fuse_fx = get_placeholder('torch>=1.13') @@ -62,6 +63,23 @@ qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d: qat_modules.Linear } + + def fake_quantize_per_channel_affine(g, x, scale, zero_point, ch_axis, + quant_min, quant_max): + return g.op('mmrazor::FixedPerChannelAffine', x, scale, zero_point, + ch_axis, quant_min, quant_max) + + register_custom_op_symbolic('::fake_quantize_per_channel_affine', + fake_quantize_per_channel_affine, 11) + + def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, + quant_max): + return g.op('mmrazor::FixedPerTensorAffine', x, scale, zero_point, + quant_min, quant_max) + + register_custom_op_symbolic('::fake_quantize_per_tensor_affine', + fake_quantize_per_tensor_affine, 11) + else: SUPPORT_QAT_MODULES = () MERGE_BN_MAPPINGS = {} @@ -181,6 +199,13 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') + def export_onnx(self, model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], + torch.Tensor], output_path: str, **kwargs): + """Export the onnx model that can be deployed to a native backend.""" + torch.onnx.export(model, args, output_path, **kwargs) + def prepare(self, model, concrete_args=None): """prepare graph to ObservedGraphModule. @@ -223,16 +248,17 @@ def prepare(self, model, concrete_args=None): return prepared - def post_process_weight_fakequant(self, - observed_module: ObservedGraphModule, - keep_fake_quant: bool = False): + def post_process_for_deploy(self, + observed_module: ObservedGraphModule, + keep_w_fake_quant: bool = False): """weight fake-quant for supported QAT modules. Args: observed_module (ObservedGraphModule): Modules after fused and observed. - keep_fake_quant (bool, optional): Bool to determine whether to keep - fake-quant op, depending on the backend. Defaults to False. + keep_w_fake_quant (bool, optional): Bool to determine whether to + keep weight fake-quant op, depending on the backend. Defaults + to False. Note: `post_process_weight_fakequant()` function is necessary that the @@ -259,7 +285,7 @@ def traverse(module): # This is decided by backend type, some backend need # explicitly keep the fake quant structure, others don't. # TODO add deploy doc link - if keep_fake_quant: + if keep_w_fake_quant: for m in float_child.modules(): setattr(m, 'qconfig', self.qconfig.convert()) @@ -276,13 +302,9 @@ def traverse(module): else: traverse(child) - observed_module.apply(enable_fake_quant) traverse(observed_module) - - def prepare_for_mmdeploy(self, model: nn.Module, dummy_input: Tuple, - checkpoint: Optional[str]): - """Prepare model to Observed_model.""" - raise NotImplementedError + observed_module.apply(enable_fake_quant) + observed_module.apply(disable_observer) def del_redundant_fakequant(self, prepared: GraphModule): """delete redundant fakequant op in prepared model. diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index f8a25bd56..94658afbd 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple +from typing import Any, Optional, Tuple, Union import torch @@ -46,32 +46,26 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') - def prepare_for_mmdeploy(self, - model: torch.nn.Module, - dummy_input: Tuple = (1, 3, 224, 224), - checkpoint: Optional[str] = None): - """Prepare for deploy to the backend with mmdeploy, which will be used - in mmdeploy, and usually includes as follows: - - 1. prepare for the float model rewritten by mmdeploy. - 2. load checkpoint consists of float weight and quantized params in - mmrazor. - 3. post process weight fakequant for exporting .onnx that meet - the backend's requirement. - """ - self.convert_batchnorm2d(model) - observed_model = self.prepare(model) - if dummy_input is not None: - observed_model(torch.randn(dummy_input)) - if checkpoint is not None: - observed_model.load_state_dict( - torch.load(checkpoint)['state_dict']) - self.post_process_weight_fakequant( - observed_model, keep_fake_quant=True) - - observed_model.apply(disable_observer) - - return observed_model + def export_onnx(self, + model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], torch.Tensor], + output_path: str, + opset_version: Optional[int] = 11, + **kwargs): + """Export the onnx model that can be deployed to OpenVino backend.""" + + symbolic_output_path = f'symbolic_{output_path}' + torch.onnx.export( + model, + args, + symbolic_output_path, + opset_version=opset_version, + **kwargs) + + from .exporters import OpenVinoQuantizeExportor + exporter = OpenVinoQuantizeExportor(symbolic_output_path, output_path) + exporter.export() @property def module_prev_wo_fakequant(self): From 0a755474a2d50c2f21e8ca7129f456e407caa75c Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Mon, 27 Mar 2023 18:19:20 +0800 Subject: [PATCH 17/44] [Feature]Rewrite the origin model during prepare (#488) * add rewriter * add deploy_cfg arg * modify post_process_for_mmdeploy * fix bugs * add det config --- ...classification_openvino_dynamic-224x224.py | 30 ++++ .../detection_openvino_dynamic-800x1344.py | 50 ++++++ ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 6 +- ...openvino_retina_r50_1x_coco_calib32xb32.py | 6 +- .../quantization/mm_architecture.py | 169 +++++++++++++++++- .../models/quantizers/openvino_quantizer.py | 28 ++- 6 files changed, 278 insertions(+), 11 deletions(-) create mode 100644 configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py create mode 100644 configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py diff --git a/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py b/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py new file mode 100644 index 000000000..d1fc673c5 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py @@ -0,0 +1,30 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 0: 'batch' + } + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]), + codebase_config=dict(type='mmcls', task='Classification'), + function_record_to_pop=[ + 'mmcls.models.classifiers.ImageClassifier.forward', + 'mmcls.models.classifiers.BaseClassifier.forward' + ], +) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py new file mode 100644 index 000000000..f7fc5d064 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py @@ -0,0 +1,50 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_shape=None, + input_names=['input'], + output_names=['dets', 'labels'], + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'dets': { + 0: 'batch', + 1: 'num_dets', + }, + 'labels': { + 0: 'batch', + 1: 'num_dets', + }, + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 800, 1344]))]), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + model_type='end2end', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, # for YOLOv3 + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )), + function_record_to_pop=[ + 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', + 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', + 'mmdet.models.detectors.single_stage_instance_seg.' + 'SingleStageInstanceSegmentor.forward', + 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.' + 'predict_by_feat' + ]) diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 125f46367..93e3897bc 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -1,4 +1,7 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +_base_ = [ + 'mmcls::resnet/resnet18_8xb32_in1k.py', + '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] train_dataloader = dict(batch_size=32) @@ -33,6 +36,7 @@ # convert image from BGR to RGB to_rgb=True), architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 578f5fe84..109a7ee04 100644 --- a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -1,4 +1,7 @@ -_base_ = ['mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'] +_base_ = [ + 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', + '../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' +] train_dataloader = dict(batch_size=32) @@ -32,6 +35,7 @@ bgr_to_rgb=True, pad_size_divisor=32), architecture=retina, + deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index ff99592d6..d4081d96e 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch +from mmengine.config import Config from mmengine.model import MMDistributedDataParallel from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement @@ -37,6 +39,7 @@ class MMArchitectureQuant(BaseAlgorithm): quantized. quantizer (Union[Dict, BaseModel]): The quantizer to support different backend type. + deploy_cfg (Union[str, Dict]): Deployment config file or Config object. qmodel_modes (List): The available mode of runner. data_preprocessor (Optional[Dict]): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to None. @@ -57,6 +60,7 @@ class MMArchitectureQuant(BaseAlgorithm): def __init__(self, architecture: Union[Dict, BaseModel], quantizer: Union[Dict, BaseModel], + deploy_cfg: Union[str, Dict], data_preprocessor: Optional[Dict] = None, forward_modes: Tuple = ('tensor', 'predict', 'loss'), float_checkpoint: Optional[str] = None, @@ -68,6 +72,9 @@ def __init__(self, self.quantizer = MODELS.build(quantizer) self.input_shapes = input_shapes self.forward_modes = forward_modes + if isinstance(deploy_cfg, str): + deploy_cfg = Config.fromfile(deploy_cfg) + self.deploy_cfg = deploy_cfg # Replace syncbn and _BatchNormXd (in mmengine) with batchnorm2d self.quantizer.convert_batchnorm2d(self.architecture) @@ -104,7 +111,7 @@ def sync_qparams(self, src_mode: str): src_mode (str): The modes of forward method. Note: - `traverse()` function recursively traverses all module to sync + `traverse()` method recursively traverses all modules to sync quantized graph generated from different `forward_modes`. This is because We have different mode ('tensor', 'predict', 'loss') in OpenMMLab architecture which have different graph @@ -145,6 +152,116 @@ def traverse(module, prefix): continue traverse(self.qmodels[mode], '') + def _get_rewriter_context_in_mmdeploy(self, deploy_cfg): + """Get rewriter context in mmdeploy according to the deploy related + config.""" + from mmdeploy.apis.onnx.passes import optimize_onnx + from mmdeploy.codebase import import_codebase + from mmdeploy.core import RewriterContext + from mmdeploy.utils import (IR, Backend, get_backend, get_codebase, + get_dynamic_axes, get_ir_config, + get_onnx_config) + from mmdeploy.utils.config_utils import get_codebase_external_module + + codebase = get_codebase(deploy_cfg) + custom_module_list = get_codebase_external_module(deploy_cfg) + import_codebase(codebase, custom_module_list) + + def _add_or_update(cfg: dict, key: str, val: Any): + if key in cfg and isinstance(cfg[key], dict) and isinstance( + val, dict): + cfg[key].update(val) + else: + cfg[key] = val + + context_info = dict() + deploy_cfg = copy.deepcopy(deploy_cfg) + + backend = get_backend(deploy_cfg).value + + onnx_cfg = get_onnx_config(deploy_cfg) + opset_version = onnx_cfg.get('opset_version', 11) + + input_names = onnx_cfg['input_names'] + output_names = onnx_cfg['output_names'] + axis_names = input_names + output_names + dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names) + + verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get( + 'verbose', False) + keep_initializers_as_inputs = onnx_cfg.get( + 'keep_initializers_as_inputs', True) + optimize = onnx_cfg.get('optimize', False) + if backend == Backend.NCNN.value: + """NCNN backend needs a precise blob counts, while using onnx + optimizer will merge duplicate initilizers without reference + count.""" + optimize = False + + ir_config = dict( + type='onnx', + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + verbose=verbose, + keep_initializers_as_inputs=keep_initializers_as_inputs) + + _add_or_update(deploy_cfg, 'ir_config', ir_config) + ir = IR.get(get_ir_config(deploy_cfg)['type']) + if isinstance(backend, Backend): + backend = backend.value + backend_config = dict(type=backend) + _add_or_update(deploy_cfg, 'backend_config', backend_config) + + context_info['cfg'] = deploy_cfg + context_info['ir'] = ir + if 'backend' not in context_info: + context_info['backend'] = backend + if 'opset' not in context_info: + context_info['opset'] = opset_version + + if 'onnx_custom_passes' not in context_info: + onnx_custom_passes = optimize_onnx if optimize else None + context_info['onnx_custom_passes'] = onnx_custom_passes + + rewriter_context = RewriterContext(**context_info) + + # Hard codes to delete user-specific rewriters from + # `RewriterContext._rewriter_manager`. + # We use the model which is rewritten by mmdeploy to build quantized + # models. However not all the modules, functions and symbolic rewritten + # by mmdeploy need to be rewritten in mmrazor. For example, mmdeploy + # rewrite `mmcls.models.classifiers.ImageClassifier.forward` and + # `mmcls.models.classifiers.BaseClassifier.forward` for deployment. + # But they can't be rewritten by mmrazor as ptq and qat are done in + # mmrazor. So to ensure ptq and qat proceed normally, we have to remove + # these record from `RewriterContext._rewriter_manager`. + + # We have to deepcopy rewriter_context here to delete records safely. + rewriter_context = copy.deepcopy(rewriter_context) + module_record_to_pop = deploy_cfg.get('module_record_to_pop', []) + function_record_to_pop = deploy_cfg.get('function_record_to_pop', []) + symbolic_record_to_pop = deploy_cfg.get('symbolic_record_to_pop', []) + for record in module_record_to_pop: + records = rewriter_context._rewriter_manager.module_rewriter.\ + _registry._rewrite_records + if record in records: + records.pop(record) + for record in function_record_to_pop: + records = rewriter_context._rewriter_manager.function_rewriter.\ + _registry._rewrite_records + if record in records: + records.pop(record) + + for record in symbolic_record_to_pop: + records = rewriter_context._rewriter_manager.symbolic_rewriter.\ + _registry._rewrite_records + if record in records: + records.pop(record) + + return rewriter_context + def _build_qmodels(self, model: BaseModel): """Build quantized models from the given model. @@ -171,12 +288,50 @@ def _build_qmodels(self, model: BaseModel): output output (_get_predictions,) """ + rewriter_context = self._get_rewriter_context_in_mmdeploy( + self.deploy_cfg) + + # module_record_to_pop = self.deploy_cfg.get('module_record_to_pop', + # []) + # function_record_to_pop = self.deploy_cfg.get( + # 'function_record_to_pop', []) + # symbolic_record_to_pop = self.deploy_cfg.get( + # 'symbolic_record_to_pop', []) + # module_record_backup = {} + # function_record_backup = {} + # symbolic_record_backup = {} + # for record in module_record_to_pop: + # records = rewriter_context._rewriter_manager.module_rewriter. \ + # _registry._rewrite_records + # if record in records: + # module_record_backup[record] = records.pop(record) + # for record in function_record_to_pop: + # records = rewriter_context._rewriter_manager.function_rewriter. \ + # _registry._rewrite_records + # if record in records: + # function_record_backup[record] = records.pop(record) + # + # for record in symbolic_record_to_pop: + # records = rewriter_context._rewriter_manager.symbolic_rewriter. \ + # _registry._rewrite_records + # if record in records: + # symbolic_record_backup[record] = records.pop(record) + qmodels = nn.ModuleDict() for mode in self.forward_modes: concrete_args = {'mode': mode} - observed_module = self.quantizer.prepare(model, concrete_args) + # todo: support qat. + with rewriter_context: + observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module + # rewriter_context._rewriter_manager.module_rewriter. \ + # _registry._rewrite_records.update(module_record_backup) + # rewriter_context._rewriter_manager.function_rewriter. \ + # _registry._rewrite_records.update(function_record_backup) + # rewriter_context._rewriter_manager.symbolic_rewriter. \ + # _registry._rewrite_records.update(symbolic_record_backup) + # data_samples can not be None in detectors during prediction. # But we need to make the dummy prediction in _build_qmodels. # It is more convenient to use `tensor` mode. @@ -215,7 +370,7 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]): data = self.data_preprocessor(data, False) return self._run_forward(data, mode='predict') - def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)): + def get_deploy_model(self): """Prepare for deploy to the backend with mmdeploy, which will be used in mmdeploy, and usually includes as follows: @@ -226,13 +381,11 @@ def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)): the backend's requirement. """ - quantized_state_dict = self.qmodels['tensor'].state_dict() + quantized_state_dict = self.qmodels['predict'].state_dict() fp32_model = self.architecture self.quantizer.convert_batchnorm2d(fp32_model) - observed_model = self.quantizer.prepare(fp32_model, {'mode': 'tensor'}) - if dummy_input is not None: - observed_model(torch.randn(dummy_input)) + observed_model = self.quantizer.prepare(fp32_model) observed_model.load_state_dict(quantized_state_dict) diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 94658afbd..831d991f2 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -46,6 +46,32 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') + def prepare_for_mmdeploy(self, + model: torch.nn.Module, + dummy_input: Tuple = (1, 3, 224, 224), + checkpoint: Optional[str] = None): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + self.convert_batchnorm2d(model) + observed_model = self.prepare(model) + if dummy_input is not None: + observed_model(torch.randn(dummy_input)) + if checkpoint is not None: + observed_model.load_state_dict( + torch.load(checkpoint)['state_dict']) + self.post_process_for_deploy(observed_model, keep_w_fake_quant=True) + + observed_model.apply(disable_observer) + + return observed_model + def export_onnx(self, model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction], @@ -55,7 +81,7 @@ def export_onnx(self, **kwargs): """Export the onnx model that can be deployed to OpenVino backend.""" - symbolic_output_path = f'symbolic_{output_path}' + symbolic_output_path = output_path.replace('.onnx', '_symbolic.onnx') torch.onnx.export( model, args, From 019e43b674b65c2239775ba9061684d45ea22be2 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Wed, 29 Mar 2023 20:17:14 +0800 Subject: [PATCH 18/44] [Feature] Using rewriter in mmrazor when building qmodels. (#490) * add rewriter * add deploy_cfg arg * modify post_process_for_mmdeploy * fix bugs * add det config * replace deepcopy * pop detectors' forward --- .../detection_openvino_dynamic-800x1344.py | 4 +- .../quantization/mm_architecture.py | 96 +++++++------------ 2 files changed, 33 insertions(+), 67 deletions(-) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py index f7fc5d064..f8122ecaa 100644 --- a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py +++ b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py @@ -44,7 +44,5 @@ 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', 'mmdet.models.detectors.single_stage_instance_seg.' - 'SingleStageInstanceSegmentor.forward', - 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.' - 'predict_by_feat' + 'SingleStageInstanceSegmentor.forward' ]) diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index d4081d96e..53dd0f6cf 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -225,42 +225,35 @@ def _add_or_update(cfg: dict, key: str, val: Any): onnx_custom_passes = optimize_onnx if optimize else None context_info['onnx_custom_passes'] = onnx_custom_passes - rewriter_context = RewriterContext(**context_info) - - # Hard codes to delete user-specific rewriters from - # `RewriterContext._rewriter_manager`. - # We use the model which is rewritten by mmdeploy to build quantized - # models. However not all the modules, functions and symbolic rewritten - # by mmdeploy need to be rewritten in mmrazor. For example, mmdeploy - # rewrite `mmcls.models.classifiers.ImageClassifier.forward` and - # `mmcls.models.classifiers.BaseClassifier.forward` for deployment. - # But they can't be rewritten by mmrazor as ptq and qat are done in - # mmrazor. So to ensure ptq and qat proceed normally, we have to remove - # these record from `RewriterContext._rewriter_manager`. - - # We have to deepcopy rewriter_context here to delete records safely. - rewriter_context = copy.deepcopy(rewriter_context) - module_record_to_pop = deploy_cfg.get('module_record_to_pop', []) - function_record_to_pop = deploy_cfg.get('function_record_to_pop', []) - symbolic_record_to_pop = deploy_cfg.get('symbolic_record_to_pop', []) - for record in module_record_to_pop: - records = rewriter_context._rewriter_manager.module_rewriter.\ - _registry._rewrite_records - if record in records: - records.pop(record) - for record in function_record_to_pop: - records = rewriter_context._rewriter_manager.function_rewriter.\ - _registry._rewrite_records - if record in records: - records.pop(record) + return RewriterContext(**context_info) + + def _pop_function_record_in_rewriter_context(self, rewriter_context): + """Delete user-specific rewriters from + `RewriterContext._rewriter_manager`. We use the model which is + rewritten by mmdeploy to build quantized models. However not all the + functions rewritten by mmdeploy need to be rewritten in mmrazor. For + example, mmdeploy rewrite + `mmcls.models.classifiers.ImageClassifier.forward` and + `mmcls.models.classifiers.BaseClassifier.forward` for deployment. But + they can't be rewritten by mmrazor as ptq and qat are done in mmrazor. + So to ensure ptq and qat proceed normally, we have to remove these + record from `RewriterContext._rewriter_manager`. - for record in symbolic_record_to_pop: - records = rewriter_context._rewriter_manager.symbolic_rewriter.\ + Args: + rewriter_context (RewriterContext): The RewriterContext used in + mmdeploy. + """ + skipped_methods = getattr(self.quantizer.tracer, 'skipped_methods', []) + function_record_to_pop = self.deploy_cfg.get('function_record_to_pop', + []) + function_record_to_pop.extend(skipped_methods) + function_record_backup = {} + for record in function_record_to_pop: + records = rewriter_context._rewriter_manager.function_rewriter. \ _registry._rewrite_records if record in records: - records.pop(record) - - return rewriter_context + function_record_backup[record] = records.pop(record) + return function_record_backup def _build_qmodels(self, model: BaseModel): """Build quantized models from the given model. @@ -291,31 +284,9 @@ def _build_qmodels(self, model: BaseModel): rewriter_context = self._get_rewriter_context_in_mmdeploy( self.deploy_cfg) - # module_record_to_pop = self.deploy_cfg.get('module_record_to_pop', - # []) - # function_record_to_pop = self.deploy_cfg.get( - # 'function_record_to_pop', []) - # symbolic_record_to_pop = self.deploy_cfg.get( - # 'symbolic_record_to_pop', []) - # module_record_backup = {} - # function_record_backup = {} - # symbolic_record_backup = {} - # for record in module_record_to_pop: - # records = rewriter_context._rewriter_manager.module_rewriter. \ - # _registry._rewrite_records - # if record in records: - # module_record_backup[record] = records.pop(record) - # for record in function_record_to_pop: - # records = rewriter_context._rewriter_manager.function_rewriter. \ - # _registry._rewrite_records - # if record in records: - # function_record_backup[record] = records.pop(record) - # - # for record in symbolic_record_to_pop: - # records = rewriter_context._rewriter_manager.symbolic_rewriter. \ - # _registry._rewrite_records - # if record in records: - # symbolic_record_backup[record] = records.pop(record) + # Pop function records in `quantizer.tracer.skipped_method` temporarily + function_record_backup = self._pop_function_record_in_rewriter_context( + rewriter_context) qmodels = nn.ModuleDict() for mode in self.forward_modes: @@ -325,12 +296,9 @@ def _build_qmodels(self, model: BaseModel): observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module - # rewriter_context._rewriter_manager.module_rewriter. \ - # _registry._rewrite_records.update(module_record_backup) - # rewriter_context._rewriter_manager.function_rewriter. \ - # _registry._rewrite_records.update(function_record_backup) - # rewriter_context._rewriter_manager.symbolic_rewriter. \ - # _registry._rewrite_records.update(symbolic_record_backup) + # Add these popped function records back. + rewriter_context._rewriter_manager.function_rewriter. \ + _registry._rewrite_records.update(function_record_backup) # data_samples can not be None in detectors during prediction. # But we need to make the dummy prediction in _build_qmodels. From 70a91b382a5fb7070e39aae7a2286e2c7d152c1d Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Fri, 7 Apr 2023 16:48:27 +0800 Subject: [PATCH 19/44] [Feature] Quantization global optimization (#491) * add trtquantizer * unify all fakequant before deploy * move to aide * add yolox config * pre-rebase * add unittest * add a arg of post_process_for_deploy * test trt yolox deploy * opt quantizer interface * fix rebase * add trt r50 config * update trt setting * del redundant code * fix lint * fix ut of quantizers * del redundant file * fix lint * fix some comments --- ..._tensorrt-int8-explicit_dynamic-224x224.py | 39 ++++ .../detection_openvino_dynamic-800x1344.py | 3 +- ...int8-explicit_dynamic-320x320-1344x1344.py | 57 ++++++ ...tq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 10 +- ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 4 +- ...penvino_resnet50_8xb32_in1k_calib32xb32.py | 10 +- ...openvino_retina_r50_1x_coco_calib32xb32.py | 7 +- ...vino_yolox_s_8xb8-300e_coco_calib32xb32.py | 57 ++++++ ...tq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py | 54 ++++++ ...ensorrt_resnet18_8xb32_in1k_calib32xb32.py | 51 ++++++ ...ensorrt_resnet50_8xb32_in1k_calib32xb32.py | 51 ++++++ ...tensorrt_retina_r50_1x_coco_calib32xb32.py | 53 ++++++ ...orrt_yolox_s_8xb8-300e_coco_calib32xb32.py | 58 ++++++ .../quantization/mm_architecture.py | 23 ++- mmrazor/models/quantizers/__init__.py | 4 +- mmrazor/models/quantizers/native_quantizer.py | 39 +++- .../models/quantizers/openvino_quantizer.py | 36 +--- .../models/quantizers/tensorrt_quantizer.py | 81 +++++---- .../quantization/backend_config/tensorrt.py | 27 +-- mmrazor/structures/quantization/qconfig.py | 171 +++++++++++------- .../test_quantizers/test_native_quantizer.py | 10 +- .../test_openvino_quantizer.py | 23 --- .../test_tensorrt_quantizer.py | 23 --- tests/test_structures/test_qconfig.py | 43 ++++- 24 files changed, 699 insertions(+), 235 deletions(-) create mode 100644 configs/quantization/deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py create mode 100644 configs/quantization/deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py create mode 100644 configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py create mode 100644 configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py create mode 100644 configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py create mode 100644 configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py create mode 100644 configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py create mode 100644 configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py diff --git a/configs/quantization/deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py b/configs/quantization/deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py new file mode 100644 index 000000000..debdd3ad5 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py @@ -0,0 +1,39 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=[224, 224], + optimize=True, + dynamic_axes=dict( + input=dict({ + 0: 'batch', + 2: 'height', + 3: 'width' + }), + output=dict({0: 'batch'}))), + codebase_config=dict(type='mmcls', task='Classification'), + backend_config=dict( + type='tensorrt', + common_config=dict( + fp16_mode=False, + max_workspace_size=1073741824, + int8_mode=True, + explicit_quant_mode=True), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 224, 224], + opt_shape=[4, 3, 224, 224], + max_shape=[8, 3, 224, 224]))) + ]), + function_record_to_pop=[ + 'mmcls.models.classifiers.ImageClassifier.forward', + 'mmcls.models.classifiers.BaseClassifier.forward' + ], +) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py index f8122ecaa..c76898d0b 100644 --- a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py +++ b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py @@ -43,6 +43,5 @@ function_record_to_pop=[ 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', - 'mmdet.models.detectors.single_stage_instance_seg.' - 'SingleStageInstanceSegmentor.forward' + 'mmdet.models.detectors.single_stage_instance_seg.SingleStageInstanceSegmentor.forward' # noqa: E501 ]) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py new file mode 100644 index 000000000..f515a6ee9 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py @@ -0,0 +1,57 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['dets', 'labels'], + input_shape=None, + optimize=True, + dynamic_axes=dict( + input=dict({ + 0: 'batch', + 2: 'height', + 3: 'width' + }), + dets=dict({ + 0: 'batch', + 1: 'num_dets' + }), + labels=dict({ + 0: 'batch', + 1: 'num_dets' + }))), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + model_type='end2end', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1)), + backend_config=dict( + type='tensorrt', + common_config=dict( + fp16_mode=False, + max_workspace_size=1073741824, + int8_mode=True, + explicit_quant_mode=True), + model_inputs=[ + dict( + input_shapes=dict( + input=dict( + min_shape=[1, 3, 320, 320], + opt_shape=[1, 3, 800, 1344], + max_shape=[1, 3, 1344, 1344]))) + ]), + function_record_to_pop=[ + 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', + 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', + 'mmdet.models.detectors.single_stage_instance_seg.SingleStageInstanceSegmentor.forward' # noqa: E501 + ]) diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index 7c919c0fd..43f63f208 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -1,8 +1,13 @@ -_base_ = ['mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py'] +_base_ = [ + 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py', + '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] + +val_dataloader = dict(batch_size=32) test_cfg = dict( type='mmrazor.PTQLoop', - calibrate_dataloader=_base_.train_dataloader, + calibrate_dataloader=val_dataloader, calibrate_steps=32, ) @@ -31,6 +36,7 @@ # convert image from BGR to RGB to_rgb=True), architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 93e3897bc..9bd9f55e4 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -3,11 +3,11 @@ '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' ] -train_dataloader = dict(batch_size=32) +val_dataloader = dict(batch_size=32) test_cfg = dict( type='mmrazor.PTQLoop', - calibrate_dataloader=train_dataloader, + calibrate_dataloader=val_dataloader, calibrate_steps=32, ) diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index f629337ed..14f65968e 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -1,10 +1,13 @@ -_base_ = ['mmcls::resnet/resnet50_8xb32_in1k.py'] +_base_ = [ + 'mmcls::resnet/resnet50_8xb32_in1k.py', + '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] -train_dataloader = dict(batch_size=32) +val_dataloader = dict(batch_size=32) test_cfg = dict( type='mmrazor.PTQLoop', - calibrate_dataloader=train_dataloader, + calibrate_dataloader=val_dataloader, calibrate_steps=32, ) @@ -33,6 +36,7 @@ # convert image from BGR to RGB to_rgb=True), architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 109a7ee04..3ec47e479 100644 --- a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -3,15 +3,14 @@ '../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' ] -train_dataloader = dict(batch_size=32) +val_dataloader = dict(batch_size=32) test_cfg = dict( type='mmrazor.PTQLoop', - calibrate_dataloader=train_dataloader, + calibrate_dataloader=val_dataloader, calibrate_steps=32, ) -retina = _base_.model float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 global_qconfig = dict( @@ -34,7 +33,7 @@ std=[58.395, 57.12, 57.375], bgr_to_rgb=True, pad_size_divisor=32), - architecture=retina, + architecture=_base_.model, deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( diff --git a/configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py new file mode 100644 index 000000000..4ce17fe69 --- /dev/null +++ b/configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py @@ -0,0 +1,57 @@ +_base_ = [ + 'mmdet::yolox/yolox_s_8xb8-300e_coco.py', + '../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' +] + +val_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=val_dataloader, + calibrate_steps=32, +) + +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='mmdet.BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +custom_hooks = [] diff --git a/configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..681b7dabc --- /dev/null +++ b/configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py @@ -0,0 +1,54 @@ +_base_ = [ + 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py', + '../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 +] + +val_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..313dd195e --- /dev/null +++ b/configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py @@ -0,0 +1,51 @@ +_base_ = [ + 'mmcls::resnet/resnet18_8xb32_in1k.py', + '../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 +] + +val_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py new file mode 100644 index 000000000..0bd4a083a --- /dev/null +++ b/configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py @@ -0,0 +1,51 @@ +_base_ = [ + 'mmcls::resnet/resnet50_8xb32_in1k.py', + '../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 +] + +val_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=val_dataloader, + calibrate_steps=32, +) + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501 + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +model_wrapper_cfg = dict(type='mmrazor.MMArchitectureQuantDDP', ) diff --git a/configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py new file mode 100644 index 000000000..088e6e043 --- /dev/null +++ b/configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py @@ -0,0 +1,53 @@ +_base_ = [ + 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', + '../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 +] + +val_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=val_dataloader, + calibrate_steps=32, +) + +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=32), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.anchor_head.AnchorHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) diff --git a/configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py new file mode 100644 index 000000000..01e8cc0b5 --- /dev/null +++ b/configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py @@ -0,0 +1,58 @@ +_base_ = [ + 'mmdet::yolox/yolox_s_8xb8-300e_coco.py', + '../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 +] + +val_dataloader = dict(batch_size=32) + +test_cfg = dict( + type='mmrazor.PTQLoop', + calibrate_dataloader=val_dataloader, + calibrate_steps=32, +) + +float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +model = dict( + _delete_=True, + type='mmrazor.MMArchitectureQuant', + data_preprocessor=dict( + type='mmdet.DetDataPreprocessor', + pad_size_divisor=32, + batch_augments=[ + dict( + type='mmdet.BatchSyncRandomResize', + random_size_range=(480, 800), + size_divisor=32, + interval=10) + ]), + architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.predict_by_feat', # noqa: E501 + 'mmdet.models.dense_heads.yolox_head.YOLOXHead.loss_by_feat', + ]))) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +custom_hooks = [] diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 53dd0f6cf..2f7f3914e 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -11,6 +11,7 @@ from torch import nn from mmrazor.registry import MODEL_WRAPPERS, MODELS +from mmrazor.structures.quantization import QConfigHandler from ..base import BaseAlgorithm, BaseModel try: @@ -32,7 +33,7 @@ @MODELS.register_module() class MMArchitectureQuant(BaseAlgorithm): - """General quantization. + """General quantization for OpenMMLab's models. Args: architecture (Union[Dict, BaseModel]): The config of model to be @@ -60,7 +61,7 @@ class MMArchitectureQuant(BaseAlgorithm): def __init__(self, architecture: Union[Dict, BaseModel], quantizer: Union[Dict, BaseModel], - deploy_cfg: Union[str, Dict], + deploy_cfg: Optional[Union[str, Dict]] = None, data_preprocessor: Optional[Dict] = None, forward_modes: Tuple = ('tensor', 'predict', 'loss'), float_checkpoint: Optional[str] = None, @@ -348,17 +349,27 @@ def get_deploy_model(self): 3. post process weight fakequant for exporting .onnx that meet the backend's requirement. """ - + device = next(self.parameters()).device quantized_state_dict = self.qmodels['predict'].state_dict() fp32_model = self.architecture self.quantizer.convert_batchnorm2d(fp32_model) - observed_model = self.quantizer.prepare(fp32_model) - observed_model.load_state_dict(quantized_state_dict) self.quantizer.post_process_for_deploy( - observed_model, keep_w_fake_quant=True) + observed_model, device=device, keep_w_fake_quant=True) + + # replace various activation fakequant with base fakequant, which + # contributes to deploy our model to various backends. + for node in observed_model.graph.nodes: + if 'activation_post_process_' in node.name: + module_name = node.target + module = getattr(observed_model, module_name) + fakequant_new = QConfigHandler.replace_fakequant( + module, + self.quantizer.qconfig.a_qscheme, + update_qparams=True) + setattr(observed_model, module_name, fakequant_new) observed_model.apply(disable_observer) diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py index 9d4fa1a28..a26bb1322 100644 --- a/mmrazor/models/quantizers/__init__.py +++ b/mmrazor/models/quantizers/__init__.py @@ -1,11 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .academic_quantizer import AcademicQuantizer from .base import BaseQuantizer -from .native_quantizer import NativeQuantizer +from .native_quantizer import TorchNativeQuantizer from .openvino_quantizer import OpenVINOQuantizer from .tensorrt_quantizer import TensorRTQuantizer __all__ = [ - 'BaseQuantizer', 'AcademicQuantizer', 'NativeQuantizer', + 'BaseQuantizer', 'AcademicQuantizer', 'TorchNativeQuantizer', 'TensorRTQuantizer', 'OpenVINOQuantizer' ] diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index c1edb6fe4..7b6f2f9ad 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -5,7 +5,8 @@ from mmengine.config import Config try: - from torch.ao.quantization import disable_observer, enable_fake_quant + from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) from torch.ao.quantization.fx import prepare from torch.ao.quantization.fx.graph_module import ObservedGraphModule from torch.ao.quantization.qconfig_mapping import ( @@ -22,6 +23,7 @@ ObservedGraphModule = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') disable_observer = get_placeholder('torch>=1.13') + enable_observer = get_placeholder('torch>=1.13') prepare = get_placeholder('torch>=1.13') QConfigMapping = get_placeholder('torch>=1.13') _fuse_fx = get_placeholder('torch>=1.13') @@ -86,7 +88,7 @@ def fake_quantize_per_tensor_affine(g, x, scale, zero_point, quant_min, @MODELS.register_module() -class NativeQuantizer(BaseQuantizer): +class TorchNativeQuantizer(BaseQuantizer): """Native class for quantizer. Args: @@ -250,6 +252,8 @@ def prepare(self, model, concrete_args=None): def post_process_for_deploy(self, observed_module: ObservedGraphModule, + device: str = 'cpu', + update_weight_with_fakequant: bool = False, keep_w_fake_quant: bool = False): """weight fake-quant for supported QAT modules. @@ -275,36 +279,53 @@ def traverse(module): # to perform these operations and do dequantize to # introduce quantization loss in advance. weight_fakequant = child.weight_fake_quant - child.weight.data = weight_fakequant(child.weight.data) # `to_float()` function fuse BN into conv or conv_relu, and # also convert a qat module to a normal module. # source url: https://github.com/pytorch/pytorch/blob/master/torch/nn/intrinsic/qat/modules/conv_fused.py # noqa: E501 float_child = child.to_float() + if update_weight_with_fakequant: + from torch.ao.nn.intrinsic import _FusedModule + if issubclass(type(float_child), _FusedModule): + float_child[0].weight.data = weight_fakequant( + float_child[0].weight.data) + else: + float_child.weight.data = weight_fakequant( + float_child.weight.data) # This is decided by backend type, some backend need # explicitly keep the fake quant structure, others don't. # TODO add deploy doc link if keep_w_fake_quant: + # make weight fakequant fixed as the consistent + # fakequant, it will help to deploy our model to + # various backends. + self.qconfig.fixed_w_fakequant() for m in float_child.modules(): setattr(m, 'qconfig', self.qconfig.convert()) - if type(child) in MERGE_BN_MAPPINGS: cls = MERGE_BN_MAPPINGS[type(child)] - new_child = cls.from_float(float_child) + new_child = cls.from_float(float_child).to(device) else: - new_child = type(child).from_float(float_child) - + new_child = type(child).from_float(float_child).to( + device) + + # because weight fakequants and observers are replaced + # with base fakequants and base observers, some + # initialized args need to be update by running + # weight_fake_quant. + enable_observer(new_child) new_child.weight_fake_quant(new_child.weight) + disable_observer(new_child) else: - new_child = float_child + new_child = float_child.to(device) setattr(module, name, new_child) else: traverse(child) - traverse(observed_module) observed_module.apply(enable_fake_quant) observed_module.apply(disable_observer) + traverse(observed_module) def del_redundant_fakequant(self, prepared: GraphModule): """delete redundant fakequant op in prepared model. diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 831d991f2..8f5ef3873 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -3,18 +3,12 @@ import torch -try: - from torch.ao.quantization import disable_observer -except ImportError: - from mmrazor.utils import get_placeholder - disable_observer = get_placeholder('torch>=1.13') - from mmrazor.registry import MODELS -from .native_quantizer import NativeQuantizer +from .native_quantizer import TorchNativeQuantizer @MODELS.register_module() -class OpenVINOQuantizer(NativeQuantizer): +class OpenVINOQuantizer(TorchNativeQuantizer): """Quantizer for quantizing and deploying to Openvino backend. Each backend has its own features, for reducing the gap of quantized @@ -46,32 +40,6 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') - def prepare_for_mmdeploy(self, - model: torch.nn.Module, - dummy_input: Tuple = (1, 3, 224, 224), - checkpoint: Optional[str] = None): - """Prepare for deploy to the backend with mmdeploy, which will be used - in mmdeploy, and usually includes as follows: - - 1. prepare for the float model rewritten by mmdeploy. - 2. load checkpoint consists of float weight and quantized params in - mmrazor. - 3. post process weight fakequant for exporting .onnx that meet - the backend's requirement. - """ - self.convert_batchnorm2d(model) - observed_model = self.prepare(model) - if dummy_input is not None: - observed_model(torch.randn(dummy_input)) - if checkpoint is not None: - observed_model.load_state_dict( - torch.load(checkpoint)['state_dict']) - self.post_process_for_deploy(observed_model, keep_w_fake_quant=True) - - observed_model.apply(disable_observer) - - return observed_model - def export_onnx(self, model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction], diff --git a/mmrazor/models/quantizers/tensorrt_quantizer.py b/mmrazor/models/quantizers/tensorrt_quantizer.py index 028c96a8c..be067fd4f 100644 --- a/mmrazor/models/quantizers/tensorrt_quantizer.py +++ b/mmrazor/models/quantizers/tensorrt_quantizer.py @@ -1,20 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Tuple +from typing import Any, Optional, Tuple, Union import torch -try: - from torch.ao.quantization import disable_observer -except ImportError: - from mmrazor.utils import get_placeholder - disable_observer = get_placeholder('torch>=1.13') - from mmrazor.registry import MODELS -from .native_quantizer import NativeQuantizer +from .native_quantizer import TorchNativeQuantizer @MODELS.register_module() -class TensorRTQuantizer(NativeQuantizer): +class TensorRTQuantizer(TorchNativeQuantizer): """Quantizer for quantizing and deploying to TensorRT backend. Each backend has its own features, for reducing the gap of quantized @@ -44,28 +38,47 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') - def prepare_for_mmdeploy(self, - model: torch.nn.Module, - dummy_input: Tuple = (1, 3, 224, 224), - checkpoint: Optional[str] = None): - """Prepare for deploy to the backend with mmdeploy, which will be used - in mmdeploy, and usually includes as follows: - - 1. prepare for the float model rewritten by mmdeploy. - 2. load checkpoint consists of float weight and quantized params in - mmrazor. - 3. post process weight fakequant for exporting .onnx that meet - the backend's requirement. - """ - observed_model = self.prepare(model) - if dummy_input is not None: - observed_model(torch.randn(dummy_input)) - if checkpoint is not None: - observed_model.load_state_dict( - torch.load(checkpoint)['state_dict']) - self.post_process_weight_fakequant( - observed_model, keep_fake_quant=True) - - observed_model.apply(disable_observer) - - return observed_model + def export_onnx(self, + model: Union[torch.nn.Module, torch.jit.ScriptModule, + torch.jit.ScriptFunction], + args: Union[Tuple[Any, ...], torch.Tensor], + output_path: str, + opset_version: Optional[int] = 13, + **kwargs): + """Export the onnx model that can be deployed to OpenVino backend.""" + + symbolic_output_path = output_path.replace('.onnx', '_symbolic.onnx') + torch.onnx.export( + model, + args, + symbolic_output_path, + opset_version=opset_version, + **kwargs) + + from .exporters import TensorRTExplicitExporter + exporter = TensorRTExplicitExporter(symbolic_output_path, output_path) + exporter.export() + + @property + def module_prev_wo_fakequant(self): + """Configurate the modules that their previous nodes are redundant + fakequants.""" + return (torch.nn.ReLU6, torch.nn.Identity) + + @property + def module_next_wo_fakequant(self): + """Configurate the modules that their next nodes are redundant + fakequants.""" + return (torch.nn.MaxPool2d, ) + + @property + def method_next_wo_fakequant(self): + """Configurate the methods that their next nodes are redundant + fakequants.""" + return ('flatten', ) + + @property + def op_prev_wo_fakequant(self): + """Configurate the OPs that their previous nodes are redundant + fakequants.""" + return ('output', ) diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py index 791463233..8dddbac91 100644 --- a/mmrazor/structures/quantization/backend_config/tensorrt.py +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -13,10 +13,8 @@ DTypeConfig = get_placeholder('torch>=1.13') ObservationType = get_placeholder('torch>=1.13') -from .common_operator_config_utils import (_get_binary_op_configs, - _get_conv_configs, - _get_linear_configs, - _get_share_qparams_op_configs) +from .common_operator_config_utils import (_get_conv_configs, + _get_linear_configs) def get_tensorrt_backend_config() -> BackendConfig: @@ -33,10 +31,6 @@ def get_tensorrt_backend_config() -> BackendConfig: weight_dtype=torch.qint8, bias_dtype=torch.float, ) - non_weighted_op_qint8_dtype_config = DTypeConfig( - input_dtype=torch.qint8, - output_dtype=torch.qint8, - ) addmm_config = BackendPatternConfig(torch.addmm) \ .set_observation_type( @@ -47,34 +41,19 @@ def get_tensorrt_backend_config() -> BackendConfig: 'input': 1, 'weight': 2, }) - cat_config = BackendPatternConfig(torch.cat) \ - .set_observation_type( - ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT) \ - .add_dtype_config(non_weighted_op_qint8_dtype_config) conv_dtype_configs = [ weighted_op_qint8_dtype_config, ] linear_dtype_configs = [ weighted_op_qint8_dtype_config, ] - binary_op_dtype_configs = [ - weighted_op_qint8_dtype_config, - ] - share_qparams_op_dtype_configs = [ - non_weighted_op_qint8_dtype_config, - ] # there might be things not supported in fx2trt, but it will error out # during fx2trt conversion and can support them after that return BackendConfig('tensorrt') \ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ .set_backend_pattern_config(addmm_config) \ - .set_backend_pattern_config(cat_config) \ - .set_backend_pattern_configs( - _get_linear_configs(linear_dtype_configs)) \ - .set_backend_pattern_configs( - _get_binary_op_configs(binary_op_dtype_configs)) \ .set_backend_pattern_configs( - _get_share_qparams_op_configs(share_qparams_op_dtype_configs)) + _get_linear_configs(linear_dtype_configs)) def get_tensorrt_backend_config_dict(): diff --git a/mmrazor/structures/quantization/qconfig.py b/mmrazor/structures/quantization/qconfig.py index 2a502b8f7..ab682be39 100644 --- a/mmrazor/structures/quantization/qconfig.py +++ b/mmrazor/structures/quantization/qconfig.py @@ -5,10 +5,13 @@ from mmengine.config import Config try: - from torch.ao.quantization import QConfig + from torch.ao.quantization import FakeQuantize, QConfig + from torch.ao.quantization.utils import is_per_tensor except ImportError: from mmrazor.utils import get_placeholder QConfig = get_placeholder('torch>=1.13') + FakeQuantize = get_placeholder('torch>=1.13') + is_per_tensor = get_placeholder('torch>=1.13') from mmrazor.registry import MODELS @@ -17,66 +20,10 @@ 'a_observer' ] - -class QConfigHandler(): - """Convert custom user-friendly qconfig format to torch's QConfig. - - Args: - qconfig (Dict | Config): custom user-friendly qconfig format, - including setting observers, fakequants and quantization schemes - for weights and activations. - Note: - whether quantization scheme is per-channel or not depends on - used observer, if observer support per-channel quantization, its name - should contain 'PerChannel'. - """ - - def __init__(self, qconfig: Union[Dict, Config]): - if not self.check_qconfig(qconfig): - raise ValueError('The format of qconfig is incorrect.') - else: - w_observer = MODELS.get(qconfig['w_observer']['type']) - a_observer = MODELS.get(qconfig['a_observer']['type']) - w_is_per_channel = False - a_is_per_channel = False - # import pdb;pdb.set_trace() - if 'PerChannel' in w_observer.__name__: - w_is_per_channel = True - if 'PerChannel' in a_observer.__name__: - a_is_per_channel = True - self.w_qscheme = QSchemeHandler( - is_per_channel=w_is_per_channel, **qconfig['w_qscheme']) - self.a_qscheme = QSchemeHandler( - is_per_channel=a_is_per_channel, **qconfig['a_qscheme']) - - w_fake_quant = MODELS.get(qconfig['w_fake_quant']['type']) - w_observer_kwargs = self.w_qscheme.to_observer_params() - a_fake_quant = MODELS.get(qconfig['a_fake_quant']['type']) - a_observer_kwargs = self.a_qscheme.to_observer_params() - - self.w_fake_quant = w_fake_quant.with_args( - observer=w_observer, **w_observer_kwargs) - self.a_fake_quant = a_fake_quant.with_args( - observer=a_observer, **a_observer_kwargs) - - @staticmethod - def check_qconfig(qconfig: Union[Dict, Config]): - """Check whether the passed qconfig's format meets requirement.""" - is_pass = True - for arg in RequiredArgs: - val = qconfig.get(arg, None) - if isinstance(val, dict) and arg in qconfig.keys(): - continue - else: - is_pass = False - break - return is_pass - - def convert(self): - """Generate torch's QConfig with built fake_quants.""" - torch_qconfig = QConfig( - weight=self.w_fake_quant, activation=self.a_fake_quant) - return torch_qconfig +RetainArgsPerTensor = [ + 'dtype', 'qscheme', 'quant_min', 'quant_max', 'reduce_range' +] +RetainArgsPerChannel = RetainArgsPerTensor + ['ch_axis'] class QSchemeHandler(object): @@ -149,3 +96,105 @@ def __str__(self): return f'dtype: {self.dtype} / bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ is_per_channel: {self.is_per_channel} \ / extra_kwargs: {self.kwargs}' + + +class QConfigHandler(): + """Convert custom user-friendly qconfig format to torch's QConfig. + + Args: + qconfig (Dict | Config): custom user-friendly qconfig format, + including setting observers, fakequants and quantization schemes + for weights and activations. + Note: + whether quantization scheme is per-channel or not depends on + used observer, if observer support per-channel quantization, its name + should contain 'PerChannel'. + """ + + def __init__(self, qconfig: Union[Dict, Config]): + if not self.check_qconfig(qconfig): + raise ValueError('The format of qconfig is incorrect.') + else: + w_observer = MODELS.get(qconfig['w_observer']['type']) + a_observer = MODELS.get(qconfig['a_observer']['type']) + w_is_per_channel = False + a_is_per_channel = False + # import pdb;pdb.set_trace() + if 'PerChannel' in w_observer.__name__: + w_is_per_channel = True + if 'PerChannel' in a_observer.__name__: + a_is_per_channel = True + self.w_qscheme = QSchemeHandler( + is_per_channel=w_is_per_channel, **qconfig['w_qscheme']) + self.a_qscheme = QSchemeHandler( + is_per_channel=a_is_per_channel, **qconfig['a_qscheme']) + + w_fake_quant = MODELS.get(qconfig['w_fake_quant']['type']) + w_observer_kwargs = self.w_qscheme.to_observer_params() + a_fake_quant = MODELS.get(qconfig['a_fake_quant']['type']) + a_observer_kwargs = self.a_qscheme.to_observer_params() + + self.w_fake_quant = w_fake_quant.with_args( + observer=w_observer, **w_observer_kwargs) + self.a_fake_quant = a_fake_quant.with_args( + observer=a_observer, **a_observer_kwargs) + + @staticmethod + def check_qconfig(qconfig: Union[Dict, Config]): + """Check whether the passed qconfig's format meets requirement.""" + is_pass = True + for arg in RequiredArgs: + val = qconfig.get(arg, None) + if isinstance(val, dict) and arg in qconfig.keys(): + continue + else: + is_pass = False + break + return is_pass + + def convert(self): + """Generate torch's QConfig with built fake_quants.""" + torch_qconfig = QConfig( + weight=self.w_fake_quant, activation=self.a_fake_quant) + return torch_qconfig + + @staticmethod + def replace_fakequant(fake_quant_org: FakeQuantize, + qscheme_org: QSchemeHandler, + update_qparams: bool = True): + """Replace origin fakequants in model with the specified fakequant, + which is in favor of deploying the quantized model.""" + assert isinstance(qscheme_org, QSchemeHandler) + observer_kwargs = qscheme_org.to_observer_params() + if is_per_tensor(observer_kwargs['qscheme']): + observer = MODELS.get('MinMaxObserver') + retain_args = RetainArgsPerTensor + else: + observer = MODELS.get('PerChannelMinMaxObserver') + retain_args = RetainArgsPerChannel + pop_keys = [] + for k in observer_kwargs.keys(): + if k not in retain_args: + pop_keys.append(k) + for k in pop_keys: + observer_kwargs.pop(k) + fake_quant = MODELS.get('FakeQuantize') + fake_quant_wrapper = fake_quant.with_args( + observer=observer, **observer_kwargs) + if update_qparams: + device = fake_quant_org.scale.device + fake_quant_ins = fake_quant_wrapper().to(device) + fake_quant_ins.scale.copy_(fake_quant_org.scale) + fake_quant_ins.zero_point.copy_(fake_quant_org.zero_point) + obs = fake_quant_ins.activation_post_process + obs_org = fake_quant_org.activation_post_process + obs.min_val.resize_(obs_org.min_val.shape).copy_(obs_org.min_val) + obs.max_val.resize_(obs_org.max_val.shape).copy_(obs_org.max_val) + return fake_quant_ins + else: + return fake_quant_wrapper + + def fixed_w_fakequant(self): + """Make `self.w_fake_quant` fixed as the consistent fakequant.""" + self.w_fake_quant = self.replace_fakequant( + self.w_fake_quant(), self.w_qscheme, update_qparams=False) diff --git a/tests/test_models/test_quantizers/test_native_quantizer.py b/tests/test_models/test_quantizers/test_native_quantizer.py index 06a12c206..8f982c139 100644 --- a/tests/test_models/test_quantizers/test_native_quantizer.py +++ b/tests/test_models/test_quantizers/test_native_quantizer.py @@ -5,7 +5,7 @@ import torch.nn as nn from mmrazor import digit_version -from mmrazor.models.quantizers import NativeQuantizer +from mmrazor.models.quantizers import TorchNativeQuantizer from mmrazor.models.quantizers.native_quantizer import SUPPORT_QAT_MODULES from mmrazor.models.task_modules.tracer import CustomTracer from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ @@ -107,14 +107,14 @@ def forward(self, x): ] q_kwargs = dict( - type='mmrazor.NativeQuantizer', + type='mmrazor.TorchNativeQuantizer', global_qconfig=global_qconfig, no_observer_modules=no_observer_modules, tracer=dict(type='CustomTracer'), ) -class TestNativeQuantizer(TestCase): +class TestTorchNativeQuantizer(TestCase): """TODO. Args: @@ -155,7 +155,7 @@ def test_init(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') native_quantizer = MODELS.build(self.q_kwargs) - self.assertIsInstance(native_quantizer, NativeQuantizer) + self.assertIsInstance(native_quantizer, TorchNativeQuantizer) def test_prepare(self): if digit_version(torch.__version__) < digit_version('1.13.0'): @@ -184,7 +184,7 @@ def test_prepare(self): prepared = self.native_quantizer.del_redundant_fakequant(prepared) assert isinstance(prepared, GraphModule) - def test_post_process_weight_fakequant(self): + def post_process_for_deploy(self): if digit_version(torch.__version__) < digit_version('1.13.0'): self.skipTest('version of torch < 1.13.0') toy_model = ToyQuantModel() diff --git a/tests/test_models/test_quantizers/test_openvino_quantizer.py b/tests/test_models/test_quantizers/test_openvino_quantizer.py index 24fc81ca4..7b60dc4a3 100644 --- a/tests/test_models/test_quantizers/test_openvino_quantizer.py +++ b/tests/test_models/test_quantizers/test_openvino_quantizer.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os import shutil import tempfile from copy import copy @@ -54,25 +53,3 @@ def test_property(self): assert quantizer.module_next_wo_fakequant assert quantizer.method_next_wo_fakequant assert quantizer.op_prev_wo_fakequant - - def test_prepare_for_mmdeploy(self): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') - - global_qconfig = copy(self.global_qconfig) - quantizer = OpenVINOQuantizer(global_qconfig=global_qconfig) - model = copy(self.model) - - # test checkpoint is None - prepared_deploy = quantizer.prepare_for_mmdeploy(model=model) - assert isinstance(prepared_deploy, ObservedGraphModule) - - # test checkpoint is not None - ckpt_path = os.path.join(self.temp_dir, - 'test_prepare_for_mmdeploy.pth') - model = copy(self.model) - prepared = quantizer.prepare(model) - torch.save({'state_dict': prepared.state_dict()}, ckpt_path) - prepared_deploy = quantizer.prepare_for_mmdeploy( - model=model, checkpoint=ckpt_path) - assert isinstance(prepared_deploy, ObservedGraphModule) diff --git a/tests/test_models/test_quantizers/test_tensorrt_quantizer.py b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py index aeae311f3..f5433a0f9 100644 --- a/tests/test_models/test_quantizers/test_tensorrt_quantizer.py +++ b/tests/test_models/test_quantizers/test_tensorrt_quantizer.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os import shutil import tempfile from copy import copy @@ -50,25 +49,3 @@ def test_property(self): assert quantizer.backend == 'tensorrt' assert quantizer.support_w_modes == ('per_tensor', 'per_channel') assert quantizer.support_a_modes == ('per_tensor') - - def test_prepare_for_mmdeploy(self): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') - - global_qconfig = copy(self.global_qconfig) - quantizer = TensorRTQuantizer(global_qconfig=global_qconfig) - model = copy(self.model) - - # test checkpoint is None - prepared_deploy = quantizer.prepare_for_mmdeploy(model=model) - assert isinstance(prepared_deploy, ObservedGraphModule) - - # test checkpoint is not None - ckpt_path = os.path.join(self.temp_dir, - 'test_prepare_for_mmdeploy.pth') - model = copy(self.model) - prepared = quantizer.prepare(model) - torch.save({'state_dict': prepared.state_dict()}, ckpt_path) - prepared_deploy = quantizer.prepare_for_mmdeploy( - model=model, checkpoint=ckpt_path) - assert isinstance(prepared_deploy, ObservedGraphModule) diff --git a/tests/test_structures/test_qconfig.py b/tests/test_structures/test_qconfig.py index d4f98394a..7ab78243d 100644 --- a/tests/test_structures/test_qconfig.py +++ b/tests/test_structures/test_qconfig.py @@ -6,14 +6,16 @@ from mmengine.config import Config try: - from torch.ao.quantization import QConfig + from torch.ao.quantization import FakeQuantize, QConfig except ImportError: from mmrazor.utils import get_placeholder QConfig = get_placeholder('torch>=1.13') + FakeQuantize = get_placeholder('torch>=1.13') from mmrazor import digit_version from mmrazor.models.fake_quants import register_torch_fake_quants from mmrazor.models.observers import register_torch_observers +from mmrazor.registry import MODELS from mmrazor.structures import QConfigHandler, QSchemeHandler register_torch_observers() @@ -129,3 +131,42 @@ def test_convert(self): qconfig = QConfigHandler(self.qconfig) torch_qconfig = qconfig.convert() assert isinstance(torch_qconfig, QConfig) + + def test_replace_fakequant(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + # update_qparams is False + qconfig = QConfigHandler(self.qconfig) + org_fakequant_ins = qconfig.w_fake_quant() + new_fakequant = qconfig.replace_fakequant( + org_fakequant_ins, qconfig.w_qscheme, update_qparams=False) + new_fakequant_ins = new_fakequant() + assert isinstance(new_fakequant_ins, FakeQuantize) + assert isinstance(new_fakequant_ins.activation_post_process, + MODELS.get('PerChannelMinMaxObserver')) + + # update_qparams is True + qconfig = QConfigHandler(self.qconfig) + org_fakequant_ins = qconfig.w_fake_quant() + org_fakequant_ins.scale = torch.Tensor([2]) + org_fakequant_ins.activation_post_process.min_val = torch.Tensor([1]) + new_fakequant_ins = qconfig.replace_fakequant( + org_fakequant_ins, qconfig.w_qscheme, update_qparams=True) + assert isinstance(new_fakequant_ins, FakeQuantize) + assert isinstance(new_fakequant_ins.activation_post_process, + MODELS.get('PerChannelMinMaxObserver')) + assert new_fakequant_ins.scale == org_fakequant_ins.scale + assert new_fakequant_ins.activation_post_process.min_val == \ + org_fakequant_ins.activation_post_process.min_val + + def test_fixed_w_fakequant(self): + if digit_version(torch.__version__) < digit_version('1.13.0'): + self.skipTest('version of torch < 1.13.0') + + qconfig = QConfigHandler(self.qconfig) + qconfig.fixed_w_fakequant() + new_fakequant_ins = qconfig.w_fake_quant() + assert isinstance(new_fakequant_ins, FakeQuantize) + assert isinstance(new_fakequant_ins.activation_post_process, + MODELS.get('PerChannelMinMaxObserver')) From a339aeda7e1e559b885bddfd2fccb040378a6482 Mon Sep 17 00:00:00 2001 From: wm901115nwpu Date: Mon, 10 Apr 2023 16:11:16 +0800 Subject: [PATCH 20/44] Fix code syntax in UT (#470) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 王盟 --- tests/test_models/test_task_modules/test_custom_tracer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py index fcb02f381..2d01ea496 100644 --- a/tests/test_models/test_task_modules/test_custom_tracer.py +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -27,7 +27,7 @@ class ToyModel(torch.nn.Module): def __init__(self): - super.__init__() + super().__init__() def get_loss(self, x): return x * 0.1 From 8dad35749abf7abcdd725bc4c2d9adc2f64a6cb3 Mon Sep 17 00:00:00 2001 From: humu789 Date: Tue, 11 Apr 2023 15:34:35 +0800 Subject: [PATCH 21/44] passed lint and pytest --- ...pact_pointrend_resnet50_8xb2_cityscapes.py | 4 - mmrazor/engine/__init__.py | 12 +- mmrazor/models/algorithms/__init__.py | 2 +- mmrazor/models/algorithms/nas/autoslim.py | 2 - .../algorithms/pruning/ite_prune_algorithm.py | 4 - .../mutable_channel/units/channel_unit.py | 5 - .../channel_mutator/channel_mutator.py | 18 +- .../one_shot_channel_mutator.py | 4 +- .../slimmable_channel_mutator.py | 2 +- mmrazor/models/mutators/group_mixin.py | 68 ----- .../module_mutator/diff_module_mutator.py | 117 --------- .../mutators/module_mutator/module_mutator.py | 94 ------- .../models/mutators/value_mutator/__init__.py | 5 - .../value_mutator/dynamic_value_mutator.py | 14 -- .../mutators/value_mutator/value_mutator.py | 73 ------ .../models/task_modules/tracer/__init__.py | 2 +- mmrazor/models/utils/__init__.py | 2 +- mmrazor/utils/__init__.py | 5 +- mmrazor/utils/placeholder.py | 32 +++ tests/test_data.py | 8 - .../test_algorithms/test_mm_architecture.py | 3 +- .../test_mutators/test_diff_mutator.py | 235 ------------------ .../test_mutators/test_value_mutator.py | 66 ----- 23 files changed, 53 insertions(+), 724 deletions(-) delete mode 100644 mmrazor/models/mutators/module_mutator/diff_module_mutator.py delete mode 100644 mmrazor/models/mutators/module_mutator/module_mutator.py delete mode 100644 mmrazor/models/mutators/value_mutator/__init__.py delete mode 100644 mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py delete mode 100644 mmrazor/models/mutators/value_mutator/value_mutator.py delete mode 100644 tests/test_models/test_mutators/test_diff_mutator.py delete mode 100644 tests/test_models/test_mutators/test_value_mutator.py diff --git a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py index a0d0d044a..e6c1eb031 100644 --- a/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py +++ b/configs/pruning/mmseg/dcff/dcff_compact_pointrend_resnet50_8xb2_cityscapes.py @@ -1,11 +1,7 @@ _base_ = ['dcff_pointrend_resnet50_8xb2_cityscapes.py'] # model settings -<<<<<<< HEAD _base_.model = dict( -======= -model_cfg = dict( ->>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) _scope_='mmrazor', type='sub_model', cfg=_base_.architecture, diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index d3d8e6981..8b0d4a692 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -10,10 +10,10 @@ SubnetValLoop) __all__ = [ - 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', - 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', - 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'EstimateResourcesHook', 'SelfDistillValLoop', - 'AutoSlimGreedySearchLoop', 'SubnetValLoop', 'PTQLoop', - 'QATEpochBasedLoop', 'LSQEpochBasedLoop', 'QATValLoop' + 'DMCPSubnetHook', 'StopDistillHook', 'SeparateOptimWrapperConstructor', + 'DumpSubnetHook', 'SingleTeacherDistillValLoop', + 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', + 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'EstimateResourcesHook', + 'SelfDistillValLoop', 'AutoSlimGreedySearchLoop', 'SubnetValLoop', + 'PTQLoop', 'QATEpochBasedLoop', 'LSQEpochBasedLoop', 'QATValLoop' ] diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 26305b226..178cc6535 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -15,6 +15,6 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP', 'DMCP', 'DMCPDDP', 'GeneralQuant', 'MMArchitectureQuant', + 'BigNASDDP', 'DMCP', 'DMCPDDP', 'MMArchitectureQuant', 'MMArchitectureQuantDDP' ] diff --git a/mmrazor/models/algorithms/nas/autoslim.py b/mmrazor/models/algorithms/nas/autoslim.py index 77bb6cacc..dc8d54c0e 100644 --- a/mmrazor/models/algorithms/nas/autoslim.py +++ b/mmrazor/models/algorithms/nas/autoslim.py @@ -75,8 +75,6 @@ def __init__(self, self._optim_wrapper_count_status_reinitialized = False self.norm_training = norm_training - self.bn_training_mode = bn_training_mode - def _build_mutator(self, mutator: VALID_MUTATOR_TYPE = None) -> ChannelMutator: """Build mutator.""" diff --git a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py index f510acd76..937aaa156 100644 --- a/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py +++ b/mmrazor/models/algorithms/pruning/ite_prune_algorithm.py @@ -10,7 +10,6 @@ from mmrazor.models.mutables import MutableChannelUnit from mmrazor.models.mutators import ChannelMutator from mmrazor.registry import MODELS -from mmrazor.utils import ValidFixMutable from ..base import BaseAlgorithm LossResults = Dict[str, torch.Tensor] @@ -98,8 +97,6 @@ class ItePruneAlgorithm(BaseAlgorithm): mutator_cfg (Union[Dict, ChannelMutator], optional): The config of a mutator. Defaults to dict( type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')). - fix_subnet (str | dict | :obj:`FixSubnet`): The path of yaml file or - loaded dict or built :obj:`FixSubnet`. Defaults to None. data_preprocessor (Optional[Union[Dict, nn.Module]], optional): Defaults to None. target_pruning_ratio (dict, optional): The prune-target. The template @@ -121,7 +118,6 @@ def __init__(self, type='ChannelMutator', channel_unit_cfg=dict( type='SequentialMutableChannelUnit')), - fix_subnet: Optional[ValidFixMutable] = None, data_preprocessor: Optional[Union[Dict, nn.Module]] = None, target_pruning_ratio: Optional[Dict[str, float]] = None, step_freq=1, diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py index ea8681511..e730245d4 100644 --- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -212,11 +212,6 @@ def alias(self) -> str: """str: alias of the unit""" return self.name - @property - def alias(self) -> str: - """str: alias of the unit""" - return self.name - def config_template(self, with_init_args=False, with_channels=False) -> Dict: diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 2d83d48f7..910992e1e 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -12,11 +12,10 @@ from mmrazor.models.task_modules.tracer.channel_analyzer import ChannelAnalyzer from mmrazor.registry import MODELS, TASK_UTILS from ..base_mutator import BaseMutator -from ..group_mixin import GroupMixin @MODELS.register_module() -class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin): +class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): """ChannelMutator manages the pruning structure of a model. Args: @@ -46,10 +45,6 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin): demo_input=(1, 3, 224, 224), tracer_type='BackwardTracer') - custom_groups (list[list[str]], optional): User-defined search groups. - All searchable modules that are not in ``custom_group`` will be - grouped separately. - init_cfg (dict, optional): initialization configuration dict for BaseModule. @@ -97,10 +92,6 @@ def __init__(self, self._parse_channel_unit_cfg( channel_unit_cfg) - if custom_groups is None: - custom_groups = [] - self._custom_groups = custom_groups - def prepare_from_supernet(self, supernet: Module) -> None: """Prepare from a model for pruning. @@ -238,9 +229,10 @@ def set_choices(self, choices: Dict[str, Any]) -> None: @property def current_choices(self) -> Dict: """Get current choices.""" - current_choices = dict() - for group_id, modules in self.search_groups.items(): - current_choices[group_id] = modules[0].current_choice + config = self.choice_template + for unit in self.mutable_units: + config[unit.name] = unit.current_choice + return config @property def choice_template(self) -> Dict: diff --git a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py index 3aca98c95..cc008b0b8 100644 --- a/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/one_shot_channel_mutator.py @@ -4,13 +4,11 @@ from mmrazor.models.mutables import OneShotMutableChannelUnit from mmrazor.registry import MODELS -from ..group_mixin import DynamicSampleMixin from .channel_mutator import ChannelMutator, ChannelUnitType @MODELS.register_module() -class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit], - DynamicSampleMixin): +class OneShotChannelMutator(ChannelMutator[OneShotMutableChannelUnit]): """OneShotChannelMutator based on ChannelMutator. It use OneShotMutableChannelUnit by default. diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py index b00e0ef22..c3da419bf 100644 --- a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py @@ -29,7 +29,7 @@ def __init__(self, tracer_type='BackwardTracer'), init_cfg: Optional[Dict] = None) -> None: - super().__init__(channel_unit_cfg, parse_cfg, None, init_cfg) + super().__init__(channel_unit_cfg, parse_cfg, init_cfg) self.subnets = self._prepare_subnets(self.units_cfg) diff --git a/mmrazor/models/mutators/group_mixin.py b/mmrazor/models/mutators/group_mixin.py index 3ecd44b74..569f01ebc 100644 --- a/mmrazor/models/mutators/group_mixin.py +++ b/mmrazor/models/mutators/group_mixin.py @@ -8,11 +8,6 @@ from mmrazor.models.mutables.mutable_module import MutableModule from .base_mutator import MUTABLE_TYPE -if sys.version_info < (3, 8): - from typing_extensions import Protocol -else: - from typing import Protocol - class GroupMixin(): """A mixin for :class:`BaseMutator`, which can group mutables by @@ -264,66 +259,3 @@ def _check_valid_groups(self, alias2mutable_names: Dict[str, List[str]], f'When a mutable is set alias attribute :{alias_key},' f'the corresponding module name {mutable_name} should ' f'not be used in `custom_group` {custom_group}.') - - -class MutatorProtocol(Protocol): # pragma: no cover - - @property - def mutable_class_type(self) -> Type[BaseMutable]: - ... - - @property - def search_groups(self) -> Dict: - ... - - -class OneShotSampleMixin: - """Sample mixin for one-shot mutators.""" - - def sample_choices(self: MutatorProtocol) -> Dict: - """Sample choices for each group in search_groups.""" - random_choices = dict() - for group_id, modules in self.search_groups.items(): - random_choices[group_id] = modules[0].sample_choice() - - return random_choices - - def set_choices(self: MutatorProtocol, choices: Dict) -> None: - """Set choices for each group in search_groups.""" - for group_id, modules in self.search_groups.items(): - choice = choices[group_id] - for module in modules: - module.current_choice = choice - - -class DynamicSampleMixin(OneShotSampleMixin): - - def sample_choices(self: MutatorProtocol, kind: str = 'random') -> Dict: - """Sample choices for each group in search_groups.""" - random_choices = dict() - for group_id, modules in self.search_groups.items(): - if kind == 'max': - random_choices[group_id] = modules[0].max_choice - elif kind == 'min': - random_choices[group_id] = modules[0].min_choice - else: - random_choices[group_id] = modules[0].sample_choice() - return random_choices - - @property - def max_choice(self: MutatorProtocol) -> Dict: - """Get max choices for each group in search_groups.""" - max_choice = dict() - for group_id, modules in self.search_groups.items(): - max_choice[group_id] = modules[0].max_choice - - return max_choice - - @property - def min_choice(self: MutatorProtocol) -> Dict: - """Get min choices for each group in search_groups.""" - min_choice = dict() - for group_id, modules in self.search_groups.items(): - min_choice[group_id] = modules[0].min_choice - - return min_choice diff --git a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py deleted file mode 100644 index 1f639ed28..000000000 --- a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Optional - -import torch -import torch.nn as nn - -from mmrazor.registry import MODELS -from ...mutables import DiffMutableModule -from .module_mutator import ModuleMutator - - -@MODELS.register_module() -class DiffModuleMutator(ModuleMutator): - """Differentiable mutable based mutator. - - `DiffModuleMutator` is responsible for mutating `DiffMutableModule`, - `DiffMutableOP`, `DiffChoiceRoute` and `GumbelChoiceRoute`. - The architecture parameters of the mutables are retained - in `DiffModuleMutator`. - - Args: - custom_group (list[list[str]], optional): User-defined search groups. - All searchable modules that are not in ``custom_group`` will be - grouped separately. - """ - - def __init__(self, - custom_groups: Optional[List[List[str]]] = None, - init_cfg: Optional[Dict] = None) -> None: - super().__init__(custom_groups=custom_groups, init_cfg=init_cfg) - - def build_arch_param(self, num_choices) -> nn.Parameter: - """Build learnable architecture parameters.""" - return nn.Parameter(torch.randn(num_choices) * 1e-3) - - def prepare_from_supernet(self, supernet: nn.Module) -> None: - """Inherit from ``BaseMutator``'s, generate `arch_params` in DARTS. - - Args: - supernet (:obj:`torch.nn.Module`): The architecture to be used - in your algorithm. - """ - - super().prepare_from_supernet(supernet) - self.arch_params = self.build_arch_params() - self.modify_supernet_forward(self.arch_params) - - def build_arch_params(self): - """This function will build many arch params, which are generally used - in differentiable search algorithms, such as Darts' series. Each - group_id corresponds to an arch param, so the Mutables with the same - group_id share the same arch param. - - Returns: - torch.nn.ParameterDict: the arch params are got by `search_groups`. - """ - - arch_params = nn.ParameterDict() - - for group_id, modules in self.search_groups.items(): - group_arch_param = self.build_arch_param(modules[0].num_choices) - arch_params[str(group_id)] = group_arch_param - - return arch_params - - def modify_supernet_forward(self, arch_params): - """Modify the DiffMutableModule's default arch_param in forward. - - In MMRazor, the `arch_param` is along to `DiffModuleMutator`, while the - `DiffMutableModule` needs `arch_param` in the forward. Here we use - partial function to assign the corresponding `arch_param` to each - `DiffMutableModule`. - """ - - for group_id, mutables in self.search_groups.items(): - for m in mutables: - m.set_forward_args(arch_param=arch_params[str(group_id)]) - - def sample_choices(self): - """Sampling by search groups. - - The sampling result of the first mutable of each group is the sampling - result of this group. - - Returns: - Dict[int, Any]: Random choices dict. - """ - - choices = dict() - for group_id, mutables in self.search_groups.items(): - arch_param = self.arch_params[str(group_id)] - choice = mutables[0].sample_choice(arch_param) - choices[group_id] = choice - return choices - - def set_choices(self, choices: Dict[int, Any]) -> None: - """Set mutables' current choice according to choices sample by - :func:`sample_choices`. - - Args: - choices (Dict[int, Any]): Choices dict. The key is group_id in - search groups, and the value is the sampling results - corresponding to this group. - """ - for group_id, mutables in self.search_groups.items(): - choice = choices[group_id] - for m in mutables: - m.current_choice = choice - - @property - def mutable_class_type(self): - """Differentiable mutable class type. - - Returns: - Type[DiffMutableModule]: Class type of differentiable mutable. - """ - return DiffMutableModule diff --git a/mmrazor/models/mutators/module_mutator/module_mutator.py b/mmrazor/models/mutators/module_mutator/module_mutator.py deleted file mode 100644 index f30e933e0..000000000 --- a/mmrazor/models/mutators/module_mutator/module_mutator.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from abc import abstractmethod -from typing import Dict, List, Optional, Type - -from torch.nn import Module - -from ..base_mutator import MUTABLE_TYPE, BaseMutator -from ..group_mixin import GroupMixin - - -class ModuleMutator(BaseMutator[MUTABLE_TYPE], GroupMixin): - """The base class for mutable based mutator. - - All subclass should implement the following APIS: - - - ``mutable_class_type`` - - Args: - custom_groups (list[list[str]], optional): User-defined search groups. - All searchable modules that are not in ``custom_group`` will be - grouped separately. - """ - - def __init__(self, - custom_groups: Optional[List[List[str]]] = None, - init_cfg: Optional[Dict] = None) -> None: - super().__init__(init_cfg) - - if custom_groups is None: - custom_groups = [] - self._custom_groups = custom_groups - self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None - - # TODO - # should be a class property - @property - @abstractmethod - def mutable_class_type(self) -> Type[MUTABLE_TYPE]: - """Corresponding mutable class type. - - Returns: - Type[MUTABLE_TYPE]: Mutable class type. - """ - - def prepare_from_supernet(self, supernet: Module) -> None: - """Do some necessary preparations with supernet. - - Note: - For mutable based mutator, we need to build search group first. - - Args: - supernet (:obj:`torch.nn.Module`): The supernet to be searched - in your algorithm. - """ - self._search_groups = self.build_search_groups(supernet, - self.mutable_class_type, - self._custom_groups) - - @property - def name2mutable(self) -> Dict[str, MUTABLE_TYPE]: - """Search space of supernet. - - Note: - To get the mapping: module name to mutable. - - Raises: - RuntimeError: Called before search space has been parsed. - - Returns: - Dict[str, MUTABLE_TYPE]: The name2mutable dict. - """ - if self._name2mutable is None: - raise RuntimeError( - 'Call `prepare_from_supernet` before access name2mutable!') - return self._name2mutable - - @property - def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]: - """Search group of supernet. - - Note: - For mutable based mutator, the search group is composed of - corresponding mutables. - - Raises: - RuntimeError: Called before search group has been built. - - Returns: - Dict[int, List[MUTABLE_TYPE]]: Search group. - """ - if self._search_groups is None: - raise RuntimeError( - 'Call `prepare_from_supernet` before access search group!') - return self._search_groups diff --git a/mmrazor/models/mutators/value_mutator/__init__.py b/mmrazor/models/mutators/value_mutator/__init__.py deleted file mode 100644 index a29577bb1..000000000 --- a/mmrazor/models/mutators/value_mutator/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .dynamic_value_mutator import DynamicValueMutator -from .value_mutator import ValueMutator - -__all__ = ['ValueMutator', 'DynamicValueMutator'] diff --git a/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py b/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py deleted file mode 100644 index d8d081343..000000000 --- a/mmrazor/models/mutators/value_mutator/dynamic_value_mutator.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmrazor.models.mutables import OneShotMutableValue -from mmrazor.registry import MODELS -from ..group_mixin import DynamicSampleMixin -from .value_mutator import ValueMutator - - -@MODELS.register_module() -class DynamicValueMutator(ValueMutator, DynamicSampleMixin): - """Dynamic value mutator with type as `OneShotMutableValue`.""" - - @property - def mutable_class_type(self): - return OneShotMutableValue diff --git a/mmrazor/models/mutators/value_mutator/value_mutator.py b/mmrazor/models/mutators/value_mutator/value_mutator.py deleted file mode 100644 index 5127cbe37..000000000 --- a/mmrazor/models/mutators/value_mutator/value_mutator.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Type - -from torch.nn import Module - -from mmrazor.models.mutables import MutableValue -from mmrazor.registry import MODELS -from ..base_mutator import BaseMutator -from ..group_mixin import GroupMixin - - -@MODELS.register_module() -class ValueMutator(BaseMutator[MutableValue], GroupMixin): - """The base class for mutable based mutator. All subclass should implement - the following APIS: - - - ``mutable_class_type`` - Args: - custom_group (list[list[str]], optional): User-defined search groups. - All searchable modules that are not in ``custom_group`` will be - grouped separately. - """ - - def __init__(self, - custom_group: Optional[List[List[str]]] = None, - init_cfg: Optional[Dict] = None) -> None: - super().__init__(init_cfg) - - if custom_group is None: - custom_group = [] - self._custom_group = custom_group - self._search_groups: Optional[Dict[int, List[MutableValue]]] = None - - # TODO - # should be a class property - @property - def mutable_class_type(self) -> Type[MutableValue]: - """Corresponding mutable class type. - - Returns: - Type[MUTABLE_TYPE]: Mutable class type. - """ - return MutableValue - - def prepare_from_supernet(self, supernet: Module) -> None: - """Do some necessary preparations with supernet. - - Note: - For mutable based mutator, we need to build search group first. - Args: - supernet (:obj:`torch.nn.Module`): The supernet to be searched - in your algorithm. - """ - self._search_groups = self.build_search_groups(supernet, - self.mutable_class_type, - self._custom_group) - - @property - def search_groups(self) -> Dict[int, List[MutableValue]]: - """Search group of supernet. - - Note: - For mutable based mutator, the search group is composed of - corresponding mutables. - Raises: - RuntimeError: Called before search group has been built. - Returns: - Dict[int, List[MUTABLE_TYPE]]: Search group. - """ - if self._search_groups is None: - raise RuntimeError( - 'Call `prepare_from_supernet` before access search group!') - return self._search_groups diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index 5ba623f5e..987030d81 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -12,6 +12,6 @@ __all__ = [ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', - 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', + 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', 'build_graphmodule' ] diff --git a/mmrazor/models/utils/__init__.py b/mmrazor/models/utils/__init__.py index 8ff8abdae..52f7cdc8c 100644 --- a/mmrazor/models/utils/__init__.py +++ b/mmrazor/models/utils/__init__.py @@ -3,8 +3,8 @@ from .misc import add_prefix from .optim_wrapper import reinitialize_optim_wrapper_count_status from .parse_values import parse_values -from .utils import get_module_device, set_requires_grad from .quantization_util import str2class +from .utils import get_module_device, set_requires_grad __all__ = [ 'make_divisible', 'add_prefix', 'reinitialize_optim_wrapper_count_status', diff --git a/mmrazor/utils/__init__.py b/mmrazor/utils/__init__.py index a69480e94..7d23ca632 100644 --- a/mmrazor/utils/__init__.py +++ b/mmrazor/utils/__init__.py @@ -2,7 +2,7 @@ from .index_dict import IndexDict from .log_tools import get_level, print_log from .misc import find_latest_checkpoint -from .placeholder import get_placeholder +from .placeholder import get_package_placeholder, get_placeholder from .runtime_info import RuntimeInfo from .setup_env import register_all_modules, setup_multi_processes from .typing import (FixMutable, MultiMutatorsRandomSubnet, @@ -13,5 +13,6 @@ 'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules', 'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet', 'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder', - 'IndexDict', 'get_level', 'print_log', 'RuntimeInfo' + 'IndexDict', 'get_level', 'print_log', 'RuntimeInfo', + 'get_package_placeholder' ] diff --git a/mmrazor/utils/placeholder.py b/mmrazor/utils/placeholder.py index 553223b20..9af35f7a4 100644 --- a/mmrazor/utils/placeholder.py +++ b/mmrazor/utils/placeholder.py @@ -23,3 +23,35 @@ def __init__(self) -> None: raise_import_error(string) return PlaceHolder + + +def get_package_placeholder(string: str) -> object: + """Get placeholder instance which can avoid raising errors when down-stream + dependency is not installed properly. + + Args: + string (str): the dependency's name, i.e. `mmcls` + + Raises: + ImportError: raise it when the dependency is not installed properly. + + Returns: + object: PlaceHolder instance. + """ + + def raise_import_error(package_name): + raise ImportError( + f'`{package_name}` is not installed properly, plz check.') + + class PlaceHolderMetaclass(type): + """Used to support usage of PlaceHolder.xxxx.""" + + def __getattr__(self, name): + raise_import_error(string) + + class PlaceHolder(metaclass=PlaceHolderMetaclass): + + def __init__(self) -> None: + raise_import_error(string) + + return PlaceHolder diff --git a/tests/test_data.py b/tests/test_data.py index d56a2950b..df3e07f69 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -6,13 +6,8 @@ from .data.model_library import (DefaultModelLibrary, MMClsModelLibrary, MMDetModelLibrary, MMModelLibrary, -<<<<<<< HEAD MMPoseModelLibrary, MMSegModelLibrary, ModelGenerator, TorchModelLibrary) -======= - MMSegModelLibrary, ModelGenerator, - TorchModelLibrary) ->>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) from .data.models import SingleLineModel from .data.tracer_passed_models import (BackwardPassedModelManager, FxPassedModelManager) @@ -50,7 +45,6 @@ def test_mmseg(self): if not TEST_DATA: self.skipTest('not test data to save time.') library = MMSegModelLibrary() -<<<<<<< HEAD print(library.short_names()) self.assertTrue(library.is_default_includes_cover_all_models()) @@ -61,8 +55,6 @@ def test_mmpose(self): self.skipTest('not test data to save time.') library = MMPoseModelLibrary() print(library.short_names()) -======= ->>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) self.assertTrue(library.is_default_includes_cover_all_models()) def test_get_model_by_config(self): diff --git a/tests/test_models/test_algorithms/test_mm_architecture.py b/tests/test_models/test_algorithms/test_mm_architecture.py index 639e0f492..4285694a9 100644 --- a/tests/test_models/test_algorithms/test_mm_architecture.py +++ b/tests/test_models/test_algorithms/test_mm_architecture.py @@ -2,7 +2,7 @@ import os import shutil import tempfile -from unittest import TestCase +from unittest import TestCase, skip import torch import torch.nn as nn @@ -101,6 +101,7 @@ def forward(self, inputs, data_samples, mode: str = 'tensor'): return outputs +@skip class TestMMArchitectureQuant(TestCase): def setUp(self): diff --git a/tests/test_models/test_mutators/test_diff_mutator.py b/tests/test_models/test_mutators/test_diff_mutator.py deleted file mode 100644 index 663637fc9..000000000 --- a/tests/test_models/test_mutators/test_diff_mutator.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from unittest import TestCase - -import pytest -import torch.nn as nn - -from mmrazor.models import * # noqa: F401,F403 -from mmrazor.models.mutables import DiffMutableModule -from mmrazor.models.mutators import DiffModuleMutator -from mmrazor.registry import MODELS - -MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) -MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) -MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) - - -class SearchableLayer(nn.Module): - - def __init__(self, mutable_cfg: dict) -> None: - super().__init__() - self.op1 = MODELS.build(mutable_cfg) - self.op2 = MODELS.build(mutable_cfg) - self.op3 = MODELS.build(mutable_cfg) - - def forward(self, x): - x = self.op1(x) - x = self.op2(x) - return self.op3(x) - - -class SearchableModel(nn.Module): - - def __init__(self, mutable_cfg: dict) -> None: - super().__init__() - self.slayer1 = SearchableLayer(mutable_cfg) - self.slayer2 = SearchableLayer(mutable_cfg) - self.slayer3 = SearchableLayer(mutable_cfg) - - def forward(self, x): - x = self.slayer1(x) - x = self.slayer2(x) - return self.slayer3(x) - - -class SearchableLayerAlias(nn.Module): - - def __init__(self, mutable_cfg: dict) -> None: - super().__init__() - mutable_cfg.update(alias='op1') - self.op1 = MODELS.build(mutable_cfg) - mutable_cfg.update(alias='op2') - self.op2 = MODELS.build(mutable_cfg) - mutable_cfg.update(alias='op3') - self.op3 = MODELS.build(mutable_cfg) - - def forward(self, x): - x = self.op1(x) - x = self.op2(x) - return self.op3(x) - - -class SearchableModelAlias(nn.Module): - - def __init__(self, mutable_cfg: dict) -> None: - super().__init__() - self.slayer1 = SearchableLayerAlias(mutable_cfg) - self.slayer2 = SearchableLayerAlias(mutable_cfg) - self.slayer3 = SearchableLayerAlias(mutable_cfg) - - def forward(self, x): - x = self.slayer1(x) - x = self.slayer2(x) - return self.slayer3(x) - - -class TestDiffModuleMutator(TestCase): - - def setUp(self): - self.MUTABLE_CFG = dict( - type='DiffMutableOP', - candidates=dict( - torch_conv2d_3x3=dict( - type='torchConv2d', - kernel_size=3, - padding=1, - ), - torch_conv2d_5x5=dict( - type='torchConv2d', - kernel_size=5, - padding=2, - ), - torch_conv2d_7x7=dict( - type='torchConv2d', - kernel_size=7, - padding=3, - ), - ), - module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) - - self.MUTATOR_CFG = dict( - type='DiffModuleMutator', - custom_groups=[['op1'], ['op2'], ['op3']]) - - def test_diff_mutator_diffop_layer(self) -> None: - model = SearchableLayer(self.MUTABLE_CFG) - mutator: DiffModuleMutator = MODELS.build(self.MUTATOR_CFG) - - mutator.prepare_from_supernet(model) - assert list(mutator.search_groups.keys()) == [0, 1, 2] - - def test_diff_mutator_diffop_model(self) -> None: - model = SearchableModel(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [ - ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], - ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], - ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], - ] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - - mutator.prepare_from_supernet(model) - assert list(mutator.search_groups.keys()) == [0, 1, 2] - - mutator.modify_supernet_forward(mutator.arch_params) - assert mutator.mutable_class_type == DiffMutableModule - - def test_diff_mutator_diffop_model_error(self) -> None: - model = SearchableModel(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [ - ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], - ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], - ['slayer1.op3', 'slayer2.op3', 'slayer3.op3_error_key'], - ] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - - with pytest.raises(AssertionError): - mutator.prepare_from_supernet(model) - - def test_diff_mutator_diffop_alias(self) -> None: - model = SearchableModelAlias(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [['op1'], ['op2'], ['op3']] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - - mutator.prepare_from_supernet(model) - - assert list(mutator.search_groups.keys()) == [0, 1, 2] - - mutator.modify_supernet_forward(mutator.arch_params) - assert mutator.mutable_class_type == DiffMutableModule - - def test_diff_mutator_alias_module_name(self) -> None: - """Using both alias and module name for grouping.""" - model = SearchableModelAlias(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [['op1'], - [ - 'slayer1.op2', 'slayer2.op2', - 'slayer3.op2' - ], ['slayer1.op3', 'slayer2.op3']] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - - mutator.prepare_from_supernet(model) - - assert list(mutator.search_groups.keys()) == [0, 1, 2, 3] - - mutator.modify_supernet_forward(mutator.arch_params) - assert mutator.mutable_class_type == DiffMutableModule - - def test_diff_mutator_duplicate_keys(self) -> None: - model = SearchableModel(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [ - ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], - ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], - ['slayer1.op3', 'slayer2.op3', 'slayer2.op3'], - ] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - - with pytest.raises(AssertionError): - mutator.prepare_from_supernet(model) - - def test_diff_mutator_duplicate_key_alias(self) -> None: - model = SearchableModelAlias(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [ - ['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], - ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], - ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], - ] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - - with pytest.raises(AssertionError): - mutator.prepare_from_supernet(model) - - def test_diff_mutator_illegal_key(self) -> None: - model = SearchableModel(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [ - ['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], - ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], - ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], - ] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - - with pytest.raises(AssertionError): - mutator.prepare_from_supernet(model) - - def test_sample_and_set_choices(self): - model = SearchableModel(self.MUTABLE_CFG) - - mutator_cfg = self.MUTATOR_CFG.copy() - mutator_cfg['custom_groups'] = [ - ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], - ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], - ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], - ] - mutator: DiffModuleMutator = MODELS.build(mutator_cfg) - mutator.prepare_from_supernet(model) - choices = mutator.sample_choices() - mutator.set_choices(choices) - self.assertTrue(len(choices) == 3) - - -if __name__ == '__main__': - import unittest - unittest.main() diff --git a/tests/test_models/test_mutators/test_value_mutator.py b/tests/test_models/test_mutators/test_value_mutator.py deleted file mode 100644 index a76257a9e..000000000 --- a/tests/test_models/test_mutators/test_value_mutator.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import unittest - -import torch - -from mmrazor.models.mutables import MutableValue -from mmrazor.models.mutators import DynamicValueMutator -from tests.data.models import DynamicAttention, DynamicMMBlock - - -class TestValueMutator(unittest.TestCase): - - def test_models_with_predefined_dynamic_op(self): - for Model in [ - DynamicAttention, - ]: - with self.subTest(model=Model): - model = Model() - value_mutator = DynamicValueMutator() - value_mutator.prepare_from_supernet(model) - value_choices = value_mutator.sample_choices() - value_mutator.set_choices(value_choices) - - mutable_value_space = [] - for mutable_value, module in model.named_modules(): - if isinstance(module, MutableValue): - mutable_value_space.append(mutable_value) - elif hasattr(module, 'source_mutables'): - for each_mutables in module.source_mutables: - if isinstance(each_mutables, MutableValue): - mutable_value_space.append(each_mutables) - assert len( - value_mutator.search_groups) == len(mutable_value_space) - - x = torch.rand([2, 3, 224, 224]) - y = model(x) - self.assertEqual(list(y.shape), [2, 624]) - - def test_models_with_multiple_value(self): - for Model in [ - DynamicMMBlock, - ]: - with self.subTest(model=Model): - model = Model() - value_mutator = DynamicValueMutator() - value_mutator.prepare_from_supernet(model) - value_choices = value_mutator.sample_choices() - value_mutator.set_choices(value_choices) - - # TODO check DynamicMMBlock - mutable_value_space = [] - for mutable_value, module in model.named_modules(): - if isinstance(module, MutableValue): - mutable_value_space.append(mutable_value) - elif hasattr(module, 'source_mutables'): - for each_mutables in module.source_mutables: - if isinstance(each_mutables, MutableValue): - mutable_value_space.append(each_mutables) - count = 0 - for values in value_mutator.search_groups.values(): - count += len(values) - assert count == len(mutable_value_space) - - x = torch.rand([2, 3, 224, 224]) - y = model(x) - self.assertEqual(list(y[-1].shape), [2, 1984, 1, 1]) From b046908a8428b6e508522ba63049e830f59a17fa Mon Sep 17 00:00:00 2001 From: humu789 Date: Tue, 11 Apr 2023 17:04:55 +0800 Subject: [PATCH 22/44] try to fix ci --- mmrazor/engine/runner/quantization_loops.py | 3 ++ .../models/fake_quants/torch_fake_quants.py | 14 ++++----- mmrazor/models/observers/torch_observers.py | 31 +++++++++---------- 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index df0f4f76d..e392070c3 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -21,7 +21,10 @@ from mmrazor.models.fake_quants import (enable_param_learning, enable_static_estimate, enable_val) from mmrazor.registry import LOOPS +from mmrazor.models import register_torch_fake_quants, register_torch_observers +TORCH_observers = register_torch_observers() +TORCH_fake_quants = register_torch_fake_quants() @LOOPS.register_module() class QATEpochBasedLoop(EpochBasedTrainLoop): diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py index b477929ad..e7420a8d3 100644 --- a/mmrazor/models/fake_quants/torch_fake_quants.py +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -10,7 +10,12 @@ from mmrazor.utils import get_package_placeholder torch_fake_quant_src = get_package_placeholder('torch>=1.13') - +# TORCH_fake_quants = register_torch_fake_quants() +# TORCH_fake_quants including: +# FakeQuantize +# FakeQuantizeBase +# FixedQParamsFakeQuantize +# FusedMovingAvgObsFakeQuantize def register_torch_fake_quants() -> List[str]: """Register fake_quants in ``torch.ao.quantization.fake_quantize`` to the ``MODELS`` registry. @@ -31,10 +36,3 @@ def register_torch_fake_quants() -> List[str]: torch_fake_quants.append(module_name) return torch_fake_quants - -TORCH_fake_quants = register_torch_fake_quants() -# TORCH_fake_quants including: -# FakeQuantize -# FakeQuantizeBase -# FixedQParamsFakeQuantize -# FusedMovingAvgObsFakeQuantize diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 996314d27..2c2d49382 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -30,7 +30,20 @@ def reset_min_max_vals(self): PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals - +# TORCH_observers = register_torch_observers() +# TORCH_observers including: +# FixedQParamsObserver +# HistogramObserver +# MinMaxObserver +# MovingAverageMinMaxObserver +# MovingAveragePerChannelMinMaxObserver +# NoopObserver +# ObserverBase +# PerChannelMinMaxObserver +# PlaceholderObserver +# RecordingObserver +# ReuseInputObserver +# UniformQuantizationObserverBase def register_torch_observers() -> List[str]: """Register observers in ``torch.ao.quantization.observer`` to the ``MODELS`` registry. @@ -50,19 +63,3 @@ def register_torch_observers() -> List[str]: MODELS.register_module(module=_observer) torch_observers.append(module_name) return torch_observers - - -TORCH_observers = register_torch_observers() -# TORCH_observers including: -# FixedQParamsObserver -# HistogramObserver -# MinMaxObserver -# MovingAverageMinMaxObserver -# MovingAveragePerChannelMinMaxObserver -# NoopObserver -# ObserverBase -# PerChannelMinMaxObserver -# PlaceholderObserver -# RecordingObserver -# ReuseInputObserver -# UniformQuantizationObserverBase From 64a39e136bf8cdde9d098d1276c04b685c591844 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Tue, 11 Apr 2023 17:41:40 +0800 Subject: [PATCH 23/44] [Bug] Try to fix CI (#502) fix lint --- mmrazor/engine/runner/quantization_loops.py | 3 ++- mmrazor/models/fake_quants/torch_fake_quants.py | 2 +- mmrazor/models/observers/torch_observers.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index e392070c3..d694f3da8 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -18,14 +18,15 @@ from torch.utils.data import DataLoader +from mmrazor.models import register_torch_fake_quants, register_torch_observers from mmrazor.models.fake_quants import (enable_param_learning, enable_static_estimate, enable_val) from mmrazor.registry import LOOPS -from mmrazor.models import register_torch_fake_quants, register_torch_observers TORCH_observers = register_torch_observers() TORCH_fake_quants = register_torch_fake_quants() + @LOOPS.register_module() class QATEpochBasedLoop(EpochBasedTrainLoop): """`EpochBasedLoop` for `QuantizationAwareTraining` diff --git a/mmrazor/models/fake_quants/torch_fake_quants.py b/mmrazor/models/fake_quants/torch_fake_quants.py index e7420a8d3..06e325b32 100644 --- a/mmrazor/models/fake_quants/torch_fake_quants.py +++ b/mmrazor/models/fake_quants/torch_fake_quants.py @@ -10,6 +10,7 @@ from mmrazor.utils import get_package_placeholder torch_fake_quant_src = get_package_placeholder('torch>=1.13') + # TORCH_fake_quants = register_torch_fake_quants() # TORCH_fake_quants including: # FakeQuantize @@ -35,4 +36,3 @@ def register_torch_fake_quants() -> List[str]: MODELS.register_module(module=_fake_quant) torch_fake_quants.append(module_name) return torch_fake_quants - diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 2c2d49382..4e540667a 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -30,6 +30,7 @@ def reset_min_max_vals(self): PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals + # TORCH_observers = register_torch_observers() # TORCH_observers including: # FixedQParamsObserver From 9ff2b745b6ccd0a8b68235b38d8b75fb96905b7d Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 11 Apr 2023 18:33:01 +0800 Subject: [PATCH 24/44] [Feature] Support lsq (#501) * support deploy_cfg=None * replace fakequant before load ckpt * add _load_from_state_dict to lsq fakequant * fix pre-commit * test lsq load state dict * change github ci: ubuntu 18.04 to ubuntu 20.04 * get_deploy_model order change back * sync before save ckpt * delete strict=False * test context rewriter * fix pre commit config * try to fix ci * [Bug] Try to fix CI (#502) fix lint --------- Co-authored-by: humu789 Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> --- .github/workflows/build.yml | 2 +- mmrazor/engine/runner/quantization_loops.py | 3 +- .../quantization/mm_architecture.py | 45 ++++++----- mmrazor/models/fake_quants/lsq.py | 40 ++++++++++ .../test_algorithms/test_mm_architecture.py | 74 +++++++++++++++---- .../test_fake_quants/test_lsq_fake_quants.py | 24 +++++- 6 files changed, 150 insertions(+), 38 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e00ed24c8..4b99bced4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ concurrency: jobs: test_linux: - runs-on: ubuntu-18.04 + runs-on: ubuntu-20.04 strategy: matrix: python-version: [3.7] diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index d694f3da8..18caf06f5 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -91,7 +91,6 @@ def run(self): and self._epoch % self.val_interval == 0): # observer disabled during evaluation self.prepare_for_val() - self.runner.model.sync_qparams(src_mode='loss') self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -112,6 +111,7 @@ def run_epoch(self) -> None: for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) + self.runner.model.sync_qparams(src_mode='loss') self.runner.call_hook('after_train_epoch') self._epoch += 1 @@ -185,6 +185,7 @@ def run_epoch(self) -> None: self.runner.model.apply(enable_param_learning) self.run_iter(idx, data_batch) + self.runner.model.sync_qparams(src_mode='loss') self.runner.call_hook('after_train_epoch') self._epoch += 1 diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 2f7f3914e..6d5b49ae7 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -20,6 +20,7 @@ disable_observer) except ImportError: from mmrazor.utils import get_placeholder + FakeQuantizeBase = get_placeholder('torch>=1.13') MinMaxObserver = get_placeholder('torch>=1.13') PerChannelMinMaxObserver = get_placeholder('torch>=1.13') @@ -283,23 +284,31 @@ def _build_qmodels(self, model: BaseModel): """ rewriter_context = self._get_rewriter_context_in_mmdeploy( - self.deploy_cfg) + self.deploy_cfg) if self.deploy_cfg is not None else None - # Pop function records in `quantizer.tracer.skipped_method` temporarily - function_record_backup = self._pop_function_record_in_rewriter_context( - rewriter_context) + if rewriter_context is not None: + # Pop function records in `quantizer.tracer.skipped_method` + # temporarily + function_record_backup = \ + self._pop_function_record_in_rewriter_context(rewriter_context) qmodels = nn.ModuleDict() for mode in self.forward_modes: concrete_args = {'mode': mode} - # todo: support qat. - with rewriter_context: + + if rewriter_context is not None: + with rewriter_context: + observed_module = self.quantizer.prepare( + model, concrete_args) + else: observed_module = self.quantizer.prepare(model, concrete_args) + qmodels[mode] = observed_module - # Add these popped function records back. - rewriter_context._rewriter_manager.function_rewriter. \ - _registry._rewrite_records.update(function_record_backup) + if rewriter_context is not None: + # Add these popped function records back. + rewriter_context._rewriter_manager.function_rewriter. \ + _registry._rewrite_records.update(function_record_backup) # data_samples can not be None in detectors during prediction. # But we need to make the dummy prediction in _build_qmodels. @@ -357,7 +366,10 @@ def get_deploy_model(self): observed_model.load_state_dict(quantized_state_dict) self.quantizer.post_process_for_deploy( - observed_model, device=device, keep_w_fake_quant=True) + observed_model, + device=device, + keep_w_fake_quant=True, + update_weight_with_fakequant=True) # replace various activation fakequant with base fakequant, which # contributes to deploy our model to various backends. @@ -406,21 +418,14 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]): return self.module.calibrate_step(data) - def sync_qparams(self, src: str): + def sync_qparams(self, src_mode: str): """Same as in 'MMArchitectureQuant'. Sync all quantize parameters in different `forward_modes`. We could have several modes to generate graphs, but in training, only one graph will be update, so we need to sync qparams on the other graphs. Args: - src (str): The src modes of forward method. - - Note: - `traverse()` function recursively traverses all module to sync - quantized graph generated from different `forward_modes`. - This is because We have different mode ('tensor', 'predict', - 'loss') in OpenMMLab architecture which have different graph - in some subtle ways, so we need to sync them here. + src_mode (str): The src modes of forward method. """ - self.module.sync_qparams(src) + self.module.sync_qparams(src_mode) diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py index 270140b85..1689d0393 100644 --- a/mmrazor/models/fake_quants/lsq.py +++ b/mmrazor/models/fake_quants/lsq.py @@ -258,6 +258,46 @@ def forward(self, X): return X + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Removing this function throws an error that the the size of the + loaded tensor does not match the original size i.e., These buffers + start out with numel 0 and become numel 1 once they have their first + forward pass. + + Modified from https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fake_quantize.py # noqa:E501 + """ + local_state = ['scale', 'zero_point'] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == 'scale': + self.scale.data = self.scale.data.resize_(val.shape) + else: + assert name == 'zero_point' + self.zero_point.data = self.zero_point.data.resize_( + val.shape) + # For torchscript module we need to update the attributes here + # since we do not call the `_load_from_state_dict` function + # defined module.py + if torch.jit.is_scripting(): + if name == 'scale': + self.scale.copy_(val) + else: + assert name == 'zero_point' + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super(LearnableFakeQuantize, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + @torch.jit.export def extra_repr(self): """The printable representational string.""" diff --git a/tests/test_models/test_algorithms/test_mm_architecture.py b/tests/test_models/test_algorithms/test_mm_architecture.py index 4285694a9..310d42f5e 100644 --- a/tests/test_models/test_algorithms/test_mm_architecture.py +++ b/tests/test_models/test_algorithms/test_mm_architecture.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os import shutil import tempfile -from unittest import TestCase, skip +from unittest import TestCase, skipIf import torch import torch.nn as nn @@ -13,8 +14,15 @@ from mmrazor.utils import get_placeholder GraphModule = get_placeholder('torch>=1.13') +from mmengine import ConfigDict from mmengine.model import BaseModel +try: + import mmdeploy +except ImportError: + from mmrazor.utils import get_package_placeholder + mmdeploy = get_package_placeholder('mmdeploy') + from mmrazor import digit_version from mmrazor.models.algorithms import MMArchitectureQuant from mmrazor.registry import MODELS @@ -101,12 +109,44 @@ def forward(self, inputs, data_samples, mode: str = 'tensor'): return outputs -@skip +DEPLOY_CFG = ConfigDict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 0: 'batch' + } + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]), + codebase_config=dict(type='mmcls', task='Classification'), + function_record_to_pop=[ + 'mmcls.models.classifiers.ImageClassifier.forward', + 'mmcls.models.classifiers.BaseClassifier.forward' + ], +) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') class TestMMArchitectureQuant(TestCase): def setUp(self): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') MODELS.register_module(module=ToyQuantModel, force=True) @@ -116,7 +156,7 @@ def setUp(self): toymodel = ToyQuantModel() torch.save(toymodel.state_dict(), filename) - global_qconfig = dict( + global_qconfig = ConfigDict( w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), w_fake_quant=dict(type='mmrazor.FakeQuantize'), @@ -132,7 +172,7 @@ def setUp(self): is_symmetry=True, averaging_constant=0.1), ) - alg_kwargs = dict( + alg_kwargs = ConfigDict( type='mmrazor.MMArchitectureQuant', architecture=dict(type='ToyQuantModel'), float_checkpoint=filename, @@ -141,23 +181,23 @@ def setUp(self): global_qconfig=global_qconfig, tracer=dict(type='mmrazor.CustomTracer'))) self.alg_kwargs = alg_kwargs - self.toy_model = MODELS.build(self.alg_kwargs) def tearDown(self): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') MODELS.module_dict.pop('ToyQuantModel') shutil.rmtree(self.temp_dir) def test_init(self): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') + self.toy_model = MODELS.build(self.alg_kwargs) + assert isinstance(self.toy_model, MMArchitectureQuant) + assert hasattr(self.toy_model, 'quantizer') + + alg_kwargs = copy.deepcopy(self.alg_kwargs) + alg_kwargs.deploy_cfg = DEPLOY_CFG assert isinstance(self.toy_model, MMArchitectureQuant) assert hasattr(self.toy_model, 'quantizer') def test_sync_qparams(self): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') + self.toy_model = MODELS.build(self.alg_kwargs) mode = self.toy_model.forward_modes[0] self.toy_model.sync_qparams(mode) w_loss = self.toy_model.qmodels[ @@ -170,12 +210,16 @@ def test_sync_qparams(self): assert w_loss.equal(w_tensor) def test_build_qmodels(self): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') + self.toy_model = MODELS.build(self.alg_kwargs) for forward_modes in self.toy_model.forward_modes: qmodels = self.toy_model.qmodels[forward_modes] assert isinstance(qmodels, GraphModule) + def test_get_deploy_model(self): + self.toy_model = MODELS.build(self.alg_kwargs) + deploy_model = self.toy_model.get_deploy_model() + self.assertIsInstance(deploy_model, torch.fx.graph_module.GraphModule) + def test_calibrate_step(self): # TODO pass diff --git a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py index bd8fcbd50..dcbda5d40 100644 --- a/tests/test_models/test_fake_quants/test_lsq_fake_quants.py +++ b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py @@ -8,10 +8,12 @@ from mmrazor.models import LearnableFakeQuantize try: - from torch.ao.quantization import MovingAverageMinMaxObserver + from torch.ao.quantization import (MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver) except ImportError: from mmrazor.utils import get_placeholder MovingAverageMinMaxObserver = get_placeholder('torch>=1.13') + MovingAveragePerChannelMinMaxObserver = get_placeholder('torch>=1.13') class TestLearnableFakeQuantize(TestCase): @@ -38,6 +40,16 @@ def setUp(self): reduce_range=True, zero_point_trainable=False) + self.zero_point_untrainable_per_channel_fakequant = \ + LearnableFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, + quant_min=0, + quant_max=255, + dtype=torch.quint8, + qscheme=torch.per_channel_affine, + reduce_range=True, + zero_point_trainable=False) + def test_repr(self): fq_module = self.zero_point_untrainable_fakequant() repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, ' @@ -184,3 +196,13 @@ def test_state(self): self.assertEqual(fq_module.zero_point.requires_grad, 1) self.assertEqual(fq_module.fake_quant_enabled[0], 1) self.assertEqual(fq_module.static_enabled[0], 0) + + def test_load_state_dict(self): + fq_module = self.zero_point_untrainable_per_channel_fakequant() + state_dict = fq_module.state_dict() + X = torch.rand(32, 16, 3, 3, dtype=torch.float32) + # After forwarding, the shape of `scale` and `zero_point` in + # `fq_module` will be in shape (32, ), while the shape of those in + # `state_dict` are in shape (1, ). + _ = fq_module(X) + fq_module.load_state_dict(state_dict) From 5711ed9aa9efa81e65e833ef0bd146ac76c4536a Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Wed, 12 Apr 2023 17:37:11 +0800 Subject: [PATCH 25/44] [Feature] Add exporter pytest (#504) * add exporter pytest * fix bugs * delete useless codes * handle onnx * delete useless codes --- .../exporters/base_quantize_exporter.py | 13 +- .../exporters/openvino_quantize_exporter.py | 11 +- .../quantizers/exporters/optim_utils.py | 36 +- .../exporters/tensorrt_quantize_exporter.py | 7 +- requirements/tests.txt | 1 + .../test_quantizers/test_exporter.py | 348 ++++++++++++++++++ 6 files changed, 382 insertions(+), 34 deletions(-) create mode 100644 tests/test_models/test_quantizers/test_exporter.py diff --git a/mmrazor/models/quantizers/exporters/base_quantize_exporter.py b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py index 7e1e1f375..6527d3207 100644 --- a/mmrazor/models/quantizers/exporters/base_quantize_exporter.py +++ b/mmrazor/models/quantizers/exporters/base_quantize_exporter.py @@ -1,12 +1,18 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List -import onnx from mmengine import print_log -from onnx import numpy_helper from .optim_utils import ONNXOptimUtils +try: + import onnx + from onnx import numpy_helper +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') + numpy_helper = get_package_placeholder('No module named onnx.numpy_helper') + SUPPORT_QWEIGHT_NODE = ['Gemm', 'Conv', 'ConvTranspose'] PERCHANNEL_FAKEQUANTIZER = [ @@ -73,9 +79,6 @@ def _init_mappings_from_onnx(self, onnx_model): self.output2node = self.optimizer.map_output_and_node(onnx_model) self.name2data = self.optimizer.map_name_and_data(onnx_model) - # todo: maybe useless - # self.name2init = self.optimizer.map_name_and_initializer(onnx_model) - def _remap_input_and_node(self): """Rebuild the mapping from input name to a (node, input index) tuple.""" diff --git a/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py index 6d0df5d36..e706251ca 100644 --- a/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py +++ b/mmrazor/models/quantizers/exporters/openvino_quantize_exporter.py @@ -3,9 +3,16 @@ from typing import List import numpy as np -import onnx from google.protobuf.internal.containers import RepeatedScalarFieldContainer -from onnx import helper, numpy_helper + +try: + import onnx + from onnx import helper, numpy_helper +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') + numpy_helper = get_package_placeholder('No module named onnx.numpy_helper') + helper = get_package_placeholder('No module named onnx.helper') from .base_quantize_exporter import BaseQuantizeExportor diff --git a/mmrazor/models/quantizers/exporters/optim_utils.py b/mmrazor/models/quantizers/exporters/optim_utils.py index 62b348d1c..f4adc5ee1 100644 --- a/mmrazor/models/quantizers/exporters/optim_utils.py +++ b/mmrazor/models/quantizers/exporters/optim_utils.py @@ -2,9 +2,15 @@ import copy from typing import Dict, List, Optional -import onnx from mmengine import print_log -from onnx import numpy_helper + +try: + import onnx + from onnx import numpy_helper +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') + numpy_helper = get_package_placeholder('No module named onnx.numpy_helper') class ONNXOptimUtils(): @@ -62,30 +68,6 @@ def map_input_and_node(cls, onnx_model: onnx.ModelProto): input2node[input_name].append([node, idx]) return input2node - @classmethod - def get_constant(cls, name, onnx_model): - for node in onnx_model.graph.node: - if node.op_type == 'Constant': - if node.output[0] == name: - return numpy_helper.to_array(node.attribute[0].t).tolist() - - @classmethod - def get_initializer(cls, initializer_name, onnx_model): - return numpy_helper.to_array( - onnx_model.initializer[initializer_name][0]) - - @classmethod - def get_tensor_producer(cls, output_name, output2node): - if output_name not in output2node: - return 'INPUT_TOKEN' - return output2node[output_name] - - @classmethod - def get_tensor_consumer(self, input_name, input2node): - if input_name not in input2node: - return ['OUTPUT_TOKEN'] - return input2node[input_name] - @classmethod def remove_node_from_onnx(cls, node: onnx.NodeProto, onnx_model: onnx.ModelProto): @@ -260,6 +242,8 @@ def topo_sort(cls, @classmethod def optimize(cls, onnx_model): + """Remove standalone nodes and redundant initializers, and + topologically sort the nodes in a directed acyclic graph.""" input2node = cls.map_input_and_node(onnx_model) output2node = cls.map_output_and_node(onnx_model) diff --git a/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py index 7d05847c1..cde430b08 100644 --- a/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py +++ b/mmrazor/models/quantizers/exporters/tensorrt_quantize_exporter.py @@ -1,6 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np -import onnx + +try: + import onnx +except ImportError: + from mmrazor.utils import get_package_placeholder + onnx = get_package_placeholder('No module named onnx') from .base_quantize_exporter import BaseQuantizeExportor diff --git a/requirements/tests.txt b/requirements/tests.txt index 8763670ef..e38249fcd 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -5,6 +5,7 @@ isort==4.3.21 nbconvert nbformat numpy < 1.24.0 # A temporary solution for tests with mmdet. +onnx pytest xdoctest >= 0.10.0 yapf diff --git a/tests/test_models/test_quantizers/test_exporter.py b/tests/test_models/test_quantizers/test_exporter.py new file mode 100644 index 000000000..04bd8a671 --- /dev/null +++ b/tests/test_models/test_quantizers/test_exporter.py @@ -0,0 +1,348 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +import shutil +import tempfile +from unittest import TestCase, skipIf + +import torch +import torch.nn as nn + +try: + import onnx + from onnx import helper + from torch.fx import GraphModule +except ImportError: + from mmrazor.utils import get_package_placeholder, get_placeholder + GraphModule = get_placeholder('torch>=1.13') + onnx = get_package_placeholder('No module named onnx') + helper = get_package_placeholder('No module named onnx.helper') + +from mmengine import ConfigDict +from mmengine.model import BaseModel + +try: + import mmdeploy +except ImportError: + from mmrazor.utils import get_package_placeholder + mmdeploy = get_package_placeholder('mmdeploy') + +from mmrazor import digit_version +from mmrazor.models.quantizers.exporters import (OpenVinoQuantizeExportor, + TensorRTExplicitExporter) +from mmrazor.models.quantizers.exporters.optim_utils import ONNXOptimUtils +from mmrazor.registry import MODELS + + +class BasicBlock(nn.Module): + + def __init__(self, in_channels, out_channels): + super(BasicBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mid_channels = out_channels + + self.norm1 = nn.BatchNorm2d(self.mid_channels) + self.norm2 = nn.BatchNorm2d(out_channels) + self.conv1 = nn.Conv2d(in_channels, self.mid_channels, 1) + self.conv2 = nn.Conv2d(self.mid_channels, out_channels, 1) + + self.relu = nn.ReLU6() + self.drop_path = nn.Identity() + + def forward(self, x): + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + out = self.drop_path(out) + + out += identity + + return out + + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ToyModel(nn.Module): + + def __init__(self): + super(ToyModel, self).__init__() + self.stem_layer = nn.Sequential( + nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3), nn.ReLU()) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.block = BasicBlock(3, 3) + self.block2 = BasicBlock(3, 3) + self.gap = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(3, 4) + + def forward(self, x): + x = self.stem_layer(x) + x = self.maxpool(x) + x = self.block(x) + x = self.block2(x) + x = self.gap(x) + x = x.flatten(1) + x = self.fc(x) + return x + + +class ToyQuantModel(BaseModel): + + def __init__(self): + super().__init__() + self.architecture = ToyModel() + + def loss(self, outputs, data_samples): + return dict(loss=outputs.sum() - data_samples.sum()) + + def forward(self, inputs, data_samples, mode: str = 'tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + outputs = self.architecture(inputs) + + return outputs + + +OpenVINO_GLOBAL_QCONFIG = ConfigDict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict( + qdtype='quint8', bit=8, is_symmetry=True, averaging_constant=0.1), +) + +OpenVINO_ALG_CONFIG = ConfigDict( + type='mmrazor.MMArchitectureQuant', + architecture=dict(type='ToyQuantModel'), + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=OpenVINO_GLOBAL_QCONFIG, + tracer=dict(type='mmrazor.CustomTracer'))) + +TensorRT_GLOBAL_QCONFIG = ConfigDict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict(qdtype='qint8', bit=8, is_symmetry=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +TensorRT_ALG_CONFIG = ConfigDict( + type='mmrazor.MMArchitectureQuant', + architecture=dict(type='ToyQuantModel'), + quantizer=dict( + type='mmrazor.TensorRTQuantizer', + global_qconfig=OpenVINO_GLOBAL_QCONFIG, + tracer=dict(type='mmrazor.CustomTracer'))) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') +class TestONNXOptimUtils(TestCase): + + def setUp(self): + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() + filename = 'symbolic.onnx' + filename = os.path.join(self.temp_dir, filename) + toy_model = MODELS.build(OpenVINO_ALG_CONFIG) + observed_model = toy_model.get_deploy_model() + torch.onnx.export( + observed_model, + torch.rand(2, 3, 16, 16), + filename, + opset_version=11) + self.onnx_model = onnx.load(filename) + self.optimizer = ONNXOptimUtils + + def tearDown(self): + MODELS.module_dict.pop('ToyQuantModel') + shutil.rmtree(self.temp_dir) + + def test_map_name_and_data(self): + params = self.optimizer.map_name_and_data(self.onnx_model) + params_keys = [ + 'activation_post_process_0.scale', + 'activation_post_process_0.zero_point', + 'architecture.stem_layer.0.weight', + 'architecture.stem_layer.0.bias', + 'architecture.stem_layer.0.weight_fake_quant.scale', + 'architecture.stem_layer.0.weight_fake_quant.zero_point', + 'architecture.block.conv1.weight', 'architecture.block.conv1.bias', + 'architecture.block.conv1.weight_fake_quant.scale', + 'architecture.block.conv2.bias', + 'architecture.block2.conv1.weight', + 'architecture.block2.conv1.bias', + 'architecture.block2.conv1.weight_fake_quant.scale', + 'architecture.block2.conv2.weight', + 'architecture.block2.conv2.bias', + 'architecture.block2.conv2.weight_fake_quant.scale', + 'architecture.fc.weight', 'architecture.fc.bias', + 'architecture.fc.weight_fake_quant.scale', + 'architecture.fc.weight_fake_quant.zero_point', + 'activation_post_process_15.zero_point', + 'activation_post_process_15.scale', + 'activation_post_process_14.zero_point', + 'activation_post_process_14.scale', + 'activation_post_process_12.zero_point', + 'activation_post_process_12.scale', + 'activation_post_process_10.zero_point', + 'activation_post_process_10.scale', + 'activation_post_process_8.zero_point', + 'activation_post_process_8.scale', + 'activation_post_process_6.zero_point', + 'activation_post_process_6.scale', + 'activation_post_process_4.zero_point', + 'activation_post_process_4.scale', + 'activation_post_process_1.zero_point', + 'activation_post_process_1.scale', + 'architecture.block2.conv2.weight_fake_quant.zero_point', + 'architecture.block2.conv1.weight_fake_quant.zero_point', + 'architecture.block.conv2.weight_fake_quant.zero_point', + 'architecture.block.conv2.weight_fake_quant.scale', + 'architecture.block.conv2.weight', + 'architecture.block.conv1.weight_fake_quant.zero_point', + '/activation_post_process_0/Constant_output_0', + '/activation_post_process_0/Constant_1_output_0', + '/stem_layer.0/weight_fake_quant/Constant_output_0', + '/stem_layer.0/weight_fake_quant/Constant_1_output_0', + '/relu/Constant_output_0', '/relu/Constant_1_output_0', + '/relu_dup1/Constant_output_0', '/relu_dup1/Constant_1_output_0', + '/relu_1/Constant_output_0', '/relu_1/Constant_1_output_0', + '/relu_dup1_1/Constant_output_0', + '/relu_dup1_1/Constant_1_output_0' + ] + self.assertEqual(set(params.keys()), set(params_keys)) + + def test_map_name_and_initializer(self): + initializers = self.optimizer.map_name_and_initializer(self.onnx_model) + for init in self.onnx_model.graph.initializer: + self.assertIn(init.name, initializers.keys()) + # self.assertEqual(set(initializers.keys()), set(initializers_keys)) + + def test_map_output_and_node(self): + _ = self.optimizer.map_output_and_node(self.onnx_model) + + def test_map_input_and_node(self): + _ = self.optimizer.map_input_and_node(self.onnx_model) + + def test_remove_node_from_onnx(self): + onnx_model = copy.deepcopy(self.onnx_model) + node_to_remove = next(iter(onnx_model.graph.node)) + self.optimizer.remove_node_from_onnx(node_to_remove, onnx_model) + for node in onnx_model.graph.node: + self.assertNotEqual(node, node_to_remove) + + def test_remove_initializer_from_onnx(self): + onnx_model = copy.deepcopy(self.onnx_model) + initializer_to_remove = next(iter(onnx_model.graph.initializer)) + self.optimizer.remove_initializer_from_onnx(initializer_to_remove, + onnx_model) + for initializer in onnx_model.graph.initializer: + self.assertNotEqual(initializer, initializer_to_remove) + + def test_find_standalone_nodes(self): + standalone_nodes = self.optimizer.find_standalone_nodes( + self.onnx_model) + self.assertEqual(standalone_nodes, []) + + def test_find_redundant_initializers(self): + redundant_initializers = self.optimizer.find_redundant_initializers( + self.onnx_model) + self.assertEqual(redundant_initializers, []) + + def test_topo_sort(self): + onnx_model = copy.deepcopy(self.onnx_model) + onnx_model_topo_sort = self.optimizer.topo_sort(onnx_model) + self.assertEqual( + len(onnx_model_topo_sort.graph.node), + len(self.onnx_model.graph.node)) + + def test_optimize(self): + onnx_model = copy.deepcopy(self.onnx_model) + fake_node = helper.make_node('fake_node', [], [], mode='constant') + self.optimizer.insert_node_to_onnx(fake_node, onnx_model) + self.optimizer.optimize(onnx_model) + for node in onnx_model.graph.node: + self.assertNotEqual(node, fake_node) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') +class TestOpenVinoQuantizeExportor(TestCase): + + def setUp(self): + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() + filename = 'toy_model_symbolic.onnx' + filename = os.path.join(self.temp_dir, filename) + toy_model = MODELS.build(OpenVINO_ALG_CONFIG) + observed_model = toy_model.get_deploy_model() + torch.onnx.export( + observed_model, + torch.rand(2, 3, 16, 16), + filename, + opset_version=11) + self.onnx_model = onnx.load(filename) + self.export_path = os.path.join(self.temp_dir, 'toy_model.onnx') + + def tearDown(self): + MODELS.module_dict.pop('ToyQuantModel') + shutil.rmtree(self.temp_dir) + + def test_export(self): + exporter = OpenVinoQuantizeExportor(self.onnx_model, self.export_path) + exporter.export() + self.assertTrue(os.path.exists(self.export_path)) + onnx_model = onnx.load(self.export_path) + self.assertIsInstance(onnx_model, onnx.ModelProto) + + +@skipIf( + digit_version(torch.__version__) < digit_version('1.13.0'), + 'PyTorch version lower than 1.13.0 is not supported.') +class TestTensorRTExplicitExporter(TestCase): + + def setUp(self): + MODELS.register_module(module=ToyQuantModel, force=True) + self.temp_dir = tempfile.mkdtemp() + filename = 'toy_model_symbolic.onnx' + filename = os.path.join(self.temp_dir, filename) + toy_model = MODELS.build(TensorRT_ALG_CONFIG) + observed_model = toy_model.get_deploy_model() + torch.onnx.export( + observed_model, + torch.rand(2, 3, 16, 16), + filename, + opset_version=11) + self.onnx_model = onnx.load(filename) + self.export_path = os.path.join(self.temp_dir, 'toy_model.onnx') + + def tearDown(self): + MODELS.module_dict.pop('ToyQuantModel') + shutil.rmtree(self.temp_dir) + + def test_export(self): + exporter = TensorRTExplicitExporter(self.onnx_model, self.export_path) + exporter.export() + self.assertTrue(os.path.exists(self.export_path)) + onnx_model = onnx.load(self.export_path) + self.assertIsInstance(onnx_model, onnx.ModelProto) From 8f26f5b6fccaec0e3d67bbb7ca042bf724f330c3 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Thu, 13 Apr 2023 19:25:35 +0800 Subject: [PATCH 26/44] [Bug] Fix ci converage setting (#508) fix ci converage --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4b99bced4..9ed4cb002 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -119,7 +119,7 @@ jobs: coverage report -m # Upload coverage report for python3.8 && pytorch1.12.0 cpu - name: Upload coverage to Codecov - if: ${{matrix.torch == '1.12.0' && matrix.python-version == '3.8'}} + if: ${{matrix.torch == '1.13.0' && matrix.python-version == '3.8'}} uses: codecov/codecov-action@v2 with: file: ./coverage.xml From dbf33989801088c274c7b12be6c90003dc682a1b Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Mon, 17 Apr 2023 06:13:43 +0800 Subject: [PATCH 27/44] [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss --- .github/workflows/build.yml | 2 +- mmrazor/models/losses/adaround_loss.py | 87 -------------------------- requirements/tests.txt | 2 +- 3 files changed, 2 insertions(+), 89 deletions(-) delete mode 100644 mmrazor/models/losses/adaround_loss.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9ed4cb002..2c2b8ed21 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -120,7 +120,7 @@ jobs: # Upload coverage report for python3.8 && pytorch1.12.0 cpu - name: Upload coverage to Codecov if: ${{matrix.torch == '1.13.0' && matrix.python-version == '3.8'}} - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 with: file: ./coverage.xml flags: unittests diff --git a/mmrazor/models/losses/adaround_loss.py b/mmrazor/models/losses/adaround_loss.py deleted file mode 100644 index 76c97977d..000000000 --- a/mmrazor/models/losses/adaround_loss.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn -from mmengine.logging import print_log - -from mmrazor.registry import MODELS - -_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) - - -@MODELS.register_module() -class AdaRoundLoss(nn.Module): - r'''loss function to calculate mse reconstruction loss and relaxation loss - use some tempdecay to balance the two losses. - ''' - - def __init__(self, - weight: float = 1., - iters: int = 10000, - beta_range: tuple = (20, 2), - warm_up: float = 0.0, - p: float = 2.): - self.weight = weight - self.loss_start = iters * warm_up - self.p = p - - self.temp_decay = LinearTempDecay( - iters, - warm_up=warm_up, - start_beta=beta_range[0], - end_beta=beta_range[1]) - self.count = 0 - - def forward(self, subgraph, pred, tgt): - """Compute the total loss for adaptive rounding: rec_loss is the - quadratic output reconstruction loss, round_loss is a regularization - term to optimize the rounding policy. - - :param pred: output from quantized model - :param tgt: output from FP model - :return: total loss function - """ - - def lp_loss(pred, tgt, p=2.0): - """loss function measured in L_p Norm.""" - return (pred - tgt).abs().pow(p).sum(1).mean() - - self.count += 1 - rec_loss = lp_loss(pred, tgt, p=self.p) - - beta = self.temp_decay(self.count) - if self.count < self.loss_start: - round_loss = 0 - else: - round_loss = 0 - for layer in subgraph.modules(): - if isinstance(layer, _ADAROUND_SUPPORT_TYPE): - round_vals = layer.weight_fake_quant.rectified_sigmoid() - round_loss += self.weight * (1 - ( - (round_vals - .5).abs() * 2).pow(beta)).sum() - - total_loss = rec_loss + round_loss - if self.count % 500 == 0: - print_log('Total loss:\t{:.3f} (rec_loss:{:.3f}, ' - 'round_loss:{:.3f})\tbeta={:.2f}\tcount={}'.format( - float(total_loss), float(rec_loss), - float(round_loss), beta, self.count)) - return total_loss - - -class LinearTempDecay: - - def __init__(self, t_max=10000, warm_up=0.2, start_beta=20, end_beta=2): - self.t_max = t_max - self.start_decay = warm_up * t_max - self.start_beta = start_beta - self.end_beta = end_beta - - def __call__(self, t): - if t < self.start_decay: - return self.start_beta - elif t > self.t_max: - return self.end_beta - else: - rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) - return self.end_beta + (self.start_beta - self.end_beta) * \ - max(0.0, (1 - rel_t)) diff --git a/requirements/tests.txt b/requirements/tests.txt index e38249fcd..5980dc303 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,4 @@ -codecov +coverage flake8 interrogate isort==4.3.21 From 8ec3655926f7f26b857023308b3bdba6a9867422 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Mon, 17 Apr 2023 15:01:19 +0800 Subject: [PATCH 28/44] [BUG] Fix quantization loop (#507) * fix quantization loop * fix quant loop * fix quant loop * fix qat configs * [Bug] Fix ci converage setting (#508) fix ci converage * [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss * add freeze_bn_begin to lsq * delete useless codes --------- Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> --- ... lsq_openvino_resnet18_8xb32_100e_in1k.py} | 8 ++- .../lsq_openvino_resnet18_8xb32_10e_in1k.py | 63 +++++++++++++++++++ .../qat_openvino_resnet18_10e_8xb32_in1k.py | 62 ++++++++++++++++++ mmrazor/engine/runner/quantization_loops.py | 61 ++++++++++++------ 4 files changed, 172 insertions(+), 22 deletions(-) rename configs/quantization/qat/{lsq_openvino_resnet18_8xb32_in1k.py => lsq_openvino_resnet18_8xb32_100e_in1k.py} (90%) create mode 100644 configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py create mode 100644 configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py similarity index 90% rename from configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py rename to configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py index 0b79232f8..00e424141 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py @@ -59,6 +59,10 @@ _delete_=True, type='mmrazor.LSQEpochBasedLoop', max_epochs=100, - val_interval=1) + val_interval=1, + freeze_bn_begin=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -test_cfg = val_cfg + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py new file mode 100644 index 000000000..f931ddaf5 --- /dev/null +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py @@ -0,0 +1,63 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=10, + val_interval=1, + freeze_bn_begin=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py b/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py new file mode 100644 index 000000000..261af7abb --- /dev/null +++ b/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py @@ -0,0 +1,62 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=False) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=10, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 18caf06f5..764c8605d 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -11,11 +11,13 @@ from torch.nn.intrinsic.qat import freeze_bn_stats except ImportError: from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') enable_observer = get_placeholder('torch>=1.13') freeze_bn_stats = get_placeholder('torch>=1.13') +from mmengine.dist import all_reduce_params, is_distributed from torch.utils.data import DataLoader from mmrazor.models import register_torch_fake_quants, register_torch_observers @@ -69,7 +71,18 @@ def prepare_for_run_epoch(self): """Toggle the state of the observers and fake quantizers before qat training.""" self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(enable_observer) + + # The initialized _epoch equals to 0 so _epoch + 1 + # equal to the current epoch + if (self.disable_observer_begin > 0 + and self._epoch + 1 >= self.disable_observer_begin): + self.runner.model.apply(disable_observer) + else: + self.runner.model.apply(enable_observer) + + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) def prepare_for_val(self): """Toggle the state of the observers and fake quantizers before @@ -89,8 +102,6 @@ def run(self): if (self.runner.val_loop is not None and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): - # observer disabled during evaluation - self.prepare_for_val() self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -100,18 +111,13 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # The initialized _epoch equals to 0 so _epoch + 1 - # equal to the current epoch - if self._epoch + 1 >= self.disable_observer_begin: - self.runner.model.apply(disable_observer) - - if self._epoch + 1 >= self.freeze_bn_begin: - self.runner.model.apply(freeze_bn_stats) - for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() self.runner.call_hook('after_train_epoch') self._epoch += 1 @@ -156,11 +162,16 @@ def __init__( dynamic_intervals=dynamic_intervals) self.is_first_batch = True + self.distributed = is_distributed() def prepare_for_run_epoch(self): """Toggle the state of the observers and fake quantizers before qat training.""" - pass + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) + + self.runner.model.apply(enable_param_learning) def prepare_for_val(self): """Toggle the state of the observers and fake quantizers before @@ -172,20 +183,30 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # TODO freeze bn - if self._epoch + 1 >= self.freeze_bn_begin: - self.runner.model.apply(freeze_bn_stats) - for idx, data_batch in enumerate(self.dataloader): if self.is_first_batch: - # lsq init - self.is_first_batch = False + # lsq observer init self.runner.model.apply(enable_static_estimate) - else: - self.runner.model.apply(enable_param_learning) + self.run_iter(idx, data_batch) + if self.is_first_batch: + # In the first batch, scale in LearnableFakeQuantize is + # calculated through lsq observer. As the values of `scale` of + # different observers in different rank are usually different, + # we have to sync the `scale` here. + if self.distributed: + all_reduce_params( + self.runner.model.parameters(), op='mean') + + # Change back to param learning mode + self.is_first_batch = False + self.runner.model.apply(enable_param_learning) + self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() self.runner.call_hook('after_train_epoch') self._epoch += 1 From 72dc61bb0dea111460133e07b720c527a6c0ef01 Mon Sep 17 00:00:00 2001 From: humu789 Date: Tue, 11 Apr 2023 18:24:58 +0800 Subject: [PATCH 29/44] add test ptq --- mmrazor/engine/runner/quantization_loops.py | 8 +++++++- tools/ptq.py | 6 ++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 764c8605d..0c9109b1e 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -321,7 +321,8 @@ def __init__(self, evaluator: Union[Evaluator, Dict, List], calibrate_dataloader: Union[DataLoader, Dict], calibrate_steps=32, - fp16: bool = False): + fp16: bool = False, + only_val=False): super().__init__(runner, dataloader, evaluator, fp16) if isinstance(calibrate_dataloader, dict): # Determine whether or not different ranks use different seed. @@ -333,6 +334,7 @@ def __init__(self, self.dataloader = dataloader self.calibrate_steps = calibrate_steps + self.only_val = only_val def run(self) -> dict: """Launch test.""" @@ -343,6 +345,10 @@ def run(self) -> dict: self.runner.model.apply(enable_fake_quant) self.runner.model.apply(enable_observer) + if self.only_val: + self.runner.model.apply(disable_observer) + return self.runner.val_loop.run() + for idx, data_batch in enumerate(self.dataloader): if idx == self.calibrate_steps: break diff --git a/tools/ptq.py b/tools/ptq.py index 2c00c5b11..a6755303b 100644 --- a/tools/ptq.py +++ b/tools/ptq.py @@ -14,7 +14,7 @@ def parse_args(): parser = argparse.ArgumentParser( description='MMRazor test (and eval) a model') parser.add_argument('config', help='test config file path') - # parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--ckpt-quant', help='ptq checkpoint file') parser.add_argument( '--work-dir', help='the directory to save the file containing evaluation metrics') @@ -60,7 +60,9 @@ def main(): cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - # cfg.load_from = args.checkpoint + if vars(args).get('ckpt_quant', None): + cfg.load_from = args.ckpt_quant + cfg.test_cfg.only_val = True # build the runner from config runner = Runner.from_cfg(cfg) From e307823d9396948a6b694e67ff30e3e0b31768b8 Mon Sep 17 00:00:00 2001 From: humu789 Date: Thu, 13 Apr 2023 17:53:28 +0800 Subject: [PATCH 30/44] opt ptq pipeline --- .dev_scripts/benchmark_test.py | 5 +- .../darts/darts_subnet_1xb96_cifar10_2.0.py | 2 +- configs/nas/mmcls/darts/metafile.yml | 2 +- configs/quantization/ptq/README.md | 0 configs/quantization/ptq/metafile.yml | 173 ++++++++++++++++++ configs/quantization/qat/README.md | 0 configs/quantization/qat/metafile.yml | 19 ++ .../minmax_openvino_resnet18_8xb32_in1k.py | 1 - mmrazor/engine/runner/quantization_loops.py | 47 +++-- model-index.yml | 2 + tools/ptq.py | 6 +- tools/test.py | 2 + 12 files changed, 230 insertions(+), 29 deletions(-) create mode 100644 configs/quantization/ptq/README.md create mode 100644 configs/quantization/ptq/metafile.yml create mode 100644 configs/quantization/qat/README.md create mode 100644 configs/quantization/qat/metafile.yml diff --git a/.dev_scripts/benchmark_test.py b/.dev_scripts/benchmark_test.py index a9a208dbb..1af3e4fa4 100644 --- a/.dev_scripts/benchmark_test.py +++ b/.dev_scripts/benchmark_test.py @@ -24,9 +24,9 @@ def parse_args(): parser = argparse.ArgumentParser( description="Test all models' accuracy in model-index.yml") - parser.add_argument( - 'partition', type=str, help='Cluster partition to use.') parser.add_argument('checkpoint_root', help='Checkpoint file root path.') + parser.add_argument( + '--partition', type=str, help='Cluster partition to use.') parser.add_argument( '--job-name', type=str, @@ -148,6 +148,7 @@ def create_test_job_batch(commands, model_info, args, port): if exists: print(f'{checkpoint} already exists.') else: + print(f'start downloading {fname}') wget.download(model_info.weights, str(checkpoint)) print(f'\nSaved in {checkpoint}.') diff --git a/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py b/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py index ab9ee6180..c05a3b435 100644 --- a/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py +++ b/configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py @@ -37,7 +37,7 @@ init_cfg=dict( type='Pretrained', checkpoint= # noqa: E251 - 'https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600_latest.pth', # noqa: E501 + 'https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.27_20211222-17e42600_latest.pth', # noqa: E501 prefix='architecture.')) model_wrapper_cfg = None diff --git a/configs/nas/mmcls/darts/metafile.yml b/configs/nas/mmcls/darts/metafile.yml index b92f28dd7..b262a6960 100644 --- a/configs/nas/mmcls/darts/metafile.yml +++ b/configs/nas/mmcls/darts/metafile.yml @@ -25,4 +25,4 @@ Models: Top 1 Accuracy: 97.32 Top 5 Accuracy: 99.94 Config: configs/nas/mmcls/darts/darts_subnet_1xb96_cifar10_2.0.py - Weights: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-23ca1e10.pth + Weights: https://download.openmmlab.com/mmrazor/v1/darts/darts_subnetnet_1xb96_cifar10_acc-97.32_20211222-e5727921_latest.pth diff --git a/configs/quantization/ptq/README.md b/configs/quantization/ptq/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/configs/quantization/ptq/metafile.yml b/configs/quantization/ptq/metafile.yml new file mode 100644 index 000000000..89f2fc6b2 --- /dev/null +++ b/configs/quantization/ptq/metafile.yml @@ -0,0 +1,173 @@ +Collections: + - Name: PTQ + +Models: + - Name: ptq_openvino_mbv2_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth + Metrics: + Top 1 Accuracy: 71.86 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.272 + Config: configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.pth + + - Name: ptq_openvino_resnet18_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.766 + Config: configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.pth + + - Name: ptq_openvino_resnet50_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth + Metrics: + Top 1 Accuracy: 76.55 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 76.338 + Config: configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.pth + + - Name: ptq_openvino_retina_r50_1x_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmdet::retinanet/retinanet_r50_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth + Metrics: + box AP: 36.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.3 + Config: configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_retina_r50_1x_coco_calib32xb32_20230330_172645-80eea5b6.pth + + - Name: ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: openvino + Float Model: + Config: mmdet::yolox/yolox_s_8xb8-300e_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth + Metrics: + box AP: 40.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 38.5 + Config: configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32_20230330_175747-f1a0a2f4.pth + + - Name: ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/mobilenet_v2/mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth + Metrics: + Top 1 Accuracy: 71.86 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.324 + Config: configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32_20230331_153131-335988e4.pth + + - Name: ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.762 + Config: configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32_20230331_144323-640b272e.pth + + - Name: ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmcls::resnet/resnet50_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth + Metrics: + Top 1 Accuracy: 76.55 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 76.372 + Config: configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32_20230331_145011-d2da300f.pth + + - Name: ptq_tensorrt_retina_r50_1x_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmdet::retinanet/retinanet_r50_fpn_1x_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth + Metrics: + box AP: 36.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 36.2 + Config: configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_retina_r50_1x_coco_calib32xb32_20230330_205741-4c5c10c4.pth + + - Name: ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32 + In Collection: PTQ + Metadata: + Backend: tensorrt + Float Model: + Config: mmdet::yolox/yolox_s_8xb8-300e_coco.py + Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth + Metrics: + box AP: 40.5 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 38.8 + Config: configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32_20230331_155139-f2021e57.pth diff --git a/configs/quantization/qat/README.md b/configs/quantization/qat/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/configs/quantization/qat/metafile.yml b/configs/quantization/qat/metafile.yml new file mode 100644 index 000000000..94e552ecd --- /dev/null +++ b/configs/quantization/qat/metafile.yml @@ -0,0 +1,19 @@ +Collections: + - Name: QAT +Models: + - Name: lsq_openvino_resnet18_8xb32_in1k + In Collection: QAT + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.06 + Config: configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_in1k_20230402_173316-0d441f23.pth diff --git a/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py index 8aa11d6b3..35547d91b 100644 --- a/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py @@ -62,4 +62,3 @@ max_epochs=100, val_interval=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -# test_cfg = val_cfg diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 0c9109b1e..aa917e01e 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os from typing import Dict, List, Optional, Sequence, Tuple, Union import torch from mmengine.evaluator import Evaluator +from mmengine.logging import print_log from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop try: @@ -342,32 +344,37 @@ def run(self) -> dict: self.runner.call_hook('before_test_epoch') self.runner.model.eval() - self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(enable_observer) - if self.only_val: - self.runner.model.apply(disable_observer) - return self.runner.val_loop.run() + if not self.only_val: + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) - for idx, data_batch in enumerate(self.dataloader): - if idx == self.calibrate_steps: - break - self.run_iter(idx, data_batch) + print_log('Star calibratiion...') + for idx, data_batch in enumerate(self.dataloader): + if idx == self.calibrate_steps: + break + self.run_iter(idx, data_batch) + print_log('Finish calibratiion!') - self.runner.save_checkpoint( - self.runner.work_dir, - 'model_ptq.pth', - file_client_args=None, - save_optimizer=False, - save_param_scheduler=False) + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) - self.runner.call_hook('after_test_epoch', metrics=None) + save_dir = os.path.join(self.runner.work_dir, + self.runner.timestamp) + self.runner.save_checkpoint( + save_dir, + 'model_ptq.pth', + file_client_args=None, + save_optimizer=False, + save_param_scheduler=False) + print_log(f'Quantized model is saved in {save_dir}') + + print_log('Start Evaluating quantized model...') + metricts = self.runner.val_loop.run() + self.runner.call_hook('after_test_epoch', metrics=metricts) self.runner.call_hook('after_test') - self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(disable_observer) - - return self.runner.val_loop.run() + return metricts @torch.no_grad() def run_iter(self, idx, data_batch: Sequence[dict]) -> None: diff --git a/model-index.yml b/model-index.yml index 15e4595cb..7aa7ce7e3 100644 --- a/model-index.yml +++ b/model-index.yml @@ -26,3 +26,5 @@ Import: - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml - configs/pruning/mmcls/l1-norm/metafile.yml - configs/pruning/mmcls/dmcp/metafile.yml + - configs/quantization/ptq/metafile.yml + - configs/quantization/qat/metafile.yml diff --git a/tools/ptq.py b/tools/ptq.py index a6755303b..2c00c5b11 100644 --- a/tools/ptq.py +++ b/tools/ptq.py @@ -14,7 +14,7 @@ def parse_args(): parser = argparse.ArgumentParser( description='MMRazor test (and eval) a model') parser.add_argument('config', help='test config file path') - parser.add_argument('--ckpt-quant', help='ptq checkpoint file') + # parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument( '--work-dir', help='the directory to save the file containing evaluation metrics') @@ -60,9 +60,7 @@ def main(): cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) - if vars(args).get('ckpt_quant', None): - cfg.load_from = args.ckpt_quant - cfg.test_cfg.only_val = True + # cfg.load_from = args.checkpoint # build the runner from config runner = Runner.from_cfg(cfg) diff --git a/tools/test.py b/tools/test.py index fb6b00b86..a69133158 100644 --- a/tools/test.py +++ b/tools/test.py @@ -66,6 +66,8 @@ def main(): cfg.load_from = None else: cfg.load_from = args.checkpoint + if 'type' in cfg.test_cfg and cfg.test_cfg.type.endswith('PTQLoop'): + cfg.test_cfg.only_val = True # build the runner from config runner = Runner.from_cfg(cfg) From e16cb6493c90b2f5b90c9fe2bbfa5170c7530531 Mon Sep 17 00:00:00 2001 From: humu789 Date: Thu, 13 Apr 2023 19:21:40 +0800 Subject: [PATCH 31/44] refactor quant configs --- configs/quantization/ptq/README.md | 0 configs/quantization/ptq/base/README.md | 43 ++++++++++++ .../quantization/ptq/{ => base}/metafile.yml | 31 +++------ ...tq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 0 ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 0 ...penvino_resnet50_8xb32_in1k_calib32xb32.py | 0 ...openvino_retina_r50_1x_coco_calib32xb32.py | 0 ...vino_yolox_s_8xb8-300e_coco_calib32xb32.py | 0 ...tq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py | 0 ...ensorrt_resnet18_8xb32_in1k_calib32xb32.py | 0 ...ensorrt_resnet50_8xb32_in1k_calib32xb32.py | 0 ...tensorrt_retina_r50_1x_coco_calib32xb32.py | 0 ...orrt_yolox_s_8xb8-300e_coco_calib32xb32.py | 0 configs/quantization/qat/README.md | 0 configs/quantization/qat/base/README.md | 43 ++++++++++++ configs/quantization/qat/base/metafile.yml | 20 ++++++ .../qat_openvino_resnet18_8xb32_in1k.py} | 0 configs/quantization/qat/lsq/README.md | 43 ++++++++++++ .../lsq/lsq_openvino_resnet18_8xb32_in1k.py | 68 +++++++++++++++++++ .../quantization/qat/{ => lsq}/metafile.yml | 7 +- model-index.yml | 5 +- 21 files changed, 235 insertions(+), 25 deletions(-) delete mode 100644 configs/quantization/ptq/README.md create mode 100644 configs/quantization/ptq/base/README.md rename configs/quantization/ptq/{ => base}/metafile.yml (87%) rename configs/quantization/ptq/{ => base}/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_openvino_retina_r50_1x_coco_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py (100%) rename configs/quantization/ptq/{ => base}/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py (100%) delete mode 100644 configs/quantization/qat/README.md create mode 100644 configs/quantization/qat/base/README.md create mode 100644 configs/quantization/qat/base/metafile.yml rename configs/quantization/qat/{minmax_openvino_resnet18_8xb32_in1k.py => base/qat_openvino_resnet18_8xb32_in1k.py} (100%) create mode 100644 configs/quantization/qat/lsq/README.md create mode 100644 configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py rename configs/quantization/qat/{ => lsq}/metafile.yml (78%) diff --git a/configs/quantization/ptq/README.md b/configs/quantization/ptq/README.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/configs/quantization/ptq/base/README.md b/configs/quantization/ptq/base/README.md new file mode 100644 index 000000000..b526d7274 --- /dev/null +++ b/configs/quantization/ptq/base/README.md @@ -0,0 +1,43 @@ +# Post-Training Quantization (PTQ) + +> [A White Paper on Neural Network Quantization](https://arxiv.org/abs/2106.08295) + + + +## Abstract + +While neural networks have advanced the frontiers in many applications, they often come at a high computational cost. Reducing the power and latency of neural network inference is key if we want to integrate modern networks into edge devices with strict power and compute requirements. Neural network quantization is one of the most effective ways of achieving these savings but the additional noise it induces can lead to accuracy degradation. In this white paper, we introduce state-of-the-art algorithms for mitigating the impact of quantization noise on the network's performance while maintaining low-bit weights and activations. We start with a hardware motivated introduction to quantization and then consider two main classes of algorithms: Post-Training Quantization (PTQ) and Quantization-Aware-Training (QAT). PTQ requires no re-training or labelled data and is thus a lightweight push-button approach to quantization. In most cases, PTQ is sufficient for achieving 8-bit quantization with close to floating-point accuracy. QAT requires fine-tuning and access to labeled training data but enables lower bit quantization with competitive results. For both solutions, we provide tested pipelines based on existing literature and extensive experimentation that lead to state-of-the-art performance for common deep learning models and tasks. + +## Results and models + +### Classification + +| Model | Dataset | Backend | Acc-fp32 | Acc-int8 | Acc-deployed | Download | +| -------- | -------- | -------- | -------- | -------- | ------------ | ------------ | +| resnet18 | ImageNet | openvino | | | | model \| log | + +## Citation + +```latex + @misc{Nagel_Fournarakis_Amjad_Bondarenko_Baalen_Blankevoort_2021, + title={A White Paper on Neural Network Quantization}, + journal={Cornell University - arXiv}, + author={Nagel, Markus and Fournarakis, Marios and Amjad, RanaAli and Bondarenko, Yelysei and Baalen, Martvan and Blankevoort, Tijmen}, + year={2021}, + month={Jun} + } +``` + +## Getting Started + +### Train + +```python +python tools/train.py ${CONFIG_FILE} +``` + +### Test + +```python +python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} +``` diff --git a/configs/quantization/ptq/metafile.yml b/configs/quantization/ptq/base/metafile.yml similarity index 87% rename from configs/quantization/ptq/metafile.yml rename to configs/quantization/ptq/base/metafile.yml index 89f2fc6b2..8ccd60c2c 100644 --- a/configs/quantization/ptq/metafile.yml +++ b/configs/quantization/ptq/base/metafile.yml @@ -1,6 +1,6 @@ Collections: - Name: PTQ - + README: configs/quantization/ptq/base/README.md Models: - Name: ptq_openvino_mbv2_8xb32_in1k_calib32xb32 In Collection: PTQ @@ -16,9 +16,8 @@ Models: Dataset: ImageNet-1k Metrics: Top 1 Accuracy: 70.272 - Config: configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.pth - - Name: ptq_openvino_resnet18_8xb32_in1k_calib32xb32 In Collection: PTQ Metadata: @@ -33,9 +32,8 @@ Models: Dataset: ImageNet-1k Metrics: Top 1 Accuracy: 69.766 - Config: configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.pth - - Name: ptq_openvino_resnet50_8xb32_in1k_calib32xb32 In Collection: PTQ Metadata: @@ -50,9 +48,8 @@ Models: Dataset: ImageNet-1k Metrics: Top 1 Accuracy: 76.338 - Config: configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.pth - - Name: ptq_openvino_retina_r50_1x_coco_calib32xb32 In Collection: PTQ Metadata: @@ -67,9 +64,8 @@ Models: Dataset: COCO Metrics: box AP: 36.3 - Config: configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_retina_r50_1x_coco_calib32xb32_20230330_172645-80eea5b6.pth - - Name: ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32 In Collection: PTQ Metadata: @@ -84,9 +80,8 @@ Models: Dataset: COCO Metrics: box AP: 38.5 - Config: configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32_20230330_175747-f1a0a2f4.pth - - Name: ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32 In Collection: PTQ Metadata: @@ -101,9 +96,8 @@ Models: Dataset: ImageNet-1k Metrics: Top 1 Accuracy: 70.324 - Config: configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32_20230331_153131-335988e4.pth - - Name: ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32 In Collection: PTQ Metadata: @@ -118,9 +112,8 @@ Models: Dataset: ImageNet-1k Metrics: Top 1 Accuracy: 69.762 - Config: configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32_20230331_144323-640b272e.pth - - Name: ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32 In Collection: PTQ Metadata: @@ -135,9 +128,8 @@ Models: Dataset: ImageNet-1k Metrics: Top 1 Accuracy: 76.372 - Config: configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32_20230331_145011-d2da300f.pth - - Name: ptq_tensorrt_retina_r50_1x_coco_calib32xb32 In Collection: PTQ Metadata: @@ -152,9 +144,8 @@ Models: Dataset: COCO Metrics: box AP: 36.2 - Config: configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_retina_r50_1x_coco_calib32xb32_20230330_205741-4c5c10c4.pth - - Name: ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32 In Collection: PTQ Metadata: @@ -169,5 +160,5 @@ Models: Dataset: COCO Metrics: box AP: 38.8 - Config: configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py + Config: configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32_20230331_155139-f2021e57.pth diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py rename to configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py rename to configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py rename to configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py rename to configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py rename to configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py rename to configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py rename to configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py rename to configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py rename to configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py diff --git a/configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py similarity index 100% rename from configs/quantization/ptq/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py rename to configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py diff --git a/configs/quantization/qat/README.md b/configs/quantization/qat/README.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/configs/quantization/qat/base/README.md b/configs/quantization/qat/base/README.md new file mode 100644 index 000000000..d0b2d2217 --- /dev/null +++ b/configs/quantization/qat/base/README.md @@ -0,0 +1,43 @@ +# Quantization-Aware-Training (QAT) + +> [A White Paper on Neural Network Quantization](https://arxiv.org/abs/2106.08295) + + + +## Abstract + +While neural networks have advanced the frontiers in many applications, they often come at a high computational cost. Reducing the power and latency of neural network inference is key if we want to integrate modern networks into edge devices with strict power and compute requirements. Neural network quantization is one of the most effective ways of achieving these savings but the additional noise it induces can lead to accuracy degradation. In this white paper, we introduce state-of-the-art algorithms for mitigating the impact of quantization noise on the network's performance while maintaining low-bit weights and activations. We start with a hardware motivated introduction to quantization and then consider two main classes of algorithms: Post-Training Quantization (PTQ) and Quantization-Aware-Training (QAT). PTQ requires no re-training or labelled data and is thus a lightweight push-button approach to quantization. In most cases, PTQ is sufficient for achieving 8-bit quantization with close to floating-point accuracy. QAT requires fine-tuning and access to labeled training data but enables lower bit quantization with competitive results. For both solutions, we provide tested pipelines based on existing literature and extensive experimentation that lead to state-of-the-art performance for common deep learning models and tasks. + +## Results and models + +### Classification + +| Model | Dataset | Backend | Acc-fp32 | Acc-int8 | Acc-deployed | Download | +| -------- | -------- | -------- | -------- | -------- | ------------ | ------------ | +| resnet18 | ImageNet | openvino | | | | model \| log | + +## Citation + +```latex + @misc{Nagel_Fournarakis_Amjad_Bondarenko_Baalen_Blankevoort_2021, + title={A White Paper on Neural Network Quantization}, + journal={Cornell University - arXiv}, + author={Nagel, Markus and Fournarakis, Marios and Amjad, RanaAli and Bondarenko, Yelysei and Baalen, Martvan and Blankevoort, Tijmen}, + year={2021}, + month={Jun} + } +``` + +## Getting Started + +### Train + +```python +python tools/train.py ${CONFIG_FILE} +``` + +### Test + +```python +python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} +``` diff --git a/configs/quantization/qat/base/metafile.yml b/configs/quantization/qat/base/metafile.yml new file mode 100644 index 000000000..959e74299 --- /dev/null +++ b/configs/quantization/qat/base/metafile.yml @@ -0,0 +1,20 @@ +Collections: + - Name: QAT + README: configs/quantization/qat/base/README.md +Models: + - Name: qat_openvino_resnet18_8xb32_in1k + In Collection: QAT + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 70.06 + Config: configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_in1k_20230402_173316-0d441f23.pth diff --git a/configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py similarity index 100% rename from configs/quantization/qat/minmax_openvino_resnet18_8xb32_in1k.py rename to configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py diff --git a/configs/quantization/qat/lsq/README.md b/configs/quantization/qat/lsq/README.md new file mode 100644 index 000000000..99dd9d325 --- /dev/null +++ b/configs/quantization/qat/lsq/README.md @@ -0,0 +1,43 @@ +# Learned Step Size Quantization (LSQ) + +> [Learned Step Size Quantization](https://arxiv.org/abs/1902.08153) + + + +## Abstract + +Deep networks run with low precision operations at inference time offer power and space advantages over high precision alternatives, but need to overcome the challenge of maintaining high accuracy as precision decreases. Here, we present a method for training such networks, Learned Step Size Quantization, that achieves the highest accuracy to date on the ImageNet dataset when using models, from a variety of architectures, with weights and activations quantized to 2-, 3- or 4-bits of precision, and that can train 3-bit models that reach full precision baseline accuracy. Our approach builds upon existing methods for learning weights in quantized networks by improving how the quantizer itself is configured. Specifically, we introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer's quantizer step size, such that it can be learned in conjunction with other network parameters. This approach works using different levels of precision as needed for a given system and requires only a simple modification of existing training code. + +## Results and models + +### Classification + +| Model | Dataset | Backend | Acc-fp32 | Acc-int8 | Acc-deployed | Download | +| -------- | -------- | -------- | -------- | -------- | ------------ | ------------ | +| resnet18 | ImageNet | openvino | | | | model \| log | + +## Citation + +```latex + @misc{Esser_McKinstry_Bablani_Appuswamy_Modha_2019, + title={Learned Step Size Quantization}, + journal={arXiv: Learning}, + author={Esser, StevenK. and McKinstry, JeffreyL. and Bablani, Deepika and Appuswamy, Rathinakumar and Modha, DharmendraS.}, + year={2019}, + month={Feb} + } +``` + +## Getting Started + +### Train + +```python +python tools/train.py ${CONFIG_FILE} +``` + +### Test + +```python +python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} +``` diff --git a/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py new file mode 100644 index 000000000..00e424141 --- /dev/null +++ b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py @@ -0,0 +1,68 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=100, + val_interval=1, + freeze_bn_begin=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/metafile.yml b/configs/quantization/qat/lsq/metafile.yml similarity index 78% rename from configs/quantization/qat/metafile.yml rename to configs/quantization/qat/lsq/metafile.yml index 94e552ecd..7d9235cbf 100644 --- a/configs/quantization/qat/metafile.yml +++ b/configs/quantization/qat/lsq/metafile.yml @@ -1,8 +1,9 @@ Collections: - - Name: QAT + - Name: LSQ + README: configs/quantization/qat/lsq/README.md Models: - Name: lsq_openvino_resnet18_8xb32_in1k - In Collection: QAT + In Collection: LSQ Metadata: Backend: openvino Float Model: @@ -15,5 +16,5 @@ Models: Dataset: ImageNet-1k Metrics: Top 1 Accuracy: 70.06 - Config: configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py + Config: configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_in1k_20230402_173316-0d441f23.pth diff --git a/model-index.yml b/model-index.yml index 7aa7ce7e3..efcd3cae5 100644 --- a/model-index.yml +++ b/model-index.yml @@ -26,5 +26,6 @@ Import: - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml - configs/pruning/mmcls/l1-norm/metafile.yml - configs/pruning/mmcls/dmcp/metafile.yml - - configs/quantization/ptq/metafile.yml - - configs/quantization/qat/metafile.yml + - configs/quantization/ptq/base/metafile.yml + - configs/quantization/qat/base/metafile.yml + - configs/quantization/qat/lsq/metafile.yml From 6bfff52faf48a9ac960c2fafedf0964f4a937884 Mon Sep 17 00:00:00 2001 From: humu789 Date: Fri, 14 Apr 2023 14:01:45 +0800 Subject: [PATCH 32/44] update config path --- ...tq_openvino_mbv2_8xb32_in1k_calib32xb32.py | 2 +- ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 2 +- ...penvino_resnet50_8xb32_in1k_calib32xb32.py | 2 +- ...openvino_retina_r50_1x_coco_calib32xb32.py | 2 +- ...vino_yolox_s_8xb8-300e_coco_calib32xb32.py | 2 +- ...tq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py | 2 +- ...ensorrt_resnet18_8xb32_in1k_calib32xb32.py | 2 +- ...ensorrt_resnet50_8xb32_in1k_calib32xb32.py | 2 +- ...tensorrt_retina_r50_1x_coco_calib32xb32.py | 2 +- ...orrt_yolox_s_8xb8-300e_coco_calib32xb32.py | 2 +- mmrazor/engine/runner/quantization_loops.py | 2 + model-index.yml | 56 +++++++++---------- 12 files changed, 40 insertions(+), 38 deletions(-) diff --git a/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index 43f63f208..199994378 100644 --- a/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py', - '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 9bd9f55e4..f359314fa 100644 --- a/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmcls::resnet/resnet18_8xb32_in1k.py', - '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index 14f65968e..d6dd6ac4d 100644 --- a/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmcls::resnet/resnet50_8xb32_in1k.py', - '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 3ec47e479..a831c5edf 100644 --- a/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', - '../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' + '../../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py index 4ce17fe69..54839a667 100644 --- a/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmdet::yolox/yolox_s_8xb8-300e_coco.py', - '../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' + '../../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py index 681b7dabc..13a69fbbb 100644 --- a/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py', - '../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 + '../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py index 313dd195e..80dff12ab 100644 --- a/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmcls::resnet/resnet18_8xb32_in1k.py', - '../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 + '../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py index 0bd4a083a..bf7703255 100644 --- a/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmcls::resnet/resnet50_8xb32_in1k.py', - '../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 + '../../deploy_cfgs/mmcls/classification_tensorrt-int8-explicit_dynamic-224x224.py' # noqa: E501 ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py index 088e6e043..2151310a6 100644 --- a/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', - '../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 + '../../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 ] val_dataloader = dict(batch_size=32) diff --git a/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py b/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py index 01e8cc0b5..89a9243a1 100644 --- a/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py +++ b/configs/quantization/ptq/base/ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py @@ -1,6 +1,6 @@ _base_ = [ 'mmdet::yolox/yolox_s_8xb8-300e_coco.py', - '../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 + '../../deploy_cfgs/mmdet/detection_tensorrt-int8-explicit_dynamic-320x320-1344x1344.py' # noqa: E501 ] val_dataloader = dict(batch_size=32) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index aa917e01e..6a3319aa3 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -370,6 +370,8 @@ def run(self) -> dict: print_log(f'Quantized model is saved in {save_dir}') print_log('Start Evaluating quantized model...') + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) metricts = self.runner.val_loop.run() self.runner.call_hook('after_test_epoch', metrics=metricts) self.runner.call_hook('after_test') diff --git a/model-index.yml b/model-index.yml index efcd3cae5..74994e005 100644 --- a/model-index.yml +++ b/model-index.yml @@ -1,31 +1,31 @@ Import: - - configs/distill/mmseg/cwd/metafile.yml - - configs/distill/mmdet/cwd/metafile.yml - - configs/distill/mmcls/wsld/metafile.yml - - configs/distill/mmcls/rkd/metafile.yml - - configs/nas/mmcls/spos/metafile.yml - - configs/distill/mmcls/abloss/metafile.yml - - configs/distill/mmcls/byot/metafile.yml - - configs/distill/mmcls/dafl/metafile.yml - - configs/distill/mmcls/dfad/metafile.yml - - configs/distill/mmcls/dkd/metafile.yml - - configs/distill/mmcls/fitnets/metafile.yml - - configs/distill/mmcls/kd/metafile.yml - - configs/distill/mmcls/zskt/metafile.yml - - configs/distill/mmdet/fbkd/metafile.yml - - configs/distill/mmcls/factor_transfer/metafile.yml - - configs/distill/mmcls/ofd/metafile.yml - - configs/nas/mmcls/autoslim/metafile.yml - - configs/nas/mmcls/darts/metafile.yml - - configs/nas/mmdet/detnas/metafile.yml - - configs/distill/mmdet/pkd/metafile.yml - - configs/distill/mmdet3d/pkd/metafile.yml - - configs/distill/mmcls/deit/metafile.yml - - configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml - - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml - - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml - - configs/pruning/mmcls/l1-norm/metafile.yml - - configs/pruning/mmcls/dmcp/metafile.yml + # - configs/distill/mmseg/cwd/metafile.yml + # - configs/distill/mmdet/cwd/metafile.yml + # - configs/distill/mmcls/wsld/metafile.yml + # - configs/distill/mmcls/rkd/metafile.yml + # - configs/nas/mmcls/spos/metafile.yml + # - configs/distill/mmcls/abloss/metafile.yml + # - configs/distill/mmcls/byot/metafile.yml + # - configs/distill/mmcls/dafl/metafile.yml + # - configs/distill/mmcls/dfad/metafile.yml + # - configs/distill/mmcls/dkd/metafile.yml + # - configs/distill/mmcls/fitnets/metafile.yml + # - configs/distill/mmcls/kd/metafile.yml + # - configs/distill/mmcls/zskt/metafile.yml + # - configs/distill/mmdet/fbkd/metafile.yml + # - configs/distill/mmcls/factor_transfer/metafile.yml + # - configs/distill/mmcls/ofd/metafile.yml + # - configs/nas/mmcls/autoslim/metafile.yml + # - configs/nas/mmcls/darts/metafile.yml + # - configs/nas/mmdet/detnas/metafile.yml + # - configs/distill/mmdet/pkd/metafile.yml + # - configs/distill/mmdet3d/pkd/metafile.yml + # - configs/distill/mmcls/deit/metafile.yml + # - configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml + # - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml + # - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml + # - configs/pruning/mmcls/l1-norm/metafile.yml + # - configs/pruning/mmcls/dmcp/metafile.yml - configs/quantization/ptq/base/metafile.yml - - configs/quantization/qat/base/metafile.yml + # - configs/quantization/qat/base/metafile.yml - configs/quantization/qat/lsq/metafile.yml From b56cc5d224e4ac09761516e2ee36d8484cffb301 Mon Sep 17 00:00:00 2001 From: humu789 Date: Fri, 14 Apr 2023 22:54:24 +0800 Subject: [PATCH 33/44] add summary analyse tool --- .dev_scripts/benchmark_summary_analyse.py | 66 ++++++++++++++++++++++ configs/quantization/ptq/base/metafile.yml | 6 +- model-index.yml | 52 ++++++++--------- 3 files changed, 95 insertions(+), 29 deletions(-) create mode 100644 .dev_scripts/benchmark_summary_analyse.py diff --git a/.dev_scripts/benchmark_summary_analyse.py b/.dev_scripts/benchmark_summary_analyse.py new file mode 100644 index 000000000..a4896990c --- /dev/null +++ b/.dev_scripts/benchmark_summary_analyse.py @@ -0,0 +1,66 @@ +import argparse +import os + +import mmengine + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Analyse summary.yml generated by benchmark test') + parser.add_argument('file_path', help='Summary.yml path') + args = parser.parse_args() + return args + + +metric_mapping = { + 'Top 1 Accuracy': 'accuracy/top1', + 'Top 5 Accuracy': 'accuracy/top5', + 'box AP': 'coco/bbox_mAP' +} + + +def compare_metric(result, metric): + expect_val = result['expect'][metric] + actual_val = result['actual'].get(metric_mapping[metric], None) + if actual_val is None: + return None, None + if metric == 'box AP': + actual_val *= 100 + decimal_bit = len(str(expect_val).split('.')[-1]) + actual_val = round(actual_val, decimal_bit) + error = round(actual_val - expect_val, decimal_bit) + error_percent = round(abs(error) * 100 / expect_val, 3) + return error, error_percent + + +def main(): + args = parse_args() + file_path = args.file_path + results = mmengine.load(file_path, 'yml') + miss_models = dict() + sort_by_error = dict() + for k, v in results.items(): + valid_keys = v['expect'].keys() + compare_res = dict() + for m in valid_keys: + error, error_percent = compare_metric(v, m) + if error is None: + continue + compare_res[m] = {'error': error, 'error_percent': error_percent} + if error != 0: + miss_models[k] = compare_res + sort_by_error[k] = error + sort_by_error = sorted( + sort_by_error.items(), key=lambda x: abs(x[1]), reverse=True) + miss_models_sort = dict() + miss_models_sort['total error models'] = len(sort_by_error) + for k_v in sort_by_error: + index = k_v[0] + miss_models_sort[index] = miss_models[index] + save_path = os.path.join(os.path.dirname(file_path), 'summary_error.yml') + mmengine.fileio.dump(miss_models_sort, save_path, sort_keys=False) + print(f'Summary analysis result saved in {save_path}') + + +if __name__ == '__main__': + main() diff --git a/configs/quantization/ptq/base/metafile.yml b/configs/quantization/ptq/base/metafile.yml index 8ccd60c2c..1ebceab4b 100644 --- a/configs/quantization/ptq/base/metafile.yml +++ b/configs/quantization/ptq/base/metafile.yml @@ -15,7 +15,7 @@ Models: - Task: Image Classification Dataset: ImageNet-1k Metrics: - Top 1 Accuracy: 70.272 + Top 1 Accuracy: 70.224 Config: configs/quantization/ptq/base/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.pth - Name: ptq_openvino_resnet18_8xb32_in1k_calib32xb32 @@ -31,7 +31,7 @@ Models: - Task: Image Classification Dataset: ImageNet-1k Metrics: - Top 1 Accuracy: 69.766 + Top 1 Accuracy: 69.742 Config: configs/quantization/ptq/base/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.pth - Name: ptq_openvino_resnet50_8xb32_in1k_calib32xb32 @@ -47,7 +47,7 @@ Models: - Task: Image Classification Dataset: ImageNet-1k Metrics: - Top 1 Accuracy: 76.338 + Top 1 Accuracy: 76.374 Config: configs/quantization/ptq/base/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py Weights: https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.pth - Name: ptq_openvino_retina_r50_1x_coco_calib32xb32 diff --git a/model-index.yml b/model-index.yml index 74994e005..51ce0b009 100644 --- a/model-index.yml +++ b/model-index.yml @@ -1,31 +1,31 @@ Import: - # - configs/distill/mmseg/cwd/metafile.yml - # - configs/distill/mmdet/cwd/metafile.yml - # - configs/distill/mmcls/wsld/metafile.yml - # - configs/distill/mmcls/rkd/metafile.yml - # - configs/nas/mmcls/spos/metafile.yml - # - configs/distill/mmcls/abloss/metafile.yml - # - configs/distill/mmcls/byot/metafile.yml - # - configs/distill/mmcls/dafl/metafile.yml - # - configs/distill/mmcls/dfad/metafile.yml - # - configs/distill/mmcls/dkd/metafile.yml - # - configs/distill/mmcls/fitnets/metafile.yml - # - configs/distill/mmcls/kd/metafile.yml - # - configs/distill/mmcls/zskt/metafile.yml - # - configs/distill/mmdet/fbkd/metafile.yml - # - configs/distill/mmcls/factor_transfer/metafile.yml - # - configs/distill/mmcls/ofd/metafile.yml - # - configs/nas/mmcls/autoslim/metafile.yml - # - configs/nas/mmcls/darts/metafile.yml - # - configs/nas/mmdet/detnas/metafile.yml - # - configs/distill/mmdet/pkd/metafile.yml + - configs/distill/mmseg/cwd/metafile.yml + - configs/distill/mmdet/cwd/metafile.yml + - configs/distill/mmcls/wsld/metafile.yml + - configs/distill/mmcls/rkd/metafile.yml + - configs/nas/mmcls/spos/metafile.yml + - configs/distill/mmcls/abloss/metafile.yml + - configs/distill/mmcls/byot/metafile.yml + - configs/distill/mmcls/dafl/metafile.yml + - configs/distill/mmcls/dfad/metafile.yml + - configs/distill/mmcls/dkd/metafile.yml + - configs/distill/mmcls/fitnets/metafile.yml + - configs/distill/mmcls/kd/metafile.yml + - configs/distill/mmcls/zskt/metafile.yml + - configs/distill/mmdet/fbkd/metafile.yml + - configs/distill/mmcls/factor_transfer/metafile.yml + - configs/distill/mmcls/ofd/metafile.yml + - configs/nas/mmcls/autoslim/metafile.yml + - configs/nas/mmcls/darts/metafile.yml + - configs/nas/mmdet/detnas/metafile.yml + - configs/distill/mmdet/pkd/metafile.yml # - configs/distill/mmdet3d/pkd/metafile.yml - # - configs/distill/mmcls/deit/metafile.yml - # - configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml - # - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml - # - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml - # - configs/pruning/mmcls/l1-norm/metafile.yml - # - configs/pruning/mmcls/dmcp/metafile.yml + - configs/distill/mmcls/deit/metafile.yml + - configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml + - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml + - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml + - configs/pruning/mmcls/l1-norm/metafile.yml + - configs/pruning/mmcls/dmcp/metafile.yml - configs/quantization/ptq/base/metafile.yml # - configs/quantization/qat/base/metafile.yml - configs/quantization/qat/lsq/metafile.yml From 1c95740637cc3d0d8e270caa6fef288c25220975 Mon Sep 17 00:00:00 2001 From: humu789 Date: Fri, 14 Apr 2023 23:14:12 +0800 Subject: [PATCH 34/44] fix benchmark_test:detnas_frcnn_shufflenet_subnet_coco_1x.py --- .../nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py index 0da0388f1..e10daec7d 100644 --- a/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py +++ b/configs/nas/mmdet/detnas/detnas_frcnn_shufflenet_subnet_coco_1x.py @@ -9,7 +9,7 @@ init_cfg=dict( type='Pretrained', checkpoint= # noqa: E251 - 'detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth', # noqa: E501 + 'https://download.openmmlab.com/mmrazor/v1/detnas/detnas_subnet_frcnn_shufflenetv2_fpn_1x_coco_bbox_backbone_flops-0.34M_mAP-37.5_20220715-61d2e900_v1.pth', # noqa: E501 prefix='architecture.')) find_unused_parameters = False From bae160a749d4b96841a376defc947eec411d73f3 Mon Sep 17 00:00:00 2001 From: humu789 Date: Mon, 17 Apr 2023 16:04:53 +0800 Subject: [PATCH 35/44] update quantization README.md --- configs/quantization/ptq/base/README.md | 34 ++++++++++++++++++------- configs/quantization/qat/base/README.md | 20 ++++++++------- configs/quantization/qat/lsq/README.md | 21 ++++++++------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/configs/quantization/ptq/base/README.md b/configs/quantization/ptq/base/README.md index b526d7274..4afcf0548 100644 --- a/configs/quantization/ptq/base/README.md +++ b/configs/quantization/ptq/base/README.md @@ -12,9 +12,23 @@ While neural networks have advanced the frontiers in many applications, they oft ### Classification -| Model | Dataset | Backend | Acc-fp32 | Acc-int8 | Acc-deployed | Download | -| -------- | -------- | -------- | -------- | -------- | ------------ | ------------ | -| resnet18 | ImageNet | openvino | | | | model \| log | +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Top 1 Acc(deployed) | Config | Download | +| ------------ | -------- | -------- | --------------- | --------------- | ------------------- | ----------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| resnet18 | ImageNet | openvino | 69.90 | 69.742 | 69.74 | [config](./ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet18_8xb32_in1k_calib32xb32_20230330_163655-2386d965.log) | +| resnet50 | ImageNet | openvino | 76.55 | 76.374 | 76.378 | [config](./ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_resnet50_8xb32_in1k_calib32xb32_20230330_170115-2acd6014.log) | +| mobilenet_v2 | ImageNet | openvino | 71.86 | 70.224 | 70.292 | [config](./ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/openvino/ptq_openvino_mbv2_8xb32_in1k_calib32xb32_20230330_170909-364822ad.log) | +| resnet18 | ImageNet | tensorrt | 69.90 | 69.762 | 69.85 | [config](./ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32_20230331_144323-640b272e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet18_8xb32_in1k_calib32xb32_20230331_144323-640b272e.log) | +| resnet50 | ImageNet | tensorrt | 76.55 | 76.372 | 76.374 | [config](./ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32_20230331_145011-d2da300f.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_resnet50_8xb32_in1k_calib32xb32_20230331_145011-d2da300f.log) | +| mobilenet_v2 | ImageNet | tensorrt | 71.86 | 70.324 | 70.548 | [config](./ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32_20230331_153131-335988e4.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/ptq/tensorrt/ptq_tensorrt_mbv2_8xb32_in1k_calib32xb32_20230331_153131-335988e4.log) | + +### Detection + +| Model | Dataset | Backend | box AP(fp32) | box AP(int8) | box AP(deployed) | Config | Download | +| -------------- | ------- | -------- | ------------ | ------------ | ---------------- | -------------------------------------------------------------- | ------------------------ | +| retina_r50_fpn | COCO | openvino | 36.5 | 36.3 | 36.3 | [config](./ptq_openvino_retina_r50_1x_coco_calib32xb32.py) | [model](<>) \| [log](<>) | +| yolox_s | COCO | openvino | 40.5 | 38.5 | 38.5 | [config](./ptq_openvino_yolox_s_8xb8-300e_coco_calib32xb32.py) | [model](<>) \| [log](<>) | +| retina_r50_fpn | COCO | tensorrt | 36.5 | 36.2 | 36.3 | [config](./ptq_tensorrt_retina_r50_1x_coco_calib32xb32.py) | [model](<>) \| [log](<>) | +| yolox_s | COCO | tensorrt | 40.5 | 38.8 | 39.3 | [config](./ptq_tensorrt_yolox_s_8xb8-300e_coco_calib32xb32.py) | [model](<>) \| [log](<>) | ## Citation @@ -30,14 +44,16 @@ While neural networks have advanced the frontiers in many applications, they oft ## Getting Started -### Train +**PTQ for pretrain model** -```python -python tools/train.py ${CONFIG_FILE} +``` +python tools/ptq.py ${CONFIG} ``` -### Test +**Test for quantized model** -```python -python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} ``` +python tools/test.py ${CONFIG} ${CKPT} +``` + +For more details, please refer to [Quantization Get Start](<>) diff --git a/configs/quantization/qat/base/README.md b/configs/quantization/qat/base/README.md index d0b2d2217..0693c5221 100644 --- a/configs/quantization/qat/base/README.md +++ b/configs/quantization/qat/base/README.md @@ -12,9 +12,9 @@ While neural networks have advanced the frontiers in many applications, they oft ### Classification -| Model | Dataset | Backend | Acc-fp32 | Acc-int8 | Acc-deployed | Download | -| -------- | -------- | -------- | -------- | -------- | ------------ | ------------ | -| resnet18 | ImageNet | openvino | | | | model \| log | +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Config | Download | +| -------- | -------- | -------- | --------------- | --------------- | ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| resnet18 | ImageNet | openvino | 69.90 | 69.742(to do) | [config](./qat_openvino_resnet18_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.log) | ## Citation @@ -30,14 +30,16 @@ While neural networks have advanced the frontiers in many applications, they oft ## Getting Started -### Train +**QAT for pretrain model** -```python -python tools/train.py ${CONFIG_FILE} +``` +python tools/train.py ${CONFIG} ``` -### Test +**Test for quantized model** -```python -python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} ``` +python tools/test.py ${CONFIG} ${CKPT} +``` + +For more details, please refer to [Quantization Get Start](<>) diff --git a/configs/quantization/qat/lsq/README.md b/configs/quantization/qat/lsq/README.md index 99dd9d325..807c2786b 100644 --- a/configs/quantization/qat/lsq/README.md +++ b/configs/quantization/qat/lsq/README.md @@ -12,9 +12,10 @@ Deep networks run with low precision operations at inference time offer power an ### Classification -| Model | Dataset | Backend | Acc-fp32 | Acc-int8 | Acc-deployed | Download | -| -------- | -------- | -------- | -------- | -------- | ------------ | ------------ | -| resnet18 | ImageNet | openvino | | | | model \| log | +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Max Epochs | Config | Download | +| -------- | -------- | -------- | --------------- | --------------- | ---------- | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| resnet18 | ImageNet | openvino | 69.90 | 69.742 | 10 | [config](<>) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.log) | +| resnet18 | ImageNet | openvino | 69.90 | 69.742 | 100 | [config](<>) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.log) | ## Citation @@ -30,14 +31,16 @@ Deep networks run with low precision operations at inference time offer power an ## Getting Started -### Train +**QAT for pretrain model** -```python -python tools/train.py ${CONFIG_FILE} +``` +python tools/train.py ${CONFIG} ``` -### Test +**Test for quantized model** -```python -python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_PATH} ``` +python tools/test.py ${CONFIG} ${CKPT} +``` + +For more details, please refer to [Quantization Get Start](<>) From 700bfbd03e2a9f8062b40eaff2f5e07c71365f49 Mon Sep 17 00:00:00 2001 From: humu789 Date: Mon, 17 Apr 2023 17:27:09 +0800 Subject: [PATCH 36/44] update quantization metafile, readme, config path --- configs/quantization/qat/base/README.md | 6 +- configs/quantization/qat/base/metafile.yml | 8 +-- .../qat_openvino_resnet18_10e_8xb32_in1k.py | 0 .../base/qat_openvino_resnet18_8xb32_in1k.py | 64 ----------------- configs/quantization/qat/lsq/README.md | 8 +-- .../lsq_openvino_resnet18_8xb32_100e_in1k.py | 0 .../lsq_openvino_resnet18_8xb32_10e_in1k.py | 0 .../lsq/lsq_openvino_resnet18_8xb32_in1k.py | 68 ------------------- configs/quantization/qat/lsq/metafile.yml | 24 +++++-- model-index.yml | 4 +- 10 files changed, 33 insertions(+), 149 deletions(-) rename configs/quantization/qat/{ => base}/qat_openvino_resnet18_10e_8xb32_in1k.py (100%) delete mode 100644 configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py rename configs/quantization/qat/{ => lsq}/lsq_openvino_resnet18_8xb32_100e_in1k.py (100%) rename configs/quantization/qat/{ => lsq}/lsq_openvino_resnet18_8xb32_10e_in1k.py (100%) delete mode 100644 configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py diff --git a/configs/quantization/qat/base/README.md b/configs/quantization/qat/base/README.md index 0693c5221..1d98178f1 100644 --- a/configs/quantization/qat/base/README.md +++ b/configs/quantization/qat/base/README.md @@ -12,9 +12,9 @@ While neural networks have advanced the frontiers in many applications, they oft ### Classification -| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Config | Download | -| -------- | -------- | -------- | --------------- | --------------- | ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| resnet18 | ImageNet | openvino | 69.90 | 69.742(to do) | [config](./qat_openvino_resnet18_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.log) | +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Config | Download | +| -------- | -------- | -------- | --------------- | --------------- | --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| resnet18 | ImageNet | openvino | 69.90 | 69.98 | [config](./qat_openvino_resnet18_10e_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.log) | ## Citation diff --git a/configs/quantization/qat/base/metafile.yml b/configs/quantization/qat/base/metafile.yml index 959e74299..bd4015a50 100644 --- a/configs/quantization/qat/base/metafile.yml +++ b/configs/quantization/qat/base/metafile.yml @@ -2,7 +2,7 @@ Collections: - Name: QAT README: configs/quantization/qat/base/README.md Models: - - Name: qat_openvino_resnet18_8xb32_in1k + - Name: qat_openvino_resnet18_10e_8xb32_in1k.py In Collection: QAT Metadata: Backend: openvino @@ -15,6 +15,6 @@ Models: - Task: Image Classification Dataset: ImageNet-1k Metrics: - Top 1 Accuracy: 70.06 - Config: configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py - Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_in1k_20230402_173316-0d441f23.pth + Top 1 Accuracy: 69.98 + Config: configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/qat_openvino_resnet18_8xb32_10e_in1k_20230413_172732-5b9ff01d.pth diff --git a/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py b/configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py similarity index 100% rename from configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py rename to configs/quantization/qat/base/qat_openvino_resnet18_10e_8xb32_in1k.py diff --git a/configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py deleted file mode 100644 index 35547d91b..000000000 --- a/configs/quantization/qat/base/qat_openvino_resnet18_8xb32_in1k.py +++ /dev/null @@ -1,64 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] - -train_dataloader = dict(batch_size=1024) - -global_qconfig = dict( - w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), - a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), - w_fake_quant=dict(type='mmrazor.FakeQuantize'), - a_fake_quant=dict(type='mmrazor.FakeQuantize'), - w_qscheme=dict( - qdtype='qint8', - bit=8, - is_symmetry=True, - is_symmetric_range=True, - ), - a_qscheme=dict( - qdtype='quint8', - bit=8, - is_symmetry=True, - averaging_constant=0.1, - ), -) - -float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 - -model = dict( - _delete_=True, - type='mmrazor.MMArchitectureQuant', - architecture=_base_.model, - float_checkpoint=float_checkpoint, - quantizer=dict( - type='mmrazor.OpenVINOQuantizer', - global_qconfig=global_qconfig, - tracer=dict( - type='mmrazor.CustomTracer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ]))) - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) - -# learning policy -param_scheduler = dict( - _delete_=True, - type='CosineAnnealingLR', - T_max=100, - by_epoch=True, - begin=0, - end=100) - -model_wrapper_cfg = dict( - type='mmrazor.MMArchitectureQuantDDP', - broadcast_buffers=False, - find_unused_parameters=True) - -# train, val, test setting -train_cfg = dict( - _delete_=True, - type='mmrazor.QATEpochBasedLoop', - max_epochs=100, - val_interval=1) -val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') diff --git a/configs/quantization/qat/lsq/README.md b/configs/quantization/qat/lsq/README.md index 807c2786b..b4d40f0c4 100644 --- a/configs/quantization/qat/lsq/README.md +++ b/configs/quantization/qat/lsq/README.md @@ -12,10 +12,10 @@ Deep networks run with low precision operations at inference time offer power an ### Classification -| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Max Epochs | Config | Download | -| -------- | -------- | -------- | --------------- | --------------- | ---------- | ------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| resnet18 | ImageNet | openvino | 69.90 | 69.742 | 10 | [config](<>) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.log) | -| resnet18 | ImageNet | openvino | 69.90 | 69.742 | 100 | [config](<>) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.log) | +| Model | Dataset | Backend | Top 1 Acc(fp32) | Top 1 Acc(int8) | Max Epochs | Config | Download | +| -------- | -------- | -------- | --------------- | --------------- | ---------- | ---------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| resnet18 | ImageNet | openvino | 69.90 | 69.418 | 10 | [config](./lsq_openvino_resnet18_8xb32_10e_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.log) | +| resnet18 | ImageNet | openvino | 69.90 | 69.992 | 100 | [config](./lsq_openvino_resnet18_8xb32_100e_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.log) | ## Citation diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_100e_in1k.py similarity index 100% rename from configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py rename to configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_100e_in1k.py diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_10e_in1k.py similarity index 100% rename from configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py rename to configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_10e_in1k.py diff --git a/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py deleted file mode 100644 index 00e424141..000000000 --- a/configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py +++ /dev/null @@ -1,68 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] - -resnet = _base_.model -float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 - -global_qconfig = dict( - w_observer=dict(type='mmrazor.LSQPerChannelObserver'), - a_observer=dict(type='mmrazor.LSQObserver'), - w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - w_qscheme=dict( - qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), - a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), -) - -model = dict( - _delete_=True, - _scope_='mmrazor', - type='MMArchitectureQuant', - data_preprocessor=dict( - type='mmcls.ClsDataPreprocessor', - num_classes=1000, - # RGB format normalization parameters - mean=[123.675, 116.28, 103.53], - std=[58.395, 57.12, 57.375], - # convert image from BGR to RGB - to_rgb=True), - architecture=resnet, - float_checkpoint=float_checkpoint, - quantizer=dict( - type='mmrazor.OpenVINOQuantizer', - global_qconfig=global_qconfig, - tracer=dict( - type='mmrazor.CustomTracer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ]))) - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) - -# learning policy -param_scheduler = dict( - _delete_=True, - type='CosineAnnealingLR', - T_max=100, - by_epoch=True, - begin=0, - end=100) - -model_wrapper_cfg = dict( - type='mmrazor.MMArchitectureQuantDDP', - broadcast_buffers=False, - find_unused_parameters=True) - -# train, val, test setting -train_cfg = dict( - _delete_=True, - type='mmrazor.LSQEpochBasedLoop', - max_epochs=100, - val_interval=1, - freeze_bn_begin=1) -val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') - -# Make sure the buffer such as min_val/max_val in saved checkpoint is the same -# among different rank. -default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/lsq/metafile.yml b/configs/quantization/qat/lsq/metafile.yml index 7d9235cbf..89308d333 100644 --- a/configs/quantization/qat/lsq/metafile.yml +++ b/configs/quantization/qat/lsq/metafile.yml @@ -2,7 +2,7 @@ Collections: - Name: LSQ README: configs/quantization/qat/lsq/README.md Models: - - Name: lsq_openvino_resnet18_8xb32_in1k + - Name: lsq_openvino_resnet18_8xb32_10e_in1k.py In Collection: LSQ Metadata: Backend: openvino @@ -15,6 +15,22 @@ Models: - Task: Image Classification Dataset: ImageNet-1k Metrics: - Top 1 Accuracy: 70.06 - Config: configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_in1k.py - Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_in1k_20230402_173316-0d441f23.pth + Top 1 Accuracy: 69.418 + Config: configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_10e_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_10e_in1k_20230413_224237-36eac1f1.pth + - Name: lsq_openvino_resnet18_8xb32_100e_in1k.py + In Collection: LSQ + Metadata: + Backend: openvino + Float Model: + Config: mmcls::resnet/resnet18_8xb32_in1k.py + Weights: https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth + Metrics: + Top 1 Accuracy: 69.90 + Results: + - Task: Image Classification + Dataset: ImageNet-1k + Metrics: + Top 1 Accuracy: 69.992 + Config: configs/quantization/qat/lsq/lsq_openvino_resnet18_8xb32_100e_in1k.py + Weights: https://download.openmmlab.com/mmrazor/v1/quantization/qat/openvino/lsq_openvino_resnet18_8xb32_100e_in1k_20230402_173316-ca5993bf.pth diff --git a/model-index.yml b/model-index.yml index 51ce0b009..efcd3cae5 100644 --- a/model-index.yml +++ b/model-index.yml @@ -19,7 +19,7 @@ Import: - configs/nas/mmcls/darts/metafile.yml - configs/nas/mmdet/detnas/metafile.yml - configs/distill/mmdet/pkd/metafile.yml - # - configs/distill/mmdet3d/pkd/metafile.yml + - configs/distill/mmdet3d/pkd/metafile.yml - configs/distill/mmcls/deit/metafile.yml - configs/pruning/mmcls/group_fisher/mobilenet/metafile.yml - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml @@ -27,5 +27,5 @@ Import: - configs/pruning/mmcls/l1-norm/metafile.yml - configs/pruning/mmcls/dmcp/metafile.yml - configs/quantization/ptq/base/metafile.yml - # - configs/quantization/qat/base/metafile.yml + - configs/quantization/qat/base/metafile.yml - configs/quantization/qat/lsq/metafile.yml From d431b7dbe2c149a9f4fc091420665d43774a781a Mon Sep 17 00:00:00 2001 From: humu789 Date: Mon, 17 Apr 2023 18:01:19 +0800 Subject: [PATCH 37/44] update quantization docs --- configs/quantization/ptq/base/README.md | 2 +- configs/quantization/qat/base/README.md | 2 +- configs/quantization/qat/lsq/README.md | 2 +- .../en/user_guides/quantization_user_guide.md | 34 ++++++++++++------- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/configs/quantization/ptq/base/README.md b/configs/quantization/ptq/base/README.md index 4afcf0548..95bfa623b 100644 --- a/configs/quantization/ptq/base/README.md +++ b/configs/quantization/ptq/base/README.md @@ -56,4 +56,4 @@ python tools/ptq.py ${CONFIG} python tools/test.py ${CONFIG} ${CKPT} ``` -For more details, please refer to [Quantization Get Start](<>) +For more details, please refer to [Quantization User Guide](mmrazor/docs/en/user_guides/quantization_user_guide.md) diff --git a/configs/quantization/qat/base/README.md b/configs/quantization/qat/base/README.md index 1d98178f1..f40107be0 100644 --- a/configs/quantization/qat/base/README.md +++ b/configs/quantization/qat/base/README.md @@ -42,4 +42,4 @@ python tools/train.py ${CONFIG} python tools/test.py ${CONFIG} ${CKPT} ``` -For more details, please refer to [Quantization Get Start](<>) +For more details, please refer to [Quantization User Guide](mmrazor/docs/en/user_guides/quantization_user_guide.md) diff --git a/configs/quantization/qat/lsq/README.md b/configs/quantization/qat/lsq/README.md index b4d40f0c4..1d59f8532 100644 --- a/configs/quantization/qat/lsq/README.md +++ b/configs/quantization/qat/lsq/README.md @@ -43,4 +43,4 @@ python tools/train.py ${CONFIG} python tools/test.py ${CONFIG} ${CKPT} ``` -For more details, please refer to [Quantization Get Start](<>) +For more details, please refer to [Quantization User Guide](mmrazor/docs/en/user_guides/quantization_user_guide.md) diff --git a/docs/en/user_guides/quantization_user_guide.md b/docs/en/user_guides/quantization_user_guide.md index d645d8451..503e39913 100644 --- a/docs/en/user_guides/quantization_user_guide.md +++ b/docs/en/user_guides/quantization_user_guide.md @@ -17,31 +17,30 @@ MMRazor's quantization is OpenMMLab's quantization toolkit, which has got throug MMRazor's quantization is based on `torch==1.13`. Other requirements are the same as MMRazor's ``` -Model quantization is in mmrazor, but quantized model deployment is in mmdeploy. So we need to use two branches as follows: +Model quantization is in mmrazor, but quantized model deployment is in mmdeploy. So we need to the another branches as follows if we need to delopy our quantized model: -mmrazor: https://github.com/open-mmlab/mmrazor/tree/quantize - -mmdeploy: https://github.com/humu789/mmdeploy/tree/adapt_razor_quantize +mmdeploy: https://github.com/open-mmlab/mmdeploy/tree/for_mmrazor 1. Quantize the float model in mmrazor. ```Shell # For QAT (Quantization Aware Training) -python tools/train.py ${CONFIG_FILE} [optional arguments] +python tools/train.py ${CONFIG_PATH} [optional arguments] # For PTQ (Post-training quantization) -python tools/ptq.py ${CONFIG_FILE} [optional arguments] +python tools/ptq.py ${CONFIG_PATH} [optional arguments] ``` -2. Convert quantized model checkpoint in mmrazor. (required by model deployment) +2. Evaluate the quantized model. (optional) ```Shell -python tools/model_converters/convert_quant_ckpt.py ${CKPT_PATH} +python tools/test.py ${CONFIG_PATH} ${CHECKPOINT_PATH} ``` 3. Export quantized model to a specific backend in mmdeploy. (required by model deployment) ```Shell +# MODEL_CFG_PATH is the used config in mmrazor. python ./tools/deploy.py \ ${DEPLOY_CFG_PATH} \ ${MODEL_CFG_PATH} \ @@ -52,7 +51,7 @@ python ./tools/deploy.py \ This step is the same as how to export an OpenMMLab model to a specific backend. For more details, please refer to [How to convert model](https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/02-how-to-run/convert_model.md) -4. Evaluate the exported model. (optional) +4. Evaluate the quantized backend model. (optional) ```Shell python tools/test.py \ @@ -77,13 +76,16 @@ You can refer to the previous chapter Quick Run. Let us take `resnet50` as an example to show how to handle case 2. ```Python -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +_base_ = [ + 'mmcls::resnet/resnet18_8xb32_in1k.py', + '../../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] -train_dataloader = dict(batch_size=32) +val_dataloader = dict(batch_size=32) test_cfg = dict( type='mmrazor.PTQLoop', - calibrate_dataloader=train_dataloader, + calibrate_dataloader=val_dataloader, calibrate_steps=32, ) @@ -112,6 +114,7 @@ model = dict( # convert image from BGR to RGB to_rgb=True), architecture=_base_.model, + deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', @@ -210,6 +213,13 @@ Include some basic quantization configurations. `qdtype`: to specify whether quantized data type is sign or unsign. It can be chosen from \[ 'qint8', 'quint8' \] +```{note} +If your model need to be deployed, `qdtype` must be consistent with the dtype in the corresponding backendconfig. Otherwise fakequant will not be inserted in front of the specified OPs. + +backendconfigs dir: +mmrazor/mmrazor/structures/quantization/backend_config +``` + `bit`: to specify the quantized data bit. It can be chosen from \[1 ~ 16\]. `is_symmetry`: to specify whether to use symmetry quantization. It can be chosen from \[ True, False \] From 0608dd4732a4ae2c2a1fc5074b379e8a4c88ce9d Mon Sep 17 00:00:00 2001 From: humu789 Date: Mon, 17 Apr 2023 18:04:58 +0800 Subject: [PATCH 38/44] update git main link in workflow --- .github/workflows/build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2c2b8ed21..4afb998af 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -103,11 +103,11 @@ jobs: pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - name: Install MMCls - run: pip install git+https://github.com/open-mmlab/mmclassification.git@dev-1.x + run: pip install git+https://github.com/open-mmlab/mmclassification.git@main - name: Install MMDet - run: pip install git+https://github.com/open-mmlab/mmdetection.git@dev-3.x + run: pip install git+https://github.com/open-mmlab/mmdetection.git@main - name: Install MMSeg - run: pip install git+https://github.com/open-mmlab/mmsegmentation.git@dev-1.x + run: pip install git+https://github.com/open-mmlab/mmsegmentation.git@main - name: Install other dependencies run: pip install -r requirements.txt - name: Build and install From bd3449a499bd8eb44f03f9ee3913c2d872a0720c Mon Sep 17 00:00:00 2001 From: humu789 Date: Mon, 17 Apr 2023 18:16:25 +0800 Subject: [PATCH 39/44] update benchmark_summary_analyse.py --- .dev_scripts/benchmark_summary_analyse.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.dev_scripts/benchmark_summary_analyse.py b/.dev_scripts/benchmark_summary_analyse.py index a4896990c..372e1326c 100644 --- a/.dev_scripts/benchmark_summary_analyse.py +++ b/.dev_scripts/benchmark_summary_analyse.py @@ -15,7 +15,8 @@ def parse_args(): metric_mapping = { 'Top 1 Accuracy': 'accuracy/top1', 'Top 5 Accuracy': 'accuracy/top5', - 'box AP': 'coco/bbox_mAP' + 'box AP': 'coco/bbox_mAP', + 'mIoU': 'mIoU' } From 07987a6d1f8feb294f7bd742e3fa1d1e4f895ded Mon Sep 17 00:00:00 2001 From: humu789 Date: Mon, 17 Apr 2023 18:48:52 +0800 Subject: [PATCH 40/44] del dmcp results --- configs/pruning/mmcls/dmcp/README.md | 8 ++++---- configs/pruning/mmcls/dmcp/metafile.yml | 20 ++++++++++---------- model-index.yml | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/configs/pruning/mmcls/dmcp/README.md b/configs/pruning/mmcls/dmcp/README.md index 60dbd9849..3a96b61c3 100644 --- a/configs/pruning/mmcls/dmcp/README.md +++ b/configs/pruning/mmcls/dmcp/README.md @@ -26,7 +26,7 @@ GPUS=32 sh tools/slurm_train.sh $PARTITION $JOB_NAME \ --work-dir $WORK_DIR ``` -## Results and models + @@ -42,12 +42,12 @@ GPUS=32 sh tools/slurm_train.sh $PARTITION $JOB_NAME \ | ImageNet | ResNet50 | 2.07G(Subnet) | 76.11 | 93.01 | [config](./dmcp_resnet50_subnet_32xb64.py) | [model](https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/2G/DMCP_R50_2G.pth) / [log](https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/2G/dmcp_resnet50_supernet_32xb64_target_flops_2g_20230129_112944.log) | [arch\*](./DMCP_R50_2G.json) | | ImageNet | ResNet50 | 1.05G(Subnet) | 74.12 | 92.33 | [config](./dmcp_resnet50_subnet_32xb64.py) | [model](https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/1G/DMCP_R50_1G.pth) / [log](https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/1G/dmcp_resnet50_supernet_32xb64_target_flops_1g_20230107_223552.log) | [arch](https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/resnet50/1G/DMCP_R50_1G.json) | --> -**Note** + ## Citation diff --git a/configs/pruning/mmcls/dmcp/metafile.yml b/configs/pruning/mmcls/dmcp/metafile.yml index f20e2488f..131f5c289 100644 --- a/configs/pruning/mmcls/dmcp/metafile.yml +++ b/configs/pruning/mmcls/dmcp/metafile.yml @@ -1,4 +1,4 @@ -Models: +# Models: # - Name: dmcp_resnet50_subnet_32xb64 # In Collection: DMCP # Config: configs/pruning/mmcls/dmcp/dmcp_resnet50_subnet_32xb64.py @@ -8,12 +8,12 @@ Models: # Dataset: ImageNet-1k # Metrics: # Top 1 Accuracy: 76.11 - - Name: dmcp_mbv2_subnet_32xb64 - In Collection: DMCP - Config: configs/pruning/mmcls/dmcp/dmcp_mbv2_subnet_32xb64.py - Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/mobilenetv2/100M/DMCP_MBV2_100M.pth - Results: - - Task: Image Classification - Dataset: ImageNet-1k - Metrics: - Top 1 Accuracy: 67.22 + # - Name: dmcp_mbv2_subnet_32xb64 + # In Collection: DMCP + # Config: configs/pruning/mmcls/dmcp/dmcp_mbv2_subnet_32xb64.py + # Weights: https://download.openmmlab.com/mmrazor/v1/pruning/dmcp/mobilenetv2/100M/DMCP_MBV2_100M.pth + # Results: + # - Task: Image Classification + # Dataset: ImageNet-1k + # Metrics: + # Top 1 Accuracy: 67.22 diff --git a/model-index.yml b/model-index.yml index efcd3cae5..8087d81fc 100644 --- a/model-index.yml +++ b/model-index.yml @@ -25,7 +25,7 @@ Import: - configs/pruning/mmcls/group_fisher/resnet50/metafile.yml - configs/pruning/mmdet/group_fisher/retinanet/metafile.yml - configs/pruning/mmcls/l1-norm/metafile.yml - - configs/pruning/mmcls/dmcp/metafile.yml + # - configs/pruning/mmcls/dmcp/metafile.yml - configs/quantization/ptq/base/metafile.yml - configs/quantization/qat/base/metafile.yml - configs/quantization/qat/lsq/metafile.yml From 2ba02c7d69af9c50d7a9540d6de611f2d0ad64ea Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Mon, 17 Apr 2023 18:58:46 +0800 Subject: [PATCH 41/44] [Bug] fix a rebase error (#514) fix a rebase error --- configs/pruning/mmpose/dcff/fix_subnet.json | 4 ---- 1 file changed, 4 deletions(-) diff --git a/configs/pruning/mmpose/dcff/fix_subnet.json b/configs/pruning/mmpose/dcff/fix_subnet.json index f7b40f41d..dfdcea758 100644 --- a/configs/pruning/mmpose/dcff/fix_subnet.json +++ b/configs/pruning/mmpose/dcff/fix_subnet.json @@ -54,11 +54,7 @@ "min_value":1, "min_ratio":0.9 }, -<<<<<<< HEAD "choice":0.59375 -======= - "choice":0.59374 ->>>>>>> 985a611e (Merge dev-1.x into quantize (#430)) }, "backbone.layer2.1.conv1_(0, 128)_128":{ "init_args":{ From f9c2cb06da442933a3a3f0dcfe187bfc0b604bb2 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Tue, 18 Apr 2023 01:07:52 +0800 Subject: [PATCH 42/44] [Bug] Fix CI (#515) * fix ci * mmcv2.0 need torch1.8+ --- .github/workflows/build.yml | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4afb998af..88928727e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,18 +29,8 @@ jobs: strategy: matrix: python-version: [3.7] - torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] + torch: [1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] include: - - torch: 1.6.0 - torch_version: 1.6 - torchvision: 0.7.0 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - python-version: 3.8 - torch: 1.8.0 torch_version: 1.8 torchvision: 0.9.0 @@ -103,7 +93,7 @@ jobs: pip install -U openmim mim install 'mmcv >= 2.0.0rc1' - name: Install MMCls - run: pip install git+https://github.com/open-mmlab/mmclassification.git@main + run: pip install 'mmcls>=1.0.0rc0' - name: Install MMDet run: pip install git+https://github.com/open-mmlab/mmdetection.git@main - name: Install MMSeg From 7282a150e4d4eaafdce5f2a7b32963921652e481 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Tue, 18 Apr 2023 02:13:37 +0800 Subject: [PATCH 43/44] Update CI config and Passed (#516) * test ci * update test.yml based on mmcv2.0.0 --- .circleci/test.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.circleci/test.yml b/.circleci/test.yml index 25140a879..38a4a4e3b 100644 --- a/.circleci/test.yml +++ b/.circleci/test.yml @@ -103,9 +103,9 @@ jobs: name: Clone Repos command: | git clone -b main --depth 1 https://github.com/open-mmlab/mmengine.git /home/circleci/mmengine - git clone -b dev-3.x --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection - git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification - git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation + git clone -b main --depth 1 https://github.com/open-mmlab/mmdetection.git /home/circleci/mmdetection + git clone -b 1.x --depth 1 https://github.com/open-mmlab/mmclassification.git /home/circleci/mmclassification + git clone -b main --depth 1 https://github.com/open-mmlab/mmsegmentation.git /home/circleci/mmsegmentation - run: name: Build Docker image command: | @@ -153,15 +153,15 @@ workflows: - dev-1.x - build_cpu: name: minimum_version_cpu - torch: 1.6.0 - torchvision: 0.7.0 - python: 3.7.9 + torch: 1.8.1 + torchvision: 0.9.1 + python: 3.7.4 requires: - lint - build_cpu: name: maximum_version_cpu - torch: 1.12.1 - torchvision: 0.13.1 + torch: 1.13.1 + torchvision: 0.14.1 python: 3.9.0 requires: - lint @@ -183,7 +183,7 @@ workflows: jobs: - build_cuda: name: minimum_version_gpu - torch: 1.6.0 + torch: 1.8.1 # Use double quotation mark to explicitly specify its type # as string instead of number cuda: "10.1" From ab19a46f04164165659a1e8e462257e551901055 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Tue, 18 Apr 2023 12:34:37 +0800 Subject: [PATCH 44/44] [Docs] Fix cwd test accuary (#517) * test ci * update test.yml based on mmcv2.0.0 * update cwd_logits_pspnet result --- configs/distill/mmcls/ofd/README.md | 14 +++++++------- configs/distill/mmcls/ofd/metafile.yml | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/configs/distill/mmcls/ofd/README.md b/configs/distill/mmcls/ofd/README.md index 74a931b0d..eb789e840 100644 --- a/configs/distill/mmcls/ofd/README.md +++ b/configs/distill/mmcls/ofd/README.md @@ -22,16 +22,16 @@ We investigate the design aspects of feature distillation methods achieving netw #### Vanilla -| Dataset | Model | Top-1 (%) | Top-5 (%) | Download | -| ------- | ----------------------------------------------------------------------- | --------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| CIFAR10 | [WRN16-2](../../../vanilla/mmcls/wide-resnet/wrn16-w2_b16x8_cifar10.py) | 93.43 | 99.75 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.json) | -| CIFAR10 | [WRN28-4](../../../vanilla/mmcls/wide-resnet/wrn28-w4_b16x8_cifar10.py) | 95.49 | 99.81 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.json) | +| Dataset | Model | Top-1 (%) | Download | +| ------- | ----------------------------------------------------------------------- | --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| CIFAR10 | [WRN16-2](../../../vanilla/mmcls/wide-resnet/wrn16-w2_b16x8_cifar10.py) | 93.43 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn16_2_b16x8_cifar10_20220831_204709-446b466e.json) | +| CIFAR10 | [WRN28-4](../../../vanilla/mmcls/wide-resnet/wrn28-w4_b16x8_cifar10.py) | 95.49 | [model](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/wide_resnet/wrn28_4_b16x8_cifar10_20220831_173536-d6f8725c.json) | #### Distillation -| Dataset | Model | Flops(M) | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download | -| ------- | ------- | -------- | ------- | --------- | --------- | ----------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| CIFAR10 | WRN16-2 | 101 | WRN28-4 | 95.23 | 99.79 | [config](./ofd_backbone_resnet50_resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20220831_220553-f5d12e61.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20220831_220553-f5d12e61.json) | +| Dataset | Model | Flops(M) | Teacher | Top-1 (%) | Configs | Download | +| ------- | ------- | -------- | ------- | --------- | ----------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| CIFAR10 | WRN16-2 | 101 | WRN28-4 | 94.21 | [config](./ofd_backbone_resnet50_resnet18_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20230417_192216-ace2908f.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20230417_192216-ace2908f.log) | ## Getting Started diff --git a/configs/distill/mmcls/ofd/metafile.yml b/configs/distill/mmcls/ofd/metafile.yml index 21716fd5c..cb176b1c3 100644 --- a/configs/distill/mmcls/ofd/metafile.yml +++ b/configs/distill/mmcls/ofd/metafile.yml @@ -33,6 +33,6 @@ Models: - Task: Image Classification Dataset: CIFAR-10 Metrics: - Top 1 Accuracy: 95.4400 + Top 1 Accuracy: 94.21 Config: configs/distill/mmcls/ofd/ofd_backbone_resnet50_resnet18_8xb16_cifar10.py - Weights: https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20220831_220553-f5d12e61.pth + Weights: https://download.openmmlab.com/mmrazor/v1/overhaul/ofd_backbone_resnet50_resnet18_8xb16_cifar10_20230417_192216-ace2908f.pth