Necessary imports

In [None]:
!pip install tensorflow-model-optimization

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/242.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0


In [None]:
import os
import tensorflow as tf
import tensorflow_model_optimization as tfmot
import numpy as np

In [None]:
# # function to create lenet model
# def build_lenet5_mnist():


#     model = tf.keras.Sequential([

#         # 4 layers -> 2 convolutional and 2 pooling
#         tf.keras.layers.Conv2D(6, kernel_size = 5, strides = 1, padding = 'same', activation = 'relu', input_shape = (28, 28, 1)),
#         tf.keras.layers.MaxPooling2D(pool_size = 2, strides = 2),

#         tf.keras.layers.Conv2D(16, kernel_size = 5, strides = 1, activation = 'relu'),
#         tf.keras.layers.MaxPooling2D(pool_size = 2, strides = 2),

#         # 2D -> 1D
#         tf.keras.layers.Flatten(),

#         # 3 fully connected layers
#         tf.keras.layers.Dense(120, activation = 'relu'),
#         tf.keras.layers.Dense(84, activation = 'relu'),
#         tf.keras.layers.Dense(10, activation = 'softmax')
#     ])

#     return model

In [None]:
# Define a pruning schedule
import tensorflow as tf
import tensorflow_model_optimization as tfmot

def build_lenet5_mnist_with_pruning():
    # Define an even more aggressive pruning schedule (up to 98% sparsity over 6000 steps)
    pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0,  # Start with no pruning
        final_sparsity=0.98,   # Increase final sparsity to 98%
        begin_step=0,          # Start pruning immediately
        end_step=6000          # Extend pruning to 6000 steps
    )

    # Build model with extreme pruning applied
    model = tf.keras.Sequential([
        # More aggressively pruned Conv2D Layer
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Conv2D(6, kernel_size=5, strides=1, padding='same', activation='relu', input_shape=(28, 28, 1)),
            pruning_schedule=pruning_schedule
        ),
        tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),

        # More aggressively pruned Conv2D Layer
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Conv2D(16, kernel_size=5, strides=1, activation='relu'),
            pruning_schedule=pruning_schedule
        ),
        tf.keras.layers.MaxPooling2D(pool_size=2, strides=2),

        # Flatten Layer (not pruned)
        tf.keras.layers.Flatten(),

        # Extremely pruned Dense Layers
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Dense(120, activation='relu'),
            pruning_schedule=pruning_schedule
        ),
        tfmot.sparsity.keras.prune_low_magnitude(
            tf.keras.layers.Dense(84, activation='relu'),
            pruning_schedule=pruning_schedule
        ),

        # Output Layer (No Pruning Applied)
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    return model



In [None]:
def load_data_for_pruning(validation_split=0.25):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Normalize images to [0,1] and convert to float32
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # Reshape from (28, 28) to (28, 28, 1) to match CNN input format
    x_train = x_train[..., tf.newaxis]
    x_test = x_test[..., tf.newaxis]

    # **Keep labels as integers** (NOT one-hot encoded) for sparse_categorical_crossentropy
    y_train = y_train.astype(np.int32)
    y_test = y_test.astype(np.int32)

    # Create validation set from training data
    if validation_split is not None:
        num_validation_samples = int(validation_split * x_train.shape[0])
        x_train, x_val = x_train[:-num_validation_samples], x_train[-num_validation_samples:]
        y_train, y_val = y_train[:-num_validation_samples], y_train[-num_validation_samples:]
        return (x_train, y_train), (x_val, y_val), (x_test, y_test)

    else:
        return (x_train, y_train), (x_test, y_test), (x_test, y_test)


