Skip to content

Commit

Permalink
[BUG]Support to prune models containing GroupNorm or InstanceNorm. (#144
Browse files Browse the repository at this point in the history
)

* suport GN and IN

* test pruner

* limit pytorch version

* fix pytest

* throw an error when tracing groupnorm with torch version under 1.6.0

Co-authored-by: caoweihan <caoweihan@sensetime.com>
  • Loading branch information
HIT-cwh and caoweihan committed May 4, 2022
1 parent 33f8425 commit 16681a0
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 42 deletions.
17 changes: 17 additions & 0 deletions mmrazor/models/pruners/ratio_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules import GroupNorm

from mmrazor.models.builder import PRUNERS
from .structure_pruning import StructurePruner
Expand Down Expand Up @@ -31,6 +32,22 @@ def __init__(self, ratios, **kwargs):
self.ratios = ratios
self.min_ratio = ratios[0]

def _check_pruner(self, supernet):
for module in supernet.model.modules():
if isinstance(module, GroupNorm):
num_channels = module.num_channels
num_groups = module.num_groups
for ratio in self.ratios:
new_channels = int(round(num_channels * ratio))
assert (num_channels * ratio) % num_groups == 0, \
f'Expected number of channels in input of GroupNorm ' \
f'to be divisible by num_groups, but number of ' \
f'channels may be {new_channels} according to ' \
f'ratio {ratio} and num_groups={num_groups}'

def prepare_from_supernet(self, supernet):
super(RatioPruner, self).prepare_from_supernet(supernet)

def get_channel_mask(self, out_mask):
"""Randomly choose a width ratio of a layer from ``ratios``"""
out_channels = out_mask.size(1)
Expand Down
126 changes: 88 additions & 38 deletions mmrazor/models/pruners/structure_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import torch
import torch.nn as nn
from mmcv import digit_version
from mmcv.runner import BaseModule
from ordered_set import OrderedSet
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm

from mmrazor.models.builder import PRUNERS
from .utils import SwitchableBatchNorm2d
Expand All @@ -19,14 +22,13 @@
FC = ('ThAddmmBackward', 'AddmmBackward', 'MmBackward')
BN = ('ThnnBatchNormBackward', 'CudnnBatchNormBackward',
'NativeBatchNormBackward')
GN = ('NativeGroupNormBackward', )
CONCAT = ('CatBackward', )
# the modules which contains NON_PASS grad_fn need to change the parameter size
# according to channels after pruning
NON_PASS = CONV + FC
NON_PASS_MODULE = (nn.Conv2d, nn.Linear)

PASS = BN
PASS_MODULE = (_BatchNorm)
PASS = BN + GN
NORM = BN + GN

BACKWARD_PARSER_DICT = dict()
MAKE_GROUP_PARSER_DICT = dict()
Expand Down Expand Up @@ -122,6 +124,12 @@ def prepare_from_supernet(self, supernet):
tmp_shared_module_hook_handles = list()

for name, module in supernet.model.named_modules():
if isinstance(module, nn.GroupNorm):
min_required_version = '1.6.0'
assert digit_version(torch.__version__) >= digit_version(
min_required_version
), f'Requires pytorch>={min_required_version} to auto-trace' \
f'GroupNorm correctly.'
if hasattr(module, 'weight'):
# trace shared modules
module.cnt = 0
Expand Down Expand Up @@ -172,10 +180,10 @@ def prepare_from_supernet(self, supernet):
self.trace_non_pass_path(pseudo_loss.grad_fn, module2name, var2module,
cur_non_pass_path, non_pass_paths, visited)

bn_conv_links = dict()
self.trace_bn_conv_links(pseudo_loss.grad_fn, module2name, var2module,
bn_conv_links, visited)
self.bn_conv_links = bn_conv_links
norm_conv_links = dict()
self.trace_norm_conv_links(pseudo_loss.grad_fn, module2name,
var2module, norm_conv_links, visited)
self.norm_conv_links = norm_conv_links

