In [None]:
import numpy as np
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Dropout, Input, Dense, Flatten, BatchNormalization, Conv2D, MaxPool2D, AveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import categorical_crossentropy
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

## Creating the batches for dataset

In [None]:
#find data sets. note these are not made public due to confidential nature.
train_path = "train_data/train"
valid_path = "train_data/valid"
test_path = "train_data/test"

In [None]:
#generation of batches; batch_size can be adapted based on availability of GPU.
train_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.efficientnet.preprocess_input) \
    .flow_from_directory(directory=train_path, target_size=(240,240), classes=['Fissure', 'Racines', 'Normal'], batch_size=64)

valid_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.efficientnet.preprocess_input) \
    .flow_from_directory(directory=valid_path, target_size=(240,240), classes=['Fissure', 'Racines', 'Normal'], batch_size=64)

test_batches = ImageDataGenerator(preprocessing_function=tf.keras.applications.efficientnet.preprocess_input) \
    .flow_from_directory(directory=test_path, target_size=(240,240), classes=['Fissure', 'Racines', 'Normal'], batch_size=64, shuffle=False)

## Model creation (need internet connection)

In [None]:
#download the model, top is not included in order to do transfer learning.
efficientnet_model = tf.keras.applications.efficientnet.EfficientNetB1(include_top=False,input_tensor=Input(shape=(224, 224, 3)))

In [None]:
efficientnet_model.summary()

In [None]:
#additional dense layers and dropout to successfully carry out transfer learning. Original output is flattened such that the original model functions as feature extractor.
top_layers = efficientnet_model.output
top_layers = Flatten(name="flatten_top")(top_layers)
top_layers = Dense(1024, activation="ReLU",name="first_dense_top")(top_layers)
top_layers = Dropout(0.5, name="dropout_top")(top_layers)
top_layers = Dense(units=3, activation="softmax",name="linear_output")(top_layers)

In [None]:
#the efficientnet model and the output layers are joined together
model = Model(inputs=efficientnet_model.input, outputs=top_layers)

In [None]:
#original model is set to be untrainable.
for layer in efficientnet_model.layers:
	layer.trainable = False

In [None]:
#it was previously determined which learning rate was optimal, this is implemented now.
learning_rate = np.logspace(-4,-2,4)
lr = learning_rate[1]

In [None]:
#implementation of early stopping based on the monitoring of 
stop = EarlyStopping(monitor='val_accuracy', mode='max', patience=2, min_delta=1, restore_best_weights=True)

In [None]:
#compile model and carry out transfer learning
model.compile(optimizer=Adam(learning_rate=lr), loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(train_batches,
          steps_per_epoch=len(train_batches),
          validation_data=valid_batches,
          validation_steps=len(valid_batches),
          epochs=10,
          callbacks=[stop]
)

In [None]:
#defining the directory to save the weights, use checkpoint to make sure optimal weights are saved.
cp_path = "training/weights.ckpt"
cp_directory = os.path.dirname(cp_path)

checkpoint = ModelCheckpoint(filepath=cp_path,
                              save_weights_only=True,
                              verbose=1)

In [None]:
#unfreeze last 10 layers (except for the batch normalisations)
for layer in model.layers[-10:]:
  if layer != BatchNormalization:
    layer.trainable = True

In [None]:
model.summary()

In [None]:
#further fine-tune the model to obtain the final weights
model.fit(train_batches,
          steps_per_epoch=len(train_batches),
          validation_data=valid_batches,
          validation_steps=len(valid_batches),
          epochs=10,
          callbacks = [stop,checkpoint]
)