Skip to content

Commit

Permalink
fix torch1.13 ci
Browse files Browse the repository at this point in the history
  • Loading branch information
humu789 committed Jan 16, 2023
1 parent 0863e1c commit b7c934a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mmrazor/structures/quantization/backend_config/academic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def get_academic_backend_config() -> BackendConfig:
"""Return the `BackendConfig` for academic reseaching.
Note:
Learn more about BackendConfig, please refer to:
https://github.com/pytorch/pytorch/tree/master/torch/ao/quantization/backend_config # noqa: E501
Expand Down
7 changes: 7 additions & 0 deletions tests/test_models/test_fake_quants/test_torch_fake_quants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmrazor import digit_version
from mmrazor.models.fake_quants import register_torch_fake_quants
from mmrazor.registry import MODELS


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_register_torch_fake_quants():

TORCH_fake_quants = register_torch_fake_quants()
Expand Down
7 changes: 7 additions & 0 deletions tests/test_models/test_observers/test_torch_observers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmrazor import digit_version
from mmrazor.models.observers import register_torch_observers
from mmrazor.registry import MODELS


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_register_torch_observers():

TORCH_observers = register_torch_observers()
Expand Down
18 changes: 17 additions & 1 deletion tests/test_structures/test_backendconfig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from torch.ao.quantization.backend_config import BackendConfig
try:
from torch.ao.quantization.backend_config import BackendConfig
except ImportError:
from mmrazor.utils import get_placeholder
BackendConfig = get_placeholder('torch>=1.13')

import pytest
import torch

from mmrazor import digit_version
from mmrazor.structures.quantization.backend_config import (
BackendConfigs, get_academic_backend_config,
get_academic_backend_config_dict, get_native_backend_config,
Expand All @@ -9,7 +17,11 @@
get_tensorrt_backend_config_dict)


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_get_backend_config():

# test get_native_backend_config
native_backend_config = get_native_backend_config()
assert isinstance(native_backend_config, BackendConfig)
Expand Down Expand Up @@ -39,7 +51,11 @@ def test_get_backend_config():
assert isinstance(tensorrt_backend_config_dict, dict)


@pytest.mark.skipif(
digit_version(torch.__version__) < digit_version('1.13.0'),
reason='version of torch < 1.13.0')
def test_backendconfigs_mapping():

mapping = BackendConfigs
assert isinstance(mapping, dict)
assert 'academic' in mapping.keys()
Expand Down

0 comments on commit b7c934a

Please sign in to comment.