<a href="https://colab.research.google.com/github/malraharsh/Pruning-Experiments/blob/master/Pruning_Experiments_on_CIFAR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -q tensorflow-model-optimization

import tempfile
import os

import tensorflow as tf
import numpy as np

from sklearn.model_selection import train_test_split
import pandas as pd


from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

from tensorflow import keras

# %load_ext tensorboard

os.mkdir('log')

import tensorflow_model_optimization as tfmot


from IPython.display import display

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

SHOW = False

In [None]:
data = datasets.cifar10.load_data()

(Train_images, Train_labels), (Test_images, Test_labels) = data

# pct_data = 0.1
# top = int(np.ceil(Train_images.shape[0] * pct_data))

# (train_images, train_labels), (test_images, test_labels) = (Train_images[:top], Train_labels[:top]), (Test_images[:top], Test_labels[:top])
# print(top)

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

In [8]:
train_images = train_labels = test_images = test_labels = None

def change_pct_data(pct_data):    
    global train_images, train_labels, test_images, test_labels    
    top = int(np.ceil(Train_images.shape[0] * pct_data))
    (train_images, train_labels), (test_images, test_labels) = (Train_images[:top], Train_labels[:top]), (Test_images[:top], Test_labels[:top])
    print(f"No of data - {top}")

In [None]:
def train(train_images, train_labels, test_images, test_labels):

    # Normalize the input image so that each pixel value is between 0 to 1.
    train_images = train_images.copy() / 255.0 #!!!!! CAN REOMVE COPY
    test_images = test_images.copy() / 255.0

    # Define the model architecture.
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10))

    model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

    history = model.fit(train_images, train_labels, epochs=EPOCHS, 
                        validation_data=(test_images, test_labels), verbose=VERBOSE)

    #no test, only val data
    # test_loss, test_accuracy = model.evaluate(test_images, test_labels, verbose=0)
    # history.history['test_loss'] = [test_loss]
    # history.history['test_accuracy'] = [test_accuracy]

    if SHOW:

        plt.plot(history.history['accuracy'], label='accuracy')
        plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.ylim([0.5, 1])
        plt.legend(loc='lower right')

        # test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

    
    return model, history.history
    

def prune(train_images, train_labels, model):

    # Compute end step to finish pruning after 2 epochs.
    batch_size = 128
    epochs = EPOCHS_PRUNE
    validation_split = 0.1 # 10% of training set will be used for validation set. 

    num_images = train_images.shape[0] * (1 - validation_split)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

    # Define model for pruning.
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                                   final_sparsity=0.80,
                                                                   begin_step=0,
                                                                   end_step=end_step)
    }

    model_for_pruning = prune_low_magnitude(model, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    # model_for_pruning.summary()

    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep(),
    #   tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
    ]

    history = model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_data=(test_images, test_labels),
                      callbacks=callbacks, verbose=VERBOSE)
    
    # test_loss, test_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
    # history.history['test_loss'] = [test_loss]
    # history.history['test_accuracy'] = [test_accuracy]

    if SHOW:
        print('Pruned test accuracy:', test_accuracy)

    return model_for_pruning, history.history

In [None]:
def do_train(test_pct):
    train_x, val_x, train_y, val_y = train_test_split(train_images, train_labels, test_size=test_pct, stratify=train_labels)
    
    print("TRAINING ---")
    model, info_train = train(train_x, train_y, val_x, val_y)
    
    print("PRUNING ---")
    _, info_prune = prune(train_x, train_y, model)
    
    return info_train, info_prune

In [None]:
def add_info(dic, pct_train, info):
    # print(list(dic.items()), '----')
    dic = {k:v[-1] for k, v in dic.items()}
    dic['percentage'] = pct_train
    return info.append(dic, ignore_index=True)   

def save(df, name, pct_data):
    df.to_csv(f'log/info-{name}-{pct_data*100}%.csv')


# def savefile():
#     df_info_train = add_info(info_train, 1 - pct_test, df_info_train)
#     df_info_prune = add_info(info_prune, 1 - pct_test, df_info_prune)

