### Pruning: A complete guide for beginners

Goals
In this assignment, you will practice pruning a classical neural network model to reduce both model size and latency. The goals of this assignment are as follows:

* Understand the basic concept of pruning
* Implement and apply fine-grained pruning
* Implement and apply channel pruning
* Get a basic understanding of performance improvement (such as speedup) from pruning
* Understand the differences and tradeoffs between these pruning approaches

In [2]:
!pip install torchprofile 

Collecting torchprofile
  Downloading torchprofile-0.0.4-py3-none-any.whl.metadata (303 bytes)
Downloading torchprofile-0.0.4-py3-none-any.whl (7.7 kB)
Installing collected packages: torchprofile
Successfully installed torchprofile-0.0.4


In [3]:
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

assert torch.cuda.is_available()

In [4]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x23301aba090>

In [14]:

def download_url(url, model_dir='.', overwrite=False):
    import os, sys
    from urllib.request import urlretrieve
    target_dir = url.split('/')[-1]
    model_dir = os.path.expanduser(model_dir)
    # try:
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_dir = os.path.join(model_dir, target_dir)
    cached_file = model_dir
    if not os.path.exists(cached_file) or overwrite:
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        urlretrieve(url, cached_file)
        return cached_file
    # except Exception as e:
    #     # remove lock file so download can be executed next time.
    #     os.remove(os.path.join(model_dir, 'download.lock'))
    #     sys.stderr.write('Failed to download from url %s' % url + '\n' + str(e) + '\n')
        return None

In [6]:
class VGG(nn.Module):
    ARCH = [64,128,'M',256,256,'M',512,512,'M',512,512,'M']
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        layers = []
        counts = defaultdict(int)
        
        def add(name: str, layer: nn.Module) -> None:
            layers.append((f"{name}{counts[name]}", layer))
            counts[name] += 1
        
        in_channels = 3
        for x in self.ARCH:
            if x != 'M':
                add('conv', nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
                add('batchnorm', nn.BatchNorm2d(x))
                add('relu', nn.ReLU(True))
            else:
                add('pool', nn.MaxPool2d(2))
        
        self.backbone = nn.Sequential(OrderedDict(layers))
        self.classifier = nn.Linear(512, 10)

    def forward(self, x):
        # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
        x = self.backbone(x)

        # avgpool: [N, 512, 2, 2] => [N, 512]
        x = x.mean([2,3])
        x = self.classifier(x)
        return x

In [7]:
def train(
        model: nn.Module,
        dataloader: DataLoader,
        criterion: nn.Module,
        optimizer: Optimizer,
        scheduler: LambdaLR,
        callbacks = None
) -> None:
    model.train()

    for inputs, targets in tqdm(dataloader, desc='train'):
        inputs = inputs.cuda()
        targets = targets.cuda()

        # reset gradient
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # lan truyen nguoc
        loss.backward()

        # update optimizer, learning rate
        optimizer.step()
        scheduler.step()

        if callbacks is not None:
            for callback in callbacks:
                callback()

In [8]:
def evaluate(
        model: nn.Module,
        dataloader: DataLoader,
        verbose=True
):
    model.eval()

    num_samples = 0
    num_correct = 0

    for inputs, targets in tqdm(dataloader, desc='eval'):
        inputs = inputs.cuda()
        targets = targets.cuda()

        outputs = model(inputs)

        # convert output to class\
        outputs = outputs.argmax(dim=1)

        num_samples += targets.size(0)
        num_correct += (outputs==targets).sum()

    return (num_correct/num_samples*100).item()

Helpers function

In [9]:
def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)


