# Neural Architecture Search with Knowledge Distillation for TinyML

## Overview

This notebook implements a Neural Architecture Search (NAS) framework combined with Knowledge Distillation, aimed at finding efficient model architectures for TinyML applications. The primary goal is to discover compact, high-performing neural network architectures that can be deployed on resource-constrained devices.

Key components and techniques:

1. **Neural Architecture Search (NAS)**: Systematically explores various model architectures, focusing on configurations suitable for small devices.

2. **Knowledge Distillation**: Utilises a larger, pre-trained teacher model to guide the training of smaller student models, potentially improving their performance.

3. **Model Pruning**: Applies network pruning to reduce model size and potentially improve generalization.

4. **Quantization**: Implements quantization to further reduce model size and improve inference speed.

5. **CIFAR-10 Dataset**: Uses the CIFAR-10 dataset (converted to grayscale) as a benchmark for evaluating model performance.

6. **Efficiency Metrics**: Considers both model accuracy and size to identify the best architectures for TinyML applications.



# Setup and Imports

This cell imports the necessary libraries and modules for our neural architecture search (NAS) with knowledge distillation.

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import itertools
import tempfile
import os
import tensorflow_model_optimization as tfmot
from sklearn.model_selection import train_test_split
from collections import deque
import itertools
import matplotlib.pyplot as plt

# Distiller Class

This class implements the knowledge distillation process. It combines the loss from the student model's predictions and the distillation loss from comparing the student's softened predictions to the teacher's softened predictions.

In [None]:
# Distiller class
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data
        teacher_predictions = self.teacher(x, training=False)
        with tf.GradientTape() as tape:
            student_predictions = self.student(x, training=True)
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1)
            ) * (self.temperature ** 2)
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, student_predictions)
        return {m.name: m.result() for m in self.metrics}

    def call(self, inputs, training=False):
        return self.student(inputs, training=training)

# Search Space Definition

Here we define the possible layers and parameters that our neural architecture search will explore.

In [None]:
# # Define the search space
# layer_types = ['Conv2D', 'Dense', 'Flatten', 'MaxPool2D', 'GlobalAveragePooling2D']
# conv_filters = [16, 32, 64]
# dense_units = [64, 128, 256]
# max_layers = 5

# Model Creation Function

This function creates a Keras model based on a given configuration. It ensures that the model structure is valid, handling the transition from convolutional to dense layers appropriately.

In [None]:
# Function to create a model given a configuration
def create_model(config, input_shape, num_classes):
    model = keras.Sequential()
    model.add(keras.Input(shape=input_shape))

    has_flattened = False
    for layer in config[:-1]:  # Apply all layers except the last one
        if layer[0] == 'Conv2D':
            if has_flattened:
                continue  # Skip Conv2D if we've already flattened
            model.add(keras.layers.Conv2D(layer[1], (3, 3), activation='relu', padding='same'))
        elif layer[0] == 'Dense':
            if not has_flattened:
                model.add(keras.layers.Flatten())
                has_flattened = True
            model.add(keras.layers.Dense(layer[1], activation='relu'))
        elif layer[0] == 'Flatten':
            if not has_flattened:
                model.add(keras.layers.Flatten())
                has_flattened = True
        elif layer[0] == 'MaxPool2D':
            if not has_flattened:
                model.add(keras.layers.MaxPool2D((2, 2)))
        elif layer[0] == 'GlobalAveragePooling2D':
            if not has_flattened:
                model.add(keras.layers.GlobalAveragePooling2D())
                has_flattened = True

    # Ensure the model is flattened before the final Dense layer
    if not has_flattened:
        model.add(keras.layers.Flatten())

    # Always add the final Dense layer with softmax activation
    model.add(keras.layers.Dense(num_classes, activation='softmax'))

    return model

# Quick Evaluation Function

This function performs a rapid evaluation of a model using a small number of epochs and early stopping. It's used to quickly filter out poorly performing architectures.

In [None]:
def quick_evaluate(model, x_train, y_train, x_val, y_val, epochs=5, patience=2):
    early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience)
    history = model.fit(
        x_train, y_train,
        epochs=epochs,
        validation_data=(x_val, y_val),
        callbacks=[early_stop],
        verbose=0
    )
    return history.history['val_accuracy'][-1]

# Neural Architecture Search Function

This function performs a comprehensive neural architecture search with knowledge distillation, pruning, and quantization.

1. Search Space Definition:
   - Defines the range of architectural choices (e.g., number of convolutional blocks, layers per block, initial filters, dense layers).

2. Data Splitting:
   - Splits the training data to create a smaller subset for quick evaluation.

3. Architecture Generation:
   - Iterates through different combinations of architectural parameters to generate various model configurations.

4. Quick Evaluation:
   - For each generated architecture, performs a rapid evaluation using a subset of the data.
   - This step helps to quickly filter out poorly performing architectures without spending time on full training.

