Skip to content

Commit

Permalink
throw an error when tracing groupnorm with torch version under 1.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
HIT-cwh committed May 3, 2022
1 parent 4a4f7ce commit 29e1f4d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 41 deletions.
15 changes: 7 additions & 8 deletions mmrazor/models/pruners/structure_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.nn as nn
from mmcv import digit_version
from mmcv.runner import BaseModule
from ordered_set import OrderedSet
from torch.nn.modules import GroupNorm
Expand Down Expand Up @@ -83,14 +84,6 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
"""

def __init__(self, except_start_keys=['head.fc']):

from mmcv import digit_version

min_required_version = '1.6.0'
assert digit_version(torch.__version__) >= digit_version(
min_required_version
), f'Requires to install pytorch>={min_required_version}'

super(StructurePruner, self).__init__()
if except_start_keys is None:
self.except_start_keys = list()
Expand Down Expand Up @@ -131,6 +124,12 @@ def prepare_from_supernet(self, supernet):
tmp_shared_module_hook_handles = list()

for name, module in supernet.model.named_modules():
if isinstance(module, nn.GroupNorm):
min_required_version = '1.6.0'
assert digit_version(torch.__version__) >= digit_version(
min_required_version
), f'Requires pytorch>={min_required_version} to auto-trace' \
f'GroupNorm correctly.'
if hasattr(module, 'weight'):
# trace shared modules
module.cnt = 0
Expand Down
19 changes: 1 addition & 18 deletions tests/test_models/test_algorithms/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import mmcv
import numpy as np
import pytest
import torch
from mmcv import Config, ConfigDict, digit_version
from mmcv import Config, ConfigDict

from mmrazor.models.builder import ALGORITHMS

Expand Down Expand Up @@ -96,14 +95,6 @@ def test_autoslim_pretrain():
pruner=pruner_cfg,
distiller=distiller_cfg)

# ``StructurePruner`` requires pytorch>=1.6.0 to
# auto-trace correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
model = ALGORITHMS.build(algorithm_cfg)
return

imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))

Expand Down Expand Up @@ -156,14 +147,6 @@ def test_autoslim_retrain():
retraining=True,
channel_cfg=channel_cfg)

# ``StructurePruner`` requires pytorch>=1.6.0 to
# auto-trace correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
model = ALGORITHMS.build(algorithm_cfg)
return

imgs = torch.randn(16, 3, 224, 224)
label = torch.randint(0, 1000, (16, ))

Expand Down
29 changes: 14 additions & 15 deletions tests/test_models/test_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,6 @@ def test_ratio_pruner():
type='RatioPruner',
ratios=[1 / 8, 2 / 8, 3 / 8, 4 / 8, 5 / 8, 6 / 8, 7 / 8, 1.0])

# ``StructurePruner`` requires pytorch>=1.6.0 to
# auto-trace correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
pruner = PRUNERS.build(pruner_cfg)
return

_test_reset_bn_running_stats(architecture_cfg, pruner_cfg, False)
with pytest.raises(AssertionError):
_test_reset_bn_running_stats(architecture_cfg, pruner_cfg, True)
Expand Down Expand Up @@ -273,13 +265,20 @@ def test_ratio_pruner():
model=model_cfg,
)

architecture = ARCHITECTURES.build(architecture_cfg)
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)
# ``StructurePruner`` requires pytorch>=1.6.0 to auto-trace GroupNorm
# correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
pruner = PRUNERS.build(pruner_cfg)
else:
architecture = ARCHITECTURES.build(architecture_cfg)
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)


def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail):
Expand Down

0 comments on commit 29e1f4d

Please sign in to comment.