diff --git a/mmrazor/models/pruners/ratio_pruning.py b/mmrazor/models/pruners/ratio_pruning.py index 218b142d3..151677d99 100644 --- a/mmrazor/models/pruners/ratio_pruning.py +++ b/mmrazor/models/pruners/ratio_pruning.py @@ -2,6 +2,7 @@ import numpy as np import torch import torch.nn as nn +from torch.nn.modules import GroupNorm from mmrazor.models.builder import PRUNERS from .structure_pruning import StructurePruner @@ -31,6 +32,22 @@ def __init__(self, ratios, **kwargs): self.ratios = ratios self.min_ratio = ratios[0] + def _check_pruner(self, supernet): + for module in supernet.model.modules(): + if isinstance(module, GroupNorm): + num_channels = module.num_channels + num_groups = module.num_groups + for ratio in self.ratios: + new_channels = int(round(num_channels * ratio)) + assert (num_channels * ratio) % num_groups == 0, \ + f'Expected number of channels in input of GroupNorm ' \ + f'to be divisible by num_groups, but number of ' \ + f'channels may be {new_channels} according to ' \ + f'ratio {ratio} and num_groups={num_groups}' + + def prepare_from_supernet(self, supernet): + super(RatioPruner, self).prepare_from_supernet(supernet) + def get_channel_mask(self, out_mask): """Randomly choose a width ratio of a layer from ``ratios``""" out_channels = out_mask.size(1) diff --git a/mmrazor/models/pruners/structure_pruning.py b/mmrazor/models/pruners/structure_pruning.py index af1eae2b0..b755c2e5d 100644 --- a/mmrazor/models/pruners/structure_pruning.py +++ b/mmrazor/models/pruners/structure_pruning.py @@ -6,9 +6,12 @@ import torch import torch.nn as nn +from mmcv import digit_version from mmcv.runner import BaseModule from ordered_set import OrderedSet +from torch.nn.modules import GroupNorm from torch.nn.modules.batchnorm import _BatchNorm +from torch.nn.modules.instancenorm import _InstanceNorm from mmrazor.models.builder import PRUNERS from .utils import SwitchableBatchNorm2d @@ -19,14 +22,13 @@ FC = ('ThAddmmBackward', 'AddmmBackward', 'MmBackward') BN = ('ThnnBatchNormBackward', 'CudnnBatchNormBackward', 'NativeBatchNormBackward') +GN = ('NativeGroupNormBackward', ) CONCAT = ('CatBackward', ) # the modules which contains NON_PASS grad_fn need to change the parameter size # according to channels after pruning NON_PASS = CONV + FC -NON_PASS_MODULE = (nn.Conv2d, nn.Linear) - -PASS = BN -PASS_MODULE = (_BatchNorm) +PASS = BN + GN +NORM = BN + GN BACKWARD_PARSER_DICT = dict() MAKE_GROUP_PARSER_DICT = dict() @@ -122,6 +124,12 @@ def prepare_from_supernet(self, supernet): tmp_shared_module_hook_handles = list() for name, module in supernet.model.named_modules(): + if isinstance(module, nn.GroupNorm): + min_required_version = '1.6.0' + assert digit_version(torch.__version__) >= digit_version( + min_required_version + ), f'Requires pytorch>={min_required_version} to auto-trace' \ + f'GroupNorm correctly.' if hasattr(module, 'weight'): # trace shared modules module.cnt = 0 @@ -172,10 +180,10 @@ def prepare_from_supernet(self, supernet): self.trace_non_pass_path(pseudo_loss.grad_fn, module2name, var2module, cur_non_pass_path, non_pass_paths, visited) - bn_conv_links = dict() - self.trace_bn_conv_links(pseudo_loss.grad_fn, module2name, var2module, - bn_conv_links, visited) - self.bn_conv_links = bn_conv_links + norm_conv_links = dict() + self.trace_norm_conv_links(pseudo_loss.grad_fn, module2name, + var2module, norm_conv_links, visited) + self.norm_conv_links = norm_conv_links # a node can be the name of a conv module or a str like 'concat_{id}' node2parents = self.find_node_parents(non_pass_paths) @@ -268,12 +276,12 @@ def set_subnet(self, subnet_dict): module = self.name2module[module_name] module.out_mask = subnet_dict[space_id].to(module.out_mask.device) - for bn, conv in self.bn_conv_links.items(): - module = self.name2module[bn] + for norm, conv in self.norm_conv_links.items(): + module = self.name2module[norm] conv_space_id = self.get_space_id(conv) # conv_space_id is None means the conv layer in front of - # this bn module can not be pruned. So we should not set - # the out_mask of this bn layer + # this normalization module can not be pruned. So we should not set + # the out_mask of this normalization layer if conv_space_id is not None: module.out_mask = subnet_dict[conv_space_id].to( module.out_mask.device) @@ -458,7 +466,9 @@ def add_pruning_attrs(self, module): module.register_buffer( 'out_mask', module.weight.new_ones((1, module.out_features), )) module.forward = self.modify_fc_forward(module) - if isinstance(module, nn.modules.batchnorm._BatchNorm): + if (isinstance(module, _BatchNorm) + or isinstance(module, _InstanceNorm) + or isinstance(module, GroupNorm)): module.register_buffer( 'out_mask', module.weight.new_ones((1, len(module.weight), 1, 1), )) @@ -625,15 +635,18 @@ def trace_non_pass_path(self, grad_fn, module2name, var2module, cur_path, else: result_paths.append(copy.deepcopy(cur_path)) - def trace_bn_conv_links(self, grad_fn, module2name, var2module, - bn_conv_links, visited): - """Get the convolutional layer placed before a bn layer in the model. + def trace_norm_conv_links(self, grad_fn, module2name, var2module, + norm_conv_links, visited): + """Get the convolutional layer placed before a normalization layer in + the model. Example: >>> conv = nn.Conv2d(3, 3, 3) - >>> bn = nn.BatchNorm2d(3) + >>> norm = nn.BatchNorm2d(3) >>> pseudo_img = torch.rand(1, 3, 224, 224) - >>> out = bn(conv(pseudo_img)) + >>> out = norm(conv(pseudo_img)) + >>> print(out.grad_fn) + >>> print(out.grad_fn.next_functions) ((, 0), (, 0), @@ -641,23 +654,60 @@ def trace_bn_conv_links(self, grad_fn, module2name, var2module, >>> # op.next_functions[0][0] is ThnnConv2DBackward means >>> # the parent of this NativeBatchNormBackward op is >>> # ThnnConv2DBackward - >>> # op.next_functions[1][0].variable is the weight of this bn - >>> # module - >>> # op.next_functions[2][0].variable is the bias of this bn - >>> # module + >>> # op.next_functions[1][0].variable is the weight of this + >>> # normalization module + >>> # op.next_functions[2][0].variable is the bias of this + >>> # normalization module + + >>> # Things are different in InstanceNorm + >>> conv = nn.Conv2d(3, 3, 3) + >>> norm = nn.InstanceNorm2d(3, affine=True) + >>> out = norm(conv(pseudo_img)) + >>> print(out.grad_fn) + + >>> print(out.grad_fn.next_functions) + ((, 0),) + >>> print(out.grad_fn.next_functions[0][0].next_functions) + ((, 0), + (, 0), + (, 0)) + >>> # Hence, a dfs is necessary. """ - grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn - if grad_fn is not None: - is_bn_grad_fn = False - for fn_name in BN: + + def is_norm_grad_fn(grad_fn): + for fn_name in NORM: if type(grad_fn).__name__.startswith(fn_name): - is_bn_grad_fn = True - break + return True + return False + + def is_conv_grad_fn(grad_fn): + for fn_name in CONV: + if type(grad_fn).__name__.startswith(fn_name): + return True + return False - if is_bn_grad_fn: + def is_leaf_grad_fn(grad_fn): + if type(grad_fn).__name__ == 'AccumulateGrad': + return True + return False + + grad_fn = grad_fn[0] if isinstance(grad_fn, (list, tuple)) else grad_fn + if grad_fn is not None: + if is_norm_grad_fn(grad_fn): conv_grad_fn = grad_fn.next_functions[0][0] - conv_var = conv_grad_fn.next_functions[1][0].variable - bn_var = grad_fn.next_functions[1][0].variable + while not is_conv_grad_fn(conv_grad_fn): + conv_grad_fn = conv_grad_fn.next_functions[0][0] + + leaf_grad_fn = conv_grad_fn.next_functions[1][0] + while not is_leaf_grad_fn(leaf_grad_fn): + leaf_grad_fn = leaf_grad_fn.next_functions[0][0] + conv_var = leaf_grad_fn.variable + + leaf_grad_fn = grad_fn.next_functions[1][0] + while not is_leaf_grad_fn(leaf_grad_fn): + leaf_grad_fn = leaf_grad_fn.next_functions[0][0] + bn_var = leaf_grad_fn.variable + conv_module = var2module[id(conv_var)] bn_module = var2module[id(bn_var)] conv_name = module2name[conv_module] @@ -666,20 +716,20 @@ def trace_bn_conv_links(self, grad_fn, module2name, var2module, pass else: visited[bn_name] = True - bn_conv_links[bn_name] = conv_name + norm_conv_links[bn_name] = conv_name - self.trace_bn_conv_links(conv_grad_fn, module2name, - var2module, bn_conv_links, - visited) + self.trace_norm_conv_links(conv_grad_fn, module2name, + var2module, norm_conv_links, + visited) else: # If the op is AccumulateGrad, parents is (), parents = grad_fn.next_functions if parents is not None: for parent in parents: - self.trace_bn_conv_links(parent, module2name, - var2module, bn_conv_links, - visited) + self.trace_norm_conv_links(parent, module2name, + var2module, norm_conv_links, + visited) def find_backward_parser(self, grad_fn): for name, parser in BACKWARD_PARSER_DICT.items(): diff --git a/tests/test_models/test_pruner.py b/tests/test_models/test_pruner.py index 835c4894f..12323022a 100644 --- a/tests/test_models/test_pruner.py +++ b/tests/test_models/test_pruner.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv import ConfigDict +from mmcv import ConfigDict, digit_version from mmrazor.models.builder import ARCHITECTURES, PRUNERS @@ -86,7 +86,7 @@ def test_ratio_pruner(): losses = architecture(imgs, return_loss=True, gt_label=label) assert losses['loss'].item() > 0 - # test making groups logic when there are shared modules in the model + # test models with shared module model_cfg = ConfigDict( type='mmdet.RetinaNet', backbone=dict( @@ -159,13 +159,127 @@ def test_ratio_pruner(): pruner = PRUNERS.build(pruner_cfg) pruner.prepare_from_supernet(architecture) subnet_dict = pruner.sample_subnet() - assert isinstance(subnet_dict, dict) pruner.set_subnet(subnet_dict) subnet_dict = pruner.export_subnet() - assert isinstance(subnet_dict, dict) pruner.deploy_subnet(architecture, subnet_dict) architecture.forward_dummy(imgs) + # test models with concat operations + model_cfg = ConfigDict( + type='mmdet.YOLOX', + input_size=(640, 640), + random_size_range=(15, 25), + random_size_interval=10, + backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), + neck=dict( + type='YOLOXPAFPN', + in_channels=[128, 256, 512], + out_channels=128, + num_csp_blocks=1), + bbox_head=dict( + type='YOLOXHead', + num_classes=80, + in_channels=128, + feat_channels=128), + train_cfg=dict( + assigner=dict(type='SimOTAAssigner', center_radius=2.5)), + # In order to align the source code, the threshold of the val phase is + # 0.01, and the threshold of the test phase is 0.001. + test_cfg=dict( + score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) + + architecture_cfg = dict( + type='MMDetArchitecture', + model=model_cfg, + ) + + architecture = ARCHITECTURES.build(architecture_cfg) + pruner.prepare_from_supernet(architecture) + subnet_dict = pruner.sample_subnet() + pruner.set_subnet(subnet_dict) + subnet_dict = pruner.export_subnet() + pruner.deploy_subnet(architecture, subnet_dict) + architecture.forward_dummy(imgs) + + # test models with groupnorm + model_cfg = ConfigDict( + type='mmdet.ATSS', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict( + type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_output', + num_outs=5), + bbox_head=dict( + type='ATSSHead', + num_classes=80, + in_channels=256, + stacked_convs=4, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + loss_centerness=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), + # training and testing settings + train_cfg=dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + + architecture_cfg = dict( + type='MMDetArchitecture', + model=model_cfg, + ) + + architecture = ARCHITECTURES.build(architecture_cfg) + # ``StructurePruner`` requires pytorch>=1.6.0 to auto-trace GroupNorm + # correctly + min_required_version = '1.6.0' + if digit_version(torch.__version__) < digit_version(min_required_version): + with pytest.raises(AssertionError): + pruner.prepare_from_supernet(architecture) + else: + pruner.prepare_from_supernet(architecture) + subnet_dict = pruner.sample_subnet() + pruner.set_subnet(subnet_dict) + subnet_dict = pruner.export_subnet() + pruner.deploy_subnet(architecture, subnet_dict) + architecture.forward_dummy(imgs) + def _test_reset_bn_running_stats(architecture_cfg, pruner_cfg, should_fail): import os