Skip to content

Commit

Permalink
[Feature] Support lsq (#501)
Browse files Browse the repository at this point in the history
* support deploy_cfg=None

* replace fakequant before load ckpt

* add _load_from_state_dict to lsq fakequant

* fix pre-commit

* test lsq load state dict

* change github ci: ubuntu 18.04 to ubuntu 20.04

* get_deploy_model order change back

* sync before save ckpt

* delete strict=False

* test context rewriter

* fix pre commit config

* try to fix ci

* [Bug] Try to fix CI (#502)

fix lint

---------

Co-authored-by: humu789 <humu@pjlab.org.cn>
Co-authored-by: humu789 <88702197+humu789@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 11, 2023
1 parent 5aff276 commit 05da6f5
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ concurrency:

jobs:
test_linux:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.7]
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/engine/runner/quantization_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def run(self):
and self._epoch % self.val_interval == 0):
# observer disabled during evaluation
self.prepare_for_val()
self.runner.model.sync_qparams(src_mode='loss')
self.runner.val_loop.run()

self.runner.call_hook('after_train')
Expand All @@ -112,6 +111,7 @@ def run_epoch(self) -> None:
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)

self.runner.model.sync_qparams(src_mode='loss')
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down Expand Up @@ -185,6 +185,7 @@ def run_epoch(self) -> None:
self.runner.model.apply(enable_param_learning)
self.run_iter(idx, data_batch)

self.runner.model.sync_qparams(src_mode='loss')
self.runner.call_hook('after_train_epoch')
self._epoch += 1

Expand Down
45 changes: 25 additions & 20 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
disable_observer)
except ImportError:
from mmrazor.utils import get_placeholder

FakeQuantizeBase = get_placeholder('torch>=1.13')
MinMaxObserver = get_placeholder('torch>=1.13')
PerChannelMinMaxObserver = get_placeholder('torch>=1.13')
Expand Down Expand Up @@ -283,23 +284,31 @@ def _build_qmodels(self, model: BaseModel):
"""

rewriter_context = self._get_rewriter_context_in_mmdeploy(
self.deploy_cfg)
self.deploy_cfg) if self.deploy_cfg is not None else None

# Pop function records in `quantizer.tracer.skipped_method` temporarily
function_record_backup = self._pop_function_record_in_rewriter_context(
rewriter_context)
if rewriter_context is not None:
# Pop function records in `quantizer.tracer.skipped_method`
# temporarily
function_record_backup = \
self._pop_function_record_in_rewriter_context(rewriter_context)

qmodels = nn.ModuleDict()
for mode in self.forward_modes:
concrete_args = {'mode': mode}
# todo: support qat.
with rewriter_context:

if rewriter_context is not None:
with rewriter_context:
observed_module = self.quantizer.prepare(
model, concrete_args)
else:
observed_module = self.quantizer.prepare(model, concrete_args)

qmodels[mode] = observed_module

# Add these popped function records back.
rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records.update(function_record_backup)
if rewriter_context is not None:
# Add these popped function records back.
rewriter_context._rewriter_manager.function_rewriter. \
_registry._rewrite_records.update(function_record_backup)

# data_samples can not be None in detectors during prediction.
# But we need to make the dummy prediction in _build_qmodels.
Expand Down Expand Up @@ -357,7 +366,10 @@ def get_deploy_model(self):
observed_model.load_state_dict(quantized_state_dict)

self.quantizer.post_process_for_deploy(
observed_model, device=device, keep_w_fake_quant=True)
observed_model,
device=device,
keep_w_fake_quant=True,
update_weight_with_fakequant=True)

# replace various activation fakequant with base fakequant, which
# contributes to deploy our model to various backends.
Expand Down Expand Up @@ -406,21 +418,14 @@ def calibrate_step(self, data: Union[Dict, Tuple, List]):

return self.module.calibrate_step(data)

def sync_qparams(self, src: str):
def sync_qparams(self, src_mode: str):
"""Same as in 'MMArchitectureQuant'. Sync all quantize parameters in
different `forward_modes`. We could have several modes to generate
graphs, but in training, only one graph will be update, so we need to
sync qparams on the other graphs.
Args:
src (str): The src modes of forward method.
Note:
`traverse()` function recursively traverses all module to sync
quantized graph generated from different `forward_modes`.
This is because We have different mode ('tensor', 'predict',
'loss') in OpenMMLab architecture which have different graph
in some subtle ways, so we need to sync them here.
src_mode (str): The src modes of forward method.
"""

self.module.sync_qparams(src)
self.module.sync_qparams(src_mode)
40 changes: 40 additions & 0 deletions mmrazor/models/fake_quants/lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,46 @@ def forward(self, X):

return X

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Removing this function throws an error that the the size of the
loaded tensor does not match the original size i.e., These buffers
start out with numel 0 and become numel 1 once they have their first
forward pass.
Modified from https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fake_quantize.py # noqa:E501
"""
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
# Custom handling to allow loading scale and zero_point
# of size N into uninitialized buffers of size 0. The
# buffers are resized here, and the values are copied in
# the default state_dict loading code of the parent.
if name == 'scale':
self.scale.data = self.scale.data.resize_(val.shape)
else:
assert name == 'zero_point'
self.zero_point.data = self.zero_point.data.resize_(
val.shape)
# For torchscript module we need to update the attributes here
# since we do not call the `_load_from_state_dict` function
# defined module.py
if torch.jit.is_scripting():
if name == 'scale':
self.scale.copy_(val)
else:
assert name == 'zero_point'
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
super(LearnableFakeQuantize,
self)._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs)

