Skip to content

Commit

Permalink
delete is_qat, add doc and fix pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Jan 18, 2023
1 parent db1acb3 commit 6fae8dd
Show file tree
Hide file tree
Showing 9 changed files with 17 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
is_qat=False,
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
is_qat=False,
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
is_qat=False,
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
is_qat=False,
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
float_checkpoint=float_checkpoint,
quantizer=dict(
type='mmrazor.OpenVINOQuantizer',
is_qat=True,
global_qconfig=global_qconfig,
tracer=dict(
type='mmrazor.CustomTracer',
Expand Down
3 changes: 3 additions & 0 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def _build_qmodels(self, model: BaseModel):
observed_module = self.quantizer.prepare(model, concrete_args)
qmodels[mode] = observed_module

# data_samples can not be None in detectors during prediction.
# But we need to make the dummy prediction in _build_qmodels.
# It is more convenient to use `tensor` mode.
is_training = qmodels['tensor'].training
# Avoid random input changing bn's statistics
qmodels['tensor'].eval()
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/models/quantizers/academic_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def prepare(self, model, concrete_args=None):
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(
preserved_attributes)

self.sync_module_training_mode(graph_module)
# set the training modes of all modules to True to `_fuse_fx` correctly
self.sync_module_training_mode(graph_module, mode=True)

graph_module = _fuse_fx(
graph_module=graph_module,
Expand Down
16 changes: 12 additions & 4 deletions mmrazor/models/quantizers/native_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,21 @@ def __init__(self,
assert w_mode in self.support_w_modes
assert a_mode in self.support_a_modes

self.qconfig_mapping = self.get_qconfig_mapping(no_observer_modules)
self.qconfig_mapping = self.gen_qconfig_mapping(
self.qconfig, no_observer_modules)

self.backend_config = BackendConfigs[self.backend]
self.example_inputs = (torch.randn(1, 3, 224, 224), )

self.extra_redundant_fakequants = extra_redundant_fakequants

def get_qconfig_mapping(self, no_observer_modules):
qconfig_mapping = QConfigMapping().set_global(self.qconfig.convert())
def gen_qconfig_mapping(self, qconfig, no_observer_modules):
"""Convert qconfig in config file to `QConfigMapping`.
`QConfigMapping` is a custom class for mapping from model ops to
:class:`torch.ao.quantization.QConfig` s.
"""
qconfig_mapping = QConfigMapping().set_global(qconfig.convert())

if no_observer_modules is not None:
no_observer_modules = str2class(no_observer_modules)
Expand Down Expand Up @@ -197,7 +203,9 @@ def prepare(self, model, concrete_args=None):
traced_graph = self.tracer.trace(model, concrete_args=concrete_args)
graph_module = build_graphmodule(model, traced_graph)

self.sync_module_training_mode(graph_module)
# set the training modes of all modules to True to `_fuse_fx` correctly
self.sync_module_training_mode(graph_module, mode=True)

graph_module = _fuse_fx(
graph_module=graph_module,
is_qat=True,
Expand Down
17 changes: 0 additions & 17 deletions tests/test_runners/test_quantization_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
FloatFunctional = get_placeholder('torch>=1.13')
FXFloatFunctional = get_placeholder('torch>=1.13')

from mmrazor import digit_version


class ToyDataset(Dataset):
METAINFO = dict() # type: ignore
Expand Down Expand Up @@ -68,21 +66,6 @@ def calibrate_step(self, data):
data = self.data_preprocessor(data, False)
return self.architecture(**data)

def swap_ff_with_fxff(self, model):
if digit_version(torch.__version__) < digit_version('1.13.0'):
self.skipTest('version of torch < 1.13.0')

modules_to_swap = []
for name, module in model.named_children():
if isinstance(module, FloatFunctional):
modules_to_swap.append(name)
else:
self.swap_ff_with_fxff(module)

for name in modules_to_swap:
del model._modules[name]
model._modules[name] = FXFloatFunctional()

def sync_qparams(self, src_mode):
pass

Expand Down

0 comments on commit 6fae8dd

Please sign in to comment.