In [21]:
# imports

import os
import torch
import numpy as np
import torch.nn as nn
import pandas as pd 

import torch.backends.cudnn
import torchvision.transforms
import torchvision.datasets as datasets
import torch.nn.utils.prune as prune
import torch.nn.functional as F

from torch.utils.data import  DataLoader
# from torchprofile import profile_macs

import torch
from enum import Enum
from functools import reduce
from torch.utils.data import Subset
import time
import copy
import torch.distributed as dist
import argparse
import os
import random
import warnings
import pickle
from torchvision.models import ResNet50_Weights, ResNet18_Weights, resnet18, resnet50
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
import torch_pruning as tp
from torchinfo import summary
from copy import copy, deepcopy

torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [30]:
# define utility methods
output_path = "../checkpoints/results"


def load_out_history(filename):
    with open(output_path + 'out_'+filename+'.pkl', 'rb') as file:
        history_steps = pickle.load(file)
    return history_steps


def load_in_history(filename):
    with open(output_path + 'in_'+filename+'.pkl', 'rb') as f:
        loaded_dict = pickle.load(f)
    return loaded_dict


def save_out_history(history_steps, filename):
    with open(output_path + 'out_'+filename+'.pkl', 'wb') as temp:
        pickle.dump(history_steps, temp)


def save_in_history(d, filename):
    with open(output_path + 'in_'+filename+'.pkl', 'wb') as f:
        pickle.dump(d, f)
        
def load_model(filename, pruned=False):
    model = resnet18()
    if pruned:
        state = torch.load(output_path +filename+".pth", map_location='cpu')
        tp.load_state_dict(model, state_dict=state)

    else:
        model.load_state_dict(torch.load(output_path +filename+".pth"))
    return model


def save_model(model, filename, pruned=False):
    if pruned:
        state_dict = tp.state_dict(model)  # the pruned model
        torch.save(state_dict, output_path +filename+".pth")
    else:
        torch.save(model.state_dict(), output_path +filename+".pth")

def get_in_channel_history(original_model, pruned_model, step_history):
    pruned_in_channels_history_dict = {}
    print("=> Start history generation")
    for i, history in enumerate(reversed(step_history)):
        for pruned_layer_name, b, out_channels_removed in reversed(history):
            print(pruned_layer_name)
            pruned_layer = get_module_by_name(pruned_model, pruned_layer_name)
            original_layer = get_module_by_name(original_model, pruned_layer_name)
            in_history = get_index_in_channel_history(original_layer, pruned_layer, out_channels_removed)
            pruned_in_channels_history_dict[pruned_layer_name] = in_history
            print(pruned_layer_name, in_history)

    return pruned_in_channels_history_dict

def get_index_in_channel_history(original_layer, pruned_layer, pruned_out_channels):
    skipped = 0  # adjustment to match out_channel between original and pruned model of different shapes
    pruned_in_channels_history = []

    for out_channel_idx in range(original_layer.out_channels):
        not_pruned_in_channels = []  # in channels pruned per out channel
        if out_channel_idx in pruned_out_channels:
            # the out_channel is completely pruned
            skipped += 1
        else:
            for in_channel_i in range(original_layer.in_channels):
                # the out_channel is partially pruned, loop through the in channels
                # and find which idx have been pruned for each non-pruned out channel
                for in_channel_j in range(pruned_layer.in_channels):
                    # the output channel exists in both pruned and original model
                    if torch.equal(original_layer.weight.data[out_channel_idx, in_channel_i, :, :],
                                   pruned_layer.weight.data[out_channel_idx - skipped, in_channel_j, :, :]):
                        not_pruned_in_channels.append(in_channel_i)
                        continue
                        # in_channel_j of the pruned layer matches weights in the original layer, i.e not pruned

        all_channels = list(range(original_layer.in_channels))
        pruned_in_channels = [x for x in all_channels if x not in not_pruned_in_channels]
        # print(out_channel_idx, pruned_in_channels)
        # pruned_in_channels_history.append([out_channel_idx, pruned_in_channels])
        break  # the input channels dropped are the same for each output channel
    return pruned_in_channels

class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)

        return fmtstr.format(**self.__dict__)

class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
    
