![alt text](../img/general/header_workflow.png)


# Model Compression 



### Libraries

In [None]:
import os
import numpy as np
from numpy import array
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd

## Tensorflow + Keras libraries
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD, Adam
import tensorflow_model_optimization as tfmot
from tensorflow_model_optimization.python.core.sparsity.keras import prune, pruning_callbacks, pruning_schedule
from tensorflow_model_optimization.sparsity.keras import strip_pruning

## Quantization
from qkeras import *

## Datasets
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.datasets import cifar100
from tensorflow.keras.datasets import mnist
from tensorflow.keras.datasets import fashion_mnist

from distillationClassKeras import *
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

### GPU 

In [None]:
# GPU
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

import tensorflow as tf
print("GPUs: ", len(tf.config.experimental.list_physical_devices('GPU')))

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

    except RuntimeError as e:
        print(e)

### Load pre-trained model

In [None]:
# Load a previously trained MNIST fully connected (FC) model
model = load_model('../models/mnistModel_FC.h5')


In [None]:
# Display a summary of the model architecture and parameters
model.summary()


### Dataset Loading

For this laboratory, we will work with the MNIST dataset of handwritten digits, which is commonly used for image classification tasks.

In [None]:
# Load the MNIST dataset, split into training and testing sets (data and labels)
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixel values to the range [0, 1]
x_train_norm = x_train / 255.0
x_test_norm = x_test / 255.0


In [None]:
# Number of classes in the dataset
n_classes = 10

# One-hot encode the class labels
y_train = to_categorical(y_train, num_classes=n_classes)
y_test = to_categorical(y_test, num_classes=n_classes)


## Compression Techniques

### Pruning

Pruning is a technique used to reduce the size and complexity of a deep learning model by removing unnecessary weights or neurons. Its main goal is to improve model efficiency by reducing memory usage and accelerating inference, while maintaining performance as much as possible.


In [None]:
import math



epochs = 16
batch = 64
val_split = 0.2

# Number of training samples after validation split
n_train_samples = int(x_train.shape[0] * (1 - val_split))

# Steps per epoch
steps_per_epoch = math.ceil(n_train_samples / batch)

# Total number of pruning steps
# If end_step is too small, pruning happens too aggressively and may degrade accuracy. If it is too large, the model may not reach the target sparsity.
end_step = epochs * steps_per_epoch

final_sparsity = 0.3

# begin_step = 0
# → Pruning starts from the beginning of training

# end_step = epochs × steps_per_epoch
# → The model reaches the target sparsity at the end of training

# final_sparsity = 0.3
# → 30% of the model weights will be zero


pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0,
        final_sparsity=final_sparsity,
        begin_step=0,
        end_step=end_step
    )
}


In [None]:
# Apply magnitude-based pruning to the original model
modelP = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

# Training configuration
lr = 0.001
loss = 'categorical_crossentropy'
op = Adam(learning_rate=lr)
metrics = ['accuracy']

# Compile the pruned model
modelP.compile(
    optimizer=op,
    loss=loss,
    metrics=metrics
)


In [None]:
# Train the pruned model
historyP = modelP.fit(
    x_train_norm,
    y_train,
    validation_split=val_split,
    epochs=epochs,
    batch_size=batch,
    callbacks=[tfmot.sparsity.keras.UpdatePruningStep()],  # Required to update pruning step during training
    verbose=1
)


In [None]:
# Plot accuracy over epochs for the pruned model
plt.figure(figsize=(10, 3))

plt.plot(historyP.history['accuracy'], label='Train Accuracy')
plt.plot(historyP.history['val_accuracy'], label='Validation Accuracy')

plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


#### Metrics

In [None]:
# Perform inference on the test set using the pruned model
y_pred_probs = modelP.predict(x_test_norm)

# Convert predicted probabilities to class indices
y_pred = np.argmax(y_pred_probs, axis=1)

# Since y_test is one-hot encoded, convert it back to class indices
y_true = np.argmax(y_test, axis=1)

# Compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap="Purples")

plt.title('Confusion Matrix for MNIST – Pruned Model')
plt.show()


#### Saving the Model

Once the model has been trained and evaluated, it can be saved to disk for later use. Saving the model allows it to be reloaded without retraining, preserving both the learned weights and the model architecture.

In [None]:
# Remove pruning wrappers to obtain the final optimized model
modelP = strip_pruning(modelP)

