In [None]:
# default_exp sparse.pruner

In [None]:
#all_slow

In [None]:
#hide
from fastai.vision.all import *
from fastai.callback.all import *

from fasterai.sparse.all import *

import torch
import torch.nn as nn
import torch.nn.functional as F

%config InlineBackend.figure_format = 'retina'

> Important: The Pruner method currently works on fully-feedforward ConvNets, e.g. VGG16. Support for residual connections, e.g. ResNets is under development.

When our network has filters containing zero values, there is an additional step that we may take. Indeed, those zero-filters can be **physically** removed from our network, allowing us to get a new, dense, architecture.

This can be done by reexpressing each layer, reducing the number of filter, to match the number of non-zero filters. However, when we remove a filter in a layer, this means that there will be a missing activation map, which should be used by all the filters in the next layer. So, not only should we physically remove the filter, but also its corresponding kernel in each of the filters in the next layer (see Fig. below)

![alt text](imgs/pruning_filters.pdf "Pruning Filters")

Let's illustrate this with an example:

In [None]:
path = untar_data(URLs.PETS)

files = get_image_files(path/"images")

def label_func(f): return f[0].isupper()

In [None]:
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(64))

In [None]:
#export
import torch
import torch.nn as nn
import copy
import numpy as np

In [None]:
#export
class Pruner():
    def __init__(self):
        super().__init__()
        
    def filters_to_keep(self, layer, nxt_layer):
        
        ixs = self._get_nz_ixs(layer)
    
        filters_keep = layer.weight.index_select(0, ixs[0]).data # keep only the non_zero filters
        biases_keep = layer.bias.index_select(0, ixs[0]).data
        
        nxt_filters_keep = nxt_layer.weight.index_select(1, ixs[0]).data if nxt_layer is not None else None
            
        return filters_keep, biases_keep, nxt_filters_keep
    
    def prune_conv(self, layer, nxt_layer):
        assert layer.__class__.__name__ == 'Conv2d'
    
        new_weights, new_biases, new_next_weights = self.filters_to_keep(layer, nxt_layer)
    
        layer.out_channels = new_weights.shape[0]
        layer.in_channels = new_weights.shape[1]
    
        layer.weight = nn.Parameter(new_weights)
        layer.bias = nn.Parameter(new_biases)

        if new_next_weights is not None:
            new_next_in_channels = new_next_weights.shape[1]
            nxt_layer.weight = nn.Parameter(new_next_weights)
            nxt_layer.in_channels = new_next_in_channels
    
        return layer, nxt_layer
    
    def prune_bn(self, layer, prev_conv):
        
        ixs = self._get_nz_ixs(prev_conv)
        
        weights_keep = layer.weight.data.index_select(0, ixs[0]).data
    
        layer.num_features = weights_keep.shape[0]
        layer.weight = nn.Parameter(weights_keep)
        layer.bias = nn.Parameter(layer.bias.data.index_select(0, ixs[0]).data)
        layer.running_mean = layer.running_mean.data.index_select(0, ixs[0]).data
        layer.running_var = layer.running_var.data.index_select(0, ixs[0]).data
        
        return layer

    def delete_fc_weights(self, layer, last_conv, pool_shape):
        
        ixs = self._get_nz_ixs(last_conv)
        
        new_ixs = torch.cat([torch.arange(i*pool_shape**2,((i+1)*pool_shape**2)) for i in ixs[0]]) if pool_shape else ixs[0]
        new_ixs = torch.LongTensor(new_ixs).cuda()

        weights_keep = layer.weight.data.index_select(1, new_ixs).data
        
        layer.in_features = weights_keep.shape[1]
        layer.weight = nn.Parameter(weights_keep)
    
        return layer
    
    def _get_nz_ixs(self, layer):
        filters = layer.weight
        nz_filters = filters.data.sum(dim=(1,2,3)) # Flatten the filters to compare them
        ixs = torch.nonzero(nz_filters).T
        return ixs.cuda()
    
    def _find_next_conv(self, model, conv_ix):
        for k,m in enumerate(model.modules()):
            if k > conv_ix and isinstance(m, nn.Conv2d):
                next_conv_ix = k
                break
            else:
                next_conv_ix = None
        return next_conv_ix
    
    def _find_previous_conv(self, model, layer_ix):
        for k,m in reversed(list(enumerate(model.modules()))):
            if k < layer_ix and isinstance(m, nn.Conv2d):
                prev_conv_ix = k
                break
            else:
                prev_conv_ix = None
        return prev_conv_ix    
    
    def _get_last_conv_ix(self, model):
        for k,m in enumerate(list(model.modules())):
            if isinstance(m, nn.Conv2d):
                last_conv_ix = k
        return last_conv_ix
    
    def _get_first_fc_ix(self, model):
        for k,m in enumerate(list(model.modules())):
            if isinstance(m, nn.Linear):
                first_fc_ix = k
                break       
        return first_fc_ix
    
    def _find_pool_shape(self, model):
        for k,m in enumerate(model.modules()):
            if isinstance(m, nn.AdaptiveAvgPool2d):
                output_shape = m.output_size
                break
            else: output_shape=None
        return output_shape    
    
    def prune_model(self, model):
        pruned_model = copy.deepcopy(model)
        
        layer_names = list(dict(pruned_model.named_modules()).keys())
        layers = dict(pruned_model.named_modules())
        old_layers = dict(model.named_modules())
        
        last_conv_ix = self._get_last_conv_ix(pruned_model)
        first_fc_ix = self._get_first_fc_ix(pruned_model)
        
        for k,m in enumerate(list(pruned_model.modules())):
            
            if isinstance(m, nn.Conv2d):
                next_conv_ix = self._find_next_conv(model, k)
                if next_conv_ix is not None: # The conv layer is not the last one
                    new_m, new_next_m = self.prune_conv(m, layers[layer_names[next_conv_ix]]) # Prune the current conv layer
                else:
                    new_m, _ = self.prune_conv(m, None) # Prune the current conv layer without changing the next one
                    
            if isinstance(m, nn.BatchNorm2d):
                new_m = self.prune_bn(m, old_layers[layer_names[self._find_previous_conv(model, k)]])             
                    
            if isinstance(m, nn.Linear) and k==first_fc_ix:
                pool_shape = self._find_pool_shape(model)
                new_m = self.delete_fc_weights(m, old_layers[layer_names[last_conv_ix]], pool_shape[0])

        return pruned_model

