In [1]:
import torch
from torchvision.models import resnet18
import torch_pruning as tp
from torchsummary import summary
import numpy as np
import torch.nn as nn

# model = torch.load('resnet/resnet18.pth')
model = torch.load('../../vgg/vgg11.pth')
# build layer dependency for resnet18
model.eval()


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [2]:
import random
def random_int_list(start, stop, length):
    start, stop = (int(start), int(stop)) if start <= stop else (int(stop), int(start))
    length = int(abs(length)) if length else 0
    random_list = []
    for i in range(length):
        random_list.append(random.randint(start, stop))
    return random_list

In [3]:
def prune_model(model):
    # model.cpu()
    # DG = tp.DependencyGraph(model, fake_input=torch.randn(1,3,224,224)).build_dependency( model, torch.randn(1, 3, 224, 224) )
    DG = tp.DependencyGraph( model.cpu(), fake_input=torch.randn(1,3,224,224) )
    def prune_conv(conv, pruned_prob):
        weight = conv.weight.detach().cpu().numpy()
        out_channels = weight.shape[0]
        L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
        num_pruned = int(out_channels * pruned_prob)
        prune_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
        print(prune_index)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, prune_index)
        plan.exec()
    
    conv_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    for m in model.features:
        if isinstance( m, nn.Conv2d ):
            print(m)
            prune_conv( m, conv_prune_probs[blk_id] )
            blk_id+=1

    #pruning Fully connective layer
    # linear = model.classifier[0]
    # weight = linear.weight.detach().cpu().numpy()
    # out_channels = weight.shape[0]
    # L1_norm = np.sum( np.abs(weight), axis=(1))
    # num_pruned = int(out_channels * 0.2)
    # prune_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm

    # pruning_plan = DG.get_pruning_plan( linear, tp.prune_linear, prune_index)
    # pruning_plan.exec()

    return model    


In [4]:
model = prune_model(model)

Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[22, 30, 52, 31, 38, 15]
Conv2d(58, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[8, 0, 42, 110, 22, 47, 5, 107, 88, 41, 37, 124]
Conv2d(116, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[108, 48, 136, 201, 31, 86, 76, 179, 87, 175, 73, 251, 52, 137, 186, 191, 17, 176, 77, 7, 212, 13, 6, 229, 29, 217, 246, 50, 93, 218, 67, 90, 252, 63, 133, 49, 167, 100, 46, 157, 54, 66, 190, 222, 142, 232, 249, 245, 239, 145, 171]
Conv2d(205, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[76, 206, 133, 231, 207, 175, 145, 31, 186, 115, 136, 179, 47, 104, 38, 184, 2, 122, 239, 34, 140, 217, 40, 50, 66, 81, 216, 188, 144, 79, 53, 196, 208, 215, 102, 48, 134, 236, 187, 151, 132, 101, 194, 64, 108, 181, 46, 116, 71, 109, 118]
Conv2d(205, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
[73, 355, 381, 436, 481, 321, 12, 254, 309, 243, 407, 131, 319, 212, 428, 78, 343, 287, 177, 367, 405, 292, 147, 372, 3

In [5]:
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(58, 116, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(116, 205, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(205, 205, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(205, 410, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(410, 410, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [7]:
torch.save(model,'../../vgg/pruning/vgg11.pth')