# Save the trained model to disk
modelP.save('models/mnistModel_FC_P.h5')


#### Exercise:

Modify the value of final_sparsity (0.1, 0.3, 0.5, 0.7, 0.9) and report the evaluation metrics for each pruned model.

- What conclusions can you draw regarding the relationship between sparsity level, model performance, and efficiency?


#### Pruning Results Summary

| Experiment | Final Sparsity | Train Accuracy | Validation Accuracy | Test Accuracy | Observations |
|-----------|----------------|----------------|---------------------|---------------|--------------|
| 1 | 0.1 | | | | |
| 2 | 0.3 | | | | |
| 3 | 0.5 | | | | |
| 4 | 0.7 | | | | |
| 5 | 0.9 | | | | |


**Guiding questions:**
- At what sparsity level does accuracy start to degrade significantly?
- Is there a sparsity range where efficiency improves with minimal loss in performance?
- How does aggressive pruning affect training stability and convergence?



--- 

### Quantization 
Quantization is a technique that reduces the numerical precision of a neural network’s parameters by converting floating-point values (e.g., 32-bit) into lower-precision representations, such as 16-bit or even 8-bit. The main goal is to reduce model size and accelerate inference, especially on resource-constrained devices such as mobile phones or microcontrollers.

In this laboratory, we will use Quantization-Aware Training (QAT) and Quantization-Aware Pruning (QAP).

- QAT is a training technique in which the model learns to adapt to quantization before being deployed on hardware. Instead of training the model at full precision (32-bit floating point) and quantizing it afterward, quantization effects are simulated during training.

- QAP combines pruning (removal of unnecessary connections in the neural network) with quantization-aware training. The goal is to reduce the model size before quantization, resulting in a more efficient network without significantly sacrificing accuracy.

#### Quantization-aware training

The model is redefined using QKeras, an extension of Keras designed to create and train quantized neural network models. Its main objective is to optimize models for hardware platforms with limited resources, such as FPGAs, microcontrollers, and embedded accelerators.

- It allows defining weights and activations with different precision levels (e.g., 8-bit, 4-bit, ternary -1,0,1, etc.).

- By reducing numerical precision, it decreases memory usage and computational cost.

- It facilitates the conversion of quantized models into efficient FPGA implementations, ensuring compatibility with hls4ml.

- It is compatible with standard Keras layers, while providing additional support for low-precision configurations.

Once the model has been redefined and trained using QKeras, the training process is carried out as usual.

**QKeras reference:** 

Coelho, C. N., Kuusela, A., Zhuang, H., Aarrestad, T., Loncar, V., Ngadiuba, J., ... & Summers, S. (2020). _Ultra low-latency, low-area inference accelerators using heterogeneous deep quantization with QKeras and hls4ml_. arXiv preprint arXiv:2006.10159, 108.

In [None]:
# Definition of the number of bits for kernel, bias, and activations
# 8-bit quantization

kernelQ = "quantized_bits(8, 4, alpha=1)"
biasQ = "quantized_bits(8, 4, alpha=1)"
activationQ = "quantized_bits(8, 4)"

# Definition of a Quantization-Aware Training (QAT) model using QKeras
modelQAT = Sequential(
    [
        # Flatten the 2D input image into a 1D vector
        Flatten(input_shape=(28, 28)),

        # First quantized fully connected layer
        QDense(
            100,
            name="fc1",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),
        QActivation(activation=activationQ, name="relu1"),
        Dropout(0.1),

        # Second quantized fully connected layer
        QDense(
            50,
            name="fc2",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),
        QActivation(activation=activationQ, name="relu2"),
        Dropout(0.1),

        # Third quantized fully connected layer
        QDense(
            20,
            name="fc3",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),
        QActivation(activation=activationQ, name="relu3"),
        Dropout(0.1),

        # Output quantized layer
        QDense(
            10,
            name="output",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),

        # Softmax activation kept in full precision
        Activation(activation="softmax", name="softmax"),
    ],
    name="quantizedModel",
)


# In QKeras-based QAT, weights and activations are quantized during training, allowing the model to learn robustness to reduced numerical precision.
# The final softmax layer is typically kept in full precision to preserve numerical stability.

In [None]:
modelQAT.summary()

In [None]:
# Training configuration
epochs = 16
lr = 0.001
loss = 'categorical_crossentropy'
op = Adam(learning_rate=lr)
metrics = ['accuracy']
batch = 64
val_split = 0.2

