<a href="https://colab.research.google.com/github/malraharsh/Pruning-Experiments/blob/master/Pruning_Experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -q tensorflow-model-optimization

In [None]:
import tempfile
import os

import tensorflow as tf
import numpy as np

from sklearn.model_selection import train_test_split
import pandas as pd

from tensorflow import keras

# %load_ext tensorboard

# os.mkdir('log')

import tensorflow_model_optimization as tfmot


from IPython.display import display

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

In [3]:
SHOW = False

In [None]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(Train_images, Train_labels), (Test_images, Test_labels) = mnist.load_data()

# pct_data = 0.1
# top = int(np.ceil(Train_images.shape[0] * pct_data))

# (train_images, train_labels), (test_images, test_labels) = (Train_images[:top], Train_labels[:top]), (Test_images[:top], Test_labels[:top])
# print(top)

In [17]:
train_images = train_labels = test_images = test_labels = None

def change_pct_data(pct_data):    
    global train_images, train_labels, test_images, test_labels    
    top = int(np.ceil(Train_images.shape[0] * pct_data))
    (train_images, train_labels), (test_images, test_labels) = (Train_images[:top], Train_labels[:top]), (Test_images[:top], Test_labels[:top])
    print(f"No of data - {top}")

In [None]:
def train(train_images, train_labels, test_images, test_labels):

    # Normalize the input image so that each pixel value is between 0 to 1.
    train_images = train_images.copy() / 255.0 #!!!!! CAN REOMVE COPY
    test_images = test_images.copy() / 255.0

    # Define the model architecture.
    model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
    ])

    # Train the digit classification model
    model.compile(optimizer='adam',
                loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

    history = model.fit(
    train_images,
    train_labels,
    epochs=EPOCHS,
    validation_split=0.1,
    verbose=1
    )

    test_loss, test_accuracy = model.evaluate(test_images, test_labels, verbose=0)
    history.history['test_loss'] = [test_loss]
    history.history['test_accuracy'] = [test_accuracy]

    if SHOW:
        # _, baseline_train_accuracy = model.evaluate(
        # train_images, train_labels, verbose=0)

        # print('Baseline train accuracy:', baseline_train_accuracy*100)

        print('Baseline test accuracy:', baseline_test_accuracy*100)

        # print('Baseline difference:', (baseline_train_accuracy - baseline_test_accuracy)*100)

    return model, history.history
    

def prune(train_images, train_labels, model):

    # Compute end step to finish pruning after 2 epochs.
    batch_size = 128
    epochs = EPOCHS_PRUNE
    validation_split = 0.1 # 10% of training set will be used for validation set. 

    num_images = train_images.shape[0] * (1 - validation_split)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

    # Define model for pruning.
    pruning_params = {
          'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                                   final_sparsity=0.80,
                                                                   begin_step=0,
                                                                   end_step=end_step)
    }

    model_for_pruning = prune_low_magnitude(model, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

    # model_for_pruning.summary()

    callbacks = [
      tfmot.sparsity.keras.UpdatePruningStep(),
    #   tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
    ]

    history = model_for_pruning.fit(train_images, train_labels,
                      batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                      callbacks=callbacks)
    
    test_loss, test_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
    history.history['test_loss'] = [test_loss]
    history.history['test_accuracy'] = [test_accuracy]

    if SHOW:
        print('Pruned test accuracy:', model_for_pruning_accuracy)

    return model_for_pruning, history.history

In [None]:
def do_train(test_pct):
    train_x, val_x, train_y, val_y = train_test_split(train_images, train_labels, test_size=test_pct, stratify=train_labels)
    
    print("TRAINING ---")
    model, info_train = train(train_x, train_y, val_x, val_y)
    
    print("PRUNING ---")
    _, info_prune = prune(train_x, train_y, model)
    
    return info_train, info_prune

In [None]:
def add_info(dic, pct, info):
    # print(list(dic.items()), '----')
    dic = {k:v[-1] for k, v in dic.items()}
    dic['percentage'] = pct
    return info.append(dic, ignore_index=True)   

def save(df, name, pct_data):
    df.to_csv(f'log/info-{name}-{pct_data*100}%.csv')

def full_pct_data(pct_data): #pct of full data
    change_pct_data(pct_data) 
    df_info_train = pd.DataFrame()
    df_info_prune = pd.DataFrame()

    for p in range(10, 99, 90):
        print(f'\n Percentage is {pct_data} \n')

        test_p = 1 - p/100
        info_train, info_prune = do_train(test_p)
        z = info_train

        df_info_train = add_info(info_train, 1 - test_p, df_info_train)
        df_info_prune = add_info(info_prune, 1 - test_p, df_info_prune)

    save(df_info_train, 'train', pct_data)
    save(df_info_prune, 'prune', pct_data)

    # df_info.plot.scatter(x='percentage', y='accuracy')
    return df_info_train, df_info_prune    

In [148]:
def on_pct_data(p):
    x, y = full_pct_data(p)

    print()
    display('Trained', x)
    print(x.accuracy - x.test_accuracy)
    display('Pruned', y)
    print(y.accuracy - y.test_accuracy)
    # display('Difference', x - y)

In [152]:
EPOCHS = 20
EPOCHS_PRUNE = 1

on_pct_data(0.1)

No of data - 6000

 Percentage is 0.1 

TRAINING ---
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
PRUNING ---



'Trained'

Unnamed: 0,accuracy,loss,percentage,test_accuracy,test_loss,val_accuracy,val_loss
0,0.988889,0.081956,0.1,0.879074,0.426488,0.9,0.278398


0    0.109815
dtype: float64


'Pruned'

Unnamed: 0,accuracy,loss,percentage,test_accuracy,test_loss,val_accuracy,val_loss
0,0.981481,3.143262,0.1,0.8525,69.221054,0.883333,37.213779


0    0.128981
dtype: float64
