# Neural Architecture Search with Knowledge Distillation for ACAM

## Overview

This notebook implements a Neural Architecture Search (NAS) framework combined with Knowledge Distillation, aimed at finding efficient model architectures for TinyML applications. The primary goal is to discover compact, high-performing neural network architectures that can be deployed on resource-constrained devices.

Key components and techniques:

1. **Neural Architecture Search (NAS)**: Systematically explores various model architectures, focusing on configurations suitable for small devices.

2. **Knowledge Distillation**: Utilises a larger, pre-trained teacher model to guide the training of smaller student models, potentially improving their performance.

3. **Model Pruning**: Applies network pruning to reduce model size and potentially improve generalization.

4. **Quantization**: Implements quantization to further reduce model size and improve inference speed.

5. **CIFAR-10 Dataset**: Uses the CIFAR-10 dataset (converted to grayscale) as a benchmark for evaluating model performance.

6. **Efficiency Metrics**: Considers both model accuracy and size to identify the best architectures for TinyML applications.



# Setup and Imports

This cell imports the necessary libraries and modules for our neural architecture search (NAS) with knowledge distillation.

In [None]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import itertools
import tempfile
import os
import tensorflow_model_optimization as tfmot
from sklearn.model_selection import train_test_split
from collections import deque
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
from joblib import Parallel, delayed
from collections import defaultdict


# Distiller Class

This class implements the knowledge distillation process. It combines the loss from the student model's predictions and the distillation loss from comparing the student's softened predictions to the teacher's softened predictions.

In [None]:
# Distiller class
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.teacher = teacher
        self.student = student

    def compile(self, optimizer, metrics, student_loss_fn, distillation_loss_fn, alpha=0.1, temperature=3):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distillation_loss_fn = distillation_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data
        teacher_predictions = self.teacher(x, training=False)
        with tf.GradientTape() as tape:
            student_predictions = self.student(x, training=True)
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1)
            ) * (self.temperature ** 2)
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
        trainable_vars = self.student.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        self.compiled_metrics.update_state(y, student_predictions)
        return {m.name: m.result() for m in self.metrics}

    def call(self, inputs, training=False):
        return self.student(inputs, training=training)



# Search Space Definition

Here we define the possible layers and parameters that our neural architecture search will explore.

In [None]:
# Define the search space
layer_types = ['Conv2D', 'Dense', 'Flatten', 'MaxPool2D', 'GlobalAveragePooling2D']
conv_filters = [16, 32, 64]
dense_units = [64, 128, 256]
max_layers = 5



# Model Creation Function

This function creates a Keras model based on a given configuration. It ensures that the model structure is valid, handling the transition from convolutional to dense layers appropriately.

In [None]:
# Function to create a model given a configuration
def create_model(config, input_shape, num_classes, add_softmax=False):
    model = keras.Sequential()
    model.add(keras.Input(shape=input_shape))

    has_flattened = False
    for layer in config[:-1]:  # Process all layers except the last one
        if layer[0] == 'Conv2D':
            if has_flattened:
                continue
            model.add(keras.layers.Conv2D(layer[1], (3, 3), activation='relu', padding='same'))
        elif layer[0] == 'Dense':
            if not has_flattened:
                model.add(keras.layers.Flatten())
                has_flattened = True
            model.add(keras.layers.Dense(layer[1], activation='relu'))
        elif layer[0] == 'Flatten':
            if not has_flattened:
                model.add(keras.layers.Flatten())
                has_flattened = True
        elif layer[0] == 'MaxPool2D':
            if not has_flattened:
                model.add(keras.layers.MaxPool2D((2, 2)))
        elif layer[0] == 'GlobalAveragePooling2D':
            if not has_flattened:
                model.add(keras.layers.GlobalAveragePooling2D())
                has_flattened = True

    # Ensure the model is flattened before the final layer
    if not has_flattened:
        model.add(keras.layers.Flatten())

    # Final layer
    if add_softmax:
        model.add(keras.layers.Dense(num_classes, activation='softmax'))
#     else:
#         model.add(keras.layers.Dense(num_classes))

    return model



# ACAM functions

In [None]:
# ACAM functions
def normalize_similarity(similarities):
    min_sim = min(s['similarity'] for s in similarities.values())
    max_sim = max(s['similarity'] for s in similarities.values())
    for key in similarities:
        similarities[key]['normalized_similarity'] = (similarities[key]['similarity'] - min_sim) / (max_sim - min_sim)
    return similarities

def acam_match_window_search(templates, query):
    similarities = {}
    for key, value in templates.items():
        center = value["center"]
        lower_bounds = value["lower_bounds"]
        upper_bounds = value["upper_bounds"]

        in_range = (query >= lower_bounds) & (query <= upper_bounds)
        hit_ratio = np.mean(in_range)

        distance = np.sum(np.square(np.clip(lower_bounds - query, 0, None)) +
                          np.square(np.clip(query - upper_bounds, 0, None)))

        similarity = 1 / (1 + distance)
        similarities[key] = {"similarity": similarity, "distance": distance, "hit_ratio": hit_ratio}

    return similarities

def generate_binary_templates_with_bounds(feature_maps, labels, num_classes, bound_width=0.5, threshold=0.0):
    binary_feature_maps = (feature_maps > threshold).astype(int)
    templates = {}

    for class_idx in range(num_classes):
        class_mask = np.argmax(labels, axis=1) == class_idx
        class_feature_maps = binary_feature_maps[class_mask]

        if len(class_feature_maps) == 0:
            continue

        center = np.round(np.mean(class_feature_maps, axis=0)).astype(int)
        std_dev = np.std(class_feature_maps, axis=0)

        lower_bounds = np.clip(center - bound_width * std_dev, 0, 1)
        upper_bounds = np.clip(center + bound_width * std_dev, 0, 1)

        templates[(class_idx, 0)] = {
            "center": center,
            "lower_bounds": lower_bounds,
            "upper_bounds": upper_bounds
        }

    return templates

def extract_features(model, x):
    return model.predict(x)

def extract_feature_maps(interpreter, input_details, output_details, input_data):
    input_shape = input_details[0]['shape']
    interpreter.set_tensor(input_details[0]['index'], input_data)
    interpreter.invoke()
    output_data = interpreter.get_tensor(output_details[0]['index'])
    return output_data

def feature_count_matching(templates, query):
    match_counts = {}
    for key, value in templates.items():
        center = value["center"]
        count = np.sum(query == center)
        match_counts[key] = {"count": count, "class": key[0]}
    return match_counts

def extract_feature_map_parallel(model, x_train_sample, idx):
    print(f"Extracting feature map for training sample {idx + 1}/{len(x_train)}")
    return extract_features(model, x_train_sample[np.newaxis, ...])

def binarize_features(features, threshold=0.0):
    return (features > threshold).astype(int)

def pattern_match(binary_features, templates):
    similarities = acam_match_window_search(templates, binary_features)
    return max(similarities, key=lambda k: similarities[k]['similarity'])[0]

def evaluate_with_pattern_matching(model, templates, x_val, y_val):
    correct_predictions = 0
    for x, y in zip(x_val, y_val):
        features = extract_features(model, x[np.newaxis, ...])
        binary_features = binarize_features(features[0])
        predicted_class = pattern_match(binary_features, templates)
        if predicted_class == np.argmax(y):
            correct_predictions += 1
    return correct_predictions / len(y_val)

def quick_evaluate(model, x_train, y_train, x_val, y_val, epochs=5, patience=2):
    early_stop = keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=patience)
    history = model.fit(
        x_train, y_train,
        epochs=epochs,
        validation_data=(x_val, y_val),
        callbacks=[early_stop],
        verbose=0
    )
    return history.history['val_accuracy'][-1]



# Neural Architecture Search Function

This function performs a comprehensive neural architecture search with knowledge distillation, pruning, and quantization.

1. Search Space Definition:
   - Defines the range of architectural choices (e.g., number of convolutional blocks, layers per block, initial filters, dense layers).

2. Data Splitting:
   - Splits the training data to create a smaller subset for quick evaluation.

3. Architecture Generation:
   - Iterates through different combinations of architectural parameters to generate various model configurations.

4. Quick Evaluation:
   - For each generated architecture, performs a rapid evaluation using a subset of the data.
   - This step helps to quickly filter out poorly performing architectures without spending time on full training.

5. Knowledge Distillation:
   - For promising architectures (those passing the quick evaluation threshold), applies knowledge distillation.
   - Uses a pre-trained teacher model to guide the training of the student (generated) model.
   - Combines the standard cross-entropy loss with a distillation loss that encourages the student to mimic the teacher's softened outputs.

6. Pruning:
   - Applies network pruning to reduce model size and potentially improve generalization.
   - Uses polynomial decay pruning schedule, gradually increasing sparsity over time.

