Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]Support to prune models containing GroupNorm or InstanceNorm. #144

Merged
merged 5 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions mmrazor/models/pruners/ratio_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules import GroupNorm

from mmrazor.models.builder import PRUNERS
from .structure_pruning import StructurePruner
Expand Down Expand Up @@ -31,6 +32,22 @@ def __init__(self, ratios, **kwargs):
self.ratios = ratios
self.min_ratio = ratios[0]

def _check_pruner(self, supernet):
for module in supernet.model.modules():
if isinstance(module, GroupNorm):
num_channels = module.num_channels
num_groups = module.num_groups
for ratio in self.ratios:
new_channels = int(round(num_channels * ratio))
assert (num_channels * ratio) % num_groups == 0, \
f'Expected number of channels in input of GroupNorm ' \
f'to be divisible by num_groups, but number of ' \
f'channels may be {new_channels} according to ' \
f'ratio {ratio} and num_groups={num_groups}'

def prepare_from_supernet(self, supernet):
super(RatioPruner, self).prepare_from_supernet(supernet)

def get_channel_mask(self, out_mask):
"""Randomly choose a width ratio of a layer from ``ratios``"""
out_channels = out_mask.size(1)
Expand Down
126 changes: 88 additions & 38 deletions mmrazor/models/pruners/structure_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@

import torch
import torch.nn as nn
from mmcv import digit_version
from mmcv.runner import BaseModule
from ordered_set import OrderedSet
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm

from mmrazor.models.builder import PRUNERS
from .utils import SwitchableBatchNorm2d
Expand All @@ -19,14 +22,13 @@
FC = ('ThAddmmBackward', 'AddmmBackward', 'MmBackward')
BN = ('ThnnBatchNormBackward', 'CudnnBatchNormBackward',
'NativeBatchNormBackward')
GN = ('NativeGroupNormBackward', )
CONCAT = ('CatBackward', )
# the modules which contains NON_PASS grad_fn need to change the parameter size
# according to channels after pruning
NON_PASS = CONV + FC
NON_PASS_MODULE = (nn.Conv2d, nn.Linear)

PASS = BN
PASS_MODULE = (_BatchNorm)
PASS = BN + GN
NORM = BN + GN

BACKWARD_PARSER_DICT = dict()
MAKE_GROUP_PARSER_DICT = dict()
Expand Down Expand Up @@ -122,6 +124,12 @@ def prepare_from_supernet(self, supernet):
tmp_shared_module_hook_handles = list()

for name, module in supernet.model.named_modules():
if isinstance(module, nn.GroupNorm):
min_required_version = '1.6.0'
assert digit_version(torch.__version__) >= digit_version(
min_required_version
), f'Requires pytorch>={min_required_version} to auto-trace' \
f'GroupNorm correctly.'
if hasattr(module, 'weight'):
# trace shared modules
module.cnt = 0
Expand Down Expand Up @@ -172,10 +180,10 @@ def prepare_from_supernet(self, supernet):
self.trace_non_pass_path(pseudo_loss.grad_fn, module2name, var2module,
cur_non_pass_path, non_pass_paths, visited)

bn_conv_links = dict()
self.trace_bn_conv_links(pseudo_loss.grad_fn, module2name, var2module,
bn_conv_links, visited)
self.bn_conv_links = bn_conv_links
norm_conv_links = dict()
self.trace_norm_conv_links(pseudo_loss.grad_fn, module2name,
var2module, norm_conv_links, visited)
self.norm_conv_links = norm_conv_links

# a node can be the name of a conv module or a str like 'concat_{id}'
node2parents = self.find_node_parents(non_pass_paths)
Expand Down Expand Up @@ -268,12 +276,12 @@ def set_subnet(self, subnet_dict):
module = self.name2module[module_name]
module.out_mask = subnet_dict[space_id].to(module.out_mask.device)

