diff --git a/mmrazor/models/algorithms/quantization/mm_architecture.py b/mmrazor/models/algorithms/quantization/mm_architecture.py index 20e6dbdb9..9e1d22b68 100644 --- a/mmrazor/models/algorithms/quantization/mm_architecture.py +++ b/mmrazor/models/algorithms/quantization/mm_architecture.py @@ -83,13 +83,8 @@ def reset_observer_and_fakequant_statistics(self, model): statistics in observers and fake quantizers. """ for module in model.modules(): - if isinstance(module, MinMaxObserver): + if isinstance(module, (MinMaxObserver, PerChannelMinMaxObserver)): module.reset_min_max_vals() - elif isinstance(module, PerChannelMinMaxObserver): - min_val = torch.rand(0, ) - max_val = torch.rand(0, ) - module.min_val.resize_(min_val.shape).copy_(min_val) - module.max_val.resize_(max_val.shape).copy_(max_val) elif isinstance(module, FakeQuantizeBase): module.scale.data = torch.ones_like(module.scale) module.zero_point.data = torch.zeros_like(module.zero_point) diff --git a/mmrazor/models/observers/torch_observers.py b/mmrazor/models/observers/torch_observers.py index 5dc24609f..0de628a9a 100644 --- a/mmrazor/models/observers/torch_observers.py +++ b/mmrazor/models/observers/torch_observers.py @@ -2,13 +2,33 @@ import inspect from typing import List +import torch + from mmrazor.registry import MODELS try: import torch.ao.quantization.observer as torch_observer_src + from torch.ao.quantization.observer import PerChannelMinMaxObserver except ImportError: from mmrazor.utils import get_package_placeholder torch_observer_src = get_package_placeholder('torch>=1.13') + UniformQuantizationObserverBase = get_package_placeholder('torch>=1.13') + + +@torch.jit.export +def reset_min_max_vals(self): + """Resets the min/max values. + + `min_val` and `max_val` are always be on cpu in the pytorch version of this + method. + """ + min_val = torch.rand(0, ) + max_val = torch.rand(0, ) + self.min_val.resize_(min_val.shape).copy_(min_val) + self.max_val.resize_(max_val.shape).copy_(max_val) + + +PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals def register_torch_observers() -> List[str]: