In [21]:
import torch
from torch.utils.data import DataLoader, RandomSampler, Subset
from torch import nn
import torch.nn.utils.prune as prune
import numpy as np
import sys
import copy
from utils import data
from utils.data import COCODataset, getLabelMap
import matplotlib.pyplot as plt
import math
use_gpu = True

In [4]:
class MTLUNet(nn.Module):
    def __init__(self, num_channels=3, num_classes=60):
        super(MTLUNet, self).__init__()
        
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        
        # Encoding #
        self.enc0 = nn.Sequential(nn.Conv2d(in_channels=num_channels, out_channels=64, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(64),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(64),
                                  nn.Dropout(p=0.5))
        self.enc1 = nn.Sequential(nn.Conv2d(in_channels=num_channels+64, out_channels=128, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(128),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(128),
                                  nn.Dropout(p=0.5))
        self.enc2 = nn.Sequential(nn.Conv2d(in_channels=num_channels+64+128, out_channels=256, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(256),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(256),
                                  nn.Dropout(p=0.5))
        self.enc3 = nn.Sequential(nn.Conv2d(in_channels=num_channels+64+128+256, out_channels=512, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(512),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(512),
                                  nn.Dropout(p=0.5))
        
        # Shared Base #
        self.shared_base = nn.Sequential(nn.Conv2d(in_channels=num_channels+64+128+256+512, out_channels=1024, kernel_size=3, stride=1, padding=1),
                                         nn.ReLU(),
                                         nn.BatchNorm2d(1024),
                                         nn.Dropout(p=0.5))
        
        # Task-specific Bases
        self.seg_base = nn.Sequential(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
                                      nn.ReLU(),
                                      nn.BatchNorm2d(1024),
                                      nn.Dropout(p=0.5),
                                      nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0),
                                      nn.ReLU(),
                                      nn.BatchNorm2d(512),
                                      nn.Dropout(p=0.5))
        self.class_base = nn.Sequential(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
                                        nn.ReLU(),
                                        nn.BatchNorm2d(1024),
                                        nn.Dropout(p=0.5))    
        
        # Task 1: Segmentation #
        self.seg3 = nn.Sequential(nn.Conv2d(in_channels=num_channels+64+128+256+1024, out_channels=512, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(512),
                                  nn.Dropout(p=0.5),
                                  nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0))
        self.seg2 = nn.Sequential(nn.Conv2d(in_channels=num_channels+64+128+512, out_channels=256, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(256),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(256),
                                  nn.Dropout(p=0.5),
                                  nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0))
        self.seg1 = nn.Sequential(nn.Conv2d(in_channels=num_channels+64+256, out_channels=128, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(128),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(128),
                                  nn.Dropout(p=0.5),
                                  nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0))
        self.seg0 = nn.Sequential(nn.Conv2d(in_channels=num_channels+128, out_channels=64, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(64),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
                                  nn.ReLU(),
                                  nn.BatchNorm2d(64),
                                  nn.Dropout(p=0.5),
                                  nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1))
        
        # Task 2: Classification #
        self.class3 = nn.Sequential(nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(1024),
                                    nn.Dropout(p=0.5),
                                    nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(512),
                                    nn.Dropout(p=0.5))
        self.class2 = nn.Sequential(nn.Conv2d(in_channels=1024+512, out_channels=256, kernel_size=5, stride=1, padding=0),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(256),
                                    nn.Dropout(p=0.5),
                                    nn.Conv2d(in_channels=256, out_channels=256, kernel_size=5, stride=1, padding=0),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(256),
                                    nn.Dropout(p=0.5))
        self.class1 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=128, kernel_size=5, stride=1, padding=0),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(128),
                                    nn.Dropout(p=0.5),
                                    nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=1, padding=0),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(128),
                                    nn.Dropout(p=0.5))
        self.class0 = nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0)
        
    def forward(self, X):
        # Encoding
        
        # Lvl 0
        #print(f'X shape, X_0 input: {X.shape}')
        X_0 = self.enc0(X)
        #print(f'X_0 shape: {X_0.shape}')
        X_0 = torch.cat((X, X_0), dim=1)  # dense connection
        #print(f'X_0 shape: {X_0.shape}')
        X_0_mp = self.maxpool(X_0)
        #print(f'X_0_mp, X_1 input shape: {X_0_mp.shape}')
        
        # Lvl 1
        X_1 = self.enc1(X_0_mp)
        X_1 = torch.cat((X_0_mp, X_1), dim=1)  # dense connection
        X_1_mp = self.maxpool(X_1)
        #print(f'X_1_mp, X_2 input shape: {X_1_mp.shape}')
        
        # Lvl 2
        X_2 = self.enc2(X_1_mp)
        X_2 = torch.cat((X_1_mp, X_2), dim=1)  # dense connection
        X_2_mp = self.maxpool(X_2)
        #print(f'X_2_mp, X_3 input shape: {X_2_mp.shape}')
        
        # Lvl 3
        X_3 = self.enc3(X_2_mp)
        X_3 = torch.cat((X_2_mp, X_3), dim=1)  # dense connection
        X_3_mp = self.maxpool(X_3)
        #print(f'X_3_mp, shared_base input shape: {X_3_mp.shape}')
        
        # Base
        shared_base = self.shared_base(X_3_mp)
        #print(f'shared_base output shape: {shared_base.shape}')
        
        # Task 1: Segmentation
        seg_output = self.seg_base(shared_base)
        #print(f'seg_base output shape: {seg_output.shape}')
        seg_output = torch.cat((X_3, seg_output), dim=1)  # skip connection
        #print(f'seg3 input shape: {seg_output.shape}')
        seg_output = self.seg3(seg_output)
        #print(f'seg3 output shape: {seg_output.shape}')
        
        seg_output = torch.cat((X_2, seg_output), dim=1)  # skip connection
        seg_output = self.seg2(seg_output)
        #print(f'seg2 output shape: {seg_output.shape}')
        
        seg_output = torch.cat((X_1, seg_output), dim=1)  # skip connection
        seg_output = self.seg1(seg_output)
        #print(f'seg1 output shape: {seg_output.shape}')
        
        seg_output = torch.cat((X_0, seg_output), dim=1)  # skip connection
        #print(f'seg0 input shape: {seg_output.shape}')
        seg_output = self.seg0(seg_output)
        #print(f'seg0 output shape: {seg_output.shape}')
        
        # Task 2: Classification
        #print(f'class3 input shape: {shared_base.shape}')
        class_output = self.class3(shared_base)
        class_output = torch.cat((shared_base, class_output), dim=1)  # dense connection
        #print(f'class3 output shape: {class_output.shape}')
        class_output = self.class2(class_output)
        #print(f'class2 output shape: {class_output.shape}')
        class_output = self.class1(class_output)
        #print(f'class1 output shape: {class_output.shape}')
        class_output = self.class0(class_output)
        #print(f'class0 output shape: {class_output.shape}')
        
        return seg_output, class_output
        

