# Quantization Aware Training

> In this post, we will learn quantization aware training. This is the summary of lecture "Applications of TinyML" from edX.

- toc: true 
- badges: true
- comments: true
- author: Chanseok Kang
- categories: [Python, edX, Deep_Learning, Tensorflow, tinyML]
- image: 

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import tempfile
import os
from tqdm import tqdm

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tensorflow_hub as hub
import tensorflow_datasets as tfds

print("Tensorflow : v" + tf.__version__)

Tensorflow : v2.3.1


There are two primary forms of quantization -- post training quantization that you had seen previously -- where, as part of the conversion process, your model’s internal weights and ops get converted to int8 and uint8. It’s much easier to use this, and it’s a good way to get started.

As you want to further optimize your model, you may want to explore quantization aware training. You got a taste of that in the previous video seeing how you could pass your simple model to an api and get a quantized model back. It turns out that there is a lot more fine-grained control available to you. To get some exposure to some of those controls, please work on the below colab provided by the TensorFlow team on Quantization-aware Training (QAT). After you are done, if you would like to go deeper, including a comprehensive guide on all the APIs available in the toolkit, check out and read through the [model optimization site](https://www.tensorflow.org/model_optimization/guide/quantization/training) on TensorFlow.org.

In particular, note the results posted by the Google teams when comparing accuracy of models before and after quantizing like this -- the effects on accuracy are negligible!

| Model | Non-quantized Top-1 Accuracy | 8-bit Quantized Accuracy |
| -- | -- | -- |
| MobileNetV1 224 | 71.03% | 71.06% |
| ResNetV1 50 | 76.3% | 76.1% |
| MobileNetV2 224 | 70.77% | 70.01% |

## Train a model for MNIST without quantization aware training

In [3]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)



<tensorflow.python.keras.callbacks.History at 0x1edb90f3208>

## Clone and fine-tune pre-trained model with quantization aware training

### Define the model

You will apply quantization aware training to the whole model and see this in the model summary. All layers are now prefixed by "quant".

Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.

In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md), you can see how to quantize some layers for model accuracy improvements.

In [4]:
quantize_model = tfmot.quantization.keras.quantize_model

# quantization aware
q_aware_model = quantize_model(model)
q_aware_model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantize_layer (QuantizeLaye (None, 28, 28)            3         
_________________________________________________________________
quant_reshape (QuantizeWrapp (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d (QuantizeWrappe (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d (Quantiz (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten (QuantizeWrapp (None, 2028)              1         
_________________________________________________________________
quant_dense (QuantizeWrapper (None, 10)                20295     
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
___________________________________________________

### Train and evaluate the model against baseline

To demonstrate fine tuning after training the model for just an epoch, fine tune with quantization aware training on a subset of the training data.

In [5]:
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)



<tensorflow.python.keras.callbacks.History at 0x1edf1d49fc8>

For this example, there is minimal to no loss in test accuracy after quantization aware training, compared to the baseline.

In [6]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9660000205039978
Quant test accuracy: 0.9653000235557556


## Create quantized model for TFLite backend

After this, you have an actually quantized model with int8 weights and uint8 activations.

In [7]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: C:\Users\kcsgo\AppData\Local\Temp\tmp3p53shgo\assets


## See persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TF Lite model on the test dataset.

In [8]:
def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]["index"]
    output_index = interpreter.get_output_details()[0]["index"]

    # Run predictions on every image in the "test" dataset.
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        if i % 1000 == 0:
            print('Evaluated on {n} results so far.'.format(n=i))
        # Pre-processing: add batch dimension and convert to float32 to match with
        # the model's input data format.
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)

        # Run inference.
        interpreter.invoke()

        # Post-processing: remove batch dimension and find the digit with highest
        # probability.
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)

    print('\n')
    
    # Compare prediction results with ground truth labels to calculate accuracy.
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

In [9]:
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()

test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.9653
Quant TF test accuracy: 0.9653000235557556


## See 4x smaller model from quantization

You create a float TFLite model and then see that the quantized TFLite model is 4x smaller.

In [10]:
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
    f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
    f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))

INFO:tensorflow:Assets written to: C:\Users\kcsgo\AppData\Local\Temp\tmp2tjqrzvy\assets


INFO:tensorflow:Assets written to: C:\Users\kcsgo\AppData\Local\Temp\tmp2tjqrzvy\assets


Float model in Mb: 0.08053970336914062
Quantized model in Mb: 0.02336883544921875


## Summary

Through this, you saw how to create quantization aware models with the TensorFlow Model Optimization Toolkit API and then quantized models for the TFLite backend.

You saw a 4x model size compression benefit for a model for MNIST, with minimal accuracy
difference. To see the latency benefits on mobile, try out the TFLite examples [in the TFLite app repository](https://www.tensorflow.org/lite/models).