From 3d211ff839d792da23c88661a4963aebf8cfc306 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Thu, 13 Apr 2023 14:13:14 +0800 Subject: [PATCH 1/8] fix quantization loop --- mmrazor/engine/runner/quantization_loops.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 18caf06f5..7fb3d5fd9 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -90,7 +90,7 @@ def run(self): and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): # observer disabled during evaluation - self.prepare_for_val() + self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -102,16 +102,21 @@ def run_epoch(self) -> None: # The initialized _epoch equals to 0 so _epoch + 1 # equal to the current epoch - if self._epoch + 1 >= self.disable_observer_begin: + if (self.disable_observer_begin > 0 + and self._epoch + 1 >= self.disable_observer_begin): self.runner.model.apply(disable_observer) - if self._epoch + 1 >= self.freeze_bn_begin: + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): self.runner.model.apply(freeze_bn_stats) for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() self.runner.call_hook('after_train_epoch') self._epoch += 1 From eaa84f4103e623bc3adea8e14d38412669325b60 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 14 Apr 2023 15:10:47 +0800 Subject: [PATCH 2/8] fix quant loop --- .../qat/lsq_openvino_resnet18_8xb32_in1k.py | 13 ++-- mmrazor/engine/runner/quantization_loops.py | 59 ++++++++++++------- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py index 0b79232f8..ea3485f27 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py @@ -42,12 +42,7 @@ # learning policy param_scheduler = dict( - _delete_=True, - type='CosineAnnealingLR', - T_max=100, - by_epoch=True, - begin=0, - end=100) + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) model_wrapper_cfg = dict( type='mmrazor.MMArchitectureQuantDDP', @@ -58,7 +53,9 @@ train_cfg = dict( _delete_=True, type='mmrazor.LSQEpochBasedLoop', - max_epochs=100, + max_epochs=10, val_interval=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -test_cfg = val_cfg +# test_cfg = val_cfg + +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 7fb3d5fd9..86448bdbb 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -11,11 +11,13 @@ from torch.nn.intrinsic.qat import freeze_bn_stats except ImportError: from mmrazor.utils import get_placeholder + disable_observer = get_placeholder('torch>=1.13') enable_fake_quant = get_placeholder('torch>=1.13') enable_observer = get_placeholder('torch>=1.13') freeze_bn_stats = get_placeholder('torch>=1.13') +from mmengine.dist import all_reduce_params from torch.utils.data import DataLoader from mmrazor.models import register_torch_fake_quants, register_torch_observers @@ -69,7 +71,18 @@ def prepare_for_run_epoch(self): """Toggle the state of the observers and fake quantizers before qat training.""" self.runner.model.apply(enable_fake_quant) - self.runner.model.apply(enable_observer) + + # The initialized _epoch equals to 0 so _epoch + 1 + # equal to the current epoch + if (self.disable_observer_begin > 0 + and self._epoch + 1 >= self.disable_observer_begin): + self.runner.model.apply(disable_observer) + else: + self.runner.model.apply(enable_observer) + + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) def prepare_for_val(self): """Toggle the state of the observers and fake quantizers before @@ -89,8 +102,6 @@ def run(self): if (self.runner.val_loop is not None and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): - # observer disabled during evaluation - self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -100,16 +111,6 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # The initialized _epoch equals to 0 so _epoch + 1 - # equal to the current epoch - if (self.disable_observer_begin > 0 - and self._epoch + 1 >= self.disable_observer_begin): - self.runner.model.apply(disable_observer) - - if (self.freeze_bn_begin > 0 - and self._epoch + 1 >= self.freeze_bn_begin): - self.runner.model.apply(freeze_bn_stats) - for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) @@ -165,7 +166,11 @@ def __init__( def prepare_for_run_epoch(self): """Toggle the state of the observers and fake quantizers before qat training.""" - pass + if (self.freeze_bn_begin > 0 + and self._epoch + 1 >= self.freeze_bn_begin): + self.runner.model.apply(freeze_bn_stats) + + self.runner.model.apply(enable_param_learning) def prepare_for_val(self): """Toggle the state of the observers and fake quantizers before @@ -177,20 +182,30 @@ def run_epoch(self) -> None: self.runner.call_hook('before_train_epoch') self.runner.model.train() - # TODO freeze bn - if self._epoch + 1 >= self.freeze_bn_begin: - self.runner.model.apply(freeze_bn_stats) - for idx, data_batch in enumerate(self.dataloader): if self.is_first_batch: - # lsq init - self.is_first_batch = False + # lsq observer init self.runner.model.apply(enable_static_estimate) - else: - self.runner.model.apply(enable_param_learning) + self.run_iter(idx, data_batch) + if self.is_first_batch: + # In the first batch, scale in LearnableFakeQuantize is + # calculated through lsq observer. As the values of `scale` of + # different observers in different rank are usually different, + # we have to sync the `scale` here. + all_reduce_params(self.runner.model.parameters(), op='mean') + # Change back to param learning mode + self.is_first_batch = False + self.runner.model.apply(enable_param_learning) + + if idx > 100: + break + self.runner.model.sync_qparams(src_mode='loss') + # Make sure the registered buffer such as `observer_enabled` is + # correct in the saved checkpoint. + self.prepare_for_val() self.runner.call_hook('after_train_epoch') self._epoch += 1 From 1c82c86763a50f68711b1adff0ac8a929c4169c0 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 14 Apr 2023 15:20:39 +0800 Subject: [PATCH 3/8] fix quant loop --- mmrazor/engine/runner/quantization_loops.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index 86448bdbb..ec078c42f 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -17,7 +17,7 @@ enable_observer = get_placeholder('torch>=1.13') freeze_bn_stats = get_placeholder('torch>=1.13') -from mmengine.dist import all_reduce_params +from mmengine.dist import all_reduce_params, is_distributed from torch.utils.data import DataLoader from mmrazor.models import register_torch_fake_quants, register_torch_observers @@ -162,6 +162,7 @@ def __init__( dynamic_intervals=dynamic_intervals) self.is_first_batch = True + self.distributed = is_distributed() def prepare_for_run_epoch(self): """Toggle the state of the observers and fake quantizers before qat @@ -194,7 +195,10 @@ def run_epoch(self) -> None: # calculated through lsq observer. As the values of `scale` of # different observers in different rank are usually different, # we have to sync the `scale` here. - all_reduce_params(self.runner.model.parameters(), op='mean') + if self.distributed: + all_reduce_params( + self.runner.model.parameters(), op='mean') + # Change back to param learning mode self.is_first_batch = False self.runner.model.apply(enable_param_learning) From c36dfb88c28c39d8707c88b0a80e446b4b2b4912 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 14 Apr 2023 20:28:55 +0800 Subject: [PATCH 4/8] fix qat configs --- .../lsq_openvino_resnet18_8xb32_100e_in1k.py | 67 +++++++++++++++++++ ...> lsq_openvino_resnet18_8xb32_10e_in1k.py} | 3 +- .../qat_openvino_resnet18_10e_8xb32_in1k.py | 62 +++++++++++++++++ 3 files changed, 131 insertions(+), 1 deletion(-) create mode 100644 configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py rename configs/quantization/qat/{lsq_openvino_resnet18_8xb32_in1k.py => lsq_openvino_resnet18_8xb32_10e_in1k.py} (95%) create mode 100644 configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py new file mode 100644 index 000000000..55896c2c9 --- /dev/null +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py @@ -0,0 +1,67 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.LSQPerChannelObserver'), + a_observer=dict(type='mmrazor.LSQObserver'), + w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, + type='CosineAnnealingLR', + T_max=100, + by_epoch=True, + begin=0, + end=100) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=True) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.LSQEpochBasedLoop', + max_epochs=100, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py similarity index 95% rename from configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py rename to configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py index ea3485f27..287a36fc9 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py @@ -56,6 +56,7 @@ max_epochs=10, val_interval=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -# test_cfg = val_cfg +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. default_hooks = dict(sync=dict(type='SyncBuffersHook')) diff --git a/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py b/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py new file mode 100644 index 000000000..261af7abb --- /dev/null +++ b/configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py @@ -0,0 +1,62 @@ +_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] + +resnet = _base_.model +float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 + +global_qconfig = dict( + w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'), + a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), + a_fake_quant=dict(type='mmrazor.FakeQuantize'), + w_qscheme=dict( + qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True), + a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True), +) + +model = dict( + _delete_=True, + _scope_='mmrazor', + type='MMArchitectureQuant', + data_preprocessor=dict( + type='mmcls.ClsDataPreprocessor', + num_classes=1000, + # RGB format normalization parameters + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + # convert image from BGR to RGB + to_rgb=True), + architecture=resnet, + float_checkpoint=float_checkpoint, + quantizer=dict( + type='mmrazor.OpenVINOQuantizer', + global_qconfig=global_qconfig, + tracer=dict( + type='mmrazor.CustomTracer', + skipped_methods=[ + 'mmcls.models.heads.ClsHead._get_loss', + 'mmcls.models.heads.ClsHead._get_predictions' + ]))) + +optim_wrapper = dict( + optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001)) + +# learning policy +param_scheduler = dict( + _delete_=True, type='ConstantLR', factor=1.0, by_epoch=True) + +model_wrapper_cfg = dict( + type='mmrazor.MMArchitectureQuantDDP', + broadcast_buffers=False, + find_unused_parameters=False) + +# train, val, test setting +train_cfg = dict( + _delete_=True, + type='mmrazor.QATEpochBasedLoop', + max_epochs=10, + val_interval=1) +val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') + +# Make sure the buffer such as min_val/max_val in saved checkpoint is the same +# among different rank. +default_hooks = dict(sync=dict(type='SyncBuffersHook')) From aa9ef39c11965dccc567721f4e96ea679c249170 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Thu, 13 Apr 2023 19:25:35 +0800 Subject: [PATCH 5/8] [Bug] Fix ci converage setting (#508) fix ci converage --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4b99bced4..9ed4cb002 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -119,7 +119,7 @@ jobs: coverage report -m # Upload coverage report for python3.8 && pytorch1.12.0 cpu - name: Upload coverage to Codecov - if: ${{matrix.torch == '1.12.0' && matrix.python-version == '3.8'}} + if: ${{matrix.torch == '1.13.0' && matrix.python-version == '3.8'}} uses: codecov/codecov-action@v2 with: file: ./coverage.xml From b238341eff90dcec22e2e2b04e2f47346f946920 Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Mon, 17 Apr 2023 06:13:43 +0800 Subject: [PATCH 6/8] [Bug] Fix codecov (#509) * remove codecov in requirements * try to fix ci * del adaround loss --- .github/workflows/build.yml | 2 +- mmrazor/models/losses/adaround_loss.py | 87 -------------------------- requirements/tests.txt | 2 +- 3 files changed, 2 insertions(+), 89 deletions(-) delete mode 100644 mmrazor/models/losses/adaround_loss.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9ed4cb002..2c2b8ed21 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -120,7 +120,7 @@ jobs: # Upload coverage report for python3.8 && pytorch1.12.0 cpu - name: Upload coverage to Codecov if: ${{matrix.torch == '1.13.0' && matrix.python-version == '3.8'}} - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 with: file: ./coverage.xml flags: unittests diff --git a/mmrazor/models/losses/adaround_loss.py b/mmrazor/models/losses/adaround_loss.py deleted file mode 100644 index 76c97977d..000000000 --- a/mmrazor/models/losses/adaround_loss.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -import torch.nn as nn -from mmengine.logging import print_log - -from mmrazor.registry import MODELS - -_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) - - -@MODELS.register_module() -class AdaRoundLoss(nn.Module): - r'''loss function to calculate mse reconstruction loss and relaxation loss - use some tempdecay to balance the two losses. - ''' - - def __init__(self, - weight: float = 1., - iters: int = 10000, - beta_range: tuple = (20, 2), - warm_up: float = 0.0, - p: float = 2.): - self.weight = weight - self.loss_start = iters * warm_up - self.p = p - - self.temp_decay = LinearTempDecay( - iters, - warm_up=warm_up, - start_beta=beta_range[0], - end_beta=beta_range[1]) - self.count = 0 - - def forward(self, subgraph, pred, tgt): - """Compute the total loss for adaptive rounding: rec_loss is the - quadratic output reconstruction loss, round_loss is a regularization - term to optimize the rounding policy. - - :param pred: output from quantized model - :param tgt: output from FP model - :return: total loss function - """ - - def lp_loss(pred, tgt, p=2.0): - """loss function measured in L_p Norm.""" - return (pred - tgt).abs().pow(p).sum(1).mean() - - self.count += 1 - rec_loss = lp_loss(pred, tgt, p=self.p) - - beta = self.temp_decay(self.count) - if self.count < self.loss_start: - round_loss = 0 - else: - round_loss = 0 - for layer in subgraph.modules(): - if isinstance(layer, _ADAROUND_SUPPORT_TYPE): - round_vals = layer.weight_fake_quant.rectified_sigmoid() - round_loss += self.weight * (1 - ( - (round_vals - .5).abs() * 2).pow(beta)).sum() - - total_loss = rec_loss + round_loss - if self.count % 500 == 0: - print_log('Total loss:\t{:.3f} (rec_loss:{:.3f}, ' - 'round_loss:{:.3f})\tbeta={:.2f}\tcount={}'.format( - float(total_loss), float(rec_loss), - float(round_loss), beta, self.count)) - return total_loss - - -class LinearTempDecay: - - def __init__(self, t_max=10000, warm_up=0.2, start_beta=20, end_beta=2): - self.t_max = t_max - self.start_decay = warm_up * t_max - self.start_beta = start_beta - self.end_beta = end_beta - - def __call__(self, t): - if t < self.start_decay: - return self.start_beta - elif t > self.t_max: - return self.end_beta - else: - rel_t = (t - self.start_decay) / (self.t_max - self.start_decay) - return self.end_beta + (self.start_beta - self.end_beta) * \ - max(0.0, (1 - rel_t)) diff --git a/requirements/tests.txt b/requirements/tests.txt index e38249fcd..5980dc303 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,4 @@ -codecov +coverage flake8 interrogate isort==4.3.21 From af4e6b5396520e54099c9f3d41d9e41e1ff3fa43 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Fri, 14 Apr 2023 20:38:31 +0800 Subject: [PATCH 7/8] add freeze_bn_begin to lsq --- .../quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py | 3 ++- .../quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py index 55896c2c9..00e424141 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_100e_in1k.py @@ -59,7 +59,8 @@ _delete_=True, type='mmrazor.LSQEpochBasedLoop', max_epochs=100, - val_interval=1) + val_interval=1, + freeze_bn_begin=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') # Make sure the buffer such as min_val/max_val in saved checkpoint is the same diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py index 287a36fc9..f931ddaf5 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py @@ -54,7 +54,8 @@ _delete_=True, type='mmrazor.LSQEpochBasedLoop', max_epochs=10, - val_interval=1) + val_interval=1, + freeze_bn_begin=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') # Make sure the buffer such as min_val/max_val in saved checkpoint is the same From 6f8302f784a31b8037bdacaad5535b8039db38e1 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 17 Apr 2023 14:57:20 +0800 Subject: [PATCH 8/8] delete useless codes --- mmrazor/engine/runner/quantization_loops.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index ec078c42f..764c8605d 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -203,9 +203,6 @@ def run_epoch(self) -> None: self.is_first_batch = False self.runner.model.apply(enable_param_learning) - if idx > 100: - break - self.runner.model.sync_qparams(src_mode='loss') # Make sure the registered buffer such as `observer_enabled` is # correct in the saved checkpoint.