In [None]:
from google.colab import drive # import drive from google colab

ROOT = "/content/drive"     # default location for the drive
print(ROOT)                 # print content of ROOT (Optional)

drive.mount(ROOT)

/content/drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
MY_GOOGLE_DRIVE_PATH = 'My Drive/Capstone_Prasham/'
data_dir = ROOT + MY_GOOGLE_DRIVE_PATH + '/Edge/data/'

In [None]:
from os.path import join
PROJECT_PATH = join(ROOT, MY_GOOGLE_DRIVE_PATH)

print("PROJECT_PATH: ", PROJECT_PATH)   

PROJECT_PATH:  /content/drive/My Drive/Capstone_Prasham/


In [None]:
%cd "{PROJECT_PATH}"
%cd "Edge"

/content/drive/My Drive/Capstone_Prasham
/content/drive/My Drive/Capstone_Prasham/Edge


# Model Pruning

In [None]:
# %%writefile model_pruning.py
import copy
from collections import OrderedDict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from scipy.stats import rankdata

def prune_model(model_artifact, prune_percentage):
    pruned_model = copy.deepcopy(model_artifact)
    weights = OrderedDict()
    weights = pruned_model.state_dict()
    layers = list(pruned_model.state_dict())
    ranks = dict()
    pruned_weights = list()
    # For each layer except the output one
    for l in layers[:-1]:
        data = weights[l].detach().cpu()
        w = np.array(data)
        # Rank the weights element wise and reshape rank elements as the model weights
        ranks[l] = (rankdata(np.abs(w), method='dense') -
                    1).astype(int).reshape(w.shape)
        # Get the threshold value based on the value of prune percentage
        lower_bound_rank = np.ceil(
            np.max(ranks[l]) * prune_percentage).astype(int)
        # Assign rank elements to 0 that are less than or equal to the threshold and 1 to those that are above.
        ranks[l][ranks[l] <= lower_bound_rank] = 0
        ranks[l][ranks[l] > lower_bound_rank] = 1
        # Multiply weights array with ranks to zero out the lower ranked weights
        w = w * ranks[l]
        # Assign the updated weights as tensor to data and append to the pruned_weights list
        data[...] = torch.from_numpy(w)
        pruned_weights.append(data)
    # Append the last layer weights as it is
    pruned_weights.append(weights[layers[-1]])
    # Update the model weights with all the updated weights
    new_state_dict = OrderedDict()
    for l, pw in zip(layers, pruned_weights):
        new_state_dict[l] = pw
    for name, params in pruned_model.named_parameters():
        params.data.copy_(new_state_dict[name])
    return pruned_model

def pruning_multiple(model_name, prune_percentage=[]):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model_object = 'model_artifacts/' + model_name
    results = pd.DataFrame(columns=['model', 'pruning_percentage', 'model_artifact', 'pruned_model_artifact',
                                    'train_loss', 'train_acc', 'test_loss', 'test_acc'])
    if(len(prune_percentage) == 0):
        prune_percentage = [.0, .25, .50, .60, .70, .80, .90, .95, .97, .99]
    model = torch.load(model_object, map_location=torch.device(device))
    weights = model.state_dict()
    print(weights)
    for p in prune_percentage:
        pruned_model = prune_model(model_artifact=model, prune_percentage=p)
        print("------------------------------------------------------------")
        print("Prune Percentage:", p)
        print("------------------------------------------------------------")
        pruned_weights = pruned_model.state_dict()
        print(pruned_weights)
        print("------------------------------------------------------------")

        results = results.append({'model': model_name,
                                  'pruning_percentage': p,
                                  'model_artifact': model,
                                  'pruned_model_artifact': pruned_model},
                                 ignore_index=True)
        # print('Results appended for Pruning:',p)
    return results

In [None]:
pruning_multiple('california_simple.pt')

OrderedDict([('net.0.weight', tensor([[-1.2034e-01,  1.0730e-01, -4.4607e-01, -4.0890e-01, -2.6614e-01,
         -5.6145e-02],
        [-1.0627e-01,  2.1744e-01, -1.4799e-01,  1.8995e-04, -2.3152e-01,
         -1.8809e-01],
        [-3.9002e-01, -2.7038e-01, -1.6829e-01,  1.5123e-02,  1.6139e-01,
          2.4496e-01],
        [ 2.5584e+00,  2.9751e+00,  2.3699e+00,  1.3530e+00,  2.3748e+00,
          9.6539e+00],
        [ 2.7513e+00,  3.1886e+00,  2.5843e+00,  6.2999e-01,  2.1943e+00,
          9.2278e+00],
        [-1.5914e-01,  3.5273e-01, -2.6462e-01, -1.8793e-01, -2.8522e-01,
         -3.8235e-01],
        [ 2.8045e+00,  3.6887e+00,  2.5928e+00,  1.4011e+00,  2.6690e+00,
          9.3341e+00],
        [-4.4704e-02, -4.8060e-01, -3.8933e-01, -3.0601e-01,  1.6289e-01,
          5.4286e-02],
        [ 2.9037e+00,  3.3945e+00,  2.7374e+00,  1.6738e+00,  2.8764e+00,
          9.6718e+00],
        [ 3.0754e+00,  2.9053e+00,  2.2818e+00,  6.7785e-01,  2.1624e+00,
          9.1411e+00]],

Unnamed: 0,model,pruning_percentage,model_artifact,pruned_model_artifact,train_loss,train_acc,test_loss,test_acc
0,california_simple.pt,0.0,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
1,california_simple.pt,0.25,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
2,california_simple.pt,0.5,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
3,california_simple.pt,0.6,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
4,california_simple.pt,0.7,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
5,california_simple.pt,0.8,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
6,california_simple.pt,0.9,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
7,california_simple.pt,0.95,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
8,california_simple.pt,0.97,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
9,california_simple.pt,0.99,DenseNeuralNet(\n (net): Sequential(\n (0)...,DenseNeuralNet(\n (net): Sequential(\n (0)...,,,,
