# Pruning Experiments

The Goal of this notebook is to simplify the code down to exactly what we want to work with, rather than looking at an entire benchmark framework where most of the code and analysis is wasted because we are trying to make it do something it's not meant to do. The code found in this notebook is heavily influenced by fasterai.

## Imports and Setting up Data

Below are the libraries and modules required for most of the cells as well as some basic blocks for data preprocessing

In [35]:
## Magic Commands

# Auto reload modules as changes are made
%load_ext autoreload 
%autoreload 2
from IPython.core.debugger import set_trace


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import fasterai

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import Dataset

import os
os.environ['TORCH_HOME'] = "./models"

from utils import dataset_builder, accuracy, correct

from tqdm import tqdm
from online import OnlineStats

## Loading Data: Cifar 10

Initially, we are looking at the CIFAR-10 Dataset to examine how ResNet-18 Architecture is affected by pruning.

#### Importing the Data

In [3]:
def CIFAR10(train=True, download=False):
    """Thin wrapper around torchvision.datasets.CIFAR10"""
    mean, std = [0.491, 0.482, 0.447], [0.247, 0.243, 0.262]
    normalize = transforms.Normalize(mean=mean, std=std)
    if train:
        preproc = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4)]
    else:
        preproc = []
    dataset = dataset_builder('CIFAR10', train, normalize, preproc, download)
    dataset.shape = (3, 32, 32)
    return dataset

In [4]:
from torch.utils.data import DataLoader

cifar_10_train = CIFAR10(train=True, download=True)
cifar_10_test = CIFAR10(train=False, download=True)

