# Import the necessary libraries

In [1]:
#Ref https://github.com/wanglouis49/pytorch-weights_pruning/tree/master
#Ref https://pytorch.org/tutorials/intermediate/pruning_tutorial.html
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune
from torch.autograd import Variable
import matplotlib.pyplot as plt
from pathlib import Path
import os
import numpy as np

# Utility Functions

In [2]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    model_size = os.path.getsize("temp.p")/1e3
    print('Size (KB):', model_size)
    os.remove('temp.p')
    return model_size

In [25]:
class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super(MaskedLinear, self).__init__(in_features, out_features, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False, volatile=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.linear(x, weight, self.bias)
        else:
            return F.linear(x, self.weight, self.bias)

In [26]:
class MaskedConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(MaskedConv2d, self).__init__(in_channels, out_channels, 
            kernel_size, stride, padding, dilation, groups, bias)
        self.mask_flag = False
    
    def set_mask(self, mask):
        self.mask = Variable(mask, requires_grad=False, volatile=False)
        self.weight.data = self.weight.data*self.mask.data
        self.mask_flag = True
    
    def get_mask(self):
        print(self.mask_flag)
        return self.mask
    
    def forward(self, x):
        if self.mask_flag == True:
            weight = self.weight*self.mask
            return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)
        else:
            return F.conv2d(x, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

In [5]:
def prune_one_filter(model, masks):
    '''
    Pruning one least ``important'' feature map by the scaled l2norm of 
    kernel weights
    arXiv:1611.06440
    '''
    NO_MASKS = False
    # construct masks if there is not yet
    if not masks:
        masks = []
        NO_MASKS = True

    values = []
    for p in model.parameters():

        if len(p.data.size()) == 4: # nasty way of selecting conv layer
            p_np = p.data.cpu().numpy()

            # construct masks if there is not
            if NO_MASKS:
                masks.append(np.ones(p_np.shape).astype('float32'))

            # find the scaled l2 norm for each filter this layer
            value_this_layer = np.square(p_np).sum(axis=1).sum(axis=1)\
                .sum(axis=1)/(p_np.shape[1]*p_np.shape[2]*p_np.shape[3])
            # normalization (important)
            value_this_layer = value_this_layer / \
                np.sqrt(np.square(value_this_layer).sum())
            min_value, min_ind = arg_nonzero_min(list(value_this_layer))
            values.append([min_value, min_ind])

    assert len(masks) == len(values), "something wrong here"

    values = np.array(values)

    # set mask corresponding to the filter to prune
    to_prune_layer_ind = np.argmin(values[:, 0])
    to_prune_filter_ind = int(values[to_prune_layer_ind, 1])
    masks[to_prune_layer_ind][to_prune_filter_ind] = 0.

    print('Prune filter #{} in layer #{}'.format(
        to_prune_filter_ind, 
        to_prune_layer_ind))

    return masks


def filter_prune(model, pruning_perc):
    '''
    Prune filters one by one until reach pruning_perc
    (not iterative pruning)
    '''
    masks = []
    current_pruning_perc = 0.

    while current_pruning_perc < pruning_perc:
        masks = prune_one_filter(model, masks)
        model.set_masks(masks)
        current_pruning_perc = prune_rate(model, verbose=False)
        print('{:.2f} pruned'.format(current_pruning_perc))

    return masks

In [6]:
def prune_rate(model, verbose=True):
    """
    Print out prune rate for each layer and the whole network
    """
    total_nb_param = 0
    nb_zero_param = 0

    layer_id = 0

    for parameter in model.parameters():

        param_this_layer = 1
        for dim in parameter.data.size():
            param_this_layer *= dim
        total_nb_param += param_this_layer

        # only pruning linear and conv layers
        if len(parameter.data.size()) != 1:
            layer_id += 1
            zero_param_this_layer = \
                np.count_nonzero(parameter.cpu().data.numpy()==0)
            nb_zero_param += zero_param_this_layer

            if verbose:
                print("Layer {} | {} layer | {:.2f}% parameters pruned" \
                    .format(
                        layer_id,
                        'Conv' if len(parameter.data.size()) == 4 \
                            else 'Linear',
                        100.*zero_param_this_layer/param_this_layer,
                        ))
    pruning_perc = 100.*nb_zero_param/total_nb_param
    if verbose:
        print("Final pruning rate: {:.2f}%".format(pruning_perc))
    return pruning_perc

In [7]:
def arg_nonzero_min(a):
    """
    nonzero argmin of a non-negative array
    """

    if not a:
        return

    min_ix, min_v = None, None
    # find the starting value (should be nonzero)
    for i, e in enumerate(a):
        if e != 0:
            min_ix = i
            min_v = e
    if not min_ix:
        print('Warning: all zero')
        return np.inf, np.inf

    # search for the smallest nonzero
    for i, e in enumerate(a):
         if e < min_v and e != 0:
            min_v = e
            min_ix = i

    return min_v, min_ix

# Load the MNIST dataset

In [8]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [9]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

# Define the non-pruned model

In [10]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(2)

        self.conv3 =nn.Conv2d(64, 64, kernel_size=3, padding=1, stride=1)
        self.relu3 = nn.ReLU(inplace=True)

        self.linear1 = nn.Linear(7*7*64, 10)
        
    def forward(self, x):
        out = self.maxpool1(self.relu1(self.conv1(x)))
        out = self.maxpool2(self.relu2(self.conv2(out)))
        out = self.relu3(self.conv3(out))
        out = out.view(out.size(0), -1)
        out = self.linear1(out)
        return out

# Hyperparameters and setting the model

In [11]:
# Hyperparameters
batch_size = 10
learning_rate = 0.001
num_epochs = 1

In [12]:
model = ConvNet().to(device)

In [13]:
def train(train_loader, model, epochs=5):
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(epochs):
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            
            if (i+1) % 500 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
            
            loss.backward()
            optimizer.step()
           

In [14]:
train(train_loader, model, epochs=num_epochs)

Epoch [1/1], Step [500/6000], Loss: 0.0115
Epoch [1/1], Step [1000/6000], Loss: 0.0255
Epoch [1/1], Step [1500/6000], Loss: 0.0517
Epoch [1/1], Step [2000/6000], Loss: 0.0784
Epoch [1/1], Step [2500/6000], Loss: 0.0009
Epoch [1/1], Step [3000/6000], Loss: 0.0062
Epoch [1/1], Step [3500/6000], Loss: 0.0125
Epoch [1/1], Step [4000/6000], Loss: 0.0015
Epoch [1/1], Step [4500/6000], Loss: 0.0003
Epoch [1/1], Step [5000/6000], Loss: 0.0885
Epoch [1/1], Step [5500/6000], Loss: 0.0506
Epoch [1/1], Step [6000/6000], Loss: 0.0358


In [15]:
def test(test_loader,model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            for idx, i in enumerate(outputs):
                if torch.argmax(i) == labels[idx]:
                    correct +=1
                total +=1
                
    print(f'Accuracy: {round(correct/total, 3)}')

In [16]:
test(test_loader,model)

Accuracy: 0.989


In [17]:
torch.save(model.state_dict(), 'CNN.pt')

In [27]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        self.conv1 = MaskedConv2d(1, 32, kernel_size=3, padding=1, stride=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(2)

        self.conv2 = MaskedConv2d(32, 64, kernel_size=3, padding=1, stride=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(2)

        self.conv3 = MaskedConv2d(64, 64, kernel_size=3, padding=1, stride=1)
        self.relu3 = nn.ReLU(inplace=True)

        self.linear1 = nn.Linear(7*7*64, 10)
        
    def forward(self, x):
        out = self.maxpool1(self.relu1(self.conv1(x)))
        out = self.maxpool2(self.relu2(self.conv2(out)))
        out = self.relu3(self.conv3(out))
        out = out.view(out.size(0), -1)
        out = self.linear1(out)
        return out

    def set_masks(self, masks):
        # Should be a less manual way to set masks
        # Leave it for the future
        self.conv1.set_mask(torch.from_numpy(masks[0]))
        self.conv2.set_mask(torch.from_numpy(masks[1]))
        self.conv3.set_mask(torch.from_numpy(masks[2]))

In [28]:
newmodel = ConvNet().to(device)
newmodel.load_state_dict(torch.load('CNN.pt'))
print('Loaded model from disk')

Loaded model from disk


In [29]:
test(test_loader,newmodel)

Accuracy: 0.989


In [32]:
# prune the weights
pruning_perc =  50.
masks = filter_prune(newmodel, pruning_perc)
newmodel.set_masks(masks)
print("--- {}% parameters pruned ---".format(pruning_perc))
test(test_loader, newmodel)

Prune filter #1 in layer #1
0.99 pruned
Prune filter #61 in layer #1
1.32 pruned
Prune filter #11 in layer #1
1.65 pruned
Prune filter #59 in layer #1
1.98 pruned
Prune filter #16 in layer #1
2.31 pruned
Prune filter #26 in layer #1
2.64 pruned
Prune filter #23 in layer #0
2.66 pruned
Prune filter #42 in layer #1
2.99 pruned
Prune filter #31 in layer #2
3.65 pruned
Prune filter #29 in layer #0
3.66 pruned
Prune filter #14 in layer #2
4.32 pruned
Prune filter #52 in layer #1
4.65 pruned
Prune filter #37 in layer #1
4.98 pruned
Prune filter #15 in layer #2
5.64 pruned
Prune filter #55 in layer #1
5.97 pruned
Prune filter #56 in layer #2
6.63 pruned
Prune filter #15 in layer #0
6.64 pruned
Prune filter #24 in layer #2
7.30 pruned
Prune filter #13 in layer #0
7.31 pruned
Prune filter #32 in layer #2
7.98 pruned
Prune filter #12 in layer #2
8.64 pruned
Prune filter #7 in layer #1
8.97 pruned
Prune filter #34 in layer #2
9.63 pruned
Prune filter #63 in layer #1
9.96 pruned
Prune filter #32 i

In [33]:
train(train_loader, newmodel, epochs=1)

Epoch [1/1], Step [500/6000], Loss: 0.0629
Epoch [1/1], Step [1000/6000], Loss: 0.0038
Epoch [1/1], Step [1500/6000], Loss: 0.0302
Epoch [1/1], Step [2000/6000], Loss: 0.1966
Epoch [1/1], Step [2500/6000], Loss: 0.0041
Epoch [1/1], Step [3000/6000], Loss: 0.0045
Epoch [1/1], Step [3500/6000], Loss: 0.1324
Epoch [1/1], Step [4000/6000], Loss: 0.0145
Epoch [1/1], Step [4500/6000], Loss: 0.0699
Epoch [1/1], Step [5000/6000], Loss: 0.1469
Epoch [1/1], Step [5500/6000], Loss: 0.0438
Epoch [1/1], Step [6000/6000], Loss: 0.0004


In [34]:
# Check accuracy and nonzeros weights in each layer
print("--- After retraining ---")
test(test_loader,newmodel)
prune_rate(newmodel)

--- After retraining ---
Accuracy: 0.988
Layer 1 | Conv layer | 59.38% parameters pruned
Layer 2 | Conv layer | 81.25% parameters pruned
Layer 3 | Conv layer | 78.12% parameters pruned
Layer 4 | Linear layer | 0.00% parameters pruned
Final pruning rate: 50.45%


50.44768923479578

In [35]:
# Save and load the entire model
torch.save(model.state_dict(), 'CNN_Pruned.pt')