# Load Pretrained Model 

In [1]:
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from nni.compression.pytorch.speedup import ModelSpeedup
from nni.compression.pytorch.utils import count_flops_params
import time

from mnist_model import Net, train, test, device, optimizer_scheduler_generator, trainer, loaders

# Load pretrained model
model = torch.load("mnist_cnn.pt")
model.eval()

# show the model stbructure, note that pruner will wrap the model layer.
print(model)

Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)


### Performance and statistics of pre-trained model 

In [2]:
start = time.time()

pre_best_acc = test(model, device)
pre_test_time = time.time() - start

pre_flops, pre_params, _ = count_flops_params(model, torch.randn([3, 1, 28, 28]).to(device))
print(f'Pretrained model FLOPs {pre_flops/1e6:.2f} M, #Params: {pre_params/1e6:.2f}M, Accuracy: {pre_best_acc: .2f}%, Test-time: {pre_test_time: .4f}s')


Test set: Average loss: 0.0267, Accuracy: 9919/10000 (99.19%)

+-------+-------+--------+----------------+-----------------+-----------------+----------+---------+
| Index | Name  |  Type  |  Weight Shape  |    Input Size   |   Output Size   |  FLOPs   | #Params |
+-------+-------+--------+----------------+-----------------+-----------------+----------+---------+
|   0   | conv1 | Conv2d | (32, 1, 3, 3)  |  (3, 1, 28, 28) | (3, 32, 26, 26) |  194688  |   320   |
|   1   | conv2 | Conv2d | (64, 32, 3, 3) | (3, 32, 26, 26) | (3, 64, 24, 24) | 10616832 |  18496  |
|   2   | fc1   | Linear |  (128, 9216)   |    (3, 9216)    |     (3, 128)    | 1179648  | 1179776 |
|   3   | fc2   | Linear |   (10, 128)    |     (3, 128)    |     (3, 10)     |   1280   |   1290  |
+-------+-------+--------+----------------+-----------------+-----------------+----------+---------+
FLOPs total: 11992448
#Params total: 1199882
Pretrained model FLOPs 11.99 M, #Params: 1.20M, Accuracy:  99.19%, Test-time:  1.48

# Pruning Model  Activation Mean Rank

In [3]:
from nni.compression.pytorch.pruning import ADMMPruner
from nni.compression.pytorch.pruning import ActivationMeanRankPruner
from nni.compression.pytorch.speedup import ModelSpeedup
import nni

import torch.nn.functional as F

def pruner_function(config_list):

    model = torch.load("mnist_cnn.pt")
    model.eval()

    traced_optimizer = nni.trace(optim.Adadelta)(model.parameters(), lr=1.0)
    criterion = F.nll_loss
    
    # Using ADMMPruner to prune the model and generate the masks.
    #pruner = ADMMPruner(model, config_list, trainer, traced_optimizer, criterion, iterations=5, training_epochs=1, granularity='coarse-grained')
    
    pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
    
    # show the wrapped model structure, `PrunerModuleWrapper` have wrapped the layers that configured in the config_list.
    #print(model)

    # compress the model and generate the masks
    _, masks = pruner.compress()

    # show the masks sparsity
    print("Showing the masks sparsity")
    for name, mask in masks.items():
        print(name, ' sparsity : ', '{:.2}'.format(mask['weight'].sum() / mask['weight'].numel()))


    # need to unwrap the model, if the model is wrapped before speedup
    pruner._unwrap_model()

    # speedup the model, for more information about speedup, please refer :doc:`pruning_speedup`.
    ModelSpeedup(model, torch.rand(3, 1, 28, 28).to(device), masks).speedup_model()

    #print("Model after speedup")
    #print(model)

    optimizer, scheduler = optimizer_scheduler_generator(model)
    
    # fine- tuning model compacted model
    # tuning and evaluate the model on MNIST dataset
    total_epoch = 3
    
    for epoch in range(1, total_epoch + 1):
        train(model, device, optimizer=optimizer, epoch=epoch)
        test(model, device)
        scheduler.step()
        
    return model

