<a href="https://colab.research.google.com/github/kanru-wang/coursera_quantization_pruning_distillation/blob/main/Quantization_and_Pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quantization and Pruning

## Imports

In [1]:
import tensorflow as tf
import numpy as np
import os
import tempfile
import zipfile

<a name='utilities'>

## Utilities and constants

In [2]:
# GLOBAL VARIABLES

# String constants for model filenames
FILE_WEIGHTS = 'baseline_weights.h5'
FILE_NON_QUANTIZED_H5 = 'non_quantized.h5'
FILE_NON_QUANTIZED_TFLITE = 'non_quantized.tflite'
FILE_PT_QUANTIZED = 'post_training_quantized.tflite'
FILE_QAT_QUANTIZED = 'quant_aware_quantized.tflite'
FILE_PRUNED_MODEL_H5 = 'pruned_model.h5'
FILE_PRUNED_QUANTIZED_TFLITE = 'pruned_quantized.tflite'
FILE_PRUNED_NON_QUANTIZED_TFLITE = 'pruned_non_quantized.tflite'

# Dictionaries to hold measurements
MODEL_SIZE = {}
ACCURACY = {}

In [3]:
# UTILITY FUNCTIONS

def print_metric(metric_dict, metric_name):
    '''Prints key and values stored in a dictionary'''
    for metric, value in metric_dict.items():
        print(f'{metric_name} for {metric}: {value}')


def model_builder():
    '''Returns a shallow CNN for training on the MNIST dataset'''
    keras = tf.keras
    # Define the model architecture.
    model = keras.Sequential([
        keras.layers.InputLayer(input_shape=(28, 28)),
        keras.layers.Reshape(target_shape=(28, 28, 1)),
        keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(10, activation='softmax')
    ])
    return model


def evaluate_tflite_model(filename, x_test, y_test):
    '''
    Measures the accuracy of a given TF Lite model and test set
  
    Args:
        filename (string) - filename of the model to load
        x_test (numpy array) - test images
        y_test (numpy array) - test labels

    Returns
        float showing the accuracy against the test set
    '''

    # Initialize the TF Lite Interpreter and allocate tensors
    interpreter = tf.lite.Interpreter(model_path=filename)
    interpreter.allocate_tensors()

    # Get input and output index
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Initialize empty predictions list
    prediction_digits = []
  
    # Run predictions on every image in the "test" dataset.
    for i, test_image in enumerate(x_test):
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == y_test).mean()
    
    return accuracy


def get_gzipped_model_size(file):
    '''Returns size of gzipped model, in bytes.'''
    _, zipped_file = tempfile.mkstemp('.zip')
    with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
        f.write(file)

    return os.path.getsize(zipped_file)

## Download and Prepare the Dataset

In [4]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## Baseline Model

In [5]:
# Create the baseline model
baseline_model = model_builder()

# Save the initial weights for use later
baseline_model.save_weights(FILE_WEIGHTS)

# Print the model summary
baseline_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 12)       0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
____________________________________________________

In [6]:
# Setup the model for training
baseline_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model
baseline_model.fit(train_images, train_labels, epochs=1, shuffle=False)



<keras.callbacks.History at 0x7fdd89657af0>

In [7]:
# Get the baseline accuracy
_, ACCURACY['baseline Keras model'] = baseline_model.evaluate(test_images, test_labels)



In [8]:
# Save the Keras model
baseline_model.save(FILE_NON_QUANTIZED_H5, include_optimizer=False)

# Save and get the model size
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)

# Print records so far
print_metric(ACCURACY, "test accuracy")
print_metric(MODEL_SIZE, "model size in bytes")

test accuracy for baseline Keras model: 0.9620000123977661
model size in bytes for baseline h5: 98968


### Convert the model to TF Lite format

In [9]:
def convert_tflite(model, filename, quantize=False):
    '''
    Converts the model to TF Lite format and writes to a file

    Args:
        model (Keras model) - model to convert to TF Lite
        filename (string) - string to use when saving the file
        quantize (bool) - flag to indicate quantization

    Returns:
        None
    '''
    
    # Initialize the converter
    converter = tf.lite.TFLiteConverter.from_keras_model(model)

    # Set for quantization if flag is set to True
    if quantize:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]

    # Convert the model
    tflite_model = converter.convert()

    # Save the model.
    with open(filename, 'wb') as f:
        f.write(tflite_model)