@torch.jit.export
def extra_repr(self):
"""The printable representational string."""
Expand Down
74 changes: 59 additions & 15 deletions tests/test_models/test_algorithms/test_mm_architecture.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import shutil
import tempfile
from unittest import TestCase, skip
from unittest import TestCase, skipIf

import torch
import torch.nn as nn
Expand All @@ -13,8 +14,15 @@
from mmrazor.utils import get_placeholder
GraphModule = get_placeholder('torch>=1.13')

from mmengine import ConfigDict
from mmengine.model import BaseModel

try:
import mmdeploy
except ImportError:
from mmrazor.utils import get_package_placeholder
mmdeploy = get_package_placeholder('mmdeploy')

from mmrazor import digit_version
from mmrazor.models.algorithms import MMArchitectureQuant
from mmrazor.registry import MODELS
Expand Down Expand Up @@ -101,12 +109,44 @@ def forward(self, inputs, data_samples, mode: str = 'tensor'):
return outputs


@skip
DEPLOY_CFG = ConfigDict(
onnx_config=dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='end2end.onnx',
input_names=['input'],
output_names=['output'],
input_shape=None,
optimize=True,
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch'
}
}),
backend_config=dict(
type='openvino',
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 224, 224]))]),
codebase_config=dict(type='mmcls', task='Classification'),
function_record_to_pop=[
'mmcls.models.classifiers.ImageClassifier.forward',
'mmcls.models.classifiers.BaseClassifier.forward'
],
)


@skipIf(
digit_version(torch.__version__) < digit_version('1.13.0'),
'PyTorch version lower than 1.13.0 is not supported.')
class TestMMArchitectureQuant(TestCase):