In [None]:
# function to load MNIST dataset and do some preprocessing
def load_data(validation_split = 0.25):
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Normalization can reduce training time significantly and (usually) increases the model's accuracy
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0

    # expand dimensions from (28, 28) to (28, 28, 1) to match the input format of the first convolutional layer of the model
    x_train = x_train[..., tf.newaxis]
    x_test = x_test[..., tf.newaxis]

    # Make sure that your labels are in numerical form
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)

    # A subset of the dataset can be used in order to monitor how our network behaves during training
    # (validation dataset) and helps us avoid overfitting the model to the training dataset.
    # The network's parameters are not updated when examining this subset of data

    # create validation dataset
    if validation_split is not None:
        num_validation_samples = int(validation_split * x_train.shape[0])
        x_train, x_val = x_train[:-num_validation_samples], x_train[-num_validation_samples:]
        y_train, y_val = y_train[:-num_validation_samples], y_train[-num_validation_samples:]
        return (x_train, y_train), (x_val, y_val), (x_test, y_test)

    else:
        return (x_train, y_train), (x_test, y_test), (x_test, y_test)

In [None]:
# function to train the model
def train_model(model, X_train, y_train, X_val, y_val, epochs = 25, learning_rate = 0.001,
                patience = 5, batch_size = 32):

    # You can experiment with different optimizers and learning rates (no need to focus on that though !!!)
    model.compile(optimizer = tf.keras.optimizers.SGD(learning_rate = learning_rate, momentum = 0.9),
                  loss = 'categorical_crossentropy', metrics = ['accuracy'])

    # Apply early stopping to speed up training and avoid overfitting (very helpful with smaller datasets)
    # In this instance, if the validation loss does not drop over 0.001 for <patience> number of epochs,
    # the training stops.

    early_stopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = patience,
                                   min_delta = 0.001, restore_best_weights = True)

    # Train the model
    model.fit(X_train, y_train, epochs = epochs, batch_size = batch_size,
              validation_data = (X_val, y_val),
              callbacks = [early_stopping])

    return model

In [None]:
def train_prune_model(model, X_train, y_train, X_val, y_val, epochs=25, learning_rate=0.001, patience=5, batch_size=32):
    # Compile the model (ensure pruning is applied)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    # Show model summary to verify pruning is applied
    model.summary()

    # Define pruning callbacks
    pruning_callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),  # Ensures pruning is updated during training
        tfmot.sparsity.keras.PruningSummaries(log_dir='./pruning_logs'),  # Logs pruning progress
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience, restore_best_weights=True)  # Prevents overfitting
    ]

    # Train model (pruning is applied dynamically)
    model.fit(X_train, y_train,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(X_val, y_val),
              callbacks=pruning_callbacks)

    return model  # Return trained model


In [None]:
# function to evaluate the accuracy of the trained model
def evaluate_model(model, X_test, y_test):
    _, test_accuracy = model.evaluate(X_test, y_test)
    print(f"Test Accuracy: {test_accuracy * 100:.2f}%")
    return test_accuracy * 100

In [None]:
# function to provide a small dataset sample for integer quantization
def representative_dataset():
    """"
    Representative dataset for integer quantization (calibration data to scale the
    weights and inputs to the integer domain)
    """
    for i in range(100):
        yield [X_train[i:i+1].astype('float32')]

In [None]:
# function to cnvert the model to TFLite format (to use in our device)
def convert_to_tflite(model, filename = "pruned_model.tflite"):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()
    with open(filename, "wb") as f:
        f.write(tflite_model)
    print(f"Model converted to TFLite and saved as {filename}")

    model_size = os.path.getsize(filename) / 1024  # Size in KB
    print(f"TFLite Model Size: {model_size:.2f} KB")

    return filename


In [None]:
# Function to generate a representative dataset for quantization
def representative_dataset():
    for _ in range(100):  # Use a small batch of 100 samples
        data = np.random.rand(1, 28, 28, 1).astype(np.float32)  # Match model input shape
        yield [data]

# Convert a pruned model to a quantized TFLite model
def convert_to_quantized_tflite(model, filename="pruned_model.tflite"):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)

    # Enable post-training quantization
    converter.optimizations = [tf.lite.Optimize.DEFAULT]

    # Set a representative dataset for better quantization accuracy
    converter.representative_dataset = representative_dataset

    # Enforce full integer quantization
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8  # Use int8 instead of uint8
    converter.inference_output_type = tf.int8

    # Convert the model
    tflite_model = converter.convert()

    # Save the quantized model
    with open(filename, "wb") as f:
        f.write(tflite_model)

    print(f" Quantized model saved as {filename}")

    # Print model size
    model_size = os.path.getsize(filename) / 1024  # Convert bytes to KB
    print(f" Quantized Model Size: {model_size:.2f} KB")

    return filename

In [None]:
# function to perform inference for a tflite model
def tflite_inference(tflite_model_path, X_test, y_test):
    """
    Perform inference using a quantized TFLite model.
    """

    # Load the TFLite model
    interpreter = tf.lite.Interpreter(model_path=tflite_model_path)
    interpreter.allocate_tensors()

    # Get input and output details
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Get input and output dtype (for quantized model, should be int8)
    input_dtype = input_details[0]['dtype']
    output_dtype = output_details[0]['dtype']
    print(f"Input dtype: {input_dtype}, Output dtype: {output_dtype}")

    # Fix: Handle both int8 and uint8 quantization formats
    if input_dtype in [np.uint8, np.int8]:
        input_scale, input_zero_point = input_details[0]['quantization']
        X_test = (X_test / input_scale + input_zero_point).astype(input_dtype)

    correct = 0
    total = X_test.shape[0]

    # Run inference on all test samples
    for i in range(total):
        input_data = X_test[i:i+1]  # Select one sample
        interpreter.set_tensor(input_details[0]['index'], input_data)
        interpreter.invoke()  # Run inference

        output_data = interpreter.get_tensor(output_details[0]['index'])

        # Fix: Dequantize output if necessary
        if output_dtype in [np.uint8, np.int8]:
            output_scale, output_zero_point = output_details[0]['quantization']
            output_data = (output_data.astype(np.float32) - output_zero_point) * output_scale

        # Apply softmax & get predicted label
        probabilities = tf.nn.softmax(output_data[0]).numpy()
        predicted_label = np.argmax(probabilities)

        # Fix: Check if y_test is one-hot encoded or categorical
        if len(y_test.shape) > 1:  # One-hot encoded
            true_label = np.argmax(y_test[i])
        else:  # Already categorical
            true_label = y_test[i]

        if predicted_label == true_label:
            correct += 1

    # Compute accuracy
    accuracy = (correct / total) * 100
    print(f"TFLite Model Accuracy: {accuracy:.2f}%")
    return accuracy


In [None]:
# Specify device (CPU or GPU)
device_name = "/GPU:0" if tf.config.list_physical_devices('GPU') else "/CPU:0"
print(f"Training on: {device_name}")

with tf.device(device_name):

    # lenet model
    pruned_model = build_lenet5_mnist_with_pruning()
    pruned_model.summary()

    epochs = [1, 1]
    learning_rate = [0.001, 0.0005]
    BATCH_SZ = 32
    patience = 3

    # load mnist
    (X_train, y_train), (X_val, y_val), (X_test, y_test) = load_data_for_pruning()

    # perform training
    for i, (e, lr) in enumerate(zip(epochs, learning_rate)):
        print(f"\nStarting training iteration {i + 1} with {e} epochs and learning rate {lr}")
        pruned_model = train_prune_model(pruned_model, X_train, y_train, X_val, y_val, e, lr, patience, BATCH_SZ)

    print("\n\nFinal Evaluation on Test Data:")
    pruned1_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
    # Recompile the model (Use the same loss and optimizer as before)
    pruned1_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    initial_model_accuracy = evaluate_model(pruned1_model, X_test, y_test)

    ########## Let's convert the model to TFLITE FORMAT ###########
    tflite_model_path = convert_to_tflite(pruned1_model)

    # Without applying any optimizations to our model, the accuracy should remain the same ...
    print("\n\nPerforming inference with TFLite model...")
    tflite_no_opt_accuracy = tflite_inference(tflite_model_path, X_test, y_test)

    ##### QUANTIZATION #####

    # integer_tflite_path = convert_to_quantized_tflite(model, filename = "quantized_model.tflite")

    # print(f"\nInference with Integer quantization...")
    # tflite_int_quant_acc = tflite_inference(integer_tflite_path, X_test, y_test)

Training on: /GPU:0
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 28, 28, 6)         308       
  (PruneLowMagnitude)                                            
                                                                 
 max_pooling2d (MaxPooling2  (None, 14, 14, 6)         0         
 D)                                                              
                                                                 
 prune_low_magnitude_conv2d  (None, 10, 10, 16)        4818      
 _1 (PruneLowMagnitude)                                          
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 5, 5, 16)          0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 




Starting training iteration 2 with 1 epochs and learning rate 0.0005
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 prune_low_magnitude_conv2d  (None, 28, 28, 6)         308       
  (PruneLowMagnitude)                                            
                                                                 
 max_pooling2d (MaxPooling2  (None, 14, 14, 6)         0         
 D)                                                              
                                                                 
 prune_low_magnitude_conv2d  (None, 10, 10, 16)        4818      
 _1 (PruneLowMagnitude)                                          
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 5, 5, 16)          0         
 g2D)                                                            
                                                    

In [None]:
def load_mnist(i = 0):
    # Load MNIST dataset
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

    # Normalize the images to be between 0 and 1
    # comment below for the quantized
    # train_images = train_images.astype('float32') / 255.0
    # test_images = test_images.astype('float32') / 255.0

    # Reshape the images to add the channel dimension (28, 28, 1)
    train_images = np.expand_dims(train_images, axis=-1)  # (28, 28, 1)
    test_images = np.expand_dims(test_images, axis=-1)  # (28, 28, 1)

    # Return a single image with the batch dimension (1, 28, 28, 1)
    # Adding the batch dimension for a single image from the test set
    return np.expand_dims(test_images[i], axis=0), test_labels[i]  # Shape will be (1, 28, 28, 1)

In [None]:

#RUN MODEL
import time
import numpy as np
import tensorflow as tf

# Load the TFLite model
tflite_model_path = 'pruned_model.tflite'  # Replace with your model's path
interpreter = tf.lite.Interpreter(model_path=tflite_model_path)

# Allocate tensors (this will initialize the interpreter and load the model)
interpreter.allocate_tensors()

# Get input and output details
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Get the shape and dtype of the input tensor
input_shape = input_details[0]['shape']
input_dtype = input_details[0]['dtype']
inference_time = 0

# Load test data (assuming load_mnist function exists)
#for i in range (1,20):
test_images, test_labels = load_mnist(1)  # Ensure load_mnist is defined
image = test_images.astype(np.float32) / 255.0  # Normalize to [0,1] for float32 model
label = test_labels  # Corresponding label

# Ensure input tensor has the correct shape and type
image = np.reshape(image, input_shape).astype(np.float32)  # Explicitly cast to float32

# Set the input tensor
interpreter.set_tensor(input_details[0]['index'], image)

# Measure inference time
start_time = time.time()  # Start timer
interpreter.invoke()       # Run inference
end_time = time.time()     # End timer

# Compute elapsed time
inference_time += (end_time - start_time) * 1000  # Convert to milliseconds

# Get the output tensor
output_data = interpreter.get_tensor(output_details[0]['index'])

# Apply softmax to get probabilities
probabilities = tf.nn.softmax(output_data[0]).numpy()

# Print raw output and predicted class
#print("Raw Model Output (Logits):")
#print(output_data[0])

predicted_class = np.argmax(probabilities)
print(f"Predicted Class: {predicted_class}")
print(f"Ground Truth Label: {label}")

# Print inference time
print(f"Inference Time: {inference_time/1:.2f} ms")

Predicted Class: 2
Ground Truth Label: 2
Inference Time: 0.18 ms
