From 85ed93ec5e1db67ee2aa6b6ea632c1a90f529f89 Mon Sep 17 00:00:00 2001 From: "P.Huang" <37200926+FreakieHuang@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:18:20 +0800 Subject: [PATCH] [FEATURE] add quant algo `Learned Step Size Quantization` (#346) * update * Fix a bug in make_divisible. (#333) fix bug in make_divisible Co-authored-by: liukai * [Fix] Fix counter mapping bug (#331) * fix counter mapping bug * move judgment into get_counter_type & update UT * [Docs]Add MMYOLO projects link (#334) * [Doc] fix typos in en/usr_guides (#299) * Update README.md * Update README_zh-CN.md Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> * [Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder` (#320) * support MethodInputsRecorder and FunctionInputsRecorder * fix bugs that the model can not be pickled * WIP: add pytest for ema model * fix bugs in recorder and delivery when ema_hook is used * don't register the DummyDataset * fix pytest * updated * retina loss & predict & tesnor DONE * [Feature] Add deit-base (#332) * WIP: support deit * WIP: add deithead * WIP: fix checkpoint hook * fix data preprocessor * fix cfg * WIP: add readme * reset single_teacher_distill * add metafile * add model to model-index * fix configs and readme * [Feature]Feature map visualization (#293) * WIP: vis * WIP: add visualization * WIP: add visualization hook * WIP: support razor visualizer * WIP * WIP: wrap draw_featmap * support feature map visualization * add a demo image for visualization * fix typos * change eps to 1e-6 * add pytest for visualization * fix vis hook * fix arguments' name * fix img path * support draw inference results * add visualization doc * fix figure url * move files Co-authored-by: weihan cao * [Feature] Add kd examples (#305) * support kd for mbv2 and shufflenetv2 * WIP: fix ckpt path * WIP: fix kd r34-r18 * add metafile * fix metafile * delete * [Doc] add documents about pruning. (#313) * init * update user guide * update images * update * update How to prune your model * update how_to_use_config_tool_of_pruning.md * update doc * move location * update * update * update * add mutablechannels.md * add references Co-authored-by: liukai Co-authored-by: jacky * [Feature] PyTorch version of `PKD: General Distillation Framework for Object Detectors via Pearson Correlation Coefficient`. (#304) * add pkd * add pytest for pkd * fix cfg * WIP: support fcos3d * WIP: support fcos3d pkd * support mmdet3d * fix cfgs * change eps to 1e-6 and add some comments * fix docstring * fix cfg * add assert * add type hint * WIP: add readme and metafile * fix readme * update metafiles and readme * fix metafile * fix pipeline figure * for RFC * Customed FX initialize * add UT init * [Refactor] Refactor Mutables and Mutators (#324) * refactor mutables * update load fix subnet * add DumpChosen Typehint * adapt UTs * fix lint * Add GroupMixin to ChannelMutator (temporarily) * fix type hints * add GroupMixin doc-string * modified by comments * fix type hits * update subnet format * fix channel group bugs and add UTs * fix doc string * fix comments * refactor diff module forward * fix error in channel mutator doc * fix comments Co-authored-by: liukai * [Fix] Update readme (#341) * update kl readme * update dsnas readme * fix url * Bump version to 1.0.0rc1 (#338) update version * init demo * add customer_tracer * add quantizer * add fake_quant, loop, config * remove CPatcher in custome_tracer * demo_try * init version * modified base.py * pre-rebase * wip of adaround series * adaround experiment * trasfer to s2 * update api * point at sub_reconstruction * pre-checkout * export onnx * add customtracer * fix lint * move custom tracer * fix import * TDO: UTs * Successfully RUN * update loop * update loop docstrings * update quantizer docstrings * update qscheme docstrings * update qobserver docstrings * update tracer docstrings * update UTs init * update UTs init * fix review comments * fix CI * fix UTs * update torch requirements Co-authored-by: huangpengsheng Co-authored-by: LKJacky <108643365+LKJacky@users.noreply.github.com> Co-authored-by: liukai Co-authored-by: Yang Gao Co-authored-by: kitecats <90194592+kitecats@users.noreply.github.com> Co-authored-by: Sheffield <49406546+SheffieldCao@users.noreply.github.com> Co-authored-by: whcao <41630003+HIT-cwh@users.noreply.github.com> Co-authored-by: jacky Co-authored-by: pppppM <67539920+pppppM@users.noreply.github.com> Co-authored-by: humu789 --- .github/workflows/build.yml | 38 --- configs/quantization/ptq/adaround.py | 47 +++ configs/quantization/ptq/demo.py | 1 + configs/quantization/qat/demo.py | 1 + .../qat/lsq_resnet50_8xb16_cifar10.py | 37 +++ mmrazor/engine/runner/__init__.py | 4 +- mmrazor/engine/runner/quantization_loops.py | 298 ++++++++++++++++++ mmrazor/engine/runner/utils/state.py | 19 ++ mmrazor/engine/runner/utils/subgraph.py | 61 ++++ mmrazor/models/__init__.py | 4 + mmrazor/models/algorithms/__init__.py | 3 +- .../algorithms/quantization/__init__.py | 4 + .../models/algorithms/quantization/base.py | 116 +++++++ mmrazor/models/fake_quants/__init__.py | 10 + mmrazor/models/fake_quants/adaround.py | 98 ++++++ mmrazor/models/fake_quants/base.py | 124 ++++++++ mmrazor/models/fake_quants/lsq.py | 137 ++++++++ mmrazor/models/fake_quants/qdrop.py | 44 +++ mmrazor/models/losses/__init__.py | 1 + mmrazor/models/losses/adaround_loss.py | 87 +++++ .../mutable_channel/units/channel_unit.py | 5 + .../channel_mutator/channel_mutator.py | 18 +- .../slimmable_channel_mutator.py | 2 +- .../module_mutator/diff_module_mutator.py | 117 +++++++ .../mutators/module_mutator/module_mutator.py | 94 ++++++ mmrazor/models/observers/__init__.py | 5 + mmrazor/models/observers/base.py | 74 +++++ mmrazor/models/observers/lsq_observer.py | 59 ++++ mmrazor/models/observers/minmax.py | 97 ++++++ mmrazor/models/observers/mse.py | 156 +++++++++ mmrazor/models/quantizers/__init__.py | 5 + mmrazor/models/quantizers/base.py | 194 ++++++++++++ mmrazor/models/quantizers/trt_quantizer.py | 23 ++ .../models/task_modules/tracer/__init__.py | 4 +- .../models/task_modules/tracer/fx/__init__.py | 5 + .../task_modules/tracer/fx/custom_tracer.py | 281 +++++++++++++++++ mmrazor/models/utils/quantization_util.py | 217 +++++++++++++ mmrazor/structures/quantization/__init__.py | 5 + .../quantization/backend_default_qconfigs.py | 46 +++ mmrazor/structures/quantization/qscheme.py | 68 ++++ mmrazor/testing/__init__.py | 1 + mmrazor/testing/_fx_models.py | 42 +++ .../test_algorithms/test_general_quant.py | 34 ++ .../test_lsq_fake_quants.py | 23 ++ .../test_mutators/test_diff_mutator.py | 235 ++++++++++++++ .../test_observers/test_observer.py | 38 +++ .../test_quantizers/test_trt_quantizer.py | 34 ++ .../test_task_modules/test_custom_tracer.py | 35 ++ tools/ckpt_demo.py | 13 + tools/slurm_test.sh | 26 +- tools/tracer_demo.py | 93 ++++++ 51 files changed, 3116 insertions(+), 67 deletions(-) create mode 100644 configs/quantization/ptq/adaround.py create mode 100644 configs/quantization/ptq/demo.py create mode 100644 configs/quantization/qat/demo.py create mode 100644 configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py create mode 100644 mmrazor/engine/runner/quantization_loops.py create mode 100644 mmrazor/engine/runner/utils/state.py create mode 100644 mmrazor/engine/runner/utils/subgraph.py create mode 100644 mmrazor/models/algorithms/quantization/__init__.py create mode 100644 mmrazor/models/algorithms/quantization/base.py create mode 100644 mmrazor/models/fake_quants/__init__.py create mode 100644 mmrazor/models/fake_quants/adaround.py create mode 100644 mmrazor/models/fake_quants/base.py create mode 100644 mmrazor/models/fake_quants/lsq.py create mode 100644 mmrazor/models/fake_quants/qdrop.py create mode 100644 mmrazor/models/losses/adaround_loss.py create mode 100644 mmrazor/models/mutators/module_mutator/diff_module_mutator.py create mode 100644 mmrazor/models/mutators/module_mutator/module_mutator.py create mode 100644 mmrazor/models/observers/__init__.py create mode 100644 mmrazor/models/observers/base.py create mode 100644 mmrazor/models/observers/lsq_observer.py create mode 100644 mmrazor/models/observers/minmax.py create mode 100644 mmrazor/models/observers/mse.py create mode 100644 mmrazor/models/quantizers/__init__.py create mode 100644 mmrazor/models/quantizers/base.py create mode 100644 mmrazor/models/quantizers/trt_quantizer.py create mode 100644 mmrazor/models/task_modules/tracer/fx/__init__.py create mode 100644 mmrazor/models/task_modules/tracer/fx/custom_tracer.py create mode 100644 mmrazor/models/utils/quantization_util.py create mode 100644 mmrazor/structures/quantization/__init__.py create mode 100644 mmrazor/structures/quantization/backend_default_qconfigs.py create mode 100644 mmrazor/structures/quantization/qscheme.py create mode 100644 mmrazor/testing/_fx_models.py create mode 100644 tests/test_models/test_algorithms/test_general_quant.py create mode 100644 tests/test_models/test_fake_quantize/test_lsq_fake_quants.py create mode 100644 tests/test_models/test_mutators/test_diff_mutator.py create mode 100644 tests/test_models/test_observers/test_observer.py create mode 100644 tests/test_models/test_quantizers/test_trt_quantizer.py create mode 100644 tests/test_models/test_task_modules/test_custom_tracer.py create mode 100644 tools/ckpt_demo.py create mode 100644 tools/tracer_demo.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e00ed24c8..53a184a3d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,44 +31,6 @@ jobs: python-version: [3.7] torch: [1.6.0, 1.7.0, 1.8.0, 1.9.0, 1.10.0, 1.11.0, 1.12.0, 1.13.0] include: - - torch: 1.6.0 - torch_version: 1.6 - torchvision: 0.7.0 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - - torch: 1.7.0 - torch_version: 1.7 - torchvision: 0.8.1 - python-version: 3.8 - - torch: 1.8.0 - torch_version: 1.8 - torchvision: 0.9.0 - - torch: 1.8.0 - torch_version: 1.8 - torchvision: 0.9.0 - python-version: 3.8 - - torch: 1.9.0 - torch_version: 1.9 - torchvision: 0.10.0 - - torch: 1.9.0 - torch_version: 1.9 - torchvision: 0.10.0 - python-version: 3.8 - - torch: 1.10.0 - torch_version: 1.10 - torchvision: 0.11.0 - - torch: 1.10.0 - torch_version: 1.10 - torchvision: 0.11.0 - python-version: 3.8 - - torch: 1.11.0 - torch_version: 1.11 - torchvision: 0.12.0 - - torch: 1.11.0 - torch_version: 1.11 - torchvision: 0.12.0 - python-version: 3.8 - torch: 1.12.0 torch_version: 1.12 torchvision: 0.13.0 diff --git a/configs/quantization/ptq/adaround.py b/configs/quantization/ptq/adaround.py new file mode 100644 index 000000000..389575dc6 --- /dev/null +++ b/configs/quantization/ptq/adaround.py @@ -0,0 +1,47 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +test_cfg = dict( + _delete_=True, + type='mmrazor.PTQLoop', + dataloader=_base_.test_dataloader, + evaluator=_base_.test_evaluator, + calibrate_dataloader=_base_.train_dataloader, + batch_num=32, + # reconstruction_cfg=dict( + # pattern='layer', + # loss=dict( + # type='mmrazor.AdaRoundLoss', + # iters=20000 + # ) + # ) +) + +model = dict( + _delete_=True, + type='mmrazor.GeneralQuant', + architecture=_base_.model, + quantizer=dict( + type='mmrazor.CustomQuantizer', + is_qat=False, + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.MSEObserver'), + a_observer=dict(type='mmrazor.EMAMSEObserver'), + w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + bit=2, + is_symmetry=False, + is_per_channel=True, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=4, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) diff --git a/configs/quantization/ptq/demo.py b/configs/quantization/ptq/demo.py new file mode 100644 index 000000000..af6a0a5df --- /dev/null +++ b/configs/quantization/ptq/demo.py @@ -0,0 +1 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] diff --git a/configs/quantization/qat/demo.py b/configs/quantization/qat/demo.py new file mode 100644 index 000000000..be3ec6013 --- /dev/null +++ b/configs/quantization/qat/demo.py @@ -0,0 +1 @@ +_base_ = ['./lsq_resnet50_8xb16_cifar10.py'] diff --git a/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py new file mode 100644 index 000000000..a246bc265 --- /dev/null +++ b/configs/quantization/qat/lsq_resnet50_8xb16_cifar10.py @@ -0,0 +1,37 @@ +_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] + +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=_base_.train_cfg.max_epochs, +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='GeneralQuant', + architecture={{_base_.model}}, + quantizer=dict( + type='TensorRTQuantizer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ], + qconfig=dict( + qtype='affine', + w_observer=dict(type='mmrazor.MinMaxObserver'), + a_observer=dict(type='mmrazor.EMAMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + bit=2, + is_symmetry=False, + is_per_channel=True, + is_pot_scale=False, + ), + a_qscheme=dict( + bit=4, + is_symmetry=False, + is_per_channel=False, + is_pot_scale=False), + ))) diff --git a/mmrazor/engine/runner/__init__.py b/mmrazor/engine/runner/__init__.py index 10eb2b598..647d8b410 100644 --- a/mmrazor/engine/runner/__init__.py +++ b/mmrazor/engine/runner/__init__.py @@ -4,6 +4,7 @@ from .distill_val_loop import SelfDistillValLoop, SingleTeacherDistillValLoop from .evolution_search_loop import EvolutionSearchLoop from .iteprune_val_loop import ItePruneValLoop +from .quantization_loops import PTQLoop, QATEpochBasedLoop from .slimmable_val_loop import SlimmableValLoop from .subnet_sampler_loop import GreedySamplerTrainLoop from .subnet_val_loop import SubnetValLoop @@ -12,5 +13,6 @@ 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', 'GreedySamplerTrainLoop', 'SubnetValLoop', 'SelfDistillValLoop', - 'ItePruneValLoop', 'AutoSlimGreedySearchLoop' + 'ItePruneValLoop', 'AutoSlimGreedySearchLoop', 'PTQLoop', + 'QATEpochBasedLoop' ] diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py new file mode 100644 index 000000000..2f15f5deb --- /dev/null +++ b/mmrazor/engine/runner/quantization_loops.py @@ -0,0 +1,298 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import os +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.evaluator import Evaluator +from mmengine.registry import MODELS +from mmengine.runner import EpochBasedTrainLoop, TestLoop +from torch.utils.data import DataLoader + +from mmrazor.models.task_modules import (ModuleInputsRecorder, + ModuleOutputsRecorder, + RecorderManager) +from mmrazor.registry import LOOPS +from .utils import extract_blocks, extract_layers, extract_subgraph + +_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) + + +@LOOPS.register_module() +class QATEpochBasedLoop(EpochBasedTrainLoop): + """`EpochBasedLoop` for `QuantizationAwareTraining` + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + max_epochs (int): Total training epochs. + calibrate_dataloader (Dataloader or dict, optional): A dataloader + object or a dict to build a dataloader for calibration. Defaults + to None. + val_begin (int): The epoch that begins validating. + Defaults to 1. + val_interval (int): Validation interval. Defaults to 1. + dynamic_intervals (List[Tuple[int, int]], optional): The + first element in the tuple is a milestone and the second + element is a interval. The interval is used after the + corresponding milestone. Defaults to None. + """ + + def __init__( + self, + runner, + dataloader: Union[DataLoader, Dict], + max_epochs: int, + calibrate_dataloader: Union[DataLoader, Dict] = None, + val_begin: int = 1, + val_interval: int = 1, + dynamic_intervals: Optional[List[Tuple[int, int]]] = None) -> None: + super().__init__(runner, dataloader, max_epochs, val_begin, + val_interval, dynamic_intervals) + if isinstance(calibrate_dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) + self.calibrate_dataloader = runner.build_dataloader( + calibrate_dataloader, + seed=runner.seed, + diff_rank_seed=diff_rank_seed) + else: + self.calibrate_dataloader = calibrate_dataloader + + self.is_calibrate = True if calibrate_dataloader is not None else False + + if self.runner.distributed: + self.model = runner.model.module + else: + self.model = runner.model + + def calibrate(self, calibrate_dataloader) -> None: + self.model.eval() + with torch.no_grad(): + for batch_data in calibrate_dataloader: + self.model(batch_data) + + def run(self) -> None: + """Launch training.""" + self.runner.call_hook('before_train') + + self.model.prepare() + + if self.is_calibrate: + self.model.state = (1, 0) + self.calibrate(self.calibrate_dataloader) + + self.model.state = (1, 1) + + while self._epoch < self._max_epochs: + self.run_epoch() + + self._decide_current_val_interval() + if (self.runner.val_loop is not None + and self._epoch >= self.val_begin + and self._epoch % self.val_interval == 0): + self.runner.val_loop.run() + + self.model.convert() + + # self.runner.val_loop.run() + + self.runner.call_hook('after_train') + + +@LOOPS.register_module() +class PTQLoop(TestLoop): + """`TestLoop` for Post Training Quantization. + + Args: + runner (Runner): A reference of runner + dataloader (Dataloader or dict): An iterator to generate one batch of + dataset each iteration. + evaluator (Evaluator or dict or list): Used for computing metrics. + calibrate_dataloader (Dataloader or dict, optional): A dataloader + object or a dict to build a dataloader for calibration. Defaults + to None. + batch_num (Optional[int], optional): Total calibration batches. + Defaults to None. + reconstruction_cfg (Optional[Dict], optional): Model reconstruction + configuration. Defaults to None. + fp16 (bool, optional): Enable FP16 training mode. Defaults to False. + """ + + def __init__(self, + runner, + dataloader: Union[DataLoader, Dict], + evaluator: Union[Evaluator, Dict, List], + calibrate_dataloader: Optional[Union[DataLoader, + Dict]] = None, + batch_num: Optional[int] = None, + reconstruction_cfg: Optional[Dict] = None, + fp16: bool = False): + super().__init__(runner, dataloader, evaluator, fp16) + if isinstance(calibrate_dataloader, dict): + # Determine whether or not different ranks use different seed. + diff_rank_seed = runner._randomness_cfg.get( + 'diff_rank_seed', False) + self.calibrate_dataloader = runner.build_dataloader( + calibrate_dataloader, + seed=runner.seed, + diff_rank_seed=diff_rank_seed) + else: + self.calibrate_dataloader = calibrate_dataloader + + self.is_calibrate = True if calibrate_dataloader is not None else False + + if self.runner.distributed: + self.model = runner.model.module + else: + self.model = runner.model + + self.batch_num = batch_num + self.config = reconstruction_cfg + + def calibrate(self, calibrate_dataloader) -> None: + self.model.eval() + with torch.no_grad(): + for i, batch_data in enumerate(calibrate_dataloader): + if self.batch_num and i >= self.batch_num: + break + self.model.calib_step(batch_data) + + def _save_inter_result(self, + model, + dataloader, + slices, + store_input=True, + store_output=True): + recorders = {} + for s in slices: + node_l, node_r = s[:2] + if store_input: + recorders[node_l.target + '_input'] = ModuleInputsRecorder( + node_l.target) + if store_output: + recorders[node_r.target + '_output'] = ModuleOutputsRecorder( + node_r.target) + manager = RecorderManager(recorders) + manager.initialize(model) + + with torch.no_grad(): + with manager: + for i, batch_data in enumerate(dataloader): + if self.batch_num and i >= self.batch_num: + break + batch_data = self.model.data_preprocessor( + batch_data, False) + model(**batch_data) + return manager + + def sub_reconstruction(self, graphmodule, input_recorder, output_recorder, + config): + w_para = [] + for layer in graphmodule.modules(): + # import pdb + # pdb.set_trace() + if isinstance(layer, _ADAROUND_SUPPORT_TYPE): + weight_fake_quant = layer.weight_fake_quant + weight_fake_quant.init(layer.weight.data) + w_para += [weight_fake_quant.alpha] + + w_opt = torch.optim.Adam(w_para) + loss_func = MODELS.build(config.loss) + + for _ in range(config.loss.iters): + w_opt.zero_grad() + + data_size = len(input_recorder.data_buffer) + data_index = np.random.randint(0, data_size) + out_quant = graphmodule( + input_recorder.get_recorder_data(data_index)) + out_fp = output_recorder.get_recorder_data(data_index) + err = loss_func(graphmodule, out_quant, out_fp) + err.backward() + w_opt.step() + + for layer in graphmodule.modules(): + if isinstance(layer, _ADAROUND_SUPPORT_TYPE): + weight_fake_quant = layer.weight_fake_quant + layer.weight.data = weight_fake_quant.get_hard_value( + layer.weight.data) + weight_fake_quant.adaround = False + if isinstance(layer, torch.quantization.FakeQuantize) and hasattr( + layer, 'prob'): + # recover to promise that drop activation quantization only + # occurs at reconstruction phase + layer.prob = 1.0 + + def reconstruction(self, graphmodule, calibrate_dataloader, config): + assert isinstance(graphmodule, torch.fx.GraphModule) + graphmodule_fp = graphmodule + graphmodule_quant = copy.deepcopy(graphmodule) + + # get layers/blocks need to reconstructe + slices = [] + if config.pattern == 'layer': + slices = extract_layers( + graphmodule, layer_types=_ADAROUND_SUPPORT_TYPE) + elif config.pattern == 'block': + slices = extract_blocks(graphmodule) + else: + # TODO: add remind + raise NotImplementedError + + # save fp inputs and outputs of each layers + manager_fp = self._save_inter_result(graphmodule_fp, + self.calibrate_dataloader, slices) + + # extract subgraph_module + for s in slices: + sub_graphmodule = extract_subgraph(graphmodule_quant, s) + manager_quant = self._save_inter_result( + graphmodule_quant, + self.calibrate_dataloader, [s], + store_output=False) + input_index = s[0].target + '_input' + output_index = s[1].target + '_output' + input_recorder = manager_quant.get_recorder(input_index) + output_recorder = manager_fp.get_recorder(output_index) + self.sub_reconstruction(sub_graphmodule, input_recorder, + output_recorder, config) + + return graphmodule_quant + + def run(self) -> None: + """Launch test.""" + self.runner.call_hook('before_test') + self.runner.call_hook('before_test_epoch') + + self.model.eval() + self.model.prepare() + + if self.is_calibrate: + self.model.state = (1, 0) + self.calibrate(self.calibrate_dataloader) + + self.model.state = (1, 1) + + if self.config is not None: + self.model.architecture = self.reconstruction( + self.model.architecture, self.calibrate_dataloader, + self.config) + + self.model.convert() + + self.model.eval() + from torch.onnx import OperatorExportTypes + dummy_input = torch.randn([1, 3, 224, 224]) + onnx_path = os.path.join(self.runner.work_dir, 'quantizied.onnx') + torch.onnx.export( + self.model.architecture, + dummy_input, + onnx_path, + opset_version=11, + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) + + self.runner.call_hook('after_test') diff --git a/mmrazor/engine/runner/utils/state.py b/mmrazor/engine/runner/utils/state.py new file mode 100644 index 000000000..2f6d602a5 --- /dev/null +++ b/mmrazor/engine/runner/utils/state.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.logging import print_log +from torch.ao.quantization import FakeQuantize + + +# TODO: may be removed +def set_quant_state(model, enable_observer=True, enable_fake_quant=True): + for name, submodule in model.named_modules(): + if isinstance(submodule, FakeQuantize): + if enable_observer: + submodule.enable_observer() + else: + submodule.disable_observer() + if enable_fake_quant: + submodule.enable_fake_quant() + else: + submodule.disable_fake_quant() + print_log(f'Enable observer: {enable_observer}; \ + Enable fake quant: {enable_fake_quant}') diff --git a/mmrazor/engine/runner/utils/subgraph.py b/mmrazor/engine/runner/utils/subgraph.py new file mode 100644 index 000000000..ea0f8837f --- /dev/null +++ b/mmrazor/engine/runner/utils/subgraph.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch.fx as fx + + +def extract_subgraph(graphmodule, block_slice): + subgraph = copy.deepcopy(graphmodule.graph) + block_start, block_end = block_slice[:2] + for node in subgraph.nodes: + if node.name == 'inputs': + input_node = node + if node.name == block_start.name: + node.replace_input_with(node.prev, input_node) + if node.name == block_end.name: + output_node = node + if node.op == 'output': + node.replace_input_with(node.prev, output_node) + subgraph.lint() + subgraph_module = fx.GraphModule(graphmodule, subgraph) + subgraph_module.graph.eliminate_dead_code() + subgraph_module.recompile() + return subgraph_module + + +def extract_blocks(graph, key_word='layer'): + block_slices = [] + block_slice = [] + pre_stage_index, pre_block_index = 0, 0 + cur_stage_index, cur_block_index = 0, 0 + for node in graph.nodes: + if key_word not in node.name: + continue + else: + items = node.name.split('_') + for i, item in enumerate(items): + if key_word in item: + cur_stage_index = int(item[5:]) + cur_block_index = int(items[i + 1]) + break + if (cur_block_index != pre_block_index) or (cur_stage_index != + pre_stage_index): + block_slice.append(node.prev) + if len(block_slice) == 2: + block_slices.append(block_slice) + block_slice = [] + block_slice.append(node) + + pre_stage_index, pre_block_index = cur_stage_index, cur_block_index + + return block_slices + + +def extract_layers(graphmodule, layer_types): + layer_slices = [] + for node in graphmodule.graph.nodes: + if node.op == 'call_module': + m = graphmodule.get_submodule(node.target) + if isinstance(m, layer_types): + layer_slices.append((node, node)) + return layer_slices diff --git a/mmrazor/models/__init__.py b/mmrazor/models/__init__.py index f5295aa9e..e5b9ec451 100644 --- a/mmrazor/models/__init__.py +++ b/mmrazor/models/__init__.py @@ -2,7 +2,11 @@ from .algorithms import * # noqa: F401,F403 from .architectures import * # noqa: F401,F403 from .distillers import * # noqa: F401,F403 +from .fake_quants import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .mutables import * # noqa: F401,F403 from .mutators import * # noqa: F401,F403 +from .observers import * # noqa: F401,F403 +from .quantizers import * # noqa: F401,F403 from .task_modules import * # noqa: F401,F403 +from .utils import * # noqa: F401,F403 diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 3cef96dfe..89a11b899 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -7,6 +7,7 @@ BigNAS, BigNASDDP, Darts, DartsDDP) from .pruning import DCFF, DMCP, DMCPDDP, SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm +from .quantization import GeneralQuant __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', @@ -14,5 +15,5 @@ 'Darts', 'DartsDDP', 'DCFF', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'Autoformer', 'BigNAS', - 'BigNASDDP', 'DMCP', 'DMCPDDP' + 'BigNASDDP', 'DMCP', 'DMCPDDP', 'GeneralQuant' ] diff --git a/mmrazor/models/algorithms/quantization/__init__.py b/mmrazor/models/algorithms/quantization/__init__.py new file mode 100644 index 000000000..84c25bbc0 --- /dev/null +++ b/mmrazor/models/algorithms/quantization/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import GeneralQuant + +__all__ = ['GeneralQuant'] diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/base.py new file mode 100644 index 000000000..718b08725 --- /dev/null +++ b/mmrazor/models/algorithms/quantization/base.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import torch +from mmengine.structures import BaseDataElement +from torch.fx import GraphModule + +from mmrazor.registry import MODELS +from ..base import BaseAlgorithm + +LossResults = Dict[str, torch.Tensor] +TensorResults = Union[Tuple[torch.Tensor], torch.Tensor] +PredictResults = List[BaseDataElement] +ForwardResults = Union[LossResults, TensorResults, PredictResults] + + +@MODELS.register_module() +class GeneralQuant(BaseAlgorithm): + """General quantization. + + Args: + Args: + architecture (dict | :obj:`BaseModel`): The config of + :class:`BaseModel` or built model. + quantizer (dict | :obj:`BaseModel`): The config of + :class:`BaseQuantizer` or built model. + data_preprocessor (dict | torch.nn.Module | None): The pre-process + config of :class:`BaseDataPreprocessor`. Defaults to None. + init_cfg (dict): The weight initialized config for + :class:`BaseModule`. + """ + + def __init__(self, + architecture, + quantizer, + data_preprocessor=None, + init_cfg=None): + if data_preprocessor is None: + data_preprocessor = {} + # The build process is in MMEngine, so we need to add scope here. + data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') + + super().__init__(architecture, data_preprocessor, init_cfg) + self.quantizer = MODELS.build(quantizer) + self.observers_enabled = True + self.fake_quants_enabled = True + self.gen_graphs(self.architecture) + + def gen_graphs(self, model): + self.quantizer._swap_ff_with_fxff(model) + tracer = self.quantizer.tracer + for mode in ['tensor', 'loss', 'predict']: + concrete_args = {'mode': mode} + if mode == 'tensor': + self.graph_tensor = GraphModule( + model, tracer.trace(model, concrete_args=concrete_args)) + if mode == 'loss': + self.graph_loss = GraphModule( + model, tracer.trace(model, concrete_args=concrete_args)) + if mode == 'predict': + self.graph_predict = GraphModule( + model, tracer.trace(model, concrete_args=concrete_args)) + + def forward(self, + inputs: torch.Tensor, + data_samples: Optional[List[BaseDataElement]] = None, + mode: str = 'tensor') -> ForwardResults: + + if mode == 'loss': + return self.graph_loss(inputs, data_samples, mode) + elif mode == 'tensor': + return self.graph_tensor(inputs, data_samples, mode) + elif mode == 'predict': + return self.graph_predict(inputs, data_samples, mode) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + def calib_step(self, data): + data = self.data_preprocessor(data, False) + return self._run_forward(data, mode='tensor') + + def prepare(self, mode='tensor'): + assert mode in ['tensor', 'loss', 'predict'] + if mode == 'tensor': + graph = self.graph_tensor + elif mode == 'loss': + graph = self.graph_loss + else: + graph = self.graph_predict + self.architecture = self.quantizer.prepare(self.architecture, graph) + + def convert(self): + self.architecture = self.quantizer.convert(self.architecture) + + @property + def state(self): + return (self.observers_enabled, self.fake_quants_enabled) + + @state.setter + def state(self, state): + observers_enabled, fake_quants_enabled = state + for name, submodule in self.architecture.named_modules(): + if isinstance(submodule, torch.quantization.FakeQuantize): + if observers_enabled: + submodule.enable_observer() + else: + submodule.disable_observer() + + if fake_quants_enabled: + submodule.enable_fake_quant() + else: + submodule.disable_fake_quant() + + self.observers_enabled = observers_enabled + self.fake_quants_enabled = fake_quants_enabled diff --git a/mmrazor/models/fake_quants/__init__.py b/mmrazor/models/fake_quants/__init__.py new file mode 100644 index 000000000..cea7708a2 --- /dev/null +++ b/mmrazor/models/fake_quants/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .adaround import AdaRoundFakeQuantize +from .base import FakeQuantize +from .lsq import LearnableFakeQuantize +from .qdrop import QDropFakeQuantize + +__all__ = [ + 'FakeQuantize', 'AdaRoundFakeQuantize', 'QDropFakeQuantize', + 'LearnableFakeQuantize' +] diff --git a/mmrazor/models/fake_quants/adaround.py b/mmrazor/models/fake_quants/adaround.py new file mode 100644 index 000000000..9388f1aa4 --- /dev/null +++ b/mmrazor/models/fake_quants/adaround.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS +from .base import FakeQuantize + +_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 + + +@MODELS.register_module() +class AdaRoundFakeQuantize(FakeQuantize): + + def __init__(self, observer, **observer_kwargs): + super().__init__(observer, **observer_kwargs) + self.adaround = False + + def init(self, weight_tensor: torch.Tensor): + self.adaround = True + self.observer_enabled[0] = 0 + self.fake_quant_enabled[0] = 1 + + # self.soft_targets = False # delete this + self.gamma = -0.1 + self.zeta = 1.1 + self.init_alpha(x=weight_tensor.data.clone()) + + def init_alpha(self, x: torch.Tensor): + if self.ch_axis != -1: + new_shape = [1] * len(x.shape) + new_shape[self.ch_axis] = x.shape[self.ch_axis] + scale = self.scale.data.reshape(new_shape) + else: + scale = self.scale.data + x_floor = torch.floor(x / scale) + rest = (x / scale) - x_floor # rest of rounding [0, 1) + alpha = -torch.log((self.zeta - self.gamma) / + (rest - self.gamma) - 1) # => sigmoid(alpha) = rest + self.alpha = Parameter(alpha) + + def rectified_sigmoid(self): + """Function to generate rounding mask. + + Args: + x (torch.Tensor): + zeta (torch.Tensor): + gamma (torch.Tensor): + Returns: + torch.Tensor: + """ + return ((self.zeta - self.gamma) * torch.sigmoid(self.alpha) + + self.gamma).clamp(0, 1) + + def adaround_forward(self, x, hard_value=False): + if self.ch_axis != -1: + new_shape = [1] * len(x.shape) + new_shape[self.ch_axis] = x.shape[self.ch_axis] + scale = self.scale.reshape(new_shape) + zero_point = self.zero_point.reshape(new_shape) + x = torch.floor(x / scale) + if hard_value: + x += (self.alpha >= 0).float() + else: + x += self.rectified_sigmoid(self.alpha, self.zeta, self.gamma) + x += zero_point + x = torch.clamp(x, self.quant_min, self.quant_max) + x = (x - zero_point) * scale + return x + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if not self.adaround: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, + self.zero_point.long() + if _version_under_1100 else self.zero_point, + self.ch_axis, self.quant_min, self.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale.item(), int(self.zero_point.item()), + self.quant_min, self.quant_max) + else: + if not hasattr(self, 'alpha'): + raise NotImplementedError + X = self.adaround_forward(X) + return X diff --git a/mmrazor/models/fake_quants/base.py b/mmrazor/models/fake_quants/base.py new file mode 100644 index 000000000..13f8a1e43 --- /dev/null +++ b/mmrazor/models/fake_quants/base.py @@ -0,0 +1,124 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import FakeQuantizeBase + +from mmrazor.models.utils import (_is_float_qparams, _is_per_channel, + _is_per_tensor, _is_symmetric_quant) +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class FakeQuantize(FakeQuantizeBase): + + scale: torch.Tensor + zero_point: torch.Tensor + + def __init__(self, observer, **observer_kwargs): + super().__init__() + self.activation_post_process = observer(**observer_kwargs) + self.quant_min = self.activation_post_process.quant_min + self.quant_max = self.activation_post_process.quant_max + if _is_float_qparams(self.activation_post_process.qscheme): + zero_point_dtype = torch.float + else: + zero_point_dtype = torch.int + self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float)) + self.register_buffer('zero_point', + torch.tensor([0], dtype=zero_point_dtype)) + self.dtype = self.activation_post_process.dtype + self.qscheme = self.activation_post_process.qscheme + self.ch_axis = self.activation_post_process.ch_axis \ + if hasattr(self.activation_post_process, 'ch_axis') else -1 + assert _is_per_channel(self.qscheme) or \ + _is_per_tensor(self.qscheme), \ + 'Only per channel and per tensor quantization are supported in ' \ + 'fake quantize' + ' got qscheme: ' + str(self.qscheme) + self.is_per_channel = _is_per_channel(self.qscheme) + + bitrange = torch.tensor(self.quant_max - self.quant_min + 1).double() + self.bitwidth = int(torch.log2(bitrange).item()) + self.is_pot_scale = self.activation_post_process.is_pot_scale + self.is_symmetric_quant = _is_symmetric_quant(self.qscheme) + + @torch.jit.export + def calculate_qparams(self): + return self.activation_post_process.calculate_qparams() + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale, self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + return X + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ' \ + 'ch_axis={}, scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max, self.dtype, + self.qscheme, self.ch_axis, self.scale, self.zero_point) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + # We cannot currently register scalar values as buffers, so need to + # manually specify serialization here. + super(FakeQuantize, self)._save_to_state_dict(destination, prefix, + keep_vars) + destination[prefix + 'scale'] = self.scale + destination[prefix + 'zero_point'] = self.zero_point + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + # Removing this function throws an error that the the size of the + # loaded tensor does not match the original size i.e., These buffers + # start out with numel 0 and become numel 1 once they have their + # first forward pass. + local_state = ['scale', 'zero_point'] + for name in local_state: + key = prefix + name + if key in state_dict: + val = state_dict[key] + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == 'scale': + self.scale.resize_(val.shape) + else: + assert name == 'zero_point' + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here + # since we do not call the `_load_from_state_dict` function + # defined module.py + if torch.jit.is_scripting(): + if name == 'scale': + self.scale.copy_(val) + else: + assert name == 'zero_point' + self.zero_point.copy_(val) + elif strict: + missing_keys.append(key) + super(FakeQuantize, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py new file mode 100644 index 000000000..10970a6a3 --- /dev/null +++ b/mmrazor/models/fake_quants/lsq.py @@ -0,0 +1,137 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS +from ..utils import PerChannelLoadHook, _is_symmetric_quant, is_tracing_state +from .base import FakeQuantize + + +@MODELS.register_module() +class LearnableFakeQuantize(FakeQuantize): + r""" This is an extension of the FakeQuantize module in fake_quantize.py, + which supports more generalized lower-bit quantization and support learning + of the scale and zero point parameters through backpropagation. For + literature references, please see the class + `_LearnableFakeQuantizePerTensorOp`. In addition to the attributes in the + original FakeQuantize module, the `_LearnableFakeQuantize` module also + includes the following attributes to support quantization parameter + learning. + """ + + def __init__(self, + observer, + scale=1., + zero_point=0., + use_grad_scaling=True, + **observer_kwargs): + super(LearnableFakeQuantize, self).__init__(observer, + **observer_kwargs) + self.use_grad_scaling = use_grad_scaling + self.scale = Parameter(torch.tensor([scale])) + self.zero_point = Parameter(torch.tensor([zero_point])) + self.register_buffer('eps', + torch.tensor([torch.finfo(torch.float32).eps])) + # Check whether the module will load a state dict; + # Initialize the shape of per-channel 'scale' and + # 'zero-point' before copying values + self.load_state_dict_hook = PerChannelLoadHook(self) + + @torch.jit.export + def extra_repr(self): + return 'fake_quant_enabled={}, observer_enabled={}, ' \ + 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={},'\ + 'scale={}, zero_point={}'.format( + self.fake_quant_enabled, self.observer_enabled, + self.quant_min, self.quant_max, + self.dtype, self.qscheme, self.ch_axis, + self.scale if self.ch_axis == -1 else 'List[%s]' % str(self.scale.shape), # noqa: E501 + self.zero_point if self.ch_axis == -1 else 'List') + + def forward(self, X): + # Learnable fake quantize have to zero_point.float() + # to make it learnable. + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = \ + self.activation_post_process.calculate_qparams() + _scale = _scale.to(self.scale.device) + _zero_point = _zero_point.to(self.zero_point.device) + + if self.ch_axis != -1: + self.scale.data = torch.ones_like(_scale) + self.zero_point.data = torch.zeros_like(_zero_point.float()) + + self.scale.data.copy_(_scale) + self.zero_point.data.copy_(_zero_point.float()) + else: + self.scale.data.abs_() + self.scale.data.clamp_(min=self.eps.item()) + + if self.fake_quant_enabled[0] == 1: + if _is_symmetric_quant(self.qscheme): + self.zero_point.data.zero_() + else: + self.zero_point.data.clamp_(self.quant_min, + self.quant_max).float() + + if self.is_per_channel: + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() / X.shape[self.ch_axis] * + self.quant_max)**0.5 + else: + grad_factor = 1.0 + if is_tracing_state(): + X = FakeQuantizeLearnablePerchannelAffine.apply( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + X = _fake_quantize_learnable_per_channel_affine_training( + X, self.scale, self.zero_point, self.ch_axis, + self.quant_min, self.quant_max, grad_factor) + else: + if self.use_grad_scaling: + grad_factor = 1.0 / (X.numel() * self.quant_max)**0.5 + else: + grad_factor = 1.0 + X = torch._fake_quantize_learnable_per_tensor_affine( + X, self.scale, self.zero_point, self.quant_min, + self.quant_max, grad_factor) + return X + + +def _fake_quantize_learnable_per_channel_affine_training( + x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): + zero_point = (zero_point.round() - zero_point).detach() + zero_point + new_shape = [1] * len(x.shape) + new_shape[ch_axis] = x.shape[ch_axis] + scale = grad_scale(scale, grad_factor).reshape(new_shape) + zero_point = grad_scale(zero_point, grad_factor).reshape(new_shape) + x = x / scale + zero_point + x = (x.round() - x).detach() + x + x = torch.clamp(x, quant_min, quant_max) + return (x - zero_point) * scale + + +def grad_scale(t, scale): + return (t - (t * scale)).detach() + (t * scale) + + +class FakeQuantizeLearnablePerchannelAffine(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, + grad_factor): + return _fake_quantize_learnable_per_channel_affine_training( + x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor) + + @staticmethod + def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, + grad_factor): + return g.op( + '::FakeQuantizeLearnablePerchannelAffine', + x, + scale, + zero_point, + quant_min_i=quant_min, + quant_max_i=quant_max) diff --git a/mmrazor/models/fake_quants/qdrop.py b/mmrazor/models/fake_quants/qdrop.py new file mode 100644 index 000000000..e2e13bfc0 --- /dev/null +++ b/mmrazor/models/fake_quants/qdrop.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.nn.parameter import Parameter + +from mmrazor.registry import MODELS +from .base import FakeQuantize + + +@MODELS.register_module() +class QDropFakeQuantize(FakeQuantize): + + def __init__(self, observer, **observer_kwargs): + super().__init__(observer, **observer_kwargs) + self.scale = Parameter(torch.tensor([1.0], dtype=torch.float)) + self.prob = 1.0 + + def forward(self, X): + if self.observer_enabled[0] == 1: + self.activation_post_process(X.detach()) + _scale, _zero_point = self.calculate_qparams() + _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to( + self.zero_point.device) + if self.scale.shape != _scale.shape: + self.scale.resize_(_scale.shape) + self.zero_point.resize_(_zero_point.shape) + self.scale.copy_(_scale) + self.zero_point.copy_(_zero_point) + + if self.fake_quant_enabled[0] == 1: + x_orig = X + if self.is_per_channel: + X = torch.fake_quantize_per_channel_affine( + X, self.scale, self.zero_point, self.ch_axis, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + else: + X = torch.fake_quantize_per_tensor_affine( + X, self.scale, self.zero_point, + self.activation_post_process.quant_min, + self.activation_post_process.quant_max) + if self.prob < 1.0: + x_prob = torch.where(torch.rand_like(X) < self.prob, X, x_orig) + return x_prob + return X diff --git a/mmrazor/models/losses/__init__.py b/mmrazor/models/losses/__init__.py index 65e2108fd..3509acd5c 100644 --- a/mmrazor/models/losses/__init__.py +++ b/mmrazor/models/losses/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ab_loss import ABLoss +from .adaround_loss import AdaRoundLoss from .at_loss import ATLoss from .crd_loss import CRDLoss from .cross_entropy_loss import CrossEntropyLoss diff --git a/mmrazor/models/losses/adaround_loss.py b/mmrazor/models/losses/adaround_loss.py new file mode 100644 index 000000000..76c97977d --- /dev/null +++ b/mmrazor/models/losses/adaround_loss.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmengine.logging import print_log + +from mmrazor.registry import MODELS + +_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) + + +@MODELS.register_module() +class AdaRoundLoss(nn.Module): + r'''loss function to calculate mse reconstruction loss and relaxation loss + use some tempdecay to balance the two losses. + ''' + + def __init__(self, + weight: float = 1., + iters: int = 10000, + beta_range: tuple = (20, 2), + warm_up: float = 0.0, + p: float = 2.): + self.weight = weight + self.loss_start = iters * warm_up + self.p = p + + self.temp_decay = LinearTempDecay( + iters, + warm_up=warm_up, + start_beta=beta_range[0], + end_beta=beta_range[1]) + self.count = 0 + + def forward(self, subgraph, pred, tgt): + """Compute the total loss for adaptive rounding: rec_loss is the + quadratic output reconstruction loss, round_loss is a regularization + term to optimize the rounding policy. + + :param pred: output from quantized model + :param tgt: output from FP model + :return: total loss function + """ + + def lp_loss(pred, tgt, p=2.0): + """loss function measured in L_p Norm.""" + return (pred - tgt).abs().pow(p).sum(1).mean() + + self.count += 1 + rec_loss = lp_loss(pred, tgt, p=self.p) + + beta = self.temp_decay(self.count) + if self.count < self.loss_start: + round_loss = 0 + else: + round_loss = 0 + for layer in subgraph.modules(): + if isinstance(layer, _ADAROUND_SUPPORT_TYPE): + round_vals = layer.weight_fake_quant.rectified_sigmoid() + round_loss += self.weight * (1 - ( + (round_vals - .5).abs() * 2).pow(beta)).sum() + + total_loss = rec_loss + round_loss + if self.count % 500 == 0: + print_log('Total loss:\t{:.3f} (rec_loss:{:.3f}, ' + 'round_loss:{:.3f})\tbeta={:.2f}\tcount={}'.format( + float(total_loss), float(rec_loss), + float(round_loss), beta, self.count)) + return total_loss + + +class LinearTempDecay: + + def __init__(self, t_max=10000, warm_up=0.2, start_beta=20, end_beta=2): + self.t_max = t_max + self.start_decay = warm_up * t_max + self.start_beta = start_beta + self.end_beta = end_beta + + def __call__(self, t): + if t < self.start_decay: + return self.start_beta + elif t > self.t_max: + return self.end_beta + else: + rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) + return self.end_beta + (self.start_beta - self.end_beta) * \ + max(0.0, (1 - rel_t)) diff --git a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py index e730245d4..ea8681511 100644 --- a/mmrazor/models/mutables/mutable_channel/units/channel_unit.py +++ b/mmrazor/models/mutables/mutable_channel/units/channel_unit.py @@ -212,6 +212,11 @@ def alias(self) -> str: """str: alias of the unit""" return self.name + @property + def alias(self) -> str: + """str: alias of the unit""" + return self.name + def config_template(self, with_init_args=False, with_channels=False) -> Dict: diff --git a/mmrazor/models/mutators/channel_mutator/channel_mutator.py b/mmrazor/models/mutators/channel_mutator/channel_mutator.py index 910992e1e..2d83d48f7 100644 --- a/mmrazor/models/mutators/channel_mutator/channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/channel_mutator.py @@ -12,10 +12,11 @@ from mmrazor.models.task_modules.tracer.channel_analyzer import ChannelAnalyzer from mmrazor.registry import MODELS, TASK_UTILS from ..base_mutator import BaseMutator +from ..group_mixin import GroupMixin @MODELS.register_module() -class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): +class ChannelMutator(BaseMutator, Generic[ChannelUnitType], GroupMixin): """ChannelMutator manages the pruning structure of a model. Args: @@ -45,6 +46,10 @@ class ChannelMutator(BaseMutator, Generic[ChannelUnitType]): demo_input=(1, 3, 224, 224), tracer_type='BackwardTracer') + custom_groups (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + init_cfg (dict, optional): initialization configuration dict for BaseModule. @@ -92,6 +97,10 @@ def __init__(self, self._parse_channel_unit_cfg( channel_unit_cfg) + if custom_groups is None: + custom_groups = [] + self._custom_groups = custom_groups + def prepare_from_supernet(self, supernet: Module) -> None: """Prepare from a model for pruning. @@ -229,10 +238,9 @@ def set_choices(self, choices: Dict[str, Any]) -> None: @property def current_choices(self) -> Dict: """Get current choices.""" - config = self.choice_template - for unit in self.mutable_units: - config[unit.name] = unit.current_choice - return config + current_choices = dict() + for group_id, modules in self.search_groups.items(): + current_choices[group_id] = modules[0].current_choice @property def choice_template(self) -> Dict: diff --git a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py index c3da419bf..b00e0ef22 100644 --- a/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py +++ b/mmrazor/models/mutators/channel_mutator/slimmable_channel_mutator.py @@ -29,7 +29,7 @@ def __init__(self, tracer_type='BackwardTracer'), init_cfg: Optional[Dict] = None) -> None: - super().__init__(channel_unit_cfg, parse_cfg, init_cfg) + super().__init__(channel_unit_cfg, parse_cfg, None, init_cfg) self.subnets = self._prepare_subnets(self.units_cfg) diff --git a/mmrazor/models/mutators/module_mutator/diff_module_mutator.py b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py new file mode 100644 index 000000000..1f639ed28 --- /dev/null +++ b/mmrazor/models/mutators/module_mutator/diff_module_mutator.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + +from mmrazor.registry import MODELS +from ...mutables import DiffMutableModule +from .module_mutator import ModuleMutator + + +@MODELS.register_module() +class DiffModuleMutator(ModuleMutator): + """Differentiable mutable based mutator. + + `DiffModuleMutator` is responsible for mutating `DiffMutableModule`, + `DiffMutableOP`, `DiffChoiceRoute` and `GumbelChoiceRoute`. + The architecture parameters of the mutables are retained + in `DiffModuleMutator`. + + Args: + custom_group (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_groups: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(custom_groups=custom_groups, init_cfg=init_cfg) + + def build_arch_param(self, num_choices) -> nn.Parameter: + """Build learnable architecture parameters.""" + return nn.Parameter(torch.randn(num_choices) * 1e-3) + + def prepare_from_supernet(self, supernet: nn.Module) -> None: + """Inherit from ``BaseMutator``'s, generate `arch_params` in DARTS. + + Args: + supernet (:obj:`torch.nn.Module`): The architecture to be used + in your algorithm. + """ + + super().prepare_from_supernet(supernet) + self.arch_params = self.build_arch_params() + self.modify_supernet_forward(self.arch_params) + + def build_arch_params(self): + """This function will build many arch params, which are generally used + in differentiable search algorithms, such as Darts' series. Each + group_id corresponds to an arch param, so the Mutables with the same + group_id share the same arch param. + + Returns: + torch.nn.ParameterDict: the arch params are got by `search_groups`. + """ + + arch_params = nn.ParameterDict() + + for group_id, modules in self.search_groups.items(): + group_arch_param = self.build_arch_param(modules[0].num_choices) + arch_params[str(group_id)] = group_arch_param + + return arch_params + + def modify_supernet_forward(self, arch_params): + """Modify the DiffMutableModule's default arch_param in forward. + + In MMRazor, the `arch_param` is along to `DiffModuleMutator`, while the + `DiffMutableModule` needs `arch_param` in the forward. Here we use + partial function to assign the corresponding `arch_param` to each + `DiffMutableModule`. + """ + + for group_id, mutables in self.search_groups.items(): + for m in mutables: + m.set_forward_args(arch_param=arch_params[str(group_id)]) + + def sample_choices(self): + """Sampling by search groups. + + The sampling result of the first mutable of each group is the sampling + result of this group. + + Returns: + Dict[int, Any]: Random choices dict. + """ + + choices = dict() + for group_id, mutables in self.search_groups.items(): + arch_param = self.arch_params[str(group_id)] + choice = mutables[0].sample_choice(arch_param) + choices[group_id] = choice + return choices + + def set_choices(self, choices: Dict[int, Any]) -> None: + """Set mutables' current choice according to choices sample by + :func:`sample_choices`. + + Args: + choices (Dict[int, Any]): Choices dict. The key is group_id in + search groups, and the value is the sampling results + corresponding to this group. + """ + for group_id, mutables in self.search_groups.items(): + choice = choices[group_id] + for m in mutables: + m.current_choice = choice + + @property + def mutable_class_type(self): + """Differentiable mutable class type. + + Returns: + Type[DiffMutableModule]: Class type of differentiable mutable. + """ + return DiffMutableModule diff --git a/mmrazor/models/mutators/module_mutator/module_mutator.py b/mmrazor/models/mutators/module_mutator/module_mutator.py new file mode 100644 index 000000000..f30e933e0 --- /dev/null +++ b/mmrazor/models/mutators/module_mutator/module_mutator.py @@ -0,0 +1,94 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import abstractmethod +from typing import Dict, List, Optional, Type + +from torch.nn import Module + +from ..base_mutator import MUTABLE_TYPE, BaseMutator +from ..group_mixin import GroupMixin + + +class ModuleMutator(BaseMutator[MUTABLE_TYPE], GroupMixin): + """The base class for mutable based mutator. + + All subclass should implement the following APIS: + + - ``mutable_class_type`` + + Args: + custom_groups (list[list[str]], optional): User-defined search groups. + All searchable modules that are not in ``custom_group`` will be + grouped separately. + """ + + def __init__(self, + custom_groups: Optional[List[List[str]]] = None, + init_cfg: Optional[Dict] = None) -> None: + super().__init__(init_cfg) + + if custom_groups is None: + custom_groups = [] + self._custom_groups = custom_groups + self._search_groups: Optional[Dict[int, List[MUTABLE_TYPE]]] = None + + # TODO + # should be a class property + @property + @abstractmethod + def mutable_class_type(self) -> Type[MUTABLE_TYPE]: + """Corresponding mutable class type. + + Returns: + Type[MUTABLE_TYPE]: Mutable class type. + """ + + def prepare_from_supernet(self, supernet: Module) -> None: + """Do some necessary preparations with supernet. + + Note: + For mutable based mutator, we need to build search group first. + + Args: + supernet (:obj:`torch.nn.Module`): The supernet to be searched + in your algorithm. + """ + self._search_groups = self.build_search_groups(supernet, + self.mutable_class_type, + self._custom_groups) + + @property + def name2mutable(self) -> Dict[str, MUTABLE_TYPE]: + """Search space of supernet. + + Note: + To get the mapping: module name to mutable. + + Raises: + RuntimeError: Called before search space has been parsed. + + Returns: + Dict[str, MUTABLE_TYPE]: The name2mutable dict. + """ + if self._name2mutable is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access name2mutable!') + return self._name2mutable + + @property + def search_groups(self) -> Dict[int, List[MUTABLE_TYPE]]: + """Search group of supernet. + + Note: + For mutable based mutator, the search group is composed of + corresponding mutables. + + Raises: + RuntimeError: Called before search group has been built. + + Returns: + Dict[int, List[MUTABLE_TYPE]]: Search group. + """ + if self._search_groups is None: + raise RuntimeError( + 'Call `prepare_from_supernet` before access search group!') + return self._search_groups diff --git a/mmrazor/models/observers/__init__.py b/mmrazor/models/observers/__init__.py new file mode 100644 index 000000000..22af9bae9 --- /dev/null +++ b/mmrazor/models/observers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .minmax import EMAMinMaxObserver, MinMaxObserver +from .mse import MSEObserver + +__all__ = ['MinMaxObserver', 'MSEObserver', 'EMAMinMaxObserver'] diff --git a/mmrazor/models/observers/base.py b/mmrazor/models/observers/base.py new file mode 100644 index 000000000..e10738664 --- /dev/null +++ b/mmrazor/models/observers/base.py @@ -0,0 +1,74 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch.ao.quantization.observer import UniformQuantizationObserverBase + +from mmrazor.models.utils import pot_quantization, sync_tensor + +# from mmengine.model import BaseModule + + +class BaseObserver(UniformQuantizationObserverBase): + """Modified torch quantization observer. + + Args: + dtype: dtype argument to the `quantize` node needed to implement the + reference model spec. + qscheme: Quantization scheme to be used. + reduce_range: Reduces the range of the quantized data type by 1 bit. + This is sometimes required to avoid instruction overflow. + quant_min: Minimum quantization value. If unspecified, it will follow + the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will follow + the 8-bit setup. + ch_axis (int, optional): Channel axis index. Defaults to -1. + is_pot_scale (bool, optional): Indicate whether scale is power of two. + Defaults to False. + eps: Epsilon value for float32. + Defaults to `torch.finfo(torch.float32).eps`. + """ + + min_val: torch.Tensor + max_val: torch.Tensor + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + factory_kwargs, eps) + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer('min_val', + torch.tensor(float('inf'), **factory_kwargs)) + self.register_buffer('max_val', + torch.tensor(float('-inf'), **factory_kwargs)) + self.ch_axis = ch_axis + self.is_pot_scale = is_pot_scale + + @torch.jit.export + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Calculates the quantization parameters.""" + scale, zero_point = self._calculate_qparams(self.min_val, self.max_val) + scale.data = sync_tensor(scale).data + zero_point.data = sync_tensor(zero_point).data + if self.is_pot_scale: + scale = pot_quantization(scale) + return scale, zero_point + + @torch.jit.export + def extra_repr(self): + return 'min_val={}, max_val={} ch_axis={} is_pot_scale={}'.format( + self.min_val, self.max_val, self.ch_axis, self.is_pot_scale) + + @torch.jit.export + def reset_min_max_vals(self): + """Resets the min/max values.""" + self.min_val.copy_(torch.tensor(float('inf'))) + self.max_val.copy_(torch.tensor(float('-inf'))) diff --git a/mmrazor/models/observers/lsq_observer.py b/mmrazor/models/observers/lsq_observer.py new file mode 100644 index 000000000..d9b96d7a8 --- /dev/null +++ b/mmrazor/models/observers/lsq_observer.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import torch + +from mmrazor.registry import MODELS +from ..utils import _is_symmetric_quant, pot_quantization, sync_tensor +from .base import BaseObserver + + +@MODELS.register_module() +class LSQObserver(BaseObserver): + """Observer for `LEARNED STEP SIZE QUANTIZATION`""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, is_pot_scale, factory_kwargs, eps) + + self.tensor_norm = None + + def forward(self, x_orig): + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + self.tensor_norm = x.abs().mean() + self.min_val, self.max_val = torch._aminmax(x) + else: + # compute channel-wise mean + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + self.tensor_norm = y.abs().mean(1) + self.min_val, self.max_val = torch._aminmax(y, 1) + + return x + + def calculate_qparams(self): + scale = 2 * self.tensor_norm / math.sqrt(self.quant_max) + zero_point = torch.zeros_like(self.tensor_norm) + sync_tensor(scale) + sync_tensor(zero_point) + if self.is_pot_scale: + scale = pot_quantization(scale) + if not _is_symmetric_quant(self.qscheme): + zero_point = self.quant_min - torch.round(self.min_val / scale) + return scale, zero_point diff --git a/mmrazor/models/observers/minmax.py b/mmrazor/models/observers/minmax.py new file mode 100644 index 000000000..099296536 --- /dev/null +++ b/mmrazor/models/observers/minmax.py @@ -0,0 +1,97 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import MODELS +from .base import BaseObserver + + +@MODELS.register_module() +class MinMaxObserver(BaseObserver): + """Min max observer.""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super(MinMaxObserver, self).__init__(dtype, qscheme, reduce_range, + quant_min, quant_max, ch_axis, + is_pot_scale, factory_kwargs, eps) + if (self.qscheme == torch.per_tensor_symmetric and self.reduce_range + and self.dtype == torch.quint8): + raise NotImplementedError('Cannot reduce range for symmetric \ + quantization for quint8') + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.detach() # avoid keeping autograd tape + x = x.to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val = torch.min(self.min_val, min_val_cur) + max_val = torch.max(self.max_val, max_val_cur) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + return x + + +@MODELS.register_module() +class EMAMinMaxObserver(BaseObserver): + """Moving average min/max among batches.""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + ema_ratio=0.9, + factory_kwargs=None): + super(EMAMinMaxObserver, + self).__init__(dtype, qscheme, reduce_range, quant_min, + quant_max, ch_axis, is_pot_scale, factory_kwargs) + self.ema_ratio = ema_ratio + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + y = x.permute(new_axis_list) + y = torch.flatten(y, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + + if self.max_val.numel() <= 1 and self.max_val.isinf(): + self.min_val = min_val_cur + self.max_val = max_val_cur + else: + self.min_val = self.min_val * self.ema_ratio + min_val_cur * ( + 1.0 - self.ema_ratio) + self.max_val = self.max_val * self.ema_ratio + max_val_cur * ( + 1.0 - self.ema_ratio) + return x diff --git a/mmrazor/models/observers/mse.py b/mmrazor/models/observers/mse.py new file mode 100644 index 000000000..f85abd902 --- /dev/null +++ b/mmrazor/models/observers/mse.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmrazor.registry import MODELS +from .base import BaseObserver + +_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 + + +@MODELS.register_module() +class MSEObserver(BaseObserver): + """MSE observer.""" + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + p=2.0, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, is_pot_scale, factory_kwargs, eps) + self.p = p + + def lp_loss(self, pred, tgt, dim=None): + """loss function measured in L_p Norm.""" + return (pred - tgt).abs().pow( + self.p).mean(dim) if dim else (pred - + tgt).abs().pow(self.p).mean() + + def mse(self, + x: torch.Tensor, + x_min: torch.Tensor, + x_max: torch.Tensor, + iter=80): + best_score = 1e+10 + best_min, best_max = torch.tensor( + [1.0], dtype=torch.float), torch.tensor([1.0], dtype=torch.float) + best_min.copy_(x_min) + best_max.copy_(x_max) + for i in range(iter): + new_min = x_min * (1.0 - (i * 0.01)) + new_max = x_max * (1.0 - (i * 0.01)) + scale, zero_point = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_tensor_affine( + x, scale.item(), int(zero_point.item()), self.quant_min, + self.quant_max) + score = self.lp_loss(x_q, x) + if score < best_score: + best_score = score + best_min, best_max = new_min, new_max + return best_min, best_max + + def mse_perchannel(self, + x: torch.Tensor, + x_min: torch.Tensor, + x_max: torch.Tensor, + iter=80, + ch_axis=0): + assert x_min.shape == x_max.shape + assert ch_axis >= 0, f'{ch_axis}' + best_score = 1e+10 * torch.ones_like(x_min) + best_min, best_max = x_min.clone(), x_max.clone() + reduce_dim = tuple([i for i in range(len(x.shape)) if i != ch_axis]) + for i in range(iter): + new_min = x_min * (1.0 - (i * 0.01)) + new_max = x_max * (1.0 - (i * 0.01)) + scale, zero_point = self._calculate_qparams(new_min, new_max) + x_q = torch.fake_quantize_per_channel_affine( + x, scale, + zero_point.long() if _version_under_1100 else zero_point, + ch_axis, self.quant_min, self.quant_max) + score = self.lp_loss(x_q, x, reduce_dim) + update_idx = (score < best_score) + best_score[update_idx] = score[update_idx] + best_min[update_idx] = new_min[update_idx] + best_max[update_idx] = new_max[update_idx] + return best_min, best_max + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.clone().detach().to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = self.mse( + x, min_val_cur, max_val_cur, iter=95) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + x_channel = x.permute(new_axis_list) + y = torch.flatten(x_channel, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = self.mse_perchannel( + x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) + + self.min_val = torch.min(self.min_val, min_val_cur) + self.max_val = torch.max(self.max_val, max_val_cur) + return x + + +@MODELS.register_module() +class EMAMSEObserver(MSEObserver): + + def __init__(self, + dtype=torch.quint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + ch_axis=-1, + is_pot_scale=False, + p=2.0, + ema_ratio=0.9, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps) -> None: + super().__init__(dtype, qscheme, reduce_range, quant_min, quant_max, + ch_axis, is_pot_scale, p, factory_kwargs, eps) + self.ema_ratio = ema_ratio + + def forward(self, x_orig): + r"""Records the running minimum and maximum of ``x``.""" + if x_orig.numel() == 0: + return x_orig + x = x_orig.clone().detach().to(self.min_val.dtype) + if self.ch_axis == -1: + min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = self.mse( + x, min_val_cur, max_val_cur, iter=95) + else: + x_dim = x.size() + new_axis_list = [i for i in range(len(x_dim))] + new_axis_list[self.ch_axis] = 0 + new_axis_list[0] = self.ch_axis + x_channel = x.permute(new_axis_list) + y = torch.flatten(x_channel, start_dim=1) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val_cur, max_val_cur = self.mse_perchannel( + x, min_val_cur, max_val_cur, iter=80, ch_axis=self.ch_axis) + + if self.max_val.numel() <= 1 and self.max_val.isinf(): + self.min_val = min_val_cur + self.max_val = max_val_cur + else: + self.min_val = self.min_val * self.ema_ratio + min_val_cur * ( + 1.0 - self.ema_ratio) + self.max_val = self.max_val * self.ema_ratio + max_val_cur * ( + 1.0 - self.ema_ratio) + return x diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py new file mode 100644 index 000000000..e56902eba --- /dev/null +++ b/mmrazor/models/quantizers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import CustomQuantizer +from .trt_quantizer import TensorRTQuantizer + +__all__ = ['CustomQuantizer', 'TensorRTQuantizer'] diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py new file mode 100644 index 000000000..ab4cf190a --- /dev/null +++ b/mmrazor/models/quantizers/base.py @@ -0,0 +1,194 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List + +import torch +from mmengine.model import BaseModule +from torch.ao.quantization import QConfig +from torch.ao.quantization.fx import prepare +from torch.ao.quantization.quantize_fx import _convert_fx, _fuse_fx + +from mmrazor.models.task_modules.tracer import CustomTracer +from mmrazor.models.utils import (check_is_valid_convert_custom_config_dict, + check_is_valid_prepare_custom_config_dict, + check_is_valid_qconfig_dict, + get_custom_module_class_keys) +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import (CheckArgs, DefalutQconfigs, + QuantizeScheme, SupportQtypes) + + +@MODELS.register_module() +class CustomQuantizer(BaseModule): + """Configurable quantizer, base class of quantizers. + + Args: + qconfig (Dict, optional): QConfig. Defaults to DefalutQconfigs['default']. # noqa: E501 + is_qat (bool, optional): Is QAT ro not. Defaults to True. + skipped_methods (List, optional): Skipped methods list for tracer. + Defaults to None. + prepare_custom_config_dict (Dict, optional): `PrepareCustomConfig` + from `torch.quantization.fx`. Defaults to None. + convert_custom_config_dict (Dict, optional): `ConvertCustomConfig` + from `torch.quantization.fx`. Defaults to None. + equalization_qconfig_dict (Dict, optional): Custom `QConfig` effects + on all modules. Defaults to None. + _remove_qconfig (Dict, optional): Remove qconfig at the end of + `_convert_fx`. Defaults to True. + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, + qconfig: Dict = DefalutQconfigs['default'], + is_qat: bool = True, + skipped_methods: List = None, + prepare_custom_config_dict: Dict = None, + convert_custom_config_dict: Dict = None, + equalization_qconfig_dict: Dict = None, + _remove_qconfig: bool = True, + init_cfg: Dict = None): + super().__init__(init_cfg) + if self.check_qconfig(qconfig): + qconfig = self.qconfig_convert(qconfig) + self.qconfig_dict = {'': qconfig} + else: + raise ValueError('qconfig is incorrect!') + + if prepare_custom_config_dict is None: + self.prepare_custom_config_dict = {} + else: + self.prepare_custom_config_dict = prepare_custom_config_dict + if convert_custom_config_dict is None: + self.convert_custom_config_dict = {} + else: + self.convert_custom_config_dict = convert_custom_config_dict + if equalization_qconfig_dict is None: + self.equalization_qconfig_dict = {} + else: + self.equalization_qconfig_dict = equalization_qconfig_dict + + check_is_valid_qconfig_dict(self.qconfig_dict) + check_is_valid_prepare_custom_config_dict( + self.prepare_custom_config_dict) + check_is_valid_convert_custom_config_dict( + self.convert_custom_config_dict) + check_is_valid_qconfig_dict(self.equalization_qconfig_dict) + + self.is_qat = is_qat + self.skipped_methods = skipped_methods + self._remove_qconfig = _remove_qconfig + self.tracer = self.build_tracer() + + def prepare(self, model, graph_module): + + preserved_attributes = self.prepare_custom_config_dict.get( + 'preserved_attributes', []) + for attr_name in preserved_attributes: + setattr(graph_module, attr_name, getattr(model, attr_name)) + + graph_module = self.fuse_model(graph_module) + + prepared = prepare( + graph_module, + self.qconfig_dict, + self.is_qat, + self.tracer.node_name_to_scope, + prepare_custom_config_dict=self.prepare_custom_config_dict, + equalization_qconfig_dict=self.equalization_qconfig_dict + ) # type: ignore[operator] + + for attr_name in preserved_attributes: + setattr(prepared, attr_name, getattr(model, attr_name)) + return prepared + + def convert(self, graph_module): + quantized = _convert_fx( + graph_module, + is_reference=False, + convert_custom_config_dict=self.convert_custom_config_dict, + _remove_qconfig=self._remove_qconfig, + qconfig_dict=self.qconfig_dict) + return quantized + + def check_qconfig(self, qconfig): + is_pass = True + for arg in CheckArgs: + if arg == 'qtype': + if qconfig[arg] in SupportQtypes and arg in qconfig.keys(): + continue + else: + is_pass = False + break + else: + if isinstance(qconfig[arg], dict) and arg in qconfig.keys(): + continue + else: + is_pass = False + break + return is_pass + + def qconfig_convert(self, qconfig): + self.w_qscheme = QuantizeScheme(**qconfig['w_qscheme']) + self.a_qscheme = QuantizeScheme(**qconfig['a_qscheme']) + w_observer = MODELS.get(qconfig['w_observer']['type']) + w_observer_kwargs = self.w_qscheme.to_observer_params() + a_observer = MODELS.get(qconfig['a_observer']['type']) + a_observer_kwargs = self.a_qscheme.to_observer_params() + self.w_observer = MODELS.get(qconfig['w_observer']['type']).with_args( + **self.w_qscheme.to_observer_params()) + self.a_observer = MODELS.get(qconfig['a_observer']['type']).with_args( + **self.a_qscheme.to_observer_params()) + self.w_fake_quant = MODELS.get( + qconfig['w_fake_quant']['type']).with_args( + observer=w_observer, **w_observer_kwargs) + self.a_fake_quant = MODELS.get( + qconfig['a_fake_quant']['type']).with_args( + observer=a_observer, **a_observer_kwargs) + + torch_qconfig = QConfig( + weight=self.w_fake_quant, activation=self.a_fake_quant) + return torch_qconfig + + def _swap_ff_with_fxff(self, model: torch.nn.Module) -> None: + r""" Swap FloatFunctional with FXFloatFunctional + """ + modules_to_swap = [] + for name, module in model.named_children(): + if isinstance(module, torch.nn.quantized.FloatFunctional): + modules_to_swap.append(name) + else: + self._swap_ff_with_fxff(module) + + for name in modules_to_swap: + del model._modules[name] + model._modules[name] = torch.nn.quantized.FXFloatFunctional() + + def build_tracer(self): + skipped_module_names = self.prepare_custom_config_dict.get( + 'non_traceable_module_name', []) + skipped_module_classes = self.prepare_custom_config_dict.get( + 'non_traceable_module_class', []) + standalone_module_name_configs = self.prepare_custom_config_dict.get( + 'standalone_module_name', []) + skipped_module_names += [ + config[0] for config in standalone_module_name_configs + ] + + standalone_module_class_configs = self.prepare_custom_config_dict.get( + 'standalone_module_class', []) + skipped_module_classes += [ + config[0] for config in standalone_module_class_configs + ] + float_custom_module_classes = get_custom_module_class_keys( + self.prepare_custom_config_dict, + 'float_to_observed_custom_module_class') + skipped_module_classes += float_custom_module_classes + tracer = CustomTracer(self.skipped_methods, skipped_module_names, + skipped_module_classes) + # tracer = QuantizationTracer(skipped_module_names, + # skipped_module_classes) + return tracer + + def fuse_model(self, graph_module): + graph_module = _fuse_fx(graph_module, self.is_qat, + self.prepare_custom_config_dict) + return graph_module diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py new file mode 100644 index 000000000..cc8532a53 --- /dev/null +++ b/mmrazor/models/quantizers/trt_quantizer.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import DefalutQconfigs +from .base import CustomQuantizer + + +@MODELS.register_module() +class TensorRTQuantizer(CustomQuantizer): + """Quantizer for TensorRT backend.""" + + def __init__(self, + qconfig=DefalutQconfigs['tensorrt'], + is_qat=True, + skipped_methods=None, + prepare_custom_config_dict=None, + convert_custom_config_dict=None, + equalization_qconfig_dict=None, + _remove_qconfig=True, + init_cfg=None): + super().__init__(qconfig, is_qat, skipped_methods, + prepare_custom_config_dict, + convert_custom_config_dict, equalization_qconfig_dict, + _remove_qconfig, init_cfg) diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index c3ff8dd66..aab26fa40 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -2,6 +2,7 @@ from .backward_tracer import BackwardTracer from .channel_analyzer import ChannelAnalyzer # from .razor_tracer import RazorFxTracer +from .fx import CustomTracer, UntracedMethodRegistry, custom_symbolic_trace from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, @@ -10,5 +11,6 @@ __all__ = [ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', - 'ChannelAnalyzer' + 'ChannelAnalyzer', 'CustomTracer', 'UntracedMethodRegistry', + 'custom_symbolic_trace' ] diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py new file mode 100644 index 000000000..29c93f83a --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .custom_tracer import (CustomTracer, UntracedMethodRegistry, + custom_symbolic_trace) + +__all__ = ['CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace'] diff --git a/mmrazor/models/task_modules/tracer/fx/custom_tracer.py b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py new file mode 100644 index 000000000..f69ec2269 --- /dev/null +++ b/mmrazor/models/task_modules/tracer/fx/custom_tracer.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import functools +from types import FunctionType, MethodType +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import torch +from mmengine.utils import import_modules_from_strings +from torch._C import ScriptObject # type: ignore[attr-defined] +from torch.ao.quantization.quantize_fx import QuantizationTracer +from torch.fx import GraphModule, Tracer +from torch.fx._symbolic_trace import (Graph, _autowrap_check, + _patch_wrapped_functions, _Patcher) +from torch.fx.proxy import Proxy + +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ +# _orig_module_forward_train: Callable = models.BaseDenseHead.forward_train + + +class UntracedMethodRegistry: + """A `Descriptor` class which records untraced methods.""" + method_dict: Dict = dict() + tracer = None + + def __init__(self, method): + """_summary_ + + Args: + method (FunctionType): Function to be registered. + """ + self.method = method + self.instances: Dict = dict() + self.owner = None + + def __set_name__(self, owner, name): + self.owner = owner + self.name = name + wrapped = self.method_wrapper() + self.method_dict[name] = dict(mod=self.owner, wrapped=wrapped) + + def __get__(self, instance, owner): + if instance is None: + return self.method + return MethodType(self.method, instance) + + def method_wrapper(self): + + @functools.wraps(self.method) + def wrapped_method(mod, *args, **kwargs): + + def method(*args, **kwargs): + return self.method(mod, *args, **kwargs) + + return self.tracer.call_method(mod, self.name, method, args, + kwargs) + + return wrapped_method + + +def custom_symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule: + """Modified `symbolic_trace` function. + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be + traced and converted into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially + specialized. + + Returns: + _type_: _description_ + """ + tracer = CustomTracer() + graph = tracer.trace(root, concrete_args) + name = root.__class__.__name__ if isinstance( + root, torch.nn.Module) else root.__name__ + return GraphModule(tracer.root, graph, name) + + +class CustomTracer(QuantizationTracer): + + def __init__(self, + skipped_methods: List[str] = [], + skipped_module_names: List[str] = [], + skipped_module_classes: List[Callable] = [], + *args, + **kwargs): + """_summary_ + + Args: + skipped_methods (List[str], optional): Methods to be skipped while + tracing. Defaults to None. + skipped_module_names (List[str], optional): Modules to be skipped + while tracing. Defaults to None. + skipped_module_classes (List[str], optional): Class to be skipped + while tracing. Defaults to None. + """ + super(CustomTracer, self).__init__(skipped_module_names, + skipped_module_classes) + UntracedMethodRegistry.tracer = self # type: ignore + self.skipped_methods = skipped_methods + if self.skipped_methods: + self.register_skipped_methods() + + @staticmethod + def _check_valid_source(source): + """Check if the source's format is valid.""" + if not isinstance(source, str): + raise TypeError(f'source should be a str ' + f'instance, but got {type(source)}') + + assert len(source.split('.')) > 1, \ + 'source must have at least one `.`' + + def register_skipped_methods(self): + if not isinstance(self.skipped_methods, list): + self.skipped_methods = [self.skipped_methods] + for s_method in self.skipped_methods: + self._check_valid_source(s_method) + mod_str = '.'.join(s_method.split('.')[:-2]) + cls_str = s_method.split('.')[-2] + method_str = s_method.split('.')[-1] + + try: + mod = import_modules_from_strings(mod_str) + except ImportError: + raise ImportError(f'{mod_str} is not imported correctly.') + + imported_cls: type = getattr(mod, cls_str) + if not isinstance(imported_cls, type): + raise TypeError(f'{cls_str} should be a type ' + f'instance, but got {type(imported_cls)}') + assert hasattr(imported_cls, method_str), \ + f'{method_str} is not in {mod_str}.' + + method = getattr(imported_cls, method_str) + + method_registry = UntracedMethodRegistry(method) + method_registry.__set_name__(imported_cls, method_str) + + def call_method(self, m: torch.nn.Module, name, method, args, kwargs): + """Method that specifies the behavior of this ``Tracer`` when it + encounters a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf + module via ``is_leaf_module``. If it is, emit a ``call_module`` + node referring to ``m`` in the ``Graph``. Otherwise, call the + ``Module`` normally, tracing through the operations in its ``forward`` + function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be + invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a + ``call_module`` node was emitted, this is a ``Proxy`` value. + Otherwise, it is whatever value was returned from the ``Module`` + invocation. + """ + # module_qualified_name = self.path_of_module(m) + if not self.is_skipped_method(m): + return method(*args, **kwargs) + args = list(args) + args.insert(0, m) + args = tuple(args) + return self.create_proxy('call_method', name, args, kwargs) + + def trace(self, root, concrete_args=None): + if isinstance(root, torch.nn.Module): + self.root = root + fn = type(root).forward + self.submodule_paths = { + mod: name + for name, mod in root.named_modules() + } + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[Type['Tracer']] = getattr(self, '__class__', None) + self.graph = Graph(tracer_cls=tracer_cls) + + # When we encounter a Tensor value that's not a parameter, we look if + # it is some other attribute on the model. Construct a dict mapping + # Tensor values to the qualified name here for efficiency. This is + # used downstream in create_arg + self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): + for k, v in m.__dict__.items(): + if isinstance(v, (torch.Tensor, ScriptObject)): + self.tensor_attrs[v] = '.'.join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root(fn, + isinstance(root, torch.nn.Module), + concrete_args) + + # Reduce number of get_attr calls + parameter_proxy_cache: Dict[str, Proxy] = {} + + # Method dispatch on parameters is not recorded unless it's directly + # used. Thus, we need to insert a proxy when __getattr__ requests a + # parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self._module_getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, + getattr(getattr(mod, 'forward', mod), '__globals__', {}), + self._autowrap_function_ids) + return self.call_module(mod, forward, args, kwargs) + + with _Patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + '__getattr__', + module_getattr_wrapper, + deduplicate=False) + patcher.patch_method( + torch.nn.Module, + '__call__', + module_call_wrapper, + deduplicate=False) + + for name, value in UntracedMethodRegistry.method_dict.items(): + wrapped = value['wrapped'] + patcher.patch_method( + value['mod'], name, wrapped, deduplicate=False) + + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check(patcher, module.__dict__, + self._autowrap_function_ids) + self.create_node( + 'output', + 'output', (self.create_arg(fn(*args)), ), {}, + type_expr=fn.__annotations__.get('return', None)) + + self.submodule_paths = None + + return self.graph + + def is_skipped_method(self, m): + mods = tuple(value['mod'] + for value in UntracedMethodRegistry.method_dict.values()) + custom = isinstance(m, mods) + return custom + + def is_leaf_module(self, m: torch.nn.Module, + module_qualified_name: str) -> bool: + # return super().is_leaf_module(m, module_qualified_name) + leaf = super().is_leaf_module(m, module_qualified_name) + return leaf diff --git a/mmrazor/models/utils/quantization_util.py b/mmrazor/models/utils/quantization_util.py new file mode 100644 index 000000000..376096b67 --- /dev/null +++ b/mmrazor/models/utils/quantization_util.py @@ -0,0 +1,217 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial +from typing import Any, Dict, List, Optional, Set + +import torch + + +class PerChannelLoadHook: + + def __init__(self, module, hook_param=['scale', 'zero_point']): + self.hook = module._register_load_state_dict_pre_hook( + partial(self.hook_fn, module=module)) + self.hook_param = hook_param + + def hook_fn(self, state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs, module): + if module.ch_axis == -1: + # no per-channel parameters + return + for module_key, param in module._parameters.items(): + if module_key not in self.hook_param: + continue + candidate = prefix + module_key + if candidate in state_dict: + input_param = state_dict[candidate] + if param.shape != input_param.shape: + param.data = torch.ones_like( + input_param, dtype=param.dtype, device=param.device) + for module_key, param in module._buffers.items(): + if module_key not in self.hook_param: + continue + candidate = prefix + module_key + if candidate in state_dict: + input_param = state_dict[candidate] + if param.shape != input_param.shape: + param.data = torch.ones_like( + input_param, dtype=param.dtype, device=param.device) + + def close(self): + self.hook.remove() + + +USE_LINK = False +USE_DDP = False + +try: + import spring.linklink as link + assert link.is_initialized() + USE_LINK = True +except (ModuleNotFoundError, AssertionError): + import torch.distributed as dist + if torch.distributed.is_initialized(): + USE_DDP = True + + +def sync_tensor(tensor): + if USE_LINK: + if tensor.is_cuda is True: + tensor.data = tensor.data / link.get_world_size() + link.allreduce(tensor.data) + elif USE_DDP: + tensor.data = tensor.data / dist.get_world_size() + dist.all_reduce(tensor.data) + return tensor + + +def pot_quantization(tensor: torch.Tensor, mode='round'): + log2t = torch.log2(tensor) + if mode == 'round': + log2t = (torch.round(log2t) - log2t).detach() + log2t + else: + assert mode == 'floor' + log2t = (torch.floor(log2t) - log2t).detach() + log2t + return 2**log2t + + +def _is_per_channel(qscheme: 'torch.qscheme') -> bool: + return qscheme in [ + torch.per_channel_symmetric, torch.per_channel_affine, + torch.per_channel_affine_float_qparams + ] + + +def _is_per_tensor(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine] + + +def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool: + return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric] + + +def is_tracing_state(): + return torch._C._get_tracing_state() + + +def _is_float_qparams(qscheme: 'torch.qscheme') -> bool: + return qscheme in [ + torch.per_channel_affine_float_qparams, + ] + + +def check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], + dict_name: str) -> None: + r""" Checks if the given config_dict has the correct keys + Args: + `config_dict`: dictionary whose keys we want to check + """ + + for k in config_dict.keys(): + if k not in allowed_keys: + raise ValueError('Expected ' + dict_name + + ' to have the following keys: ' + + str(allowed_keys) + '. But found \'' + k + + '\' instead.') + + +def check_is_valid_qconfig_dict(qconfig_dict: Any) -> None: + r""" Checks if the given qconfig_dict has the correct keys + Args: + `qconfig_dict`: dictionary whose keys we want to check + """ + + qconfig_dict_allowed_keys = { + '', 'object_type', 'module_name_regex', 'module_name', + 'module_name_object_type_order' + } + check_is_valid_config_dict(qconfig_dict, qconfig_dict_allowed_keys, + 'qconfig_dict') + + +def check_is_valid_prepare_custom_config_dict( + prepare_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: + r""" Checks if the given prepare_custom_config_dict has the correct keys + Args: + `prepare_custom_config_dict`: customization configuration dictionary for + quantization tool + """ + if not prepare_custom_config_dict: + return + + prepare_custom_config_dict_allowed_keys = { + 'standalone_module_name', 'standalone_module_class', + 'float_to_observed_custom_module_class', 'non_traceable_module_name', + 'non_traceable_module_class', 'input_quantized_idxs', + 'output_quantized_idxs', 'preserved_attributes' + } + check_is_valid_config_dict(prepare_custom_config_dict, + prepare_custom_config_dict_allowed_keys, + 'prepare_custom_config_dict') + + +def check_is_valid_convert_custom_config_dict( + convert_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: + r""" Checks if the given convert_custom_config_dict has the correct keys + Args: + `convert_custom_config_dict`: dictionary for custom configurations for + convert function + """ + if not convert_custom_config_dict: + return + + convert_custom_config_dict_allowed_keys = { + 'observed_to_quantized_custom_module_class', 'preserved_attributes' + } + check_is_valid_config_dict(convert_custom_config_dict, + convert_custom_config_dict_allowed_keys, + 'convert_custom_config_dict') + + +def check_is_valid_fuse_custom_config_dict( + fuse_custom_config_dict: Optional[Dict[str, Any]] = None) -> None: + r""" Checks if the given fuse_custom_config_dict has the correct keys + Args: + `fuse_custom_config_dict`: dictionary for custom configurations for + fuse_fx + """ + if not fuse_custom_config_dict: + return + + fuse_custom_config_dict_allowed_keys = {'preserved_attributes'} + check_is_valid_config_dict(fuse_custom_config_dict, + fuse_custom_config_dict_allowed_keys, + 'fuse_custom_config_dict') + + +def get_custom_module_class_keys(custom_config_dict, + custom_config_dict_key) -> List[Any]: + r""" Get all the unique custom module keys in the custom config dict + e.g. + Input: + custom_config_dict = { + "float_to_observed_custom_module_class": { + "static": { + CustomModule1: ObservedCustomModule + }, + "dynamic": { + CustomModule2: DynamicObservedCustomModule + }, + "weight_only": { + CustomModule3: WeightOnlyObservedCustomModule + }, + }, + } + Output: + # extract all the keys in "static", "dynamic" and "weight_only" dict + [CustomModule1, CustomModule2, CustomModule3] + """ + # using set to dedup + float_custom_module_classes: Set[Any] = set() + custom_module_mapping = custom_config_dict.get(custom_config_dict_key, {}) + for quant_mode in ['static', 'dynamic', 'weight_only']: + quant_mode_custom_module_config = custom_module_mapping.get( + quant_mode, {}) + quant_mode_custom_module_classes = set( + quant_mode_custom_module_config.keys()) + float_custom_module_classes |= quant_mode_custom_module_classes + return list(float_custom_module_classes) diff --git a/mmrazor/structures/quantization/__init__.py b/mmrazor/structures/quantization/__init__.py new file mode 100644 index 000000000..fc2133bf2 --- /dev/null +++ b/mmrazor/structures/quantization/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .backend_default_qconfigs import CheckArgs, DefalutQconfigs, SupportQtypes +from .qscheme import QuantizeScheme + +__all__ = ['QuantizeScheme', 'DefalutQconfigs', 'SupportQtypes', 'CheckArgs'] diff --git a/mmrazor/structures/quantization/backend_default_qconfigs.py b/mmrazor/structures/quantization/backend_default_qconfigs.py new file mode 100644 index 000000000..6a1fde183 --- /dev/null +++ b/mmrazor/structures/quantization/backend_default_qconfigs.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +SupportQtypes = ('affine') +CheckArgs = [ + 'qtype', 'w_qscheme', 'a_qscheme', 'w_fake_quant', 'a_fake_quant', + 'w_observer', 'a_observer' +] + +Default = dict( + qtype='affine', # noqa: E241 + w_qscheme=dict( + is_symmetry=True, + is_per_channel=True, + is_pot_scale=False, + bit=8, + symmetric_range=True), + a_qscheme=dict( + is_symmetry=True, + is_per_channel=False, + is_pot_scale=False, + bit=8, + symmetric_range=True), + w_fake_quant=dict(type='BaseFakeQuantize'), + a_fake_quant=dict(type='BaseFakeQuantize'), + w_observer=dict(type='MinMaxObserver'), + a_observer=dict(type='MinMaxObserver')) + +TensorRT = dict( + qtype='affine', # noqa: E241 + w_qscheme=dict( + is_symmetry=True, + is_per_channel=True, + is_pot_scale=False, + bit=8, + symmetric_range=True), + a_qscheme=dict( + is_symmetry=True, + is_per_channel=False, + is_pot_scale=False, + bit=8, + symmetric_range=True), + w_fake_quant=dict(type='LearnableFakeQuantize'), + a_fake_quant=dict(type='LearnableFakeQuantize'), + w_observer=dict(type='MinMaxObserver'), + a_observer=dict(type='EMAMinMaxObserver')) + +DefalutQconfigs = dict(default=Default, tensorrt=TensorRT) diff --git a/mmrazor/structures/quantization/qscheme.py b/mmrazor/structures/quantization/qscheme.py new file mode 100644 index 000000000..24c41832e --- /dev/null +++ b/mmrazor/structures/quantization/qscheme.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + + +class QuantizeScheme(object): + """Custom QScheme. Refer to: + https://github.com/pytorch/pytorch/blob/master/c10/core/QScheme.h. + + Args: + bit (int, optional): Bit number. Defaults to 8. + is_symmetry (bool, optional): Is symmetry quantization or not. Defaults + to True. + is_per_channel (bool, optional): Is per-channel quantization or not. + Defaults to False. + is_pot_scale (bool, optional): Indicate whether scale is power of two. + Defaults to False. + """ + + def __init__(self, + bit=8, + is_symmetry=True, + is_per_channel=False, + is_pot_scale=False, + **kwargs): + self.bit = bit + self.is_symmetry = is_symmetry + self.is_per_channel = is_per_channel + self.is_pot_scale = is_pot_scale + + if self.is_per_channel: + self.torch_qscheme = torch.per_channel_symmetric \ + if self.is_symmetry else torch.per_channel_affine + else: + self.torch_qscheme = torch.per_tensor_symmetric \ + if self.is_symmetry else torch.per_tensor_affine + if 'is_symmetric_range' in kwargs: + self.is_symmetric_range = kwargs['is_symmetric_range'] + del kwargs['is_symmetric_range'] + else: + self.is_symmetric_range = False + self.kwargs = kwargs + + def to_observer_params(self): + quant_min = 0 + quant_max = 2**self.bit - 1 + if self.is_symmetry: + quant_max = 2**(self.bit - 1) - 1 + if self.is_symmetric_range: + quant_min = -2**(self.bit - 1) + 1 + else: + quant_min = -2**(self.bit - 1) + + naive_para = { + 'quant_min': quant_min, + 'quant_max': quant_max, + 'dtype': torch.qint8 if self.is_symmetry else torch.quint8, + 'is_pot_scale': self.is_pot_scale, + 'qscheme': self.torch_qscheme, + 'reduce_range': False, + 'ch_axis': 0 if self.is_per_channel else -1 + } + naive_para.update(self.kwargs) + return naive_para + + def __str__(self): + return f'bit: {self.bit} / is_symmetry: {self.is_symmetry} / \ + is_per_channel: {self.is_per_channel} / is_pot_scale: \ + {self.is_pot_scale} / extra_kwargs: {self.kwargs}' diff --git a/mmrazor/testing/__init__.py b/mmrazor/testing/__init__.py index 009dd844d..54dfd30ed 100644 --- a/mmrazor/testing/__init__.py +++ b/mmrazor/testing/__init__.py @@ -1,2 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. from ._fast_stop_training_hook import FastStopTrainingHook # noqa: F401,F403 +from ._fx_models import * # noqa: F401, F403 diff --git a/mmrazor/testing/_fx_models.py b/mmrazor/testing/_fx_models.py new file mode 100644 index 000000000..969c4792d --- /dev/null +++ b/mmrazor/testing/_fx_models.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional, Tuple, Union + +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmrazor.registry import MODELS + + +@MODELS.register_module() +class ConvBNReLU(nn.Module): + + def __init__( + self, + in_channel: int, + out_channel: int, + kernel_size: Union[int, Tuple[int, int]] = 1, + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: Union[str, bool] = 'auto', + conv_cfg: Optional[Dict] = None, + norm_cfg: Optional[Dict] = None, + act_cfg: Dict = dict(type='ReLU'), + inplace: bool = True, + with_spectral_norm: bool = False, + padding_mode: str = 'zeros', + order: tuple = ('conv', 'norm', 'act'), + init_cfg: Optional[Dict] = None, + ) -> None: + super().__init__() + self.conv_module = ConvModule(in_channel, out_channel, kernel_size, + stride, padding, dilation, groups, bias, + conv_cfg, norm_cfg, act_cfg, inplace, + with_spectral_norm, padding_mode, order) + + def forward(self, x): + x = self.conv_module.conv(x) + x = self.conv_module.norm(x) + x = self.conv_module.activate(x) + return x diff --git a/tests/test_models/test_algorithms/test_general_quant.py b/tests/test_models/test_algorithms/test_general_quant.py new file mode 100644 index 000000000..94a2485bc --- /dev/null +++ b/tests/test_models/test_algorithms/test_general_quant.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + + +class ToyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # TODO + + +class TestGeneralQuant(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def test_init(self): + pass + + def test_prepare(self): + pass + + def test_convert(self): + pass + + def test_states(self): + pass + + def test_forward(self): + pass diff --git a/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py b/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py new file mode 100644 index 000000000..d6b670bb5 --- /dev/null +++ b/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + + +class TestLearnableFakeQuantize(TestCase): + + def test_init(self): + pass + + def test_repr(self): + pass + + def test_calculate_qparams(self): + pass + + def test_forward(self): + pass + + def test_load_state_dict(self): + pass + + def test_save_state_dict(self): + pass diff --git a/tests/test_models/test_mutators/test_diff_mutator.py b/tests/test_models/test_mutators/test_diff_mutator.py new file mode 100644 index 000000000..663637fc9 --- /dev/null +++ b/tests/test_models/test_mutators/test_diff_mutator.py @@ -0,0 +1,235 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch.nn as nn + +from mmrazor.models import * # noqa: F401,F403 +from mmrazor.models.mutables import DiffMutableModule +from mmrazor.models.mutators import DiffModuleMutator +from mmrazor.registry import MODELS + +MODELS.register_module(name='torchConv2d', module=nn.Conv2d, force=True) +MODELS.register_module(name='torchMaxPool2d', module=nn.MaxPool2d, force=True) +MODELS.register_module(name='torchAvgPool2d', module=nn.AvgPool2d, force=True) + + +class SearchableLayer(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + self.op1 = MODELS.build(mutable_cfg) + self.op2 = MODELS.build(mutable_cfg) + self.op3 = MODELS.build(mutable_cfg) + + def forward(self, x): + x = self.op1(x) + x = self.op2(x) + return self.op3(x) + + +class SearchableModel(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + self.slayer1 = SearchableLayer(mutable_cfg) + self.slayer2 = SearchableLayer(mutable_cfg) + self.slayer3 = SearchableLayer(mutable_cfg) + + def forward(self, x): + x = self.slayer1(x) + x = self.slayer2(x) + return self.slayer3(x) + + +class SearchableLayerAlias(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + mutable_cfg.update(alias='op1') + self.op1 = MODELS.build(mutable_cfg) + mutable_cfg.update(alias='op2') + self.op2 = MODELS.build(mutable_cfg) + mutable_cfg.update(alias='op3') + self.op3 = MODELS.build(mutable_cfg) + + def forward(self, x): + x = self.op1(x) + x = self.op2(x) + return self.op3(x) + + +class SearchableModelAlias(nn.Module): + + def __init__(self, mutable_cfg: dict) -> None: + super().__init__() + self.slayer1 = SearchableLayerAlias(mutable_cfg) + self.slayer2 = SearchableLayerAlias(mutable_cfg) + self.slayer3 = SearchableLayerAlias(mutable_cfg) + + def forward(self, x): + x = self.slayer1(x) + x = self.slayer2(x) + return self.slayer3(x) + + +class TestDiffModuleMutator(TestCase): + + def setUp(self): + self.MUTABLE_CFG = dict( + type='DiffMutableOP', + candidates=dict( + torch_conv2d_3x3=dict( + type='torchConv2d', + kernel_size=3, + padding=1, + ), + torch_conv2d_5x5=dict( + type='torchConv2d', + kernel_size=5, + padding=2, + ), + torch_conv2d_7x7=dict( + type='torchConv2d', + kernel_size=7, + padding=3, + ), + ), + module_kwargs=dict(in_channels=32, out_channels=32, stride=1)) + + self.MUTATOR_CFG = dict( + type='DiffModuleMutator', + custom_groups=[['op1'], ['op2'], ['op3']]) + + def test_diff_mutator_diffop_layer(self) -> None: + model = SearchableLayer(self.MUTABLE_CFG) + mutator: DiffModuleMutator = MODELS.build(self.MUTATOR_CFG) + + mutator.prepare_from_supernet(model) + assert list(mutator.search_groups.keys()) == [0, 1, 2] + + def test_diff_mutator_diffop_model(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + assert list(mutator.search_groups.keys()) == [0, 1, 2] + + mutator.modify_supernet_forward(mutator.arch_params) + assert mutator.mutable_class_type == DiffMutableModule + + def test_diff_mutator_diffop_model_error(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3_error_key'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_diffop_alias(self) -> None: + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [['op1'], ['op2'], ['op3']] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + + assert list(mutator.search_groups.keys()) == [0, 1, 2] + + mutator.modify_supernet_forward(mutator.arch_params) + assert mutator.mutable_class_type == DiffMutableModule + + def test_diff_mutator_alias_module_name(self) -> None: + """Using both alias and module name for grouping.""" + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [['op1'], + [ + 'slayer1.op2', 'slayer2.op2', + 'slayer3.op2' + ], ['slayer1.op3', 'slayer2.op3']] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + mutator.prepare_from_supernet(model) + + assert list(mutator.search_groups.keys()) == [0, 1, 2, 3] + + mutator.modify_supernet_forward(mutator.arch_params) + assert mutator.mutable_class_type == DiffMutableModule + + def test_diff_mutator_duplicate_keys(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer2.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_duplicate_key_alias(self) -> None: + model = SearchableModelAlias(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['op1', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_diff_mutator_illegal_key(self) -> None: + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['illegal_key', 'slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + + with pytest.raises(AssertionError): + mutator.prepare_from_supernet(model) + + def test_sample_and_set_choices(self): + model = SearchableModel(self.MUTABLE_CFG) + + mutator_cfg = self.MUTATOR_CFG.copy() + mutator_cfg['custom_groups'] = [ + ['slayer1.op1', 'slayer2.op1', 'slayer3.op1'], + ['slayer1.op2', 'slayer2.op2', 'slayer3.op2'], + ['slayer1.op3', 'slayer2.op3', 'slayer3.op3'], + ] + mutator: DiffModuleMutator = MODELS.build(mutator_cfg) + mutator.prepare_from_supernet(model) + choices = mutator.sample_choices() + mutator.set_choices(choices) + self.assertTrue(len(choices) == 3) + + +if __name__ == '__main__': + import unittest + unittest.main() diff --git a/tests/test_models/test_observers/test_observer.py b/tests/test_models/test_observers/test_observer.py new file mode 100644 index 000000000..ca39ecfbd --- /dev/null +++ b/tests/test_models/test_observers/test_observer.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + + +class ToyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # TODO + + +class TestMinMaxObserver(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def test_init(self): + pass + + def test_prepare(self): + pass + + def test_convert(self): + pass + + def test_states(self): + pass + + def test_forward(self): + pass + + +class TestLSQObserver(TestMinMaxObserver): + pass diff --git a/tests/test_models/test_quantizers/test_trt_quantizer.py b/tests/test_models/test_quantizers/test_trt_quantizer.py new file mode 100644 index 000000000..9f85d1ecd --- /dev/null +++ b/tests/test_models/test_quantizers/test_trt_quantizer.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import torch.nn as nn + + +class ToyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + # TODO + + +class TestTRTQuantizer(TestCase): + """TODO. + + Args: + TestCase (_type_): _description_ + """ + + def test_init(self): + pass + + def test_prepare(self): + pass + + def test_convert(self): + pass + + def test_states(self): + pass + + def test_forward(self): + pass diff --git a/tests/test_models/test_task_modules/test_custom_tracer.py b/tests/test_models/test_task_modules/test_custom_tracer.py new file mode 100644 index 000000000..671922f69 --- /dev/null +++ b/tests/test_models/test_task_modules/test_custom_tracer.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +from mmrazor.models.task_modules import CustomTracer, UntracedMethodRegistry +from mmrazor.testing import ConvBNReLU + + +class testCustomTracer(TestCase): + + def test_init(self): + tracer = CustomTracer() + assert tracer.skipped_methods.__len__() == 0 + + def test_trace(self): + tracer = CustomTracer() + model = ConvBNReLU(3, 3, norm_cfg=dict(type='BN')) + graph = tracer.trace(model) # noqa: F841 + + def test_auto_skip_call_module(self): + pass + + def test_auto_skip_call_method(self): + pass + + def test_configurable_skipped_methods(self): + pass + + +class testUntracedMethodRgistry(TestCase): + + def test_init(self): + self.assertEqual(len(UntracedMethodRegistry.method_dict), 0) + + def test_add_method(self): + pass diff --git a/tools/ckpt_demo.py b/tools/ckpt_demo.py new file mode 100644 index 000000000..ee257390c --- /dev/null +++ b/tools/ckpt_demo.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +ckpt_path = '/mnt/lustre/humu/experiments/adaround/quantizied.pth' +# ckpt_path = +# '/mnt/petrelfs/humu/share/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' +# ckpt_path = '/tmp/humu/resnet18_uniform8/checkpoint.pth.tar' +# ckpt_path = '/tmp/humu/resnet18_uniform8/quantized_checkpoint.pth.tar' + +state_dict = torch.load(ckpt_path, map_location='cpu') + +for k, v in state_dict['state_dict'].items(): + print(k) diff --git a/tools/slurm_test.sh b/tools/slurm_test.sh index 6dd67e574..3c74ec6ec 100644 --- a/tools/slurm_test.sh +++ b/tools/slurm_test.sh @@ -1,24 +1,10 @@ #!/usr/bin/env bash -set -x - -PARTITION=$1 -JOB_NAME=$2 -CONFIG=$3 -CHECKPOINT=$4 -GPUS=${GPUS:-8} -GPUS_PER_NODE=${GPUS_PER_NODE:-8} -CPUS_PER_TASK=${CPUS_PER_TASK:-5} -PY_ARGS=${@:5} -SRUN_ARGS=${SRUN_ARGS:-""} +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +PORT=${PORT:-29500} PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ -srun -p ${PARTITION} \ - --job-name=${JOB_NAME} \ - --gres=gpu:${GPUS_PER_NODE} \ - --ntasks=${GPUS} \ - --ntasks-per-node=${GPUS_PER_NODE} \ - --cpus-per-task=${CPUS_PER_TASK} \ - --kill-on-bad-exit=1 \ - ${SRUN_ARGS} \ - python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} +python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/tracer_demo.py b/tools/tracer_demo.py new file mode 100644 index 000000000..88334d6aa --- /dev/null +++ b/tools/tracer_demo.py @@ -0,0 +1,93 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy + +import torch +import torch.fx as fx +from mmengine.config import Config +from mmengine.registry import MODELS + +from mmrazor.models.task_modules.tracer import custom_symbolic_trace + +cfg_path = 'configs/quantization/ptq/demo.py' +_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) + + +def extract_subgraph(graphmodule, block_slice): + subgraph = copy.deepcopy(graphmodule.graph) + block_start, block_end = block_slice[:2] + for node in subgraph.nodes: + if node.name == 'inputs': + input_node = node + if node.name == block_start.name: + node.replace_input_with(node.prev, input_node) + if node.name == block_end.name: + output_node = node + if node.op == 'output': + node.replace_input_with(node.prev, output_node) + subgraph.lint() + subgraph_module = fx.GraphModule(graphmodule, subgraph) + subgraph_module.graph.eliminate_dead_code() + subgraph_module.recompile() + return subgraph_module + + +def extract_blocks(graphmodule, key_word='layer'): + block_slices = [] + block_slice = [] + pre_stage_index, pre_block_index = 0, 0 + cur_stage_index, cur_block_index = 0, 0 + for node in graphmodule.graph.nodes: + if key_word not in node.name: + continue + else: + items = node.name.split('_') + for i, item in enumerate(items): + if key_word in item: + cur_stage_index = int(item[5:]) + cur_block_index = int(items[i + 1]) + break + if (cur_block_index != pre_block_index) or (cur_stage_index != + pre_stage_index): + block_slice.append(node.prev) + if len(block_slice) == 2: + block_slices.append(block_slice) + block_slice = [] + block_slice.append(node) + + pre_stage_index, pre_block_index = cur_stage_index, cur_block_index + + return block_slices + + +def extract_layers(graphmodule, layer_types): + layer_slices = [] + for node in graphmodule.graph.nodes: + if node.op == 'call_module': + m = node.graph.owning_module.get_submodule(node.target) + if isinstance(m, _ADAROUND_SUPPORT_TYPE): + layer_slices.append((node, node)) + return layer_slices + + +def main(): + # load config + cfg = Config.fromfile(cfg_path) + model = MODELS.build(cfg.model) + symbolic_traced = custom_symbolic_trace( + model, concrete_args={'mode': 'tensor'}) + # block_slices = extract_blocks(symbolic_traced) + block_slices = extract_layers( + symbolic_traced, layer_types=_ADAROUND_SUPPORT_TYPE) + + for b in block_slices: + print(b[0].name, b[1].name) + + print('#' * 100) + subgraph = extract_subgraph(symbolic_traced, block_slices[0]) + print(subgraph.code) + for name, layer in subgraph.named_modules(): + print(name, layer) + + +if __name__ == '__main__': + main()