# Compile the QAT model
modelQAT.compile(
    optimizer=op,
    loss=loss,
    metrics=metrics
)

# Train the QAT model
historyQAT = modelQAT.fit(
    x_train_norm,
    y_train,
    validation_split=val_split,
    epochs=epochs,
    batch_size=batch,
    verbose=1
)


In [None]:
# Plot accuracy over epochs
plt.figure(figsize=(10,3))
plt.plot(historyQAT.history['accuracy'], label='Train Accuracy')
plt.plot(historyQAT.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

#### Metrics

In [None]:
# Perform inference on the test set using the QAT model
y_pred_probs = modelQAT.predict(x_test_norm)

# Convert predicted probabilities to class indices
y_pred = np.argmax(y_pred_probs, axis=1)

# Since y_test is one-hot encoded, convert it back to class indices
y_true = np.argmax(y_test, axis=1)

# Compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap="Purples")

plt.title('Confusion Matrix for MNIST – QAT Model')
plt.show()


#### Quantization-aware pruning

Quantization-aware pruning (QAP) combines pruning with quantization-aware training. The goal is to reduce the model size after quantization, resulting in a more efficient neural network without significantly sacrificing accuracy.

In QAP, pruning and quantization effects are applied simultaneously during training. An incorrect pruning schedule may lead to excessive accuracy degradation.

In [None]:
import math

epochs = 16
batch = 64
val_split = 0.2

# Number of training samples after validation split
n_train_samples = int(x_train_norm.shape[0] * (1 - val_split))

# Steps per epoch
steps_per_epoch = math.ceil(n_train_samples / batch)

# Total number of pruning steps
end_step = epochs * steps_per_epoch

# begin_step = 0
# → Pruning starts from the beginning of training

# end_step = epochs × steps_per_epoch
# → The model reaches the target sparsity at the end of training

# final_sparsity = 0.3
# → 30% of the weights will be pruned


In [None]:
# Pruning strategy
final_sparsity = 0.3

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0,
        final_sparsity=final_sparsity,
        begin_step=0,
        end_step=end_step
    )
}


In [None]:
# Quantization strategy (for QAP)

# Definition of the number of bits for kernel, bias, and activations
# 8-bit quantization
kernelQ = "quantized_bits(8, 4, alpha=1)"
biasQ = "quantized_bits(8, 4, alpha=1)"
activationQ = "quantized_bits(8, 4)"

# Definition of a quantized model (QKeras) to be used for Quantization-Aware Pruning (QAP)
modelQ_QAP = Sequential(
    [
        # Flatten the 2D input image into a 1D vector
        Flatten(input_shape=(28, 28)),

        # First quantized fully connected layer
        QDense(
            100,
            name="fc1",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),
        QActivation(activation=activationQ, name="relu1"),
        Dropout(0.1),

        # Second quantized fully connected layer
        QDense(
            50,
            name="fc2",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),
        QActivation(activation=activationQ, name="relu2"),
        Dropout(0.1),

        # Third quantized fully connected layer
        QDense(
            20,
            name="fc3",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),
        QActivation(activation=activationQ, name="relu3"),
        Dropout(0.1),

        # Output quantized layer
        QDense(
            10,
            name="output",
            kernel_quantizer=kernelQ,
            bias_quantizer=biasQ,
            kernel_initializer="lecun_uniform",
        ),

        # Softmax activation kept in full precision for numerical stability
        Activation(activation="softmax", name="softmax"),
    ],
    name="quantizedModel",
)


In [None]:
# Training configuration
epochs = 16
lr = 0.001
loss = 'categorical_crossentropy'
op = Adam(learning_rate=lr)
metrics = ['accuracy']
batch = 64
val_split = 0.2

# Apply magnitude-based pruning to the quantized model (QAP)
modelQAP = tfmot.sparsity.keras.prune_low_magnitude(
    modelQ_QAP, **pruning_params
)

# Compile the QAP model
modelQAP.compile(
    optimizer=op,
    loss=loss,
    metrics=metrics
)

# Train the QAP model
history_QAP = modelQAP.fit(
    x_train_norm,
    y_train,
    validation_split=val_split,
    epochs=epochs,
    batch_size=batch,
    callbacks=[tfmot.sparsity.keras.UpdatePruningStep()],  # Required for pruning schedule updates
    verbose=1
)