def get_module_by_name(model, access_string):
    names = access_string.split(sep='.')
    return reduce(getattr, names, model)


In [46]:

def validate(val_loader, model, criterion, device):
    def run_validate(loader, base_progress=0):
        with torch.no_grad():
            end = time.time()
            for i, (images, target) in enumerate(loader):
                i = base_progress + i
                images = images.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % 100 == 0:
                    progress.display(i + 1)

    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Validation: ')

    # switch to evaluate mode
    model.eval()

    run_validate(val_loader)
    progress.display_summary()
    
def apply_channel_prune(model, original_model, sparsity, example_inputs):
        print("=> Applying pruning: '{}'".format(sparsity))
        # Importance criteria
        imp = tp.importance.TaylorImportance()

        ignored_layers = []
        for m in model.modules():
            if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
                ignored_layers.append(m)  # DO NOT prune the final classifier!

        iterative_steps = 1  # progressive pruning
        current_step = 1
        prune_amounts = [x / 64 for x in range(48)]

        pruner = tp.pruner.MagnitudePruner(
            model,
            example_inputs,
            importance=imp,
            iterative_steps=iterative_steps,
            ch_sparsity=sparsity,
            ignored_layers=ignored_layers,
        )

        base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)

        for i in range(iterative_steps):
            if isinstance(imp, tp.importance.TaylorImportance):
                # Taylor expansion requires gradients for importance estimation
                loss = model(example_inputs).sum()  # a dummy loss for TaylorImportance
                loss.backward()  # before pruner.step()
            pruner.step()
            macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
            print("Pruning step:", current_step, "multiply–accumulate (macs):", macs, "number of parameters", nparams)
            current_step += 1
            print(
                "  Iter %d/%d, Params: %.2f M => %.2f M"
                % (i + 1, iterative_steps, base_nparams / 1e6, nparams / 1e6)
            )
            print(
                "  Iter %d/%d, MACs: %.2f G => %.2f G"
                % (i + 1, iterative_steps, base_macs / 1e9, macs / 1e9)
            )


        model_statistics = summary(model, (1, 3, 224, 224), depth=3,
                                   col_names=["kernel_size", "input_size", "output_size", "num_params", "mult_adds"], )
        model_statistics_str = str(model_statistics)

        history = pruner.pruning_history()
        layers_affected = len(history)
        layers_affected_per_step = int(layers_affected / iterative_steps)
        out_history = [history[i:i + layers_affected_per_step] for i in
                        range(0, layers_affected, layers_affected_per_step)]

        in_history_dict = {}
        in_history_dict = get_in_channel_history(original_model, model, out_history)

        return model, out_history, in_history_dict
    

