Skip to content

Commit

Permalink
torch.quantization --> torch.ao.quantization in vision for facebookre…
Browse files Browse the repository at this point in the history
…search/fvcore (#86)

Summary:
Pull Request resolved: #86

setting up vision for ao migration, uses TORCH_VERSION to handle BC

Reviewed By: z-a-f

Differential Revision: D31436846

fbshipit-source-id: 0a13c5832bc6b45ec8fe7d3721ada8b2f9d420f3
  • Loading branch information
HDCharles authored and facebook-github-bot committed Oct 15, 2021
1 parent 4525b81 commit 4a39fce
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
16 changes: 12 additions & 4 deletions fvcore/common/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
from torch.nn.parallel import DataParallel, DistributedDataParallel


TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 11):
from torch.ao import quantization
from torch.ao.quantization import ObserverBase, FakeQuantizeBase
else:
from torch import quantization
from torch.quantization import ObserverBase, FakeQuantizeBase

__all__ = ["Checkpointer", "PeriodicCheckpointer"]


Expand Down Expand Up @@ -282,8 +290,8 @@ def _load_model(self, checkpoint: Any) -> _IncompatibleKeys:

has_observer_base_classes = (
TORCH_VERSION >= (1, 8)
and hasattr(torch.quantization, "ObserverBase")
and hasattr(torch.quantization, "FakeQuantizeBase")
and hasattr(quantization, "ObserverBase")
and hasattr(quantization, "FakeQuantizeBase")
)
if has_observer_base_classes:
# Handle the special case of quantization per channel observers,
Expand All @@ -299,8 +307,8 @@ def _get_module_for_key(
return cur_module

cls_to_skip = (
torch.quantization.ObserverBase,
torch.quantization.FakeQuantizeBase,
ObserverBase,
FakeQuantizeBase,
)
target_module = _get_module_for_key(self.model, k)
if isinstance(target_module, cls_to_skip):
Expand Down
31 changes: 25 additions & 6 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,31 @@
import unittest
from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Tuple
from unittest.mock import MagicMock

import torch
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
from torch import nn

TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
if TORCH_VERSION >= (1, 11):
from torch.ao import quantization
from torch.ao.quantization import (
get_default_qat_qconfig,
prepare_qat,
disable_observer,
enable_fake_quant,
)
else:
from torch import quantization
from torch.quantization import (
get_default_qat_qconfig,
prepare_qat,
disable_observer,
enable_fake_quant,
)


class TestCheckpointer(unittest.TestCase):
def _create_model(self) -> nn.Module:
Expand Down Expand Up @@ -46,15 +65,15 @@ def _create_complex_model(
return m, state_dict

@unittest.skipIf( # pyre-fixme[56]
(not hasattr(torch.quantization, "ObserverBase"))
or (not hasattr(torch.quantization, "FakeQuantizeBase")),
(not hasattr(quantization, "ObserverBase"))
or (not hasattr(quantization, "FakeQuantizeBase")),
"quantization per-channel observer base classes not supported",
)
def test_loading_objects_with_expected_shape_mismatches(self) -> None:
def _get_model() -> torch.nn.Module:
m = nn.Sequential(nn.Conv2d(2, 2, 1))
m.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
m = torch.quantization.prepare_qat(m)
m.qconfig = get_default_qat_qconfig("fbgemm")
m = prepare_qat(m)
return m

m1, m2 = _get_model(), _get_model()
Expand All @@ -77,8 +96,8 @@ def _get_model() -> torch.nn.Module:
# Run the expected input through the network with observers
# disabled and fake_quant enabled. If buffers were loaded correctly
# into per-channel observers, this line will not crash.
m2.apply(torch.quantization.disable_observer)
m2.apply(torch.quantization.enable_fake_quant)
m2.apply(disable_observer)
m2.apply(enable_fake_quant)
m2(torch.randn(4, 2, 4, 4))

def test_from_last_checkpoint_model(self) -> None:
Expand Down

0 comments on commit 4a39fce

Please sign in to comment.