7. Quantization:
   - Applies quantization to further reduce model size and improve inference speed.
   - Converts the model to TensorFlow Lite format with default optimizations.

8. Final Evaluation:
   - Assesses the pruned and quantized model on the test set to get the final accuracy.

9. Model Tracking:
   - Keeps track of the top 5 models based on accuracy and model size.

In [None]:
def search_architectures(x_train, y_train, x_test, y_test, input_shape, num_classes, teacher):
    best_models = deque(maxlen=5)

    # Define the search space
    conv_blocks = [1, 2, 3]
    conv_layers_per_block = [1, 2]
    initial_filters = [16, 32, 64]
    dense_layers = [0, 1, 2, 3]
    dense_units_options = [64, 128, 256, 512]

    total_configs = len(conv_blocks) * len(conv_layers_per_block) * len(initial_filters) * len(dense_layers) * len(dense_units_options)
    config_count = 0

    for n_blocks in conv_blocks:
        for layers_per_block in conv_layers_per_block:
            for init_filters in initial_filters:
                for n_dense in dense_layers:
                    for dense_units in dense_units_options:
                        config_count += 1
                        print(f"\nTesting configuration {config_count}/{total_configs}")

                        # Build the model configuration
                        config = []
                        filters = init_filters
                        for _ in range(n_blocks):
                            for _ in range(layers_per_block):
                                config.append(('Conv2D', filters))
                            config.append(('MaxPool2D', None))
                            filters *= 2

                        config.append(('GlobalAveragePooling2D', None))

                        for _ in range(n_dense):
                            config.append(('Dense', dense_units))
                            dense_units = max(dense_units // 2, 64)

                        # Add final layer
#                         config.append(('Dense', num_classes))

                        try:
                            # Create model with softmax for quick evaluation
                            student = create_model(config, input_shape, num_classes, add_softmax=True)
                            student.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
                        except:
                            print("Failed to create model with this configuration. Skipping...")
                            continue

                        # Quick evaluation with softmax
                        quick_accuracy = quick_evaluate(student, x_train, y_train, x_test, y_test)
                        print(f"Quick evaluation accuracy: {quick_accuracy:.4f}")

                        if quick_accuracy < 0.7:
                            print("Accuracy below threshold. Skipping...")
                            continue

                        # Create model without softmax for further processing
                        student = create_model(config, input_shape, num_classes, add_softmax=False)

#                         New
                        student.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
                        student.fit(x_train, y_train, epochs=10, validation_split=0.2, verbose=0)


                        # Apply knowledge distillation
                        distiller = Distiller(student=student, teacher=teacher)
                        distiller.compile(
                            optimizer=keras.optimizers.Adam(),
                            metrics=[keras.metrics.MeanSquaredError()],
                            student_loss_fn=keras.losses.MeanSquaredError(),
                            distillation_loss_fn=keras.losses.MeanSquaredError(),
                            alpha=0.1,
                            temperature=3,
                        )
                        distiller.fit(x_train, y_train, epochs=10, validation_split=0.2, verbose=0)

                        # Apply pruning
                        pruning_params = {
                            'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
                                initial_sparsity=0.30,
                                final_sparsity=0.50,
                                begin_step=0,
                                end_step=len(x_train) * 10
                            )
                        }
                        pruned_model = tfmot.sparsity.keras.prune_low_magnitude(distiller.student, **pruning_params)
                        pruned_model.compile(
                            optimizer=keras.optimizers.Adam(),
                            loss=keras.losses.MeanSquaredError(),
                            metrics=[keras.metrics.MeanSquaredError()],
                        )
                        with tf.device('/cpu:0'):
                            pruned_model.fit(x_train, y_train, epochs=1, validation_split=0.2, callbacks=[
                                tfmot.sparsity.keras.UpdatePruningStep(),
                                tfmot.sparsity.keras.PruningSummaries(log_dir=tempfile.mkdtemp()),
                            ], verbose=0)

                        # Apply quantization-aware training
                        quantize_model = tfmot.quantization.keras.quantize_model
                        stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)
                        q_aware_model = quantize_model(stripped_pruned_model)

                        q_aware_model.compile(
                            optimizer=keras.optimizers.Adam(),
                            loss=keras.losses.MeanSquaredError(),
                            metrics=[keras.metrics.MeanSquaredError()],
                        )

                        q_aware_model.fit(x_train, y_train, epochs=5, validation_split=0.2, verbose=0)

                        # Generate templates and evaluate using pattern matching
                        print("Generating templates...")
#                         train_features = extract_features(q_aware_model, x_train)







                        # Extract feature maps for all training samples in parallel


                        print("Extracting feature maps for all training samples...")
                        train_feature_maps = q_aware_model.predict(x_train)
                        train_feature_maps = train_feature_maps.reshape(len(x_train), -1)
                        print("Feature map extraction completed.")



                        def process_feature_map(feature_map, idx):
                            print(f"Processing feature map for training sample {idx + 1}/{len(x_train)}")
                            return feature_map  # You can add any additional processing here if needed

                        print("Starting parallel processing of feature maps...")
                        processed_feature_maps = Parallel(n_jobs=-1)(
                            delayed(process_feature_map)(train_feature_maps[i], i) for i in range(len(x_train))
                        )
                        processed_feature_maps = np.array(processed_feature_maps)
                        print("Parallel processing of feature maps completed.")









                        templates = generate_binary_templates_with_bounds(processed_feature_maps, y_train, num_classes, bound_width=0.5)


                        print("Evaluating with pattern matching...")


                        correct_predictions_similarity = 0
                        correct_predictions_count = 0
                        num_test_samples = len(x_test)
                        class_correct = defaultdict(int)
                        class_total = defaultdict(int)
                        true_labels = []
                        predicted_labels = []



                        print("Starting testing on test samples...")
                        for i in range(num_test_samples):
                            test_sample_feature_map = extract_features(q_aware_model, x_test[i][np.newaxis, ...])
                            test_sample_vector = test_sample_feature_map.flatten()

                            # Binarize the test sample
                            binary_test_sample = (test_sample_vector > 0).astype(int)

                            # Similarity-based method
                            similarities = acam_match_window_search(templates, binary_test_sample)

                            predicted_template_similarity = max(similarities, key=lambda k: similarities[k]['similarity'])[0]
                            predicted_label_similarity = predicted_template_similarity

                            # Feature count method
                            match_counts = feature_count_matching(templates, binary_test_sample)
                            predicted_template_count = max(match_counts, key=lambda k: match_counts[k]['count'])
                            predicted_label_count = match_counts[predicted_template_count]['class']

                            true_label = y_test[i].argmax()
                            class_total[true_label] += 1

                            if predicted_label_similarity == true_label:
                                class_correct[true_label] += 1
                                correct_predictions_similarity += 1

                            if predicted_label_count == true_label:
                                correct_predictions_count += 1

                            true_labels.append(true_label)
                            predicted_labels.append(predicted_label_similarity)

                            print(f"Test sample {i + 1}/{num_test_samples}:")
                            print(f"True label: {true_label}")
                            print(f"Predicted label (similarity): {predicted_label_similarity}")
                            print(f"Predicted label (feature count): {predicted_label_count}")

                        accuracy_similarity = correct_predictions_similarity / num_test_samples
                        accuracy_count = correct_predictions_count / num_test_samples
                        print("Testing completed.")
                        print(f"Accuracy (similarity-based): {accuracy_similarity:.4f}")
                        print(f"Accuracy (feature count-based): {accuracy_count:.4f}")

#                         accuracy = evaluate_with_pattern_matching(q_aware_model, templates, x_test, y_test)

                        # Convert to TFLite
                        converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
                        converter.optimizations = [tf.lite.Optimize.DEFAULT]
                        quantized_tflite_model = converter.convert()

                        # Save and get model size
                        _, tflite_file = tempfile.mkstemp('.tflite')
                        with open(tflite_file, 'wb') as f:
                            f.write(quantized_tflite_model)

                        model_size = os.path.getsize(tflite_file) / float(2**20)  # Size in MB
                        num_params = q_aware_model.count_params()

                        print(f"Configuration: {config}")
                        print(f"Pattern matching accuracy: {accuracy_count:.4f}")
                        print(f"Model size: {model_size:.2f} MB")
                        print(f"Number of parameters: {num_params}")
                        print("--------------------")

                        best_models.append((q_aware_model, accuracy_count, model_size, num_params, config))
                        best_models = deque(sorted(best_models, key=lambda x: (x[1], -x[2]), reverse=True)[:5])

    return best_models



# Data Preparation and Teacher Model Creation

This section loads and preprocesses the CIFAR-10 dataset, converting it to grayscale. It also defines and trains the teacher model that will be used for knowledge distillation.

In [None]:
# Data Preprocessing
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Convert images to greyscale
x_train = np.dot(x_train[..., :3], [0.2989, 0.5870, 0.1140])
x_test = np.dot(x_test[..., :3], [0.2989, 0.5870, 0.1140])

