In [1]:
import torch.nn.utils.prune as prune

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision 
import torchvision.transforms as transforms
from collections import OrderedDict
import time
import os
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from datetime import datetime


In [2]:
n_epochs = 10
mini_batch_size = 128
learning_rate = 0.01
momentum = 0.9
log_interval = 10

In [3]:
train_transforms = transforms.Compose([\
                transforms.Resize((32,32)),\
                transforms.ToTensor(),\
                transforms.Normalize(mean = (0.1307,), std = (0.3081,))\
])

test_transforms = transforms.Compose([\
                transforms.Resize((32,32)),\
                transforms.ToTensor(),\
                transforms.Normalize(mean = (0.1325,), std = (0.3105,))\
])

train_loader = DataLoader(torchvision.datasets.MNIST('data/',train=True, download=True, transform=train_transforms),batch_size=mini_batch_size)

test_loader = DataLoader(torchvision.datasets.MNIST('data/',train=False, download=True, transform=test_transforms), batch_size=mini_batch_size)

In [4]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.convnet = nn.Sequential(OrderedDict([
            ('c1' , nn.Conv2d(1,6, kernel_size=(5,5),) ),
            ('relu1', nn.ReLU()),
            ('s2', nn.MaxPool2d(kernel_size=(2,2), stride=2)),
            ('c3', nn.Conv2d(6, 16, kernel_size=(5,5), )),
            ('relu2' , nn.ReLU()),
            ('s4', nn.MaxPool2d(kernel_size=(2,2), stride=2)) ,
            ('c5', nn.Conv2d(16, 120, kernel_size=(5,5)) ),
            ('relu3', nn.ReLU() ),]
        ))
        self.fc = nn.Sequential(OrderedDict([
            ('f6', nn.Linear(120, 84) ),
            ('relu6', nn.ReLU() ),
            ('f7', nn.Linear(84,10) ),
            ('sig7', nn.LogSoftmax(dim=-1) )
            ]
        ))


    def forward(self, x):
        out = self.convnet(x)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)

        return out


In [15]:
class ClassifierTraining:

    def __init__(self, classifer_model, train_loader: DataLoader, test_loader: DataLoader, prune_version=None):
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = classifer_model.to(self.device)
        # self.model.to(self.device)
        self.optimizer = optim.SGD(self.model.parameters(), lr=learning_rate, momentum=momentum)
        self.train_losses = []
        self.train_counter = []
        self.test_counter = []
        self.test_losses = []
        self.train_loader = train_loader
        self.test_loader = test_loader
        if prune_version is None:
            self.version = 'base'
        else:
            self.version = prune_version

    def train(self, epoch, verbose=True):
        print(self.device)
        os.makedirs('results/',exist_ok=True)

        for batch_idx, (data,target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = F.nll_loss(output, target) # cross entropy loss
            loss.backward()
            self.optimizer.step()
            if batch_idx % log_interval == 0:
                if verbose:
                    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx * len(data), len(self.train_loader.dataset),\
                        100. * batch_idx / len(self.train_loader), loss.item()))
                self.train_losses.append(loss.item())
                self.train_counter.append(
                    (batch_idx * 128) + ((epoch - 1) * len(self.train_loader.dataset)))
                torch.save(self.model.state_dict(), f'results/model_{self.version}.pth')
                torch.save(self.optimizer.state_dict(), f'results/optimizer_{self.version}.pth')
    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                test_loss += F.nll_loss(output, target, size_average=True).item()
                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(target.data.view_as(pred)).sum()

        test_loss /= len(self.test_loader.dataset)
        self.test_losses.append(test_loss)
        print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(self.test_loader.dataset),
            100. * correct / len(self.test_loader.dataset)))

    def train_model(self,n_epoch):
        train_start_dt = datetime.now()
        self.test_counter = [i * len(self.train_loader.dataset) for i in range(n_epoch)]

        for ep in range(1, n_epoch+1):
            self.train(ep)
            self.test()
        train_end_dt = datetime.now()
        train_duration = (train_end_dt-train_start_dt)

        print(f"Total Training Time : {train_duration}")
        self.plot_train_test()

    def plot_train_test(self):
        fig = plt.figure()
        plt.plot(self.train_counter, self.train_losses, color='blue')
        plt.scatter(self.test_counter, self.test_losses, color='red')
        plt.legend(['Training Loss', 'Test Loss'], loc='upper right')
        plt.xlabel('No of Training Samples')
        plt.ylabel('Loss Function (NLL)')
        fig.savefig(f'results/training_curve_{self.version}.png')

In [17]:
prune_model = LeNet()

prune_model.load_state_dict(torch.load('results/model_base.pth'))

parameters_to_prune = (
    (prune_model.convnet.c1,'weight'),
    (prune_model.convnet.c3,'weight'),
    (prune_model.convnet.c5,'weight'),
    (prune_model.fc.f6,'weight'),
    (prune_model.fc.f7,'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.RandomUnstructured,
    amount=0.1,
)


prune.remove(prune_model.convnet.c1,'weight')
prune.remove(prune_model.convnet.c3,'weight')
prune.remove(prune_model.convnet.c5,'weight')
prune.remove(prune_model.fc.f6,'weight')
prune.remove(prune_model.fc.f7,'weight')

pruned_classifer = ClassifierTraining(classifer_model=prune_model, train_loader=train_loader, test_loader=test_loader)
pruned_classifer.train_model(n_epoch=n_epochs)

pruned_params = get_pruned_parameters_count(prune_model)
print('Original Model paramete count:', total_params_count)
print('Pruned Model parameter count:', pruned_model_param_count)
print(f'Compressed Percentage: {(100 - (pruned_model_param_count / total_params_count) * 100)}%')