In [1]:
import time
import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.autograd as autograd

import torchvision.models as models

from agents.base import BaseAgent
from graphs.models.vgg import *
from prune.channel import *
from datasets.cifar100 import *

from utils.metrics import AverageMeter, cls_accuracy
from utils.misc import timeit, print_cuda_statistics


import argparse
import easydict
import matplotlib.pyplot as plt

from utils.config import *
from utils.count_model import *
from agents import *
from scipy.spatial import distance


In [2]:
config = easydict.EasyDict()
config.exp_name = "assemblenet_imagenet"
config.log_dir = os.path.join("experiments", config.exp_name, "logs/")
create_dirs([config.log_dir])
setup_logging(config.log_dir)

config.load_file = os.path.join("experiments", "vgg16_exp_imagenet_0", "checkpoints/checkpoint.pth")
config.cuda = True
config.gpu_device = 0
config.seed = 1
config.milestones =[5,10,15,20,25,30,35,40,45,50,55,60,65,70]
config.gamma = 0.95
config.img_size = 224
config.num_classes = 1000
# config.data_mode = "download"
config.data_mode = "image_folder"
config.data_dir = "C:/Users/USER001/"
config.data_loader_workers = 4
config.pin_memory = True
config.async_loading = True
config.batch_size = 1
config.async_loading = True
config.max_epoch = 10

cls_i = None
channel_importance = dict()

all_list = list()

named_modules_list = dict()
named_conv_list = dict()
original_conv_output = dict()

model_size = {}
model_length = {}
compress_rate = {}
distance_rate = {}
mat = {}
model = models.vgg16(pretrained = True)
filter_small_index = {}
filter_large_index = {}
similar_matrix = {}

# init graph
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0005,
                            nesterov=True)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 10],
                                           gamma=0.9)

device = torch.device("cuda")
model = model.to(device)
loss_fn = loss_fn.to(device)

# current_epoch = 0


current_iteration = 0
best_valid_acc = 0

named_modules_idx_list = dict()
named_modules_list = dict()
named_conv_list = dict()
all_module_list = dict()
named_conv_idx_list = dict()

original_conv_output = dict()

stayed_channels = dict()

data_loader = ImagenetDataLoader(config=config)  # data loader

i = 0
for idx, m in enumerate(model.features):
    if isinstance(m, torch.nn.Conv2d):
        named_modules_idx_list['{}.conv'.format(i)] = idx
        named_modules_list['{}.conv'.format(i)] = m
        named_conv_idx_list['{}.conv'.format(i)] = idx
        named_conv_list['{}.conv'.format(i)] = m
        i += 1

cls_i = 30
sub_data_loader = SpecializedImagenetDataLoader(config, cls_i)

# record conv output and make channel output size
# def record_conv_output(inputs):
inputs_data, _ = next(iter(sub_data_loader.part_train_loader))
inputs_data = inputs_data.cuda(non_blocking=config.async_loading)

i = 0
x = inputs_data
for m in model.features:
    x = m(x)
    if isinstance(m, torch.nn.Conv2d):
        original_conv_output['{}.conv'.format(i)] = x.data
        channel_importance['{}.conv'.format(i)] = torch.zeros(x.size())
        i += 1



[INFO]: Loading DATA.....
[INFO]: Loading DATA.....


In [3]:
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [4]:
named_conv_list

{'0.conv': Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '1.conv': Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '2.conv': Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '3.conv': Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '4.conv': Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '5.conv': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '6.conv': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '7.conv': Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '8.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '9.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '10.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '11.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '12.conv': Conv2d(512, 512, kernel_size=(3,

In [5]:
for i in channel_importance.values():
    print(i.size())

torch.Size([1, 64, 224, 224])
torch.Size([1, 64, 224, 224])
torch.Size([1, 128, 112, 112])
torch.Size([1, 128, 112, 112])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 256, 56, 56])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 28, 28])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])
torch.Size([1, 512, 14, 14])


In [6]:
def save_grad(idx):
    global grads

    def hook(grad):
        grads[idx] = grad

    return hook


def cal_importance(grads_list, outputs_list):
    global channel_importance
    for n, m in named_conv_list.items():
        # print(m)
        if isinstance(m, torch.nn.Conv2d):
            grad = grads_list[n]
            output = outputs_list[n]
            importance = (grad * output)
            channel_importance[n] += importance.data.cpu()


# cacculate importance using output * grad except for last data_loader because to skip dimension mismatch problem
iteration = 1
for inputs, labels in sub_data_loader.part_train_loader:
    if iteration < sub_data_loader.part_train_iterations:
        num_batch = inputs.size(0)
        outputs, grads = {}, {}
        inputs = inputs.cuda(non_blocking=config.async_loading)
        inputs.requires_grad = True

        x = inputs
        i = 0
        for m in model.features:
            x = m(x)
            if isinstance(m, torch.nn.Conv2d):
                outputs['{}.conv'.format(i)] = x
                outputs['{}.conv'.format(i)].register_hook(save_grad('{}.conv'.format(i)))
                i += 1
        else:
            x = x.view(num_batch, -1)
        x = model.classifier(x)

        y_hat = x
        y_hat[:, cls_i].backward(gradient=torch.ones_like(y_hat[:, cls_i]))

        cal_importance(grads, outputs)

    iteration += 1

In [7]:
grads.keys()

dict_keys(['12.conv', '11.conv', '10.conv', '9.conv', '8.conv', '7.conv', '6.conv', '5.conv', '4.conv', '3.conv', '2.conv', '1.conv', '0.conv'])

In [8]:
channel_importance.keys()

dict_keys(['0.conv', '1.conv', '2.conv', '3.conv', '4.conv', '5.conv', '6.conv', '7.conv', '8.conv', '9.conv', '10.conv', '11.conv', '12.conv'])

In [9]:
named_conv_list

