In [1]:
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from tqdm import tqdm

In [2]:
SAVED_MODEL_DIR = './saved_model'
OUTPUT_DIR = './tflite_model'
INPUT_SIZE = 224
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [3]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [4]:
(train_data_ds, val_data_ds), metadata = tfds.load(
    name='tf_flowers',
    split=['train[:90%]', 'train[90%:]'],
    with_info=True,
    as_supervised=True,
)

In [5]:
def center_crop_and_resize(image, label):
    image = tf.cast(image, tf.float32)
    shape = tf.shape(image)
    height, width = shape[0], shape[1]
    size = tf.minimum(height, width)
    image = tf.image.crop_to_bounding_box(image, (height - size) // 2, (width - size) // 2, size, size)
    image = tf.image.resize(image, [INPUT_SIZE, INPUT_SIZE])
    image = preprocess_input(image)
    return image, label

val_ds = (
    val_data_ds
    .map(center_crop_and_resize, num_parallel_calls=AUTOTUNE)
    .batch(1)
    .prefetch(AUTOTUNE)
)

In [6]:
def evaluate_tflite_model(model_path, val_ds):
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    count, acc = 0, 0
    loop = tqdm(val_ds, ascii=True)
    for i, (data, label) in enumerate(loop):
        interpreter.set_tensor(input_details[0]['index'], data)
        interpreter.invoke()
        predictions = interpreter.get_tensor(output_details[0]['index'])

        count += 1
        if np.argmax(predictions) == label:
            acc += 1
        loop.set_postfix(acc="{:.4f}".format(acc / count))

    loop.close()
    return acc / count

## No quantization

In [7]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_full.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [8]:
evaluate_tflite_model(model_path, val_ds)

367it [00:15, 23.09it/s, acc=0.9755]


0.9754768392370572

## Float16 quantization

In [9]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_float16.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [10]:
evaluate_tflite_model(model_path, val_ds)

367it [00:15, 23.00it/s, acc=0.9728]


0.9727520435967303

## Dynamic range quantization

In [11]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_default.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [12]:
evaluate_tflite_model(model_path, val_ds)

367it [00:34, 10.69it/s, acc=0.7384]


0.7384196185286104

## Integer quantization with float fallback

In [13]:
repr_ds = (
    val_data_ds
    .map(center_crop_and_resize, num_parallel_calls=AUTOTUNE)
    .batch(100)
    .prefetch(AUTOTUNE)
)

batch_image, _ = next(iter(repr_ds))

def representative_dataset_gen():
    for i in range(100):
        yield [batch_image[i:i+1, ]]

In [14]:
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_DIR)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.representative_dataset = representative_dataset_gen
tflite_model = converter.convert()

model_path = os.path.join(OUTPUT_DIR, 'model_int8.tflite')
with open(model_path, 'wb') as f:
    f.write(tflite_model)

In [15]:
evaluate_tflite_model(model_path, val_ds)

367it [12:08,  1.99s/it, acc=0.9646]


0.9645776566757494