In [4]:
def Perfomance_function(model):
    print("Model after speedup")
    print(model)
    
    start = time.time()
    best_acc = test(model, device)
    test_time = time.time() - start

    flops, params, _ = count_flops_params(model, torch.randn([3, 1, 28, 28]).to(device))

    print(f'Pretrained model FLOPs {pre_flops/1e6:.2f} M, #Params: {pre_params/1e6:.2f}M, Accuracy: {pre_best_acc: .2f}%, , Test-time: {pre_test_time: .4f}s')
    print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}%, Test-time: {test_time: .4f}s, Speed-up: {pre_test_time/test_time: .2f}x')

## ADMM Configuration 2

In [5]:
config_list = [{
    'op_types': ['Conv2d'],
    'total_sparsity': 0.85
    }, {
    'op_names': ['Linear'],
    'total_sparsity': 0.85
    },
    {
    'exclude': True,
    'op_names': ['fc2']
}]


pruned_model = pruner_function(config_list=config_list)

Showing the masks sparsity
conv1  sparsity :  0.16
conv2  sparsity :  0.16
[2022-11-20 13:14:56] [32mstart to speedup the model[0m
[2022-11-20 13:14:56] [32minfer module masks...[0m
[2022-11-20 13:14:56] [32mUpdate mask for conv1[0m
[2022-11-20 13:14:56] [32mUpdate mask for .aten::relu.6[0m
[2022-11-20 13:14:56] [32mUpdate mask for conv2[0m
[2022-11-20 13:14:56] [32mUpdate mask for .aten::relu.7[0m
[2022-11-20 13:14:56] [32mUpdate mask for .aten::max_pool2d.8[0m
[2022-11-20 13:14:56] [32mUpdate mask for dropout1[0m
[2022-11-20 13:14:56] [32mUpdate mask for .aten::flatten.9[0m
[2022-11-20 13:14:56] [32mUpdate mask for fc1[0m
[2022-11-20 13:14:56] [32mUpdate mask for .aten::relu.10[0m
[2022-11-20 13:14:56] [32mUpdate mask for dropout2[0m
[2022-11-20 13:14:56] [32mUpdate mask for fc2[0m
[2022-11-20 13:14:56] [32mUpdate mask for .aten::log_softmax.11[0m
[2022-11-20 13:14:56] [32mUpdate the indirect sparsity for the .aten::log_softmax.11[0m
[2022-11-20 13:14:56

  return self._grad



Test set: Average loss: 0.0596, Accuracy: 9792/10000 (97.92%)




Test set: Average loss: 0.0488, Accuracy: 9844/10000 (98.44%)


Test set: Average loss: 0.0391, Accuracy: 9871/10000 (98.71%)



In [6]:
Perfomance_function(pruned_model)

Model after speedup
Net(
  (conv1): Conv2d(1, 5, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(5, 10, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=1440, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

Test set: Average loss: 0.0391, Accuracy: 9871/10000 (98.71%)

+-------+-------+--------+---------------+----------------+-----------------+--------+---------+
| Index | Name  |  Type  |  Weight Shape |   Input Size   |   Output Size   | FLOPs  | #Params |
+-------+-------+--------+---------------+----------------+-----------------+--------+---------+
|   0   | conv1 | Conv2d |  (5, 1, 3, 3) | (3, 1, 28, 28) |  (3, 5, 26, 26) | 30420  |    50   |
|   1   | conv2 | Conv2d | (10, 5, 3, 3) | (3, 5, 26, 26) | (3, 10, 24, 24) | 259200 |   460   |
|   2   | fc1   | Linear |  (128, 1440)  |   (3, 1440)    |     (3, 128)    | 184320 |  1844

## Knowledge Distillation on NNI

In [7]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

from copy import deepcopy

In [8]:
class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss
    
def get_dummy_input(device):
    dummy_input = torch.randn([3, 1, 28, 28]).to(device)
    return dummy_input

def get_model_optimizer_scheduler(model_t, model_s):
    module_list = nn.ModuleList([])
    module_list.append(model_s)
    module_list.append(model_t)

    # setup opotimizer for fine-tuning studeng model
    #optimizer = torch.optim.SGD(model_s.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    #scheduler = MultiStepLR(optimizer, milestones=[int(160*0.5), int(160*0.75)], gamma=0.1)
    optimizer, scheduler = optimizer_scheduler_generator(model_s)
    
    return module_list, optimizer, scheduler



def train(models, device, train_loader, criterion, optimizer, epoch):
    model_s = models[0].train()
    model_t = models[-1].eval()
    cri_cls = criterion
    cri_kd = DistillKL(4)  # (args.kd_T)


    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output_s = model_s(data)
        output_t = model_t(data)

        loss_cls = cri_cls(output_s, target)
        loss_kd = cri_kd(output_s, output_t)
        loss = loss_cls + loss_kd
        loss.backward()

        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

            
def test(model, device, criterion, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    acc = 100 * correct / len(test_loader.dataset)

    print('Test Loss: {}  Accuracy: {}%\n'.format(
        test_loss, acc))
    return acc

In [9]:
train_loader, test_loader = loaders()
criterion = F.nll_loss
models, optimizer, scheduler = get_model_optimizer_scheduler(model, pruned_model)

In [10]:
## Test models
print(test(models[0], device, criterion, test_loader))
print(test(models[1], device, criterion, test_loader))

Test Loss: 3.9109499566257e-05  Accuracy: 98.71%

98.71
Test Loss: 2.6686942018568516e-05  Accuracy: 99.19%

99.19


In [11]:
best_top1 = 0
fine_tune_epochs = 20 

print('start fine-tuning...')

for epoch in range(fine_tune_epochs):
    print('# Epoch {} #'.format(epoch))
    train(models, device, train_loader, criterion, optimizer, epoch)
    scheduler.step()
    
    # test student only
    top1 = test(models[0], device, criterion, test_loader)
    if top1 > best_top1:
        best_top1 = top1
        updated_model = deepcopy(models[0])
        torch.save(updated_model, "mnist_prunned.pt")

start fine-tuning...
# Epoch 0 #




Test Loss: 5.233486592769623e-05  Accuracy: 98.56%

# Epoch 1 #
Test Loss: 4.277927167713642e-05  Accuracy: 98.77%

# Epoch 2 #
Test Loss: 3.803734350949526e-05  Accuracy: 98.91%

# Epoch 3 #
Test Loss: 3.740927577018738e-05  Accuracy: 98.92%

# Epoch 4 #
Test Loss: 3.66814985871315e-05  Accuracy: 98.84%

# Epoch 5 #
Test Loss: 3.575578182935715e-05  Accuracy: 98.91%

# Epoch 6 #
Test Loss: 3.677289746701717e-05  Accuracy: 98.87%

# Epoch 7 #
Test Loss: 3.691333644092083e-05  Accuracy: 98.87%

# Epoch 8 #
Test Loss: 3.656118661165237e-05  Accuracy: 98.87%

# Epoch 9 #
Test Loss: 3.576974701136351e-05  Accuracy: 98.89%

# Epoch 10 #
Test Loss: 3.5620678775012496e-05  Accuracy: 98.9%

# Epoch 11 #
Test Loss: 3.639510851353407e-05  Accuracy: 98.88%

# Epoch 12 #
Test Loss: 3.564784917980433e-05  Accuracy: 98.88%

# Epoch 13 #
Test Loss: 3.5611553769558667e-05  Accuracy: 98.89%

# Epoch 14 #


Test Loss: 3.578967303037643e-05  Accuracy: 98.88%

# Epoch 15 #
Test Loss: 3.5690776258707045e-05  Accuracy: 98.88%

# Epoch 16 #
Test Loss: 3.560706460848451e-05  Accuracy: 98.88%

# Epoch 17 #
Test Loss: 3.5614733211696146e-05  Accuracy: 98.88%

# Epoch 18 #
Test Loss: 3.556640092283487e-05  Accuracy: 98.88%

# Epoch 19 #
Test Loss: 3.560533616691828e-05  Accuracy: 98.89%



In [12]:
print(test(updated_model, device, criterion, test_loader))

Test Loss: 3.740927744656801e-05  Accuracy: 98.92%

98.92


In [17]:
torch.save(updated_model.state_dict(), 'model_trained.pth')

In [18]:
compiled_model = torch.jit.script(updated_model)
torch.jit.save(compiled_model, 'updated_model.pt')