Skip to content

Commit

Permalink
[Feature] Add prepare_for_mmdeploy interface (#365)
Browse files Browse the repository at this point in the history
* remove useless code

* fix build graph module import bug

* refactor general quant

* rename GeneralQuant to MMArchitectureQuant

* fix some dtype bugs

* add prepare_for_mmdeploy interface

* update prepare for mmdeploy args

* fix some comments

Co-authored-by: humu789 <humu@pjlab.org.cn>
  • Loading branch information
pppppM and humu789 committed Dec 1, 2022
1 parent 22b8075 commit 71ad718
Show file tree
Hide file tree
Showing 20 changed files with 409 additions and 541 deletions.
1 change: 0 additions & 1 deletion configs/quantization/ptq/demo.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,14 @@

test_cfg = dict(
type='mmrazor.PTQLoop',

# reconstruction_cfg=dict(
# pattern='layer',
# loss=dict(
# type='mmrazor.AdaRoundLoss',
# iters=20000
# )
# )
)

model = dict(
_delete_=True,
type='mmrazor.GeneralQuant',
type='mmrazor.MMArchitectureQuant',
architecture=_base_.model,
quantizer=dict(
type='mmrazor.CustomQuantizer',
type='mmrazor.OpenvinoQuantizer',
is_qat=False,
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
Expand All @@ -27,16 +19,16 @@
qtype='affine',
w_observer=dict(type='mmrazor.MSEObserver'),
a_observer=dict(type='mmrazor.EMAMSEObserver'),
w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
a_fake_quant=dict(type='mmrazor.FakeQuantize'),
w_qscheme=dict(
bit=2,
is_symmetry=False,
bit=8,
is_symmetry=True,
is_per_channel=True,
is_pot_scale=False,
),
a_qscheme=dict(
bit=4,
bit=8,
is_symmetry=False,
is_per_channel=False,
is_pot_scale=False),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
_base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py']

resnet = _base_.model
pretrained_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501
float_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501

model = dict(
_delete_=True,
_scope_='mmrazor',
type='GeneralQuant',
data_preprocessor=dict(
type='mmcls.ClsDataPreprocessor',
num_classes=10,
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# loaded images are already RGB format
to_rgb=False),
type='MMArchitectureQuant',
architecture=resnet,
pretrained_ckpt=pretrained_ckpt,
float_checkpoint=float_ckpt,
quantizer=dict(
type='CustomQuantizer',
type='OpenvinoQuantizer',
skipped_methods=[
'mmcls.models.heads.ClsHead._get_loss',
'mmcls.models.heads.ClsHead._get_predictions'
Expand All @@ -31,8 +23,8 @@
a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'),
w_qscheme=dict(
bit=8,
is_symmetry=False,
is_per_channel=False,
is_symmetry=True,
is_per_channel=True,
is_pot_scale=False,
),
a_qscheme=dict(
Expand All @@ -55,16 +47,15 @@
end=100)

model_wrapper_cfg = dict(
type='mmrazor.GeneralQuantDDP',
type='mmrazor.MMArchitectureQuantDDP',
broadcast_buffers=False,
find_unused_parameters=True)

# train, val, test setting
train_cfg = dict(
_delete_=True,
type='mmrazor.QATEpochBasedLoop',
by_epoch=True,
max_epochs=100,
val_interval=1)
val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop')
test_cfg = val_cfg
# test_cfg = val_cfg
75 changes: 0 additions & 75 deletions configs/quantization/qat/lsq_resnet18_8xb32_in1k.py

This file was deleted.

Loading

0 comments on commit 71ad718

Please sign in to comment.