> CNN Filter Pruning Project for Harvard CS 2420: Computing at Scale (Fall 2024)


---

### **Setup**

---

In [None]:
!nvidia-smi

Thu Oct 10 01:13:25 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   29C    P0              46W / 400W |   1079MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
import sys
import time
import os
import math

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from copy import deepcopy

from torch.autograd import Variable

In [None]:
# tracks the highest accuracy observed so far
best_acc = 0

def moving_average(a, n=100):
    '''Helper function used for visualization'''
    ret = torch.cumsum(torch.Tensor(a), 0)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

def train(net, epoch, loader, criterion, optimizer, loss_tracker = [], acc_tracker = []):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        # update optimizer state
        optimizer.step()
        # compute average loss
        train_loss += loss.item()
        loss_tracker.append(loss.item())
        loss = train_loss / (batch_idx + 1)
        # compute accuracy
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        # Print status
        sys.stdout.write(f'\rEpoch {epoch}: Train Loss: {loss:.3f}' +
                         f'| Train Acc: {acc:.3f}')
        sys.stdout.flush()
    acc_tracker.append(acc)
    sys.stdout.flush()

def test(net, epoch, loader, criterion, loss_tracker = [], acc_tracker = []):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            loss_tracker.append(loss.item())
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            loss = test_loss / (batch_idx + 1)
            acc = 100.* correct / total
    sys.stdout.write(f' | Test Loss: {loss:.3f} | Test Acc: {acc:.3f}\n')
    sys.stdout.flush()

    # Save checkpoint
    acc = 100.*correct/total
    acc_tracker.append(acc)
    if acc > best_acc:
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc

In [None]:
def conv_block(in_channels, out_channels, kernel_size=3, stride=1,
               padding=1):
    '''
    A nn.Sequential layer executes its arguments in sequential order. In
    this case, it performs Conv2d -> BatchNorm2d -> ReLU. This is a typical
    block of layers used in Convolutional Neural Networks (CNNs). The
    ConvNet implementation below stacks multiple instances of this three layer
    pattern in order to achieve over 90% classification accuracy on CIFAR-10.
    '''
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding,
                  bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
        )

class ConvNet(nn.Module):
    '''
    A 9 layer CNN using the conv_block function above. Again, we use a
    nn.Sequential layer to build the entire model. The Conv2d layers get
    progressively larger (more filters) as the model gets deeper. This
    corresponds to spatial resolution getting smaller (via the stride=2 blocks),
    going from 32x32 -> 16x16 -> 8x8. The nn.AdaptiveAvgPool2d layer at the end
    of the model reduces the spatial resolution from 8x8 to 1x1 using a simple
    average across all the pixels in each channel. This is then fed to the
    single fully connected (linear) layer called classifier, which is the output
    prediction of the model.
    '''
    def __init__(self):
        super(ConvNet, self).__init__()
        self.model = nn.Sequential(
            conv_block(3, 32),
            conv_block(32, 32),
            conv_block(32, 64, stride=2),
            conv_block(64, 64),
            conv_block(64, 64),
            conv_block(64, 128, stride=2),
            conv_block(128, 128),
            conv_block(128, 256),
            conv_block(256, 256),
            nn.AdaptiveAvgPool2d(1)
            )

        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        '''
        The forward function is called automatically by the model when it is
        given an input image. It first applies the 8 convolution layers, then
        finally the single classifier layer.
        '''
        h = self.model(x)
        B, C, _, _ = h.shape
        h = h.view(B, C)
        return self.classifier(h)

---

### **Structured and Non-structured Filter Pruning**

---

