# Customize Your Own Pruners

Torch-pruning is a scalable tool that enables you to create your own pruners with customized importance criteria and pruning schemes. For instance, you can use torch-pruning to implement the [Slimming pruner](https://arxiv.org/abs/1708.06519), which utilizes the scaling parameters in Batch Normalization (BN) to identify and remove unimportant channels. 

In [1]:
import warnings
warnings.filterwarnings('ignore')
import sys, os
sys.path.append(os.path.abspath("../"))

import torch
import torch.nn as nn
from torchvision.models import resnet18
import torch_pruning as tp

### 1. Pruner Definition

Slimming Pruner leverages the scaling factor in Batch Normalization (BN) layers to determine the importance score of different channels. This technique follows a "training-pruning-fine-tuning" paradigm, which involves sparse training of the original model. In Torch-Pruning, the base class ``tp.pruner.MetaPruner`` provides a convenient ``.regularize(model)`` method for sparse training. Our first task is to implement such an interface to enable efficient regularization of BN parameters.

In [6]:
class MySimplePruner(tp.pruner.MetaPruner):
    def regularize(self, model, reg):
        print("No regularization required.")

### 2. Importance function
Now, we need a new importance criterion for slimming, which compares the magnitude of the scaling parameter in BN. In this work, importance criterion is a callable function or object which accept a group ``tp.PruningGroup`` as inputs. ``tp.PruningGroup`` records all coupled layers as well as their pruning indices. We can scan the group to design our own importance function as follows:

In [35]:
class MySimplePrunerImportance(tp.importance.Importance):
    def __call__(self, group, **kwargs):
        #note that we have multiple BNs in a group, 
        # we store layer-wise scores in a list and then reduce them to get the final results
        group_imp = [] # (num_bns, num_channels) 
        # 1. iterate the group to estimate importance
        for dep, idxs in group:
            layer = dep.target.module # get the target model
            prune_fn = dep.handler    # get the pruning function of target model, unused in this example
            if isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)) and layer.affine:
                local_imp = torch.abs(layer.weight.data)
                group_imp.append(local_imp)
        if len(group_imp)==0: return None # return None if the group contains no BN layer
        # 2. reduce your group importance to a 1-D scroe vector. Here we use the average score across layers.
        group_imp = torch.stack(group_imp, dim=0).mean(dim=0) 
        return group_imp # (num_channels, )

# You can implement any importance functions, as long as it transforms a group to a 1-D score vector.
class MinimumChannelImportance(tp.importance.Importance):
    @torch.no_grad()
    def __call__(self, group, **kwargs):
        _, idxs = group[0]
        return torch.rand(len(idxs))

### 3. Pruning
Now let's leverage the customized pruner to slim a resnet-18 models

In [36]:
model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

In [37]:
# 0. importance criterion 
imp = MinimumChannelImportance()

# 1. ignore some layers that should not be pruned, e.g., the final classifier layer.
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

# 2. Pruner initialization
iterative_steps = 5 # You can prune your model to the target sparsity iteratively.
pruner = MySimplePruner(
    model, 
    example_inputs, 
    global_pruning=False, # If False, a uniform sparsity will be assigned to different layers.
    importance=imp, # importance criterion for parameter selection
    iterative_steps=iterative_steps, # the number of iterations to achieve target sparsity
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

Sparse training with ``pruner.regularize``. Rember to regularize the model before ``optimizer.step()``.

In [6]:
# Training
for _ in range(100):
    pass
    # optimizer.zero_grad()
    # ...
    # loss.backward()
    # pruner.regularize(model, reg=1e-5)
    # optimizer.step()

Pruning and finetuning

In [38]:
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()

    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(model)
    print(model(example_inputs).shape)
    print(
        "  Iter %d/%d, Params: %.2f M => %.2f M"
        % (i+1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
    )
    print(
        "  Iter %d/%d, MACs: %.2f G => %.2f G"
        % (i+1, iterative_steps, base_macs / 1e9, macs / 1e9)
    )
    print("="*16)
    # finetune your model here
    # finetune(model)
    # ...

layer4.0.downsample.0 [445, 19, 340, 341, 370, 181, 86, 358, 329, 60, 46, 27, 170, 476, 208, 115, 167, 142, 446, 299, 395, 101, 281, 463, 70, 52, 293, 259, 401, 229, 74, 265, 116, 66, 280, 16, 76, 107, 219, 39, 197, 279, 286, 367, 73, 190, 133, 160, 450, 248, 489, 251]
layer3.0.downsample.0 [228, 109, 164, 243, 22, 149, 147, 205, 170, 218, 61, 183, 2, 220, 223, 86, 181, 4, 16, 38, 95, 136, 80, 56, 72, 192]
layer2.0.downsample.0 [16, 43, 7, 118, 37, 31, 28, 15, 88, 123, 47, 61, 11]
conv1 [58, 18, 47, 24, 36, 2, 39]
layer1.0.conv1 [31, 35, 45, 36, 13, 42, 62]
layer1.1.conv1 [15, 55, 34, 12, 61, 27, 41]
layer2.0.conv1 [97, 8, 67, 11, 95, 4, 76, 69, 115, 53, 121, 10, 49]
layer2.1.conv1 [83, 88, 71, 78, 92, 40, 32, 10, 38, 16, 37, 65, 89]
layer3.0.conv1 [112, 81, 61, 224, 116, 59, 93, 43, 31, 42, 180, 14, 151, 0, 107, 186, 119, 7, 77, 148, 252, 73, 167, 36, 235, 214]
layer3.1.conv1 [242, 240, 71, 18, 174, 106, 60, 211, 154, 241, 62, 133, 247, 52, 2, 235, 141, 39, 135, 166, 113, 178, 207, 6,