In [None]:
import os
import sys

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tqdm import tqdm

In [None]:
SAVED_MODEL_DIR = './saved_model'
OUTPUT_DIR = './tflite_model'
INPUT_SIZE = 224

In [None]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
(train_data_ds, val_data_ds), metadata = tfds.load(
    name='tf_flowers',
    split=['train[:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

In [None]:
def center_crop(image, label):
    image = tf.cast(image, tf.float32)
    shape = tf.shape(image)
    height, width = shape[0], shape[1]
    size = tf.minimum(height, width)
    image = tf.image.crop_to_bounding_box(image, (height - size) // 2, (width - size) // 2, size, size)
    return image, label

def resize_and_rescale(image, label):
    image = tf.image.resize(image, [INPUT_SIZE, INPUT_SIZE])
    image = preprocess_input(image)
    return image, label

val_ds = (
    val_data_ds
    .map(center_crop)
    .map(resize_and_rescale)
    .batch(1)
)

In [None]:
def evaluate_tflite_model(model_path, val_ds):
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    input_details = interpreter.get_input_details()[0]
    input_scale, input_zero_point = input_details["quantization"]
    print(f"[input] dtype: {input_details['dtype']}, scale: {input_scale}, zero_point: {input_zero_point}")

    output_details = interpreter.get_output_details()[0]
    output_scale, output_zero_point = output_details["quantization"]
    print(f"[output] dtype: {output_details['dtype']}, scale: {output_scale}, zero_point: {output_zero_point}")

    count, acc = 0, 0
    loop = tqdm(val_ds, ascii=True, file=sys.stdout)
    for i, (image, label) in enumerate(loop):
        if input_scale != 0:
            image = image / input_scale + input_zero_point
        image = image.numpy().astype(input_details['dtype'])
        interpreter.set_tensor(input_details['index'], image)
        interpreter.invoke()
        predictions = interpreter.get_tensor(output_details['index'])
        if output_scale != 0:
            predictions = predictions * output_scale + output_zero_point

        count += 1
        if np.argmax(predictions) == label:
            acc += 1
        loop.set_postfix(acc="{:.4f}".format(acc / count))

    loop.close()
    return acc / count

## No quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [None]:
evaluate_tflite_model(model_path, val_ds)

[input] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
[output] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
367it [00:20, 18.08it/s, acc=0.9755]


0.9754768392370572

## Float16 quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_float16.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [None]:
evaluate_tflite_model(model_path, val_ds)

[input] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
[output] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
367it [00:20, 18.31it/s, acc=0.9755]


0.9754768392370572

## Dynamic range quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_default.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [None]:
evaluate_tflite_model(model_path, val_ds)

[input] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
[output] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
367it [00:41,  8.83it/s, acc=0.8856]


0.885558583106267

## Integer quantization with float fallback

In [None]:
def representative_dataset_gen():
    for image, _ in val_ds.take(100):
        yield [image]

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_int8.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [None]:
evaluate_tflite_model(model_path, val_ds)

[input] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
[output] dtype: <class 'numpy.float32'>, scale: 0.0, zero_point: 0
367it [11:38,  1.90s/it, acc=0.9755]


0.9754768392370572

## Integer-only quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_int8_only.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [None]:
evaluate_tflite_model(model_path, val_ds)

[input] dtype: <class 'numpy.uint8'>, scale: 0.007843137718737125, zero_point: 127
[output] dtype: <class 'numpy.uint8'>, scale: 0.00390625, zero_point: 0
367it [11:39,  1.91s/it, acc=0.9755]


0.9754768392370572