# Normalize and reshape
x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 32, 32, 1))
x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 32, 32, 1))

# Convert labels to categorical
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# Use only 1000 samples for training and testing
x_train, _, y_train, _ = train_test_split(x_train, y_train, stratify=y_train, random_state=42)
x_test, _, y_test, _ = train_test_split(x_test, y_test, stratify=y_test, random_state=42)



In [None]:
# Create the teacher model
teacher = keras.Sequential([
    keras.Input(shape=(32, 32, 1)),
    keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D((2, 2)),
    keras.layers.Dropout(0.2),
    keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D((2, 2)),
    keras.layers.Dropout(0.3),
    keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_uniform', padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.MaxPool2D((2, 2)),
    keras.layers.Dropout(0.4),
    keras.layers.Flatten(),
    keras.layers.Dense(512, activation='relu', kernel_initializer='he_uniform'),
    keras.layers.Dense(256, activation='relu', kernel_initializer='he_uniform'),
    keras.layers.Dense(128, activation='relu', kernel_initializer='he_uniform'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(10, activation='softmax'),
])

# Compile and train the teacher
teacher.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
teacher.fit(x_train, y_train, epochs=20, validation_split=0.2, verbose=1)



Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x155dbc03760>

In [None]:
# Run the architecture search
best_models = search_architectures(x_train, y_train, x_test, y_test, input_shape=(32, 32, 1), num_classes=num_classes, teacher=teacher)
for i, (model, accuracy, size, params, config) in enumerate(best_models, 1):
    print(f"Model {i}:")
    print(f"  Accuracy: {accuracy:.4f}")
    print(f"  Size: {size:.2f} MB")
    print(f"  Parameters: {params}")
    print(f"  Configuration: {config}")
    print()




Testing configuration 1/288
Quick evaluation accuracy: 0.5391
Accuracy below threshold. Skipping...

Testing configuration 2/288
Quick evaluation accuracy: 0.5215
Accuracy below threshold. Skipping...

Testing configuration 3/288
Quick evaluation accuracy: 0.5368
Accuracy below threshold. Skipping...

Testing configuration 4/288
Quick evaluation accuracy: 0.5204
Accuracy below threshold. Skipping...

Testing configuration 5/288
Quick evaluation accuracy: 0.2244
Accuracy below threshold. Skipping...

Testing configuration 6/288
Quick evaluation accuracy: 0.2093
Accuracy below threshold. Skipping...

Testing configuration 7/288
Quick evaluation accuracy: 0.2143
Accuracy below threshold. Skipping...

Testing configuration 8/288
Quick evaluation accuracy: 0.2369
Accuracy below threshold. Skipping...

Testing configuration 9/288
Quick evaluation accuracy: 0.2853
Accuracy below threshold. Skipping...

Testing configuration 10/288
Quick evaluation accuracy: 0.2845
Accuracy below threshold. S

Quick evaluation accuracy: 0.6159
Accuracy below threshold. Skipping...

Testing configuration 83/288
Quick evaluation accuracy: 0.6145
Accuracy below threshold. Skipping...

Testing configuration 84/288
Quick evaluation accuracy: 0.6197
Accuracy below threshold. Skipping...

Testing configuration 85/288
Quick evaluation accuracy: 0.3825
Accuracy below threshold. Skipping...

Testing configuration 86/288
Quick evaluation accuracy: 0.3801
Accuracy below threshold. Skipping...

Testing configuration 87/288
Quick evaluation accuracy: 0.3849
Accuracy below threshold. Skipping...

Testing configuration 88/288
Quick evaluation accuracy: 0.3627
Accuracy below threshold. Skipping...

Testing configuration 89/288
Quick evaluation accuracy: 0.4297
Accuracy below threshold. Skipping...

Testing configuration 90/288
Quick evaluation accuracy: 0.4333
Accuracy below threshold. Skipping...

Testing configuration 91/288
Quick evaluation accuracy: 0.4459
Accuracy below threshold. Skipping...

Testing c

# Visualisation Function



In [None]:
# Visualization function
def plot_top_models(top_models):
    accuracies = [model[1] for model in top_models]
    params = [model[3] for model in top_models]
    sizes = [model[2] for model in top_models]

    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(params, accuracies, c=sizes, s=100, cmap='viridis')
    plt.colorbar(scatter, label='Model Size (MB)')

    plt.xscale('log')
    plt.xlabel('Number of Parameters')
    plt.ylabel('Accuracy')
    plt.title('Top 5 Models: Accuracy vs Number of Parameters')

    for i, model in enumerate(top_models):
        plt.annotate(f"Model {i+1}", (params[i], accuracies[i]), xytext=(5, 5), textcoords='offset points')

    plt.tight_layout()
    plt.show()

# Plot the results
plot_top_models(best_models)



In [None]:
# Evaluate the best model
best_model = best_models[0][0]
train_features = extract_features(best_model, x_train)
templates = generate_binary_templates_with_bounds(train_features, y_train, num_classes)

y_true = []
y_pred = []

for x, y in zip(x_test, y_test):
    features = extract_features(best_model, x[np.newaxis, ...])
    binary_features = binarize_features(features[0])
    predicted_class = pattern_match(binary_features, templates)
    y_true.append(np.argmax(y))
    y_pred.append(predicted_class)



In [None]:
# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()



In [None]:
# Classification report
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))



