In [None]:
# ===========================================
# == Training image classification network ==
# ===========================================

# https://www.tensorflow.org/lite/tutorials/model_maker_image_classification
# with incremental changes

# some ideas here for customising to use checkpoints:
# https://stackoverflow.com/questions/69444878/how-to-continue-training-with-checkpoints-using-object-detector-efficientdetlite

# available model_specs:
# https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/core/task/model_spec/image_spec.py#L29-L59

# Ensure the kernel is selected from virtual environment

# note, if getting error message, suspect temp folder tidied away, so delete and try again
# OSError: SavedModel file does not exist at: /var/folders/bt/pk67cdmj12l8t7d5zbkx4hbr0000gn/T/tfhub_modules
        
import os

import numpy as np

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import model_spec
from tflite_model_maker import image_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.image_classifier import DataLoader

# Extracted functions to enable continued training; copied from:
# https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/core/task/image_classifier.py
# https://github.com/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/core/task/train_image_classifier_lib.py
from tensorflow_examples.lite.model_maker.core.task import model_util

import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

import itertools
import sys
import pathlib
import shutil
import re
import io
import time

import PIL

import pandas as pd

print("TF version:", tf.__version__)

# # Apple feature to optimise performance using ML compute resources in Mac
# try:
#     from tensorflow.python.compiler.mlcompute import mlcompute
#     mlcompute.set_mlc_device(device_name='any')
#     print('Using Apple MLCompute')
# except:
#     print('Apple MLCompute not found')
#     pass

tf.get_logger().setLevel('ERROR')

In [None]:
# Training configuration
TRAINING_EPOCHS = 2
EXPORT_MODEL    = True
RUN_TRAINING    = True
RUN_TEST        = True
NUM_TEST_IMAGES = 2

# Model selection
# options look to be:
# mobilenet_v2, resnet_50, efficientnet_lite0, efficientnet_lite1, efficientnet_lite2, efficientnet_lite3, efficientnet_lite4
# the larger versions of efficientnet look to be for inputs larger than 224x224

PROCESS_LIST_OF_MODELS = True
if not PROCESS_LIST_OF_MODELS:
    MODEL_SPEC_NAME = 'efficientnet_lite1'
