# Example of filter pruning with fastai

## Import the librairies

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from fasterai.Pruner import *
from fasterai.Sparsifier import *

## Get the data

In [3]:
path = untar_data(URLs.MNIST)

In [4]:
data = (ImageList.from_folder(path)
                .split_by_folder(train='training', valid='testing')
                .label_from_folder()
                .databunch()
                .normalize())

In [5]:
bs, epochs = 64,3

## Create the CNN

In [6]:
class Net(nn.Module):
    def __init__(self, mnist=True):
        super().__init__()
          
        self.conv1 = nn.Conv2d(3, 6, 5, 1)
        self.conv2 = nn.Conv2d(6, 8, 5, 1)
        self.conv3 = nn.Conv2d(8,12, 5, 1)
        self.pool = nn.AdaptiveAvgPool2d((1))
        self.fc1 = nn.Linear(12, 6)
        self.fc2 = nn.Linear(6, 10)
    
    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
learn = Learner(data, Net().cuda(), metrics=accuracy)

## Define the pruning scheduling

In [8]:
# Pruning type supported: "weights" and "filters".
prune_meta = {
        "pruning_type": "filters",
        "starting_step" : 0,
        "current_step": 0,
        "ending_step": epochs * np.ceil(len(data.train_ds)/bs),
        "final_sparsity": 50,
        "initial_sparsity": 0,
        "span": 100
    }

In [9]:
learn.fit(epochs, 1e-3, callbacks=SparsifyCallback(learn, meta=prune_meta))

Pruning of filters until a sparsity of 50%


epoch,train_loss,valid_loss,accuracy,time
0,1.090852,1.024179,0.7106,00:14
1,0.339771,0.315391,0.9023,00:14
2,0.256382,0.258534,0.9178,00:16


Sparsity: 38.58%
Sparsity: 49.45%
Sparsity: 50.00%


By printing the weight of a convolutional layer, we can see that some filters are entierly zero

In [10]:
print(learn.model.conv1.weight)

Parameter containing:
tensor([[[[-0.5228, -0.0126,  0.3633,  0.4228,  0.0974],
          [-0.5749, -0.2330,  0.4947,  0.3954,  0.1764],
          [-0.1482,  0.1135,  0.5132,  0.4862,  0.2257],
          [-0.0302,  0.2682,  0.3962,  0.1128,  0.2342],
          [ 0.1774,  0.1704,  0.1390,  0.0294,  0.1387]],

         [[-0.4536, -0.0539,  0.4892,  0.4711,  0.0133],
          [-0.3654, -0.1947,  0.4054,  0.4770,  0.2275],
          [-0.1587, -0.0174,  0.5705,  0.3663,  0.1356],
          [-0.1126,  0.2630,  0.4412,  0.1609,  0.2611],
          [ 0.0327,  0.1720,  0.2241,  0.0225, -0.0215]],

         [[-0.4995, -0.0284,  0.5097,  0.4110,  0.0322],
          [-0.4508, -0.2233,  0.3995,  0.4642,  0.1835],
          [-0.2945,  0.0252,  0.5992,  0.4894,  0.2228],
          [ 0.0905,  0.3047,  0.2982,  0.1535,  0.2217],
          [ 0.1498,  0.1724,  0.1488,  0.0428,  0.0609]]],


        [[[-0.3981, -0.3902, -0.3447, -0.4797, -0.6005],
          [-0.3788, -0.2374, -0.3164, -0.4427, -0.3264],
 

## Remove the zero filters

In [11]:
pruner = Pruner()

In [12]:
pruned_model = pruner.prune_model(learn.model)

In [13]:
new_learn = Learner(data, pruned_model, metrics =[accuracy])

In [14]:
print(f'The original model had {100*learn.validate()[1]:.2f} % accuracy')

The original model had 91.78 % accuracy


In [15]:
print(f'The original model had {100*new_learn.validate()[1]:.2f} % accuracy')

The original model had 91.78 % accuracy


In [16]:
learn.summary()

Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [6, 24, 24]          456        True      
______________________________________________________________________
Conv2d               [8, 8, 8]            1,208      True      
______________________________________________________________________
Conv2d               [12, 4, 4]           2,412      True      
______________________________________________________________________
AdaptiveAvgPool2d    [12, 1, 1]           0          False     
______________________________________________________________________
Linear               [6]                  78         True      
______________________________________________________________________
Linear               [10]                 70         True      
______________________________________________________________________

Total params: 4,224
Total trainable params: 4,224
Total non-trainable params: 0

In [17]:
new_learn.summary()

Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [3, 24, 24]          228        True      
______________________________________________________________________
Conv2d               [4, 8, 8]            304        True      
______________________________________________________________________
Conv2d               [6, 4, 4]            606        True      
______________________________________________________________________
AdaptiveAvgPool2d    [6, 1, 1]            0          False     
______________________________________________________________________
Linear               [6]                  42         True      
______________________________________________________________________
Linear               [10]                 70         True      
______________________________________________________________________

Total params: 1,250
Total trainable params: 1,250
Total non-trainable params: 0