In this section, we will implement a simplified version of structured filter pruning proposed in [Pruning Filters for Efficient ConvNets](https://openreview.net/pdf?id=rJqFGTslg). Instead of pruning weights, this paper describes removing whole filters from each convolutional layer in a CNN. Compared to pruning weights across the network, filter pruning is a naturally structured pruning method that does not introduce irregular sparsity. Therefore, it does not require using sparse libraries or specialized hardware.
For each convolutional layer, we measure each filter’s relative importance by its absolute weight sum $\sum|\mathcal{F}_{i,j}|$ (i.e., its $\ell_1$-norm). When pruning a layer, $m$ filters with the smallest relative importance will be pruned, where $m$ = (prune percentage $\times$ total number of filters in this layer).

Besides structured pruning, we will also implement non-structured pruning proposed in [Learning both Weights and Connections for Efficient Neural Networks
](https://arxiv.org/abs/1506.02626) for comparsion. Non-structured pruning is more flexible than structured pruning, and allows irregular sparsity in the weight tensor.

In [None]:
def _make_pair(x):
    if hasattr(x, '__len__'):
        return x
    else:
        return (x, x)

class SparseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                     padding=1):
        super(SparseConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = _make_pair(stride)
        self.padding = _make_pair(padding)

        # initialize weights of this layer
        self._weight = nn.Parameter(torch.randn([self.out_channels, self.in_channels,
                                                        self.kernel_size, self.kernel_size]))
        stdv = 1. / math.sqrt(in_channels)
        self._weight.data.uniform_(-stdv, stdv)
        # initialize mask
        # Since we are going to zero out the whole filter, the number of
        # elements in the mask is equal to the number of filters.
        self.register_buffer('_mask', torch.ones(out_channels))


    def forward(self, x):
        return F.conv2d(x, self.weight, stride=self.stride,
                        padding=self.padding)

    @property
    def weight(self):
        return self._mask[:,None,None,None] * self._weight

In [None]:
def sparse_conv_block(in_channels, out_channels, kernel_size=3, stride=1,
                      padding=1):
    '''
    Replaces 3x3 nn.Conv2d with 3x3 SparseConv2d
    '''
    return nn.Sequential(
        SparseConv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
        )

class SparseConvNet(nn.Module):
    '''
    A 9 layer CNN using the sparse_conv_block function above.
    PART 3.1: Implement!
    '''
    def __init__(self):
        super(SparseConvNet, self).__init__()
        print("Hello from constructor")

        ## We simply create a sequence of sparse convolution blocks
        self.model = nn.Sequential(
            sparse_conv_block(3, 32),
            sparse_conv_block(32, 32),
            sparse_conv_block(32, 64, stride=2),
            sparse_conv_block(64, 64),
            sparse_conv_block(64, 64),
            sparse_conv_block(64, 128, stride=2),
            sparse_conv_block(128, 128),
            sparse_conv_block(128, 256),
            sparse_conv_block(256, 256),
            nn.AdaptiveAvgPool2d(1)
        )

        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        '''
        PART 3.1: Implement!
        '''

        h = self.model(x)
        B, C, _, _ = h.shape
        h = h.view(B, C)
        return self.classifier(h)

In [None]:
torch.manual_seed(43) # to give stable randomness

def get_sparse_conv2d_layers(net):
    '''
    Helper function which returns all SparseConv2d layers in the net.
    Use this below to implement layerwise pruning.
    '''
    sparse_conv_layers = []
    for layer in net.children():
        if isinstance(layer, SparseConv2d):
            sparse_conv_layers.append(layer)
        else:
            child_layers = get_sparse_conv2d_layers(layer)
            sparse_conv_layers.extend(child_layers)

    return sparse_conv_layers

def filter_l1_pruning(net, prune_percent):
    for i, layer in enumerate(get_sparse_conv2d_layers(net)):
        num_nonzero = int(layer._mask.sum().item())
        num_total = len(layer._mask)
        num_prune = round(num_total * prune_percent)
        sparsity = 100.0 * (1 - (num_nonzero / num_total))
#         print(num_prune, num_total, num_nonzero, prune_percent)

        ## We keep track of filters' L1 norms, as well as which filter they are
        L1_norms = []
        for j, filter in enumerate(layer._weight):

            ## Don't want to re-mask something that's been pruned
            if layer._mask.data[j] != 0:
                L1_norms.append((filter.norm(p = 1), j))

        ## Bring the filters with lowest norm to the front
        L1_norms.sort()

        ## The number of additional filters we need to prune to reach the desired amount
        for j in range(num_prune - (num_total - num_nonzero)):
            layer._mask.data[L1_norms[j][1]] = 0

device = 'cuda'
net = SparseConvNet()
net = net.to(device)

lr = 0.1
milestones = [24, 49, 74, 99]

## We change to a list
prune_percentages = [0.1, 0.2, 0.3, 0.4, 0.5]
prune_epochs = [10, 20, 30, 40, 50]

epochs = 100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,
                            weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=milestones,
                                                 gamma=0.1)

train_loss_tracker, train_acc_tracker = [], []
test_loss_tracker, test_acc_tracker = [], []

print('Training for {} epochs, with learning rate {} and milestones {}'.format(
      epochs, lr, milestones))

start_time = time.time()
for epoch in range(0, epochs):
    train(net=net, epoch=epoch, loader=trainloader, criterion=criterion, optimizer=optimizer, loss_tracker=train_loss_tracker, acc_tracker=train_acc_tracker)

    if epoch in prune_epochs:
        idx = prune_epochs.index(epoch)
        prune_epoch = prune_epochs[idx]
        prune_percentage = prune_percentages[idx]
        print('Pruning at epoch {}'.format(epoch))
        filter_l1_pruning(net, prune_percentage)
        # unstructured_pruning(net, prune_percentage)

    test(net=net, epoch=epoch, loader=testloader, criterion=criterion, loss_tracker=test_loss_tracker, acc_tracker=test_acc_tracker)
    scheduler.step()


total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))

