Skip to content

Commit

Permalink
[Docs] Add docstring and unittest about backendconfig & observer & fa…
Browse files Browse the repository at this point in the history
…kequant (#428)

* add ut about backendconfig

* add ut about observers and fakequants in torch

* fix torch1.13 ci
  • Loading branch information
humu789 committed Jan 16, 2023
1 parent 985a611 commit 5d50314
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 6 deletions.
9 changes: 7 additions & 2 deletions mmrazor/structures/quantization/backend_config/academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@


def get_academic_backend_config() -> BackendConfig:
"""Return the `BackendConfig` for academic reseaching."""
"""Return the `BackendConfig` for academic reseaching.
Note:
Learn more about BackendConfig, please refer to:
https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501
"""

# ===================
# | DTYPE CONFIGS |
Expand All @@ -34,7 +39,7 @@ def get_academic_backend_config() -> BackendConfig:
conv_dtype_configs = [weighted_op_int8_dtype_config]
linear_dtype_configs = [weighted_op_int8_dtype_config]

return BackendConfig('native') \
return BackendConfig('academic') \
.set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
.set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs))

Expand Down
8 changes: 6 additions & 2 deletions mmrazor/structures/quantization/backend_config/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@


def get_native_backend_config() -> BackendConfig:
"""Return the `BackendConfig` for PyTorch Native backend
(fbgemm/qnnpack)."""
"""Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
Note:
Learn more about BackendConfig, please refer to:
https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501
"""
# TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK
# BackendConfigs

Expand Down
7 changes: 6 additions & 1 deletion mmrazor/structures/quantization/backend_config/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@


def get_openvino_backend_config() -> BackendConfig:
"""Return the `BackendConfig` for the OpenVINO backend."""
"""Return the `BackendConfig` for the OpenVINO backend.
Note:
Learn more about BackendConfig, please refer to:
https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501
"""
# dtype configs
weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.quint8,
Expand Down
7 changes: 6 additions & 1 deletion mmrazor/structures/quantization/backend_config/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@


def get_tensorrt_backend_config() -> BackendConfig:
"""Return the `BackendConfig` for the TensorRT backend."""
"""Return the `BackendConfig` for the TensorRT backend.
Note:
Learn more about BackendConfig, please refer to:
https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501
"""
# dtype configs
weighted_op_qint8_dtype_config = DTypeConfig(
input_dtype=torch.qint8,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_models/test_fake_quants/test_torch_fake_quants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmrazor import digit_version
from mmrazor.models.fake_quants import register_torch_fake_quants
from mmrazor.registry import MODELS


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_register_torch_fake_quants():

TORCH_fake_quants = register_torch_fake_quants()
assert isinstance(TORCH_fake_quants, list)
for fake_quant in TORCH_fake_quants:
assert MODELS.get(fake_quant)
18 changes: 18 additions & 0 deletions tests/test_models/test_observers/test_torch_observers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmrazor import digit_version
from mmrazor.models.observers import register_torch_observers
from mmrazor.registry import MODELS


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_register_torch_observers():

TORCH_observers = register_torch_observers()
assert isinstance(TORCH_observers, list)
for observer in TORCH_observers:
assert MODELS.get(observer)
62 changes: 62 additions & 0 deletions tests/test_structures/test_backendconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
from torch.ao.quantization.backend_config import BackendConfig
except ImportError:
from mmrazor.utils import get_placeholder
BackendConfig = get_placeholder('torch>=1.13')

import pytest
import torch

from mmrazor import digit_version
from mmrazor.structures.quantization.backend_config import (
BackendConfigs, get_academic_backend_config,
get_academic_backend_config_dict, get_native_backend_config,
get_native_backend_config_dict, get_openvino_backend_config,
get_openvino_backend_config_dict, get_tensorrt_backend_config,
get_tensorrt_backend_config_dict)


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_get_backend_config():

# test get_native_backend_config
native_backend_config = get_native_backend_config()
assert isinstance(native_backend_config, BackendConfig)
assert native_backend_config.name == 'native'
native_backend_config_dict = get_native_backend_config_dict()
assert isinstance(native_backend_config_dict, dict)

# test get_academic_backend_config
academic_backend_config = get_academic_backend_config()
assert isinstance(academic_backend_config, BackendConfig)
assert academic_backend_config.name == 'academic'
academic_backend_config_dict = get_academic_backend_config_dict()
assert isinstance(academic_backend_config_dict, dict)

# test get_openvino_backend_config
openvino_backend_config = get_openvino_backend_config()
assert isinstance(openvino_backend_config, BackendConfig)
assert openvino_backend_config.name == 'openvino'
openvino_backend_config_dict = get_openvino_backend_config_dict()
assert isinstance(openvino_backend_config_dict, dict)

# test get_tensorrt_backend_config
tensorrt_backend_config = get_tensorrt_backend_config()
assert isinstance(tensorrt_backend_config, BackendConfig)
assert tensorrt_backend_config.name == 'tensorrt'
tensorrt_backend_config_dict = get_tensorrt_backend_config_dict()
assert isinstance(tensorrt_backend_config_dict, dict)


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_backendconfigs_mapping():

mapping = BackendConfigs
assert isinstance(mapping, dict)
assert 'academic' in mapping.keys()
assert isinstance(mapping['academic'], BackendConfig)

0 comments on commit 5d50314

Please sign in to comment.