In [None]:
#| default_exp sparse.pruner

In [None]:
#| include: false
from nbdev.showdoc import *
%config InlineBackend.figure_format = 'retina'

:::{.callout-important}

The Pruner method currently works on fully-feedforward ConvNets, e.g. VGG16. Support for residual connections, e.g. ResNets is under development.

:::

When our network has filters containing zero values, there is an additional step that we may take. Indeed, those zero-filters can be **physically** removed from our network, allowing us to get a new, dense, architecture.

This can be done by reexpressing each layer, reducing the number of filter, to match the number of non-zero filters. However, when we remove a filter in a layer, this means that there will be a missing activation map, which should be used by all the filters in the next layer. So, not only should we physically remove the filter, but also its corresponding kernel in each of the filters in the next layer (see Fig. below)

![](imgs/pruning_filters.png "Pruning Filters")

In [None]:
#| export
import torch
import torch.nn as nn
import copy
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
class Pruner():
    "Remove zero filters from a model"
    def __init__(self):
        super().__init__()
        
    def filters_to_keep(self, layer, nxt_layer):
        
        ixs = self._get_nz_ixs(layer)
    
        filters_keep = layer.weight.index_select(0, ixs[0]).data # keep only the non_zero filters
        biases_keep = layer.bias.index_select(0, ixs[0]).data
        
        nxt_filters_keep = nxt_layer.weight.index_select(1, ixs[0]).data if nxt_layer is not None else None
            
        return filters_keep, biases_keep, nxt_filters_keep
    
    def prune_conv(self, layer, nxt_layer):
        assert layer.__class__.__name__ == 'Conv2d'
    
        new_weights, new_biases, new_next_weights = self.filters_to_keep(layer, nxt_layer)
    
        layer.out_channels = new_weights.shape[0]
        layer.in_channels = new_weights.shape[1]
    
        layer.weight = nn.Parameter(new_weights)
        layer.bias = nn.Parameter(new_biases)

        if new_next_weights is not None:
            new_next_in_channels = new_next_weights.shape[1]
            nxt_layer.weight = nn.Parameter(new_next_weights)
            nxt_layer.in_channels = new_next_in_channels
    
        return layer, nxt_layer
    
    def prune_bn(self, layer, prev_conv):
        
        ixs = self._get_nz_ixs(prev_conv)
        
        weights_keep = layer.weight.data.index_select(0, ixs[0]).data
    
        layer.num_features = weights_keep.shape[0]
        layer.weight = nn.Parameter(weights_keep)
        layer.bias = nn.Parameter(layer.bias.data.index_select(0, ixs[0]).data)
        layer.running_mean = layer.running_mean.data.index_select(0, ixs[0]).data
        layer.running_var = layer.running_var.data.index_select(0, ixs[0]).data
        
        return layer

    def delete_fc_weights(self, layer, last_conv, pool_shape):
        
        ixs = self._get_nz_ixs(last_conv)
        
        new_ixs = torch.cat([torch.arange(i*pool_shape**2,((i+1)*pool_shape**2)) for i in ixs[0]]) if pool_shape else ixs[0]
        new_ixs = torch.LongTensor(new_ixs).cuda()

        weights_keep = layer.weight.data.index_select(1, new_ixs).data
        
        layer.in_features = weights_keep.shape[1]
        layer.weight = nn.Parameter(weights_keep)
    
        return layer
    
    def _get_nz_ixs(self, layer):
        filters = layer.weight
        nz_filters = filters.data.sum(dim=(1,2,3)) # Flatten the filters to compare them
        ixs = torch.nonzero(nz_filters).T
        return ixs.cuda()
    
    def _find_next_conv(self, model, conv_ix):
        for k,m in enumerate(model.modules()):
            if k > conv_ix and isinstance(m, nn.Conv2d):
                next_conv_ix = k
                break
            else:
                next_conv_ix = None
        return next_conv_ix
    
    def _find_previous_conv(self, model, layer_ix):
        for k,m in reversed(list(enumerate(model.modules()))):
            if k < layer_ix and isinstance(m, nn.Conv2d):
                prev_conv_ix = k
                break
            else:
                prev_conv_ix = None
        return prev_conv_ix    
    
    def _get_last_conv_ix(self, model):
        for k,m in enumerate(list(model.modules())):
            if isinstance(m, nn.Conv2d):
                last_conv_ix = k
        return last_conv_ix
    
    def _get_first_fc_ix(self, model):
        for k,m in enumerate(list(model.modules())):
            if isinstance(m, nn.Linear):
                first_fc_ix = k
                break       
        return first_fc_ix
    
    def _find_pool_shape(self, model):
        for k,m in enumerate(model.modules()):
            if isinstance(m, nn.AdaptiveAvgPool2d):
                output_shape = m.output_size
                break
            else: output_shape=None
        return output_shape    
    
    def prune_model(self, model):
        pruned_model = copy.deepcopy(model)
        
        layer_names = list(dict(pruned_model.named_modules()).keys())
        layers = dict(pruned_model.named_modules())
        old_layers = dict(model.named_modules())
        
        last_conv_ix = self._get_last_conv_ix(pruned_model)
        first_fc_ix = self._get_first_fc_ix(pruned_model)
        
        for k,m in enumerate(list(pruned_model.modules())):
            
            if isinstance(m, nn.Conv2d):
                next_conv_ix = self._find_next_conv(model, k)
                if next_conv_ix is not None: # The conv layer is not the last one
                    new_m, new_next_m = self.prune_conv(m, layers[layer_names[next_conv_ix]]) # Prune the current conv layer
                else:
                    new_m, _ = self.prune_conv(m, None) # Prune the current conv layer without changing the next one
                    
            if isinstance(m, nn.BatchNorm2d):
                new_m = self.prune_bn(m, old_layers[layer_names[self._find_previous_conv(model, k)]])             
                    
            if isinstance(m, nn.Linear) and k==first_fc_ix:
                pool_shape = self._find_pool_shape(model)
                new_m = self.delete_fc_weights(m, old_layers[layer_names[last_conv_ix]], pool_shape[0])

        return pruned_model

In [None]:
show_doc(Pruner.prune_model)

---

### Pruner.prune_model

>      Pruner.prune_model (model)