From 8198f27d0ddb2780be91568eb8b9a1b8ffcbff57 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Mon, 17 Apr 2023 15:01:19 +0800 Subject: [PATCH] [BUG] Fix quantization loop (#507) * fix quantization loop * fix quant loop * fix quant loop * fix qat configs * [Bug] Fix ci converage setting (#508) fix ci converage * [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss * add freeze_bn_begin to lsq * delete useless codes --------- Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com> --- ... lsq_openvino_resnet18_8xb32_100e_in1k.py} | 8 ++- .../lsq_openvino_resnet18_8xb32_10e_in1k.py | 63 +++++++++++++++++++ .../qat_openvino_resnet18_10e_8xb32_in1k.py | 62 ++++++++++++++++++ mmrazor/engine/runner/quantization_loops.py | 61 ++++++++++++------ 4 files changed, 172 insertions(+), 22 deletions(-) rename configs/quantization/qat/{lsq_openvino_resnet18_8xb32_in1k.py => lsq_openvino_resnet18_8xb32_100e_in1k.py} (90%) create mode 100644 configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py create mode 100644 configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py similarity index 90% rename from configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py rename to configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py index 0b79232f8..00e424141 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py @@ -59,6 +59,10 @@ _delete_=True, type='mmrazor.LSQEpochBasedLoop', max_epochs=100, - val_interval=1) + val_interval=1, + freeze_bn_begin=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -test_cfg = val_cfg + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py new file mode 100644 index 000000000..f931ddaf5 --- /dev/null +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py @@ -0,0 +1,63 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=10, + val_interval=1, + freeze_bn_begin=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py b/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py new file mode 100644 index 000000000..261af7abb --- /dev/null +++ b/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py @@ -0,0 +1,62 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=False) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=10, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 18caf06f5..764c8605d 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -11,11 +11,13 @@ from torch.nn.intrinsic.qat import freeze_bn_stats except ImportError: from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') enable_observer = get_placeholder('torch>=1.13') freeze_bn_stats = get_placeholder('torch>=1.13') +from mmengine.dist import all_reduce_params, is_distributed from torch.utils.data import DataLoader from mmrazor.models import register_torch_fake_quants, register_torch_observers @@ -69,7 +71,18 @@ def prepare_for_run_epoch(self): """Toggle the state of the observers and fake quantizers before qat training.""" self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(enable_observer) + + # The initialized _epoch equals to 0 so _epoch + 1 + # equal to the current epoch + if (self.disable_observer_begin > 0 + and self._epoch + 1 >= self.disable_observer_begin): + self.runner.model.apply(disable_observer) + else: + self.runner.model.apply(enable_observer) + + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) def prepare_for_val(self): """Toggle the state of the observers and fake quantizers before @@ -89,8 +102,6 @@ def run(self): if (self.runner.val_loop is not None and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): - # observer disabled during evaluation - self.prepare_for_val() self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -100,18 +111,13 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # The initialized _epoch equals to 0 so _epoch + 1 - # equal to the current epoch - if self._epoch + 1 >= self.disable_observer_begin: - self.runner.model.apply(disable_observer) - - if self._epoch + 1 >= self.freeze_bn_begin: - self.runner.model.apply(freeze_bn_stats) - for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() self.runner.call_hook('after_train_epoch') self._epoch += 1 @@ -156,11 +162,16 @@ def __init__( dynamic_intervals=dynamic_intervals) self.is_first_batch = True + self.distributed = is_distributed() def prepare_for_run_epoch(self): """Toggle the state of the observers and fake quantizers before qat training.""" - pass + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) + + self.runner.model.apply(enable_param_learning) def prepare_for_val(self): """Toggle the state of the observers and fake quantizers before @@ -172,20 +183,30 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # TODO freeze bn - if self._epoch + 1 >= self.freeze_bn_begin: - self.runner.model.apply(freeze_bn_stats) - for idx, data_batch in enumerate(self.dataloader): if self.is_first_batch: - # lsq init - self.is_first_batch = False + # lsq observer init self.runner.model.apply(enable_static_estimate) - else: - self.runner.model.apply(enable_param_learning) + self.run_iter(idx, data_batch) + if self.is_first_batch: + # In the first batch, scale in LearnableFakeQuantize is + # calculated through lsq observer. As the values of `scale` of + # different observers in different rank are usually different, + # we have to sync the `scale` here. + if self.distributed: + all_reduce_params( + self.runner.model.parameters(), op='mean') + + # Change back to param learning mode + self.is_first_batch = False + self.runner.model.apply(enable_param_learning) + self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() self.runner.call_hook('after_train_epoch') self._epoch += 1