In [6]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
import torch_pruning as tp
from functools import reduce

def get_module_by_name(model, access_string):
    names = access_string.split(sep='.')
    return reduce(getattr, names, model)

model = resnet50().eval()
bigger_model = resnet50(weights=ResNet50_Weights).eval()

state = torch.load('exp_2_model_resnet50_prune_0.125_single_epochs_10_epoch9pruned.pth', map_location='cpu')
tp.load_state_dict(model, state_dict=state)
model.eval()

model = model.module

# Iterate over model parameters
for name, params in model.named_parameters():
    new_tensor = torch.ones_like(params.data)
    params.data = new_tensor

# # Iterate over model parameters
# for name, params in model.named_parameters():
#     if name == "conv1.weight":
#         new_tensor = torch.ones_like(params.data[1, :, :, :])
#         params.data[1, :, :, :] = new_tensor 

# model.conv1.weight.data[1, :, :, :]



In [118]:
import pickle
with open('0.125_history_exp2', 'rb') as file:
    history = pickle.load(file)

In [119]:
layers_affected = len(history)
layers_affected_per_step = int(layers_affected / 1)
step_history = [history[i:i+layers_affected_per_step] for i in range(0, layers_affected, layers_affected_per_step)]

In [None]:
# # Iterate over model parameters
# for name, params in model.named_parameters():
#     if name == "conv1.weight":
#         new_tensor = torch.ones_like(params.data[1, :, :, :])
#         params.data[1, :, :, :] = new_tensor 

# model.conv1.weight.data[1, :, :, :]

def rebuild_model(tuned_model, bigger_model, step_history):
    for i, history in enumerate(reversed(step_history[0])):

        # loop through each layer changed in pruning
        for pruned_layer_name, b, channels_removed in reversed(history):

            # loop through the layers of the larger model (same number of layers, different channel width)
            for layer_name, bigger_layer_params in bigger_model.named_parameters():

                skipped = 0

                if"module."+layer_name == pruned_layer_name+".weight" and layer_name == "conv1.weight":

                        # get copy of layers
                        tuned_layer = get_module_by_name(tuned_model, layer_name[:-7])
                        bigger_layer = get_module_by_name(bigger_model, layer_name[:-7])
                        print(channels_removed)

                        # loop throught the channels of the bigger model
                        for idx in range(bigger_layer.out_channels):

                            # check if the channel has been dropped
                            if idx in channels_removed:
                                # if channel was dropped, do not copy weights from smaller tuned model
                                # print("Channel was skipped")
                                skipped += 1

                            else:
                                # copy weights from tuned model to larger model
                                if "layer" not in layer_name:
                                    print(layer_name, idx, idx-skipped)
                                    bigger_layer_params.data[idx,:, : ,:] = tuned_layer.weight.data[idx-skipped,:, : ,:]

                                else: # for conv layers with reshape of both input and output

                                    # bigger_layer_params.requires_grad_(False)
                                    skipped_j = 0

                                    if (bigger_layer.in_channels - tuned_layer.in_channels) == len(channels_removed):

                                        for idx_j in range(bigger_layer.in_channels):

                                            if idx_j in channels_removed:
                                                # if channel was dropped, do not copy weights from smaller tuned model
                                                skipped_j += 1
                                            else:
                                                bigger_layer_params.data[idx,idx_j, : ,:] = tuned_layer.weight.data[idx-skipped,idx_j-skipped_j, : ,:]
    return bigger_model

bm = rebuild_model(model, bigger_model, step_history)

In [5]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
import torch_pruning as tp
from functools import reduce
import pickle

def get_module_by_name(model, access_string):
    names = access_string.split(sep='.')
    return reduce(getattr, names, model)


model = resnet50().eval()
bigger_model = resnet50(weights=ResNet50_Weights).eval()

state = torch.load(
    '../src/exp_2_model_resnet50_pruned_125_july.pth', map_location='cpu')
tp.load_state_dict(model, state_dict=state)
model.eval()
model = model.module

with open('../src/history_exp2_125_july', 'rb') as file:
    history_steps = pickle.load(file)


def get_layer_history(layer, history_steps):
    for step in history_steps:
        for name, b, channels in step:
            if name == "module."+layer:
                return channels


