In [1]:
import torch
import torch.optim
import torch.utils.data
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import random
import numpy as np
import time

# Import CIFAR-10 dataset

In [2]:
def load_dataset(batch_size):
    mean = [x/255 for x in [125.30691805, 122.95039414, 113.86538318]]
    std = [x/255 for x in [62.99321928, 62.08870764, 66.70489964]]
    
    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

# Mask class for pruning

In [3]:
# TODO: pruning epoch parameter
# save model and record time taken to predict/validate on test set (inference time)
# plot training curves

class Mask:
    """
    Used to determine the least imporant filters and to prune them
    """
    
    def __init__(self, model, rate):
        self.model = model
        self.rate = rate
        self.model_size = {}
        self.model_length = {}
        self.masks = {}
        
    
    def init_length(self):
        """
        Initialize size of parameters for each layer
        model_size: shape of weights
        model_length: number of parameters
        """
        for index, item in enumerate(self.model.parameters()):
            self.model_size[index] = item.size()
            self.model_length[index] = np.prod(list(item.size()))

    
    def get_mask(self, weights, length):
        """
        Get mask for one conv layer
        """
        mask = np.ones(length)
        if len(weights.size())==4: # check if it is conv layer
            num_prune = int(weights.size()[0]*self.rate)
            kernel_params = weights.view(weights.size()[0],-1)
            kernel_length = kernel_params.size()[1]
            
            norm = torch.norm(kernel_params, 2, 1).cpu().numpy()
            sorted_index = norm.argsort()[:num_prune]
                       
            for i in range(len(sorted_index)):
                index = sorted_index[i]
                # set elements corresponding to all weights in that kernel to zero
                mask[index*kernel_length:(index+1)*kernel_length] = 0    
        
        return mask
            

    def init_mask(self):
        """
        Get masks for each conv layer
        """
        for index, item in enumerate(self.model.parameters()):
            self.masks[index] = self.get_mask(item.data, self.model_length[index])
            self.masks[index] = torch.FloatTensor(self.masks[index])
            
    def do_mask(self):
        """
        Set weights of pruned kernels to zero by multiplying weights with the mask
        """
        for index, item in enumerate(self.model.parameters()):
            weights = item.data.view(self.model_length[index])
            masked_weights = self.masks[index] * weights
            item.data = masked_weights.view(self.model_size[index])
            

    def count_zero(self): # should we change this to sum across all layers?
        """
        Prints number of zero and non-zero weights in each layer
        """
        for index, item in enumerate(self.model.parameters()):
            if len(item.data.size())==4: # check if is conv layer
                weights = item.data.view(self.model_length[index])
                weights = weights.cpu().numpy()
                non_zero = np.count_nonzero(weights)
                zero = len(weights) - non_zero

                print("Layer: {} - No. of zero weights = {} - No. of non-zero weights = {}".format(
                    index, zero, non_zero))

            

# Model

In [136]:
model = models.resnet18(pretrained=True)

In [138]:
m = Mask(model, 0.2)
m.init_length()
m.init_mask()
m.do_mask()

In [139]:
# check that masking function is working
for index, item in enumerate(model.parameters()):
    for i in range(item.data.size()[0]): 
        print(i)
        print(item.data[i])
    break

