# Project 1: DNN Pruning via NNI
Oliver Fowler, Brian Park

This notebook contains a draft of the models we experimented with.

In [1]:
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import SGD
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from nni.compression.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from tqdm import tqdm
from torchviz import make_dot, make_dot_from_trace
import sys
import os

device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)

## MNIST CNN with PyTorch
For the first exploration, we're going to use NNI to prune one of the examples from PyTorch that uses a simple CNN to train the MNIST dataset.

In [2]:
batch_size = 64
test_batch_size = 1000
epochs = 14
lr = 1.0
gamma = 0.7
seed = 1
arc_env = False

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [4]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()

    for batch_idx, (data, target) in tqdm(enumerate(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_dataset_length = len(test_loader.dataset)
    test_loss /= test_dataset_length
    accuracy = 100. * correct / test_dataset_length
    print('Average test loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(test_loss, correct, test_dataset_length, accuracy))


In [5]:
train_kwargs = {'batch_size': batch_size}
test_kwargs = {'batch_size': test_batch_size}

# If we're using NVIDIA, we can apply some more software/hardware optimizations if available
if device.type == "cuda":
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
    model = torch.nn.DataParallel(model)
    torch.backends.cudnn.benchmark = True

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

if arc_env:
    dir = "/mnt/beegfs/$USER/data/"
else:
    dir = "data"

dataset1 = datasets.MNIST(dir, train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST(dir, train=False,
                   transform=transform)

train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

model = Net().to(device)

optimizer = optim.Adadelta(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
criterion = F.nll_loss

In [7]:
# Unfortunatley, doesn't work on Mac
if device.type != "mps":
    batch = next(iter(train_loader))
    yhat = model(batch[0])
    make_dot(yhat, params=dict(list(model.named_parameters()))).render("figures/mnist_cnn", format="png")

In [8]:
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    scheduler.step()

938it [00:15, 61.12it/s]
938it [00:15, 62.23it/s]
938it [00:15, 61.34it/s]
938it [00:14, 62.59it/s]
938it [00:15, 62.05it/s]
938it [00:15, 61.53it/s]
938it [00:15, 62.32it/s]
938it [00:15, 62.38it/s]
938it [00:14, 62.69it/s]
938it [00:14, 63.13it/s]
938it [00:15, 61.61it/s]
938it [00:14, 62.56it/s]
938it [00:15, 62.00it/s]
938it [00:15, 62.23it/s]


In [9]:
if not os.path.isfile("models/mnist_cnn.pt"):
    torch.save(model, "models/mnist_cnn.pt")

In [31]:
test(model, device, test_loader)

Average test loss: 0.0267, Accuracy: 9918/10000 (99%)


## NNI Pruning
We've finished training and evaluating the model. Now we can prune the model and see what speedups and improvements on accuracy that we can get.

In [36]:
model = torch.load("models/mnist_cnn.pt")
device = torch.device("cpu")
model = model.to(device)

In [37]:
config_list = [{
    'sparsity_per_layer': 0.5,
    'op_types': ['Conv2d', 'Linear']
}, {
    'exclude': True,
    'op_names': ['fc2']
}]

In [38]:
pruner = L1NormPruner(model, config_list)

In [39]:
print(model)

Net(
  (conv1): PrunerModuleWrapper(
    (module): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  )
  (conv2): PrunerModuleWrapper(
    (module): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  )
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): PrunerModuleWrapper(
    (module): Linear(in_features=9216, out_features=128, bias=True)
  )
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


In [40]:
# compress the model and generate the masks
_, masks = pruner.compress()
# # show the masks sparsity
for name, mask in masks.items():
    print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))

conv1  sparsity :  0.5
conv2  sparsity :  0.5
fc1  sparsity :  0.5


In [42]:
# need to unwrap the model, if the model is wrapped before speedup
pruner._unwrap_model()
ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()

[2022-09-24 21:36:02] [32mstart to speedup the model[0m
[2022-09-24 21:36:02] [32minfer module masks...[0m
[2022-09-24 21:36:02] [32mUpdate mask for conv1[0m
[2022-09-24 21:36:02] [32mUpdate mask for .aten::relu.6[0m
[2022-09-24 21:36:02] [32mUpdate mask for conv2[0m
[2022-09-24 21:36:02] [32mUpdate mask for .aten::relu.7[0m
[2022-09-24 21:36:02] [32mUpdate mask for .aten::max_pool2d.8[0m
[2022-09-24 21:36:02] [32mUpdate mask for dropout1[0m
[2022-09-24 21:36:02] [32mUpdate mask for .aten::flatten.9[0m
[2022-09-24 21:36:02] [32mUpdate mask for fc1[0m
[2022-09-24 21:36:02] [32mUpdate mask for .aten::relu.10[0m
[2022-09-24 21:36:02] [32mUpdate mask for dropout2[0m
[2022-09-24 21:36:02] [32mUpdate mask for fc2[0m
[2022-09-24 21:36:02] [32mUpdate mask for .aten::log_softmax.11[0m
[2022-09-24 21:36:02] [32mUpdate the indirect sparsity for the .aten::log_softmax.11[0m
[2022-09-24 21:36:02] [32mUpdate the indirect sparsity for the fc2[0m
[2022-09-24 21:36:02] 

  return self._grad


In [43]:
print(model)

Net(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=4608, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)


In [44]:
model.parameters()

<generator object Module.parameters at 0x1685cfed0>

In [46]:
optimizer = SGD(model.parameters(), 1e-2)
for epoch in range(3):
    train(model, device, train_loader, optimizer, criterion, epoch)

100%|█████████████████████████████████████████████| 3/3 [01:23<00:00, 27.93s/it]


In [47]:
test(model, device, test_loader)

Average test loss: 0.0298, Accuracy: 9900/10000 (99%)