def get_sparsity(tensor: torch.Tensor) -> float:
    """
    calculate the sparsity of the given tensor
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()


def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [10]:
def test_fine_grained_prune(
    test_tensor=torch.tensor([[-0.46, -0.40, 0.39, 0.19, 0.37],
                              [0.00, 0.40, 0.17, -0.15, 0.16],
                              [-0.20, -0.23, 0.36, 0.25, 0.03],
                              [0.24, 0.41, 0.07, 0.13, -0.15],
                              [0.48, -0.09, -0.36, 0.12, 0.45]]),
    test_mask=torch.tensor([[True, True, False, False, False],
                            [False, True, False, False, False],
                            [False, False, False, False, False],
                            [False, True, False, False, False],
                            [True, False, False, False, True]]),
    target_sparsity=0.75, target_nonzeros=None):
    def plot_matrix(tensor, ax, title):
        ax.imshow(tensor.cpu().numpy() == 0, vmin=0, vmax=1, cmap='tab20c')
        ax.set_title(title)
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        for i in range(tensor.shape[1]):
            for j in range(tensor.shape[0]):
                text = ax.text(j, i, f'{tensor[i, j].item():.2f}',
                                ha="center", va="center", color="k")

    test_tensor = test_tensor.clone()
    fig, axes = plt.subplots(1,2, figsize=(6, 10))
    ax_left, ax_right = axes.ravel()
    plot_matrix(test_tensor, ax_left, 'dense tensor')

    sparsity_before_pruning = get_sparsity(test_tensor)
    mask = fine_grained_prune(test_tensor, target_sparsity)
    sparsity_after_pruning = get_sparsity(test_tensor)
    sparsity_of_mask = get_sparsity(mask)

    plot_matrix(test_tensor, ax_right, 'sparse tensor')
    fig.tight_layout()
    plt.show()

    print('* Test fine_grained_prune()')
    print(f'    target sparsity: {target_sparsity:.2f}')
    print(f'        sparsity before pruning: {sparsity_before_pruning:.2f}')
    print(f'        sparsity after pruning: {sparsity_after_pruning:.2f}')
    print(f'        sparsity of pruning mask: {sparsity_of_mask:.2f}')

    if target_nonzeros is None:
        if test_mask.equal(mask):
            print('* Test passed.')
        else:
            print('* Test failed.')
    else:
        if mask.count_nonzero() == target_nonzeros:
            print('* Test passed.')
        else:
            print('* Test failed.')

Load pretrained and CIFAR dataset

In [None]:
checkpoint_url = "https://hanlab.mit.edu/files/course/labs/vgg.cifar.pretrained.pth"
checkpoint = torch.load(download_url(checkpoint_url), map_location="cpu")
model = VGG().cuda()
print(f"=> loading checkpoint '{checkpoint_url}'")
model.load_state_dict(checkpoint['state_dict'])
recover_model = lambda: model.load_state_dict(checkpoint['state_dict'])

In [12]:
image_size = 32
transforms = {
    "train": Compose([
        RandomCrop(image_size, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
    ]),
    "test": ToTensor(),
}
dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=512,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
  )
     

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [01:38<00:00, 1724360.55it/s]


Extracting data/cifar10\cifar-10-python.tar.gz to data/cifar10
Files already downloaded and verified


First evaluate

In [None]:
dense_model_accuracy = evaluate(model, dataloader['test'])
dense_model_size = get_model_size(model)
print(f"dense model has accuracy={dense_model_accuracy:.2f}%")
print(f"dense model has size={dense_model_size/MiB:.2f} MiB")

### Fine-grained Prune

In this section, we will implement and perform fine-grained pruning.

Fine-grained pruning removes the synapses with lowest importance. The weight tensor $W$ will become sparse after fine-grained pruning, which can be described with **sparsity**:

> $\mathrm{sparsity} := \#\mathrm{Zeros} / \#W = 1 - \#\mathrm{Nonzeros} / \#W$

where $\#W$ is the number of elements in $W$.

In practice, given the target sparsity $s$, the weight tensor $W$ is multiplied with a binary mask $M$ to disregard removed weight:

> $v_{\mathrm{thr}} = \texttt{kthvalue}(Importance, \#W \cdot s)$
>
> $M = Importance > v_{\mathrm{thr}}$
>
> $W = W \cdot M$

where $Importance$ is importance tensor with the same shape of $W$, $\texttt{kthvalue}(X, k)$ finds the $k$-th smallest value of tensor $X$, $v_{\mathrm{thr}}$ is the threshold value.

* In step 1, we calculate the number of zeros (num_zeros) after pruning. Note that num_zeros should be an integer. You could use either round() or int() to convert a floating number into an integer. Here we use round().
* In step 2, we calculate the importance of weight tensor. Pytorch provides torch.abs(), torch.Tensor.abs(), torch.Tensor.abs_() APIs.
* In step 3, we calculate the pruning threshold so that all synapses with importance smaller than threshold will be removed. Pytorch provides torch.kthvalue(), torch.Tensor.kthvalue(), torch.topk() APIs.
* In step 4, we calculate the pruning mask based on the threshold. 1 in the mask indicates the synapse will be kept, and 0 in the mask indicates the synapse will be removed. mask = importance > threshold. Pytorch provides torch.gt() API.

In [18]:
def fine_grained_prune(tensor: torch.Tensor, sparsity: float) -> torch.Tensor:
    """
    :param tensor: weight of layer
    :param sparsity: float, pruing sparsity
    """

    sparsity = min(max(0.0, sparsity), 1.0)

    if sparsity == 1.0:
        tensor.zero_()
        return torch.zeros_like(tensor)
    elif sparsity == 0.0:
        return torch.ones_like(tensor)
    
    num_elements = tensor.numel()

    #step 1: calculate zeros
    num_zeros = round(sparsity * num_elements)
    #step 2: calculate importance of weight
    importance = torch.abs(tensor)
    #step 3: calculate threshold
    threshold, _ = torch.kthvalue(importance.flatten(), num_zeros)
    #step 4: get binary mask
    mask = importance > threshold
    print('mask: ' + mask)

    #step 5; apply mask to prune
    tensor.mul(mask)
    return mask

In [19]:
class FineGrainedPrune:
    def __init__(self) -> None:
        self.mask = FineGrainedPrune.prune()

    @torch.no_grad()
    def apply(self, model):
        for name, param in model.parameters():
            if name in self.masks:
                param *= self.masks[name]

    def prune(self, model, sparsity_dict):
        masks = dict()
        for name, param in model.parameters():
            if param.dim() > 1: #prune conv and fc weight
                masks[name] = fine_grained_prune(param, sparsity_dict[name])
        return masks