<a href="https://colab.research.google.com/github/mett29/optimized-fashion-mnist/blob/main/fashion_mnist_solution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Import libraries

In [1]:
import tensorflow as tf
import numpy as np
import os
import time
import tempfile

!pip install tensorflow-model-optimization
import tensorflow_model_optimization as tfmot

print(f'Tensorflow version: {tf.__version__}')

Collecting tensorflow-model-optimization
[?25l  Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
[K     |██                              | 10kB 21.2MB/s eta 0:00:01[K     |███▉                            | 20kB 15.0MB/s eta 0:00:01[K     |█████▊                          | 30kB 9.4MB/s eta 0:00:01[K     |███████▋                        | 40kB 7.9MB/s eta 0:00:01[K     |█████████▌                      | 51kB 5.3MB/s eta 0:00:01[K     |███████████▍                    | 61kB 6.2MB/s eta 0:00:01[K     |█████████████▎                  | 71kB 6.2MB/s eta 0:00:01[K     |███████████████▏                | 81kB 6.1MB/s eta 0:00:01[K     |█████████████████               | 92kB 6.0MB/s eta 0:00:01[K     |███████████████████             | 102kB 6.5MB/s eta 0:00:01[K     |████████████████████▉           | 112kB 6.5MB/s eta 0:00:01[K     |█████

## Baseline model

In [2]:
# Load the dataset.
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# Normalize
train_images = train_images / 255.0
test_images = test_images / 255.0

# Add a trailing unitary dimension to make a 3D multidimensional array (tensor).
# N x 28 x 28 --> N x 28 x 28 x 1
train_images = np.expand_dims(train_images, -1)
test_images = np.expand_dims(test_images, -1)

# Convert the labels from integers to one-hot encoding.
train_labels = tf.keras.utils.to_categorical(train_labels, 10)
test_labels = tf.keras.utils.to_categorical(test_labels, 10)

LR = 1E-3 
EPOCHS = 10
BATCH_SIZE = 64

def build_model(input_shape):
    model = tf.keras.models.Sequential()

    model.add(tf.keras.layers.InputLayer(input_shape=input_shape))
    model.add(tf.keras.layers.Conv2D(32, (5, 5), padding='same', activation='relu'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
    model.add(tf.keras.layers.Dropout(0.25))

    model.add(tf.keras.layers.Conv2D(64, (5, 5), padding='same', activation='relu'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(tf.keras.layers.Dropout(0.25))

    model.add(tf.keras.layers.Conv2D(128, (5, 5), padding='same', activation='relu'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2,2)))
    model.add(tf.keras.layers.Dropout(0.25))

    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(128))
    model.add(tf.keras.layers.Activation('elu'))
    model.add(tf.keras.layers.Dropout(0.5))
    model.add(tf.keras.layers.Dense(10))
    model.add(tf.keras.layers.Activation('softmax'))

    return model

def train(train_images, train_labels):
    """
    Train the model given the dataset and the global parameters (LR, EPOCHS and BATCH_SIZE).
    The model is automalically saved after the training.

    """
    model = build_model(train_images.shape[1:])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
        loss='categorical_crossentropy',
        metrics=['categorical_accuracy'],
    )

    start_time = time.time()

    model.fit(
        x=train_images.astype(np.float32),
        y=train_labels.astype(np.float32),
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
    )

    end_time = time.time()
    print("Train elapsed time: {} seconds".format(end_time - start_time))

    model.save("baseline_model.tf", overwrite=True)    


def test(test_images, test_labels, model_path):
    """
    Load the saved model and evaluate it against the test set.
    """
    model = tf.keras.models.load_model(model_path)

    start_time = time.time()

    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print("Test Loss: {} - Test Accuracy: {}".format(test_loss, test_acc))

    end_time = time.time()
    print("Test elapsed time: {} seconds".format(end_time - start_time))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


Train the baseline model

In [3]:
train(train_images, train_labels)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Train elapsed time: 100.33633255958557 seconds
INFO:tensorflow:Assets written to: baseline_model.tf/assets


Evaluate the baseline model

In [4]:
model_path = "./baseline_model.tf"
test(test_images, test_labels, model_path)

Test Loss: 0.2488774210214615 - Test Accuracy: 0.9111999869346619
Test elapsed time: 1.4511198997497559 seconds


## Weight Pruning

In [5]:
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Load the baseline
model = tf.keras.models.load_model("./baseline_model.tf")

# Compute end step to finish pruning after 2 epochs.
batch_size = 64
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set.

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense`  to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    model,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LR),
        loss='categorical_crossentropy',
        metrics=['categorical_accuracy'],
)



Fine-tune model for pruning

In [6]:
logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.fit(train_images.astype(np.float32), train_labels.astype(np.float32),
                  batch_size=BATCH_SIZE, epochs=epochs, validation_split=validation_split, callbacks=callbacks)

Epoch 1/2
Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7f90fca945d0>

Evaluate the weight pruned model

In [7]:
_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)
print("Weight pruned model test accuracy: {}".format(model_for_pruning_accuracy))

Weight pruned model test accuracy: 0.9139999747276306


Save the weight pruned model

In [8]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning) # Necessary step to see the benefits of the pruning
model_for_export.save("weight_pruned_model.tf", overwrite=True)

INFO:tensorflow:Assets written to: weight_pruned_model.tf/assets


Convert to TFLite

In [9]:
def convert_to_tflite(saved_model_dir, output_filename, use_quantization=False):
  # Convert the model
  converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
  if use_quantization:
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
  tflite_model = converter.convert()

  # Save the model
  with open(output_filename + '.tflite', 'wb') as f:
    f.write(tflite_model)

  return tflite_model

def get_gzipped_model_size(file):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file) / float(2**20)

Convert the model and apply quantization

In [10]:
output_filename = "weight_pruned_and_quantized_model"
weight_pruned_tflite_model = convert_to_tflite("./weight_pruned_model.tf", output_filename, use_quantization=True)

Convert also the baseline model

In [11]:
output_filename = "baseline_model"
baseline_tflite_model = convert_to_tflite("./baseline_model.tf", output_filename)

Compare the size

In [12]:
print("Size of gzipped baseline tflite model: %.2f MB" % (get_gzipped_model_size("./baseline_model.tflite")))
print("Size of gzipped weight pruned and quantized tflite model: %.2f MB" % (get_gzipped_model_size("./weight_pruned_and_quantized_model.tflite")))

Size of gzipped baseline tflite model: 1.44 MB
Size of gzipped weight pruned and quantized tflite model: 0.32 MB


## Check persistency of accuracy

In [13]:
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on ever y image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  # Revert from categorical to numerical
  test_labels_numerical = np.argmax(test_labels, axis=-1)
  accuracy = (prediction_digits == test_labels_numerical).mean()
  return accuracy

Note that it will take some time because of the quantization, which, as explained in the documentation, seems to slow down the inference time on desktop CPUs/GPUs. It is instead beneficial in a mobile setting.

In [14]:
interpreter = tf.lite.Interpreter(model_content=weight_pruned_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Weight pruned and quantized TFLite test accuracy:', test_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Weight pruned and quantized TFLite test accuracy: 0.9138
