In [1]:
import numpy as np
import torch  
import torch.nn as nn

def compress_matrix(x):

    if len(x.shape) != 2:
        A, B, C, D = x.shape
        x = x.reshape(A * B,  C * D)
        ind = np.argwhere((np.sum(np.abs(x), axis=1)) == 0.)
        # remove non-necessary filters and rows
        x = x[:, (x != 0).any(axis=0)]
        x = x[(x != 0).any(axis=1), :]
        x = x.reshape(-1,B,C,D)
    else:
        # remove unnecessary rows, columns
        x = x[(x != 0).any(axis=1), :]
        x = x[:, (x != 0).any(axis=0)]
        ind = numpy.array([0.]) # for now
                          
    return x, ind





model = sparsify(LeNet())

new_channels = []

for m in model.children():
    if isinstance(m, nn.Conv2d):
        num_channels = compress_layer(m)
        new_channels.append(num_channels)

compressed = TinyNet(channels=new_channels)

for layer, compressed_layer in zip(model.children(), compressed.children()):
    compressed_layer.weight = layer.weight
    if(layer.bias):
        compressed_layer.bias   = layer.bias







example = nn.Conv2d(3,6,5)

print(example)

x = example.weight.data.numpy()
x[1,:] = 0.

x, ind = compress_matrix(x)
out_, in_, height_, width_ = x.shape

new_layer = nn.Conv2d(in_, out_, height_)
new_layer.weight.data = torch.Tensor(x)

print(new_layer)

## remove the bias


# now need to remove the inputs corresponding to the channel we just pruned in the following layer
# could be conv, bn, pool or fc
# layer 2 of LeNet
next_layer = nn.Conv2d(6,16,5)



NameError: name 'sparsify' is not defined

In [2]:
l1 = nn.Conv2d(2,3,2)
print(l1)
l1 = l1.weight.data.numpy()
l1[1,:] = 0. 
l1, ind = compress_matrix(l1)
out_, in_, height_, width_ = l1.shape

new_layer = nn.Conv2d(in_, out_, height_)
new_layer.weight.data = torch.Tensor(l1)

print(new_layer)



# follow up 
x = nn.Conv2d(3,3,2)
x = x.weight.data.numpy()

print(x)

print("\n\n\n")

print(ind)
x = np.delete(x, ind, 0)


print(x)

Conv2d(2, 3, kernel_size=(2, 2), stride=(1, 1))
Conv2d(2, 2, kernel_size=(2, 2), stride=(1, 1))
[[[[-0.07139529  0.04683082]
   [ 0.14715445  0.10146303]]

  [[ 0.20623404 -0.20414823]
   [ 0.01558907  0.11479891]]

  [[ 0.26737356  0.16508116]
   [ 0.01479433  0.11537937]]]


 [[[-0.03518109  0.1855519 ]
   [ 0.25363138 -0.07190578]]

  [[ 0.0926738   0.14208406]
   [-0.14850877  0.00415868]]

  [[ 0.05830792 -0.13324894]
   [ 0.10483494  0.07671355]]]


 [[[-0.23851964  0.1788559 ]
   [ 0.11026009 -0.01322468]]

  [[ 0.15628926 -0.22298165]
   [ 0.02409538 -0.24568257]]

  [[ 0.07694096  0.00656321]
   [-0.2602202  -0.14715835]]]]




[[2]
 [3]]
