From b7c934aa5976df172cf7db7622115dd81faf113b Mon Sep 17 00:00:00 2001 From: humu789 Date: Mon, 16 Jan 2023 18:20:55 +0800 Subject: [PATCH] fix torch1.13 ci --- .../quantization/backend_config/academic.py | 2 +- .../test_fake_quants/test_torch_fake_quants.py | 7 +++++++ .../test_observers/test_torch_observers.py | 7 +++++++ tests/test_structures/test_backendconfig.py | 18 +++++++++++++++++- 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mmrazor/structures/quantization/backend_config/academic.py b/mmrazor/structures/quantization/backend_config/academic.py index 297e49e87..6b4f0d598 100644 --- a/mmrazor/structures/quantization/backend_config/academic.py +++ b/mmrazor/structures/quantization/backend_config/academic.py @@ -18,7 +18,7 @@ def get_academic_backend_config() -> BackendConfig: """Return the `BackendConfig` for academic reseaching. - + Note: Learn more about BackendConfig, please refer to: https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501 diff --git a/tests/test_models/test_fake_quants/test_torch_fake_quants.py b/tests/test_models/test_fake_quants/test_torch_fake_quants.py index d1ded6c6a..485113e90 100644 --- a/tests/test_models/test_fake_quants/test_torch_fake_quants.py +++ b/tests/test_models/test_fake_quants/test_torch_fake_quants.py @@ -1,8 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version from mmrazor.models.fake_quants import register_torch_fake_quants from mmrazor.registry import MODELS +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') def test_register_torch_fake_quants(): TORCH_fake_quants = register_torch_fake_quants() diff --git a/tests/test_models/test_observers/test_torch_observers.py b/tests/test_models/test_observers/test_torch_observers.py index 44ecbeb1d..cc32e69d8 100644 --- a/tests/test_models/test_observers/test_torch_observers.py +++ b/tests/test_models/test_observers/test_torch_observers.py @@ -1,8 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmrazor import digit_version from mmrazor.models.observers import register_torch_observers from mmrazor.registry import MODELS +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') def test_register_torch_observers(): TORCH_observers = register_torch_observers() diff --git a/tests/test_structures/test_backendconfig.py b/tests/test_structures/test_backendconfig.py index 62b865bff..24295e391 100644 --- a/tests/test_structures/test_backendconfig.py +++ b/tests/test_structures/test_backendconfig.py @@ -1,6 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from torch.ao.quantization.backend_config import BackendConfig +try: + from torch.ao.quantization.backend_config import BackendConfig +except ImportError: + from mmrazor.utils import get_placeholder + BackendConfig = get_placeholder('torch>=1.13') +import pytest +import torch + +from mmrazor import digit_version from mmrazor.structures.quantization.backend_config import ( BackendConfigs, get_academic_backend_config, get_academic_backend_config_dict, get_native_backend_config, @@ -9,7 +17,11 @@ get_tensorrt_backend_config_dict) +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') def test_get_backend_config(): + # test get_native_backend_config native_backend_config = get_native_backend_config() assert isinstance(native_backend_config, BackendConfig) @@ -39,7 +51,11 @@ def test_get_backend_config(): assert isinstance(tensorrt_backend_config_dict, dict) +@pytest.mark.skipif( + digit_version(torch.__version__) < digit_version('1.13.0'), + reason='version of torch < 1.13.0') def test_backendconfigs_mapping(): + mapping = BackendConfigs assert isinstance(mapping, dict) assert 'academic' in mapping.keys()