From 4a4f7ce181a789eeefd78f5003c268a86af21658 Mon Sep 17 00:00:00 2001 From: whcao <2892770585@qq.com> Date: Mon, 2 May 2022 20:46:27 +0800 Subject: [PATCH] fix pytest --- .../test_algorithms/test_algorithm.py | 19 ++++++++++++++++++- tests/test_models/test_pruner.py | 10 +++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/test_models/test_algorithms/test_algorithm.py b/tests/test_models/test_algorithms/test_algorithm.py index 1869c94e3..839629f7f 100644 --- a/tests/test_models/test_algorithms/test_algorithm.py +++ b/tests/test_models/test_algorithms/test_algorithm.py @@ -5,8 +5,9 @@ import mmcv import numpy as np +import pytest import torch -from mmcv import Config, ConfigDict +from mmcv import Config, ConfigDict, digit_version from mmrazor.models.builder import ALGORITHMS @@ -95,6 +96,14 @@ def test_autoslim_pretrain(): pruner=pruner_cfg, distiller=distiller_cfg) + # ``StructurePruner`` requires pytorch>=1.6.0 to + # auto-trace correctly + min_required_version = '1.6.0' + if digit_version(torch.__version__) < digit_version(min_required_version): + with pytest.raises(AssertionError): + model = ALGORITHMS.build(algorithm_cfg) + return + imgs = torch.randn(16, 3, 224, 224) label = torch.randint(0, 1000, (16, )) @@ -147,6 +156,14 @@ def test_autoslim_retrain(): retraining=True, channel_cfg=channel_cfg) + # ``StructurePruner`` requires pytorch>=1.6.0 to + # auto-trace correctly + min_required_version = '1.6.0' + if digit_version(torch.__version__) < digit_version(min_required_version): + with pytest.raises(AssertionError): + model = ALGORITHMS.build(algorithm_cfg) + return + imgs = torch.randn(16, 3, 224, 224) label = torch.randint(0, 1000, (16, )) diff --git a/tests/test_models/test_pruner.py b/tests/test_models/test_pruner.py index 94bc49c6e..f16970022 100644 --- a/tests/test_models/test_pruner.py +++ b/tests/test_models/test_pruner.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv import ConfigDict +from mmcv import ConfigDict, digit_version from mmrazor.models.builder import ARCHITECTURES, PRUNERS @@ -35,6 +35,14 @@ def test_ratio_pruner(): type='RatioPruner', ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0]) + # ``StructurePruner`` requires pytorch>=1.6.0 to + # auto-trace correctly + min_required_version = '1.6.0' + if digit_version(torch.__version__) < digit_version(min_required_version): + with pytest.raises(AssertionError): + pruner = PRUNERS.build(pruner_cfg) + return + _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, False) with pytest.raises(AssertionError): _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, True)