[[[[-0.07139529  0.04683082]
   [ 0.14715445  0.10146303]]

  [[ 0.20623404 -0.20414823]
   [ 0.01558907  0.11479891]]

  [[ 0.26737356  0.16508116]
   [ 0.01479433  0.11537937]]]


 [[[-0.03518109  0.1855519 ]
   [ 0.25363138 -0.07190578]]

  [[ 0.0926738   0.14208406]
   [-0.14850877  0.00415868]]

  [[ 0.05830792 -0.13324894]
   [ 0.104834



In [15]:
x = nn.Conv2d(3,2,2).weight.data.numpy()


print(x, "\n\n--------")


print(x[1,:,:,:], "\n\n---------")
print(x[:,1,:,:], "\n\n---------")

[[[[-0.11309339 -0.13252267]
   [ 0.10189015  0.23333943]]

  [[ 0.18838969 -0.20607772]
   [-0.04675874 -0.2579743 ]]

  [[-0.22457686 -0.12578286]
   [ 0.10464424 -0.02261609]]]


 [[[ 0.21070072  0.12532869]
   [-0.2725599  -0.10313283]]

  [[-0.13886876 -0.2513148 ]
   [-0.25815955  0.2037502 ]]

  [[ 0.18243036  0.06070128]
   [ 0.0784964  -0.14904013]]]] 

--------
[[[ 0.21070072  0.12532869]
  [-0.2725599  -0.10313283]]

 [[-0.13886876 -0.2513148 ]
  [-0.25815955  0.2037502 ]]

 [[ 0.18243036  0.06070128]
  [ 0.0784964  -0.14904013]]] 

---------
[[[ 0.18838969 -0.20607772]
  [-0.04675874 -0.2579743 ]]

 [[-0.13886876 -0.2513148 ]
  [-0.25815955  0.2037502 ]]] 

---------


# Channel Pruning in PyTorch


In [3]:
import numpy as np
import torch  

import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 16, 5)
        self.conv4 = nn.Conv2d(16, 10, 5)
        self.fc1   = nn.Linear(10*5*5, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv3(out))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv4(out))
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        return out
    
    
class TinyNet(nn.Module):
    def __init__(self, channels):
        super(TinyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, channels[0], 5)
        self.conv2 = nn.Conv2d(channels[0], channels[1], 5)
        self.conv3 = nn.Conv2d(channels[1], channels[2], 5)
        self.conv4 = nn.Conv2d(channels[2], channels[3], 5)
        self.fc1   = nn.Linear(channels[3] * 5 * 5, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv3(out))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv4(out))
        out = F.max_pool2d(out, 2)
        
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        return out
    

def calculate_threshold(weights, ratio):
    return np.percentile(np.array(torch.abs(weights)), ratio)

def sparsify(model, threshold=50.):
    for name, param in model.named_parameters():
        if 'weight' in name:
            threshold  = calculate_threshold(param.data, threshold)
            mask       = torch.gt(torch.abs(param), threshold).float()
            param.data = (param * mask).data
    return model

def sparsify_on_bn(model):
    '''
    Here we zero out whole planes where their batchnorm weight is 0
    1. Consider lists in pairs
    2. If conv followed by batchnorm - get nonzeros from batchnorm 
    3. Zero out whole conv filters
    '''
    
    for l1, l2 in zip(list(model.children()), list(model.children())[1:]):
        if isinstance(l1, nn.Conv2d) and isinstance(l2, bn.BatchNorm2dEx):
            zeros = argwhere_nonzero(l2.weight, batchnorm=True)
            l1[zeros] = 0.
            

def argwhere_nonzero(layer, batchnorm=False):
    indices=[]
    
              
    # for batchnorms we want to do the opposite
    if batchnorm:
        x = layer.data.cpu().numpy()
        indices = np.argwhere(x, x==0.) # <<- not sure about syntax
    else:
        for idx,w in enumerate(layer):
            if torch.sum(torch.abs(w)) != 0.:
                indices.append(idx)
        
    return indices


def prune_conv(indices, layer, follow=False):
    # follow tells us whether we need to prune input channels or output channels
    if not follow:
        # prune output channels
        layer.weight.data = layer.weight[indices].data
        layer.bias.data   = layer.bias[indices].data
    else: 
        # prune input channels
        layer.weight.data = layer.weight[:,indices].data
        
def prune_fc(indices, channel_size, layer, follow_conv=True):
    if follow_conv:
        # if we are following a conv layer we need to expand each index by the size of the plane
        indices = [item for sublist in list((map(lambda i : np.arange(i, (i+channel_size)), indices))) for item in sublist]
    
    fc_layer[indices]
        
def compress_convs(model):
    
    ls = list(model.children())

    channels = []
    nonzeros = []
    for l1, l2 in zip(ls, ls[1:]):
        # so now we have pairs of layers
        
        if isinstance(l1, nn.Conv2d):
            nonzeros = argwhere_nonzero(l1.weight)
            channels.append(len(nonzeros))
            
            prune_conv(nonzeros, l1)
            
            if isinstance(l2, nn.Conv2d):
                prune_conv(nonzeros, l2, follow=True)
            elif isinstance(l2, nn.Linear):
                channel_size = l1.kernel_size[0] * l1.kernel_size[1]
                prune_fc(indices, channel_size, l2, follow_conv=True)
            elif isinstance(l2, nn.BatchNorm2d):
                prune_fc(indices, 0, l2, follow_conv=False)
    
    
    
    new_model = TinyNet(channels)
    
    for original, compressed in zip(model.children(), new_model.children()):
        compressed.weight = original.weight
        compressed.bias   = original.bias
    
    return new_model
            

model = LeNet()    
#sparsify(model)
channels = compress_convs(model)


new_model = TinyNet(channels)

for original, compressed in zip(model.children(), new_model.children()):
    compressed.weight = original.weight
    compressed.bias   = original.bias
    
new_model

AttributeError: 'Conv2d' object has no attribute 'kernel_width'

In [36]:
fclayer = [0,1,2,3,4,5]

indices = [0]
channel_size=3

indices = [item for sublist in list((map(lambda i : np.arange(i, (i+channel_size)), indices))) for item in sublist]

indices

[0, 1, 2]