In [10]:
!pip install torchprofile



In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchprofile

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class VGG16(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.layer2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer5 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        self.layer6 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU())
        self.layer7 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        # self.layer8 = nn.Sequential(
        #     nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU())
        # self.layer9 = nn.Sequential(
        #     nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU())
        # self.layer10 = nn.Sequential(
        #     nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size = 2, stride = 2))
        # self.layer11 = nn.Sequential(
        #     nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU())
        # self.layer12 = nn.Sequential(
        #     nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU())
        # self.layer13 = nn.Sequential(
        #     nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
        #     nn.BatchNorm2d(512),
        #     nn.ReLU(),
        #     nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU())
        self.fc2= nn.Sequential(
            nn.Linear(4096, num_classes))
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.layer7(out)
        # out = self.layer8(out)
        # out = self.layer9(out)
        # out = self.layer10(out)
        # out = self.layer11(out)
        # out = self.layer12(out)
        # out = self.layer13(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out

In [35]:
net = VGG16()
input = torch.randn(1, 3, 32, 32)
flops_1 = torchprofile.profile_macs(net, input)
flops_1



224354304

In [36]:
def prune_layer(layer,next_layer,batch_norm_layer, sellection_rate, compression):
    nk = layer.weight.data.shape[0]
    nktbs = round(nk*sellection_rate)
    ng = nktbs // compression
    ng = max(ng,1)
    number_of_kernels_to_be_pruned = nktbs-ng

    print(f"From {nk} sellected {nktbs} filters merged into {ng} groups deleting {number_of_kernels_to_be_pruned}")
    out_channels = layer.out_channels
    layer.out_channels = out_channels - number_of_kernels_to_be_pruned
    layer.weight = nn.Parameter(layer.weight[:out_channels - number_of_kernels_to_be_pruned])
    next_layer.in_channels = layer.out_channels
    next_layer.weight = nn.Parameter(next_layer.weight[:, :layer.out_channels, :, :])
    if batch_norm_layer is not None:
        batch_norm_layer.num_features = layer.out_channels
        if batch_norm_layer.running_mean is not None:
            batch_norm_layer.running_mean = batch_norm_layer.running_mean[:layer.out_channels]
        if batch_norm_layer.running_var is not None:
            batch_norm_layer.running_var = batch_norm_layer.running_var[:layer.out_channels]
        if batch_norm_layer.weight is not None:
            batch_norm_layer.weight = nn.Parameter(batch_norm_layer.weight[:layer.out_channels])
        if batch_norm_layer.bias is not None:
            batch_norm_layer.bias = nn.Parameter(batch_norm_layer.bias[:layer.out_channels])

    if layer.bias is not None:
        layer.bias = nn.Parameter(layer.bias[:out_channels - number_of_kernels_to_be_pruned])



prune_layer(net.layer1[0], net.layer2[0],net.layer1[1],0.9,6) #75
prune_layer(net.layer2[0], net.layer3[0],net.layer2[1],0.9,6)
prune_layer(net.layer3[0], net.layer4[0],net.layer3[1],0.9,9) #80
prune_layer(net.layer4[0], net.layer5[0],net.layer4[1],0.9,9)
prune_layer(net.layer5[0], net.layer6[0],net.layer5[1],0.9,12) #82,5
prune_layer(net.layer6[0], net.layer7[0],net.layer6[1],0.9,12)

From 64 sellected 58 filters merged into 9 groups deleting 49
From 64 sellected 58 filters merged into 9 groups deleting 49
From 128 sellected 115 filters merged into 12 groups deleting 103
From 128 sellected 115 filters merged into 12 groups deleting 103
From 256 sellected 230 filters merged into 19 groups deleting 211
From 256 sellected 230 filters merged into 19 groups deleting 211


In [37]:
flops_2 = torchprofile.profile_macs(net, input)
flops_2



46903296

In [38]:
flops_2/flops_1

0.209059042611458