def get_module_by_name(model, access_string):
    names = access_string.split(sep='.')
    return reduce(getattr, names, model)

    # print(original_layer.weight.data.shape)
    # print(pruned_layer.weight.data.shape)

    # todo add loop for all out channels

    # out 256 => 224 32 out channels, these removed channels use original weights when rebuilding
    # in 256 => 224 32 in channels removed on non pruned channels, for each non pruned channel, find in channels that are pruned also


def get_layer_in_channel_history(original, pruned, layer, pruned_out_channels):
    original_layer = get_module_by_name(original, layer)
    pruned_layer = get_module_by_name(pruned, layer)
    skipped = 0  # adjustment to match out_channel between original and pruned model of different shapes
    pruned_in_channels_history = []

    for out_channel_idx in range(original_layer.out_channels):
        not_pruned_in_channels = []  # in channels pruned per out channel
        if out_channel_idx in pruned_out_channels:
            # the out_channel is completely pruned
            skipped += 1
        else:
            for in_channel_i in range(original_layer.in_channels):
                # the out_channel is partially pruned, loop through the in channels 
                # and find which idx have been pruned for each non-pruned out channel
                for in_channel_j in range(pruned_layer.in_channels):
                    # the output channel exists in both pruned and original model
                    if torch.equal(original_layer.weight.data[out_channel_idx, in_channel_i, :, :], original_layer.weight.data[out_channel_idx-skipped, in_channel_j, :, :]):
                        not_pruned_in_channels.append(in_channel_j)
                        continue
                        # in_channel_j of the pruned layer matches weights in the original layer, i.e not pruned
        
        all_channels = list(range(original_layer.in_channels))
        pruned_in_channels= [x for x in all_channels if x not in not_pruned_in_channels]
        pruned_in_channels_history.append([layer, pruned_in_channels])
        break # the input channels dropped are the same for each output channel
    return pruned_in_channels_history

pruned_out_channels = get_layer_history('layer2.2.conv1', history_steps)
in_history = get_layer_in_channel_history(bigger_model, model, 'layer2.2.conv1', pruned_out_channels)

for step in history_steps:
    for layer_name, b, out_channels_pruned in step:
        # print(layer_name[7:])
        in_history = get_layer_in_channel_history(bigger_model, model, layer_name[7:], pruned_out_channels)
        for idx, in_channels in in_history:
            print(layer_name, idx, in_channels)

