In [1]:
import itertools
import os

import matplotlib.pylab as plt
import numpy as np

import tensorflow as tf
from tensorflow import keras
import tensorflow_hub as hub


In [None]:
data_dir = data_path = '' #path to the training dataset

datagen_kwargs = dict(rescale=1./255, validation_split=.20)

img_size=224

dataflow_kwargs = dict(target_size=(img_size, img_size),
                       batch_size=64,
                       interpolation="bilinear")

valid_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    **datagen_kwargs)
valid_generator = valid_datagen.flow_from_directory(
    data_dir, subset="validation", shuffle=False, **dataflow_kwargs)

train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=40,
    horizontal_flip=True,
    width_shift_range=0.2, 
    height_shift_range=0.2,
    shear_range=0.2, 
    zoom_range=0.2,
    **datagen_kwargs)

train_generator = train_datagen.flow_from_directory(
    data_dir, subset="training", shuffle=True, **dataflow_kwargs)

In [None]:
num_epochs =  5#@param {type:"integer"} #insert here the number of epochs

tf.keras.backend.clear_session()
model = tf.keras.Sequential([
    tf.keras.layers.InputLayer(input_shape=[img_size, img_size, 3]),
    hub.KerasLayer("https://tfhub.dev/google/imagenet/resnet_v2_101/feature_vector/5", trainable=True, arguments=dict(batch_norm_momentum=0.997)),
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(train_generator.num_classes,
                          kernel_regularizer=tf.keras.regularizers.l2(0.0001))
])
model.build((None, img_size, img_size, 3))
model.summary()

model.compile(
  optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9), 
  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1),
  metrics=['accuracy'])

steps_per_epoch = train_generator.samples // train_generator.batch_size
validation_steps = valid_generator.samples // valid_generator.batch_size
hist = model.fit(
    train_generator,
    epochs=num_epochs, 
    steps_per_epoch= steps_per_epoch,
    callbacks=[keras.callbacks.EarlyStopping(patience=8, verbose=1, restore_best_weights=True)],
    validation_data=valid_generator,
    validation_steps=validation_steps).history

In [None]:
saved_model_path = "" #insert the path where you want to save the model
model.save(saved_model_path, include_optimizer=True)