Based on: https://www.tensorflow.org/hub/tutorials/tf2_image_retraining

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras
import matplotlib.pylab as plt
import numpy as np

In [None]:
print("TF version:", tf.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

image_size = (384, 384)
print(f"Input size {image_size}")

batch_size = 32
data_dir = 'training_data/'
do_fine_tuning = False

In [None]:
# Clean ._* files created by get_files
!find $data_dir  -name ._\* -delete
!find $data_dir  -name .DS\* -delete

In [None]:
def build_dataset(subset, image_size, batch_size, data_dir):
    return tf.keras.preprocessing.image_dataset_from_directory(
        directory=data_dir,
        validation_split=0.20,
        subset=subset,
        labels='inferred',
        label_mode='categorical',
        color_mode='rgb',
        # Seed needs to provided when using validation_split and shuffle = True.
        # A fixed seed is used so that the validation set is stable across runs.
        seed=123,
        image_size=image_size,
        batch_size=1,
    )


In [None]:
def get_train_and_validation_dataset(batch_size, data_dir, image_size, do_data_augmentation: bool):
    train_ds = build_dataset('training', image_size, batch_size, data_dir)
    class_names = tuple(train_ds.class_names)
    train_size = train_ds.cardinality().numpy()
    train_ds = train_ds.unbatch().batch(batch_size)
    train_ds = train_ds.repeat()

    normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)
    preprocessing_model = tf.keras.Sequential([normalization_layer])
    if do_data_augmentation:
        preprocessing_model = tf.keras.Sequential()
        preprocessing_model.add(tf.keras.layers.experimental.preprocessing.RandomRotation(40))
        preprocessing_model.add(tf.keras.layers.experimental.preprocessing.RandomTranslation(0, 0.2))
        preprocessing_model.add(tf.keras.layers.experimental.preprocessing.RandomTranslation(0.2, 0))
        # Like the old tf.keras.preprocessing.image.ImageDataGenerator(),
        # image sizes are fixed when reading, and then a random zoom is applied.
        # If all training inputs are larger than image_size, one could also use
        # RandomCrop with a batch size of 1 and rebatch later.
        preprocessing_model.add(tf.keras.layers.experimental.preprocessing.RandomZoom(0.2, 0.2))
        preprocessing_model.add(tf.keras.layers.experimental.preprocessing.RandomFlip(mode='horizontal'))
    train_ds = train_ds.map(lambda images, labels: (preprocessing_model(images), labels))

    val_ds = build_dataset('validation', image_size, batch_size, data_dir)
    val_size = val_ds.cardinality().numpy()
    val_ds = val_ds.unbatch().batch(batch_size)
    val_ds = val_ds.map(lambda images, labels: (normalization_layer(images), labels))
    return train_ds, train_size, val_ds, val_size, class_names


In [None]:
train_ds, train_size, val_ds, val_size, class_names = get_train_and_validation_dataset(
        batch_size=batch_size, data_dir=data_dir, image_size=image_size, do_data_augmentation=False
    )


In [None]:
image_width = image_size[0]
image_height = image_size[1]
input_shape = image_size + (3,)

base_model = tf.keras.applications.efficientnet_v2.EfficientNetV2S(
   weights='imagenet', 
   input_shape=input_shape,
   include_top=False
)

base_model.trainable = False

x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(rate=0.2)(x)
outputs = tf.keras.layers.Dense(len(class_names), activation=tf.keras.activations.softmax, kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
model = tf.keras.Model(inputs=base_model.input, outputs=outputs, name="my_interpetable_model")

model.build((None,)+image_size+(3,))

model.summary()


In [None]:
model.compile(
  optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False, label_smoothing=0.1),
  metrics=[
      'accuracy'
      ])

In [None]:
activate_early_stopping = True
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
if activate_early_stopping:
    keras_callbacks = [early_stopping]
else:
    keras_callbacks = []

In [None]:
steps_per_epoch = train_size // batch_size
validation_steps = val_size // batch_size
hist = model.fit(
    train_ds,
    epochs=50, steps_per_epoch=steps_per_epoch,
    validation_data=val_ds,
    validation_steps=validation_steps,
    callbacks=keras_callbacks,).history

In [None]:
x, y = next(iter(val_ds))
image = x[0, :, :, :]
true_index = np.argmax(y[0])
plt.imshow(image)
plt.axis('off')
plt.show()

# Expand the validation image to (1, 224, 224, 3) before predicting the label
prediction_scores = model.predict(np.expand_dims(image, axis=0))
predicted_index = np.argmax(prediction_scores)
print("True label: " + class_names[true_index])
print("Predicted label: " + class_names[predicted_index])

In [None]:
saved_model_path = f"saved_model_for_interpretability"
model.save(saved_model_path)