In [None]:
class SparseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                     padding=1):
        super(SparseConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = _make_pair(stride)
        self.padding = _make_pair(padding)

        # initialize weights of this layer
        self._weight = nn.Parameter(torch.randn([self.out_channels, self.in_channels,
                                                        self.kernel_size, self.kernel_size]))
        stdv = 1. / math.sqrt(in_channels)
        self._weight.data.uniform_(-stdv, stdv)
        # initialize mask
        # Since we are going to zero out the whole filter, the number of
        # elements in the mask is equal to the number of filters.
        self.register_buffer('_mask', torch.ones_like(self._weight))


    def forward(self, x):
        return F.conv2d(x, self.weight, stride=self.stride,
                        padding=self.padding)

    @property
    def weight(self):
        return self._mask * self._weight

def unstructured_pruning(net, prune_percent):
    for i, layer in enumerate(get_sparse_conv2d_layers(net)):
        num_nonzero = layer._mask.sum().item()
        num_total = layer._mask.numel()
        num_prune = round(num_total * prune_percent)

        # Get the absolute value of all weights in the layer
        L1_norms = torch.abs(layer._weight)

        weight_1d = layer._weight.view(-1)
        mask_1d = layer._mask.view(-1)
        L1_norms_1d = L1_norms.view(-1)

        prune_indices = torch.argsort(L1_norms_1d)
        prune_indices = prune_indices[mask_1d[prune_indices] != 0][:int(num_prune - (num_total - num_nonzero))]

        mask_1d[prune_indices] = 0
        layer._mask.data = mask_1d.view_as(layer._weight)

In [None]:
def sparse_conv_block(in_channels, out_channels, kernel_size=3, stride=1,
                      padding=1):
    '''
    Replaces 3x3 nn.Conv2d with 3x3 SparseConv2d
    '''
    return nn.Sequential(
        SparseConv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
        )

class SparseConvNet(nn.Module):
    '''
    A 9 layer CNN using the sparse_conv_block function above.
    PART 3.1: Implement!
    '''
    def __init__(self):
        super(SparseConvNet, self).__init__()

        ## We simply create a sequence of sparse convolution blocks
        self.model = nn.Sequential(
            sparse_conv_block(3, 32),
            sparse_conv_block(32, 32),
            sparse_conv_block(32, 64, stride=2),
            sparse_conv_block(64, 64),
            sparse_conv_block(64, 64),
            sparse_conv_block(64, 128, stride=2),
            sparse_conv_block(128, 128),
            sparse_conv_block(128, 256),
            sparse_conv_block(256, 256),
            nn.AdaptiveAvgPool2d(1)
        )

        self.classifier = nn.Linear(256, 10)

    def forward(self, x):
        '''
        PART 3.1: Implement!
        '''

        h = self.model(x)
        B, C, _, _ = h.shape
        h = h.view(B, C)
        return self.classifier(h)

In [None]:
torch.manual_seed(43) # to give stable randomness

def get_sparse_conv2d_layers(net):
    '''
    Helper function which returns all SparseConv2d layers in the net.
    Use this below to implement layerwise pruning.
    '''
    sparse_conv_layers = []
    for layer in net.children():
        if isinstance(layer, SparseConv2d):
            sparse_conv_layers.append(layer)
        else:
            child_layers = get_sparse_conv2d_layers(layer)
            sparse_conv_layers.extend(child_layers)

    return sparse_conv_layers

device = 'cuda'
net = SparseConvNet()
net = net.to(device)

lr = 0.1
milestones = [24, 49, 74, 99]

## We change to a list
prune_percentages = [0.1, 0.2, 0.3, 0.4, 0.5]
prune_epochs = [10, 20, 30, 40, 50]

epochs = 100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.9,
                            weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=milestones,
                                                 gamma=0.1)

train_loss_tracker, train_acc_tracker = [], []
test_loss_tracker, test_acc_tracker = [], []

print('Training for {} epochs, with learning rate {} and milestones {}'.format(
      epochs, lr, milestones))

start_time = time.time()
for epoch in range(0, epochs):
    train(net=net, epoch=epoch, loader=trainloader, criterion=criterion, optimizer=optimizer, loss_tracker=train_loss_tracker, acc_tracker=train_acc_tracker)

    if epoch in prune_epochs:
        idx = prune_epochs.index(epoch)
        prune_epoch = prune_epochs[idx]
        prune_percentage = prune_percentages[idx]
        print('Pruning at epoch {}'.format(epoch))
#         filter_l1_pruning(net, prune_percentage)
        unstructured_pruning(net, prune_percentage)

    test(net=net, epoch=epoch, loader=testloader, criterion=criterion, loss_tracker=test_loss_tracker, acc_tracker=test_acc_tracker)
    scheduler.step()


total_time = time.time() - start_time
print('Total training time: {} seconds'.format(total_time))