In [35]:
def prune_model(model, PRUNING_PERCENT=0.2, n=2, dims = 0):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            prune.ln_structured(module, name='weight', amount=PRUNING_PERCENT, n=n, dim=dims)
            prune.remove(module, 'weight')
    return model

In [36]:
PATH = 'models/test_model.pt'
ITER_PRUNING = 10
PRUNING_PERCENT = 0.05
##Load trained model
# mtlunet = torch.load(PATH)

## placeholder for now --
mtlunet = MTLUNet()
use_gpu = False
if use_gpu:
    mtlunet = mtlunet.cuda()
# mtlunet.load_state_dict(torch.load(PATH))
mtlunet.eval()
    
for idx_prune in range(ITER_PRUNING):
        print(f"\n\nIteration {idx_prune+1} - Pruning {PRUNING_PERCENT*100}% of the least important neurons/filters...")
        dims = [0, 1, 2]
        norms = [1, 2]
        for dim in dims:
            for n in norms:
                print(f"Dimension {dim}, norms {n}") 
                model = copy.deepcopy(mtlunet)
                pruned_model = prune_model(model, PRUNING_PERCENT, n=n, dims=dim)
                      
                ## Fine-tuning
                      
                ## Testing to calculate accuracy and latency
                
                ## Save model and update DB with accuracy and latency
        PRUNING_PERCENT+=0.05
    



Iteration 1 - Pruning 5.0% of the least important neurons/filters...
Dimension 0, norms 1
Dimension 0, norms 2
Dimension 1, norms 1
Dimension 1, norms 2
Dimension 2, norms 1
Dimension 2, norms 2


Iteration 2 - Pruning 10.0% of the least important neurons/filters...
Dimension 0, norms 1
Dimension 0, norms 2
Dimension 1, norms 1
Dimension 1, norms 2
Dimension 2, norms 1
Dimension 2, norms 2


Iteration 3 - Pruning 15.000000000000002% of the least important neurons/filters...
Dimension 0, norms 1
Dimension 0, norms 2
Dimension 1, norms 1
Dimension 1, norms 2
Dimension 2, norms 1
Dimension 2, norms 2


Iteration 4 - Pruning 20.0% of the least important neurons/filters...
Dimension 0, norms 1
Dimension 0, norms 2
Dimension 1, norms 1
Dimension 1, norms 2
Dimension 2, norms 1
Dimension 2, norms 2


Iteration 5 - Pruning 25.0% of the least important neurons/filters...
Dimension 0, norms 1
Dimension 0, norms 2
Dimension 1, norms 1
Dimension 1, norms 2
Dimension 2, norms 1
Dimension 2, norms