def train(train_loader, model, epoch, device, optimizer, criterion):
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        progress = ProgressMeter(
            len(train_loader),
            [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        # switch to train mode
        model.train()
        end = time.time()
        for i, (images, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            # move data to the same device as model
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 500 == 0 and i != 0:
                progress.display(i + 1)
                print(model.conv1.weight.data[0])
                break

def get_layers(model: torch.nn.Module, parent_name=''):
    layers = {}
    for name, module in model.named_children():
        layer_name = f"{parent_name}.{name}" if parent_name else name
        if len(list(module.children())) == 0:
            layers[layer_name] = module
        else:
            layers.update(get_layers(module, parent_name=layer_name))
    return layers

def rebuild_model(tuned_model, bigger_model, verification_model, device, out_history, in_history):
    print("=> Starting rebuilding...")
    layers_total = 0
    layers_rebuilt_count = 0

    tuned_model = tuned_model.to(device)
    bigger_model = bigger_model.to(device)
    verification_layers = get_layers(verification_model)


    for i, history in enumerate(reversed(out_history)):
        for pruned_layer_name, b, out_channels_removed in reversed(history):
            layers_total += 1

    for i, history in enumerate(reversed(out_history)):
        # loop through each layer changed in pruning

        for pruned_layer_name, b, out_channels_removed in reversed(history):

            # loop through the layers of the larger model (same number of layers, different channel width)
            for layer_name, bigger_layer_params in bigger_model.named_parameters():

                skipped_out_channels = 0
                # if"module."+layer_name == pruned_layer_name+".weight":
                # if layer_name == pruned_layer_name + ".weight" and pruned_layer_name == "layer1.2.conv2":
                if layer_name == pruned_layer_name + ".weight":
                    # get copy of layers
                    tuned_layer = get_module_by_name(tuned_model, pruned_layer_name)
                    bigger_layer = get_module_by_name(bigger_model, pruned_layer_name)
                    verification_layer = get_module_by_name(verification_model, pruned_layer_name)

                    # in_channels_removed = get_index_in_channel_history(bigger_layer, tuned_layer, out_channels_removed)
                    in_channels_removed = in_history[pruned_layer_name]

                    # loop throughout the channels of the bigger model
                    for out_channel_idx in range(bigger_layer.out_channels):

                        # check if the channel has been dropped
                        if out_channel_idx in out_channels_removed:
                            # if channel was dropped, do not copy weights from smaller tuned model
                            skipped_out_channels += 1

                        else:
                            # copy weights from tuned model to larger model
                            if (bigger_layer.in_channels - tuned_layer.in_channels) == 0:
                                bigger_layer_params.data[out_channel_idx, :, :, :] =\
                                    tuned_layer.weight.data[out_channel_idx - skipped_out_channels, :, :, :]


                            else:  # for conv layers with reshape of both input and output
                                skipped_in_channels = 0
                                for in_channel_idx in range(bigger_layer.in_channels):

                                    if in_channel_idx in in_channels_removed:
                                        # if channel was dropped, do not copy weights from smaller tuned model
                                        skipped_in_channels += 1
                                    else:
                                        bigger_layer_params.data[out_channel_idx, in_channel_idx, :, :] = \
                                            tuned_layer.weight.data[out_channel_idx - skipped_out_channels, in_channel_idx - skipped_in_channels, :, :]


            layers_rebuilt_count += 1
            print("("+str(layers_rebuilt_count), "of", str(layers_total)+")", pruned_layer_name, "has been rebuilt.")
#             if torch.equal(bigger_layer_params.data,tuned_layer.weight.data):
#                 print("Same:", pruned_layer_name, bigger_layer)
#             else:
#                 print("Different (Correct):", pruned_layer_name)

    output_model = deepcopy(bigger_model)
    return output_model

In [33]:
# data loaders

model_str = "resnet18"
device = torch.device("cuda:0")
example_inputs = torch.randn(1, 3, 224, 224).to(device)

traindir = '../data/train/'
valdir = '../data/val/'
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
    traindir,
    transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

val_dataset = datasets.ImageFolder(
    valdir,
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ]))

batch_size = 16
workers = 1

# train_loader = torch.utils.data.DataLoader(
#     train_dataset, batch_size=batch_size, shuffle=(None is None),
#     num_workers=workers, pin_memory=True, sampler=None)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False,
    num_workers=workers, pin_memory=True, sampler=None)

criterion = nn.CrossEntropyLoss().to(device)

In [47]:
verification_model = resnet18(weights=ResNet18_Weights)
verification_model = verification_model.to(device)
print("=> evaluate model")
# validate(val_loader, verification_model, criterion, device)
verification_model.conv1.weight.data[0]

=> evaluate model




tensor([[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  7.4841e-02,  5.6615e-02,
           1.7083e-02, -1.2694e-02],
         [ 1.1083e-02,  9.5276e-03, -1.0993e-01, -2.8050e-01, -2.7124e-01,
          -1.2907e-01,  3.7424e-03],
         [-6.9434e-03,  5.9089e-02,  2.9548e-01,  5.8720e-01,  5.1972e-01,
           2.5632e-01,  6.3573e-02],
         [ 3.0505e-02, -6.7018e-02, -2.9841e-01, -4.3868e-01, -2.7085e-01,
          -6.1282e-04,  5.7602e-02],
         [-2.7535e-02,  1.6045e-02,  7.2595e-02, -5.4102e-02, -3.3285e-01,
          -4.2058e-01, -2.5781e-01],
         [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  2.3897e-01,  4.1384e-01,
           3.9359e-01,  1.6606e-01],
         [-1.3736e-02, -3.6746e-03, -2.4084e-02, -6.5877e-02, -1.5070e-01,
          -8.2230e-02, -5.7828e-03]],

        [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  3.6812e-02,  3.2521e-02,
           6.6221e-04, -2.5743e-02],
         [ 4.5687e-02,  3.3603e-02, -1.0453e-01, -3.0885e-01, -3.1253e-01,
          -1.6051e-01, -1.2

# Pruning

In [58]:
model = resnet18(weights=ResNet18_Weights)
verification_model = resnet18(weights=ResNet18_Weights)
original_model = resnet18(weights=ResNet18_Weights)
model = model.to(device)

original_model = original_model.to(device)

prune = 0.05
prune_02 = 0.05



model, out_history, in_history = apply_channel_prune(model, original_model, prune, example_inputs)
model = model.to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), 0.001, momentum=0.9)
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

