# Lab 1: pruning AI models

## Intro

In this lab you will learn to **optimize** and **prune AI models** using the **LiteRT** library (previously called Tensorflow Lite \[For Microcontrollers]). <br />


To be able to run the necessary scripts throughout this lab, you will need access to a GPU. You can either **make use of your own GPU** (through a Linux or Windows WSL system, with a GPU-enabled tensorflow installed (version 2.18.0)) **or use Google Colab**. <br /><br />

**To run this notebook in colab, you will need use the zip from Lab 0 and place this file in the extracted folder.** <br /><br />Instructions lab 0: download the lab folder on Ufora, **unzip it and put it on your Google Drive** (this folder will only be a few MBs in size). You can **drag and drop** the unzipped folder in your Google Drive.<br />


Next, **double click on the provided .ipynb file** for each lab which will open Google Colab. <br />From there, fill in the necessary variables (such as the path to your Google Drive) and you will be able to **run and program the necessary code. Be sure te select a GPU under Runtime > Change runtime type.**

In [None]:
%pip install --user --upgrade tensorflow-model-optimization
%pip install tf_keras

# Click Runtime > Restart session
# This ensures the above installed libraries are correctly imported

Collecting tensorflow-model-optimization
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl.metadata (904 bytes)
Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-model-optimization
Successfully installed tensorflow-model-optimization-0.8.0


In [None]:
import pdb

In [None]:
# Run this code to connect your Google Drive to Colab

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Change to your project directory
path_to_lab = "drive/MyDrive/embedded-ML/labs"

## Functions
Below you can find **functions** which can be used to complete the lab. <br />
_Note: when running the below code for the first time on Google Colab, you will get a warning that you need to restart your runtime session. This is expected because the kernel needs to use the expected tensorflow version._

In [None]:
import tensorflow as tf
from tensorflow import keras as keras
import tensorflow_model_optimization as tfmot
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd

def mnist_model(train=False):
    model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(28, 28)),
    tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(filters=64, kernel_size=(6, 6), activation=tf.nn.relu, name="conv1"),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation=tf.nn.relu, name="conv2"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(16, activation=tf.nn.relu, name="dense1"),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax, name="dense2")
    ])

    model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

    if train:
        model.fit(x=train_images, y= train_labels, batch_size=64, epochs=50, validation_data=(test_images, test_labels))
    else:
        # model = tf.keras.models.load_model("Models/mnist.keras")
        model = tf.keras.models.load_model(f"{path_to_lab}/Models/mnist")
    return model

In [None]:
import shutil

# shutil.rmtree(f'{path_to_lab}/Models/mnist_pruned_50pct')
# shutil.rmtree(f'{path_to_lab}/Models/mnist_pruned_95pct')

## Part 1: convert models using LiteRT

1) Load the mnist dataset and pre-trained model. For this exercise we will use a pre-trained model working on the mnist dataset for digit recognition.
2) Evaluate the model. To obtain a baseline performance, evaluate the model without any LiteRT optimizations applied.
3) Convert the model to the LiteRT format and evaluate whether this has an impact on performance or not.

**Q1: Did you need to change anything to your (test)dataset?**

**Q2: Do you see any difference in accuracy compared to the baseline?**

In [None]:
# Load dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Load pre-trained model
model = mnist_model(train=False)
model.save("Models/mnist")

# Verify performance by inserting your code below
# ---- see lab 0

# Perform lite model conversion
converter = tf.lite.TFLiteConverter.from_saved_model(f"{path_to_lab}/Models/mnist") # path to the SavedModel directory
tflite_model = converter.convert()

with open(f'{path_to_lab}/Models/mnist.tflite', 'wb') as f:
  f.write(tflite_model)

In [None]:
# Verify performance of lite model

# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path=f"{path_to_lab}/Models/mnist.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test model on random input data.
input_shape = input_details[0]['shape']
test_image = test_images[0].astype(np.float32)
test_image = np.expand_dims(test_image, axis=0)

print(input_shape)
print(test_image.shape)

interpreter.set_tensor(input_details[0]['index'], test_image)
interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])
predicted_label = np.argmax(output_data)
print(test_labels[0] == predicted_label)

