From 5d50314221958e3027af99e640abcb02e1c76dab Mon Sep 17 00:00:00 2001 From: humu789 <88702197+humu789@users.noreply.github.com> Date: Mon, 16 Jan 2023 18:36:32 +0800 Subject: [PATCH] [Docs] Add docstring and unittest about backendconfig & observer & fakequant (#428) * add ut about backendconfig * add ut about observers and fakequants in torch * fix torch1.13 ci --- .../quantization/backend_config/academic.py | 9 ++- .../quantization/backend_config/native.py | 8 ++- .../quantization/backend_config/openvino.py | 7 ++- .../quantization/backend_config/tensorrt.py | 7 ++- .../test_lsq_fake_quants.py | 0 .../test_torch_fake_quants.py | 18 ++++++ .../test_observers/test_torch_observers.py | 18 ++++++ tests/test_structures/test_backendconfig.py | 62 +++++++++++++++++++ 8 files changed, 123 insertions(+), 6 deletions(-) rename tests/test_models/{test_fake_quantize => test_fake_quants}/test_lsq_fake_quants.py (100%) create mode 100644 tests/test_models/test_fake_quants/test_torch_fake_quants.py create mode 100644 tests/test_models/test_observers/test_torch_observers.py create mode 100644 tests/test_structures/test_backendconfig.py diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py index 4348e7179..6b4f0d598 100644 --- a/mmrazor/structures/quantization/backend_config/academic.py +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -17,7 +17,12 @@ def get_academic_backend_config() -> BackendConfig: - """Return the `BackendConfig` for academic reseaching.""" + """Return the `BackendConfig` for academic reseaching. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # =================== # | DTYPE CONFIGS | @@ -34,7 +39,7 @@ def get_academic_backend_config() -> BackendConfig: conv_dtype_configs = [weighted_op_int8_dtype_config] linear_dtype_configs = [weighted_op_int8_dtype_config] - return BackendConfig('native') \ + return BackendConfig('academic') \ .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \ .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) diff --git a/mmrazor/structures/quantization/backend_config/native.py b/mmrazor/structures/quantization/backend_config/native.py index 94c35d535..59085a56a 100644 --- a/mmrazor/structures/quantization/backend_config/native.py +++ b/mmrazor/structures/quantization/backend_config/native.py @@ -20,8 +20,12 @@ def get_native_backend_config() -> BackendConfig: - """Return the `BackendConfig` for PyTorch Native backend - (fbgemm/qnnpack).""" + """Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack). + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK # BackendConfigs diff --git a/mmrazor/structures/quantization/backend_config/openvino.py b/mmrazor/structures/quantization/backend_config/openvino.py index d990d4ef9..5e3051f75 100644 --- a/mmrazor/structures/quantization/backend_config/openvino.py +++ b/mmrazor/structures/quantization/backend_config/openvino.py @@ -20,7 +20,12 @@ def get_openvino_backend_config() -> BackendConfig: - """Return the `BackendConfig` for the OpenVINO backend.""" + """Return the `BackendConfig` for the OpenVINO backend. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # dtype configs weighted_op_qint8_dtype_config = DTypeConfig( input_dtype=torch.quint8, diff --git a/mmrazor/structures/quantization/backend_config/tensorrt.py b/mmrazor/structures/quantization/backend_config/tensorrt.py index 53305f650..791463233 100644 --- a/mmrazor/structures/quantization/backend_config/tensorrt.py +++ b/mmrazor/structures/quantization/backend_config/tensorrt.py @@ -20,7 +20,12 @@ def get_tensorrt_backend_config() -> BackendConfig: - """Return the `BackendConfig` for the TensorRT backend.""" + """Return the `BackendConfig` for the TensorRT backend. + + Note: + Learn more about BackendConfig, please refer to: + https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 + """ # dtype configs weighted_op_qint8_dtype_config = DTypeConfig( input_dtype=torch.qint8, diff --git a/tests/test_models/test_fake_quantize/test_lsq_fake_quants.py b/tests/test_models/test_fake_quants/test_lsq_fake_quants.py similarity index 100% rename from tests/test_models/test_fake_quantize/test_lsq_fake_quants.py rename to tests/test_models/test_fake_quants/test_lsq_fake_quants.py diff --git a/tests/test_models/test_fake_quants/test_torch_fake_quants.py b/tests/test_models/test_fake_quants/test_torch_fake_quants.py new file mode 100644 index 000000000..485113e90 --- /dev/null +++ b/tests/test_models/test_fake_quants/test_torch_fake_quants.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.models.fake_quants import register_torch_fake_quants +from mmrazor.registry import MODELS + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_register_torch_fake_quants(): + + TORCH_fake_quants = register_torch_fake_quants() + assert isinstance(TORCH_fake_quants, list) + for fake_quant in TORCH_fake_quants: + assert MODELS.get(fake_quant) diff --git a/tests/test_models/test_observers/test_torch_observers.py b/tests/test_models/test_observers/test_torch_observers.py new file mode 100644 index 000000000..cc32e69d8 --- /dev/null +++ b/tests/test_models/test_observers/test_torch_observers.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.models.observers import register_torch_observers +from mmrazor.registry import MODELS + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_register_torch_observers(): + + TORCH_observers = register_torch_observers() + assert isinstance(TORCH_observers, list) + for observer in TORCH_observers: + assert MODELS.get(observer) diff --git a/tests/test_structures/test_backendconfig.py b/tests/test_structures/test_backendconfig.py new file mode 100644 index 000000000..24295e391 --- /dev/null +++ b/tests/test_structures/test_backendconfig.py @@ -0,0 +1,62 @@ +# Copyright (c) OpenMMLab. All rights reserved. +try: + from torch.ao.quantization.backend_config import BackendConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') + +import pytest +import torch + +from mmrazor import digit_version +from mmrazor.structures.quantization.backend_config import ( + BackendConfigs, get_academic_backend_config, + get_academic_backend_config_dict, get_native_backend_config, + get_native_backend_config_dict, get_openvino_backend_config, + get_openvino_backend_config_dict, get_tensorrt_backend_config, + get_tensorrt_backend_config_dict) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_get_backend_config(): + + # test get_native_backend_config + native_backend_config = get_native_backend_config() + assert isinstance(native_backend_config, BackendConfig) + assert native_backend_config.name == 'native' + native_backend_config_dict = get_native_backend_config_dict() + assert isinstance(native_backend_config_dict, dict) + + # test get_academic_backend_config + academic_backend_config = get_academic_backend_config() + assert isinstance(academic_backend_config, BackendConfig) + assert academic_backend_config.name == 'academic' + academic_backend_config_dict = get_academic_backend_config_dict() + assert isinstance(academic_backend_config_dict, dict) + + # test get_openvino_backend_config + openvino_backend_config = get_openvino_backend_config() + assert isinstance(openvino_backend_config, BackendConfig) + assert openvino_backend_config.name == 'openvino' + openvino_backend_config_dict = get_openvino_backend_config_dict() + assert isinstance(openvino_backend_config_dict, dict) + + # test get_tensorrt_backend_config + tensorrt_backend_config = get_tensorrt_backend_config() + assert isinstance(tensorrt_backend_config, BackendConfig) + assert tensorrt_backend_config.name == 'tensorrt' + tensorrt_backend_config_dict = get_tensorrt_backend_config_dict() + assert isinstance(tensorrt_backend_config_dict, dict) + + +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') +def test_backendconfigs_mapping(): + + mapping = BackendConfigs + assert isinstance(mapping, dict) + assert 'academic' in mapping.keys() + assert isinstance(mapping['academic'], BackendConfig)