def setUp(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')

MODELS.register_module(module=ToyQuantModel, force=True)

Expand All @@ -116,7 +156,7 @@ def setUp(self):
toymodel = ToyQuantModel()
torch.save(toymodel.state_dict(), filename)

global_qconfig = dict(
global_qconfig = ConfigDict(
w_observer=dict(type='mmrazor.PerChannelMinMaxObserver'),
a_observer=dict(type='mmrazor.MovingAverageMinMaxObserver'),
w_fake_quant=dict(type='mmrazor.FakeQuantize'),
Expand All @@ -132,7 +172,7 @@ def setUp(self):
is_symmetry=True,
averaging_constant=0.1),
)
alg_kwargs = dict(
alg_kwargs = ConfigDict(
type='mmrazor.MMArchitectureQuant',
architecture=dict(type='ToyQuantModel'),
float_checkpoint=filename,
Expand All @@ -141,23 +181,23 @@ def setUp(self):
global_qconfig=global_qconfig,
tracer=dict(type='mmrazor.CustomTracer')))
self.alg_kwargs = alg_kwargs
self.toy_model = MODELS.build(self.alg_kwargs)

def tearDown(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
MODELS.module_dict.pop('ToyQuantModel')
shutil.rmtree(self.temp_dir)

def test_init(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.toy_model = MODELS.build(self.alg_kwargs)
assert isinstance(self.toy_model, MMArchitectureQuant)
assert hasattr(self.toy_model, 'quantizer')

alg_kwargs = copy.deepcopy(self.alg_kwargs)
alg_kwargs.deploy_cfg = DEPLOY_CFG
assert isinstance(self.toy_model, MMArchitectureQuant)
assert hasattr(self.toy_model, 'quantizer')

def test_sync_qparams(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.toy_model = MODELS.build(self.alg_kwargs)
mode = self.toy_model.forward_modes[0]
self.toy_model.sync_qparams(mode)
w_loss = self.toy_model.qmodels[
Expand All @@ -170,12 +210,16 @@ def test_sync_qparams(self):
assert w_loss.equal(w_tensor)

def test_build_qmodels(self):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')
self.toy_model = MODELS.build(self.alg_kwargs)
for forward_modes in self.toy_model.forward_modes:
qmodels = self.toy_model.qmodels[forward_modes]
assert isinstance(qmodels, GraphModule)

def test_get_deploy_model(self):
self.toy_model = MODELS.build(self.alg_kwargs)
deploy_model = self.toy_model.get_deploy_model()
self.assertIsInstance(deploy_model, torch.fx.graph_module.GraphModule)

def test_calibrate_step(self):
# TODO
pass
24 changes: 23 additions & 1 deletion tests/test_models/test_fake_quants/test_lsq_fake_quants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from mmrazor.models import LearnableFakeQuantize

try:
from torch.ao.quantization import MovingAverageMinMaxObserver
from torch.ao.quantization import (MovingAverageMinMaxObserver,
MovingAveragePerChannelMinMaxObserver)
except ImportError:
from mmrazor.utils import get_placeholder
MovingAverageMinMaxObserver = get_placeholder('torch>=1.13')
MovingAveragePerChannelMinMaxObserver = get_placeholder('torch>=1.13')


class TestLearnableFakeQuantize(TestCase):
Expand All @@ -38,6 +40,16 @@ def setUp(self):
reduce_range=True,
zero_point_trainable=False)

self.zero_point_untrainable_per_channel_fakequant = \
LearnableFakeQuantize.with_args(
observer=MovingAveragePerChannelMinMaxObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_channel_affine,
reduce_range=True,
zero_point_trainable=False)

def test_repr(self):
fq_module = self.zero_point_untrainable_fakequant()
repr_str = f'static_enabled={torch.tensor([1], dtype=torch.uint8)}, '
Expand Down Expand Up @@ -184,3 +196,13 @@ def test_state(self):
self.assertEqual(fq_module.zero_point.requires_grad, 1)
self.assertEqual(fq_module.fake_quant_enabled[0], 1)
self.assertEqual(fq_module.static_enabled[0], 0)

def test_load_state_dict(self):
fq_module = self.zero_point_untrainable_per_channel_fakequant()
state_dict = fq_module.state_dict()
X = torch.rand(32, 16, 3, 3, dtype=torch.float32)
# After forwarding, the shape of `scale` and `zero_point` in
# `fq_module` will be in shape (32, ), while the shape of those in
# `state_dict` are in shape (1, ).
_ = fq_module(X)
fq_module.load_state_dict(state_dict)

0 comments on commit 05da6f5

Please sign in to comment.