From 1caf08bc42cf4a55c27fc6178c67058a4a5b2b5e Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Thu, 16 Mar 2023 16:36:56 +0800 Subject: [PATCH 1/7] add rewriter --- .../quantization/mm_architecture.py | 81 ++++++++++++++++++- .../models/quantizers/openvino_quantizer.py | 28 ++++++- 2 files changed, 106 insertions(+), 3 deletions(-) diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index ff99592d6..44c90bc17 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from mmengine.model import MMDistributedDataParallel @@ -37,6 +38,7 @@ class MMArchitectureQuant(BaseAlgorithm): quantized. quantizer (Union[Dict, BaseModel]): The quantizer to support different backend type. + deploy_cfg (Union[str, Dict]): Deployment config file or Config object. qmodel_modes (List): The available mode of runner. data_preprocessor (Optional[Dict]): The pre-process config of :class:`BaseDataPreprocessor`. Defaults to None. @@ -57,6 +59,7 @@ class MMArchitectureQuant(BaseAlgorithm): def __init__(self, architecture: Union[Dict, BaseModel], quantizer: Union[Dict, BaseModel], + deploy_cfg: Union[str, Dict], data_preprocessor: Optional[Dict] = None, forward_modes: Tuple = ('tensor', 'predict', 'loss'), float_checkpoint: Optional[str] = None, @@ -68,6 +71,7 @@ def __init__(self, self.quantizer = MODELS.build(quantizer) self.input_shapes = input_shapes self.forward_modes = forward_modes + self.deploy_cfg = deploy_cfg # Replace syncbn and _BatchNormXd (in mmengine) with batchnorm2d self.quantizer.convert_batchnorm2d(self.architecture) @@ -145,6 +149,75 @@ def traverse(module, prefix): continue traverse(self.qmodels[mode], '') + def _get_rewriter_context_in_mmdeploy(self, deploy_cfg): + """Get rewriter context in mmdeploy according to the deploy related + config.""" + from mmdeploy.apis.onnx.passes import optimize_onnx + from mmdeploy.core import RewriterContext + from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes, + get_ir_config, get_onnx_config) + + def _add_or_update(cfg: dict, key: str, val: Any): + if key in cfg and isinstance(cfg[key], dict) and isinstance( + val, dict): + cfg[key].update(val) + else: + cfg[key] = val + + context_info = dict() + deploy_cfg = copy.deepcopy(deploy_cfg) + context_info['deploy_cfg'] = deploy_cfg + + backend = get_backend(deploy_cfg).value + + onnx_cfg = get_onnx_config(deploy_cfg) + opset_version = onnx_cfg.get('opset_version', 11) + + input_names = onnx_cfg['input_names'] + output_names = onnx_cfg['output_names'] + axis_names = input_names + output_names + dynamic_axes = get_dynamic_axes(deploy_cfg, axis_names) + + verbose = not onnx_cfg.get('strip_doc_string', True) or onnx_cfg.get( + 'verbose', False) + keep_initializers_as_inputs = onnx_cfg.get( + 'keep_initializers_as_inputs', True) + optimize = onnx_cfg.get('optimize', False) + if backend == Backend.NCNN.value: + """NCNN backend needs a precise blob counts, while using onnx + optimizer will merge duplicate initilizers without reference + count.""" + optimize = False + + ir_config = dict( + type='onnx', + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + verbose=verbose, + keep_initializers_as_inputs=keep_initializers_as_inputs) + + _add_or_update(deploy_cfg, 'ir_config', ir_config) + ir = IR.get(get_ir_config(deploy_cfg)['type']) + if isinstance(backend, Backend): + backend = backend.value + backend_config = dict(type=backend) + _add_or_update(deploy_cfg, 'backend_config', backend_config) + + context_info['cfg'] = deploy_cfg + context_info['ir'] = ir + if 'backend' not in context_info: + context_info['backend'] = backend + if 'opset' not in context_info: + context_info['opset'] = opset_version + + if 'onnx_custom_passes' not in context_info: + onnx_custom_passes = optimize_onnx if optimize else None + context_info['onnx_custom_passes'] = onnx_custom_passes + + return RewriterContext(**context_info) + def _build_qmodels(self, model: BaseModel): """Build quantized models from the given model. @@ -171,10 +244,14 @@ def _build_qmodels(self, model: BaseModel): output output (_get_predictions,) """ + rewriter_context = self._get_rewriter_context_in_mmdeploy( + self.deploy_cfg) + qmodels = nn.ModuleDict() for mode in self.forward_modes: concrete_args = {'mode': mode} - observed_module = self.quantizer.prepare(model, concrete_args) + with rewriter_context: + observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module # data_samples can not be None in detectors during prediction. diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py index 94658afbd..831d991f2 100644 --- a/mmrazor/models/quantizers/openvino_quantizer.py +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -46,6 +46,32 @@ def support_a_modes(self): per_channel.""" return ('per_tensor') + def prepare_for_mmdeploy(self, + model: torch.nn.Module, + dummy_input: Tuple = (1, 3, 224, 224), + checkpoint: Optional[str] = None): + """Prepare for deploy to the backend with mmdeploy, which will be used + in mmdeploy, and usually includes as follows: + + 1. prepare for the float model rewritten by mmdeploy. + 2. load checkpoint consists of float weight and quantized params in + mmrazor. + 3. post process weight fakequant for exporting .onnx that meet + the backend's requirement. + """ + self.convert_batchnorm2d(model) + observed_model = self.prepare(model) + if dummy_input is not None: + observed_model(torch.randn(dummy_input)) + if checkpoint is not None: + observed_model.load_state_dict( + torch.load(checkpoint)['state_dict']) + self.post_process_for_deploy(observed_model, keep_w_fake_quant=True) + + observed_model.apply(disable_observer) + + return observed_model + def export_onnx(self, model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction], @@ -55,7 +81,7 @@ def export_onnx(self, **kwargs): """Export the onnx model that can be deployed to OpenVino backend.""" - symbolic_output_path = f'symbolic_{output_path}' + symbolic_output_path = output_path.replace('.onnx', '_symbolic.onnx') torch.onnx.export( model, args, From 3dac52ddb35ec6e4c52b786f5a8c8dd52211081d Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 17 Mar 2023 16:34:13 +0800 Subject: [PATCH 2/7] add deploy_cfg arg --- .../ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py | 2 ++ mmrazor/models/algorithms/quantization/mm_architecture.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 578f5fe84..31d42bd83 100644 --- a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -10,6 +10,7 @@ retina = _base_.model float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 +deploy_cfg = r'G:\projects\openmmlab\mmdeploy\configs\mmdet\detection\detection_openvino_dynamic-800x1344-quantize.py', # noqa: E501 global_qconfig = dict( w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), @@ -32,6 +33,7 @@ bgr_to_rgb=True, pad_size_divisor=32), architecture=retina, + deploy_cfg=deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 44c90bc17..306b6cbd2 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -155,7 +155,8 @@ def _get_rewriter_context_in_mmdeploy(self, deploy_cfg): from mmdeploy.apis.onnx.passes import optimize_onnx from mmdeploy.core import RewriterContext from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes, - get_ir_config, get_onnx_config) + get_ir_config, get_onnx_config, + load_config) def _add_or_update(cfg: dict, key: str, val: Any): if key in cfg and isinstance(cfg[key], dict) and isinstance( @@ -164,6 +165,8 @@ def _add_or_update(cfg: dict, key: str, val: Any): else: cfg[key] = val + if isinstance(deploy_cfg, str): + deploy_cfg, = load_config(deploy_cfg) context_info = dict() deploy_cfg = copy.deepcopy(deploy_cfg) context_info['deploy_cfg'] = deploy_cfg From e67bf7bc4bc6d938d689a0eb0d47a9bf66a94430 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Tue, 21 Mar 2023 17:22:15 +0800 Subject: [PATCH 3/7] modify post_process_for_mmdeploy --- ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 2 + .../quantization/mm_architecture.py | 83 +++++++++++-------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 125f46367..161fe3eb1 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -20,6 +20,7 @@ ) float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 +deploy_cfg = r'G:\projects\openmmlab\mmdeploy\configs\mmcls\classification_openvino_dynamic-224x224.py' # noqa: E501 model = dict( _delete_=True, @@ -33,6 +34,7 @@ # convert image from BGR to RGB to_rgb=True), architecture=_base_.model, + deploy_cfg=deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 306b6cbd2..627ad3d3a 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -98,6 +98,41 @@ def reset_observer_and_fakequant_statistics(self, model): module.scale.data = torch.ones_like(module.scale) module.zero_point.data = torch.zeros_like(module.zero_point) + def _load(self, module, prefix, src_state_dict): + """Copies parameters and buffers from :attr:`src_state_dict` into this + module and its descendants. + + If the shape of the parameters and buffers between + :attr:`src_state_dict` and this module are different, we will reshape + the tensor shape of the parameters and buffers. + """ + for name, child in module._modules.items(): + if module is None: + continue + child_name = f'{prefix}{name}' + if isinstance(child, FakeQuantizeBase): + for name, param in child.named_parameters(): + param_name = f'{child_name}.{name}' + src_param = src_state_dict[param_name] + if src_param.shape == param.shape: + param.data.copy_(src_param) + else: + requirs_grad = param.requires_grad + param.requires_grad = False + param.resize_(src_param.shape) + param.requires_grad = requirs_grad + param.data.copy_(src_param) + for name, buffer in child.named_buffers(): + buffer_name = f'{child_name}.{name}' + src_buffer = src_state_dict[buffer_name] + if src_buffer.shape == buffer.shape: + buffer.data.copy_(src_buffer) + else: + buffer.resize_(src_buffer.shape) + buffer.data.copy_(src_buffer) + else: + self._load(child, f'{child_name}.', src_state_dict) + def sync_qparams(self, src_mode: str): """Sync all quantize parameters in different `forward_modes`. We could have more than one forward mode to generate graphs, each mode will @@ -108,46 +143,18 @@ def sync_qparams(self, src_mode: str): src_mode (str): The modes of forward method. Note: - `traverse()` function recursively traverses all module to sync + `_load()` method recursively traverses all module to sync quantized graph generated from different `forward_modes`. This is because We have different mode ('tensor', 'predict', 'loss') in OpenMMLab architecture which have different graph in some subtle ways, so we need to sync them here. """ - def traverse(module, prefix): - for name, child in module._modules.items(): - if module is None: - continue - child_name = f'{prefix}{name}' - if isinstance(child, FakeQuantizeBase): - for name, param in child.named_parameters(): - param_name = f'{child_name}.{name}' - src_param = src_state_dict[param_name] - if src_param.shape == param.shape: - param.data.copy_(src_param) - else: - requirs_grad = param.requires_grad - param.requires_grad = False - param.resize_(src_param.shape) - param.requires_grad = requirs_grad - param.data.copy_(src_param) - for name, buffer in child.named_buffers(): - buffer_name = f'{child_name}.{name}' - src_buffer = src_state_dict[buffer_name] - if src_buffer.shape == buffer.shape: - buffer.data.copy_(src_buffer) - else: - buffer.resize_(src_buffer.shape) - buffer.data.copy_(src_buffer) - else: - traverse(child, f'{child_name}.') - src_state_dict = self.qmodels[src_mode].state_dict() for mode in self.forward_modes: if mode == src_mode: continue - traverse(self.qmodels[mode], '') + self._load(self.qmodels[mode], '', src_state_dict) def _get_rewriter_context_in_mmdeploy(self, deploy_cfg): """Get rewriter context in mmdeploy according to the deploy related @@ -306,13 +313,23 @@ def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)): the backend's requirement. """ - quantized_state_dict = self.qmodels['tensor'].state_dict() + quantized_state_dict = self.qmodels['predict'].state_dict() fp32_model = self.architecture self.quantizer.convert_batchnorm2d(fp32_model) - observed_model = self.quantizer.prepare(fp32_model, {'mode': 'tensor'}) + + observed_model = self.quantizer.prepare(fp32_model, + {'mode': 'predict'}) if dummy_input is not None: - observed_model(torch.randn(dummy_input)) + # modify the tensor shape of parameters and buffers in + # observed_model + tensor_model = self.quantizer.prepare(fp32_model, + {'mode': 'tensor'}) + device = next(tensor_model.parameters()).device + dummy_input = torch.randn(dummy_input).to(device) + tensor_model(dummy_input, None, 'tensor') + src_state_dict = tensor_model.state_dict() + self._load(observed_model, '', src_state_dict) observed_model.load_state_dict(quantized_state_dict) From ca4fb4a8051b0c9a8982b476d1c2f2d5678d21f8 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 24 Mar 2023 15:29:47 +0800 Subject: [PATCH 4/7] fix bugs --- ...classification_openvino_dynamic-224x224.py | 30 ++++ ...penvino_resnet18_8xb32_in1k_calib32xb32.py | 8 +- .../quantization/mm_architecture.py | 140 ++++++++++-------- 3 files changed, 117 insertions(+), 61 deletions(-) create mode 100644 configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py diff --git a/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py b/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py new file mode 100644 index 000000000..d1fc673c5 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py @@ -0,0 +1,30 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_names=['input'], + output_names=['output'], + input_shape=None, + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 0: 'batch' + } + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]), + codebase_config=dict(type='mmcls', task='Classification'), + function_record_to_pop=[ + 'mmcls.models.classifiers.ImageClassifier.forward', + 'mmcls.models.classifiers.BaseClassifier.forward' + ], +) diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 161fe3eb1..93e3897bc 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -1,4 +1,7 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] +_base_ = [ + 'mmcls::resnet/resnet18_8xb32_in1k.py', + '../deploy_cfgs/mmcls/classification_openvino_dynamic-224x224.py' +] train_dataloader = dict(batch_size=32) @@ -20,7 +23,6 @@ ) float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 -deploy_cfg = r'G:\projects\openmmlab\mmdeploy\configs\mmcls\classification_openvino_dynamic-224x224.py' # noqa: E501 model = dict( _delete_=True, @@ -34,7 +36,7 @@ # convert image from BGR to RGB to_rgb=True), architecture=_base_.model, - deploy_cfg=deploy_cfg, + deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 627ad3d3a..4207ff20e 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from mmengine.config import Config from mmengine.model import MMDistributedDataParallel from mmengine.runner import load_checkpoint from mmengine.structures import BaseDataElement @@ -98,41 +99,6 @@ def reset_observer_and_fakequant_statistics(self, model): module.scale.data = torch.ones_like(module.scale) module.zero_point.data = torch.zeros_like(module.zero_point) - def _load(self, module, prefix, src_state_dict): - """Copies parameters and buffers from :attr:`src_state_dict` into this - module and its descendants. - - If the shape of the parameters and buffers between - :attr:`src_state_dict` and this module are different, we will reshape - the tensor shape of the parameters and buffers. - """ - for name, child in module._modules.items(): - if module is None: - continue - child_name = f'{prefix}{name}' - if isinstance(child, FakeQuantizeBase): - for name, param in child.named_parameters(): - param_name = f'{child_name}.{name}' - src_param = src_state_dict[param_name] - if src_param.shape == param.shape: - param.data.copy_(src_param) - else: - requirs_grad = param.requires_grad - param.requires_grad = False - param.resize_(src_param.shape) - param.requires_grad = requirs_grad - param.data.copy_(src_param) - for name, buffer in child.named_buffers(): - buffer_name = f'{child_name}.{name}' - src_buffer = src_state_dict[buffer_name] - if src_buffer.shape == buffer.shape: - buffer.data.copy_(src_buffer) - else: - buffer.resize_(src_buffer.shape) - buffer.data.copy_(src_buffer) - else: - self._load(child, f'{child_name}.', src_state_dict) - def sync_qparams(self, src_mode: str): """Sync all quantize parameters in different `forward_modes`. We could have more than one forward mode to generate graphs, each mode will @@ -143,27 +109,64 @@ def sync_qparams(self, src_mode: str): src_mode (str): The modes of forward method. Note: - `_load()` method recursively traverses all module to sync + `traverse()` method recursively traverses all modules to sync quantized graph generated from different `forward_modes`. This is because We have different mode ('tensor', 'predict', 'loss') in OpenMMLab architecture which have different graph in some subtle ways, so we need to sync them here. """ + def traverse(module, prefix): + for name, child in module._modules.items(): + if module is None: + continue + child_name = f'{prefix}{name}' + if isinstance(child, FakeQuantizeBase): + for name, param in child.named_parameters(): + param_name = f'{child_name}.{name}' + src_param = src_state_dict[param_name] + if src_param.shape == param.shape: + param.data.copy_(src_param) + else: + requirs_grad = param.requires_grad + param.requires_grad = False + param.resize_(src_param.shape) + param.requires_grad = requirs_grad + param.data.copy_(src_param) + for name, buffer in child.named_buffers(): + buffer_name = f'{child_name}.{name}' + src_buffer = src_state_dict[buffer_name] + if src_buffer.shape == buffer.shape: + buffer.data.copy_(src_buffer) + else: + buffer.resize_(src_buffer.shape) + buffer.data.copy_(src_buffer) + else: + traverse(child, f'{child_name}.') + src_state_dict = self.qmodels[src_mode].state_dict() for mode in self.forward_modes: if mode == src_mode: continue - self._load(self.qmodels[mode], '', src_state_dict) + traverse(self.qmodels[mode], '') def _get_rewriter_context_in_mmdeploy(self, deploy_cfg): """Get rewriter context in mmdeploy according to the deploy related config.""" from mmdeploy.apis.onnx.passes import optimize_onnx + from mmdeploy.codebase import import_codebase from mmdeploy.core import RewriterContext - from mmdeploy.utils import (IR, Backend, get_backend, get_dynamic_axes, - get_ir_config, get_onnx_config, - load_config) + from mmdeploy.utils import (IR, Backend, get_backend, get_codebase, + get_dynamic_axes, get_ir_config, + get_onnx_config) + from mmdeploy.utils.config_utils import get_codebase_external_module + + if isinstance(deploy_cfg, str): + deploy_cfg = Config.fromfile(deploy_cfg) + + codebase = get_codebase(deploy_cfg) + custom_module_list = get_codebase_external_module(deploy_cfg) + import_codebase(codebase, custom_module_list) def _add_or_update(cfg: dict, key: str, val: Any): if key in cfg and isinstance(cfg[key], dict) and isinstance( @@ -172,11 +175,8 @@ def _add_or_update(cfg: dict, key: str, val: Any): else: cfg[key] = val - if isinstance(deploy_cfg, str): - deploy_cfg, = load_config(deploy_cfg) context_info = dict() deploy_cfg = copy.deepcopy(deploy_cfg) - context_info['deploy_cfg'] = deploy_cfg backend = get_backend(deploy_cfg).value @@ -226,7 +226,42 @@ def _add_or_update(cfg: dict, key: str, val: Any): onnx_custom_passes = optimize_onnx if optimize else None context_info['onnx_custom_passes'] = onnx_custom_passes - return RewriterContext(**context_info) + rewriter_context = RewriterContext(**context_info) + + # Hard codes to delete user-specific rewriters from + # `RewriterContext._rewriter_manager`. + # We use the model which is rewritten by mmdeploy to build quantized + # models. However not all the modules, functions and symbolic rewritten + # by mmdeploy need to be rewritten in mmrazor. For example, mmdeploy + # rewrite `mmcls.models.classifiers.ImageClassifier.forward` and + # `mmcls.models.classifiers.BaseClassifier.forward` for deployment. + # But they can't be rewritten by mmrazor as ptq and qat are done in + # mmrazor. So to ensure ptq and qat proceed normally, we have to remove + # these record from `RewriterContext._rewriter_manager`. + + # We have to deepcopy rewriter_context here to delete records safely. + rewriter_context = copy.deepcopy(rewriter_context) + module_record_to_pop = deploy_cfg.get('module_record_to_pop', []) + function_record_to_pop = deploy_cfg.get('function_record_to_pop', []) + symbolic_record_to_pop = deploy_cfg.get('symbolic_record_to_pop', []) + for record in module_record_to_pop: + records = rewriter_context._rewriter_manager.module_rewriter.\ + _registry._rewrite_records + if record in records: + records.pop(record) + for record in function_record_to_pop: + records = rewriter_context._rewriter_manager.function_rewriter.\ + _registry._rewrite_records + if record in records: + records.pop(record) + + for record in symbolic_record_to_pop: + records = rewriter_context._rewriter_manager.symbolic_rewriter.\ + _registry._rewrite_records + if record in records: + records.pop(record) + + return rewriter_context def _build_qmodels(self, model: BaseModel): """Build quantized models from the given model. @@ -260,6 +295,7 @@ def _build_qmodels(self, model: BaseModel): qmodels = nn.ModuleDict() for mode in self.forward_modes: concrete_args = {'mode': mode} + # todo: support qat. with rewriter_context: observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module @@ -302,7 +338,7 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]): data = self.data_preprocessor(data, False) return self._run_forward(data, mode='predict') - def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)): + def get_deploy_model(self): """Prepare for deploy to the backend with mmdeploy, which will be used in mmdeploy, and usually includes as follows: @@ -317,19 +353,7 @@ def post_process_for_mmdeploy(self, dummy_input: Tuple = (1, 3, 224, 224)): fp32_model = self.architecture self.quantizer.convert_batchnorm2d(fp32_model) - observed_model = self.quantizer.prepare(fp32_model, - {'mode': 'predict'}) - - if dummy_input is not None: - # modify the tensor shape of parameters and buffers in - # observed_model - tensor_model = self.quantizer.prepare(fp32_model, - {'mode': 'tensor'}) - device = next(tensor_model.parameters()).device - dummy_input = torch.randn(dummy_input).to(device) - tensor_model(dummy_input, None, 'tensor') - src_state_dict = tensor_model.state_dict() - self._load(observed_model, '', src_state_dict) + observed_model = self.quantizer.prepare(fp32_model) observed_model.load_state_dict(quantized_state_dict) From dc90f324a01ed16552e14f1f8ab47c8ff9885d9d Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 24 Mar 2023 21:54:30 +0800 Subject: [PATCH 5/7] add det config --- .../detection_openvino_dynamic-800x1344.py | 50 +++++++++++++++++++ ...openvino_retina_r50_1x_coco_calib32xb32.py | 8 +-- .../quantization/mm_architecture.py | 38 ++++++++++++-- 3 files changed, 90 insertions(+), 6 deletions(-) create mode 100644 configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py new file mode 100644 index 000000000..f7fc5d064 --- /dev/null +++ b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py @@ -0,0 +1,50 @@ +deploy_cfg = dict( + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + save_file='end2end.onnx', + input_shape=None, + input_names=['input'], + output_names=['dets', 'labels'], + optimize=True, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'dets': { + 0: 'batch', + 1: 'num_dets', + }, + 'labels': { + 0: 'batch', + 1: 'num_dets', + }, + }), + backend_config=dict( + type='openvino', + model_inputs=[dict(opt_shapes=dict(input=[1, 3, 800, 1344]))]), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + model_type='end2end', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, # for YOLOv3 + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )), + function_record_to_pop=[ + 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', + 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', + 'mmdet.models.detectors.single_stage_instance_seg.' + 'SingleStageInstanceSegmentor.forward', + 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.' + 'predict_by_feat' + ]) diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 31d42bd83..109a7ee04 100644 --- a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -1,4 +1,7 @@ -_base_ = ['mmdet::retinanet/retinanet_r50_fpn_1x_coco.py'] +_base_ = [ + 'mmdet::retinanet/retinanet_r50_fpn_1x_coco.py', + '../deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py' +] train_dataloader = dict(batch_size=32) @@ -10,7 +13,6 @@ retina = _base_.model float_checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_r50_fpn_1x_coco/retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth' # noqa: E501 -deploy_cfg = r'G:\projects\openmmlab\mmdeploy\configs\mmdet\detection\detection_openvino_dynamic-800x1344-quantize.py', # noqa: E501 global_qconfig = dict( w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), @@ -33,7 +35,7 @@ bgr_to_rgb=True, pad_size_divisor=32), architecture=retina, - deploy_cfg=deploy_cfg, + deploy_cfg=_base_.deploy_cfg, float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 4207ff20e..d4081d96e 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -72,6 +72,8 @@ def __init__(self, self.quantizer = MODELS.build(quantizer) self.input_shapes = input_shapes self.forward_modes = forward_modes + if isinstance(deploy_cfg, str): + deploy_cfg = Config.fromfile(deploy_cfg) self.deploy_cfg = deploy_cfg # Replace syncbn and _BatchNormXd (in mmengine) with batchnorm2d @@ -161,9 +163,6 @@ def _get_rewriter_context_in_mmdeploy(self, deploy_cfg): get_onnx_config) from mmdeploy.utils.config_utils import get_codebase_external_module - if isinstance(deploy_cfg, str): - deploy_cfg = Config.fromfile(deploy_cfg) - codebase = get_codebase(deploy_cfg) custom_module_list = get_codebase_external_module(deploy_cfg) import_codebase(codebase, custom_module_list) @@ -292,6 +291,32 @@ def _build_qmodels(self, model: BaseModel): rewriter_context = self._get_rewriter_context_in_mmdeploy( self.deploy_cfg) + # module_record_to_pop = self.deploy_cfg.get('module_record_to_pop', + # []) + # function_record_to_pop = self.deploy_cfg.get( + # 'function_record_to_pop', []) + # symbolic_record_to_pop = self.deploy_cfg.get( + # 'symbolic_record_to_pop', []) + # module_record_backup = {} + # function_record_backup = {} + # symbolic_record_backup = {} + # for record in module_record_to_pop: + # records = rewriter_context._rewriter_manager.module_rewriter. \ + # _registry._rewrite_records + # if record in records: + # module_record_backup[record] = records.pop(record) + # for record in function_record_to_pop: + # records = rewriter_context._rewriter_manager.function_rewriter. \ + # _registry._rewrite_records + # if record in records: + # function_record_backup[record] = records.pop(record) + # + # for record in symbolic_record_to_pop: + # records = rewriter_context._rewriter_manager.symbolic_rewriter. \ + # _registry._rewrite_records + # if record in records: + # symbolic_record_backup[record] = records.pop(record) + qmodels = nn.ModuleDict() for mode in self.forward_modes: concrete_args = {'mode': mode} @@ -300,6 +325,13 @@ def _build_qmodels(self, model: BaseModel): observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module + # rewriter_context._rewriter_manager.module_rewriter. \ + # _registry._rewrite_records.update(module_record_backup) + # rewriter_context._rewriter_manager.function_rewriter. \ + # _registry._rewrite_records.update(function_record_backup) + # rewriter_context._rewriter_manager.symbolic_rewriter. \ + # _registry._rewrite_records.update(symbolic_record_backup) + # data_samples can not be None in detectors during prediction. # But we need to make the dummy prediction in _build_qmodels. # It is more convenient to use `tensor` mode. From db4e966172760ba45198c55b97a8454aa8584533 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Tue, 28 Mar 2023 20:09:22 +0800 Subject: [PATCH 6/7] replace deepcopy --- .../detection_openvino_dynamic-800x1344.py | 10 +- .../quantization/mm_architecture.py | 93 ++++++------------- 2 files changed, 30 insertions(+), 73 deletions(-) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py index f7fc5d064..3417f9476 100644 --- a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py +++ b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py @@ -39,12 +39,4 @@ pre_top_k=5000, keep_top_k=100, background_label_id=-1, - )), - function_record_to_pop=[ - 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', - 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', - 'mmdet.models.detectors.single_stage_instance_seg.' - 'SingleStageInstanceSegmentor.forward', - 'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.' - 'predict_by_feat' - ]) + ))) diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index d4081d96e..41ea0e2be 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -225,42 +225,32 @@ def _add_or_update(cfg: dict, key: str, val: Any): onnx_custom_passes = optimize_onnx if optimize else None context_info['onnx_custom_passes'] = onnx_custom_passes - rewriter_context = RewriterContext(**context_info) - - # Hard codes to delete user-specific rewriters from - # `RewriterContext._rewriter_manager`. - # We use the model which is rewritten by mmdeploy to build quantized - # models. However not all the modules, functions and symbolic rewritten - # by mmdeploy need to be rewritten in mmrazor. For example, mmdeploy - # rewrite `mmcls.models.classifiers.ImageClassifier.forward` and - # `mmcls.models.classifiers.BaseClassifier.forward` for deployment. - # But they can't be rewritten by mmrazor as ptq and qat are done in - # mmrazor. So to ensure ptq and qat proceed normally, we have to remove - # these record from `RewriterContext._rewriter_manager`. - - # We have to deepcopy rewriter_context here to delete records safely. - rewriter_context = copy.deepcopy(rewriter_context) - module_record_to_pop = deploy_cfg.get('module_record_to_pop', []) - function_record_to_pop = deploy_cfg.get('function_record_to_pop', []) - symbolic_record_to_pop = deploy_cfg.get('symbolic_record_to_pop', []) - for record in module_record_to_pop: - records = rewriter_context._rewriter_manager.module_rewriter.\ - _registry._rewrite_records - if record in records: - records.pop(record) - for record in function_record_to_pop: - records = rewriter_context._rewriter_manager.function_rewriter.\ - _registry._rewrite_records - if record in records: - records.pop(record) + return RewriterContext(**context_info) + + def _pop_function_record_in_rewriter_context(self, rewriter_context): + """Delete user-specific rewriters from + `RewriterContext._rewriter_manager`. We use the model which is + rewritten by mmdeploy to build quantized models. However not all the + functions rewritten by mmdeploy need to be rewritten in mmrazor. For + example, mmdeploy rewrite + `mmcls.models.classifiers.ImageClassifier.forward` and + `mmcls.models.classifiers.BaseClassifier.forward` for deployment. But + they can't be rewritten by mmrazor as ptq and qat are done in mmrazor. + So to ensure ptq and qat proceed normally, we have to remove these + record from `RewriterContext._rewriter_manager`. - for record in symbolic_record_to_pop: - records = rewriter_context._rewriter_manager.symbolic_rewriter.\ + Args: + rewriter_context (RewriterContext): The RewriterContext used in + mmdeploy. + """ + skipped_methods = getattr(self.quantizer.tracer, 'skipped_methods', []) + function_record_backup = {} + for record in skipped_methods: + records = rewriter_context._rewriter_manager.function_rewriter. \ _registry._rewrite_records if record in records: - records.pop(record) - - return rewriter_context + function_record_backup[record] = records.pop(record) + return function_record_backup def _build_qmodels(self, model: BaseModel): """Build quantized models from the given model. @@ -291,31 +281,9 @@ def _build_qmodels(self, model: BaseModel): rewriter_context = self._get_rewriter_context_in_mmdeploy( self.deploy_cfg) - # module_record_to_pop = self.deploy_cfg.get('module_record_to_pop', - # []) - # function_record_to_pop = self.deploy_cfg.get( - # 'function_record_to_pop', []) - # symbolic_record_to_pop = self.deploy_cfg.get( - # 'symbolic_record_to_pop', []) - # module_record_backup = {} - # function_record_backup = {} - # symbolic_record_backup = {} - # for record in module_record_to_pop: - # records = rewriter_context._rewriter_manager.module_rewriter. \ - # _registry._rewrite_records - # if record in records: - # module_record_backup[record] = records.pop(record) - # for record in function_record_to_pop: - # records = rewriter_context._rewriter_manager.function_rewriter. \ - # _registry._rewrite_records - # if record in records: - # function_record_backup[record] = records.pop(record) - # - # for record in symbolic_record_to_pop: - # records = rewriter_context._rewriter_manager.symbolic_rewriter. \ - # _registry._rewrite_records - # if record in records: - # symbolic_record_backup[record] = records.pop(record) + # Pop function records in `quantizer.tracer.skipped_method` temporarily + function_record_backup = self._pop_function_record_in_rewriter_context( + rewriter_context) qmodels = nn.ModuleDict() for mode in self.forward_modes: @@ -325,12 +293,9 @@ def _build_qmodels(self, model: BaseModel): observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module - # rewriter_context._rewriter_manager.module_rewriter. \ - # _registry._rewrite_records.update(module_record_backup) - # rewriter_context._rewriter_manager.function_rewriter. \ - # _registry._rewrite_records.update(function_record_backup) - # rewriter_context._rewriter_manager.symbolic_rewriter. \ - # _registry._rewrite_records.update(symbolic_record_backup) + # Add these popped function records back. + rewriter_context._rewriter_manager.function_rewriter. \ + _registry._rewrite_records.update(function_record_backup) # data_samples can not be None in detectors during prediction. # But we need to make the dummy prediction in _build_qmodels. From a7dcc139a2bc0bb62db9601323cfd4b4f61052d4 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Tue, 28 Mar 2023 20:16:07 +0800 Subject: [PATCH 7/7] pop detectors' forward --- .../mmdet/detection_openvino_dynamic-800x1344.py | 8 +++++++- mmrazor/models/algorithms/quantization/mm_architecture.py | 5 ++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py index 3417f9476..f8122ecaa 100644 --- a/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py +++ b/configs/quantization/deploy_cfgs/mmdet/detection_openvino_dynamic-800x1344.py @@ -39,4 +39,10 @@ pre_top_k=5000, keep_top_k=100, background_label_id=-1, - ))) + )), + function_record_to_pop=[ + 'mmdet.models.detectors.single_stage.SingleStageDetector.forward', + 'mmdet.models.detectors.two_stage.TwoStageDetector.forward', + 'mmdet.models.detectors.single_stage_instance_seg.' + 'SingleStageInstanceSegmentor.forward' + ]) diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 41ea0e2be..53dd0f6cf 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -244,8 +244,11 @@ def _pop_function_record_in_rewriter_context(self, rewriter_context): mmdeploy. """ skipped_methods = getattr(self.quantizer.tracer, 'skipped_methods', []) + function_record_to_pop = self.deploy_cfg.get('function_record_to_pop', + []) + function_record_to_pop.extend(skipped_methods) function_record_backup = {} - for record in skipped_methods: + for record in function_record_to_pop: records = rewriter_context._rewriter_manager.function_rewriter. \ _registry._rewrite_records if record in records: