In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from keras.applications import resnet50
from keras.src.applications import ResNet50
from tqdm.notebook import tqdm

from resnet50.utils import split_functional_model

In [None]:
def normalize_img(img, lbl):
    """Normalizes images: `uint8` -> `float32`."""
    img = tf.image.resize_with_pad(img, 224, 224)
    img = resnet50.preprocess_input(img)
    return img, lbl

In [None]:
# load imagenet2012 dataset
validation_ds, metadata = tfds.load(
    'imagenet2012',
    split='validation',
    with_info=True,
    as_supervised=True,
)
validation_ds = validation_ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
get_label_name = metadata.features['label'].int2str

In [None]:
def representative_dataset():
    # should be 100 to 500 according to documentation
    number_of_samples = 100
    for data in validation_ds.batch(1).take(number_of_samples):
        yield [data[0]]

In [None]:
def quantize_and_save_model(model, name):
    # Convert the model.
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.representative_dataset = representative_dataset
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
    converter.inference_input_type = tf.int8  # or tf.uint8
    converter.inference_output_type = tf.int8  # or tf.uint8
    tflite_model = converter.convert()

    # Save the model.
    with open("models/head/" + str(name) + '.tflite', 'wb') as f:
        f.write(tflite_model)

In [None]:
model = ResNet50(weights='imagenet')

In [None]:
print("Save full models")
quantize_and_save_model(model, "40")
model.save("models/tail/0.keras")

In [None]:
print("Save partial models")
# skip full model with first and last index
for i in tqdm(range(1, 40)):
    head, tail = split_functional_model(model, i)
    quantize_and_save_model(head, i)
    tail.save("models/tail/" + str(i) + ".keras")