# CIFAR10 Training and quantization

Project inspired by this [kaggle competition](https://www.kaggle.com/ektasharma/simple-cifar10-cnn-keras-code-with-88-accuracy#A-Simple-Keras-CNN-trained-on-CIFAR-10-dataset-with-over-88%-accuracy-(Without-Data-Augmentation). In this colab you will be driven through the training of a CNN for CIFAR10 classification task. The model is then exported for inference with tflite and quantized, prepared for the *GAPflow*.

### Used python modules

In [20]:
from keras_model import *
from tqdm import tqdm
import tensorflow as tf
import keras
from keras import datasets
from keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import os
from PIL import Image

In [None]:
import tensorflow as tf
tf.__version__, tf.test.is_gpu_available()

# Reading the CIFAR-10 dataset from Keras datasets

In [22]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
train_images, train_labels = train_images[:10000], train_labels[:10000]
test_images, test_labels = test_images[:1000], test_labels[:1000]

In [None]:
# Checking the number of rows (records) and columns (features)
print(train_images.shape)
print(train_labels.shape)
print(test_images.shape)
print(test_labels.shape)

In [None]:
# Checking the number of unique classes 
print(np.unique(train_labels))
print(np.unique(test_labels))

In [25]:
# Creating a list of all the class labels
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
# Visualizing some of the images from the training dataset
plt.figure(figsize=[10,10])
for i in range (25):    # for first 25 images
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i][0]])

plt.show()

# Data Preprocessing

Convert input [0:255] to [-1:1] float

Convert labels to one hot encoding

In [27]:
# Converting the pixels data to float type
train_images = train_images.astype('float32')
test_images = test_images.astype('float32')

# Standardizing (255 is the total number of pixels an image can have)
train_images = (train_images / 128) - 1.0
test_images = (test_images / 128) - 1.0

# One hot encoding the target class (labels)
num_classes = 10
train_labels = to_categorical(train_labels, num_classes)
test_labels = to_categorical(test_labels, num_classes)

## Model Design

In [None]:
MODEL_VERSION = 4

if MODEL_VERSION == 1:
    model = model_v1()
    model_name = "v1"
elif MODEL_VERSION == 2:
    model = model_v2()
    model_name = "v2"
elif MODEL_VERSION == 3:
    model = model_v3()
    model_name = "v3"
elif MODEL_VERSION == 4:
    model = model_v4()
    model_name = "v4"
# elif MODEL_VERSION == 5:
#     model = model_v5()
#     model_name = "v5"

# Checking the model summary
model.summary()

# Compile and Train the model (or load the pretrained one)

In [49]:
#checkpoint_path = "gdrive/MyDrive/cifar10/saved_model/my_model"
checkpoint_path = f"./checkpoints/saved_model_{model_name}/"
train_again = False

if os.path.exists(checkpoint_path) and not train_again:
    model = tf.keras.models.load_model(checkpoint_path)
    history = None
else:
    model.compile(optimizer='adam', loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])
    history = model.fit(train_images, train_labels, batch_size=128, epochs=10, # Add more epochs to get better results
                      validation_data=(test_images, test_labels))
    model.save(checkpoint_path)

In [None]:
# Making the Predictions
pred = model.predict(test_images)
accuracy = 100 * np.sum(np.argmax(pred, 1) == np.argmax(test_labels, 1)) / len(test_labels)
print(f"Trained model Accuracy: {accuracy}%")

# Visualize training results

Only available if you have trained the model in this session

In [None]:
if history:
    # Loss curve
    plt.figure(figsize=[6,4])
    plt.plot(history.history['loss'], 'black', linewidth=2.0)
    plt.plot(history.history['val_loss'], 'green', linewidth=2.0)
    plt.legend(['Training Loss', 'Validation Loss'], fontsize=14)
    plt.xlabel('Epochs', fontsize=10)
    plt.ylabel('Loss', fontsize=10)
    plt.title('Loss Curves', fontsize=12)

In [None]:
if history:
    # Accuracy curve
    plt.figure(figsize=[6,4])
    plt.plot(history.history['accuracy'], 'black', linewidth=2.0)
    plt.plot(history.history['val_accuracy'], 'blue', linewidth=2.0)
    plt.legend(['Training Accuracy', 'Validation Accuracy'], fontsize=14)
    plt.xlabel('Epochs', fontsize=10)
    plt.ylabel('Accuracy', fontsize=10)
    plt.title('Accuracy Curves', fontsize=12)

# See the model at work

In [None]:
# Plotting the Actual vs. Predicted results
# Converting the predictions into label index 
pred_classes = np.argmax(pred, axis=1)

fig, axes = plt.subplots(5, 5, figsize=(15,15))
axes = axes.ravel()

for i in np.arange(0, 25):
    axes[i].imshow((128*(test_images[i]+1)).astype(np.uint8))
    axes[i].set_title(f"True: {class_names[np.argmax(test_labels[i])]}\nPredict: {class_names[pred_classes[i]]}", color="r" if class_names[np.argmax(test_labels[i])]!=class_names[pred_classes[i]] else "b")
    axes[i].axis('off')
    plt.subplots_adjust(wspace=1)
    img = Image.fromarray(np.uint8(128*(test_images[i]+1)))
    img.save(f"samples/cifar_test_{i}_{np.argmax(test_labels[i])}.ppm")


