Skip to content

Commit

Permalink
[Feature] Using rewriter in mmrazor when building qmodels. (#490)
Browse files Browse the repository at this point in the history
* add rewriter

* add deploy_cfg arg

* modify post_process_for_mmdeploy

* fix bugs

* add det config

* replace deepcopy

* pop detectors' forward
  • Loading branch information
HIT-cwh committed Mar 29, 2023
1 parent 639c291 commit 2dbb633
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,5 @@
'mmdet.models.detectors.single_stage.SingleStageDetector.forward',
'mmdet.models.detectors.two_stage.TwoStageDetector.forward',
'mmdet.models.detectors.single_stage_instance_seg.'
'SingleStageInstanceSegmentor.forward',
'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.'
'predict_by_feat'
'SingleStageInstanceSegmentor.forward'
])
96 changes: 32 additions & 64 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,42 +225,35 @@ def _add_or_update(cfg: dict, key: str, val: Any):
onnx_custom_passes = optimize_onnx if optimize else None
context_info['onnx_custom_passes'] = onnx_custom_passes

rewriter_context = RewriterContext(**context_info)

# Hard codes to delete user-specific rewriters from
# `RewriterContext._rewriter_manager`.
# We use the model which is rewritten by mmdeploy to build quantized
# models. However not all the modules, functions and symbolic rewritten
# by mmdeploy need to be rewritten in mmrazor. For example, mmdeploy
# rewrite `mmcls.models.classifiers.ImageClassifier.forward` and
# `mmcls.models.classifiers.BaseClassifier.forward` for deployment.
# But they can't be rewritten by mmrazor as ptq and qat are done in
# mmrazor. So to ensure ptq and qat proceed normally, we have to remove
# these record from `RewriterContext._rewriter_manager`.

# We have to deepcopy rewriter_context here to delete records safely.
rewriter_context = copy.deepcopy(rewriter_context)
module_record_to_pop = deploy_cfg.get('module_record_to_pop', [])
function_record_to_pop = deploy_cfg.get('function_record_to_pop', [])
symbolic_record_to_pop = deploy_cfg.get('symbolic_record_to_pop', [])
for record in module_record_to_pop:
records = rewriter_context._rewriter_manager.module_rewriter.\
_registry._rewrite_records
if record in records:
records.pop(record)
for record in function_record_to_pop:
records = rewriter_context._rewriter_manager.function_rewriter.\
_registry._rewrite_records
if record in records:
records.pop(record)
return RewriterContext(**context_info)

def _pop_function_record_in_rewriter_context(self, rewriter_context):
"""Delete user-specific rewriters from
`RewriterContext._rewriter_manager`. We use the model which is
rewritten by mmdeploy to build quantized models. However not all the
functions rewritten by mmdeploy need to be rewritten in mmrazor. For
example, mmdeploy rewrite
`mmcls.models.classifiers.ImageClassifier.forward` and
`mmcls.models.classifiers.BaseClassifier.forward` for deployment. But
they can't be rewritten by mmrazor as ptq and qat are done in mmrazor.
So to ensure ptq and qat proceed normally, we have to remove these
record from `RewriterContext._rewriter_manager`.
for record in symbolic_record_to_pop:
records = rewriter_context._rewriter_manager.symbolic_rewriter.\
Args:
rewriter_context (RewriterContext): The RewriterContext used in
mmdeploy.
"""
skipped_methods = getattr(self.quantizer.tracer, 'skipped_methods', [])
function_record_to_pop = self.deploy_cfg.get('function_record_to_pop',
[])
function_record_to_pop.extend(skipped_methods)
function_record_backup = {}
for record in function_record_to_pop:
records = rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records
if record in records:
records.pop(record)

return rewriter_context
function_record_backup[record] = records.pop(record)
return function_record_backup

def _build_qmodels(self, model: BaseModel):
"""Build quantized models from the given model.
Expand Down Expand Up @@ -291,31 +284,9 @@ def _build_qmodels(self, model: BaseModel):
rewriter_context = self._get_rewriter_context_in_mmdeploy(
self.deploy_cfg)

# module_record_to_pop = self.deploy_cfg.get('module_record_to_pop',
# [])
# function_record_to_pop = self.deploy_cfg.get(
# 'function_record_to_pop', [])
# symbolic_record_to_pop = self.deploy_cfg.get(
# 'symbolic_record_to_pop', [])
# module_record_backup = {}
# function_record_backup = {}
# symbolic_record_backup = {}
# for record in module_record_to_pop:
# records = rewriter_context._rewriter_manager.module_rewriter. \
# _registry._rewrite_records
# if record in records:
# module_record_backup[record] = records.pop(record)
# for record in function_record_to_pop:
# records = rewriter_context._rewriter_manager.function_rewriter. \
# _registry._rewrite_records
# if record in records:
# function_record_backup[record] = records.pop(record)
#
# for record in symbolic_record_to_pop:
# records = rewriter_context._rewriter_manager.symbolic_rewriter. \
# _registry._rewrite_records
# if record in records:
# symbolic_record_backup[record] = records.pop(record)
# Pop function records in `quantizer.tracer.skipped_method` temporarily
function_record_backup = self._pop_function_record_in_rewriter_context(
rewriter_context)

qmodels = nn.ModuleDict()
for mode in self.forward_modes:
Expand All @@ -325,12 +296,9 @@ def _build_qmodels(self, model: BaseModel):
observed_module = self.quantizer.prepare(model, concrete_args)
qmodels[mode] = observed_module

# rewriter_context._rewriter_manager.module_rewriter. \
# _registry._rewrite_records.update(module_record_backup)
# rewriter_context._rewriter_manager.function_rewriter. \
# _registry._rewrite_records.update(function_record_backup)
# rewriter_context._rewriter_manager.symbolic_rewriter. \
# _registry._rewrite_records.update(symbolic_record_backup)
# Add these popped function records back.
rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records.update(function_record_backup)

# data_samples can not be None in detectors during prediction.
# But we need to make the dummy prediction in _build_qmodels.
Expand Down

0 comments on commit 2dbb633

Please sign in to comment.