diff --git a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py index e0f2128cc..7c919c0fd 100644 --- a/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_mbv2_8xb32_in1k_calib32xb32.py @@ -34,7 +34,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py index 84d757552..125f46367 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb32_in1k_calib32xb32.py @@ -36,7 +36,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py index 03f8a4e22..f629337ed 100644 --- a/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_resnet50_8xb32_in1k_calib32xb32.py @@ -36,7 +36,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py index 59fb9f9df..578f5fe84 100644 --- a/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py +++ b/configs/quantization/ptq/ptq_openvino_retina_r50_1x_coco_calib32xb32.py @@ -35,7 +35,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=False, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py index 3a8a65bb8..0b79232f8 100644 --- a/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb32_in1k.py @@ -29,7 +29,6 @@ float_checkpoint=float_checkpoint, quantizer=dict( type='mmrazor.OpenVINOQuantizer', - is_qat=True, global_qconfig=global_qconfig, tracer=dict( type='mmrazor.CustomTracer', diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 767d7c4ce..06580cbb3 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -178,6 +178,9 @@ def _build_qmodels(self, model: BaseModel): observed_module = self.quantizer.prepare(model, concrete_args) qmodels[mode] = observed_module + # data_samples can not be None in detectors during prediction. + # But we need to make the dummy prediction in _build_qmodels. + # It is more convenient to use `tensor` mode. is_training = qmodels['tensor'].training # Avoid random input changing bn's statistics qmodels['tensor'].eval() diff --git a/mmrazor/models/quantizers/academic_quantizer.py b/mmrazor/models/quantizers/academic_quantizer.py index c8824e512..2d56be6c5 100644 --- a/mmrazor/models/quantizers/academic_quantizer.py +++ b/mmrazor/models/quantizers/academic_quantizer.py @@ -104,7 +104,8 @@ def prepare(self, model, concrete_args=None): fuse_custom_config = FuseCustomConfig().set_preserved_attributes( preserved_attributes) - self.sync_module_training_mode(graph_module) + # set the training modes of all modules to True to `_fuse_fx` correctly + self.sync_module_training_mode(graph_module, mode=True) graph_module = _fuse_fx( graph_module=graph_module, diff --git a/mmrazor/models/quantizers/native_quantizer.py b/mmrazor/models/quantizers/native_quantizer.py index 102a6710f..1fe620b7e 100644 --- a/mmrazor/models/quantizers/native_quantizer.py +++ b/mmrazor/models/quantizers/native_quantizer.py @@ -123,15 +123,21 @@ def __init__(self, assert w_mode in self.support_w_modes assert a_mode in self.support_a_modes - self.qconfig_mapping = self.get_qconfig_mapping(no_observer_modules) + self.qconfig_mapping = self.gen_qconfig_mapping( + self.qconfig, no_observer_modules) self.backend_config = BackendConfigs[self.backend] self.example_inputs = (torch.randn(1, 3, 224, 224), ) self.extra_redundant_fakequants = extra_redundant_fakequants - def get_qconfig_mapping(self, no_observer_modules): - qconfig_mapping = QConfigMapping().set_global(self.qconfig.convert()) + def gen_qconfig_mapping(self, qconfig, no_observer_modules): + """Convert qconfig in config file to `QConfigMapping`. + + `QConfigMapping` is a custom class for mapping from model ops to + :class:`torch.ao.quantization.QConfig` s. + """ + qconfig_mapping = QConfigMapping().set_global(qconfig.convert()) if no_observer_modules is not None: no_observer_modules = str2class(no_observer_modules) @@ -197,7 +203,9 @@ def prepare(self, model, concrete_args=None): traced_graph = self.tracer.trace(model, concrete_args=concrete_args) graph_module = build_graphmodule(model, traced_graph) - self.sync_module_training_mode(graph_module) + # set the training modes of all modules to True to `_fuse_fx` correctly + self.sync_module_training_mode(graph_module, mode=True) + graph_module = _fuse_fx( graph_module=graph_module, is_qat=True, diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index 0ddc578bf..bafeb203e 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -39,8 +39,6 @@ FloatFunctional = get_placeholder('torch>=1.13') FXFloatFunctional = get_placeholder('torch>=1.13') -from mmrazor import digit_version - class ToyDataset(Dataset): METAINFO = dict() # type: ignore @@ -68,21 +66,6 @@ def calibrate_step(self, data): data = self.data_preprocessor(data, False) return self.architecture(**data) - def swap_ff_with_fxff(self, model): - if digit_version(torch.__version__) < digit_version('1.13.0'): - self.skipTest('version of torch < 1.13.0') - - modules_to_swap = [] - for name, module in model.named_children(): - if isinstance(module, FloatFunctional): - modules_to_swap.append(name) - else: - self.swap_ff_with_fxff(module) - - for name in modules_to_swap: - del model._modules[name] - model._modules[name] = FXFloatFunctional() - def sync_qparams(self, src_mode): pass