# Network Pruning

When you sparsify your network, you don't really remove the parameters so you don't take advantage of it. There are a lot of research being done to accelerate computation on sparse matrices and we expect to have it implemented in PyTorch soon: [see `torch.sparse`](https://pytorch.org/docs/stable/sparse.html)

What I mean here by pruning is the process of completely remove the sparsified parameters. This can thus only be done when the granularity is at the level of complete filters as it doesn't make sense to remove a single parameter.

 > Note: only Sequential feed-forward networks are supported for now.

In [1]:
from fastai.vision import *
from fastai.callbacks import *

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
path = untar_data(URLs.IMAGENETTE_160)

In [4]:
data = (ImageList.from_folder(path)
                .split_by_folder(train='train', valid='val')
                .label_from_folder()
                .transform(get_transforms(), size=128)
                .databunch(bs=64)
                .normalize(imagenet_stats))

In [5]:
from fasterai.sparsifier import *

## Network

In [6]:
class Net(nn.Module):
    def __init__(self, mnist=True):
        super().__init__()
          
        self.conv1 = nn.Conv2d(3, 12, 5, 1)
        self.conv2 = nn.Conv2d(12, 24, 5, 1)
        self.conv3 = nn.Conv2d(24,36, 5, 1)
        self.pool = nn.AdaptiveAvgPool2d((1))
        self.fc1 = nn.Linear(36, 18)
        self.fc2 = nn.Linear(18, 10)
    
    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [7]:
class Net(nn.Module):
    def __init__(self, mnist=True):
        super().__init__()
          
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.bn1 = nn.BatchNorm2d(20)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.bn2 = nn.BatchNorm2d(50)
        self.conv3 = nn.Conv2d(50, 100, 5, 1)
        self.bn3 = nn.BatchNorm2d(100)
        self.pool = nn.AdaptiveAvgPool2d((1))
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 10)
    
    def forward(self, x):

        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = self.bn2(x)
        x = F.relu(self.conv3(x))
        x = self.bn3(x)
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [8]:
learn = Learner(data, Net().cuda(), metrics=[accuracy])

In [10]:
learn.fit_one_cycle(3, 1e-3, callbacks=[SparsifyCallback(learn, sparsity=30, granularity='filter', method='local', criteria='l1', sched_func=annealing_cos)])

Pruning of filter until a sparsity of 30%


epoch,train_loss,valid_loss,accuracy,time
0,1.846674,1.647472,0.437707,00:08
1,1.510728,1.368574,0.556943,00:08
2,1.357512,1.366837,0.556178,00:08


Sparsity at epoch 0: 7.59%
Sparsity at epoch 1: 22.59%
Sparsity at epoch 2: 30.00%
Final Sparsity: 30.00


So now our network has only half of its parameters that are used and still is able to achieve almost the same accuracy as when it was using 100% !

Let's double check that we correctly removed the parameters:

In [11]:
for k,m in enumerate(learn.model.modules()):
    if isinstance(m, nn.Conv2d):
        print(f"Sparsity in {m.__class__.__name__} {k}: {100. * float(torch.sum(m.weight == 0))/ float(m.weight.nelement()):.2f}%")

Sparsity in Conv2d 1: 30.00%
Sparsity in Conv2d 3: 30.00%
Sparsity in Conv2d 5: 30.00%


In [12]:
from fasterai.pruner import *

In [13]:
pruner = Pruner()

In [14]:
pruned_model = pruner.prune_model(learn.model)

In [15]:
pruned_learn = Learner(data, pruned_model, metrics =[accuracy])

In [16]:
print(f'The original model had {100*learn.validate()[1]:.2f} % accuracy')

The original model had 55.62 % accuracy


In [17]:
print(f'The pruned model has {100*pruned_learn.validate()[1]:.2f} % accuracy')

The pruned model has 57.25 % accuracy


Surprinsingly, the pruned model has even better accuracy

In [18]:
learn.summary()

Net
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [20, 124, 124]       1,520      True      
______________________________________________________________________
BatchNorm2d          [20, 124, 124]       40         True      
______________________________________________________________________
Conv2d               [50, 58, 58]         25,050     True      
______________________________________________________________________
BatchNorm2d          [50, 58, 58]         100        True      
______________________________________________________________________
Conv2d               [100, 54, 54]        125,100    True      
______________________________________________________________________
BatchNorm2d          [100, 54, 54]        200        True      
______________________________________________________________________
AdaptiveAvgPool2d    [100, 1, 1]          0          False     
__________________________________________________________

In [19]:
pruned_learn.summary()

Net
Layer (type)         Output Shape         Param #    Trainable 
Conv2d               [14, 124, 124]       1,064      True      
______________________________________________________________________
BatchNorm2d          [14, 124, 124]       28         True      
______________________________________________________________________
Conv2d               [35, 58, 58]         12,285     True      
______________________________________________________________________
BatchNorm2d          [35, 58, 58]         70         True      
______________________________________________________________________
Conv2d               [70, 54, 54]         61,320     True      
______________________________________________________________________
BatchNorm2d          [70, 54, 54]         140        True      
______________________________________________________________________
AdaptiveAvgPool2d    [70, 1, 1]           0          False     
__________________________________________________________

We can see now that our network has a lot of parameters removed, because we removed the filters that were not useful anymore (all of their weight were 0).

### To-Do: generalize to other feed-forward models (e.g VGG16)