Skip to content

Commit

Permalink
fix bugs under pt13
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed Jan 14, 2023
1 parent afc2048 commit 4aad895
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion mmrazor/models/observers/torch_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
except ImportError:
from mmrazor.utils import get_package_placeholder
torch_observer_src = get_package_placeholder('torch>=1.13')
UniformQuantizationObserverBase = get_package_placeholder('torch>=1.13')
PerChannelMinMaxObserver = get_package_placeholder('torch>=1.13')


@torch.jit.export
Expand Down
5 changes: 3 additions & 2 deletions tests/test_runners/test_quantization_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mmengine.optim import OptimWrapper
from mmengine.registry import DATASETS, HOOKS, METRICS, MODELS, OPTIM_WRAPPERS
from mmengine.runner import Runner
from torch.ao.nn.quantized import FloatFunctional, FXFloatFunctional
from torch.nn.intrinsic.qat import ConvBnReLU2d
from torch.utils.data import Dataset

Expand Down Expand Up @@ -71,14 +72,14 @@ def swap_ff_with_fxff(self, model):

modules_to_swap = []
for name, module in model.named_children():
if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
if isinstance(module, FloatFunctional):
modules_to_swap.append(name)
else:
self.swap_ff_with_fxff(module)

for name in modules_to_swap:
del model._modules[name]
model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
model._modules[name] = FXFloatFunctional()

def sync_qparams(self, src_mode):
pass
Expand Down

0 comments on commit 4aad895

Please sign in to comment.