[ 1 28 28]
(1, 28, 28)
True


In [None]:
correct = 0
for i in range(len(test_images)):

    # change type of array elements form UINT to float32
    test_image = test_images[i].astype(np.float32)
    # change shape of test img to be batch of lenght 1
    test_image = np.expand_dims(test_image, axis=0)

    # input test_image
    interpreter.set_tensor(input_details[0]['index'], test_image)

    # run model
    interpreter.invoke()

    # get result
    output_data = interpreter.get_tensor(output_details[0]['index'])

    if np.argmax(output_data) == test_labels[i]:
        correct += 1

accuracy = correct / len(test_images)
print(f"TFLite Model Accuracy: {accuracy:.4f}")

TFLite Model Accuracy: 0.9912


## Part 2: prune optimized model

4) Prune all layers of the model, once at 50\%, once at 95\%. You can find

---

information on how to prune keras models [here](https://www.tensorflow.org/model_optimization/guide/pruning/pruning_with_keras).

In [None]:
def prune_model(model, sparsity, epochs=15):
    model_copy = tf.keras.models.clone_model(model)
    model_copy.set_weights(model.get_weights())

    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
            target_sparsity=sparsity,
            begin_step=0,
            end_step=int(train_images.shape[0] / 64 * epochs)  # batches of 64
        )
    }

    # Apply pruning to the copy
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    model_for_pruning = prune_low_magnitude(model_copy, **pruning_params)

    model_for_pruning.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )

    return model_for_pruning

def train_pruned_model(model, sparsity_level, epochs=15):

    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir=f'logs/pruning_{int(sparsity_level*100)}pct'),
        keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    ]

    model.fit(
        train_images, train_labels,
        batch_size=64,
        epochs=epochs,
        validation_data=(test_images, test_labels),
        callbacks=callbacks
    )

    # final sparsity mask
    model = tfmot.sparsity.keras.strip_pruning(model)

    # save pruned model
    model.save(f'{path_to_lab}/Models/mnist_pruned_{int(sparsity_level*100)}pct')

    # convert to TFLite
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    tflite_model = converter.convert()

    with open(f'{path_to_lab}/Models/mnist_pruned_{int(sparsity_level*100)}pct.tflite', 'wb') as f:
        f.write(tflite_model)

    return model

In [None]:
# model_50 = prune_model(model, 0.5)
# train_pruned_model(model_50, 0.5)

model_95 = prune_model(model, 0.95)
train_pruned_model(model_95, 0.95, epochs=20)

Epoch 1/20


  output, from_logits = _get_logits(


  1/938 [..............................] - ETA: 46:38 - loss: 8.1515e-04 - accuracy: 1.0000



Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20




<tf_keras.src.engine.sequential.Sequential at 0x7dd0e57204d0>

In [None]:
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f'Baseline test accuracy: {test_acc:.4f}')

test_loss, test_acc = model_50.evaluate(test_images, test_labels, verbose=0)
print(f'50% Pruned test accuracy: {test_acc:.4f}')

test_loss, test_acc = model_95.evaluate(test_images, test_labels, verbose=0)
print(f'95% Pruned test accuracy: {test_acc:.4f}')

Baseline test accuracy: 0.9912
50% Pruned test accuracy: 0.9921
95% Pruned test accuracy: 0.5549


5) Ensure all layers have the correct sparsity

**Q3: what difference do you see in accuracy?**

In [None]:
def calculate_sparsity(weights):
    """Calculate the sparsity (percentage of zeros) in the weights tensor."""
    total_params = tf.size(weights).numpy()
    zero_params = tf.math.count_nonzero(tf.abs(weights) < 1e-10).numpy()
    return (zero_params / total_params)

def verify_layer_sparsity(model):
    results = []
    for i, layer in enumerate(model.layers):
        layer_name = layer.name

        # skip layers without weights or non-prunable layers
        if not layer.weights or isinstance(layer, (tf.keras.layers.Dropout,
                                                tf.keras.layers.MaxPooling2D,
                                                tf.keras.layers.Reshape)):
            continue

        # process each weight in the layer
        for j, weight in enumerate(layer.weights):
            weight_name = weight.name

            if 'kernel' in weight_name or 'weight' in weight_name:
                sparsity = calculate_sparsity(weight)

                results.append({
                    'layer_index': i,
                    'layer_name': layer_name,
                    'weight_name': weight_name,
                    'sparsity': sparsity,
                })

    return pd.DataFrame(results)

