Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Using rewriter in mmrazor when building qmodels. #490

Merged
merged 8 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,5 @@
'mmdet.models.detectors.single_stage.SingleStageDetector.forward',
'mmdet.models.detectors.two_stage.TwoStageDetector.forward',
'mmdet.models.detectors.single_stage_instance_seg.'
'SingleStageInstanceSegmentor.forward',
'mmdet.models.dense_heads.base_dense_head.BaseDenseHead.'
'predict_by_feat'
'SingleStageInstanceSegmentor.forward'
])
96 changes: 32 additions & 64 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,42 +225,35 @@ def _add_or_update(cfg: dict, key: str, val: Any):
onnx_custom_passes = optimize_onnx if optimize else None
context_info['onnx_custom_passes'] = onnx_custom_passes

rewriter_context = RewriterContext(**context_info)

# Hard codes to delete user-specific rewriters from
# `RewriterContext._rewriter_manager`.
# We use the model which is rewritten by mmdeploy to build quantized
# models. However not all the modules, functions and symbolic rewritten
# by mmdeploy need to be rewritten in mmrazor. For example, mmdeploy
# rewrite `mmcls.models.classifiers.ImageClassifier.forward` and
# `mmcls.models.classifiers.BaseClassifier.forward` for deployment.
# But they can't be rewritten by mmrazor as ptq and qat are done in
# mmrazor. So to ensure ptq and qat proceed normally, we have to remove
# these record from `RewriterContext._rewriter_manager`.

# We have to deepcopy rewriter_context here to delete records safely.
rewriter_context = copy.deepcopy(rewriter_context)
module_record_to_pop = deploy_cfg.get('module_record_to_pop', [])
function_record_to_pop = deploy_cfg.get('function_record_to_pop', [])
symbolic_record_to_pop = deploy_cfg.get('symbolic_record_to_pop', [])
for record in module_record_to_pop:
records = rewriter_context._rewriter_manager.module_rewriter.\
_registry._rewrite_records
if record in records:
records.pop(record)
for record in function_record_to_pop:
records = rewriter_context._rewriter_manager.function_rewriter.\
_registry._rewrite_records
if record in records:
records.pop(record)
return RewriterContext(**context_info)

def _pop_function_record_in_rewriter_context(self, rewriter_context):
"""Delete user-specific rewriters from
`RewriterContext._rewriter_manager`. We use the model which is
rewritten by mmdeploy to build quantized models. However not all the
functions rewritten by mmdeploy need to be rewritten in mmrazor. For
example, mmdeploy rewrite
`mmcls.models.classifiers.ImageClassifier.forward` and
`mmcls.models.classifiers.BaseClassifier.forward` for deployment. But
they can't be rewritten by mmrazor as ptq and qat are done in mmrazor.
So to ensure ptq and qat proceed normally, we have to remove these
record from `RewriterContext._rewriter_manager`.

for record in symbolic_record_to_pop:
records = rewriter_context._rewriter_manager.symbolic_rewriter.\
Args:
rewriter_context (RewriterContext): The RewriterContext used in
mmdeploy.
"""
skipped_methods = getattr(self.quantizer.tracer, 'skipped_methods', [])
function_record_to_pop = self.deploy_cfg.get('function_record_to_pop',
[])
function_record_to_pop.extend(skipped_methods)
function_record_backup = {}
for record in function_record_to_pop:
records = rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records
if record in records:
records.pop(record)

return rewriter_context
function_record_backup[record] = records.pop(record)
return function_record_backup

def _build_qmodels(self, model: BaseModel):
"""Build quantized models from the given model.
Expand Down Expand Up @@ -291,31 +284,9 @@ def _build_qmodels(self, model: BaseModel):
rewriter_context = self._get_rewriter_context_in_mmdeploy(
self.deploy_cfg)

# module_record_to_pop = self.deploy_cfg.get('module_record_to_pop',
# [])
# function_record_to_pop = self.deploy_cfg.get(
# 'function_record_to_pop', [])
# symbolic_record_to_pop = self.deploy_cfg.get(
# 'symbolic_record_to_pop', [])
# module_record_backup = {}
# function_record_backup = {}
# symbolic_record_backup = {}
# for record in module_record_to_pop:
# records = rewriter_context._rewriter_manager.module_rewriter. \
# _registry._rewrite_records
# if record in records:
# module_record_backup[record] = records.pop(record)
# for record in function_record_to_pop:
# records = rewriter_context._rewriter_manager.function_rewriter. \
# _registry._rewrite_records
# if record in records:
# function_record_backup[record] = records.pop(record)
#
# for record in symbolic_record_to_pop:
# records = rewriter_context._rewriter_manager.symbolic_rewriter. \
# _registry._rewrite_records
# if record in records:
# symbolic_record_backup[record] = records.pop(record)
# Pop function records in `quantizer.tracer.skipped_method` temporarily
function_record_backup = self._pop_function_record_in_rewriter_context(
rewriter_context)

qmodels = nn.ModuleDict()
for mode in self.forward_modes:
Expand All @@ -325,12 +296,9 @@ def _build_qmodels(self, model: BaseModel):
observed_module = self.quantizer.prepare(model, concrete_args)
qmodels[mode] = observed_module

# rewriter_context._rewriter_manager.module_rewriter. \
# _registry._rewrite_records.update(module_record_backup)
# rewriter_context._rewriter_manager.function_rewriter. \
# _registry._rewrite_records.update(function_record_backup)
# rewriter_context._rewriter_manager.symbolic_rewriter. \
# _registry._rewrite_records.update(symbolic_record_backup)
# Add these popped function records back.
rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records.update(function_record_backup)

# data_samples can not be None in detectors during prediction.
# But we need to make the dummy prediction in _build_qmodels.
Expand Down