diff --git a/configs/quantization/ptq/demo.py b/configs/quantization/ptq/demo.py deleted file mode 100644 index af6a0a5df..000000000 --- a/configs/quantization/ptq/demo.py +++ /dev/null @@ -1 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] diff --git a/configs/quantization/ptq/adaround.py b/configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py similarity index 68% rename from configs/quantization/ptq/adaround.py rename to configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py index 78157c61a..bb6dbc778 100644 --- a/configs/quantization/ptq/adaround.py +++ b/configs/quantization/ptq/ptq_openvino_resnet18_8xb16_cifar10.py @@ -2,22 +2,14 @@ test_cfg = dict( type='mmrazor.PTQLoop', - - # reconstruction_cfg=dict( - # pattern='layer', - # loss=dict( - # type='mmrazor.AdaRoundLoss', - # iters=20000 - # ) - # ) ) model = dict( _delete_=True, - type='mmrazor.GeneralQuant', + type='mmrazor.MMArchitectureQuant', architecture=_base_.model, quantizer=dict( - type='mmrazor.CustomQuantizer', + type='mmrazor.OpenvinoQuantizer', is_qat=False, skipped_methods=[ 'mmcls.models.heads.ClsHead._get_loss', @@ -27,16 +19,16 @@ qtype='affine', w_observer=dict(type='mmrazor.MSEObserver'), a_observer=dict(type='mmrazor.EMAMSEObserver'), - w_fake_quant=dict(type='mmrazor.AdaRoundFakeQuantize'), + w_fake_quant=dict(type='mmrazor.FakeQuantize'), a_fake_quant=dict(type='mmrazor.FakeQuantize'), w_qscheme=dict( - bit=2, - is_symmetry=False, + bit=8, + is_symmetry=True, is_per_channel=True, is_pot_scale=False, ), a_qscheme=dict( - bit=4, + bit=8, is_symmetry=False, is_per_channel=False, is_pot_scale=False), diff --git a/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py b/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py similarity index 68% rename from configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py rename to configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py index 412a6fd87..8076769a9 100644 --- a/configs/quantization/qat/lsq_resnet18_8xb16_cifar10.py +++ b/configs/quantization/qat/lsq_openvino_resnet18_8xb16_cifar10.py @@ -1,24 +1,16 @@ _base_ = ['mmcls::resnet/resnet18_8xb16_cifar10.py'] resnet = _base_.model -pretrained_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 +float_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_b16x8_cifar10_20210528-bd6371c8.pth' # noqa: E501 model = dict( _delete_=True, _scope_='mmrazor', - type='GeneralQuant', - data_preprocessor=dict( - type='mmcls.ClsDataPreprocessor', - num_classes=10, - # RGB format normalization parameters - mean=[125.307, 122.961, 113.8575], - std=[51.5865, 50.847, 51.255], - # loaded images are already RGB format - to_rgb=False), + type='MMArchitectureQuant', architecture=resnet, - pretrained_ckpt=pretrained_ckpt, + float_checkpoint=float_ckpt, quantizer=dict( - type='CustomQuantizer', + type='OpenvinoQuantizer', skipped_methods=[ 'mmcls.models.heads.ClsHead._get_loss', 'mmcls.models.heads.ClsHead._get_predictions' @@ -31,8 +23,8 @@ a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), w_qscheme=dict( bit=8, - is_symmetry=False, - is_per_channel=False, + is_symmetry=True, + is_per_channel=True, is_pot_scale=False, ), a_qscheme=dict( @@ -55,7 +47,7 @@ end=100) model_wrapper_cfg = dict( - type='mmrazor.GeneralQuantDDP', + type='mmrazor.MMArchitectureQuantDDP', broadcast_buffers=False, find_unused_parameters=True) @@ -63,8 +55,7 @@ train_cfg = dict( _delete_=True, type='mmrazor.QATEpochBasedLoop', - by_epoch=True, max_epochs=100, val_interval=1) val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -test_cfg = val_cfg +# test_cfg = val_cfg diff --git a/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py b/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py deleted file mode 100644 index a0885a52a..000000000 --- a/configs/quantization/qat/lsq_resnet18_8xb32_in1k.py +++ /dev/null @@ -1,75 +0,0 @@ -_base_ = ['mmcls::resnet/resnet18_8xb32_in1k.py'] - -train_cfg = dict( - _delete_=True, - type='mmrazor.QATEpochBasedLoop', - max_epochs=_base_.train_cfg.max_epochs) - -resnet = _base_.model -ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth' # noqa: E501 -resnet.init_cfg = dict(type='Pretrained', checkpoint=ckpt) - -model = dict( - _delete_=True, - _scope_='mmrazor', - type='GeneralQuant', - # data_preprocessor = dict( - # num_classes=1000, - # # RGB format normalization parameters - # mean=[123.675, 116.28, 103.53], - # std=[58.395, 57.12, 57.375], - # # convert image from BGR to RGB - # to_rgb=True, - # ), - architecture=resnet, - quantizer=dict( - type='CustomQuantizer', - skipped_methods=[ - 'mmcls.models.heads.ClsHead._get_loss', - 'mmcls.models.heads.ClsHead._get_predictions' - ], - qconfig=dict( - qtype='affine', - w_observer=dict(type='mmrazor.MinMaxObserver'), - a_observer=dict(type='mmrazor.EMAMinMaxObserver'), - w_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - a_fake_quant=dict(type='mmrazor.LearnableFakeQuantize'), - w_qscheme=dict( - bit=8, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False, - ), - a_qscheme=dict( - bit=8, - is_symmetry=False, - is_per_channel=False, - is_pot_scale=False), - ))) - -optim_wrapper = dict( - optimizer=dict(type='SGD', lr=0.004, momentum=0.9, weight_decay=0.0001)) - -# learning policy -param_scheduler = dict( - _delete_=True, - type='CosineAnnealingLR', - T_max=100, - by_epoch=True, - begin=0, - end=100) - -default_hooks = dict( - checkpoint=dict( - type='CheckpointHook', - interval=5, - max_keep_ckpts=3, - out_dir='/mnt/petrelfs/caoweihan.p/training_ckpt/quant')) - -model_wrapper_cfg = dict( - type='mmrazor.GeneralQuantDDP', - broadcast_buffers=False, - find_unused_parameters=False) - -val_cfg = dict(_delete_=True, type='mmrazor.QATValLoop') -test_cfg = val_cfg diff --git a/mmrazor/engine/runner/quantization_loops.py b/mmrazor/engine/runner/quantization_loops.py index a2d5d383b..bca61a563 100644 --- a/mmrazor/engine/runner/quantization_loops.py +++ b/mmrazor/engine/runner/quantization_loops.py @@ -1,24 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy -import os from typing import Dict, List, Optional, Sequence, Tuple, Union -import numpy as np import torch from mmengine.evaluator import Evaluator -from mmengine.registry import MODELS -from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop, autocast -from torch.ao.quantization import disable_observer +from mmengine.runner import EpochBasedTrainLoop, TestLoop, ValLoop +from torch.ao.quantization import (disable_observer, enable_fake_quant, + enable_observer) from torch.nn.intrinsic.qat import freeze_bn_stats from torch.utils.data import DataLoader -from mmrazor.models.task_modules import (ModuleInputsRecorder, - ModuleOutputsRecorder, - RecorderManager) from mmrazor.registry import LOOPS -from .utils import extract_blocks, extract_layers, extract_subgraph - -_ADAROUND_SUPPORT_TYPE = (torch.nn.Conv2d, torch.nn.Linear) @LOOPS.register_module() @@ -59,18 +50,14 @@ def __init__( self.disable_observer_begin = disable_observer_begin self.freeze_bn_begin = freeze_bn_begin - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - qat_metrics = dict() - for key, value in metrics.items(): - qat_key = 'qat.' + key - ori_key = 'original.' + key - qat_metrics[qat_key] = value - self.runner.message_hub.log_scalars.pop(f'val/{ori_key}', None) + def run(self) -> torch.nn.Module: + """Launch training.""" + self.runner.call_hook('before_train') while self._epoch < self._max_epochs: # state: observer_enabled, fakequant_enabled - self.runner.model.state = (True, True) + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) self.run_epoch() self._decide_current_val_interval() @@ -78,8 +65,8 @@ def __init__( and self._epoch >= self.val_begin and self._epoch % self.val_interval == 0): # observer disabled during evaluation - self.runner.model.state = (False, True) - self.runner.model.sync_param() + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) self.runner.val_loop.run() self.runner.call_hook('after_train') @@ -185,8 +172,8 @@ def run_iter(self, idx, data_batch: Sequence[dict], model): self.runner.call_hook( 'before_val_iter', batch_idx=idx, data_batch=data_batch) # outputs should be sequence of BaseDataElement - with autocast(enabled=self.fp16): - outputs = model.val_step(data_batch) + + outputs = model.val_step(data_batch) self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_val_iter', @@ -219,15 +206,13 @@ def run(self) -> dict: self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') self.runner.model.eval() - self.runner.model.state = (True, False) + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(enable_observer) for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - self.runner.call_hook('after_test_epoch', metrics=metrics) + self.runner.call_hook('after_test_epoch', metrics=None) self.runner.call_hook('after_test') # todo: hard code to save checkpoint on disk @@ -238,80 +223,10 @@ def run(self) -> dict: save_optimizer=False, save_param_scheduler=False) - return metrics - - @torch.no_grad() - def run_iter(self, idx, data_batch: Sequence[dict]) -> None: - """Iterate one mini-batch. - - Args: - data_batch (Sequence[dict]): Batch of data from dataloader. - """ - self.runner.call_hook( - 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # predictions should be sequence of BaseDataElement - - outputs = self.runner.model.calibrate_step(data_batch) - - self.runner.call_hook( - 'after_test_iter', - batch_idx=idx, - data_batch=data_batch, - outputs=outputs) - - -# TODO refactor to supoort DDP -@LOOPS.register_module() -class AdaRoundLoop(TestLoop): - """`TestLoop` for Post Training Quantization. + self.runner.model.apply(enable_fake_quant) + self.runner.model.apply(disable_observer) - Args: - runner (Runner): A reference of runner - dataloader (Dataloader or dict): An iterator to generate one batch of - dataset each iteration. - evaluator (Evaluator or dict or list): Used for computing metrics. - calibrate_dataloader (Dataloader or dict, optional): A dataloader - object or a dict to build a dataloader for calibration. Defaults - to None. - batch_num (Optional[int], optional): Total calibration batches. - Defaults to None. - reconstruction_cfg (Optional[Dict], optional): Model reconstruction - configuration. Defaults to None. - fp16 (bool, optional): Enable FP16 training mode. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - fp16: bool = False): - super().__init__(runner, dataloader, evaluator, fp16) - - def run(self) -> None: - """Launch test.""" - self.runner.call_hook('before_test') - self.runner.call_hook('before_test_epoch') - self.runner.model.eval() - self.runner.model.state = (1, 0) - - for idx, data_batch in enumerate(self.dataloader): - self.run_iter(idx, data_batch) - - # compute metrics - metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) - - self.runner.call_hook('after_test_epoch', metrics=metrics) - self.runner.call_hook('after_test') - - # todo: hard code to save checkpoint on disk - self.runner.save_checkpoint( - self.runner.work_dir, - 'checkpoint_after_ptq.pth', - file_client_args=None, - save_optimizer=False, - save_param_scheduler=False) - - return metrics + return self.runner.val_loop.run() @torch.no_grad() def run_iter(self, idx, data_batch: Sequence[dict]) -> None: @@ -322,208 +237,11 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: """ self.runner.call_hook( 'before_test_iter', batch_idx=idx, data_batch=data_batch) - # predictions should be sequence of BaseDataElement - outputs = self.runner.model.calibrate_step(data_batch) + _ = self.runner.model.calibrate_step(data_batch) self.runner.call_hook( 'after_test_iter', batch_idx=idx, data_batch=data_batch, - outputs=outputs) - - -# TODO refactor to supoort DDP -@LOOPS.register_module() -class AdaRoundLoop(TestLoop): - """`TestLoop` for Post Training Quantization. - - Args: - runner (Runner): A reference of runner - dataloader (Dataloader or dict): An iterator to generate one batch of - dataset each iteration. - evaluator (Evaluator or dict or list): Used for computing metrics. - calibrate_dataloader (Dataloader or dict, optional): A dataloader - object or a dict to build a dataloader for calibration. Defaults - to None. - batch_num (Optional[int], optional): Total calibration batches. - Defaults to None. - reconstruction_cfg (Optional[Dict], optional): Model reconstruction - configuration. Defaults to None. - fp16 (bool, optional): Enable FP16 training mode. Defaults to False. - """ - - def __init__(self, - runner, - dataloader: Union[DataLoader, Dict], - evaluator: Union[Evaluator, Dict, List], - calibrate_dataloader: Optional[Union[DataLoader, - Dict]] = None, - batch_num: Optional[int] = None, - reconstruction_cfg: Optional[Dict] = None, - fp16: bool = False): - super().__init__(runner, dataloader, evaluator, fp16) - if isinstance(calibrate_dataloader, dict): - # Determine whether or not different ranks use different seed. - diff_rank_seed = runner._randomness_cfg.get( - 'diff_rank_seed', False) - self.calibrate_dataloader = runner.build_dataloader( - calibrate_dataloader, - seed=runner.seed, - diff_rank_seed=diff_rank_seed) - else: - self.calibrate_dataloader = calibrate_dataloader - - self.is_calibrate = True if calibrate_dataloader is not None else False - - if self.runner.distributed: - self.model = runner.model.module - else: - self.model = runner.model - - self.batch_num = batch_num - self.config = reconstruction_cfg - - def calibrate(self, calibrate_dataloader) -> None: - self.model.eval() - with torch.no_grad(): - for i, batch_data in enumerate(calibrate_dataloader): - if self.batch_num and i >= self.batch_num: - break - self.model.calib_step(batch_data) - - def _save_inter_result(self, - model, - dataloader, - slices, - store_input=True, - store_output=True): - recorders = {} - for s in slices: - node_l, node_r = s[:2] - if store_input: - recorders[node_l.target + '_input'] = ModuleInputsRecorder( - node_l.target) - if store_output: - recorders[node_r.target + '_output'] = ModuleOutputsRecorder( - node_r.target) - manager = RecorderManager(recorders) - manager.initialize(model) - - with torch.no_grad(): - with manager: - for i, batch_data in enumerate(dataloader): - if self.batch_num and i >= self.batch_num: - break - batch_data = self.model.data_preprocessor( - batch_data, False) - model(**batch_data) - return manager - - def sub_reconstruction(self, graphmodule, input_recorder, output_recorder, - config): - w_para = [] - for layer in graphmodule.modules(): - # import pdb - # pdb.set_trace() - if isinstance(layer, _ADAROUND_SUPPORT_TYPE): - weight_fake_quant = layer.weight_fake_quant - weight_fake_quant.init(layer.weight.data) - w_para += [weight_fake_quant.alpha] - - w_opt = torch.optim.Adam(w_para) - loss_func = MODELS.build(config.loss) - - for _ in range(config.loss.iters): - w_opt.zero_grad() - - data_size = len(input_recorder.data_buffer) - data_index = np.random.randint(0, data_size) - out_quant = graphmodule( - input_recorder.get_recorder_data(data_index)) - out_fp = output_recorder.get_recorder_data(data_index) - err = loss_func(graphmodule, out_quant, out_fp) - err.backward() - w_opt.step() - - for layer in graphmodule.modules(): - if isinstance(layer, _ADAROUND_SUPPORT_TYPE): - weight_fake_quant = layer.weight_fake_quant - layer.weight.data = weight_fake_quant.get_hard_value( - layer.weight.data) - weight_fake_quant.adaround = False - if isinstance(layer, torch.quantization.FakeQuantize) and hasattr( - layer, 'prob'): - # recover to promise that drop activation quantization only - # occurs at reconstruction phase - layer.prob = 1.0 - - def reconstruction(self, graphmodule, calibrate_dataloader, config): - assert isinstance(graphmodule, torch.fx.GraphModule) - graphmodule_fp = graphmodule - graphmodule_quant = copy.deepcopy(graphmodule) - - # get layers/blocks need to reconstructe - slices = [] - if config.pattern == 'layer': - slices = extract_layers( - graphmodule, layer_types=_ADAROUND_SUPPORT_TYPE) - elif config.pattern == 'block': - slices = extract_blocks(graphmodule) - else: - # TODO: add remind - raise NotImplementedError - - # save fp inputs and outputs of each layers - manager_fp = self._save_inter_result(graphmodule_fp, - self.calibrate_dataloader, slices) - - # extract subgraph_module - for s in slices: - sub_graphmodule = extract_subgraph(graphmodule_quant, s) - manager_quant = self._save_inter_result( - graphmodule_quant, - self.calibrate_dataloader, [s], - store_output=False) - input_index = s[0].target + '_input' - output_index = s[1].target + '_output' - input_recorder = manager_quant.get_recorder(input_index) - output_recorder = manager_fp.get_recorder(output_index) - self.sub_reconstruction(sub_graphmodule, input_recorder, - output_recorder, config) - - return graphmodule_quant - - def run(self) -> None: - """Launch test.""" - self.runner.call_hook('before_test') - self.runner.call_hook('before_test_epoch') - - self.model.eval() - self.model.prepare() - - if self.is_calibrate: - self.model.state = (1, 0) - self.calibrate(self.calibrate_dataloader) - - self.model.state = (1, 1) - - if self.config is not None: - self.model.architecture = self.reconstruction( - self.model.architecture, self.calibrate_dataloader, - self.config) - - self.model.convert() - - self.model.eval() - from torch.onnx import OperatorExportTypes - dummy_input = torch.randn([1, 3, 224, 224]) - onnx_path = os.path.join(self.runner.work_dir, 'quantizied.onnx') - torch.onnx.export( - self.model.architecture, - dummy_input, - onnx_path, - opset_version=11, - operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK) - - self.runner.call_hook('after_test') + outputs=None) diff --git a/mmrazor/models/algorithms/__init__.py b/mmrazor/models/algorithms/__init__.py index 0d694f203..214b9212c 100644 --- a/mmrazor/models/algorithms/__init__.py +++ b/mmrazor/models/algorithms/__init__.py @@ -6,12 +6,13 @@ from .nas import DSNAS, DSNASDDP, SPOS, AutoSlim, AutoSlimDDP, Darts, DartsDDP from .pruning import SlimmableNetwork, SlimmableNetworkDDP from .pruning.ite_prune_algorithm import ItePruneAlgorithm -from .quantization import GeneralQuant +from .quantization import MMArchitectureQuant, MMArchitectureQuantDDP __all__ = [ 'SingleTeacherDistill', 'BaseAlgorithm', 'FpnTeacherDistill', 'SPOS', 'SlimmableNetwork', 'SlimmableNetworkDDP', 'AutoSlim', 'AutoSlimDDP', 'Darts', 'DartsDDP', 'SelfDistill', 'DataFreeDistillation', 'DAFLDataFreeDistillation', 'OverhaulFeatureDistillation', - 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'GeneralQuant' + 'ItePruneAlgorithm', 'DSNAS', 'DSNASDDP', 'MMArchitectureQuant', + 'MMArchitectureQuantDDP' ] diff --git a/mmrazor/models/algorithms/quantization/__init__.py b/mmrazor/models/algorithms/quantization/__init__.py index 84c25bbc0..337717c01 100644 --- a/mmrazor/models/algorithms/quantization/__init__.py +++ b/mmrazor/models/algorithms/quantization/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import GeneralQuant +from .base import MMArchitectureQuant, MMArchitectureQuantDDP -__all__ = ['GeneralQuant'] +__all__ = ['MMArchitectureQuant', 'MMArchitectureQuantDDP'] diff --git a/mmrazor/models/algorithms/quantization/base.py b/mmrazor/models/algorithms/quantization/base.py index c97d832ff..ad3c16340 100644 --- a/mmrazor/models/algorithms/quantization/base.py +++ b/mmrazor/models/algorithms/quantization/base.py @@ -20,7 +20,7 @@ @MODELS.register_module() -class GeneralQuant(BaseAlgorithm): +class MMArchitectureQuant(BaseAlgorithm): """General quantization. Args: @@ -39,13 +39,15 @@ class GeneralQuant(BaseAlgorithm): :class:`BaseModule`. """ + + def __init__(self, architecture, quantizer, - export_mode: str = 'predict', - qmodel_modes: List[str] = ['tensor', 'predict', 'loss'], data_preprocessor=None, - pretrained_ckpt: Optional[str] = None, + forward_modes = ('tensor', 'predict', 'loss'), + float_checkpoint: Optional[str] = None, + input_shapes=(1, 3, 224, 224), init_cfg=None): if data_preprocessor is None: @@ -53,35 +55,50 @@ def __init__(self, # The build process is in MMEngine, so we need to add scope here. data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') super().__init__(architecture, data_preprocessor, init_cfg) - if pretrained_ckpt: - _ = load_checkpoint(self.architecture, pretrained_ckpt) + if float_checkpoint: + _ = load_checkpoint(self.architecture, float_checkpoint) self.architecture._is_init = True self.quantizer = MODELS.build(quantizer) - self._observers_enabled = True - self._fake_quants_enabled = True - self.export_mode = export_mode - self.qmodel_modes = qmodel_modes + self.input_shapes = input_shapes + self.forward_modes = forward_modes + self.qmodels = self._build_qmodels(self.architecture) - def sync_param(self): + self.sync_param('tensor') + + def sync_param(self, src_mode): def traverse(module, prefix): for name, child in module._modules.items(): if module is None: continue - module_name = f'{prefix}{name}' + child_name = f'{prefix}{name}' if isinstance(child, FakeQuantizeBase): for name, param in child.named_parameters(): - param.data.copy_(self.qmodels['loss'].state_dict() - [f'{module_name}.{name}']) + param_name = f'{child_name}.{name}' + src_param = src_state_dict[param_name] + if src_param.shape == param.shape: + param.data.copy_(src_param) + else: + requirs_grad = param.requires_grad + param.requires_grad = False + param.resize_(src_param.shape) + param.requires_grad = requirs_grad + param.data.copy_(src_param) for name, buffer in child.named_buffers(): - buffer.data.copy_(self.qmodels['loss'].state_dict() - [f'{module_name}.{name}']) + buffer_name = f'{child_name}.{name}' + src_buffer = src_state_dict[buffer_name] + if src_buffer.shape == buffer.shape: + buffer.data.copy_(src_buffer) + else: + buffer.resize_(src_buffer.shape) + buffer.data.copy_(src_buffer) else: - traverse(child, f'{module_name}.') + traverse(child, f'{child_name}.') - for mode in self.qmodel_modes: - if mode == 'loss': + src_state_dict = self.qmodels[src_mode].state_dict() + for mode in self.forward_modes: + if mode == src_mode: continue traverse(self.qmodels[mode], '') @@ -92,12 +109,17 @@ def _build_qmodels(self, model): self.quantizer._swap_ff_with_fxff(model) tracer = self.quantizer.tracer - for mode in self.qmodel_modes: + for mode in self.forward_modes: concrete_args = {'mode': mode} traced_graph = tracer.trace(model, concrete_args=concrete_args) - qmodel = build_graphmodule(model, traced_graph) - qmodels[mode] = self.quantizer.prepare(model, qmodel) + graph_mopdule = build_graphmodule(model, traced_graph) + observed_module = self.quantizer.prepare(model, graph_mopdule) + + qmodels[mode] = observed_module + + dummy_input = torch.randn(self.input_shapes) + qmodels['predict'](dummy_input, None, 'predict') return qmodels @@ -114,39 +136,11 @@ def forward(self, def calibrate_step(self, data): data = self.data_preprocessor(data, False) - self.state = (1, 0) return self._run_forward(data, mode='tensor') - def convert(self, mode='predict'): - qmodel = self.qmodels[self.export_mode] - self.qmodels[mode] = self.quantizer.convert(qmodel) - - @property - def state(self): - return (self._observers_enabled, self._fake_quants_enabled) - - @state.setter - def state(self, state: Tuple[bool, bool]): - observers_enabled, fake_quants_enabled = state - qmodel = self.qmodels[self.export_mode] - for submodule in qmodel.modules(): - if isinstance(submodule, torch.quantization.FakeQuantize): - if observers_enabled: - submodule.enable_observer() - else: - submodule.disable_observer() - - if fake_quants_enabled: - submodule.enable_fake_quant() - else: - submodule.disable_fake_quant() - - self._observers_enabled = observers_enabled - self._fake_quants_enabled = fake_quants_enabled - @MODEL_WRAPPERS.register_module() -class GeneralQuantDDP(MMDistributedDataParallel): +class MMArchitectureQuantDDP(MMDistributedDataParallel): """DDPwapper for GeneralQuant.""" def __init__(self, @@ -165,18 +159,5 @@ def __init__(self, def calibrate_step(self, data): return self.module.calibrate_step(data) - @property - def state(self): - return (self.module._observers_enabled, - self.module._fake_quants_enabled) - - @state.setter - def state(self, state: Tuple[bool]): - self.module.state = state - - def convert(self, mode='predict'): - self.module.convert(mode) - self.module.qmodels[mode].cuda() - - def sync_param(self): - self.module.sync_param() + def sync_param(self, src): + self.module.sync_param(src) diff --git a/mmrazor/models/fake_quants/lsq.py b/mmrazor/models/fake_quants/lsq.py index 3e26631eb..a24898442 100644 --- a/mmrazor/models/fake_quants/lsq.py +++ b/mmrazor/models/fake_quants/lsq.py @@ -144,10 +144,5 @@ def forward(ctx, x, scale, zero_point, ch_axis, quant_min, quant_max, @staticmethod def symbolic(g, x, scale, zero_point, ch_axis, quant_min, quant_max, grad_factor): - return g.op( - '::FakeQuantizeLearnablePerchannelAffine', - x, - scale, - zero_point, - quant_min_i=quant_min, - quant_max_i=quant_max) + return g.op('::FakeQuantizeLearnablePerchannelAffine', x, scale, + zero_point, ch_axis, quant_min, quant_max) diff --git a/mmrazor/models/observers/lsq_observer.py b/mmrazor/models/observers/lsq_observer.py index d9b96d7a8..b543efe17 100644 --- a/mmrazor/models/observers/lsq_observer.py +++ b/mmrazor/models/observers/lsq_observer.py @@ -33,7 +33,7 @@ def forward(self, x_orig): x = x_orig.to(self.min_val.dtype) if self.ch_axis == -1: self.tensor_norm = x.abs().mean() - self.min_val, self.max_val = torch._aminmax(x) + self.min_val, self.max_val = torch.aminmax(x) else: # compute channel-wise mean x_dim = x.size() diff --git a/mmrazor/models/observers/mse.py b/mmrazor/models/observers/mse.py index f85abd902..a2b65a3a6 100644 --- a/mmrazor/models/observers/mse.py +++ b/mmrazor/models/observers/mse.py @@ -71,9 +71,8 @@ def mse_perchannel(self, new_max = x_max * (1.0 - (i * 0.01)) scale, zero_point = self._calculate_qparams(new_min, new_max) x_q = torch.fake_quantize_per_channel_affine( - x, scale, - zero_point.long() if _version_under_1100 else zero_point, - ch_axis, self.quant_min, self.quant_max) + x, scale, zero_point.int(), ch_axis, self.quant_min, + self.quant_max) score = self.lp_loss(x_q, x, reduce_dim) update_idx = (score < best_score) best_score[update_idx] = score[update_idx] @@ -87,7 +86,7 @@ def forward(self, x_orig): return x_orig x = x_orig.clone().detach().to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) min_val_cur, max_val_cur = self.mse( x, min_val_cur, max_val_cur, iter=95) else: @@ -131,7 +130,7 @@ def forward(self, x_orig): return x_orig x = x_orig.clone().detach().to(self.min_val.dtype) if self.ch_axis == -1: - min_val_cur, max_val_cur = torch._aminmax(x) + min_val_cur, max_val_cur = torch.aminmax(x) min_val_cur, max_val_cur = self.mse( x, min_val_cur, max_val_cur, iter=95) else: diff --git a/mmrazor/models/quantizers/__init__.py b/mmrazor/models/quantizers/__init__.py index e56902eba..2741e2fd1 100644 --- a/mmrazor/models/quantizers/__init__.py +++ b/mmrazor/models/quantizers/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import CustomQuantizer +from .openvino_quantizer import OpenvinoQuantizer from .trt_quantizer import TensorRTQuantizer -__all__ = ['CustomQuantizer', 'TensorRTQuantizer'] +__all__ = ['CustomQuantizer', 'TensorRTQuantizer', 'OpenvinoQuantizer'] diff --git a/mmrazor/models/quantizers/base.py b/mmrazor/models/quantizers/base.py index 6f1fb4e31..2dd3930fc 100644 --- a/mmrazor/models/quantizers/base.py +++ b/mmrazor/models/quantizers/base.py @@ -3,9 +3,11 @@ import torch from mmengine.model import BaseModule -from torch.ao.quantization import QConfig +from torch.ao.quantization import QConfig, enable_fake_quant from torch.ao.quantization.fx import prepare from torch.ao.quantization.quantize_fx import _convert_fx, _fuse_fx +from torch.nn.intrinsic.qat import modules as qat_fused_modules +from torch.nn.qat import modules as qat_modules from mmrazor.models.task_modules.tracer import CustomTracer from mmrazor.models.utils import (check_is_valid_convert_custom_config_dict, @@ -16,6 +18,25 @@ from mmrazor.structures.quantization import (CheckArgs, DefaultQconfigs, QuantizeScheme, SupportQtypes) +SUPPORT_QAT_MODULES = ( + qat_fused_modules.ConvBn1d, qat_fused_modules.ConvBn2d, + qat_fused_modules.ConvBn3d, qat_fused_modules.ConvBnReLU1d, + qat_fused_modules.ConvBnReLU2d, qat_fused_modules.ConvBnReLU3d, + qat_fused_modules.ConvReLU1d, qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvReLU3d, qat_fused_modules.LinearBn1d, + qat_fused_modules.LinearReLU, qat_modules.Conv1d, qat_modules.Conv2d, + qat_modules.Conv3d, qat_modules.Linear) + +MERGE_BN_MAPPINGS = { + qat_fused_modules.ConvBn1d: qat_modules.Conv1d, + qat_fused_modules.ConvBn2d: qat_modules.Conv2d, + qat_fused_modules.ConvBn3d: qat_modules.Conv3d, + qat_fused_modules.ConvBnReLU1d: qat_fused_modules.ConvReLU1d, + qat_fused_modules.ConvBnReLU2d: qat_fused_modules.ConvReLU2d, + qat_fused_modules.ConvBnReLU3d: qat_fused_modules.ConvReLU3d, + qat_fused_modules.LinearBn1d: qat_modules.Linear +} + @MODELS.register_module() class CustomQuantizer(BaseModule): @@ -23,7 +44,6 @@ class CustomQuantizer(BaseModule): Args: qconfig (Dict, optional): QConfig. Defaults to DefaultQconfigs['default']. # noqa: E501 - is_qat (bool, optional): Is QAT ro not. Defaults to True. skipped_methods (List, optional): Skipped methods list for tracer. Defaults to None. prepare_custom_config_dict (Dict, optional): `PrepareCustomConfig` @@ -39,7 +59,6 @@ class CustomQuantizer(BaseModule): def __init__(self, qconfig: Dict = DefaultQconfigs['default'], - is_qat: bool = True, skipped_methods: List = None, prepare_custom_config_dict: Dict = None, convert_custom_config_dict: Dict = None, @@ -73,7 +92,6 @@ def __init__(self, self.convert_custom_config_dict) check_is_valid_qconfig_dict(self.equalization_qconfig_dict) - self.is_qat = is_qat self.skipped_methods = skipped_methods self._remove_qconfig = _remove_qconfig self.tracer = self.build_tracer() @@ -90,9 +108,8 @@ def prepare(self, model, graph_module): prepared = prepare( graph_module, self.qconfig_dict, - self.is_qat, + True, self.tracer.node_name_to_scope, - prepare_custom_config_dict=self.prepare_custom_config_dict, equalization_qconfig_dict=self.equalization_qconfig_dict ) # type: ignore[operator] @@ -189,9 +206,41 @@ def build_tracer(self): return tracer def fuse_model(self, graph_module): - if not self.is_qat: - graph_module.eval() - - graph_module = _fuse_fx(graph_module, self.is_qat, + graph_module = _fuse_fx(graph_module, True, self.prepare_custom_config_dict) return graph_module + + def post_process_weight_fakequant(self, + observed_module, + keep_fake_quant=False): + + def traverse(module): + + for name, child in module.named_children(): + if isinstance(child, SUPPORT_QAT_MODULES): + weight_fakequant = child.weight_fake_quant + child.weight.data = weight_fakequant(child.weight.data) + + float_child = child.to_float() + + if keep_fake_quant: + for m in float_child.modules(): + setattr(m, 'qconfig', self.qconfig_dict['']) + + if type(child) in MERGE_BN_MAPPINGS: + cls = MERGE_BN_MAPPINGS[type(child)] + new_child = cls.from_float(float_child) + else: + new_child = child.from_float(float_child) + + new_child.weight_fake_quant(new_child.weight) + else: + new_child = float_child + setattr(module, name, new_child) + else: + traverse(child) + observed_module.apply(enable_fake_quant) + traverse(observed_module) + + def prepare_for_mmdeploy(self, model): + raise NotImplementedError diff --git a/mmrazor/models/quantizers/openvino_quantizer.py b/mmrazor/models/quantizers/openvino_quantizer.py new file mode 100644 index 000000000..7bf067593 --- /dev/null +++ b/mmrazor/models/quantizers/openvino_quantizer.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import disable_observer +from mmrazor.registry import MODELS +from mmrazor.structures.quantization import DefaultQconfigs +from mmrazor.models.task_modules.tracer.fx.custom_tracer import build_graphmodule +from .base import CustomQuantizer + + +@MODELS.register_module() +class OpenvinoQuantizer(CustomQuantizer): + """Quantizer for Openvino backend.""" + + support_bits = [8] + support_w_mode = ['per_channel'] + support_a_mode = ['per_tensor'] + + + def __init__(self, + qconfig, + is_qat=True, + skipped_methods=None, + prepare_custom_config_dict=None, + convert_custom_config_dict=None, + equalization_qconfig_dict=None, + _remove_qconfig=True, + init_cfg=None): + super().__init__(qconfig, is_qat, skipped_methods, + prepare_custom_config_dict, + convert_custom_config_dict, equalization_qconfig_dict, + _remove_qconfig, init_cfg) + + + + + def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None): + + + self._swap_ff_with_fxff(model) + graph = self.tracer.trace(model) + graph_module = build_graphmodule(model, graph) + observed_model = self.prepare(model, graph_module) + + + self.post_process_weight_fakequant(observed_model, keep_fake_quant=True) + + if dummy_input is not None: + observed_model(dummy_input) + + observed_model.apply(disable_observer) + + + return observed_model + + + + + + diff --git a/mmrazor/models/quantizers/trt_quantizer.py b/mmrazor/models/quantizers/trt_quantizer.py index 9dbe9f594..f2c46844a 100644 --- a/mmrazor/models/quantizers/trt_quantizer.py +++ b/mmrazor/models/quantizers/trt_quantizer.py @@ -1,4 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.ao.quantization import disable_observer + +from mmrazor.models.task_modules.tracer.fx.custom_tracer import \ + build_graphmodule from mmrazor.registry import MODELS from mmrazor.structures.quantization import DefaultQconfigs from .base import CustomQuantizer @@ -8,6 +13,10 @@ class TensorRTQuantizer(CustomQuantizer): """Quantizer for TensorRT backend.""" + support_bits = [8] + support_w_mode = ['per_channel'] + support_a_mode = ['per_tensor'] + def __init__(self, qconfig=DefaultQconfigs['tensorrt'], is_qat=True, @@ -21,3 +30,19 @@ def __init__(self, prepare_custom_config_dict, convert_custom_config_dict, equalization_qconfig_dict, _remove_qconfig, init_cfg) + + def prepare_for_mmdeploy(self, model, dummy_input=None, checkpoint=None) + + graph = self.tracer.trace(model) + graph_module = build_graphmodule(model, graph) + observed_model = self.prepare(model, graph_module) + + observed_model(torch.randn(1, 3, 224, 224)) + + self.post_process_weight_fakequant(observed_model) + if dummy_input is not None: + observed_model(dummy_input) + + observed_model.apply(disable_observer) + + return observed_model diff --git a/mmrazor/models/task_modules/tracer/__init__.py b/mmrazor/models/task_modules/tracer/__init__.py index dafd98e05..b93f7a876 100644 --- a/mmrazor/models/task_modules/tracer/__init__.py +++ b/mmrazor/models/task_modules/tracer/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .backward_tracer import BackwardTracer -from .fx import (CustomTracer, UntracedMethodRegistry, custom_symbolic_trace, - prepare_graph_module) +from .fx import (CustomTracer, UntracedMethodRegistry, build_graphmodule, + custom_symbolic_trace) from .loss_calculator import * # noqa: F401,F403 from .parsers import * # noqa: F401,F403 from .path import (Path, PathConcatNode, PathConvNode, PathDepthWiseConvNode, @@ -11,5 +11,5 @@ 'BackwardTracer', 'PathConvNode', 'PathLinearNode', 'PathNormNode', 'PathConcatNode', 'Path', 'PathList', 'PathNode', 'PathDepthWiseConvNode', 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', - 'prepare_graph_module' + 'build_graphmodule' ] diff --git a/mmrazor/models/task_modules/tracer/fx/__init__.py b/mmrazor/models/task_modules/tracer/fx/__init__.py index 998e9ffe1..1190f945c 100644 --- a/mmrazor/models/task_modules/tracer/fx/__init__.py +++ b/mmrazor/models/task_modules/tracer/fx/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .custom_tracer import (CustomTracer, UntracedMethodRegistry, - custom_symbolic_trace, build_graphmodule) + build_graphmodule, custom_symbolic_trace) __all__ = [ 'CustomTracer', 'UntracedMethodRegistry', 'custom_symbolic_trace', diff --git a/mmrazor/models/utils/quantization_util.py b/mmrazor/models/utils/quantization_util.py index 376096b67..5593572ce 100644 --- a/mmrazor/models/utils/quantization_util.py +++ b/mmrazor/models/utils/quantization_util.py @@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Set import torch +import torch.distributed as dist class PerChannelLoadHook: @@ -40,25 +41,15 @@ def close(self): self.hook.remove() -USE_LINK = False USE_DDP = False -try: - import spring.linklink as link - assert link.is_initialized() - USE_LINK = True -except (ModuleNotFoundError, AssertionError): - import torch.distributed as dist - if torch.distributed.is_initialized(): - USE_DDP = True +if torch.distributed.is_initialized(): + USE_DDP = True def sync_tensor(tensor): - if USE_LINK: - if tensor.is_cuda is True: - tensor.data = tensor.data / link.get_world_size() - link.allreduce(tensor.data) - elif USE_DDP: + + if USE_DDP: tensor.data = tensor.data / dist.get_world_size() dist.all_reduce(tensor.data) return tensor diff --git a/mmrazor/registry/registry.py b/mmrazor/registry/registry.py index 60b0a1e16..2f6efb185 100644 --- a/mmrazor/registry/registry.py +++ b/mmrazor/registry/registry.py @@ -30,7 +30,6 @@ from mmengine.registry import \ WEIGHT_INITIALIZERS as MMENGINE_WEIGHT_INITIALIZERS from mmengine.registry import Registry, build_from_cfg -from mmengine.runner import load_checkpoint def build_razor_model_from_cfg( @@ -44,26 +43,7 @@ def build_razor_model_from_cfg( model = get_model(**cfg) # type: ignore return model - from mmrazor.structures import load_fix_subnet - model = build_from_cfg(cfg, registry, default_args) - if cfg.get('_fix_subnet_', None): - fix_subnet = cfg.pop('_fix_subnet_') - # model is a mutable model - model = build_from_cfg(cfg, registry, default_args) - # after load_fix_subnet, model is a fixed model - load_fix_subnet(model, fix_subnet) - - if cfg.get('_export_compressed_', False): - - if cfg.get('_compressed_checkpoint_', None): - _ = load_checkpoint(model, cfg.get('_compressed_checkpoint_')) - - from mmrazor.models import GeneralQuant - if isinstance(model, GeneralQuant): - model = model.convert() - - model = model.architecture return model diff --git a/tools/debug.py b/tools/debug.py new file mode 100644 index 000000000..5d594cff8 --- /dev/null +++ b/tools/debug.py @@ -0,0 +1,162 @@ +import os +import sys +import time +import numpy as np +from tqdm import tqdm + +import torch +from torch.ao.quantization import get_default_qconfig +from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx +import torch.nn as nn +from torch.utils.data import DataLoader + +import torchvision +from torchvision import datasets +from torchvision.models import resnet18, ResNet18_Weights +import torchvision.transforms as transforms + +# Set up warnings +import warnings +warnings.filterwarnings( + action='ignore', + category=DeprecationWarning, + module=r'.*' +) +warnings.filterwarnings( + action='default', + module=r'torch.ao.quantization' +) + +# Specify random seed for repeatable results +_ = torch.manual_seed(191009) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def evaluate(model, criterion, data_loader, use_cuda=False): + model.eval() + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + cnt = 0 + with torch.no_grad(): + for image, target in data_loader: + if use_cuda: + image = image.cuda() + target = target.cuda() + output = model(image) + loss = criterion(output, target) + cnt += 1 + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + top1.update(acc1[0], image.size(0)) + top5.update(acc5[0], image.size(0)) + print('') + + return top1, top5 + +def print_size_of_model(model): + if isinstance(model, torch.jit.RecursiveScriptModule): + torch.jit.save(model, "temp.p") + else: + torch.jit.save(torch.jit.script(model), "temp.p") + print("Size (MB):", os.path.getsize("temp.p")/1e6) + os.remove("temp.p") + +def prepare_data_loaders(data_path, batch_size=4): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = torchvision.datasets.CIFAR10(data_path, train=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + testset = torchvision.datasets.CIFAR10(data_path, train=False, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + return trainloader, testloader + +data_path = '/nvme/dataset/cifar10/' +saved_model_dir = '../experiments/valid_convert' +os.makedirs(saved_model_dir, exist_ok=True) + +_, data_loader_test = prepare_data_loaders(data_path) +criterion = nn.CrossEntropyLoss() +weights = ResNet18_Weights.DEFAULT +float_model = resnet18(weights=weights).cuda() +float_model.eval() + + +# deepcopy the model since we need to keep the original model around +import copy +model_to_quantize = copy.deepcopy(float_model) +model_to_quantize.eval() +qconfig = get_default_qconfig("fbgemm") +qconfig_dict = {"": qconfig} +prepared_model = prepare_fx(model_to_quantize, qconfig_dict) +# print(prepared_model.graph) + +def calibrate(model, data_loader): + model.eval() + with torch.no_grad(): + for image, target in tqdm(data_loader): + model(image.cuda()) +calibrate(prepared_model, data_loader_test) + +quantized_model = convert_fx(prepared_model) +# print(quantized_model) + +print("Size of model before quantization") +print_size_of_model(float_model) +print("Size of model after quantization") +print_size_of_model(quantized_model) +top1, top5 = evaluate(float_model, criterion, data_loader_test, use_cuda=True) +print("[float model] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) +top1, top5 = evaluate(quantized_model, criterion, data_loader_test) +print("[before serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) + +fx_graph_mode_model_file_path = os.path.join(saved_model_dir, "resnet18_fx_graph_mode_quantized.pth") + +# save with state_dict +torch.save(quantized_model.state_dict(), fx_graph_mode_model_file_path) +import copy +model_to_quantize = copy.deepcopy(float_model) +prepared_model = prepare_fx(model_to_quantize, {"": qconfig}) +loaded_quantized_model = convert_fx(prepared_model) +loaded_quantized_model.load_state_dict(torch.load(fx_graph_mode_model_file_path)) +top1, top5 = evaluate(loaded_quantized_model, criterion, data_loader_test) +print("[after serilaization] Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top5.avg)) +print('pipeline finish') \ No newline at end of file