In [269]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers
from sklearn.metrics import confusion_matrix

import utils
import graphing

SEED = 15243
np.random.seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)
tf.random.set_seed(SEED)

plt.rc("axes.spines", right=False, top=False)
plt.rc("font", family="serif")

## Importing the dataset
The below code is used to import the dataset from the the folder at the path `DATA_DIR`. The dataset is then loaded using `Keras`'s `image_dataset_from_directory` function.

### Note
The dataset is not uploaded to the notebook as it is too large. The dataset can be downloaded from [here](https://www.kaggle.com/andrewmvd/face-mask-detection) and uploaded to the notebook.

In [7]:
DATA_DIR = "data_dir"

In [None]:
BATCH_SIZE = 25

os.listdir(DATA_DIR)
IMG_SIZE = 224
MAX_EPOCHS = 50
class_names = ["normal", "polyps", "ulcerative-colitis"]
class_details = ["Normal", "Polyps", "Ulcerative Colitis"]

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_DIR,
    batch_size=BATCH_SIZE,
    image_size=(IMG_SIZE, IMG_SIZE),
    class_names=class_names,
    seed=45,
    subset="training",
    validation_split=0.2
)

validation_ds = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_DIR,
    batch_size=BATCH_SIZE,
    image_size=(IMG_SIZE, IMG_SIZE),
    class_names=class_names,
    seed=45,
    subset="validation", 
    validation_split=0.2
)

## Previewing the dataset
The dataset is then previewed by displaying the first 9 images from the dataset along with their labels.

In [None]:
plt.figure(figsize=(14, 7))
for images, labels in train_ds.take(1):
    labels = labels.numpy()
    for i in range(8):
        ax = plt.subplot(2, 4, i + 1)
        x = images[i].numpy().astype("uint8")
        plt.imshow(x, cmap='gist_gray', vmin=0, vmax=255)
        plt.title(class_details[labels[i]])
        plt.axis("off")

In [258]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
validation_ds = validation_ds.cache().prefetch(buffer_size=AUTOTUNE)

## Models

The pretrained base models used are:
1. VGG16
2. EfficientNetB0
3. EfficientNetB1
4. EfficientNetV2B0
5. EfficientNetV2S
6. MobileNetV3Small

The models are defined below

In [None]:
# VGG-16 Model
pretrained_vgg16_base = tf.keras.applications.vgg16.VGG16(
    include_top=False, weights="imagenet", pooling=None
)
pretrained_vgg16_base.trainable = False

vgg16_model = tf.keras.Sequential([
    layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)),
    layers.Lambda(tf.keras.applications.vgg16.preprocess_input),
    pretrained_vgg16_base,
    layers.GlobalAveragePooling2D(),
    layers.BatchNormalization(),
    layers.Dropout(0.3),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.3),
    layers.Dense(3, activation="softmax")
])

vgg16_model.summary()

In [None]:
# EfficientNet B0 Model

pretrained_efficientnetB0_base = tf.keras.applications.efficientnet.EfficientNetB0(
    include_top=False, weights="imagenet", pooling=None,
)
pretrained_efficientnetB0_base.trainable = False

efficientnetB0_model = tf.keras.Sequential([
    layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)),
    pretrained_efficientnetB0_base,
    layers.GlobalAveragePooling2D(),
    layers.BatchNormalization(),
    layers.Dropout(0.3),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.3),
    layers.Dense(3, activation="softmax")
])

efficientnetB0_model.summary()

In [None]:
# EfficientNet B1 Model

pretrained_efficientnetB1_base = tf.keras.applications.efficientnet.EfficientNetB1(
    include_top=False, weights="imagenet", pooling=None,
)
pretrained_efficientnetB1_base.trainable = False

efficientnetB1_model = tf.keras.Sequential([
    layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)),
    pretrained_efficientnetB1_base,
    layers.GlobalAveragePooling2D(),
    layers.BatchNormalization(),
    layers.Dropout(0.3),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.3),
    layers.Dense(3, activation="softmax")
])

efficientnetB1_model.summary()

In [None]:
# EfficientNet V2-B0

pretrained_efficientnetV2B0_base = tf.keras.applications.EfficientNetV2B0(
    include_top=False, weights="imagenet", pooling=None,
)
pretrained_efficientnetV2B0_base.trainable = False