In [None]:
# Plot accuracy during training for the QAP model
plt.figure(figsize=(10, 3))

plt.plot(history_QAP.history['accuracy'], label='Train Accuracy')
plt.plot(history_QAP.history['val_accuracy'], label='Validation Accuracy')

plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()


In [None]:
# Plot loss during training for the QAP model
plt.figure(figsize=(10, 3))

plt.plot(history_QAP.history['loss'], label='Train Loss')
plt.plot(history_QAP.history['val_loss'], label='Validation Loss')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()


**Note:** When using QAP, convergence may be slightly slower compared to the full-precision model due to the combined effects of quantization and pruning.
A stable validation curve indicates that efficiency gains are achieved without significant loss in accuracy.

#### Metrics

In [None]:
# Perform inference on the test set using the QAP model
y_pred_probs = modelQAP.predict(x_test_norm)

# Convert predicted probabilities to class indices
y_pred = np.argmax(y_pred_probs, axis=1)

# Since y_test is one-hot encoded, convert it back to class indices
y_true = np.argmax(y_test, axis=1)

# Compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap="Purples")

plt.title('Confusion Matrix for MNIST – QAP Model')
plt.show()



#### Exercise:

- Modify the bit-width (4, 8, 16, 32) and report the metrics for each model, considering both QAT and QAP. What conclusions can you draw?

- Modify the MLP architecture by increasing or decreasing the number of layers, and replace the Flatten layer with a Dense layer.

--- 

### Knowledge Distillation

This technique focuses on transferring knowledge from a large network (teacher) to a smaller and faster target network (distilled or student). The student model learns to reproduce the behavior of the teacher architecture while being computationally more efficient.

In knowledge distillation, the teacher model provides **soft labels**, which are probability distributions over classes rather than hard class labels. These soft labels contain richer information about class similarities and decision boundaries.

A **temperature** parameter is introduced in the softmax function to control the smoothness of the output probabilities. Higher temperature values produce softer probability distributions, allowing the student model to better capture the teacher’s knowledge during training. During inference, the temperature is typically set back to 1.

**Knowledge Distillation**: Hinton, G. (2015). _Distilling the Knowledge in a Neural Network_. arXiv preprint arXiv:1503.02531.

![alt text](../img/lab02/distillation.png)


In [None]:
# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize pixel values to the range [0, 1]
x_train, x_test = x_train / 255.0, x_test / 255.0

# Flatten images into 1D vectors
x_train = x_train.reshape(-1, 28 * 28)
x_test = x_test.reshape(-1, 28 * 28)

# Convert labels to one-hot encoded vectors
y_train = to_categorical(y_train, num_classes=10, dtype=int)
y_test = to_categorical(y_test, num_classes=10, dtype=int)


In [None]:
# Define the Teacher model (large MLP)

def build_teacher():
    model = keras.Sequential([
        Dense(512, activation="relu", input_shape=(28 * 28,)),
        Dense(256, activation="relu"),
        Dense(10, activation="softmax")
    ])
    return model

# Compile and train the Teacher model
teacher = build_teacher()

# Display model architecture
teacher.summary()

teacher.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)

# Train the teacher model
history = teacher.fit(
    x_train,
    y_train,
    epochs=16,
    batch_size=128,
    validation_data=(x_test, y_test)
)


In [None]:
# Define the Student model (smaller MLP)

def build_student():
    model = keras.Sequential([
        Dense(5, activation="relu", input_shape=(28 * 28,)),
        Dense(7, activation="relu"),
        Dense(10, activation="softmax")  # Softmax output layer
    ])
    return model

# Instantiate the student model
student = build_student()


In [None]:
# Convert one-hot encoded labels back to integer class indices (optional)
train_labels = np.argmax(y_train, axis=1)

# Create the distillation wrapper (student + teacher)
distilledMLP = Distiller(student=student, teacher=teacher)


In [None]:
distilledMLP.student.summary()

In [None]:
# Knowledge distillation process

adam = Adam(learning_rate=0.0001)

# Compile the distillation model
distilledMLP.compile(
    optimizer=adam,
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    
    # Loss for hard labels (ground truth)
    student_loss_fn=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    
    # Loss for soft labels (teacher predictions)
    distillation_loss_fn=keras.losses.KLDivergence(),
    
    # Weighting factor between hard-label loss and distillation loss
    alpha=0.1,
    
    # Temperature used to soften the teacher and student outputs
    temperature=10,
)

