# Sandbox for model pruning using:
- https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

In [14]:
import os
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import torch 
import torchvision.models as models
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import matplotlib.pyplot as plt
from collections import OrderedDict

In [15]:
def prune_model(model, prune_type, prune_percent):
    ''' Sparsifies (L1) model weights with either global or layerwise prune_percent. Currently only pruning Conv2D.
    '''
    if prune_type == 'global':
        print('Globally pruning all Conv2d layers with {} sparsity'.format(prune_percent))
        parameters_to_prune = []
        for name, module in model.named_modules():
            # prune 20% of connections in all 2D-conv layers
            if isinstance(module, torch.nn.Conv2d):
                parameters_to_prune.append((module,'weight'))
        
        prune.global_unstructured(tuple(parameters_to_prune), pruning_method=prune.L1Unstructured, amount=prune_percent)

    elif prune_type == 'per_layer':
        print('Layerwise pruning all Conv2d layers with {} sparsity'.format(prune_percent))
        for name, module in model.named_modules():
            # prune 20% of connections in all 2D-conv layers
            if isinstance(module, torch.nn.Conv2d):
                prune.l1_unstructured(module, name='weight', amount=prune_percent)

    else:
        print('Unknown pruning method: {}'.format(prune_type))

    return model 

In [11]:
# load model
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
            in_channels=3, out_channels=1, init_features=32, pretrained=False)

prune_percent = 0.1
prune_type =  'per_layer' #'global' #'per_layer'
model = prune_model(model, prune_type, prune_percent)

# save pruned version
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        # prune.remove(module, 'weight')
        print("Sparsity in {}: {:.2f}%".format(name, 100. * float(torch.sum(module.weight == 0))/float(module.weight.nelement())))


Using cache found in /home/nikhil/.cache/torch/hub/mateuszbuda_brain-segmentation-pytorch_master
Layerwise pruning all Conv2d layers with 0.1 sparsity
Sparsity in encoder1.enc1conv1: 9.95%
Sparsity in encoder1.enc1conv2: 10.00%
Sparsity in encoder2.enc2conv1: 10.00%
Sparsity in encoder2.enc2conv2: 10.00%
Sparsity in encoder3.enc3conv1: 10.00%
Sparsity in encoder3.enc3conv2: 10.00%
Sparsity in encoder4.enc4conv1: 10.00%
Sparsity in encoder4.enc4conv2: 10.00%
Sparsity in bottleneck.bottleneckconv1: 10.00%
Sparsity in bottleneck.bottleneckconv2: 10.00%
Sparsity in decoder4.dec4conv1: 10.00%
Sparsity in decoder4.dec4conv2: 10.00%
Sparsity in decoder3.dec3conv1: 10.00%
Sparsity in decoder3.dec3conv2: 10.00%
Sparsity in decoder2.dec2conv1: 10.00%
Sparsity in decoder2.dec2conv2: 10.00%
Sparsity in decoder1.dec1conv1: 10.00%
Sparsity in decoder1.dec1conv2: 10.00%
Sparsity in conv: 9.38%
