Skip to content

Commit

Permalink
fix reset_min_max_vals
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Jan 13, 2023
1 parent 296930b commit afc2048
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
7 changes: 1 addition & 6 deletions mmrazor/models/algorithms/quantization/mm_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,8 @@ def reset_observer_and_fakequant_statistics(self, model):
statistics in observers and fake quantizers.
"""
for module in model.modules():
if isinstance(module, MinMaxObserver):
if isinstance(module, (MinMaxObserver, PerChannelMinMaxObserver)):
module.reset_min_max_vals()
elif isinstance(module, PerChannelMinMaxObserver):
min_val = torch.rand(0, )
max_val = torch.rand(0, )
module.min_val.resize_(min_val.shape).copy_(min_val)
module.max_val.resize_(max_val.shape).copy_(max_val)
elif isinstance(module, FakeQuantizeBase):
module.scale.data = torch.ones_like(module.scale)
module.zero_point.data = torch.zeros_like(module.zero_point)
Expand Down
20 changes: 20 additions & 0 deletions mmrazor/models/observers/torch_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,33 @@
import inspect
from typing import List

import torch

from mmrazor.registry import MODELS

try:
import torch.ao.quantization.observer as torch_observer_src
from torch.ao.quantization.observer import PerChannelMinMaxObserver
except ImportError:
from mmrazor.utils import get_package_placeholder
torch_observer_src = get_package_placeholder('torch>=1.13')
UniformQuantizationObserverBase = get_package_placeholder('torch>=1.13')


@torch.jit.export
def reset_min_max_vals(self):
"""Resets the min/max values.
`min_val` and `max_val` are always be on cpu in the pytorch version of this
method.
"""
min_val = torch.rand(0, )
max_val = torch.rand(0, )
self.min_val.resize_(min_val.shape).copy_(min_val)
self.max_val.resize_(max_val.shape).copy_(max_val)


PerChannelMinMaxObserver.reset_min_max_vals = reset_min_max_vals


def register_torch_observers() -> List[str]:
Expand Down

0 comments on commit afc2048

Please sign in to comment.