print(verify_layer_sparsity(model_50))
print(verify_layer_sparsity(model_95))

   layer_index                  layer_name      weight_name  sparsity
0            1   prune_low_magnitude_conv1   conv1/kernel:0       0.5
1            4   prune_low_magnitude_conv2   conv2/kernel:0       0.5
2            7  prune_low_magnitude_dense1  dense1/kernel:0       0.5
3            8  prune_low_magnitude_dense2  dense2/kernel:0       0.5
   layer_index                  layer_name      weight_name  sparsity
0            1   prune_low_magnitude_conv1   conv1/kernel:0  0.950087
1            4   prune_low_magnitude_conv2   conv2/kernel:0  0.949978
2            7  prune_low_magnitude_dense1  dense1/kernel:0  0.949990
3            8  prune_low_magnitude_dense2  dense2/kernel:0  0.950000


6) Iterate pruning individually on the following layers: _["conv1","conv2","dense1","dense2"]_ and using the following sparsity levels: _[0.5, 0.85, 0.95, 0.99]_.
   
   _Hint: You can find more information about layer-based pruning [here](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide)_


In [None]:
def prune_model_layer(model, layer_name, sparsity, epochs=15):
    """Prune only a specific layer in the model."""
    # Create a clean copy of the model
    model_copy = tf.keras.models.clone_model(model)
    model_copy.set_weights(model.get_weights())

    # Configure layer-wise pruning by setting up pruning_configs dictionary
    pruning_configs = []

    # Define the pruning schedule
    pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(
        target_sparsity=sparsity,
        begin_step=0,
        end_step=int(train_images.shape[0] / 64 * epochs)  # batches of 64
    )

    # Find the target layer by name and configure just that layer for pruning
    for layer in model_copy.layers:
        if layer.name == layer_name:
            pruning_configs.append({
                'layer': layer,
                'pruning_schedule': pruning_schedule
            })
            break

    # Apply the layer-specific pruning configuration
    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
    model_for_pruning = prune_low_magnitude(
        model_copy,
        pruning_configs=pruning_configs
    )

    # Compile the pruned model
    model_for_pruning.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )

    return model_for_pruning

def train_pruned_layer_model(model, layer_name, sparsity_level, epochs=15):
    """Train a model with a specific layer pruned."""
    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir=f'logs/pruning_{layer_name}_{int(sparsity_level*100)}pct'),
        tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
    ]

    model.fit(
        train_images, train_labels,
        batch_size=64,
        epochs=epochs,
        validation_data=(test_images, test_labels),
        callbacks=callbacks
    )

    # Strip final sparsity mask
    stripped_model = tfmot.sparsity.keras.strip_pruning(model)

    # Recompile the stripped model
    stripped_model.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )

    # Save pruned model
    model_path = f'{path_to_lab}/Models/mnist_pruned_{layer_name}_{int(sparsity_level*100)}pct'
    stripped_model.save(model_path)

    # Convert to TFLite
    converter = tf.lite.TFLiteConverter.from_keras_model(stripped_model)
    tflite_model = converter.convert()

    with open(f'{model_path}.tflite', 'wb') as f:
        f.write(tflite_model)

    return stripped_model

In [None]:
layers = ["conv1","conv2","dense1","dense2"]
sparsity_levels = [0.5, 0.85, 0.95, 0.99]


results = {}

# First evaluate the baseline model
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=0)
print(f'Baseline test accuracy: {test_acc:.4f}')
results['baseline'] = test_acc

# Iterate through each layer and sparsity level
for layer_name in layers:
    for sparsity in sparsity_levels:
        print(f"\nPruning layer '{layer_name}' with {sparsity*100:.0f}% sparsity")

        model_pruned = prune_model_layer(model, layer_name, sparsity)

        # train
        pruned_model = train_pruned_layer_model(model_pruned, layer_name, sparsity)

        # evaluate
        test_loss, test_acc = pruned_model.evaluate(test_images, test_labels, verbose=0)
        print(f"Layer '{layer_name}' @ {sparsity*100:.0f}% pruned test accuracy: {test_acc:.4f}")
        results[f"{layer_name}_{int(sparsity*100)}pct"] = test_acc