In [None]:
# Print top misclassifications
def print_top_misclassifications(cm, class_names, top_n=5):
    misclassifications = []
    for i in range(len(cm)):
        for j in range(len(cm)):
            if i != j:
                misclassifications.append((i, j, cm[i, j]))

    misclassifications.sort(key=lambda x: x[2], reverse=True)

    print(f"Top {top_n} misclassifications:")
    for true, pred, count in misclassifications[:top_n]:
        print(f"True: {class_names[true]}, Predicted: {class_names[pred]}, Count: {count}")

class_names = [str(i) for i in range(num_classes)]
print_top_misclassifications(cm, class_names)



In [None]:
# Plot accuracy by class
class_accuracy = {}
for i in range(num_classes):
    class_accuracy[i] = cm[i, i] / np.sum(cm[i])

plt.figure(figsize=(12, 6))
plt.bar(class_names, list(class_accuracy.values()))
plt.xlabel('Class')
plt.ylabel('Accuracy')
plt.title('Accuracy by Class')
plt.ylim(0, 1)
for i, v in enumerate(class_accuracy.values()):
    plt.text(i, v + 0.01, f'{v:.2f}', ha='center')
plt.tight_layout()
plt.show()



In [None]:
# Template overlap analysis
def calculate_template_overlap(templates):
    num_classes = len(templates)
    overlap_matrix = np.zeros((num_classes, num_classes))

    for i in range(num_classes):
        for j in range(num_classes):
            if (i, 0) not in templates or (j, 0) not in templates:
                continue

            template_i = templates[(i, 0)]
            template_j = templates[(j, 0)]

            lower_i, upper_i = template_i['lower_bounds'], template_i['upper_bounds']
            lower_j, upper_j = template_j['lower_bounds'], template_j['upper_bounds']

            min_upper = np.minimum(upper_i, upper_j)
            max_lower = np.maximum(lower_i, lower_j)
            overlap = np.maximum(0, min_upper - max_lower)

            total_range = np.maximum(upper_i, upper_j) - np.minimum(lower_i, lower_j)
            overlap_ratio = np.mean(overlap / (total_range + 1e-10))

            overlap_matrix[i, j] = overlap_ratio

    return overlap_matrix

overlap_matrix = calculate_template_overlap(templates)
plt.figure(figsize=(12, 10))
sns.heatmap(overlap_matrix, annot=True, cmap='YlOrRd', vmin=0, vmax=1, fmt='.2f')
plt.title('Overlap between Class Templates')
plt.xlabel('Template Class')
plt.ylabel('Template Class')
plt.show()

print("Average overlap for each class:")
for i in range(len(overlap_matrix)):
    avg_overlap = np.mean(overlap_matrix[i, :])
    print(f"Class {i}: {avg_overlap:.4f}")

print("\nClasses with highest mutual overlap:")
class_pairs = [(i, j) for i in range(len(overlap_matrix)) for j in range(i+1, len(overlap_matrix))]
class_pairs.sort(key=lambda x: overlap_matrix[x[0], x[1]], reverse=True)
for i, j in class_pairs[:5]:
    print(f"Class {i} and Class {j}: {overlap_matrix[i, j]:.4f}")

avg_overlaps = np.mean(overlap_matrix, axis=1)
most_overlapping_class = np.argmax(avg_overlaps)
print(f"\nClass with highest average overlap: {most_overlapping_class}")
print(f"Average overlap: {avg_overlaps[most_overlapping_class]:.4f}")