efficientnetV2B0_model = tf.keras.Sequential([
    layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)),
    layers.Lambda(tf.keras.applications.efficientnet_v2.preprocess_input),
    pretrained_efficientnetV2B0_base,
    layers.GlobalAveragePooling2D(),
    # layers.BatchNormalization(),
    layers.Dropout(0.3),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.3),
    layers.Dense(3, activation="softmax")
])

efficientnetV2B0_model.summary()

In [None]:
# EfficientNet V2S Model

pretrained_efficientnetV2S_base = tf.keras.applications.EfficientNetV2S(
    include_top=False,
    weights='imagenet',
    pooling=None,
    include_preprocessing=True
)

pretrained_efficientnetV2S_base.trainable = False

efficientnetV2S_model = tf.keras.Sequential([
    layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)),
    pretrained_efficientnetV2S_base,
    layers.GlobalAveragePooling2D(),
    layers.BatchNormalization(),
    layers.Dropout(0.3),
    layers.Dense(128, activation="relu"),
    layers.Dropout(0.3),
    layers.Dense(3, activation="softmax")
])

efficientnetV2S_model.summary()

In [None]:
# MobileNet V3S
pretrained_mobilenetV3S_base = tf.keras.applications.MobileNetV3Small(
    input_shape=(IMG_SIZE, IMG_SIZE, 3),
    include_top=False,
    weights='imagenet',
    pooling=None
)

pretrained_mobilenetV3S_base.trainable = False

mobilenetV3S_model = tf.keras.Sequential([
    layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3)),
    layers.Lambda(tf.keras.applications.mobilenet_v3.preprocess_input),
    pretrained_mobilenetV3S_base,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.01)),
    layers.Dense(3, activation="softmax")
])

for layer in pretrained_mobilenetV3S_base.layers[-30:]:
    if not isinstance(layer, layers.BatchNormalization):
        layer.trainable = True
        
mobilenetV3S_model.summary()

## Training the models

We consolidate the training of the models below.

In [20]:
# Define the models in a list of list of dictionaries
models = [
    {
        "name": "VGG-16",
        "id": "vgg16",
        "model": vgg16_model,
        "custom_objects": {"preprocess_input": tf.keras.applications.vgg16.preprocess_input},
    },
    {
        "name": "EfficientNet B0",
        "id": "efficientnetB0",
        "model": efficientnetB0_model,
    },
    {
        "name": "EfficientNet B1",
        "id": "efficientnetB1",
        "model": efficientnetB1_model,
    },
    {
        "name": "EfficientNet V2-B0",
        "id": "efficientnetV2B0",
        "model": efficientnetV2B0_model,
    },
    {
        "name": "EfficientNet V2S",
        "id": "efficientnetV2S",
        "model": efficientnetV2S_model,
    },
    {
        "name": "MobileNet V3S",
        "id": "mobilenetV3S",
        "model": mobilenetV3S_model,
        "custom_objects": {"preprocess_input": tf.keras.applications.mobilenet_v3.preprocess_input},
    }
]

In [None]:
# Copy the above models to a new list
trained_models = models.copy()

In [None]:
# Train the models, saving their history, figures, and TensorFlow checkpoints
checkpoint_path = f"./models/{MAX_EPOCHS}e"
for i, model in enumerate(models):
    print(f"Starting training for model #{i + 1}: {model['name']}")
    model_history = utils.compile_and_fit_model(model['model'], train_ds=train_ds, validation_ds=validation_ds, epochs=MAX_EPOCHS, early_stopping=False, patience=2, checkpoint_path=f"{checkpoint_path}/{model['id']}-{IMG_SIZE}.keras")
    trained_models[i]['history'] = model_history
    print(f"Saving figures for model #{i + 1}: {model['name']}")
    utils.save_figure_from_history(model_history, f"{checkpoint_path}/{model['id']}-{IMG_SIZE}.svg")

In [None]:
# Generating a confusion heatmap for the MobileNet V3S model
graphing.confusion_heatmap(*(utils.get_predictions(mobilenetV3S_model, ds=validation_ds)), class_details, percentage=False, path="./mobilenet-heatmap.svg")