-
Notifications
You must be signed in to change notification settings - Fork 220
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Docs] Add docstring and unittest about backendconfig & observer & fa…
…kequant (#428) * add ut about backendconfig * add ut about observers and fakequants in torch * fix torch1.13 ci
- Loading branch information
Showing
8 changed files
with
123 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
18 changes: 18 additions & 0 deletions
18
tests/test_models/test_fake_quants/test_torch_fake_quants.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmrazor import digit_version | ||
from mmrazor.models.fake_quants import register_torch_fake_quants | ||
from mmrazor.registry import MODELS | ||
|
||
|
||
@pytest.mark.skipif( | ||
digit_version(torch.__version__) < digit_version('1.13.0'), | ||
reason='version of torch < 1.13.0') | ||
def test_register_torch_fake_quants(): | ||
|
||
TORCH_fake_quants = register_torch_fake_quants() | ||
assert isinstance(TORCH_fake_quants, list) | ||
for fake_quant in TORCH_fake_quants: | ||
assert MODELS.get(fake_quant) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmrazor import digit_version | ||
from mmrazor.models.observers import register_torch_observers | ||
from mmrazor.registry import MODELS | ||
|
||
|
||
@pytest.mark.skipif( | ||
digit_version(torch.__version__) < digit_version('1.13.0'), | ||
reason='version of torch < 1.13.0') | ||
def test_register_torch_observers(): | ||
|
||
TORCH_observers = register_torch_observers() | ||
assert isinstance(TORCH_observers, list) | ||
for observer in TORCH_observers: | ||
assert MODELS.get(observer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
try: | ||
from torch.ao.quantization.backend_config import BackendConfig | ||
except ImportError: | ||
from mmrazor.utils import get_placeholder | ||
BackendConfig = get_placeholder('torch>=1.13') | ||
|
||
import pytest | ||
import torch | ||
|
||
from mmrazor import digit_version | ||
from mmrazor.structures.quantization.backend_config import ( | ||
BackendConfigs, get_academic_backend_config, | ||
get_academic_backend_config_dict, get_native_backend_config, | ||
get_native_backend_config_dict, get_openvino_backend_config, | ||
get_openvino_backend_config_dict, get_tensorrt_backend_config, | ||
get_tensorrt_backend_config_dict) | ||
|
||
|
||
@pytest.mark.skipif( | ||
digit_version(torch.__version__) < digit_version('1.13.0'), | ||
reason='version of torch < 1.13.0') | ||
def test_get_backend_config(): | ||
|
||
# test get_native_backend_config | ||
native_backend_config = get_native_backend_config() | ||
assert isinstance(native_backend_config, BackendConfig) | ||
assert native_backend_config.name == 'native' | ||
native_backend_config_dict = get_native_backend_config_dict() | ||
assert isinstance(native_backend_config_dict, dict) | ||
|
||
# test get_academic_backend_config | ||
academic_backend_config = get_academic_backend_config() | ||
assert isinstance(academic_backend_config, BackendConfig) | ||
assert academic_backend_config.name == 'academic' | ||
academic_backend_config_dict = get_academic_backend_config_dict() | ||
assert isinstance(academic_backend_config_dict, dict) | ||
|
||
# test get_openvino_backend_config | ||
openvino_backend_config = get_openvino_backend_config() | ||
assert isinstance(openvino_backend_config, BackendConfig) | ||
assert openvino_backend_config.name == 'openvino' | ||
openvino_backend_config_dict = get_openvino_backend_config_dict() | ||
assert isinstance(openvino_backend_config_dict, dict) | ||
|
||
# test get_tensorrt_backend_config | ||
tensorrt_backend_config = get_tensorrt_backend_config() | ||
assert isinstance(tensorrt_backend_config, BackendConfig) | ||
assert tensorrt_backend_config.name == 'tensorrt' | ||
tensorrt_backend_config_dict = get_tensorrt_backend_config_dict() | ||
assert isinstance(tensorrt_backend_config_dict, dict) | ||
|
||
|
||
@pytest.mark.skipif( | ||
digit_version(torch.__version__) < digit_version('1.13.0'), | ||
reason='version of torch < 1.13.0') | ||
def test_backendconfigs_mapping(): | ||
|
||
mapping = BackendConfigs | ||
assert isinstance(mapping, dict) | ||
assert 'academic' in mapping.keys() | ||
assert isinstance(mapping['academic'], BackendConfig) |