{'0.conv': Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '1.conv': Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '2.conv': Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '3.conv': Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '4.conv': Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '5.conv': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '6.conv': Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '7.conv': Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '8.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '9.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '10.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '11.conv': Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 '12.conv': Conv2d(512, 512, kernel_size=(3,

In [29]:
def get_channel_similar(channel_importance_list, compress_rate, distance_rate):
    indice_stayed_list = {}
    for key, channel_weight in channel_importance_list.items():
        print('channel_pruned_num is', int(channel_weight.size()[1] * (1-compress_rate))) # 6
        print('similar_pruned_num is', int(channel_weight.size()[1] * distance_rate)) # 51 

        channel_weight = channel_weight.cuda()
        channel_pruned_num = int(channel_weight.size()[1] * (1-compress_rate))
        similar_pruned_num = int(channel_weight.size()[1] * distance_rate)
        channel_weight_vec = channel_weight.view(channel_weight.size(1), -1) # each channel
        norm2 = torch.norm(channel_weight_vec, 2, 1) # L2norm return 1 vector
        norm2_np = norm2.cpu().numpy()
        print(norm2_np)
        channel_large_index = norm2_np.argsort()[channel_pruned_num:]
        channel_small_index = norm2_np.argsort()[:channel_pruned_num]
        print('channel_large_index : ', channel_large_index, len(channel_large_index) )
        # distance using numpy function
        indices = torch.LongTensor(channel_large_index).cuda()
        channel_weight_vec_after_norm = torch.index_select(channel_weight_vec, 0, indices).cpu().numpy() # select large channel using index
        print('channel_weight_vec_after_norm is ' , channel_weight_vec_after_norm.shape)
        
        #for euclidean distance
        similar_matrix = distance.cdist(channel_weight_vec_after_norm,channel_weight_vec_after_norm,'euclidean' )
        print('similar_matrix is ', similar_matrix) # (58,58)
        similar_sum = np.sum(np.abs(similar_matrix), axis = 0)  
        print('similar_sum is ', similar_sum.shape) # 58

        # for distance similar : get the filter index with largest similarity = small distance
        similar_large_index = similar_sum.argsort()[similar_pruned_num : ] # [51:58]
        similar_small_index = similar_sum.argsort()[: similar_pruned_num]
        print('similar_large_index is ', similar_large_index)
        similar_index_for_channel = [channel_large_index[i] for i in similar_large_index]
        print('similar_large_index is ',len(similar_large_index))
        print('similar_small_index is',len(similar_small_index))
        print()
        indice_stayed_list[key] = similar_index_for_channel
    return indice_stayed_list

indice_stayed_list = get_channel_similar(channel_importance, 0.9, 0.8)
    

channel_pruned_num is 6
similar_pruned_num is 51
norm size is  (64,)
[ 8.843054   16.4678     20.465994    2.1409807  25.955065    1.4569151
 20.467567    3.6182199   5.4505134  10.542476    7.0324416   5.465562
  8.386919    1.7527492  12.133921    1.5574309  12.515965    6.7304826
  5.484535   23.187952   13.716554   13.823424    5.594946   10.095293
  3.9778738   4.0980024   2.3475268   3.2530458  15.189196    5.2836237
  3.000581    3.6320965   7.1418276  10.641338   23.222857    2.386313
 13.863937    2.6403618  11.402835    7.8919635   7.4937873   2.1737514
 12.850243    1.2160556   5.6978054   2.449034    9.564865    7.828643
  9.971663    2.4093535   0.63696325 22.200565    2.034619   14.695812
  1.0921729   9.197013    8.030674   29.293228   27.775875   19.304482
  4.3623586   2.892663    0.9491922   7.288437  ]
channel_large_index :  [13 52  3 41 26 35 49 45 37 61 30 27  7 31 24 25 60 29  8 11 18 22 44 17
 10 32 63 40 47 39 56 12  0 55 46 48 23  9 33 38 14 16 42 20 21 36 53 2

norm size is  (256,)
[ 6.138709   4.3469095  6.70702    5.1417217  2.6298952  6.6414156
  6.543254   1.9184171  8.517323   5.7260976  6.728763  13.266199
  4.602842  12.519095   3.7278178 13.500653  21.631618   8.823098
  9.917969  17.15443    4.2480607  6.528798  11.75193    9.008531
  5.4222503  4.1144485  3.6032007  8.485818  12.5564     8.541261
 60.530075  12.883948   5.6982965  9.523921   6.882387   7.2125716
  9.652905  11.689086   6.0770655  6.3261166  7.134543   9.084133
 14.662647  22.485643  15.719734   1.1214216  6.2966437  9.198606
  4.5737276  5.1935806 50.968235   5.5034285  6.765223   6.5646396
  3.7192242  8.084684   5.792927   2.9418     4.684833   3.1090105
  4.824629   4.751178   6.439912   3.1218286 16.52713    5.250758
 34.12794    2.998714  13.288203   5.7153964  3.32313   14.3116
  8.457389  15.362759   4.4844885  2.654795   4.098627   3.882608
 13.648926   2.1757476  2.762608   9.215157   7.6309614  6.4094367
 13.256072   4.657688   8.374654   3.3370187  7.6918

similar_matrix is  [[ 0.          5.71526563  5.53274147 ... 39.11500024 42.70163452
  86.71643867]
 [ 5.71526563  0.          5.1815611  ... 41.39207749 42.95054551
  86.63649671]
 [ 5.53274147  5.1815611   0.         ... 40.66736637 42.86104784
  87.2096842 ]
 ...
 [39.11500024 41.39207749 40.66736637 ...  0.         50.33754031
  87.4225745 ]
 [42.70163452 42.95054551 42.86104784 ... 50.33754031  0.
  76.96320525]
 [86.71643867 86.63649671 87.2096842  ... 87.4225745  76.96320525
   0.        ]]
similar_sum is  (231,)
similar_large_index is  [199 194 208 204 209 207 210 214 211 216 215 219 218 212 213 220 217 221
 222 223 224 225 226 227 228 229 230]
similar_large_index is  27
similar_small_index is 204

channel_pruned_num is 25
similar_pruned_num is 204
norm size is  (256,)
[ 10.084057    5.515793    7.9231424  13.730294   16.42403     5.949347
   6.792806    7.495636    6.262456    8.552223    8.533479    8.187577
   6.985281   14.03162   228.13712     5.645352    6.8072786  11.469

similar_matrix is  [[  0.           3.76049852   3.57430445 ...  65.94901982 185.02462511
  234.02900378]
 [  3.76049852   0.           3.0470156  ...  66.27335349 183.87962667
  234.1686389 ]
 [  3.57430445   3.0470156    0.         ...  66.26809041 184.78056769
  233.93927216]
 ...
 [ 65.94901982  66.27335349  66.26809041 ...   0.         181.95839117
  235.06908312]
 [185.02462511 183.87962667 184.78056769 ... 181.95839117   0.
  269.381722  ]
 [234.02900378 234.1686389  233.93927216 ... 235.06908312 269.381722
    0.        ]]
similar_sum is  (461,)
similar_large_index is  [413 414 418 417 415 409 411 419 410 422 420 412 421 425 416 426 424 427
 429 428 430 423 431 433 434 435 432 438 436 440 442 443 437 439 444 441
 445 446 447 449 448 450 451 452 453 454 455 456 457 458 459 460]
similar_large_index is  52
similar_small_index is 409

channel_pruned_num is 51
similar_pruned_num is 409
norm size is  (512,)
[  3.344481    5.508268    5.897479   21.456322   13.432796    2.2649255
   4

similar_matrix is  [[  0.           3.86997526   3.36483257 ...  93.48312569  96.42057924
  103.1947742 ]
 [  3.86997526   0.           4.15274311 ...  93.59112904  95.75717097
  102.81547136]
 [  3.36483257   4.15274311   0.         ...  91.74035585  97.95136116
  102.39089215]
 ...
 [ 93.48312569  93.59112904  91.74035585 ...   0.         181.90327013
  139.56382061]
 [ 96.42057924  95.75717097  97.95136116 ... 181.90327013   0.
  133.72073753]
 [103.1947742  102.81547136 102.39089215 ... 139.56382061 133.72073753
    0.        ]]
similar_sum is  (461,)
similar_large_index is  [406 411 413 407 414 415 416 412 417 419 418 420 423 425 422 421 424 427
 426 428 430 432 433 429 431 434 435 436 439 441 437 438 440 442 444 443
 445 446 447 448 449 450 451 452 453 454 455 456 457 459 458 460]
similar_large_index is  52
similar_small_index is 409

channel_pruned_num is 51
similar_pruned_num is 409
norm size is  (512,)
[  4.84364      4.346143    14.320311     3.7209709   18.60027
   5.2408156

In [10]:
indice_stayed_list = {}
for key, channel_weight in channel_importance.items():

    channel_weight = channel_weight.cuda()
    channel_weight_vec = channel_weight.view(channel_weight.size(1), -1) # each channel
    norm2 = torch.norm(channel_weight_vec, 2, 1) # L2norm return 1 vector
    norm2_np = norm2.cpu().numpy()

    # distance using numpy function
    indices = torch.LongTensor(norm2_np).cuda()
    channel_weight_vec_after_norm = torch.index_select(channel_weight_vec, 0, indices).cpu().numpy() # select large channel using index
    print('channel_weight_vec_after_norm is ' , channel_weight_vec_after_norm.shape)

    #for euclidean distance
    similar_matrix = distance.cdist(channel_weight_vec_after_norm,channel_weight_vec_after_norm,'euclidean' )
    print('similar_matrix is ', similar_matrix) # (58,58)
    similar_sum = np.sum(np.abs(similar_matrix), axis = 0)  
    print('similar_sum is ', similar_sum.shape) # 58
    print()
    indice_stayed_list[key] = similar_index_for_channel
    break


channel_weight_vec_after_norm is  (64, 50176)
similar_matrix is  [[ 0.         13.75109252 14.72776517 ... 21.18812208  8.89975291
   6.08517054]
 [13.75109252  0.         18.33925446 ... 24.30470696 15.35590649
  12.68703811]
 [14.72776517 18.33925446  0.         ... 17.09473424 14.67529452
  14.23844201]
 ...
 [21.18812208 24.30470696 17.09473424 ...  0.         21.42707064
  20.666945  ]
 [ 8.89975291 15.35590649 14.67529452 ... 21.42707064  0.
   8.82143317]
 [ 6.08517054 12.68703811 14.23844201 ... 20.666945    8.82143317
   0.        ]]
similar_sum is  (64,)



NameError: name 'similar_index_for_channel' is not defined

In [11]:
similar_matrix.shape

(64, 64)

In [None]:
k = 6
stayed_channel_idx = []

first_random = np.random.choice(similar_matrix.shape()[0], 1) # 3
stayed_channel_idx.append(first_random) # [3]
max_idx = similar_matrix[first_random].argsort()[-1] # 6
stayed_channel_idx.append(max_idx) # [3,6]

similar_matrix[first_random][max_idx] = 0
similar_matrix[max_idx][first_random] = 0

while len(stayed_channel_idx) < k:
    next_idx = similar_matrix[max_idx].argsort()[-1] # max_idx = 6, next_idx = 4
    similar_matrix[first_random][next_idx] = 0
    similar_matrix[next_idx][first_random] = 0

    

In [273]:
first_val

53

In [15]:
def get_channel_similar_using_kmeans(channel_importance_list, compress_rate, distance_rate):
    indice_stayed_list = {}
    for key, channel_weight in channel_importance_list.items():
        print('channel_pruned_num is', int(channel_weight.size()[1] * (1-compress_rate))) # 6
        print('similar_pruned_num is', int(channel_weight.size()[1] * distance_rate)) # 51 

        channel_weight = channel_weight.cuda()
        channel_pruned_num = int(channel_weight.size()[1] * (1-compress_rate))
        similar_pruned_num = int(channel_weight.size()[1] * distance_rate)
        channel_weight_vec = channel_weight.view(channel_weight.size(1), -1) # each channel
        norm2 = torch.norm(channel_weight_vec, 2, 1) # L2norm return 1 vector
        norm2_np = norm2.cpu().numpy()
        channel_large_index = norm2_np.argsort()[channel_pruned_num:]
        channel_small_index = norm2_np.argsort()[:channel_pruned_num]
        print('channel_large_index : ', channel_large_index, len(channel_large_index))
        
        # distance using numpy function
        indices = torch.LongTensor(channel_large_index).cuda()
        channel_weight_vec_after_norm = torch.index_select(channel_weight_vec, 0, indices).cpu().numpy() # select large channel using index
        print('channel_weight_vec_after_norm is ' , channel_weight_vec_after_norm.shape)
        
        #for euclidean distance
        similar_matrix = distance.cdist(channel_weight_vec_after_norm,channel_weight_vec_after_norm,'euclidean' )
        print('similar_matrix is ', similar_matrix.shape) # (58,58)
        similar_sum = np.sum(np.abs(similar_matrix), axis = 0)  
        print('similar_sum is ', similar_sum.shape) # 58

        # for distance similar : get the filter index with largest similarity = small distance
        similar_large_index = similar_sum.argsort()[similar_pruned_num : ] # [51:58]
        similar_small_index = similar_sum.argsort()[: similar_pruned_num]
        print('similar_large_index is ', similar_large_index)
        similar_index_for_channel = [channel_large_index[i] for i in similar_large_index]
        print('similar_large_index is ',len(similar_large_index))
        print('similar_small_index is',len(similar_small_index))
        print()
        indice_stayed_list[key] = similar_index_for_channel
        
    return indice_stayed_list

indice_stayed_list = get_channel_similar_using_kmean(channel_importance, 0.9, 0.8)
    

TypeError: unsupported operand type(s) for -: 'int' and 'dict'

In [11]:
print(indice_stayed_list)

{'0.conv': [50, 62, 54, 43, 15, 5, 13, 52, 41, 3, 35, 26, 45, 37, 49, 61, 30, 27, 7, 31, 24, 25, 60, 29, 11, 8, 44, 22, 18, 10, 32, 40], '1.conv': [15, 33, 26, 23, 21, 1, 29, 51, 27, 61, 3, 28, 37, 43, 17, 19, 6, 12, 30, 55, 31, 0, 54, 60, 44, 46, 38, 9, 24, 52, 7, 13], '2.conv': [102, 9, 88, 45, 54, 68, 28, 33, 51, 15, 24, 66, 65, 38, 111, 104, 61, 18, 40, 109, 37, 22, 14, 27, 86, 5, 106, 62, 97, 78, 103, 13, 1, 8, 119, 34, 44, 55, 98, 110, 79, 82, 80, 91, 74, 20, 42, 21, 32, 96, 120, 90, 52, 127, 43, 60, 112, 6, 73, 81, 46, 10, 39, 72], '3.conv': [15, 77, 36, 14, 29, 99, 61, 90, 126, 39, 63, 107, 84, 31, 68, 73, 71, 9, 5, 50, 88, 82, 108, 54, 16, 0, 106, 3, 25, 69, 59, 112, 86, 57, 114, 40, 47, 104, 30, 18, 123, 35, 95, 122, 113, 62, 67, 53, 13, 45, 93, 37, 20, 92, 94, 87, 115, 34, 4, 105, 89, 12, 51, 66], '4.conv': [45, 93, 171, 110, 7, 149, 89, 194, 79, 167, 125, 75, 253, 189, 95, 252, 236, 80, 4, 180, 152, 59, 106, 136, 201, 67, 57, 213, 163, 70, 217, 26, 87, 63, 102, 145, 220, 10

In [12]:
for i,(name, m) in enumerate(named_conv_list.items()):
    if isinstance(m, torch.nn.Conv2d):
        if str(i + 1) + '.conv' in named_conv_list:
            next_m = named_modules_list[str(i + 1) + '.conv']
        else:
            next_m = model.classifier[0]
        indices_stayed = indice_stayed_list[name]
        module_surgery(m, next_m, indices_stayed)
        if not isinstance(next_m, torch.nn.Linear):
            next_output_features = original_conv_output[str(i + 1) + '.conv']
            next_m_idx = named_conv_idx_list[str(i + 1) + '.conv']
            pruned_next_inputs_features = model.features[:next_m_idx](inputs_data)
            #weight_reconstruction(next_m, pruned_next_inputs_features, next_output_features,use_gpu=cuda)
    stayed_channels[str(i) + '.conv'] = set(indices_stayed)

In [13]:
stayed_channels

{'0.conv': {3,
  5,
  7,
  8,
  10,
  11,
  13,
  15,
  18,
  22,
  24,
  25,
  26,
  27,
  29,
  30,
  31,
  32,
  35,
  37,
  40,
  41,
  43,
  44,
  45,
  49,
  50,
  52,
  54,
  60,
  61,
  62},
 '1.conv': {0,
  1,
  3,
  6,
  7,
  9,
  12,
  13,
  15,
  17,
  19,
  21,
  23,
  24,
  26,
  27,
  28,
  29,
  30,
  31,
  33,
  37,
  38,
  43,
  44,
  46,
  51,
  52,
  54,
  55,
  60,
  61},
 '2.conv': {1,
  5,
  6,
  8,
  9,
  10,
  13,
  14,
  15,
  18,
  20,
  21,
  22,
  24,
  27,
  28,
  32,
  33,
  34,
  37,
  38,
  39,
  40,
  42,
  43,
  44,
  45,
  46,
  51,
  52,
  54,
  55,
  60,
  61,
  62,
  65,
  66,
  68,
  72,
  73,
  74,
  78,
  79,
  80,
  81,
  82,
  86,
  88,
  90,
  91,
  96,
  97,
  98,
  102,
  103,
  104,
  106,
  109,
  110,
  111,
  112,
  119,
  120,
  127},
 '3.conv': {0,
  3,
  4,
  5,
  9,
  12,
  13,
  14,
  15,
  16,
  18,
  20,
  25,
  29,
  30,
  31,
  34,
  35,
  36,
  37,
  39,
  40,
  45,
  47,
  50,
  51,
  53,
  54,
  57,
  59,
  61,
  62,
  63,


In [None]:
volatile gpu-util