# # save model 01, out_channels, in_channels
# save_model(pruned_model_01, model_str+"_pruned_" + str(prune), pruned=True)
# save_out_history(out_history, model_str+"_pruned_" + str(prune))
# save_in_history(in_history, model_str+"_pruned_" + str(prune))

=> Applying pruning: '0.05'
layer4.0.downsample.0 [9, 60, 92, 100, 130, 144, 146, 154, 172, 179, 187, 230, 252, 266, 281, 283, 290, 364, 407, 421, 422, 437, 441, 458, 465, 475]
layer3.0.downsample.0 [40, 130, 133, 70, 86, 190, 102, 58, 34, 211, 61, 2, 137]
layer2.0.downsample.0 [53, 24, 55, 18, 76, 28, 22]
conv1 [18, 4, 7, 54]
layer1.0.conv1 [4, 21, 22, 33]
layer1.1.conv1 [3, 18, 38, 31]
layer2.0.conv1 [71, 79, 30, 22, 8, 62, 43]
layer2.1.conv1 [126, 94, 76, 84, 95, 39, 17]
layer3.0.conv1 [21, 24, 25, 26, 69, 89, 107, 108, 110, 124, 125, 127, 130]
layer3.1.conv1 [2, 19, 38, 44, 60, 71, 131, 143, 155, 158, 169, 171, 172]
layer4.0.conv1 [2, 5, 7, 9, 13, 14, 16, 18, 19, 23, 26, 27, 32, 34, 37, 44, 45, 46, 52, 54, 59, 60, 64, 70, 71, 81]
layer4.1.conv1 [0, 6, 17, 19, 20, 22, 34, 41, 43, 49, 53, 60, 63, 64, 71, 74, 92, 94, 110, 114, 115, 116, 128, 134, 135, 138]
Pruning step: 1 multiply–accumulate (macs): 1632482862.0 number of parameters 10549052
  Iter 1/1, Params: 11.69 M => 10.55 M
  It

# Tuning

In [53]:
print(model.conv1.weight.shape, model.conv1.weight.data[0])

