Skip to content

Commit

Permalink
[BUG] Fix quantization loop (#507)
Browse files Browse the repository at this point in the history
* fix quantization loop

* fix quant loop

* fix quant loop

* fix qat configs

* [Bug] Fix ci converage setting (#508)

fix ci converage

* [Bug] Fix codecov (#509)

* remove codecov in requirements

* try to fix ci

* del adaround loss

* add freeze_bn_begin to lsq

* delete useless codes

---------

Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>
  • Loading branch information
HIT-cwh and humu789 committed Apr 17, 2023
1 parent 2efb327 commit 8198f27
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
_delete_=True,
type='mmrazor.LSQEpochBasedLoop',
max_epochs=100,
val_interval=1)
val_interval=1,
freeze_bn_begin=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
test_cfg = val_cfg

# Make sure the buffer such as min_val/max_val in saved checkpoint is the same
# among different rank.
default_hooks = dict(sync=dict(type='SyncBuffersHook'))
63 changes: 63 additions & 0 deletions configs/quantization/qat/lsq_openvino_resnet18_8xb32_10e_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

resnet = _base_.model
float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501

global_qconfig = dict(
w_observer=dict(type='mmrazor.LSQPerChannelObserver'),
a_observer=dict(type='mmrazor.LSQObserver'),
w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
w_qscheme=dict(
qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True),
a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True),
)

model = dict(
_delete_=True,
_scope_='mmrazor',
type='MMArchitectureQuant',
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True),
architecture=resnet,
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True, type='ConstantLR', factor=1.0, by_epoch=True)

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.LSQEpochBasedLoop',
max_epochs=10,
val_interval=1,
freeze_bn_begin=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')

# Make sure the buffer such as min_val/max_val in saved checkpoint is the same
# among different rank.
default_hooks = dict(sync=dict(type='SyncBuffersHook'))
62 changes: 62 additions & 0 deletions configs/quantization/qat/qat_openvino_resnet18_10e_8xb32_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py']

resnet = _base_.model
float_checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501

global_qconfig = dict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
qdtype='qint8', bit=8, is_symmetry=True, is_symmetric_range=True),
a_qscheme=dict(qdtype='quint8', bit=8, is_symmetry=True),
)

model = dict(
_delete_=True,
_scope_='mmrazor',
type='MMArchitectureQuant',
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True),
architecture=resnet,
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
])))

optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.0001, momentum=0.9, weight_decay=0.0001))

# learning policy
param_scheduler = dict(
_delete_=True, type='ConstantLR', factor=1.0, by_epoch=True)

model_wrapper_cfg = dict(
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=False)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
max_epochs=10,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')

# Make sure the buffer such as min_val/max_val in saved checkpoint is the same
# among different rank.
default_hooks = dict(sync=dict(type='SyncBuffersHook'))
61 changes: 41 additions & 20 deletions mmrazor/engine/runner/quantization_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from torch.nn.intrinsic.qat import freeze_bn_stats
except ImportError:
from mmrazor.utils import get_placeholder

disable_observer = get_placeholder('torch>=1.13')
enable_fake_quant = get_placeholder('torch>=1.13')
enable_observer = get_placeholder('torch>=1.13')
freeze_bn_stats = get_placeholder('torch>=1.13')

from mmengine.dist import all_reduce_params, is_distributed
from torch.utils.data import DataLoader

from mmrazor.models import register_torch_fake_quants, register_torch_observers
Expand Down Expand Up @@ -69,7 +71,18 @@ def prepare_for_run_epoch(self):
"""Toggle the state of the observers and fake quantizers before qat
training."""
self.runner.model.apply(enable_fake_quant)
self.runner.model.apply(enable_observer)

# The initialized _epoch equals to 0 so _epoch + 1
# equal to the current epoch
if (self.disable_observer_begin > 0
and self._epoch + 1 >= self.disable_observer_begin):
self.runner.model.apply(disable_observer)
else:
self.runner.model.apply(enable_observer)

if (self.freeze_bn_begin > 0
and self._epoch + 1 >= self.freeze_bn_begin):
self.runner.model.apply(freeze_bn_stats)

def prepare_for_val(self):
"""Toggle the state of the observers and fake quantizers before
Expand All @@ -89,8 +102,6 @@ def run(self):
if (self.runner.val_loop is not None
and self._epoch >= self.val_begin
and self._epoch % self.val_interval == 0):
# observer disabled during evaluation
self.prepare_for_val()
self.runner.val_loop.run()

self.runner.call_hook('after_train')
Expand All @@ -100,18 +111,13 @@ def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()

# The initialized _epoch equals to 0 so _epoch + 1
# equal to the current epoch
if self._epoch + 1 >= self.disable_observer_begin:
self.runner.model.apply(disable_observer)

if self._epoch + 1 >= self.freeze_bn_begin:
self.runner.model.apply(freeze_bn_stats)

for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

self.runner.model.sync_qparams(src_mode='loss')
# Make sure the registered buffer such as `observer_enabled` is
# correct in the saved checkpoint.
self.prepare_for_val()
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down Expand Up @@ -156,11 +162,16 @@ def __init__(
dynamic_intervals=dynamic_intervals)

self.is_first_batch = True
self.distributed = is_distributed()

def prepare_for_run_epoch(self):
"""Toggle the state of the observers and fake quantizers before qat
training."""
pass
if (self.freeze_bn_begin > 0
and self._epoch + 1 >= self.freeze_bn_begin):
self.runner.model.apply(freeze_bn_stats)

self.runner.model.apply(enable_param_learning)

def prepare_for_val(self):
"""Toggle the state of the observers and fake quantizers before
Expand All @@ -172,20 +183,30 @@ def run_epoch(self) -> None:
self.runner.call_hook('before_train_epoch')
self.runner.model.train()

# TODO freeze bn
if self._epoch + 1 >= self.freeze_bn_begin:
self.runner.model.apply(freeze_bn_stats)

for idx, data_batch in enumerate(self.dataloader):
if self.is_first_batch:
# lsq init
self.is_first_batch = False
# lsq observer init
self.runner.model.apply(enable_static_estimate)
else:
self.runner.model.apply(enable_param_learning)

self.run_iter(idx, data_batch)

if self.is_first_batch:
# In the first batch, scale in LearnableFakeQuantize is
# calculated through lsq observer. As the values of `scale` of
# different observers in different rank are usually different,
# we have to sync the `scale` here.
if self.distributed:
all_reduce_params(
self.runner.model.parameters(), op='mean')

# Change back to param learning mode
self.is_first_batch = False
self.runner.model.apply(enable_param_learning)

self.runner.model.sync_qparams(src_mode='loss')
# Make sure the registered buffer such as `observer_enabled` is
# correct in the saved checkpoint.
self.prepare_for_val()
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down

0 comments on commit 8198f27

Please sign in to comment.