In [None]:
import os
os.chdir('/Users/renjie.tan/Workspace/hologram_seg')
os.getcwd()

In [None]:
import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)


import tensorflow as tf
import numpy as np
assert float(tf.__version__[:3]) >= 2.3

In [None]:
from src.train.model_utils import load_model

In [None]:
tf_model = load_model('checkpoints/unet_try/')
tf_model.summary()

## TFLite model but still using 32-bit float values for parameter data

In [None]:
# converter = tf.lite.TFLiteConverter.from_saved_model('checkpoints/unet_try/')
# converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
# tflite_model = converter.convert()
# open('checkpoints/unet_try' + "/unet_from_keras.tflite", "wb").write(tflite_model)

## TFLite conversion using dynamic range quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_quant = converter.convert()

In [None]:
open('checkpoints/unet_try' + "/unet_dynamic_range_quant.tflite", "wb").write(tflite_model_quant)

## TFLite conversion using float16 quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_fp16_model = converter.convert()

In [None]:
open('checkpoints/unet_try' + "/unet_f16_quant.tflite", "wb").write(tflite_fp16_model)

## Convert using float fallback quantization
To quantize the variable data (such as model input/output and intermediates between layers), you need to provide a RepresentativeDataset. This is a generator function that provides a set of input data that's large enough to represent typical values. It allows the converter to estimate a dynamic range for all the variable data. 

The dataset does not need to be unique compared to the training or evaluation dataset.

To support multiple inputs, each representative data point is a list and elements in the list are fed to the model according to their their indices.

In [None]:
from src.train.dataloader import HologramDataGenerator

In [None]:
holo_gen = HologramDataGenerator(
    data_dir = 'data/cvat_hologram',
    input_shape = (128,128),
    batch_size = 1,
    augmentations = None,
    shuffle = False
)

In [None]:
def representative_data_gen():
    for batch_x, batch_y in holo_gen:
        yield [np.array(batch_x, dtype=np.float32)]

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

tflite_floatfallback_model = converter.convert()

In [None]:
open('checkpoints/unet_try' + "/unet_floatfallback.tflite", "wb").write(tflite_floatfallback_model)

## Convert using integer-only quantization

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(tf_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_intonly_model = converter.convert()

In [None]:
open('checkpoints/unet_try' + "/unet_intonly.tflite", "wb").write(tflite_intonly_model)

# Run TF Lite models
https://www.tensorflow.org/lite/performance/post_training_float16_quant

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
interpreter = tf.lite.Interpreter(model_path=str("checkpoints/unet_try/unet_intonly.tflite"))
# interpreter = tf.lite.Interpreter(model_path="checkpoints/unet_try/unet.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]


In [None]:
batch_x, batch_y = holo_gen[40]

In [None]:
output_details['index']

In [None]:
input_details

In [None]:
output_details

In [None]:
# Check if input type is quantize
if input_details['dtype'] == np.uint8:
    input_scale, input_zero_point = input_details["quantization"]
    print('input is integer quantize')
    batch_x = batch_x / input_scale + input_zero_point

In [None]:
test_batch = batch_x.astype(input_details['dtype'])
interpreter.set_tensor(input_details["index"], test_batch)
interpreter.invoke()
output = interpreter.get_tensor(output_details["index"])[0]

In [None]:
plt.imshow(test_batch[0].astype('uint8'))

In [None]:
from src.predictions.utils import create_mask_from_prediction

In [None]:
mask = create_mask_from_prediction(np.expand_dims(output, axis = 0))

In [None]:
plt.imshow(mask)