torch.Size([60, 3, 7, 7]) tensor([[[-1.0419e-02, -6.1356e-03, -1.8098e-03,  7.4841e-02,  5.6615e-02,
           1.7083e-02, -1.2694e-02],
         [ 1.1083e-02,  9.5276e-03, -1.0993e-01, -2.8050e-01, -2.7124e-01,
          -1.2907e-01,  3.7424e-03],
         [-6.9434e-03,  5.9089e-02,  2.9548e-01,  5.8720e-01,  5.1972e-01,
           2.5632e-01,  6.3573e-02],
         [ 3.0505e-02, -6.7018e-02, -2.9841e-01, -4.3868e-01, -2.7085e-01,
          -6.1282e-04,  5.7602e-02],
         [-2.7535e-02,  1.6045e-02,  7.2595e-02, -5.4102e-02, -3.3285e-01,
          -4.2058e-01, -2.5781e-01],
         [ 3.0613e-02,  4.0960e-02,  6.2850e-02,  2.3897e-01,  4.1384e-01,
           3.9359e-01,  1.6606e-01],
         [-1.3736e-02, -3.6746e-03, -2.4084e-02, -6.5877e-02, -1.5070e-01,
          -8.2230e-02, -5.7828e-03]],

        [[-1.1397e-02, -2.6619e-02, -3.4641e-02,  3.6812e-02,  3.2521e-02,
           6.6221e-04, -2.5743e-02],
         [ 4.5687e-02,  3.3603e-02, -1.0453e-01, -3.0885e-01, -3.1253e-01,
 

In [54]:
# print("=> evaluate pruned model 01")
# validate(val_loader, model, criterion, device)

# finetune model
for epoch in range(1):
    train(val_loader, model, epoch, device, optimizer, criterion)
    scheduler.step()

print(model.conv1.weight.data[0])


Epoch: [0][ 501/2779]	Time  0.079 ( 0.095)	Data  0.047 ( 0.063)	Loss 6.6779e+00 (8.4837e+00)	Acc@1   0.00 (  0.47)	Acc@5   0.00 (  3.01)
tensor([[[ 0.3482,  0.3032,  0.3009,  0.3205,  0.2572,  0.2315,  0.2727],
         [ 0.3290,  0.2686,  0.1016, -0.0944, -0.1090,  0.0795,  0.3137],
         [ 0.2485,  0.2650,  0.4686,  0.6777,  0.6084,  0.4446,  0.3904],
         [ 0.3105,  0.1884, -0.0620, -0.2504, -0.0516,  0.3021,  0.4780],
         [ 0.2684,  0.2516,  0.2658,  0.1361, -0.1011, -0.0750,  0.1505],
         [ 0.3506,  0.2714,  0.2130,  0.3361,  0.5446,  0.6411,  0.4680],
         [ 0.2805,  0.2480,  0.2015,  0.0781,  0.0315,  0.2053,  0.3512]],

        [[ 0.8136,  0.7306,  0.7099,  0.7497,  0.6659,  0.5869,  0.5890],
         [ 0.8487,  0.7708,  0.5786,  0.3390,  0.2600,  0.4059,  0.6363],
         [ 0.7818,  0.8287,  1.0736,  1.3572,  1.2423,  0.9413,  0.8016],
         [ 0.7914,  0.6564,  0.3262,  0.1113,  0.2926,  0.6572,  0.8186],
         [ 0.7461,  0.7657,  0.7603,  0.5455,  

In [57]:
print(model.conv1.weight.data[0][0])
print(verification_model.conv1.weight.data[0][0])



tensor([[ 0.3482,  0.3032,  0.3009,  0.3205,  0.2572,  0.2315,  0.2727],
        [ 0.3290,  0.2686,  0.1016, -0.0944, -0.1090,  0.0795,  0.3137],
        [ 0.2485,  0.2650,  0.4686,  0.6777,  0.6084,  0.4446,  0.3904],
        [ 0.3105,  0.1884, -0.0620, -0.2504, -0.0516,  0.3021,  0.4780],
        [ 0.2684,  0.2516,  0.2658,  0.1361, -0.1011, -0.0750,  0.1505],
        [ 0.3506,  0.2714,  0.2130,  0.3361,  0.5446,  0.6411,  0.4680],
        [ 0.2805,  0.2480,  0.2015,  0.0781,  0.0315,  0.2053,  0.3512]],
       device='cuda:0')
tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],
        [ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],
        [-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],
        [ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],
        [-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],
        [ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],
        [-0.0137, -0.0037,

In [None]:

print("=> evaluate pruned model 02")
validate(val_loader, model, criterion, device)

save_model(model, model_str+"_tuned_" + str(prune), pruned=True)

# Rebuilding 01

In [20]:
bigger_model = resnet18(weights=ResNet18_Weights)
bigger_model = bigger_model.to(device)
model = model.to(device)
verification_model = verification_model.to(device)

verification_layers = get_layers(verification_model)
print("=> Starting rebuilding...")
layers_total = 0
layers_rebuilt_count = 0


for i, history in enumerate(reversed(out_history)):
    for pruned_layer_name, b, out_channels_removed in reversed(history):
        layers_total += 1

for i, history in enumerate(reversed(out_history)):
    # loop through each layer changed in pruning

    for pruned_layer_name, b, out_channels_removed in reversed(history):
        
        verification_layer = get_module_by_name(verification_model, pruned_layer_name)

        # loop through the layers of the larger model (same number of layers, different channel width)
        for layer_name, bigger_layer_params in bigger_model.named_parameters():

            skipped_out_channels = 0
            # if"module."+layer_name == pruned_layer_name+".weight":
            # if layer_name == pruned_layer_name + ".weight" and pruned_layer_name == "layer1.2.conv2":
            if layer_name == pruned_layer_name + ".weight":
                # get copy of layers
                tuned_layer = get_module_by_name(model, pruned_layer_name)
                bigger_layer = get_module_by_name(bigger_model, pruned_layer_name)
                # in_channels_removed = get_index_in_channel_history(bigger_layer, tuned_layer, out_channels_removed)
                in_channels_removed = in_history[pruned_layer_name]

                # loop throughout the channels of the bigger model
                for out_channel_idx in range(bigger_layer.out_channels):

                    # check if the channel has been dropped
                    if out_channel_idx in out_channels_removed:
                        # if channel was dropped, do not copy weights from smaller tuned model
                        skipped_out_channels += 1

                    else:
                        # copy weights from tuned model to larger model
                        if (bigger_layer.in_channels - tuned_layer.in_channels) == 0:
                            bigger_layer_params.data[out_channel_idx, :, :, :] =\
                                tuned_layer.weight.data[out_channel_idx - skipped_out_channels, :, :, :]


                        else:  # for conv layers with reshape of both input and output
                            skipped_in_channels = 0
                            for in_channel_idx in range(bigger_layer.in_channels):

                                if in_channel_idx in in_channels_removed:
                                    # if channel was dropped, do not copy weights from smaller tuned model
                                    skipped_in_channels += 1
                                else:
                                    bigger_layer_params.data[out_channel_idx, in_channel_idx, :, :] = \
                                        tuned_layer.weight.data[out_channel_idx - skipped_out_channels, in_channel_idx - skipped_in_channels, :, :]


        layers_rebuilt_count += 1
        print("("+str(layers_rebuilt_count), "of", str(layers_total)+")", pruned_layer_name, "has been rebuilt.")
        if torch.equal(bigger_layer_params.data, verification_layer.weight.data):
            print("Same:", pruned_layer_name, bigger_layer)
        else:
            print("Different (Correct):", pruned_layer_name)

save_model(bigger_model, model_str+"_rebuilt_01_"+str(prune))

print("=> evaluate rebuilt model 01 (no fine-tuning)")
validate(val_loader, bigger_model, criterion,  device)

optimizer = torch.optim.SGD(bigger_model.parameters(), 0.1,momentum=0.9)

scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
for epoch in range(1):
    train(val_loader, bigger_model, epoch, device, optimizer, criterion)
    scheduler.step()

# print("=> Starting fine-tuning of rebuilt model 01 (Only train tune pruned channels)...")
# tune(val_loader, rebuilt_model_01, criterion, out_history, in_history, device)

print("=> evaluate rebuilt/tuned 01")
validate(val_loader, bigger_model, criterion, device)

# print("=> Starting fine-tuning of rebuilt model (Only train tune pruned channels)...")
# tune(val_loader, rebuilt_model_02, criterion, out_history, in_history, device)



=> Starting rebuilding...
(1 of 12) layer4.1.conv1 has been rebuilt.
Different (Correct): layer4.1.conv1
(2 of 12) layer4.0.conv1 has been rebuilt.
Different (Correct): layer4.0.conv1
(3 of 12) layer3.1.conv1 has been rebuilt.
Different (Correct): layer3.1.conv1
(4 of 12) layer3.0.conv1 has been rebuilt.
Different (Correct): layer3.0.conv1
(5 of 12) layer2.1.conv1 has been rebuilt.
Different (Correct): layer2.1.conv1
(6 of 12) layer2.0.conv1 has been rebuilt.
Different (Correct): layer2.0.conv1
(7 of 12) layer1.1.conv1 has been rebuilt.
Different (Correct): layer1.1.conv1
(8 of 12) layer1.0.conv1 has been rebuilt.
Different (Correct): layer1.0.conv1
(9 of 12) conv1 has been rebuilt.
Different (Correct): conv1
(10 of 12) layer2.0.downsample.0 has been rebuilt.
Different (Correct): layer2.0.downsample.0
(11 of 12) layer3.0.downsample.0 has been rebuilt.
Different (Correct): layer3.0.downsample.0
(12 of 12) layer4.0.downsample.0 has been rebuilt.
Different (Correct): layer4.0.downsample.0

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
    with record_function("model_inference"):
        with torch.no_grad():
            print("")

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
               

# CIFAR10