# Train the student model using knowledge distillation
history_distilledMLP = distilledMLP.fit(
    x_train,
    train_labels,
    validation_split=0.2,
    batch_size=64,
    epochs=32
)


**Note:** A higher temperature produces softer probability distributions, allowing the student model to better capture the teacher’s knowledge.
The parameter alpha controls the trade-off between learning from ground-truth labels and mimicking the teacher’s behavior.

In [None]:
# Plot accuracy over epochs for the distilled student model
plt.figure(figsize=(15, 3))
plt.plot(
    history_distilledMLP.history['sparse_categorical_accuracy'],
    label='Train Accuracy'
)
plt.plot(
    history_distilledMLP.history['val_sparse_categorical_accuracy'],
    label='Validation Accuracy'
)

plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Over Epochs')
plt.show()



In [None]:
# Plot loss over epochs for the distilled student model
plt.figure(figsize=(15, 3))
plt.plot(
    history_distilledMLP.history['student_loss'],
    label='Train Loss'
)
plt.plot(
    history_distilledMLP.history['val_student_loss'],
    label='Validation Loss'
)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Over Epochs')
plt.show()


**Note:** In knowledge distillation, the student model often achieves higher accuracy than a directly trained small model, as it benefits from the soft targets provided by the teacher.

#### Metrics

In [None]:
# Perform inference on the test set using the distilled student model
y_pred_probs = distilledMLP.student.predict(x_test)

# Convert predicted probabilities to class indices
y_pred = np.argmax(y_pred_probs, axis=1)

# Since y_test is one-hot encoded, convert it back to class indices
y_true = np.argmax(y_test, axis=1)

# Compute the confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Display the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Purples')

plt.title('Confusion Matrix for MNIST – Distilled Student Model')
plt.show()


#### Saving the Model

Once the model has been trained and evaluated, it can be saved to disk for later use. Saving the model allows it to be reloaded without retraining, preserving both the learned weights and the model architecture.

In [None]:
# Save the trained model to disk
distilledMLP.student.save("models/mnistKD.h5")

--------

#### Exercises

- Vary the **student model architecture** and analyze the performance of the distilled model using the appropriate evaluation metrics. Consider the following scenarios:
  
  - Decrease the number of layers.
  - Vary the number of neurons in each layer.
  - Increase the number of layers while reducing the number of neurons per layer.

- Repeat the **model compression processes** for the **Fashion-MNIST** and **CIFAR-10** datasets.


### Results Summary Table

| Experiment | Dataset | Model Type | # Layers | Neurons per Layer | Compression Method | Bit-width | Final Sparsity | Train Accuracy | Validation Accuracy | Test Accuracy | Model Size / Params | Observations |
|-----------|---------|------------|----------|-------------------|--------------------|-----------|----------------|----------------|---------------------|---------------|---------------------|--------------|
| 1 | MNIST | Student (baseline) | | | None | FP32 | – | | | | | |
| 2 | MNIST | Student | | | QAT | 8 | – | | | | | |
| 3 | MNIST | Student | | | QAP | 8 | 0.3 | | | | | |
| 4 | MNIST | Student | | | QAT + Distillation | 8 | – | | | | | |
| 5 | Fashion-MNIST | Student | | | QAT | 8 | – | | | | | |
| 6 | Fashion-MNIST | Student | | | QAP | 8 | 0.3 | | | | | |
| 7 | CIFAR-10 | Student | | | QAT | 8 | – | | | | | |
| 8 | CIFAR-10 | Student | | | QAT + Distillation | 8 | – | | | | | |


**Guiding Questions:**
- How does reducing the number of layers affect accuracy and generalization?
- Is it more effective to reduce depth or width when compressing the model?
- At what point does compression significantly degrade performance?
- Does knowledge distillation help recover accuracy lost due to quantization or pruning?
- Which configuration offers the best trade-off between accuracy and efficiency?



- To continue with the next Machine Learning laboratory:

  - **For the MNIST dataset, define a binary classifier** that discriminates only between digits **6 and 9**, or between **7 and 5**.  
    Apply **Quantization-Aware Pruning (QAP)** as the training and compression method.

  - **For the same binary classification scenario**, apply **Quantization-Aware Training (QAT)** combined with **Knowledge Distillation**.