#     save(df_info_train, 'train', pct_data)
#     save(df_info_prune, 'prune', pct_data)


def full_pct_data(pct_data, pct_test=0.2): #pct of full data
    change_pct_data(pct_data) 
    df_info_train = pd.DataFrame()
    df_info_prune = pd.DataFrame()

    print(f'\n Percentage of Whole data {pct_data*100}% Test data {pct_test*100}% \n')

    info_train, info_prune = do_train(pct_test)
    
    df_info_train = add_info(info_train, 1 - pct_test, df_info_train)
    df_info_prune = add_info(info_prune, 1 - pct_test, df_info_prune)

    # save(df_info_train, 'train', pct_data)
    # save(df_info_prune, 'prune', pct_data)

    # df_info.plot.scatter(x='percentage', y='accuracy')
    return df_info_train, df_info_prune    

In [74]:
def on_pct_data(p, ptest):
    global x, y, dft, dfp
    x, y = full_pct_data(p, ptest)
    x['epochs'] = EPOCHS
    x['pct_data'] = p
    x['pct_test'] = ptest

    y['epochs'] = EPOCHS_PRUNE
    y['pct_data'] = p
    y['pct_test'] = ptest

    dft = dft.append(x, ignore_index=True, verify_integrity=True) #trained
    dfp = dfp.append(y, ignore_index=True, verify_integrity=True)

    print()
    display('Trained', x)
    display('Pruned', y)
    print()
    print('Training Acc. Diff', (x.accuracy[0] - x.val_accuracy[0])*100)
    print('Pruned Acc. Diff', (y.accuracy[0] - y.val_accuracy[0])*100)
    # display('Difference', x - y)

In [53]:
EPOCHS = 10
EPOCHS_PRUNE = 3
VERBOSE = 1
SHOW = 0

In [75]:
dft = pd.DataFrame()
dfp = pd.DataFrame()

In [76]:
on_pct_data(0.1, 0.5)

No of data - 5000

 Percentage of Whole data 10.0% Test data 50.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.588,1.148828,0.5,0.4948,1.398556,10,0.1,0.5


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.5852,12.945297,0.5,0.4112,19.885075,3,0.1,0.5



Training Acc. Diff 9.319999814033508
Pruned Acc. Diff 17.400002479553223


In [77]:
on_pct_data(0.4, 0.5)

No of data - 20000

 Percentage of Whole data 40.0% Test data 50.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.6893,0.885093,0.5,0.5931,1.191358,10,0.4,0.5


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.2195,3.097613,0.5,0.1292,2.382054,3,0.4,0.5



Training Acc. Diff 9.619998931884766
Pruned Acc. Diff 9.030000865459442


In [78]:
on_pct_data(0.8, 0.5)

No of data - 40000

 Percentage of Whole data 80.0% Test data 50.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.7534,0.697526,0.5,0.65135,1.040999,10,0.8,0.5


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.1037,2.330379,0.5,0.1063,2.318101,3,0.8,0.5



Training Acc. Diff 10.205000638961792
Pruned Acc. Diff -0.25999993085861206


In [80]:
on_pct_data(1.0, 0.5)

No of data - 50000

 Percentage of Whole data 100.0% Test data 50.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.76232,0.667134,0.5,0.63868,1.115362,10,1.0,0.5


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.10596,2.312976,0.5,0.1092,2.31258,3,1.0,0.5



Training Acc. Diff 12.364000082015991
Pruned Acc. Diff -0.32400041818618774


In [81]:
on_pct_data(0.1, 0.9)

No of data - 5000

 Percentage of Whole data 10.0% Test data 90.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.39,1.670102,0.1,0.333111,1.862452,10,0.1,0.9


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.476,38.303448,0.1,0.298,53.598587,3,0.1,0.9



Training Acc. Diff 5.68888783454895
Pruned Acc. Diff 17.800000309944153


In [82]:
on_pct_data(0.1, 0.6)

No of data - 5000

 Percentage of Whole data 10.0% Test data 60.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.5715,1.19727,0.4,0.465667,1.477077,10,0.1,0.6


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.561,12.313752,0.4,0.3944,20.094696,3,0.1,0.6