# Convert to tflite

In [43]:
# Helper function to run inference on a TFLite model
def test_tflite_model(tflite_file, test_images, test_labels):
    # Initialize the interpreter
    interpreter = tf.lite.Interpreter(model_path=str(tflite_file))
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]

    predictions = np.zeros((len(test_images),), dtype=int)
    for i, (test_image, test_label) in enumerate(tqdm(zip(test_images, test_labels), total=len(test_labels))):
        # Check if the input type is quantized, then rescale input data to uint8
        if input_details['dtype'] == np.uint8:
            input_scale, input_zero_point = input_details["quantization"]
            test_image = test_image / input_scale + input_zero_point

        test_image = np.expand_dims(test_image, axis=0).astype(input_details["dtype"])
        interpreter.set_tensor(input_details["index"], test_image)
        interpreter.invoke()
        output = interpreter.get_tensor(output_details["index"])[0]

        predictions[i] = output.argmax()

    test_labels_not_one_hot = np.argmax(test_labels, 1)
    accuracy = (np.sum(test_labels_not_one_hot == predictions) * 100) / len(test_images)
    return accuracy

In [44]:
tflite_model_file = pathlib.Path(f"{checkpoint_path}/cifar10_model_{model_name}_fp32.tflite")

In [None]:
# Converting a tf.Keras model to a TensorFlow Lite model.
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the unquantized/float model:
tflite_model_file.write_bytes(tflite_model)

In [None]:
fp32_accuracy = test_tflite_model(tflite_model_file, test_images, test_labels)
print(f"\nFloat model accuracy: {fp32_accuracy}")

## Quantize to int8

We use post training integer quantization to quantize the model (https://www.tensorflow.org/lite/performance/post_training_integer_quant).

Weights are quantized directly from their values (they are constants), activations on the other hand depend on the input data. Hence we need to provide a calibration dataset to the quantizer so that it can run inference on it and collect the statistics of each layer in order to quantize the values in those ranges, i.e. Layer1 -> [-3.0, 6.0], Layer2 -> [-1.0, 2.5], ...

As calibration dataset we need representative data of our use case. They cannot be the testing set, we are "learning" the statistics so using test dataset would be cheating. A subset of the training is tipycally used.

In [None]:
def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()


In [20]:
tflite_quant_model_file = pathlib.Path(f"{checkpoint_path}/cifar10_model_{model_name}_uint8.tflite")

In [None]:
# Save the unquantized/float model:
tflite_quant_model_file.write_bytes(tflite_model_quant)

In [None]:
quant_accuracy = test_tflite_model(tflite_quant_model_file, test_images[:1000], test_labels[:1000])
print(f"Quantized model accuracy: {quant_accuracy}")

## Quantize to INT8 using QAT

After performing post-training quantization, we might notice an accuracy drop between the FP32 and INT8 networks. As the weights are kept constant during PTQ and the activations' range is analysed only with respect to the input (calibration) data, such a quantization strategy, albeit quick, can render a model inefficient.

Quantization-aware training (QAT) updates the model's weights whilst reducing the bitwidth. A comprehensive guide on TensorFlow-based QAT is available [here](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide).

In [None]:
import tensorflow_model_optimization as tfmot
# Clone the trained TensorFlow model

tf_model_qat = tfmot.quantization.keras.quantize_model(model)

# We need to compile again the model
tf_model_qat.compile(optimizer='adam', loss=keras.losses.categorical_crossentropy, metrics=['accuracy'])

tf_model_qat.summary()


In [None]:
model

In [None]:
# Finetune the fake-quantized model on a subset of the training data

checkpoint_path = f"./checkpoints/saved_model_qat_{model_name}/"

history = tf_model_qat.fit(train_images, train_labels, batch_size=64, epochs=1, # Add more epochs to get better results
                  validation_data=(test_images, test_labels))
tf_model_qat.save(checkpoint_path)



In [None]:
# Evaluate the QAT model

pred = tf_model_qat.predict(test_images)
accuracy = 100 * np.sum(np.argmax(pred, 1) == np.argmax(test_labels, 1)) / len(test_labels)
print(f"Trained model Accuracy: {accuracy}%")

In [None]:
# Quantize and save the model

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter_qat = tf.lite.TFLiteConverter.from_keras_model(tf_model_qat)
converter_qat.optimizations = [tf.lite.Optimize.DEFAULT]
converter_qat.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter_qat.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter_qat.inference_input_type = tf.uint8
converter_qat.inference_output_type = tf.uint8

tflite_model_qat_quant = converter_qat.convert()

tflite_quant_qat_model_file = pathlib.Path(f"{checkpoint_path}/model/cifar10_qat_{model_name}_uint8.tflite")
tflite_quant_qat_model_file.write_bytes(tflite_model_qat_quant)

In [None]:
# Evaluate the QAT model
quant_qat_accuracy = test_tflite_model(tflite_quant_qat_model_file, test_images, test_labels)
print(f"Quantized model accuracy: {quant_qat_accuracy}")

Did the accuracy increase compared to the PTQ solution? By how much?
<br>
Is this due to the benefits of QAT? How can you tell?