# Quantization aware training

Converting a model from float to int using this method can lead to improved accuracy.

In [1]:
import tempfile
import os

import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot

## Load dataset, define model and train
This is as would normally happen without quantization.

In [6]:
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0
test_images = test_images / 255.0

model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(28, 28)),
    keras.layers.Reshape(target_shape=(28, 28, 1)),
    keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
    keras.layers.MaxPooling2D(pool_size=(2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Train model initially with floating-point weights, then calibrate.
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_images,
          train_labels,
          epochs=1,
          validation_split=0.1)



<tensorflow.python.keras.callbacks.History at 0x1531a5d90>

## Quantize the model
Quantize the model using quantization aware training. The resulting model is 'quantization aware' but not fully quantized as the weights are float32 instead of int8. On converting to a TFLite model we can fully quantize the model.

In [7]:
q_aware_model = tfmot.quantization.keras.quantize_model(model)

q_aware_model.compile(optimizer='adam',
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                      metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
quantize_layer_3 (QuantizeLa (None, 28, 28)            3         
_________________________________________________________________
quant_reshape_3 (QuantizeWra (None, 28, 28, 1)         1         
_________________________________________________________________
quant_conv2d_3 (QuantizeWrap (None, 26, 26, 12)        147       
_________________________________________________________________
quant_max_pooling2d_3 (Quant (None, 13, 13, 12)        1         
_________________________________________________________________
quant_flatten_3 (QuantizeWra (None, 2028)              1         
_________________________________________________________________
quant_dense_3 (QuantizeWrapp (None, 10)                20295     
Total params: 20,448
Trainable params: 20,410
Non-trainable params: 38
_________________________________________________

## Calibrate model with quantization aware training and evaluate against baseline

The accuracy of both the baseline and quantized model should be similar. Hence, quantizing the model has little impact on performance.

In [10]:
train_images_subset = train_images[0:1000]
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, 
                  train_labels_subset, 
                  batch_size=500,
                  epochs=1,
                  validation_split=0.1)

_, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0)
_, q_aware_model_accuracy = q_aware_model.evaluate(test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9577000141143799
Quant test accuracy: 0.9545000195503235


## Create fully quantized model in TFLite

In [11]:
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_tflite_model = converter.convert()



INFO:tensorflow:Assets written to: /var/folders/1c/v15_t4sj19x7p9fx5qk_h9hc0000gn/T/tmpo9hxb524/assets


INFO:tensorflow:Assets written to: /var/folders/1c/v15_t4sj19x7p9fx5qk_h9hc0000gn/T/tmpo9hxb524/assets


## Evaluate model

In [15]:
import numpy as np

def evaluate_model(interpreter):
    input_index = interpreter.get_input_details()[0]['index']
    output_index = interpreter.get_output_details()[0]['index']
    prediction_digits = []
    for i, test_image in enumerate(test_images):
        test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
        interpreter.set_tensor(input_index, test_image)
        interpreter.invoke()
        output = interpreter.tensor(output_index)
        digit = np.argmax(output()[0])
        prediction_digits.append(digit)
    
    prediction_digits = np.array(prediction_digits)
    accuracy = (prediction_digits == test_labels).mean()
    return accuracy

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)
interpreter.allocate_tensors()
test_accuracy = evaluate_model(interpreter)

print('TF accuracy:', baseline_model_accuracy)
print('Quant TF accuracy:', q_aware_model_accuracy)
print('Quant TFLite accuracy:', test_accuracy)

TF accuracy: 0.9577000141143799
Quant TF accuracy: 0.9545000195503235
Quant TFLite accuracy: 0.9545