for bn, conv in self.bn_conv_links.items():
module = self.name2module[bn]
for norm, conv in self.norm_conv_links.items():
module = self.name2module[norm]
conv_space_id = self.get_space_id(conv)
# conv_space_id is None means the conv layer in front of
# this bn module can not be pruned. So we should not set
# the out_mask of this bn layer
# this normalization module can not be pruned. So we should not set
# the out_mask of this normalization layer
if conv_space_id is not None:
module.out_mask = subnet_dict[conv_space_id].to(
module.out_mask.device)
Expand Down Expand Up @@ -458,7 +466,9 @@ def add_pruning_attrs(self, module):
module.register_buffer(
'out_mask', module.weight.new_ones((1, module.out_features), ))
module.forward = self.modify_fc_forward(module)
if isinstance(module, nn.modules.batchnorm._BatchNorm):
if (isinstance(module, _BatchNorm)
or isinstance(module, _InstanceNorm)
or isinstance(module, GroupNorm)):
module.register_buffer(
'out_mask',
module.weight.new_ones((1, len(module.weight), 1, 1), ))
Expand Down Expand Up @@ -625,39 +635,79 @@ def trace_non_pass_path(self, grad_fn, module2name, var2module, cur_path,
else:
result_paths.append(copy.deepcopy(cur_path))

def trace_bn_conv_links(self, grad_fn, module2name, var2module,
bn_conv_links, visited):
"""Get the convolutional layer placed before a bn layer in the model.
def trace_norm_conv_links(self, grad_fn, module2name, var2module,
norm_conv_links, visited):
"""Get the convolutional layer placed before a normalization layer in
the model.

Example:
>>> conv = nn.Conv2d(3, 3, 3)
>>> bn = nn.BatchNorm2d(3)
>>> norm = nn.BatchNorm2d(3)
>>> pseudo_img = torch.rand(1, 3, 224, 224)
>>> out = bn(conv(pseudo_img))
>>> out = norm(conv(pseudo_img))
>>> print(out.grad_fn)
<NativeBatchNormBackward object at 0x0000022BC709DB08>
>>> print(out.grad_fn.next_functions)
((<ThnnConv2DBackward object at 0x0000020E40639688>, 0),
(<AccumulateGrad object at 0x0000020E40639208>, 0),
(<AccumulateGrad object at 0x0000020E406398C8>, 0))
>>> # op.next_functions[0][0] is ThnnConv2DBackward means
>>> # the parent of this NativeBatchNormBackward op is
>>> # ThnnConv2DBackward
>>> # op.next_functions[1][0].variable is the weight of this bn
>>> # module
>>> # op.next_functions[2][0].variable is the bias of this bn
>>> # module
>>> # op.next_functions[1][0].variable is the weight of this
>>> # normalization module
>>> # op.next_functions[2][0].variable is the bias of this
>>> # normalization module

>>> # Things are different in InstanceNorm
>>> conv = nn.Conv2d(3, 3, 3)
>>> norm = nn.InstanceNorm2d(3, affine=True)
>>> out = norm(conv(pseudo_img))
>>> print(out.grad_fn)
<ViewBackward object at 0x0000022BC709DD48>
>>> print(out.grad_fn.next_functions)
((<NativeBatchNormBackward object at 0x0000022BC81E8A08>, 0),)
>>> print(out.grad_fn.next_functions[0][0].next_functions)
((<ViewBackward object at 0x0000022BC81E8DC8>, 0),
(<RepeatBackward object at 0x0000022BC81E8D08>, 0),
(<RepeatBackward object at 0x0000022BC81E81C8>, 0))
>>> # Hence, a dfs is necessary.
"""
grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
if grad_fn is not None:
is_bn_grad_fn = False
for fn_name in BN:

def is_norm_grad_fn(grad_fn):
for fn_name in NORM:
if type(grad_fn).__name__.startswith(fn_name):
is_bn_grad_fn = True
break
return True
return False

def is_conv_grad_fn(grad_fn):
for fn_name in CONV:
if type(grad_fn).__name__.startswith(fn_name):
return True
return False

if is_bn_grad_fn:
def is_leaf_grad_fn(grad_fn):
if type(grad_fn).__name__ == 'AccumulateGrad':
return True
return False

grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn
if grad_fn is not None:
if is_norm_grad_fn(grad_fn):
conv_grad_fn = grad_fn.next_functions[0][0]
conv_var = conv_grad_fn.next_functions[1][0].variable
bn_var = grad_fn.next_functions[1][0].variable
while not is_conv_grad_fn(conv_grad_fn):
conv_grad_fn = conv_grad_fn.next_functions[0][0]

leaf_grad_fn = conv_grad_fn.next_functions[1][0]
while not is_leaf_grad_fn(leaf_grad_fn):
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
conv_var = leaf_grad_fn.variable

leaf_grad_fn = grad_fn.next_functions[1][0]
while not is_leaf_grad_fn(leaf_grad_fn):
leaf_grad_fn = leaf_grad_fn.next_functions[0][0]
bn_var = leaf_grad_fn.variable

conv_module = var2module[id(conv_var)]
bn_module = var2module[id(bn_var)]
conv_name = module2name[conv_module]
Expand All @@ -666,20 +716,20 @@ def trace_bn_conv_links(self, grad_fn, module2name, var2module,
pass
else:
visited[bn_name] = True
bn_conv_links[bn_name] = conv_name
norm_conv_links[bn_name] = conv_name

self.trace_bn_conv_links(conv_grad_fn, module2name,
var2module, bn_conv_links,
visited)
self.trace_norm_conv_links(conv_grad_fn, module2name,
var2module, norm_conv_links,
visited)

else:
# If the op is AccumulateGrad, parents is (),
parents = grad_fn.next_functions
if parents is not None:
for parent in parents:
self.trace_bn_conv_links(parent, module2name,
var2module, bn_conv_links,
visited)
self.trace_norm_conv_links(parent, module2name,
var2module, norm_conv_links,
visited)

def find_backward_parser(self, grad_fn):
for name, parser in BACKWARD_PARSER_DICT.items():
Expand Down
122 changes: 118 additions & 4 deletions tests/test_models/test_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import torch
from mmcv import ConfigDict
from mmcv import ConfigDict, digit_version

from mmrazor.models.builder import ARCHITECTURES, PRUNERS

Expand Down Expand Up @@ -86,7 +86,7 @@ def test_ratio_pruner():
losses = architecture(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0

# test making groups logic when there are shared modules in the model
# test models with shared module
model_cfg = ConfigDict(
type='mmdet.RetinaNet',
backbone=dict(
Expand Down Expand Up @@ -159,13 +159,127 @@ def test_ratio_pruner():
pruner = PRUNERS.build(pruner_cfg)
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
assert isinstance(subnet_dict, dict)
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
assert isinstance(subnet_dict, dict)
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)

# test models with concat operations
model_cfg = ConfigDict(
type='mmdet.YOLOX',
input_size=(640, 640),
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
in_channels=[128, 256, 512],
out_channels=128,
num_csp_blocks=1),
bbox_head=dict(
type='YOLOXHead',
num_classes=80,
in_channels=128,
feat_channels=128),
train_cfg=dict(
assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
# In order to align the source code, the threshold of the val phase is
# 0.01, and the threshold of the test phase is 0.001.
test_cfg=dict(
score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))

architecture_cfg = dict(
type='MMDetArchitecture',
model=model_cfg,
)

architecture = ARCHITECTURES.build(architecture_cfg)
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)

# test models with groupnorm
model_cfg = ConfigDict(
type='mmdet.ATSS',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
bbox_head=dict(
type='ATSSHead',
num_classes=80,
in_channels=256,
stacked_convs=4,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))

architecture_cfg = dict(
type='MMDetArchitecture',
model=model_cfg,
)

architecture = ARCHITECTURES.build(architecture_cfg)
# ``StructurePruner`` requires pytorch>=1.6.0 to auto-trace GroupNorm
# correctly
min_required_version = '1.6.0'
if digit_version(torch.__version__) < digit_version(min_required_version):
with pytest.raises(AssertionError):
pruner.prepare_from_supernet(architecture)
else:
pruner.prepare_from_supernet(architecture)
subnet_dict = pruner.sample_subnet()
pruner.set_subnet(subnet_dict)
subnet_dict = pruner.export_subnet()
pruner.deploy_subnet(architecture, subnet_dict)
architecture.forward_dummy(imgs)


def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail):
import os
Expand Down