Baseline test accuracy: 0.9912

Pruning layer 'conv1' with 50% sparsity
Epoch 1/15


  output, from_logits = _get_logits(


  1/938 [..............................] - ETA: 1:00:23 - loss: 0.2726 - accuracy: 0.9375



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15


  output, from_logits = _get_logits(


Layer 'conv1' @ 50% pruned test accuracy: 0.0980

Pruning layer 'conv1' with 85% sparsity
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


  output, from_logits = _get_logits(


Layer 'conv1' @ 85% pruned test accuracy: 0.0981

Pruning layer 'conv1' with 95% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 59:32 - loss: 0.0043 - accuracy: 1.0000



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


  output, from_logits = _get_logits(


Layer 'conv1' @ 95% pruned test accuracy: 0.0979

Pruning layer 'conv1' with 99% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 43:43 - loss: 0.0072 - accuracy: 1.0000



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15


  output, from_logits = _get_logits(


Layer 'conv1' @ 99% pruned test accuracy: 0.0981

Pruning layer 'conv2' with 50% sparsity
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15


  output, from_logits = _get_logits(


Layer 'conv2' @ 50% pruned test accuracy: 0.0984

Pruning layer 'conv2' with 85% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 43:34 - loss: 0.0796 - accuracy: 0.9844



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


  output, from_logits = _get_logits(


Layer 'conv2' @ 85% pruned test accuracy: 0.0984

Pruning layer 'conv2' with 95% sparsity
Epoch 1/15
  6/938 [..............................] - ETA: 10s - loss: 0.0705 - accuracy: 0.9844    



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15


  output, from_logits = _get_logits(


Layer 'conv2' @ 95% pruned test accuracy: 0.0982

Pruning layer 'conv2' with 99% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 45:24 - loss: 0.0022 - accuracy: 1.0000



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15


  output, from_logits = _get_logits(


Layer 'conv2' @ 99% pruned test accuracy: 0.0987

Pruning layer 'dense1' with 50% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 43:18 - loss: 0.1929 - accuracy: 0.9531



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15


  output, from_logits = _get_logits(


Layer 'dense1' @ 50% pruned test accuracy: 0.0984

Pruning layer 'dense1' with 85% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 43:39 - loss: 4.5116e-04 - accuracy: 1.0000



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15


  output, from_logits = _get_logits(


Layer 'dense1' @ 85% pruned test accuracy: 0.0985

Pruning layer 'dense1' with 95% sparsity
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15


  output, from_logits = _get_logits(


Layer 'dense1' @ 95% pruned test accuracy: 0.0982

Pruning layer 'dense1' with 99% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 43:44 - loss: 0.0057 - accuracy: 1.0000



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


  output, from_logits = _get_logits(


Layer 'dense1' @ 99% pruned test accuracy: 0.0984

Pruning layer 'dense2' with 50% sparsity
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15


  output, from_logits = _get_logits(


Layer 'dense2' @ 50% pruned test accuracy: 0.0979

Pruning layer 'dense2' with 85% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 44:41 - loss: 0.0029 - accuracy: 1.0000



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15


  output, from_logits = _get_logits(


Layer 'dense2' @ 85% pruned test accuracy: 0.0980

Pruning layer 'dense2' with 95% sparsity
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


  output, from_logits = _get_logits(


Layer 'dense2' @ 95% pruned test accuracy: 0.0981

Pruning layer 'dense2' with 99% sparsity
Epoch 1/15
  1/938 [..............................] - ETA: 46:39 - loss: 0.0023 - accuracy: 1.0000



Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


   
**Q4: Generate the following figure (see also in the Ufora question) using the above results:**

![PruningFigure.png](PruningFigure.png)

**Q5: How much would you prune the model for an embedded ML application?**  

**Q6: Can you see the lite model decrease in size? By how much? Tip: check the zipped file size**