else:
    ALL_MODELS = ['mobilenet_v2', 'resnet_50', 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4']
    FEW_MODELS = ['mobilenet_v2', 'efficientnet_lite0', 'efficientnet_lite4']
    TWO_MODELS = ['mobilenet_v2', 'efficientnet_lite0']
    ONE_MODEL  = ['efficientnet_lite0']
    
    MODEL_LIST = ALL_MODELS

OUTPUT_DIR='.'
LABEL_FILENAME = 'labels.txt'

In [None]:
image_path = tf.keras.utils.get_file('flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)

data = DataLoader.from_folder(image_path)
train_data, rest_data = data.split(0.8)
validation_data, test_data = rest_data.split(0.5)

print("Show a few example images")
plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(data.gen_dataset().unbatch().take(25)):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(image.numpy(), cmap=plt.cm.gray)
  plt.xlabel(data.index_to_label[label.numpy()])
plt.show()

In [None]:
def build_model(model_spec_name, epochs_between_evals, validation_data, do_data_augmentation, do_train=False):
    model = image_classifier.create(train_data,
                                    model_spec=model_spec.get(model_spec_name),
                                    epochs=epochs_between_evals,
                                    validation_data=validation_data,
                                    use_augmentation=do_data_augmentation,
                                    do_train=do_train)
    return model
    
def prepare_model_for_training(model, train_data, validation_data=None, hparams=None, steps_per_epoch=None):
    hparams = model._get_hparams_or_default(hparams)
    
#     model.create_model() # already done in build_model()

    train_ds = train_data.gen_dataset(
                                hparams.batch_size,
                                is_training=True,
                                shuffle=model.shuffle,
                                preprocess=model.preprocess)
    
    steps_per_epoch = model_util.get_steps_per_epoch(steps_per_epoch,
                                                     hparams.batch_size,
                                                     train_data)
    if steps_per_epoch is not None:
        train_ds = train_ds.take(steps_per_epoch)
        
    validation_ds = None
    if validation_data is not None:
        validation_ds = validation_data.gen_dataset(
            hparams.batch_size, is_training=False, preprocess=model.preprocess)

    loss = tf.keras.losses.CategoricalCrossentropy(
        label_smoothing=hparams.label_smoothing)

    # Compile the model
    model.model.compile(
        optimizer=tf.keras.optimizers.SGD(
            lr=hparams.learning_rate, momentum=hparams.momentum),
        loss=loss,
        metrics=["accuracy"])


    return train_ds, validation_ds

def model_train_local(model, train_ds, validation_ds=None, hparams=None, steps_per_epoch=None):
    hparams = model._get_hparams_or_default(hparams)
    
    model.history = model.model.fit(
      train_ds,
      epochs=hparams.train_epochs,
      steps_per_epoch=steps_per_epoch,
      validation_data=validation_ds)
    
def train_model(model, total_training_epochs, epochs_between_evals, steps_per_epoch=None,
                train_ds=None, validation_ds=None):
    num_training_loops = int(total_training_epochs / epochs_between_evals)

    history = {'loss':[], 'accuracy':[], 'val_loss':[], 'val_accuracy':[]}
    for loop in range(num_training_loops):
        print("Loop: {} of {}".format(loop+1, num_training_loops))

        model_train_local(model, train_ds, validation_ds=validation_ds, hparams=None, steps_per_epoch=steps_per_epoch)
        hist = model.history.history # model, history call-back, history property

        history['loss'].append(hist['loss'][-1]) # add last element from list
        history['accuracy'].append(hist['accuracy'][-1]) # add last element from list
#         print("loss=", hist['loss'][-1], "; accuracy=", hist['accuracy'][-1])

        val_loss, val_accuracy = model.evaluate(test_data)
#         print("val_loss=", val_loss, "; val_accuracy=", val_accuracy)

        history['val_loss'].append(round(val_loss, 2))
        history['val_accuracy'].append(round(val_accuracy, 2))

    return history

def plot_training_metrics(model_name, hist, width=16, height=4):
    fig = plt.figure(figsize=(width, height))
    fig.subplots_adjust(hspace=0.7)

    plt.subplot(2, 1, 1)
    plt.ylabel("Loss")
    plt.xlabel("Training Steps")
    plt.plot(hist["loss"], label='Training loss')
    plt.plot(hist["val_loss"], label='Validation loss')
    y_max = max(max(hist["loss"]), max(hist["val_loss"]))
    plt.ylim([0,y_max])
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.legend(loc='upper right')
    plt.title('Loss')

    plt.subplot(2, 1, 2)
    plt.ylabel("Accuracy")
    plt.xlabel("Training Steps")
    plt.plot(hist["accuracy"], label='Training accuracy')
    plt.plot(hist["val_accuracy"], label='Validation accuracy')
    plt.ylim([0,1])
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.legend(loc='lower right')
    plt.title('Accuracy')

    plt.show()
    
def export_tflite(model, output_dir, model_filename, labels_filename):
    model.export(export_dir=output_dir,
                 tflite_filename=model_filename,
                 label_filename=labels_filename,
                 export_format=[ExportFormat.TFLITE, ExportFormat.LABEL])

def load_tflite_model_and_labels(model_file, labels_file):
    # Read TensorFlow Lite model from TensorFlow Lite file.
    with tf.io.gfile.GFile(model_file, 'rb') as f:
        model_content = f.read()

    # Read label names from label file.
    with tf.io.gfile.GFile(labels_file, 'r') as f:
        label_names = f.read().split('\n')

    # Initialze TensorFlow Lite inpterpreter.
    interpreter = tf.lite.Interpreter(model_content=model_content)

    return interpreter, label_names

def input_preprocess(image, model_input_details=None, mean_rgb=0.0, stddev_rgb=255.0):
    floating_model = model_input_details['dtype'] == np.float32
    batch, height, width, channels = model_input_details['shape']
    input_image_shape = [height, width]

    image = tf.compat.v1.image.resize(image, input_image_shape)

    if floating_model:
        image = tf.cast(image, tf.float32)
        image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
        image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
    else:
        image = tf.cast(image, tf.uint8)
            
    return image

def test_model(model, interpreter, test_data, label_names, num_test_images=1, inf_time_ms_list=[]):
    print('Number of test images =', test_data.size)
    if num_test_images is None:
        num_test_images=test_data.size
    else:
        print("Limiting to:", num_test_images, " for test")

    model_input_details = interpreter.get_input_details()[0]
    interpreter.allocate_tensors()
    input_index = interpreter.get_input_details()[0]['index']
    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
    
    # check the type of the input tensor
    floating_model = interpreter.get_input_details()[0]['dtype'] == np.float32

    # Run predictions on each test image data and calculate accuracy.
    accurate_count = 0
    total_count = 0
    for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(num_test_images)):
        # Pre-process input image
        image = input_preprocess(image, model_input_details=model_input_details)
        input_data = np.expand_dims(image, axis=0) # add batch dimension

        # Run inference (with timing)
        interpreter.set_tensor(input_index, input_data)
        start_time = time.time()
        interpreter.invoke()
        finish_time = time.time()
        inference_time_ms = int((finish_time - start_time)*1000)
        inf_time_ms_list.append(inference_time_ms)

        # Post-processing: remove batch dimension and find the label with highest probability.
        predict_label = np.argmax(output()[0])
        # Get label name with label index.
        predict_label_name = label_names[predict_label]
        prediction_is_correct = (predict_label == label.numpy())
        if prediction_is_correct:
            format_string = '{:d}: Correct prediction ({:s} == {:s})'
        else:
            format_string = '{:d}: Incorrect: predicted {:s} but should have been {:s}'
        print(format_string.format(i+1, predict_label_name, data.index_to_label[label.numpy()]))

        accurate_count += prediction_is_correct
        total_count += 1

    accuracy = accurate_count * 1.0 / total_count
    print('TensorFlow Lite model accuracy = {:.1f}%'.format(accuracy*100))
    print('Inference time (ms): min={:,.0f}, mean={:,.0f}, max={:,.0f}'.format(min(inf_time_ms_list), np.mean(inf_time_ms_list), max(inf_time_ms_list)))

def build_and_train(model_spec_name='efficientnet_lite0',
                    run_training=True,
                    total_training_epochs=1, epochs_between_evals=1,
                    validation_data=None,
                    test_data=None,
                    export_model=False, output_dir='.',
                    run_test=True, num_test_images=None):
    
    print('\nBuild model:', model_spec_name)
    model = build_model(model_spec_name,
                        epochs_between_evals=epochs_between_evals,
                        validation_data=validation_data,
                        do_data_augmentation=True,
                        do_train=False)
    
    print('Model input shape:', model.model_spec.input_image_shape)
    
    val_accuracy = None
    history      = None
    if run_training:
        print('\nPrepare model for retraining')
        train_ds, validation_ds = prepare_model_for_training(model, train_data, validation_data=validation_data, hparams=None, steps_per_epoch=None)

        print('\nTrain model:')
        history = train_model(model, total_training_epochs, epochs_between_evals, steps_per_epoch=None,
                              train_ds=train_ds, validation_ds=validation_ds)

        print('\nTraining Metrics:')
        plot_training_metrics(model_spec_name, history)

        train_accuracy = float(history['accuracy'][-1])
        val_accuracy   = float(history['val_accuracy'][-1])
        train_loss     = float(history['loss'][-1])
        val_loss       = float(history['val_loss'][-1])
        print('Training:     accuracy={0:.2f}'.format(train_accuracy)+'; loss={0:.2f}'.format(train_loss))
        print('Verification: accuracy={0:.2f}'.format(val_accuracy)+'; loss={0:.2f}'.format(val_loss))
        val_accuracy   = round(val_accuracy, 2)
    
    file_size_mb = None
    model_filename = model_spec_name+'.tflite'
    model_file  = os.path.join(output_dir, model_filename)
    labels_file = os.path.join(output_dir, LABEL_FILENAME)
    if run_training and export_model:
        print('\nExport model file:')
        export_tflite(model=model,
                      output_dir=output_dir,
                      model_filename=model_filename,
                      labels_filename=LABEL_FILENAME)
        if os.path.isfile(model_file):
            file_size_mb = round(os.stat(model_file).st_size / 1024 / 1024, 2)

    inf_time_ms_list=[]
    average_inf_time_ms=None
    if run_test:
        print('\nTest model file:')
        interpreter, label_names = load_tflite_model_and_labels(model_file, labels_file)
        test_model(model, interpreter, test_data, label_names, num_test_images=num_test_images, inf_time_ms_list=inf_time_ms_list)
        average_inf_time_ms = round(np.mean(inf_time_ms_list), 2)
    
    model_metrics = {'model_spec_name':model_spec_name,
                 'history':history,
                 'val_accuracy':val_accuracy,
                 'model_size_mb':file_size_mb,
                 'inference_time_ms':average_inf_time_ms}

    return model_metrics

def display_model_metrics_summary(model_metrics):
    metrics = model_metrics.copy()
    metrics.pop('history')

    print("{:<20} {:<10}".format('Key', 'Value'))
    for k, v in metrics.items():
        print("{!s:<20} {!s:<10}".format(k, v))

In [None]:
## =========================== ##
## Build, Train and Test Model ##
## =========================== ##
if PROCESS_LIST_OF_MODELS:
    print('Not configured for single model; continuing to list of models')
else:
    model_metrics = build_and_train(model_spec_name=MODEL_SPEC_NAME,
                                    run_training=RUN_TRAINING,
                                    total_training_epochs=TRAINING_EPOCHS,
                                    validation_data=validation_data,
                                    test_data=test_data,
                                    export_model=EXPORT_MODEL,
                                    output_dir=OUTPUT_DIR,
                                    run_test=RUN_TEST,
                                    num_test_images=NUM_TEST_IMAGES)

    display_model_metrics_summary(model_metrics)

In [None]:
def plot_training_metrics_list(model_metrics_list, width=16, height=4):
    fig = plt.figure(figsize=(width, height))
    fig.subplots_adjust(hspace=0.7)

    for model_metrics in model_metrics_list:
        model_name = model_metrics.get('model_spec_name','')
        hist = model_metrics.get('history',{})
        
        plt.subplot(2, 2, 1)
        plt.ylabel("Accuracy")
        plt.xlabel("Training Steps")
        plt.ylim([0,1])
        plt.plot(hist["accuracy"], label=model_name)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.legend(loc='lower right')
        plt.title('Training Accuracy')

        plt.subplot(2, 2, 2)
        plt.ylabel("Accuracy")
        plt.xlabel("Training Steps")
        plt.ylim([0,1])
        plt.plot(hist["val_accuracy"], label=model_name)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.legend(loc='lower right')
        plt.title('Validation Accuracy')

        plt.subplot(2, 2, 3)
        plt.ylabel("Loss")
        plt.xlabel("Training Steps")
        plt.plot(hist["loss"], label=model_name)
        y_max = max(max(hist["loss"]), max(hist["val_loss"]))
        plt.ylim([0,y_max])
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.legend(loc='upper right')
        plt.title('Training Loss')

        plt.subplot(2, 2, 4)
        plt.ylabel("Loss")
        plt.xlabel("Training Steps")
        y_max = max(max(hist["loss"]), max(hist["val_loss"]))
        plt.ylim([0,y_max])
        plt.plot(hist["val_loss"], label=model_name)
        plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
        plt.legend(loc='upper right')
        plt.title('Validation Loss')

    plt.show()

def display_model_metrics_list_summary(model_metrics_list):
    df = pd.DataFrame(model_metrics_list)
    df = df.drop(['history'], axis='columns')
    display(df)
    return df

def build_and_train_model_list(model_list=['efficientnet_lite0'],
                               run_training=True,
                               total_training_epochs=1, epochs_between_evals=1,
                               validation_data=None,
                               test_data=None,
                               export_model=False, output_dir='.',
                               run_test=True, num_test_images=None):

    print('Model list:', model_list)
    model_metrics_list = []
    for model_spec_name in MODEL_LIST:
        
        print('\nBuild model:', model_spec_name)
        model = build_model(model_spec_name,
                            epochs_between_evals=epochs_between_evals,
                            validation_data=validation_data,
                            do_data_augmentation=True,
                            do_train=False)

        print('Model input shape:', model.model_spec.input_image_shape)

        val_accuracy = None
        history      = None
        if run_training:
            print('\nPrepare model for retraining')
            train_ds, validation_ds = prepare_model_for_training(model, train_data, validation_data=validation_data, hparams=None, steps_per_epoch=None)

            print('\nTrain model:')
            history = train_model(model, total_training_epochs, epochs_between_evals, steps_per_epoch=None,
                                  train_ds=train_ds, validation_ds=validation_ds)

            print('\nTraining Metrics:')
            plot_training_metrics(model_spec_name, history)

            train_accuracy = float(history['accuracy'][-1])
            val_accuracy   = float(history['val_accuracy'][-1])
            train_loss     = float(history['loss'][-1])
            val_loss       = float(history['val_loss'][-1])
            print('Training:     accuracy={0:.2f}'.format(train_accuracy)+'; loss={0:.2f}'.format(train_loss))
            print('Verification: accuracy={0:.2f}'.format(val_accuracy)+'; loss={0:.2f}'.format(val_loss))
            val_accuracy   = round(val_accuracy, 2)

        file_size_mb = None
        model_filename = model_spec_name+'.tflite'
        model_file  = os.path.join(output_dir, model_filename)
        labels_file = os.path.join(output_dir, LABEL_FILENAME)
        if run_training and export_model:
            print('\nExport model file:')
            export_tflite(model=model,
                          output_dir=output_dir,
                          model_filename=model_filename,
                          labels_filename=LABEL_FILENAME)
            model_file = os.path.join(output_dir, model_filename)
            if os.path.isfile(model_file):
                file_size_mb = round(os.stat(model_file).st_size / 1024 / 1024, 2)

        inf_time_ms_list=[]
        average_inf_time_ms=None
        if run_test:
            print('\nTest model file:')
            interpreter, label_names = load_tflite_model_and_labels(model_file, labels_file)
            test_model(model, interpreter, test_data, label_names, num_test_images=num_test_images, inf_time_ms_list=inf_time_ms_list)
            average_inf_time_ms = round(np.mean(inf_time_ms_list), 2)

        model_metrics = {'model_spec_name':model_spec_name,
                     'history':history,
                     'val_accuracy':val_accuracy,
                     'model_size_mb':file_size_mb,
                     'inference_time_ms':average_inf_time_ms}

        model_metrics_list.append(model_metrics)
    
    print('\nTraining History')
    plot_training_metrics_list(model_metrics_list)
    
    print('\nSummary Metrics')
    df = display_model_metrics_list_summary(model_metrics_list)

    return model_metrics_list

In [None]:
## =========================== ##
## Build, Train and Test Model ##
## =========================== ##

if not PROCESS_LIST_OF_MODELS:
    print('Not configured for list of models')
else:
    model_metrics_list = build_and_train_model_list(model_list=MODEL_LIST,
                    run_training=RUN_TRAINING,
                    total_training_epochs=TRAINING_EPOCHS,
                    validation_data=validation_data,
                    test_data=test_data,
                    export_model=EXPORT_MODEL,
                    output_dir=OUTPUT_DIR,
                    run_test=RUN_TEST,
                    num_test_images=NUM_TEST_IMAGES)

#     print(model_metrics_list)