cifar_10_train_dl = DataLoader(cifar_10_train, batch_size=128, num_workers=4, shuffle=True)
cifar_10_test_dl = DataLoader(cifar_10_test, batch_size=128, num_workers=4, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


#### Importing the Models

We are using the pretrained model from pytorch

In [5]:
import torchvision.models as models

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## From PyTorch (Not Trained for Cifar10)
#resnet18 = models.resnet18().to(device)

## From Trained models 
resnet18 = torch.load("./saved/model/ResNet_10.pt").to(device)

In [8]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Running on device {device}")
print(f"ResNet18: {count_parameters(resnet18):,} parameters\n")

# print(f"Running on device {device}")
# print(f"VGG16: {count_parameters(vgg16):,} parameters")

Running on device cuda
ResNet18: 11,689,512 parameters



#### Training the Model

In [9]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [10]:
def train_model(model, num_epochs=5, criterion=nn.CrossEntropyLoss(), optimizer=optim.SGD(resnet18.parameters(), lr=0.0005, momentum=0.9, weight_decay=5e-4)):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for epoch in range(num_epochs):
        trainloader = tqdm(cifar_10_train_dl)
        trainloader.set_description(f"Train Epoch {epoch+1}/{epochs}")

        total_loss = OnlineStats()
        acc1 = OnlineStats()
        acc5 = OnlineStats()

        for i, (inputs, labels) in enumerate(trainloader, 0):
            inputs, labels = inputs.to(device), labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = resnet18(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            ## Print Statistics
            c1, c5 = correct(outputs, labels, (1, 5))
            acc1.add(c1 / cifar_10_train_dl.batch_size)
            acc5.add(c5 / cifar_10_train_dl.batch_size)
            total_loss.add(loss.item() / cifar_10_train_dl.batch_size)

            trainloader.set_postfix(loss=total_loss.mean, top1=acc1.mean, top5=acc5.mean)

#### Saving the Model

In [11]:
def checkpoint_model(model):
    current_model_num = len([name for name in os.listdir('./saved/model/.')]) + 1
    model_name = model.__class__.__name__

    torch.save(model, f"./saved/model/{model_name}_{current_model_num}.pt")
    torch.save(model.state_dict(), f"./saved/state/{model_name}.pt_{current_model_num}")

    print(f"Saved Model Version {current_model_num}!")

### Evaluating The Model

In [12]:
# Get the accuracy of the model
print(f"ResNet18 Accuracy: {accuracy(resnet18, cifar_10_test_dl)}")

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


ResNet18 Accuracy: [0.7484]


### Looking at Fasterai Pruning

Using the Sparsifier class allows us to specify a model, the granularity for which to prune, the method of pruning, and also the criteria used to calculate the 'importance' of each connection.

In [39]:
import torch
import torch.nn as nn
from fastcore.basics import store_attr
from fasterai.sparse.criteria import *

In [51]:
class Sparsifier():

    def __init__(self, model, granularity, method, criteria):
        store_attr()
        self._save_weights() # Save the original weights

    def prune_layer(self, module, sparsity):
        weight = self.criteria(module, self.granularity)
        mask = self._compute_mask(self.model, weight, sparsity)
        module.register_buffer("_mask", mask) # Put the mask into a buffer
        self._apply(module)

    def prune_model(self, sparsity):
        for k, m in enumerate(self.model.modules()):
            if isinstance(m, nn.Conv2d):
                self.prune_layer(m, sparsity)

    def _apply(self, module):
        mask = getattr(module, "_mask")
        module.weight.data.mul_(mask)

        if self.granularity == 'filter': # If we remove complete filters, we want to remove the bias as well
            if module.bias is not None:
                module.bias.data.mul_(mask.squeeze())

    def _mask_grad(self):
        for k, m in enumerate(self.model.modules()):
            if isinstance(m, nn.Conv2d) and hasattr(m, '_mask'):
                mask = getattr(m, "_mask")
                if m.weight.grad is not None: # In case some layers are freezed
                    m.weight.grad.mul_(mask)

                if self.granularity == 'filter': # If we remove complete filters, we want to remove the bias as well
                        if m.bias.grad is not None: # In case some layers are freezed
                            m.bias.grad.mul_(mask.squeeze())

    def _reset_weights(self):
        for k, m in enumerate(self.model.modules()):
            if isinstance(m, nn.Linear):
                init_weights = getattr(m, "_init_weights")
                m.weight.data = init_weights.clone()
            if isinstance(m, nn.Conv2d):
                init_weights = getattr(m, "_init_weights")
                m.weight.data = init_weights.clone()
                self._apply(m) # Reset the weights and apply the current mask

    def _save_weights(self):
        for k, m in enumerate(self.model.modules()):
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                m.register_buffer("_init_weights", m.weight.clone())

    def _clean_buffers(self):
        for k, m in enumerate(self.model.modules()):
            if isinstance(m, nn.Conv2d) and hasattr(m, '_mask'):
                del m._buffers["_mask"]

            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                del m._buffers["_init_weights"]


    def _compute_mask(self, model, weight, sparsity):
        if self.method == 'global':
            global_weight = torch.cat([self.criteria(m, self.granularity).view(-1) for m in model.modules() if isinstance(m, nn.Conv2d)])
            threshold = torch.quantile(global_weight, sparsity/100) # Compute the threshold globally

        elif self.method == 'local':
            threshold = torch.quantile(weight.view(-1), sparsity/100) # Compute the threshold locally

        else: raise NameError('Invalid Method')

        if threshold > weight.max(): threshold = weight.max() # Make sure we don't remove every weight of a given layer

        mask = weight.ge(threshold).to(dtype=weight.dtype)

        return mask

## Now looking at the criterias

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastcore.basics import *
from fastcore.imports import *

In [42]:
class Criteria():
    def __init__(self, f, needs_init=False, needs_update=False, output_f=None, return_init=False):
        store_attr()
        assert (needs_init and needs_update)==False, "The init values will be overwritten by the updating ones."

    def __call__(self, m, granularity):
        if self.needs_update and hasattr(m, '_old_weights') == False:
            m.register_buffer("_old_weights", m._init_weights.clone()) # If the previous value of weights is not known, take the initial value

        if granularity == 'weight':
            wf = self.f(m.weight)
            if self.needs_init: wi = self.f(m._init_weights)
            elif self.needs_update: wi = self.f(m._old_weights)

        elif granularity in granularities:
            dim = granularities[granularity]
            wf = self.f(m.weight).mean(dim=dim, keepdim=True)
            if self.needs_init: wi = self.f(m._init_weights).mean(dim=dim, keepdim=True)
            elif self.needs_update: wi = self.f(m._old_weights).mean(dim=dim, keepdim=True)

        else: raise NameError('Invalid Granularity')

        if self.needs_update: m._old_weights = m.weight.clone() # The current value becomes the old one for the next iteration

        if self.output_f: return self.output_f(wf, wi)
        elif self.return_init: return wi
        else: return wf

In [43]:
## Gradient Criteria
def grad_crit(m, granularity):
    if m.weight.grad is not None:
        if granularity == 'weight':
            w = (m.weight*m.weight.grad).pow(2)

        elif granularity in granularities:
            dim = granularities[granularity]
            w = (m.weight*m.weight.grad).pow(2).mean(dim=dim, keepdim=True)

        else: raise NameError('Invalid Granularity') 

        return w

In [46]:
large_final = Criteria(torch.abs)
weight_pruner = Sparsifier(resnet18, 'weight', 'global', large_final)

grad_pruner = Sparsifier(resnet18, 'weight', 'global', grad_crit)

In [50]:


print(f"Loading Model...")
resnet18 = torch.load("./saved/model/ResNet_10.pt")

print(f"accuracy before pruning: {accuracy(resnet18, cifar_10_test_dl)}")
print(f"Before Pruning: ResNet18: {count_parameters(resnet18):,} parameters\n")

print("pruning with weight pruner...\n")
weight_pruner.prune_model(sparsity=80)

print(f"accuracy after pruning: {accuracy(resnet18, cifar_10_test_dl)}")
print(f"After weight pruning: ResNet18: {count_parameters(resnet18):,} parameters\n")

####################################################################################

print(f"Loading Model...")
resnet18 = torch.load("./saved/model/ResNet_10.pt")

print(f"accuracy before pruning: {accuracy(resnet18, cifar_10_test_dl)}")
print(f"Before Pruning: ResNet18: {count_parameters(resnet18):,} parameters\n")

print("pruning with gradient pruner...\n")
grad_pruner.prune_model(sparsity=80)

print(f"accuracy after pruning: {accuracy(resnet18, cifar_10_test_dl)}")
print(f"After weight pruning: ResNet18: {count_parameters(resnet18):,} parameters\n")

Loading Model...
accuracy before pruning: [0.7484]
Before Pruning: ResNet18: 11,689,512 parameters

pruning with weight pruner...

> [0;32m<ipython-input-40-4107b39a3564>[0m(65)[0;36m_compute_mask[0;34m()[0m
[0;32m     63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mmethod[0m [0;34m==[0m [0;34m'global'[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            [0mglobal_weight[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m[[0m[0mself[0m[0;34m.[0m[0mcriteria[0m[0;34m([0m[0mm[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mgranularity[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m [0;32mfor[0m [0mm[0m [0;32min[0m [0mmodel[0m[0;34m.[0m[0mmodules[0m[0;34m([0m[0;34m)[0m [0;32mif[0m [0misinstance[0m[0;34m([0m[0mm[0m[0;34m,[0m [0mnn[0m[0;34m.[0

ipdb>  r


--Return--
tensor([[[[0....vice='cuda:0')
> [0;32m<ipython-input-40-4107b39a3564>[0m(77)[0;36m_compute_mask[0;34m()[0m
[0;32m     73 [0;31m        [0;32mif[0m [0mthreshold[0m [0;34m>[0m [0mweight[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m:[0m [0mthreshold[0m [0;34m=[0m [0mweight[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m [0;31m# Make sure we don't remove every weight of a given layer[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m[0;32m     75 [0;31m        [0mmask[0m [0;34m=[0m [0mweight[0m[0;34m.[0m[0mge[0m[0;34m([0m[0mthreshold[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdtype[0m[0;34m=[0m[0mweight[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m[0;34m[0m[0m
[0m[0;32m---> 77 [0;31m        [0;32mreturn[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


> [0;32m<ipython-input-40-4107b39a3564>[0m(11)[0;36mprune_layer[0;34m()[0m
[0;32m      9 [0;31m        [0mweight[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcriteria[0m[0;34m([0m[0mmodule[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mgranularity[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     10 [0;31m        [0mmask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_compute_mask[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmodel[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0msparsity[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 11 [0;31m        [0mmodule[0m[0;34m.[0m[0mregister_buffer[0m[0;34m([0m[0;34m"_mask"[0m[0;34m,[0m [0mmask[0m[0;34m)[0m [0;31m# Put the mask into a buffer[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m        [0mself[0m[0;34m.[0m[0m_apply[0m[0;34m([0m[0mmodule[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m[0;34m[0m[0m
[0m


ipdb>  r


--Return--
None
> [0;32m<ipython-input-40-4107b39a3564>[0m(12)[0;36mprune_layer[0;34m()[0m
[0;32m     10 [0;31m        [0mmask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_compute_mask[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmodel[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0msparsity[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m        [0mmodule[0m[0;34m.[0m[0mregister_buffer[0m[0;34m([0m[0;34m"_mask"[0m[0;34m,[0m [0mmask[0m[0;34m)[0m [0;31m# Put the mask into a buffer[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 12 [0;31m        [0mself[0m[0;34m.[0m[0m_apply[0m[0;34m([0m[0mmodule[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m[0;34m[0m[0m
[0m[0;32m     14 [0;31m    [0;32mdef[0m [0mprune_model[0m[0;34m([0m[0mself[0m[0;34m,[0m [0msparsity[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


> [0;32m<ipython-input-40-4107b39a3564>[0m(15)[0;36mprune_model[0;34m()[0m
[0;32m     13 [0;31m[0;34m[0m[0m
[0m[0;32m     14 [0;31m    [0;32mdef[0m [0mprune_model[0m[0;34m([0m[0mself[0m[0;34m,[0m [0msparsity[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 15 [0;31m        [0;32mfor[0m [0mk[0m[0;34m,[0m [0mm[0m [0;32min[0m [0menumerate[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmodel[0m[0;34m.[0m[0mmodules[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m            [0;32mif[0m [0misinstance[0m[0;34m([0m[0mm[0m[0;34m,[0m [0mnn[0m[0;34m.[0m[0mConv2d[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m                [0mself[0m[0;34m.[0m[0mprune_layer[0m[0;34m([0m[0mm[0m[0;34m,[0m [0msparsity[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


> [0;32m<ipython-input-40-4107b39a3564>[0m(65)[0;36m_compute_mask[0;34m()[0m
[0;32m     63 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mmethod[0m [0;34m==[0m [0;34m'global'[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     64 [0;31m            [0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 65 [0;31m            [0mglobal_weight[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mcat[0m[0;34m([0m[0;34m[[0m[0mself[0m[0;34m.[0m[0mcriteria[0m[0;34m([0m[0mm[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mgranularity[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m [0;32mfor[0m [0mm[0m [0;32min[0m [0mmodel[0m[0;34m.[0m[0mmodules[0m[0;34m([0m[0;34m)[0m [0;32mif[0m [0misinstance[0m[0;34m([0m[0mm[0m[0;34m,[0m [0mnn[0m[0;34m.[0m[0mConv2d[0m[0;34m)[0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     66 [0;31m            [0mthreshold[

ipdb>  r


--Return--
tensor([[[[1....vice='cuda:0')
> [0;32m<ipython-input-40-4107b39a3564>[0m(77)[0;36m_compute_mask[0;34m()[0m
[0;32m     73 [0;31m        [0;32mif[0m [0mthreshold[0m [0;34m>[0m [0mweight[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m[0;34m:[0m [0mthreshold[0m [0;34m=[0m [0mweight[0m[0;34m.[0m[0mmax[0m[0;34m([0m[0;34m)[0m [0;31m# Make sure we don't remove every weight of a given layer[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     74 [0;31m[0;34m[0m[0m
[0m[0;32m     75 [0;31m        [0mmask[0m [0;34m=[0m [0mweight[0m[0;34m.[0m[0mge[0m[0;34m([0m[0mthreshold[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdtype[0m[0;34m=[0m[0mweight[0m[0;34m.[0m[0mdtype[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     76 [0;31m[0;34m[0m[0m
[0m[0;32m---> 77 [0;31m        [0;32mreturn[0m [0mmask[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  r


> [0;32m<ipython-input-40-4107b39a3564>[0m(11)[0;36mprune_layer[0;34m()[0m
[0;32m      9 [0;31m        [0mweight[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcriteria[0m[0;34m([0m[0mmodule[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mgranularity[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     10 [0;31m        [0mmask[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_compute_mask[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mmodel[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0msparsity[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 11 [0;31m        [0mmodule[0m[0;34m.[0m[0mregister_buffer[0m[0;34m([0m[0;34m"_mask"[0m[0;34m,[0m [0mmask[0m[0;34m)[0m [0;31m# Put the mask into a buffer[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m        [0mself[0m[0;34m.[0m[0m_apply[0m[0;34m([0m[0mmodule[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m[0;34m[0m[0m
[0m


ipdb>  run


Restart: 

## Viisualizing the Pruning

In [11]:
def plot_kernels(layer, save=None):
    kernels = layer.weight.detach().clone()
    kernels = kernels - kernels.min()
    kernels = kernels/kernels.max()
    
    plt.figure(figsize=(10,10))
    img = make_grid(kernels, nrow=8, padding=1, pad_value=1)
    plt.axis('off')
    plt.imshow(img.detach().permute(1,2,0).cpu())
    if save: plt.savefig(f'{save}.pdf')

## Visualizing The Data

Code to view 9 random images in the training data, as well as use the Dataloader to ensure everything is working as expected

In [None]:
import matplotlib.pyplot as plt
import numpy as np

labels_map = {
    0: "plane",
    1: "car",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(cifar_10_train), size=(1,)).item()
    img, label = cifar_10_train[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    imshow(img)
plt.show()


# Display image and label.
train_features, train_labels = next(iter(cifar_10_train_dl))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
imshow(img)
plt.show()
print(f"Label: {label}")