diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 0de628a9a..996314d27 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -12,7 +12,7 @@ except ImportError: from mmrazor.utils import get_package_placeholder torch_observer_src = get_package_placeholder('torch>=1.13') - UniformQuantizationObserverBase = get_package_placeholder('torch>=1.13') + PerChannelMinMaxObserver = get_package_placeholder('torch>=1.13') @torch.jit.export diff --git a/tests/test_runners/test_quantization_loop.py b/tests/test_runners/test_quantization_loop.py index 7a15a5ccb..bc46e0cca 100644 --- a/tests/test_runners/test_quantization_loop.py +++ b/tests/test_runners/test_quantization_loop.py @@ -15,6 +15,7 @@ from mmengine.optim import OptimWrapper from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS, OPTIM_WRAPPERS from mmengine.runner import Runner +from torch.ao.nn.quantized import FloatFunctional, FXFloatFunctional from torch.nn.intrinsic.qat import ConvBnReLU2d from torch.utils.data import Dataset @@ -71,14 +72,14 @@ def swap_ff_with_fxff(self, model): modules_to_swap = [] for name, module in model.named_children(): - if isinstance(module, torch.ao.nn.quantized.FloatFunctional): + if isinstance(module, FloatFunctional): modules_to_swap.append(name) else: self.swap_ff_with_fxff(module) for name in modules_to_swap: del model._modules[name] - model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional() + model._modules[name] = FXFloatFunctional() def sync_qparams(self, src_mode): pass