In [15]:
from google.colab import drive
import os
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot
import numpy as np

# Mount Google Drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Loading and Preprocessing Data
We are now loading and preprocessing the CIFAR-10 dataset. This step involves resizing the images and normalizing them to prepare for model training. The use of a consistent data format is crucial for effective model training.

In [29]:
import tensorflow_model_optimization as tfmot

def load_or_train_model(train_ds, test_ds, model_path, is_pruned=False):
    if os.path.exists(model_path):
        print("Loading saved model...")
        if is_pruned:
            # Use the prune_scope for loading pruned models
            with tfmot.sparsity.keras.prune_scope():
                return tf.keras.models.load_model(model_path)
        else:
            return tf.keras.models.load_model(model_path)
    else:
        print("Training new model...")
        model = create_and_train_model(train_ds, test_ds)
        model.save(model_path)
        return model


# Creating and Training the Model
Here, we are creating a MobileNetV2 model, leveraging transfer learning for better performance. The base layers of the model are frozen to preserve learned features, and new layers are added for the specific task of classifying CIFAR-10 images. This step concludes with training the model on the prepared dataset.

In [17]:
def create_and_train_model(train_ds, test_ds):
    base_model = tf.keras.applications.MobileNetV2(input_shape=(160, 160, 3), include_top=False, weights='imagenet')
    base_model.trainable = False

    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.fit(train_ds, epochs=5, validation_data=test_ds)
    return model

# Converting to TensorFlow Lite with Quantization
In this section, we are converting the trained model to TensorFlow Lite format while applying quantization. Quantization reduces the model size and improves performance, making it suitable for deployment on devices with limited resources.

In [18]:
def convert_to_tflite(model):
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    return converter.convert()

# Pruning the Model
We are now applying pruning to the model. Pruning involves systematically removing weights from the model to reduce its size and complexity. The process uses a pruning schedule to determine which weights to remove and when, balancing model size and performance.

In [19]:
def apply_pruning_to_model(model, train_ds, test_ds):
    pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.0, final_sparsity=0.5, begin_step=0, end_step=1000)
    pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, pruning_schedule=pruning_schedule)

    pruned_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    pruned_model.fit(train_ds, epochs=5, validation_data=test_ds, callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])
    return pruned_model

# Evaluating the TensorFlow Lite Model
We are setting up a TensorFlow Lite interpreter and using it to evaluate the quantized model. This involves processing the test dataset and running it through the model to measure its accuracy and performance.

In [31]:
def evaluate_tflite_model(interpreter, test_ds):
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    total_seen, total_correct = 0, 0
    test_ds = test_ds.unbatch()

    for img, label in test_ds:
        img = tf.image.resize(img, [input_details[0]['shape'][1], input_details[0]['shape'][2]])
        img = tf.expand_dims(img, axis=0)
        img = tf.cast(img, tf.float32)  # Corrected line

        if img.shape != input_details[0]['shape']:
            raise ValueError(f"Expected input shape {input_details[0]['shape']}, but got {img.shape}")

        interpreter.set_tensor(input_details[0]['index'], img)
        interpreter.invoke()
        output_data = interpreter.get_tensor(output_details[0]['index'])
        predictions = np.argmax(output_data, axis=1)

        total_seen += 1
        total_correct += (predictions[0] == label.numpy())

    return total_correct / total_seen

# Final Evaluation and Comparison
Finally, we are evaluating and comparing the performance of all models - the original, the quantized, and the pruned versions. This comparison is crucial to understand the trade-offs made between model size, speed, and accuracy.

In [21]:
def load_or_train_model(train_ds, test_ds, model_path):
    if os.path.exists(model_path):
        print("Loading saved model...")
        return tf.keras.models.load_model(model_path)
    else:
        print("Training new model...")
        model = create_and_train_model(train_ds, test_ds)
        model.save(model_path)
        return model

In [32]:
# Specify paths in Google Drive for models
model_path = '/content/drive/My Drive/EdgeAI/model.h5'
pruned_model_path = '/content/drive/My Drive/EdgeAI/pruned_model.h5'

# Main execution flow
train_ds, test_ds = load_and_preprocess_data()
model = load_or_train_model(train_ds, test_ds, model_path)
tflite_model_quant = convert_to_tflite(model)
model_for_pruning = load_or_train_model(train_ds, test_ds, pruned_model_path, is_pruned=True)

interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
quant_accuracy = evaluate_tflite_model(interpreter, test_ds)

original_eval = model.evaluate(test_ds)
pruned_eval = model_for_pruning.evaluate(test_ds)

print(f"Original Model Accuracy: {original_eval[1]}")
print(f"Quantized Model Accuracy: {quant_accuracy}")
print(f"Pruned Model Accuracy: {pruned_eval[1]}")

Loading saved model...
Loading saved model...
Original Model Accuracy: 0.8147000074386597
Quantized Model Accuracy: 0.7944
Pruned Model Accuracy: 0.538100004196167