This is not yet quantized

In [10]:
# Convert baseline model
convert_tflite(baseline_model, FILE_NON_QUANTIZED_TFLITE)



Slight decrease in model size when converting to `.tflite` format.

In [11]:
MODEL_SIZE['non quantized tflite'] = os.path.getsize(FILE_NON_QUANTIZED_TFLITE)

print_metric(MODEL_SIZE, 'model size in bytes')

model size in bytes for baseline h5: 98968
model size in bytes for non quantized tflite: 85012


*If there is a `Runtime Error: There is at least 1 reference to internal data in the interpreter in the form of a numpy array or slice.` , try re-running the cell.*

In [12]:
ACCURACY['non quantized tflite'] = evaluate_tflite_model(FILE_NON_QUANTIZED_TFLITE, test_images, test_labels)

In [13]:
print_metric(ACCURACY, 'test accuracy')

test accuracy for baseline Keras model: 0.9620000123977661
test accuracy for non quantized tflite: 0.962


### Post-Training Quantization

Convert 32 bit representations (float) into 8 bits (integer) to reduce model size and achieve faster computation

In [14]:
# Convert and quantize the baseline model
convert_tflite(baseline_model, FILE_PT_QUANTIZED, quantize=True)



In [15]:
# Get the model size
MODEL_SIZE['post training quantized tflite'] = os.path.getsize(FILE_PT_QUANTIZED)

print_metric(MODEL_SIZE, 'model size')

model size for baseline h5: 98968
model size for non quantized tflite: 85012
model size for post training quantized tflite: 24256


About 4X reduction in model size in the quantized version


In [16]:
ACCURACY['post training quantized tflite'] = evaluate_tflite_model(FILE_PT_QUANTIZED, test_images, test_labels)

In [17]:
print_metric(ACCURACY, 'test accuracy')

test accuracy for baseline Keras model: 0.9620000123977661
test accuracy for non quantized tflite: 0.962
test accuracy for post training quantized tflite: 0.9616


## Quantization Aware Training

Doing quantization aware training before quantizing the model in order to preserve more accuracy. It simulates the loss of precision by inserting fake quant nodes in the model during training, so the model will learn to adapt with the loss of precision to get more accurate predictions.

In [18]:
# Install the toolkit
!pip install tensorflow_model_optimization

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow_model_optimization
  Downloading tensorflow_model_optimization-0.7.3-py2.py3-none-any.whl (238 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/238.9 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m238.9/238.9 KB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorflow_model_optimization
Successfully installed tensorflow_model_optimization-0.7.3


If we need to pass in a model that is already trained, need to recompile before continue training.

In [19]:
import tensorflow_model_optimization as tfmot

# method to quantize a Keras model
quantize_model = tfmot.quantization.keras.quantize_model

# Define the model architecture. Still the baseline model
model_to_quantize = model_builder()

# Reinitialize weights with saved file
model_to_quantize.load_weights(FILE_WEIGHTS)

# Quantize the model
q_aware_model = quantize_model(model_to_quantize)

# `quantize_model` requires a recompile.
q_aware_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

q_aware_model.summary()

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLay  (None, 28, 28)           3         
 er)                                                             
                                                                 
 quant_reshape_1 (QuantizeWr  (None, 28, 28, 1)        1         
 apperV2)                                                        
                                                                 
 quant_conv2d_1 (QuantizeWra  (None, 26, 26, 12)       147       
 pperV2)                                                         
                                                                 
 quant_max_pooling2d_1 (Quan  (None, 13, 13, 12)       1         
 tizeWrapperV2)                                                  
                                                                 
 quant_flatten_1 (QuantizeWr  (None, 2028)            

The number of model params increased because of the nodes added by the `quantize_model()` method.

In [20]:
# Train the model
q_aware_model.fit(train_images, train_labels, epochs=1, shuffle=False)



<keras.callbacks.History at 0x7fdd895d89a0>

In [21]:
# Reinitialize the dictionary
ACCURACY = {}

# Get the accuracy of the quantization aware trained model (not yet quantized)
_, ACCURACY['quantization aware non-quantized'] = q_aware_model.evaluate(test_images, test_labels, verbose=0)
print_metric(ACCURACY, 'test accuracy')

test accuracy for quantization aware non-quantized: 0.9614999890327454


In [22]:
# Convert and quantize the model.
convert_tflite(q_aware_model, FILE_QAT_QUANTIZED, quantize=True)

# Get the accuracy of the quantized model
ACCURACY['quantization aware quantized'] = evaluate_tflite_model(FILE_QAT_QUANTIZED, test_images, test_labels)
print_metric(ACCURACY, 'test accuracy')



test accuracy for quantization aware non-quantized: 0.9614999890327454
test accuracy for quantization aware quantized: 0.9617


## Pruning

Zero out insignificant low magnitude weights. Making the weights sparse helps in compressing the model.

Pass in the baseline model trained earlier. The number of model params increased because of the wrapper layers added by the pruning method.

In [23]:
# Get the pruning method
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define pruning schedule.
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.50,
        final_sparsity=0.80,
        begin_step=0,
        end_step=end_step
    )
}