module.layer4.0.downsample.0 0 [896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023]
module.layer3.0.downsample.0 0 [448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 50

In [None]:

# histories overview
# out_channel history layer_name, [channels pruned]
# in_channel history layer_name, out_channel, [channels_pruned] (do not include out_channels that have been pruned)

def find_missing_channel_indices(original, pruned, layer):
    conv_layer1 = get_module_by_name(original, layer)
    conv_layer2 = get_module_by_name(pruned, layer)

    print(conv_layer1.weight.data.shape)
    print(conv_layer2.weight.data.shape)

    not_found = []
    for channel_i in range(conv_layer1.out_channels):
        found_match = False
        for channel_j in range(conv_layer2.out_channels):
            if torch.equal(conv_layer1.weight.data[channel_i], conv_layer2.weight.data[channel_j]):
                found_match = True
                break

        if not found_match:
            not_found.append(channel_i)

    return not_found

# for step in history:
#     for name, b, channels in step:
#         pruned_channels = find_missing_channel_indices(bigger_model, model, name[7:])
#         print(name, len(pruned_channels), len(channels))

# pruned_channels = find_missing_channel_indices(bigger_model, model, "layer1.0.conv2")
# check input channels


# layer_name = "layer3.5.conv1"
out_history = get_layer_history(layer_name, history)
print(len(out_history), out_history)

# find the history of how the input channels are removed for each layer
found = find_missing_channel_indices_input(
    bigger_model, model, layer_name, out_history)
len(found)

# make input channel history per output channel

for step in history_steps:
    for layer_name, b, channels in step:
        pruned_channels = find_missing_channel_indices_input(
            bigger_model, model, layer_name, channels)
        print(name, len(pruned_channels), len(channels))

# rebuilt model with output history and input history (change needed for in channels part)


In [None]:
h = "module.layer4.0.downsample.0,2,61,93,100,108,112,143,275,311,338,468,471,583,605,619,653,695,831,839,854,892,898,991,1058,1099,1154,1257,1258,1288,1312,1427,1462,1465,1477,1542,1561,1748,1772,1827,1880,1925,1993,2040,2035,351,305,715,1727,1014,1694,327,1201,1532,570,1530,1555,167,751,1193,1736,1888,1344,1077,2046,1492,745,15,1009,449,1578,309,1906,847,1980,1850,1253,616,1929,273,1444,1156,623,1568,1592,558,2020,861,1126,560,456,794,1441,22,66,663,706,526,478,32,1868,675,660,1158,610,1508,1819,1653,929,106,1428,1262,1144,1326,497,195,1817,94,362,218,1974,968,1349,16,1693,1322,833,1051,950,1743,1900,1292,204,1570,563,1371,955,303,743,80,43,1192,734,622,801,1843,1460,1087,1496,296,808,1604,975,1510,1932,1820,158,617,301,1124,1211,1834,850,1625,1613,1971,1718,1829,7,1898,331,144,726,54,1393,427,1595,128,382,2003,307,101,1686,1645,1055,1307,194,771,868,453,688,77,1270,472,415,1440,203,1681,705,1429,909,1676,51,1680,1674,939,2042,353,974,1093,1120,412,1081,474,1692,1769,1608,486,687,1187,672,2022,685,1770,247,1357,1600,279,1003,1347,1249,1028,863,691,521,1598,895,910,1849,293,461,83,107,300,809,1699,737,1944,782,754,1167,1658,1046,207,1414,302,1301,module.layer3.0.downsample.0,962,1023,853,1018,544,657,197,486,170,412,644,303,491,18,778,91,729,501,906,151,500,980,51,60,759,444,297,194,45,741,96,1017,38,609,419,966,1011,803,16,123,390,662,923,927,685,487,514,786,83,149,511,656,590,595,118,728,825,928,482,744,407,132,881,352,732,709,873,747,388,331,541,92,536,772,105,478,365,556,647,37,475,579,610,863,801,202,340,800,329,985,619,99,71,325,205,776,47,836,766,817,665,457,326,819,93,371,601,589,620,114,706,433,199,576,701,330,748,678,716,683,4,14,900,376,543,72,311,533,module.layer2.0.downsample.0,387,246,243,404,176,412,347,34,100,11,163,24,114,98,371,424,66,96,363,53,379,43,495,361,438,326,207,509,166,311,85,33,353,188,382,473,442,309,58,328,192,372,339,60,0,131,462,184,147,279,465,72,142,10,81,285,269,177,102,381,504,447,265,483,module.layer1.0.downsample.0,144,77,108,214,150,58,112,131,103,92,156,29,235,233,31,10,36,225,231,168,63,245,68,0,196,97,218,135,78,60,21,105,module.conv1,13,42,44,52,57,11,31,59,module.layer1.0.conv2,0,1,9,14,19,22,28,59,module.layer1.0.conv1,1,7,18,20,21,23,27,33,module.layer1.1.conv2,6,17,25,41,44,38,40,45,module.layer1.1.conv1,2,4,9,13,15,19,23,30,module.layer1.2.conv2,11,4,7,57,8,24,51,43,module.layer1.2.conv1,39,49,2,11,63,19,45,52,module.layer2.0.conv2,35,119,90,3,12,29,4,69,53,120,99,51,111,110,115,38,module.layer2.0.conv1,112,92,63,89,104,68,8,14,122,61,66,2,110,12,121,86,module.layer2.1.conv2,9,10,26,29,46,54,77,93,94,100,109,120,121,112,38,126,module.layer2.1.conv1,1,3,5,14,28,29,45,56,58,63,69,76,83,102,108,111,module.layer2.2.conv2,56,29,18,39,101,75,34,86,32,2,35,12,53,96,40,97,module.layer2.2.conv1,38,63,75,112,117,120,40,12,102,6,76,88,123,107,67,124,module.layer2.3.conv2,127,124,36,76,77,17,105,67,113,0,95,19,116,107,104,21,module.layer2.3.conv1,9,74,120,6,79,49,114,4,52,23,30,37,33,105,119,127,module.layer3.0.conv2,2,4,13,54,106,141,207,221,248,250,35,157,151,194,133,152,222,247,211,115,254,111,160,116,127,126,219,136,232,192,150,210,module.layer3.0.conv1,4,18,48,89,90,136,139,148,220,245,254,134,236,92,110,178,44,225,39,79,196,135,86,160,244,1,189,235,103,197,194,253,module.layer3.1.conv2,7,8,9,16,23,24,42,56,72,74,82,88,100,108,111,123,127,130,132,133,146,155,164,191,209,213,218,221,228,229,235,238,module.layer3.1.conv1,11,13,32,41,43,50,54,73,75,76,77,79,85,92,95,97,103,122,131,142,144,149,152,162,167,172,173,181,187,188,194,201,module.layer3.2.conv2,35,79,108,117,149,196,43,157,156,166,99,190,34,47,87,241,3,240,228,17,153,14,233,250,130,112,30,182,48,81,169,131,module.layer3.2.conv1,1,20,23,29,30,47,51,56,97,99,100,105,134,153,155,161,164,167,216,242,243,62,8,0,181,148,196,208,210,59,84,52,module.layer3.3.conv2,34,35,127,148,190,247,100,146,37,82,131,197,161,157,199,43,85,124,186,71,218,76,183,169,255,38,221,217,119,149,173,231,module.layer3.3.conv1,9,19,24,31,58,59,100,114,124,180,181,189,195,214,219,240,244,254,144,128,217,213,71,41,79,188,209,252,110,81,167,73,module.layer3.4.conv2,19,41,42,53,126,144,150,169,186,242,244,252,35,116,105,50,213,79,165,212,117,222,164,174,231,84,193,16,198,67,210,86,module.layer3.4.conv1,6,10,14,20,26,48,55,60,71,73,87,100,108,110,190,204,111,168,227,76,1,240,17,21,164,122,65,19,153,49,8,104,module.layer3.5.conv2,8,38,59,91,94,105,135,142,143,183,186,203,215,221,239,219,141,77,206,137,52,224,222,227,127,53,249,163,128,85,114,0,module.layer3.5.conv1,1,3,5,10,16,29,70,84,95,112,114,133,190,200,204,220,246,249,251,250,126,55,199,166,205,81,77,203,153,254,12,105,module.layer4.0.conv2,12,25,38,46,48,61,65,72,78,84,103,111,127,130,142,143,157,160,163,171,178,192,203,231,236,238,244,266,268,269,272,284,286,290,303,306,326,365,371,378,381,383,386,388,409,422,452,461,506,511,453,499,405,342,211,22,91,319,126,267,335,50,80,253,module.layer4.0.conv1,8,10,12,15,16,18,34,53,55,124,132,143,188,191,225,230,234,251,252,254,268,280,291,304,309,312,314,316,318,326,329,338,383,388,406,410,416,420,429,431,432,436,447,454,457,467,471,483,485,487,488,510,203,426,421,433,259,0,128,105,336,114,253,394,module.layer4.1.conv2,12,18,35,36,40,46,48,51,54,57,58,60,61,63,67,73,76,78,86,87,88,91,95,96,104,116,118,119,120,128,131,134,144,148,153,155,162,163,169,170,172,175,177,178,183,185,193,194,198,202,204,205,214,223,224,229,237,248,249,251,256,258,261,263,module.layer4.1.conv1,7,9,10,12,17,19,22,25,27,29,30,31,32,35,41,44,45,46,47,51,52,54,56,58,60,61,63,64,65,68,72,73,74,75,78,80,84,86,88,93,94,96,99,100,104,105,111,114,117,118,120,122,124,126,128,133,135,139,141,142,143,145,147,151,module.layer4.2.conv2,0,7,8,11,15,26,33,37,43,46,47,54,62,70,80,84,87,92,96,99,100,125,150,155,156,158,160,166,169,171,173,176,178,183,187,188,201,213,219,235,243,245,266,270,275,277,281,283,284,294,295,302,307,308,309,315,331,333,336,337,344,348,350,353,module.layer4.2.conv1,3,5,6,7,8,14,15,18,19,24,26,29,33,34,37,38,44,46,48,49,50,54,57,59,64,65,66,69,76,77,78,79,80,87,88,89,90,91,92,93,95,98,100,102,103,105,106,108,109,110,111,112,113,114,115,117,118,120,121,125,126,127,129,135"
# h = h.replace("'", "")

delimiter = 'module'
h_list = [delimiter+x for x in h.split(delimiter) if x]
h_list


#split names from indexes
hl = []
for i in h_list:
    i = i.split(",", 1)
    # print(i)
    hl.append(i)


def remove_values_from_list(the_list, val):
   return [value for value in the_list if value != val]

fl = []
for name, indexes in hl:
    s_indexes = indexes.split(",")
    s_indexes = remove_values_from_list(s_indexes, "")
    n_indexes = [int(i) for i in s_indexes]
    indexes = n_indexes

for name, indexes in hl:
    print(name, indexes)