# a node can be the name of a conv module or a str like 'concat_{id}'
node2parents = self.find_node_parents(non_pass_paths)
Expand Down Expand Up @@ -268,12 +276,12 @@ def set_subnet(self, subnet_dict):
module = self.name2module[module_name]
module.out_mask = subnet_dict[space_id].to(module.out_mask.device)

for bn, conv in self.bn_conv_links.items():
module = self.name2module[bn]
for norm, conv in self.norm_conv_links.items():
module = self.name2module[norm]
conv_space_id = self.get_space_id(conv)
# conv_space_id is None means the conv layer in front of
# this bn module can not be pruned. So we should not set
# the out_mask of this bn layer
# this normalization module can not be pruned. So we should not set
# the out_mask of this normalization layer
if conv_space_id is not None:
module.out_mask = subnet_dict[conv_space_id].to(
module.out_mask.device)
Expand Down Expand Up @@ -458,7 +466,9 @@ def add_pruning_attrs(self, module):
module.register_buffer(
'out_mask', module.weight.new_ones((1, module.out_features), ))
module.forward = self.modify_fc_forward(module)
if isinstance(module, nn.modules.batchnorm._BatchNorm):
if (isinstance(module, _BatchNorm)
or isinstance(module, _InstanceNorm)
or isinstance(module, GroupNorm)):
module.register_buffer(
'out_mask',
module.weight.new_ones((1, len(module.weight), 1, 1), ))
Expand Down Expand Up @@ -625,39 +635,79 @@ def trace_non_pass_path(self, grad_fn, module2name, var2module, cur_path,
else:
result_paths.append(copy.deepcopy(cur_path))

def trace_bn_conv_links(self, grad_fn, module2name, var2module,
bn_conv_links, visited):
"""Get the convolutional layer placed before a bn layer in the model.
def trace_norm_conv_links(self, grad_fn, module2name, var2module,
norm_conv_links, visited):
"""Get the convolutional layer placed before a normalization layer in
the model.
Example:
>>> conv = nn.Conv2d(3, 3, 3)
>>> bn = nn.BatchNorm2d(3)
>>> norm = nn.BatchNorm2d(3)
>>> pseudo_img = torch.rand(1, 3, 224, 224)
>>> out = bn(conv(pseudo_img))
>>> out = norm(conv(pseudo_img))
>>> print(out.grad_fn)
<NativeBatchNormBackward object at 0x0000022BC709DB08>
>>> print(out.grad_fn.next_functions)
((<ThnnConv2DBackward object at 0x0000020E40639688>, 0),
(<AccumulateGrad object at 0x0000020E40639208>, 0),
(<AccumulateGrad object at 0x0000020E406398C8>, 0))
>>> # op.next_functions[0][0] is ThnnConv2DBackward means
>>> # the parent of this NativeBatchNormBackward op is
>>> # ThnnConv2DBackward
>>> # op.next_functions[1][0].variable is the weight of this bn
>>> # module
>>> # op.next_functions[2][0].variable is the bias of this bn
>>> # module
>>> # op.next_functions[1][0].variable is the weight of this
>>> # normalization module
>>> # op.next_functions[2][0].variable is the bias of this
>>> # normalization module
>>> # Things are different in InstanceNorm
>>> conv = nn.Conv2d(3, 3, 3)
>>> norm = nn.InstanceNorm2d(3, affine=True)
>>> out = norm(conv(pseudo_img))
>>> print(out.grad_fn)
<ViewBackward object at 0x0000022BC709DD48>
>>> print(out.grad_fn.next_functions)
((<NativeBatchNormBackward object at 0x0000022BC81E8A08>, 0),)
>>> print(out.grad_fn.next_functions[0][0].next_functions)
((<ViewBackward object at 0x0000022BC81E8DC8>, 0),
(<RepeatBackward object at 0x0000022BC81E8D08>, 0),
(<RepeatBackward object at 0x0000022BC81E81C8>, 0))
>>> # Hence, a dfs is necessary.
"""
grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
if grad_fn is not None:
is_bn_grad_fn = False
for fn_name in BN:

def is_norm_grad_fn(grad_fn):
for fn_name in NORM:
if type(grad_fn).__name__.startswith(fn_name):
is_bn_grad_fn = True
break
return True
return False

def is_conv_grad_fn(grad_fn):
for fn_name in CONV:
if type(grad_fn).__name__.startswith(fn_name):
return True
return False

if is_bn_grad_fn:
def is_leaf_grad_fn(grad_fn):
if type(grad_fn).__name__ == 'AccumulateGrad':
return True
return False

grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
if grad_fn is not None:
if is_norm_grad_fn(grad_fn):
conv_grad_fn = grad_fn.next_functions[0][0]
conv_var = conv_grad_fn.next_functions[1][0].variable
bn_var = grad_fn.next_functions[1][0].variable
while not is_conv_grad_fn(conv_grad_fn):
conv_grad_fn = conv_grad_fn.next_functions[0][0]

leaf_grad_fn = conv_grad_fn.next_functions[1][0]
while not is_leaf_grad_fn(leaf_grad_fn):
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
conv_var = leaf_grad_fn.variable

leaf_grad_fn = grad_fn.next_functions[1][0]
while not is_leaf_grad_fn(leaf_grad_fn):
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
bn_var = leaf_grad_fn.variable

conv_module = var2module[id(conv_var)]
bn_module = var2module[id(bn_var)]
conv_name = module2name[conv_module]
Expand All @@ -666,20 +716,20 @@ def trace_bn_conv_links(self, grad_fn, module2name, var2module,
pass
else:
visited[bn_name] = True
bn_conv_links[bn_name] = conv_name
norm_conv_links[bn_name] = conv_name

self.trace_bn_conv_links(conv_grad_fn, module2name,
var2module, bn_conv_links,
visited)
self.trace_norm_conv_links(conv_grad_fn, module2name,
var2module, norm_conv_links,
visited)

else:
# If the op is AccumulateGrad, parents is (),
parents = grad_fn.next_functions
if parents is not None:
for parent in parents:
self.trace_bn_conv_links(parent, module2name,
var2module, bn_conv_links,
visited)
self.trace_norm_conv_links(parent, module2name,
var2module, norm_conv_links,
visited)

def find_backward_parser(self, grad_fn):
for name, parser in BACKWARD_PARSER_DICT.items():
Expand Down
122 changes: 118 additions & 4 deletions tests/test_models/test_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import torch
from mmcv import ConfigDict
from mmcv import ConfigDict, digit_version

from mmrazor.models.builder import ARCHITECTURES, PRUNERS

Expand Down Expand Up @@ -86,7 +86,7 @@ def test_ratio_pruner():
losses = architecture(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0

# test making groups logic when there are shared modules in the model
# test models with shared module
model_cfg = ConfigDict(
type='mmdet.RetinaNet',
backbone=dict(
Expand Down Expand Up @@ -159,13 +159,127 @@ def test_ratio_pruner():
pruner = PRUNERS.build(pruner_cfg)
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
assert isinstance(subnet_dict, dict)
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
assert isinstance(subnet_dict, dict)
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)

# test models with concat operations
model_cfg = ConfigDict(
type='mmdet.YOLOX',
input_size=(640, 640),
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
in_channels=[128, 256, 512],
out_channels=128,
num_csp_blocks=1),
bbox_head=dict(
type='YOLOXHead',
num_classes=80,
in_channels=128,
feat_channels=128),
train_cfg=dict(
assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
# In order to align the source code, the threshold of the val phase is
# 0.01, and the threshold of the test phase is 0.001.
test_cfg=dict(
score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

architecture_cfg = dict(
type='MMDetArchitecture',
model=model_cfg,
)

architecture = ARCHITECTURES.build(architecture_cfg)
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)

# test models with groupnorm
model_cfg = ConfigDict(
type='mmdet.ATSS',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='ATSSHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))

architecture_cfg = dict(
type='MMDetArchitecture',
model=model_cfg,
)

architecture = ARCHITECTURES.build(architecture_cfg)
# ``StructurePruner`` requires pytorch>=1.6.0 to auto-trace GroupNorm
# correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
pruner.prepare_from_supernet(architecture)
else:
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)


def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail):
import os
Expand Down

0 comments on commit 16681a0

Please sign in to comment.