# Pass in the trained baseline model
model_for_pruning = prune_low_magnitude(baseline_model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model_for_pruning.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_reshape  (None, 28, 28, 1)        1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_conv2d   (None, 26, 26, 12)       230       
 (PruneLowMagnitude)                                             
                                                                 
 prune_low_magnitude_max_poo  (None, 13, 13, 12)       1         
 ling2d (PruneLowMagnitude)                                      
                                                                 
 prune_low_magnitude_flatten  (None, 2028)             1         
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_dense (  (None, 10)               4

Peek at the weights of one layer before pruning. After pruning, many of these will be zeroed out.

In [24]:
# Preview model weights
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[ 0.21670069,  0.14448787, -0.3921924 , -0.5839371 ,
           0.27797788,  0.11577412, -0.39421183, -0.5164737 ,
           0.03561811,  0.23253371,  0.43321967,  0.17293948]],

        [[ 0.35277528,  0.47311407, -0.05976715, -0.52713066,
           0.07559928,  0.22698393, -0.06401199, -0.3694725 ,
          -0.11109661,  0.27704132,  0.13577738,  0.29972312]],

        [[-0.00256522,  0.7086048 ,  0.37773892, -0.3767449 ,
           0.08378106,  0.05992218, -0.09731199,  0.32304454,
           0.03348943, -0.07085457, -0.7070701 ,  0.1516765 ]]],


       [[[ 0.15074071, -0.2181239 , -0.4557513 , -0.07979532,
           0.27309716,  0.22540969, -0.18514894, -0.6064833 ,
          -0.20642167,  0.15235803,  0.5228574 , -0.05335016]],

        [[-0.10594414,  0.13728   ,  0.14186494,  0.23395713,
          -0.0799045 ,  0.1545502 , -0.41689992,  0.03875222,
           0.2576298 ,  0.17784038,  0.173412

Start re-training

In [25]:
# Callback to update pruning wrappers at each step
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
]

# Train and prune the model
model_for_pruning.fit(
    train_images,
    train_labels,
    epochs=epochs,
    validation_split=validation_split,
    callbacks=callbacks
)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7fdd8901da30>

Weights in the same layer after pruning

In [26]:
# Preview model weights
model_for_pruning.weights[1]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[-0.        ,  0.        , -0.        , -1.0433143 ,
           0.        , -0.        , -0.9726265 , -1.0005678 ,
          -0.        ,  0.        ,  0.74551606,  0.        ]],

        [[ 0.93137616,  0.82501125, -0.        , -0.31208834,
           0.        , -0.        , -0.        , -0.        ,
           0.        ,  0.        , -0.        ,  0.        ]],

        [[-0.        ,  1.1437249 ,  0.6497642 , -0.        ,
           0.        , -0.        , -0.        , -0.        ,
           0.        , -0.        , -1.0637197 ,  0.        ]]],


       [[[-0.        ,  0.        ,  0.        , -0.        ,
           0.        , -0.        , -0.        , -1.0609056 ,
           0.        ,  0.        ,  0.9937551 ,  0.        ]],

        [[-0.        ,  0.        , -0.        , -0.        ,
           0.        , -0.        , -0.        , -0.        ,
           0.        ,  0.        , -0.      

After pruning, remove the wrapper layers to have the same layers and params as the baseline model.

In [27]:
# Remove pruning wrappers
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
model_for_export.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape (Reshape)           (None, 28, 28, 1)         0         
                                                                 
 conv2d (Conv2D)             (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 13, 13, 12)       0         
 )                                                               
                                                                 
 flatten (Flatten)           (None, 2028)              0         
                                                                 
 dense (Dense)               (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
____________________________________________________

For the same model weights, the index is now different, because the wrappers were removed.

In [28]:
# Preview model weights (index 1 earlier is now 0 because pruning wrappers were removed)
model_for_export.weights[0]

<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 12) dtype=float32, numpy=
array([[[[-0.        ,  0.        , -0.        , -1.0433143 ,
           0.        , -0.        , -0.9726265 , -1.0005678 ,
          -0.        ,  0.        ,  0.74551606,  0.        ]],

        [[ 0.93137616,  0.82501125, -0.        , -0.31208834,
           0.        , -0.        , -0.        , -0.        ,
           0.        ,  0.        , -0.        ,  0.        ]],

        [[-0.        ,  1.1437249 ,  0.6497642 , -0.        ,
           0.        , -0.        , -0.        , -0.        ,
           0.        , -0.        , -1.0637197 ,  0.        ]]],


       [[[-0.        ,  0.        ,  0.        , -0.        ,
           0.        , -0.        , -0.        , -1.0609056 ,
           0.        ,  0.        ,  0.9937551 ,  0.        ]],

        [[-0.        ,  0.        , -0.        , -0.        ,
           0.        , -0.        , -0.        , -0.        ,
           0.        ,  0.        , -0.      

The pruned model has the same file size as the baseline_model when saved as H5, which is to be expected. The improvement will be noticeable after compressing.

In [29]:
# Save Keras model
model_for_export.save(FILE_PRUNED_MODEL_H5, include_optimizer=False)

# Get uncompressed model size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = os.path.getsize(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned non quantized h5'] = os.path.getsize(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, 'model_size in bytes')



model_size in bytes for baseline h5: 98968
model_size in bytes for pruned non quantized h5: 98968


The pruned model is about 3 times smaller, because the zeros can be compressed much more efficiently than the low magnitude weights before pruning.

In [30]:
# Get compressed size of baseline and pruned models
MODEL_SIZE = {}
MODEL_SIZE['baseline h5'] = get_gzipped_model_size(FILE_NON_QUANTIZED_H5)
MODEL_SIZE['pruned non quantized h5'] = get_gzipped_model_size(FILE_PRUNED_MODEL_H5)

print_metric(MODEL_SIZE, "gzipped model size in bytes")

gzipped model size in bytes for baseline h5: 78044
gzipped model size in bytes for pruned non quantized h5: 25768


Can make the model even more lightweight by quantizing the pruned model. About 10X reduction in compressed model size as compared to the baseline.

In [31]:
# Convert and quantize the pruned model.
pruned_quantized_tflite = convert_tflite(model_for_export, FILE_PRUNED_QUANTIZED_TFLITE, quantize=True)

# Compress and get the model size
MODEL_SIZE['pruned quantized tflite'] = get_gzipped_model_size(FILE_PRUNED_QUANTIZED_TFLITE)
print_metric(MODEL_SIZE, "gzipped model size in bytes")



gzipped model size in bytes for baseline h5: 78044
gzipped model size in bytes for pruned non quantized h5: 25768
gzipped model size in bytes for pruned quantized tflite: 8401


Accuracy remains good

In [33]:
# Get accuracy of pruned Keras and TF Lite models
ACCURACY = {}

_, ACCURACY['pruned model h5'] = model_for_pruning.evaluate(test_images, test_labels)
ACCURACY['pruned and quantized tflite'] = evaluate_tflite_model(FILE_PRUNED_QUANTIZED_TFLITE, test_images, test_labels)

print_metric(ACCURACY, 'accuracy')

accuracy for pruned model h5: 0.9715999960899353
accuracy for pruned and quantized tflite: 0.9717