5. Knowledge Distillation:
   - For promising architectures (those passing the quick evaluation threshold), applies knowledge distillation.
   - Uses a pre-trained teacher model to guide the training of the student (generated) model.
   - Combines the standard cross-entropy loss with a distillation loss that encourages the student to mimic the teacher's softened outputs.

6. Pruning:
   - Applies network pruning to reduce model size and potentially improve generalization.
   - Uses polynomial decay pruning schedule, gradually increasing sparsity over time.

7. Quantization:
   - Applies quantization to further reduce model size and improve inference speed.
   - Converts the model to TensorFlow Lite format with default optimizations.

8. Final Evaluation:
   - Assesses the pruned and quantized model on the test set to get the final accuracy.

9. Model Tracking:
   - Keeps track of the top 5 models based on accuracy and model size.

In [None]:
def search_architectures(x_train, y_train, x_test, y_test, input_shape, num_classes, teacher):
    best_models = deque(maxlen=5)
    best_metrics = (float('inf'), float('inf'), 0)  # (size, params, accuracy)

    # Define the search space
    conv_blocks = [1, 2, 3]  # Number of convolutional blocks
    conv_layers_per_block = [1, 2]  # Conv layers in each block
    initial_filters = [16, 32, 64]  # Initial number of filters
    dense_layers = [0, 1, 2]  # Number of dense layers before the final layer
    dense_units_options = [64, 128, 256, 512]  # Units in dense layers

    # Split the training data for quick evaluation
    x_train_quick, x_val_quick, y_train_quick, y_val_quick = train_test_split(
        x_train, y_train, test_size=0.2, random_state=42
    )

    for n_blocks in conv_blocks:
        for layers_per_block in conv_layers_per_block:
            for init_filters in initial_filters:
                for n_dense in dense_layers:
                    for dense_units in dense_units_options:
                        # Build the model configuration
                        config = []
                        filters = init_filters
                        for _ in range(n_blocks):
                            for _ in range(layers_per_block):
                                config.append(('Conv2D', filters))
                            config.append(('MaxPool2D', None))
                            filters *= 2  # Double the number of filters

                        config.append(('GlobalAveragePooling2D', None))

                        for _ in range(n_dense):
                            config.append(('Dense', dense_units))

                        # Final classification layer
                        config.append(('Dense', num_classes))

                        try:
                            student = create_model(config, input_shape, num_classes)
                            student.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
                        except:
                            continue

                        # Quick evaluation
                        quick_accuracy = quick_evaluate(student, x_train_quick, y_train_quick, x_val_quick, y_val_quick)

                        # Only proceed with full training if the model shows promise
                        if quick_accuracy > 0.6:
                            # Apply knowledge distillation
                            distiller = Distiller(student=student, teacher=teacher)
                            distiller.compile(
                                optimizer=keras.optimizers.Adam(),
                                metrics=[keras.metrics.CategoricalAccuracy()],
                                student_loss_fn=keras.losses.CategoricalCrossentropy(),
                                distillation_loss_fn=keras.losses.KLDivergence(),
                                alpha=0.1,
                                temperature=3,
                            )
                            distiller.fit(x_train, y_train, epochs=10, validation_split=0.2, verbose=0)

                            # Apply pruning
                            pruning_params = {
                                'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
                                    initial_sparsity=0.50,
                                    final_sparsity=0.80,
                                    begin_step=0,
                                    end_step=len(x_train) * 10
                                )
                            }
                            pruned_model = tfmot.sparsity.keras.prune_low_magnitude(distiller.student, **pruning_params)
                            pruned_model.compile(
                                optimizer=keras.optimizers.Adam(),
                                loss=keras.losses.CategoricalCrossentropy(),
                                metrics=[keras.metrics.CategoricalAccuracy()],
                            )

                            with tf.device('/cpu:0'):
                                pruned_model.fit(x_train, y_train, epochs=1, validation_split=0.2, callbacks=[
                                    tfmot.sparsity.keras.UpdatePruningStep(),
                                    tfmot.sparsity.keras.PruningSummaries(log_dir=tempfile.mkdtemp()),
                                ], verbose=0)



                            # Apply quantization-aware training
                            quantize_model = tfmot.quantization.keras.quantize_model
                            stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
                            q_aware_model = quantize_model(stripped_pruned_model)

                            q_aware_model.compile(
                                optimizer=keras.optimizers.Adam(),
                                loss=keras.losses.CategoricalCrossentropy(),
                                metrics=[keras.metrics.CategoricalAccuracy()],
                            )

                            q_aware_model.fit(x_train, y_train, epochs=5, validation_split=0.2, verbose=0)

                            # Evaluate the quantization-aware model
                            _, accuracy = q_aware_model.evaluate(x_test, y_test, verbose=0)

                            # Convert to TFLite
                            converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
                            converter.optimizations = [tf.lite.Optimize.DEFAULT]
                            quantized_tflite_model = converter.convert()

                            # Save and get model size
                            _, tflite_file = tempfile.mkstemp('.tflite')
                            with open(tflite_file, 'wb') as f:
                                f.write(quantized_tflite_model)

                            model_size = os.path.getsize(tflite_file) / float(2**20)  # Size in MB
                            num_params = q_aware_model.count_params()

                            print(f"Configuration: {config}")
                            print(f"Quick Accuracy: {quick_accuracy:.4f}")
                            print(f"Final Accuracy: {accuracy:.4f}")
                            print(f"Model size: {model_size:.2f} MB")
                            print(f"Number of parameters: {num_params}")
                            print("--------------------")

                            best_models.append((q_aware_model, accuracy, model_size, num_params, config))
                            best_models = deque(sorted(best_models, key=lambda x: (x[1], -x[2]), reverse=True)[:5])

                        else:
                            print(f"Configuration: {config}")
                            print(f"Quick Accuracy: {quick_accuracy:.4f} - Skipped")
                            print("--------------------")

    return best_models

# Data Preparation and Teacher Model Creation

This section loads and preprocesses the CIFAR-10 dataset, converting it to grayscale. It also defines and trains the teacher model that will be used for knowledge distillation.

In [None]:
# Load and preprocess CIFAR-10 data
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()



# Convert to grayscale
x_train = np.dot(x_train[..., :3], [0.2989, 0.5870, 0.1140])
x_test = np.dot(x_test[..., :3], [0.2989, 0.5870, 0.1140])

# Normalize and reshape
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 32, 32, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 32, 32, 1))

# Convert labels to categorical
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)



# Create the teacher model
teacher = keras.Sequential([
    keras.Input(shape=(32, 32, 1)),
    keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D((2, 2)),
    keras.layers.Dropout(0.2),
    keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D((2, 2)),
    keras.layers.Dropout(0.3),
    keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D((2, 2)),
    keras.layers.Dropout(0.4),
    keras.layers.Flatten(),
    keras.layers.Dense(512, activation='relu', kernel_initializer='he_uniform'),
    keras.layers.Dense(256, activation='relu', kernel_initializer='he_uniform'),
    keras.layers.Dense(128, activation='relu', kernel_initializer='he_uniform'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax'),
])

# Compile and train the teacher
teacher.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
teacher.fit(x_train, y_train, epochs=20, validation_split=0.2, verbose=1)

def curriculum_sort(x_train, y_train, teacher):
    # sort based on the entropy of the teacher's predictions
    teacher_preds = teacher.predict(x_train)
    entropies = -np.sum(teacher_preds * np.log(teacher_preds + 1e-10), axis=1)
    sorted_indices = np.argsort(entropies)
    return x_train[sorted_indices], y_train[sorted_indices]

x_train, y_train = curriculum_sort(x_train, y_train, teacher)


# Run the architecture search
best_model, (best_size, best_params, best_accuracy) = search_architectures(
    x_train, y_train, x_test, y_test, input_shape=(32, 32, 1), num_classes=10, teacher=teacher
)

print("Best model found:")
print(f"Size: {best_size:.2f} MB")
print(f"Parameters: {best_params}")
print(f"Accuracy: {best_accuracy:.4f}")

# Visualisation Function

This function creates a scatter plot of the top models, comparing their accuracy, number of parameters, and model size. It also prints detailed information about each top model.

In [None]:
def plot_top_models(top_models):
    accuracies = [model[1] for model in top_models]
    params = [model[3] for model in top_models]
    sizes = [model[2] for model in top_models]

    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(params, accuracies, c=sizes, s=100, cmap='viridis')
    plt.colorbar(scatter, label='Model Size (MB)')

    plt.xscale('log')
    plt.xlabel('Number of Parameters')
    plt.ylabel('Accuracy')
    plt.title('Top 5 Models: Accuracy vs Number of Parameters')

    for i, model in enumerate(top_models):
        plt.annotate(f"Model {i+1}", (params[i], accuracies[i]), xytext=(5, 5), textcoords='offset points')

    plt.tight_layout()
    plt.show()

    # Print detailed information about each model
    for i, (model, accuracy, size, num_params, config) in enumerate(top_models, 1):
        print(f"Model {i}:")
        print(f"  Accuracy: {accuracy:.4f}")
        print(f"  Size: {size:.2f} MB")
        print(f"  Parameters: {num_params}")
        print(f"  Configuration: {config}")
        print()


# Run the architecture search
# top_models = search_architectures(x_train, y_train, x_test, y_test, input_shape=(32, 32, 1), num_classes=10, teacher=teacher)

# Plot the results
plot_top_models(best_model)