In [None]:
learn = Learner(dls, vgg16_bn(num_classes=2), metrics=accuracy)

In [None]:
#hide
def count_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
count_parameters(learn.model)

134277186

Our initial model, a VGG16, possess more than 134 million parameters. Let's see what happens when we make it sparse, on a filter level

In [None]:
sp_cb=SparsifyCallback(end_sparsity=50, granularity='filter', method='local', criteria=large_final, sched_func=sched_onecycle)

In [None]:
learn.fit_one_cycle(3, 1e-3, cbs=sp_cb)

Pruning of filter until a sparsity of 50%


epoch,train_loss,valid_loss,accuracy,time
0,0.630026,0.764953,0.685386,00:13
1,0.609164,0.562719,0.719892,00:14
2,0.541365,0.497277,0.746279,00:13


Saving Weights at epoch 0
Sparsity at the end of epoch 0: 10.43%
Sparsity at the end of epoch 1: 48.29%
Sparsity at the end of epoch 2: 50.00%
Final Sparsity: 50.00


In [None]:
count_parameters(learn.model)

134277186

The total amount of parameters hasn't changed! This is because we only replaced the values by zeroes, leading to a sparse model, but they are still there.

The `Pruner` will take care of removing those useless filters.

In [None]:
pruner = Pruner()
pruned_model = pruner.prune_model(learn.model)

Done! Let's see if the performance is still the same

In [None]:
pruned_learn = Learner(dls, pruned_model.cuda(), metrics=accuracy)

In [None]:
pruned_learn.validate()

(#2) [0.4975821375846863,0.7435724139213562]

In [None]:
count_parameters(pruned_learn.model)

71858210

Now we have 71 million of parameters, approximately 50% of the initial parameters as we asked!