Training Acc. Diff 10.583332180976868
Pruned Acc. Diff 16.659998893737793


In [83]:
on_pct_data(0.1, 0.3)

No of data - 5000

 Percentage of Whole data 10.0% Test data 30.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.628857,1.046791,0.7,0.486,1.519135,10,0.1,0.3


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.636571,15.862948,0.7,0.4712,29.236944,3,0.1,0.3



Training Acc. Diff 14.28571343421936
Pruned Acc. Diff 16.537141799926758


In [84]:
on_pct_data(0.1, 0.1)

No of data - 5000

 Percentage of Whole data 10.0% Test data 10.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.604222,1.1205,0.9,0.514,1.304809,10,0.1,0.1


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.576889,10.227122,0.9,0.423,13.741801,3,0.1,0.1



Training Acc. Diff 9.022223949432373
Pruned Acc. Diff 15.388885140419006


In [85]:
for i in range(10, 100, 30):
    on_pct_data(1, i/100)

No of data - 50000

 Percentage of Whole data 100% Test data 10.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.778667,0.627731,0.9,0.7082,0.888582,10,1,0.1


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.115,2.307281,0.9,0.1001,2.308947,3,1,0.1



Training Acc. Diff 7.0466697216033936
Pruned Acc. Diff 1.4899998903274536
No of data - 50000

 Percentage of Whole data 100% Test data 40.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.766333,0.669397,0.6,0.66735,1.015857,10,1,0.4


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.111067,2.325177,0.6,0.1039,2.314982,3,1,0.4



Training Acc. Diff 9.898334741592407
Pruned Acc. Diff 0.7166668772697449
No of data - 50000

 Percentage of Whole data 100% Test data 70.0% 

TRAINING ---
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
PRUNING ---
Epoch 1/3
Epoch 2/3
Epoch 3/3



'Trained'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.720267,0.79033,0.3,0.620629,1.118874,10,1,0.7


'Pruned'

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.098933,2.323755,0.3,0.1,2.308626,3,1,0.7



Training Acc. Diff 9.963804483413696
Pruned Acc. Diff -0.10666698217391968


In [86]:
dft

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.588,1.148828,0.5,0.4948,1.398556,10,0.1,0.5
1,0.6893,0.885093,0.5,0.5931,1.191358,10,0.4,0.5
2,0.7534,0.697526,0.5,0.65135,1.040999,10,0.8,0.5
3,0.76232,0.667134,0.5,0.63868,1.115362,10,1.0,0.5
4,0.39,1.670102,0.1,0.333111,1.862452,10,0.1,0.9
5,0.5715,1.19727,0.4,0.465667,1.477077,10,0.1,0.6
6,0.628857,1.046791,0.7,0.486,1.519135,10,0.1,0.3
7,0.604222,1.1205,0.9,0.514,1.304809,10,0.1,0.1
8,0.778667,0.627731,0.9,0.7082,0.888582,10,1.0,0.1
9,0.766333,0.669397,0.6,0.66735,1.015857,10,1.0,0.4


In [87]:
dfp

Unnamed: 0,accuracy,loss,percentage,val_accuracy,val_loss,epochs,pct_data,pct_test
0,0.5852,12.945297,0.5,0.4112,19.885075,3,0.1,0.5
1,0.2195,3.097613,0.5,0.1292,2.382054,3,0.4,0.5
2,0.1037,2.330379,0.5,0.1063,2.318101,3,0.8,0.5
3,0.10596,2.312976,0.5,0.1092,2.31258,3,1.0,0.5
4,0.476,38.303448,0.1,0.298,53.598587,3,0.1,0.9
5,0.561,12.313752,0.4,0.3944,20.094696,3,0.1,0.6
6,0.636571,15.862948,0.7,0.4712,29.236944,3,0.1,0.3
7,0.576889,10.227122,0.9,0.423,13.741801,3,0.1,0.1
8,0.115,2.307281,0.9,0.1001,2.308947,3,1.0,0.1
9,0.111067,2.325177,0.6,0.1039,2.314982,3,1.0,0.4
