In [118]:
import torch
from torchvision import models

import pprint


##############################################################################################################
# def getModelLayers(model): is a function to return the layers of a model in a more clean and organized manner
#                            that is efficient and easier for pruning purposes
#
#
# The format of the returned data will be a LIST of TUPLES or as followed:
#    cleanedLayers[i] represents the i_th layer (that needs to be pruned) counting from the output
#         ---cleanedLayers[0] is the last layer of the model that needs to be pruned:
#            -Therefore if the last layer is a convolutional layer, it will be stored...
#            -However, if the last layer is a batch norm layer (doesn't need to be pruned) it will not store it.
#            -If the last layer is batch norm and the second-to-last layer is a convolutional layer, cleanedLayers[0] will 
#            store the second-to-last layer and so forth.
#
#    At cleanedLayers[i] will be a tuple of the following data:
#         ---cleanedLayers[i][0]: Name of the i_th to last layer to be pruned (string)
#         ---cleanedLayers[i][1]: weights of the i_th to last layer, of type tensor
#         ---cleanedLayers[i][2]: biases of the i_th to last layer, of type tensor
#
# NOTES:
#         ---memory is shared between the model that is passed into getModelLayers and the returned list,
#            therefore any changed values are permanately changed unless a deep copy is made

def getModelLayers(model):
    
    #load state dictionary
    sdict = model.state_dict()
    
    
    #store
    #modelLayers = [(layer, param) for (layer, param) in sdict.items()]
    
    
    #dictionary of layers with their (names, weight/bias) --- adding only layers we must prune
    pruningWeights = dict()
    pruningBiases = dict()
    
    #for i in range(len(modelLayers)):
    #    if (not 'bn' in modelLayers[i][0] and not 'downsample' in modelLayers[i][0] and not 'pool' in modelLayers[i][0]):
    #        if 'weight' in modelLayers[i][0]:
    #            #print(modelLayers[i][0][:-len('weight')])
    #            pruningWeights[modelLayers[i][0]] = (modelLayers[i][1])
    #        elif('bias' in modelLayers[i][0]):
    #            pruningBiases[modelLayers[i][0]] = (modelLayers[i][1])

    for (layer_name, params) in sdict.items():
        if (not 'bn' in layer_name and not 'downsample' in layer_name and not 'pool' in layer_name):
            if 'weight' in layer_name:
                pruningWeights[layer_name] = (params)
            elif('bias' in modelLayers[i][0]):
                pruningBiases[layer_name] = (params)
    
    cleanedLayers = []    
    for weightName, weights in pruningWeights.items():
        
        #if layer as biases append weights + biases
        if weightName[:-len('weight')]+'bias' in pruningBiases:
            cleanedLayers.append((weightName[:-len('weight')], weights, pruningBiases[weightName[:-len('weight')]+'bias']))
        
        #if layer doesn't have biases, append weights and an empty tensor for biases
        else:
            cleanedLayers.append((weightName[:-len('weight')], weights, torch.FloatTensor([])))
    
    #reverse the list so cleanedLayers[0] is the last layer of the network -- aka the layer closest to output
    cleanedLayers.reverse()
    return (cleanedLayers)


##############################################################################################################

#Lets use vgg16 to because it is small
model = models.vgg16()

modelLayers = getModelLayers(model)

#An example of how to iterate through modelLayers
#only printing names of the layer
for name, weights, biases in modelLayers:
        print(name)#, weights, biases)

classifier.6.
classifier.3.
classifier.0.
features.28.
features.26.
features.24.
features.21.
features.19.
features.17.
features.14.
features.12.
features.10.
features.7.
features.5.
features.2.
features.0.


In [119]:
#Here is another way
for i in range(len(modelLayers)):
    name, weights, biases = modelLayers[i]          
    print(name)#, weights, biases)

classifier.6.
classifier.3.
classifier.0.
features.28.
features.26.
features.24.
features.21.
features.19.
features.17.
features.14.
features.12.
features.10.
features.7.
features.5.
features.2.
features.0.


In [120]:
#and another...
for i, (name, weights, biases) in enumerate(modelLayers):
    print(i, name)#, weights, biases)

0 classifier.6.
1 classifier.3.
2 classifier.0.
3 features.28.
4 features.26.
5 features.24.
6 features.21.
7 features.19.
8 features.17.
9 features.14.
10 features.12.
11 features.10.
12 features.7.
13 features.5.
14 features.2.
15 features.0.


In [121]:
#lastly, lets check this across multiple models and see if it still works

#for vgg16
print('number of layers: {}'.format(len(modelLayers)))

number of layers: 16


In [122]:
#for resnet50
print('number of layers: {}'.format(len(getModelLayers(models.resnet50()))))

number of layers: 50


In [123]:
#for resnet101
print('number of layers: {}'.format(len(getModelLayers(models.resnet101()))))

number of layers: 101


In [124]:
#for resnet152
print('number of layers: {}'.format(len(getModelLayers(models.resnet152()))))

number of layers: 152


In [125]:
#for alexnet
print('number of layers: {}'.format(len(getModelLayers(models.alexnet()))))

number of layers: 8


In [126]:
#for inception
print('number of layers: {}'.format(len(getModelLayers(models.inception_v3()))))

number of layers: 89


In [127]:
#I think this is wrong, for the sake of showing that my function probably doesn't work
#across all models, lets take a look
modelLayers = getModelLayers(models.inception_v3())
for name, _, _ in modelLayers:
    print(name)

fc.
Mixed_7c.branch3x3dbl_3b.conv.
Mixed_7c.branch3x3dbl_3a.conv.
Mixed_7c.branch3x3dbl_2.conv.
Mixed_7c.branch3x3dbl_1.conv.
Mixed_7c.branch3x3_2b.conv.
Mixed_7c.branch3x3_2a.conv.
Mixed_7c.branch3x3_1.conv.
Mixed_7c.branch1x1.conv.
Mixed_7b.branch3x3dbl_3b.conv.
Mixed_7b.branch3x3dbl_3a.conv.
Mixed_7b.branch3x3dbl_2.conv.
Mixed_7b.branch3x3dbl_1.conv.
Mixed_7b.branch3x3_2b.conv.
Mixed_7b.branch3x3_2a.conv.
Mixed_7b.branch3x3_1.conv.
Mixed_7b.branch1x1.conv.
Mixed_7a.branch7x7x3_4.conv.
Mixed_7a.branch7x7x3_3.conv.
Mixed_7a.branch7x7x3_2.conv.
Mixed_7a.branch7x7x3_1.conv.
Mixed_7a.branch3x3_2.conv.
Mixed_7a.branch3x3_1.conv.
AuxLogits.fc.
AuxLogits.conv1.conv.
AuxLogits.conv0.conv.
Mixed_6e.branch7x7dbl_5.conv.
Mixed_6e.branch7x7dbl_4.conv.
Mixed_6e.branch7x7dbl_3.conv.
Mixed_6e.branch7x7dbl_2.conv.
Mixed_6e.branch7x7dbl_1.conv.
Mixed_6e.branch7x7_3.conv.
Mixed_6e.branch7x7_2.conv.
Mixed_6e.branch7x7_1.conv.
Mixed_6e.branch1x1.conv.
Mixed_6d.branch7x7dbl_5.conv.
Mixed_6d.branch7x7dbl_

In [114]:
#These layers all look like valid layers but I am fairly certain that they aren't.
#      --- Some of these layers shouldn't be pruned

#I'm pointing this out because we have to remember not to blindly use this function on any model architecture.

#For our purposes it should work as we are only using resnet101, alexnet, and vgg16 and 
#      --- This function works for them