0
tensor([[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  7.4841e-02,  5.6615e-02,
           1.7083e-02, -1.2694e-02],
         [ 1.1083e-02,  9.5276e-03, -1.0993e-01, -2.8050e-01, -2.7124e-01,
          -1.2907e-01,  3.7424e-03],
         [-6.9434e-03,  5.9089e-02,  2.9548e-01,  5.8720e-01,  5.1972e-01,
           2.5632e-01,  6.3573e-02],
         [ 3.0505e-02, -6.7018e-02, -2.9841e-01, -4.3868e-01, -2.7085e-01,
          -6.1282e-04,  5.7602e-02],
         [-2.7535e-02,  1.6045e-02,  7.2595e-02, -5.4102e-02, -3.3285e-01,
          -4.2058e-01, -2.5781e-01],
         [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  2.3897e-01,  4.1384e-01,
           3.9359e-01,  1.6606e-01],
         [-1.3736e-02, -3.6746e-03, -2.4084e-02, -6.5877e-02, -1.5070e-01,
          -8.2230e-02, -5.7828e-03]],

        [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  3.6812e-02,  3.2521e-02,
           6.6221e-04, -2.5743e-02],
         [ 4.5687e-02,  3.3603e-02, -1.0453e-01, -3.0885e-01, -3.1253e-01,
          -1.6051e-01, -1

In [4]:
def metrics(model, device, criterion, data_loader):
    """
    Returns loss and accuracy
    """
    model.eval()
    loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)
            data = torch.autograd.Variable(data)
            target = torch.autograd.Variable(target)
            
            output = model(data)
            loss += criterion(output, target)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    loss /= len(data_loader)
    acc = 100. * correct / len(data_loader.dataset)
    
    
    return loss, acc
            

In [5]:
def train(model, device, criterion, train_loader, test_loader, optimizer, epoch):
    """
    Forward and backward pass
    """
    model.train()
    start_time = time.time()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data = torch.autograd.Variable(data)
        target = torch.autograd.Variable(target)
        
        output = model(data)
        loss = criterion(output, target)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    end_time = time.time()
    time_taken = end_time - start_time
    train_loss, train_acc = metrics(model, device, criterion, train_loader)
    print("Train Epoch: {} - Time taken: {:.1f}s\nTrain Loss: {:.4f} - Train Accuracy: {:.1f}%".format(epoch, time_taken, train_loss, train_acc))
    
    return train_loss, train_acc

In [6]:
def run(arch='resnet18', pruning_rate=0.1, epochs=10, batch_size=128, learning_rate=0.1, momentum=0.9, decay=0.0005):
    """
    Do training and pruning
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = load_dataset(batch_size)
    num_classes = 10 # cifar-10
    model = models.__dict__[arch](num_classes)
    
    model = model.to(device)
    
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=decay, nesterov=True)
    criterion = torch.nn.CrossEntropyLoss()
    
    _, test_acc = metrics(model, device, criterion, test_loader)
    print("Test Accuracy before pruning: {:.1f}%".format(test_acc))
    
    m = Mask(model.cpu(), pruning_rate)
    m.init_length()
    
    # prune the filters
    m.init_mask()
    m.do_mask()
    model = m.model
    model.to(device)
    
    _, test_acc = metrics(model, device, criterion, test_loader)
    print("Test Accuracy after pruning: {:.1f}%".format(test_acc))
    
    
    # Main training loop
    for epoch in range(1, epochs+1):
        # train for one epoch
        train_loss, train_acc = train(model, device, criterion, train_loader, test_loader, optimizer, epoch)
        
        _, test_acc = metrics(model, device, criterion, test_loader)
        print("Test Accuracy before pruning: {:.1f}%".format(test_acc))
        
        # pruning and comparing
        m.model = model.cpu()
#         print("Before pruning:")
#         m.count_zero()
        m.init_mask()
        m.do_mask()
#         print("After pruning:")
#         m.count_zero()
        model = m.model
        model.to(device)
        
        _, test_acc = metrics(model, device, criterion, test_loader)
        print("Test Accuracy after pruning: {:.1f}%".format(test_acc))


In [10]:
run(arch='resnet18', pruning_rate=0, epochs=50, batch_size=128, learning_rate=0.1)

Files already downloaded and verified
Files already downloaded and verified
Test Accuracy before pruning: 8.5%
Test Accuracy after pruning: 8.5%
Train Epoch: 1 - Time taken: 107.6s
Train Loss: 2.0730 - Train Accuracy: 24.4%
Test Accuracy before pruning: 25.8%
Test Accuracy after pruning: 25.8%
Train Epoch: 2 - Time taken: 117.3s
Train Loss: 1.7402 - Train Accuracy: 34.0%
Test Accuracy before pruning: 36.0%
Test Accuracy after pruning: 36.0%
Train Epoch: 3 - Time taken: 118.8s
Train Loss: 1.5959 - Train Accuracy: 41.3%
Test Accuracy before pruning: 42.8%
Test Accuracy after pruning: 42.8%
Train Epoch: 4 - Time taken: 120.1s
Train Loss: 1.4500 - Train Accuracy: 46.7%
Test Accuracy before pruning: 49.6%
Test Accuracy after pruning: 49.6%
Train Epoch: 5 - Time taken: 117.9s
Train Loss: 1.4674 - Train Accuracy: 48.0%
Test Accuracy before pruning: 50.5%
Test Accuracy after pruning: 50.5%
Train Epoch: 6 - Time taken: 170.0s
Train Loss: 1.2931 - Train Accuracy: 53.9%
Test Accuracy before pruni

In [11]:
run(arch='resnet18', pruning_rate=0.1, epochs=50, batch_size=128, learning_rate=0.1)

Files already downloaded and verified
Files already downloaded and verified
Test Accuracy before pruning: 8.3%
Test Accuracy after pruning: 9.4%
Train Epoch: 1 - Time taken: 118.1s
Train Loss: 1.8748 - Train Accuracy: 32.9%
Test Accuracy before pruning: 35.4%
Test Accuracy after pruning: 34.7%
Train Epoch: 2 - Time taken: 109.9s
Train Loss: 1.6857 - Train Accuracy: 42.6%
Test Accuracy before pruning: 44.4%
Test Accuracy after pruning: 43.5%
Train Epoch: 3 - Time taken: 111.2s
Train Loss: 1.4273 - Train Accuracy: 48.1%
Test Accuracy before pruning: 49.5%
Test Accuracy after pruning: 49.6%
Train Epoch: 4 - Time taken: 110.1s
Train Loss: 1.4467 - Train Accuracy: 49.2%
Test Accuracy before pruning: 51.6%
Test Accuracy after pruning: 51.9%
Train Epoch: 5 - Time taken: 112.0s
Train Loss: 1.2173 - Train Accuracy: 57.4%
Test Accuracy before pruning: 59.0%
Test Accuracy after pruning: 59.1%
Train Epoch: 6 - Time taken: 113.7s
Train Loss: 1.2479 - Train Accuracy: 57.1%
Test Accuracy before pruni

In [None]:
run(arch='googlenet', pruning_rate=0.1, epochs=50, batch_size=128, learning_rate=0.1)

Files already downloaded and verified
Files already downloaded and verified


Downloading: "https://download.pytorch.org/models/googlenet-1378be20.pth" to C:\Users\65842/.cache\torch\checkpoints\googlenet-1378be20.pth


  0%|          | 0.00/49.7M [00:00<?, ?B/s]

Test Accuracy before pruning: 0.0%
Test Accuracy after pruning: 0.1%
Train Epoch: 1 - Time taken: 114.5s
Train Loss: 1.6657 - Train Accuracy: 41.7%
Test Accuracy before pruning: 44.7%
Test Accuracy after pruning: 44.8%
Train Epoch: 2 - Time taken: 149.5s
Train Loss: 1.6584 - Train Accuracy: 39.3%
Test Accuracy before pruning: 39.1%
Test Accuracy after pruning: 38.9%
Train Epoch: 3 - Time taken: 125.7s
Train Loss: 1.4551 - Train Accuracy: 47.9%
Test Accuracy before pruning: 48.5%
Test Accuracy after pruning: 48.4%
Train Epoch: 4 - Time taken: 142.2s
Train Loss: 1.2022 - Train Accuracy: 57.0%
Test Accuracy before pruning: 57.2%
Test Accuracy after pruning: 57.2%
Train Epoch: 5 - Time taken: 119.2s
Train Loss: 1.0932 - Train Accuracy: 61.4%
Test Accuracy before pruning: 62.0%
Test Accuracy after pruning: 62.0%
Train Epoch: 6 - Time taken: 123.7s
Train Loss: 1.0518 - Train Accuracy: 62.8%
Test Accuracy before pruning: 63.2%
Test Accuracy after pruning: 63.2%
Train Epoch: 7 - Time taken: 12