# Import Libraries


In [10]:
import os
import shutil

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import tensorflow as tf

import numpy as np
import pandas as pd

import wandb
from wandb.keras import WandbCallback

from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight

from utils import create_model, augment_images, flatten_datasets
from config import config

# Prepare Training Data

Our data will be split into 80% training, and holding out 20% for testing later on. The training set we will be used for 5-fold cross validation during model fitting, so our training set will be further split into training and validation.


In [11]:
data_dir = "dataset/2-cropped-v3"

train_set, test_set = tf.keras.utils.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    seed=config["seed_value"],
    image_size=config["img_shape"],
    batch_size=config["batch_size"],
    label_mode="categorical",
    subset="both",
)

train_images, train_labels = flatten_datasets(train_set)

Found 447 files belonging to 2 classes.
Using 358 files for training.
Using 89 files for validation.


# Defining the Base Model

This project will be using EfficientNetV2-B0 and MobileNetV3-Small. We can easily get this from the TensorFlow API. Let's set a `flag` variable so we can easily set which model we want to test with.


In [None]:
flag = 1

models = {
    1: {
        "base_model": tf.keras.applications.MobileNetV3Small(
            weights="imagenet",
            input_shape=config["input_shape"],
            include_top=False,
            pooling="avg",
        ),
        "model_name": "mobilenetv3small",
    },
    2: {
        "base_model": tf.keras.applications.MobileNetV3Small(
            weights="imagenet",
            input_shape=config["input_shape"],
            include_top=False,
            pooling="avg",
        ),
        "model_name": "efficientnetv2b0",
    },
}

base_model = models[flag]["base_model"]
model_name = models[flag]["model_name"]

# Fitting the Model


We define our cross validation strategy as `StratifiedKFold` and set the splits to 5 to ensure that the distribution of image samples per split remains equal all throughout to counter class imbalance.


In [13]:
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=config["seed_value"])

histories = []

# Loop over the dataset to create separate folds
for i, (train_idx, valid_idx) in enumerate(
    cv.split(train_images, np.argmax(train_labels, axis=1))
):
    print(f"Fold {i + 1}")

    # Create a new model instance
    model = create_model(base_model, config)

    # Get the training and validation data
    X_train, y_train = train_images[train_idx], train_labels[train_idx]
    X_valid, y_valid = train_images[valid_idx], train_labels[valid_idx]

    # Augment ONLY training data
    X_train, y_train = augment_images(X_train, y_train)

    # Compute weights
    weights = compute_class_weight(
        class_weight="balanced", classes=np.unique([0, 1]), y=y_train.argmax(axis=1)
    )
    weights = dict(zip(np.unique([0, 1]), weights))

    # Define checkpoint path and checkpoint callback
    if os.path.exists(f"checkpoints/{model_name}"):
        shutil.rmtree(f"checkpoints/{model_name}")
    os.makedirs(f"checkpoints/{model_name}")

    checkpoint_path = f"checkpoints/{model_name}" + "/cp-{epoch:04d}.ckpt"

    # Save the weights using the `checkpoint_path` format
    model.save_weights(checkpoint_path.format(epoch=0))

    # Define callbacks
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(
            filepath=checkpoint_path,
            monitor="val_loss",
            save_best_only=True,
            save_weights_only=True,
            verbose=1,
        ),
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=10),
    ]

    # Fit the model on the train set and evaluate on the validation set
    history = model.fit(
        X_train,
        y_train,
        batch_size=config["batch_size"],
        epochs=config["epochs"],
        class_weight=weights,
        validation_data=(X_valid, y_valid),
        verbose=1,
        callbacks=callbacks,
    )

    # Append to histories
    histories.append(history)

Fold 1
Epoch 1/100
Epoch 1: val_loss improved from inf to 2847491.25000, saving model to checkpoints/mobilenet_v3_large_100_224/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss improved from 2847491.25000 to 1665138.87500, saving model to checkpoints/mobilenet_v3_large_100_224/cp-0002.ckpt
Epoch 3/100
Epoch 3: val_loss improved from 1665138.87500 to 1467762.62500, saving model to checkpoints/mobilenet_v3_large_100_224/cp-0003.ckpt
Epoch 4/100
Epoch 4: val_loss improved from 1467762.62500 to 1310816.37500, saving model to checkpoints/mobilenet_v3_large_100_224/cp-0004.ckpt
Epoch 5/100
Epoch 5: val_loss did not improve from 1310816.37500
Epoch 6/100
Epoch 6: val_loss did not improve from 1310816.37500
Epoch 7/100
Epoch 7: val_loss improved from 1310816.37500 to 1246245.50000, saving model to checkpoints/mobilenet_v3_large_100_224/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss improved from 1246245.50000 to 1192676.00000, saving model to checkpoints/mobilenet_v3_large_100_224/cp-0008.ckpt
Epoch 9/1

# Saving to CSV


We will save the histories of each fold so that we can use it for visualization purposes later on.


In [6]:
path = f"results/{model_name}"

if os.path.exists(path):
    shutil.rmtree(path, ignore_errors=True)

os.makedirs(path)

for i, history in enumerate(histories):
    pd.DataFrame(history.history).to_csv(
        f"results/{model_name}/train_fold_{i+1}.csv", index=False
    )