Skip to content

Commit

Permalink
fix pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed May 2, 2022
1 parent efab980 commit 4a4f7ce
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
19 changes: 18 additions & 1 deletion tests/test_models/test_algorithms/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import mmcv
import numpy as np
import pytest
import torch
from mmcv import Config, ConfigDict
from mmcv import Config, ConfigDict, digit_version

from mmrazor.models.builder import ALGORITHMS

Expand Down Expand Up @@ -95,6 +96,14 @@ def test_autoslim_pretrain():
pruner=pruner_cfg,
distiller=distiller_cfg)

# ``StructurePruner`` requires pytorch>=1.6.0 to
# auto-trace correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
model = ALGORITHMS.build(algorithm_cfg)
return

imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))

Expand Down Expand Up @@ -147,6 +156,14 @@ def test_autoslim_retrain():
retraining=True,
channel_cfg=channel_cfg)

# ``StructurePruner`` requires pytorch>=1.6.0 to
# auto-trace correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
model = ALGORITHMS.build(algorithm_cfg)
return

imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))

Expand Down
10 changes: 9 additions & 1 deletion tests/test_models/test_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import torch
from mmcv import ConfigDict
from mmcv import ConfigDict, digit_version

from mmrazor.models.builder import ARCHITECTURES, PRUNERS

Expand Down Expand Up @@ -35,6 +35,14 @@ def test_ratio_pruner():
type='RatioPruner',
ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0])

# ``StructurePruner`` requires pytorch>=1.6.0 to
# auto-trace correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
pruner = PRUNERS.build(pruner_cfg)
return

_test_reset_bn_running_stats(architecture_cfg, pruner_cfg, False)
with pytest.raises(AssertionError):
_test_reset_bn_running_stats(architecture_cfg, pruner_cfg, True)
Expand Down

0 comments on commit 4a4f7ce

Please sign in to comment.