In [None]:
#| default_exp prune.pruner

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torch_pruning.pruner import function

import numpy as np
import torch
import torch.nn as nn
import pickle
from itertools import cycle
from fastcore.basics import store_attr, listify, true
from fasterai.core.criteria import *
from fastai.vision.all import *

  warn(f"Failed to load image Python extension: {e}")


In [None]:
#| include: false
from nbdev.showdoc import *

In [None]:
#| export
class Pruner():
    def __init__(self, model, context, criteria, layer_type=nn.Conv2d, example_inputs=torch.randn(1,3,224,224)):
        store_attr()
        self.DG = tp.DependencyGraph()
        self.DG.build_dependency(self.model, example_inputs=example_inputs.to(next(model.parameters()).device))
        self._save_init_state()
        self._reset_threshold()
        self.init_num_groups = None

    def compute_threshold(self, sparsity):
        self.global_importance = {}
        for ix, grp in enumerate(self.DG.get_all_groups(root_module_types=[self.layer_type])):
            imp = self.group_importance(grp)
            self.global_importance[ix] = imp

        global_imp = torch.cat(list(self.global_importance.values()), dim=0)

        self.init_num_groups = self.init_num_groups or len(global_imp)
        n_pruned = np.clip(int((1-sparsity/100)*self.init_num_groups), 1, len(global_imp))
        self.global_threshold = torch.topk(global_imp, n_pruned)[0].min()

    def prune_group(self, group, ix, sparsity, round_to):
        module = group[0][0].target.module
        pruning_fn = group[0][0].handler
        pruning_idxs = self.prune_method(group, ix, sparsity, round_to)
        group = self.DG.get_pruning_group(module, pruning_fn, pruning_idxs.tolist())
        group.prune()
    
    def prune_model(self, sparsity, round_to=None):
        if self.context=='global': self.compute_threshold(sparsity)

        for ix, group in enumerate(self.DG.get_all_groups(root_module_types=[self.layer_type])):
            self.prune_group(group, ix, sparsity, round_to)

    def prune_method(self, group, ix, sparsity, round_to):
        if self.context=='global':
            imp = self.global_importance[ix]
            n_pruned = max(1, int(imp.ge(self.global_threshold).sum()))
        else:
            imp = self.group_importance(group)
            n_pruned = max(1, int((1-sparsity/100)*group[0].dep.target.module._init_out_channels))
 
        threshold = torch.topk(imp, int(self._rounded_sparsity(torch.tensor(n_pruned), round_to)))[0].min() if round_to else torch.topk(imp, n_pruned)[0].min()
        return imp.lt(threshold).nonzero().view(-1)
                
    def updated_sparsity(self, m, sparsity):
        init_channels = m._init_out_channels
        return sparsity
                
    def _save_init_state(self):
        for m in self.model.modules():
            if hasattr(m, 'weight'):
                setattr(m, '_init_out_channels', self.DG.get_out_channels(m))

    def _rounded_sparsity(self, n_to_prune, round_to):
        return max(round_to*torch.floor(n_to_prune/round_to), round_to)
    
    def _reset_threshold(self):
        self.global_threshold=None
    
    def group_importance(self, group):
        handler_map = {
            function.prune_conv_out_channels: ('filter', (1, 2, 3)),
            function.prune_linear_out_channels: ('row', None),
            function.prune_conv_in_channels: ('shared_kernel', (0, 2, 3)),
            function.prune_linear_out_channels: ('column', None)
        }

        group_importance = [
            self.criteria(dep.target.module, granularity).squeeze(squeeze_dims)
            for dep, _ in group
            if dep.handler in handler_map
            for granularity, squeeze_dims in [handler_map.get(dep.handler)]
        ]

        return torch.stack(group_importance).mean(0)

In [None]:
show_doc(Pruner.prune_model)

---

### Pruner.prune_model

>      Pruner.prune_model (sparsity, round_to=None)

Let's try the `Pruner` with a VGG16 model

In [None]:
model = vgg16_bn()
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

The `Pruner`can either remove filters based on `local` criteria (i.e. each layer will be trimmed of the same % of filters)

In [None]:
pruner = Pruner(model, 'local', large_final)
pruner.prune_model(50)
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(64, 128, kern

The `Pruner`can also remove filters based on `global` criteria (i.e. each layer will be trimmed of a different % of filters, but we specify the sparsity of the whole network)

In [None]:
pruner = Pruner(model, 'global', large_final)
pruner.prune_model(50)
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(64, 128, kern