In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import datasets
import matplotlib.pyplot as plt
from model import myModel
import os
import copy

os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
tf.random.set_seed(0)
BATCH_SIZE = 16

In [None]:
# Load the testing dataset and the construct a dataloader
data = np.load('./data/testset.npz')
test_x, test_y = data['test_x'], data['test_y']
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
test_dataset = test_dataset.shuffle(buffer_size=10000).batch(BATCH_SIZE)

# Load the pre-trained model
model = myModel()
baseline = tf.train.latest_checkpoint('./models/')
model.load_weights(baseline).expect_partial()
opt = tf.keras.optimizers.Adam(learning_rate = 1e-3)
model.compile(
    optimizer=opt,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)
model.build(input_shape=(BATCH_SIZE,10))
w_dict, b_dict = model.get_params()

w = []
for item in model.trainable_variables:
    w.extend(item.numpy().ravel())

w_sort = np.sort(np.abs(w), kind='mergesort')

In [None]:
#prune_ratios = np.linspace(0,1.0,50)
prune_ratios = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6] # USER PARAMETER (PRUNING RATIO)

intbit = None  # USER PARAMETER (QUANTIZATION)
fracbit = None  # USER PARAMETER (QUANTIZATION)

thres_all = {}
accuracy_all = {}
loss_all = {}

for i, prune_ratio in enumerate(prune_ratios):
    w_dict_new = copy.deepcopy(w_dict)
    b_dict_new = copy.deepcopy(b_dict)
    
    if(int(len(w)*prune_ratio)>=len(w_sort)):
        break
        
    thres = w_sort[int(len(w)*prune_ratio)]
    for m in w_dict_new.keys():
        w_dict_new[m][np.abs(w_dict_new[m])<thres] = 0.
        b_dict_new[m][np.abs(b_dict_new[m])<thres] = 0.
    
    my_model = myModel()
    baseline = tf.train.latest_checkpoint('./models/')
    my_model.load_weights(baseline).expect_partial()
    opt = tf.keras.optimizers.Adam(learning_rate = 1e-3)
    my_model.compile(
        optimizer=opt,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )
    my_model.build(input_shape=(BATCH_SIZE,10))

    # Optimize the model
    my_model.update_params(w_dict_new, b_dict_new)
    my_model.quantize_params(intbit=intbit, fracbit=fracbit)

    size, size_base, drop_ratio = my_model.compute_model_size()
    loss, accuracy = my_model.evaluate(test_dataset, verbose=3)
    
    print(f"Testing step {i:3d}/{len(prune_ratios):d} \t Prune {prune_ratio*100:3.2f}% (threshold {thres:.12f}) - Dropped {drop_ratio*100:3.2f}% - Accuracy {accuracy*100:3.2f}%")
    
    thres_all[prune_ratio] = thres
    accuracy_all[prune_ratio] = accuracy


In [None]:
fig, [ax1, ax2] = plt.subplots(2,1, figsize=(10,10))
ax1.scatter(list(accuracy_all.keys()), list(accuracy_all.values()), marker='o', s=5)
ax1.set_xlabel("Pruning rate")
ax1.set_ylabel("Accuracy")
ax2.scatter(list(accuracy_all.keys()), list(thres_all.values()), marker='o', s=5)
ax2.set_xlabel("Pruning rate")
ax2.set_ylabel("Threshold")