In [1]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
import tensorflow_addons as tfa
import glob
import wandb
from wandb.keras import WandbCallback
import shutil
import matplotlib.pyplot as plt
import numpy as np


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
EXPERIMENT_NAME = "mobilenetv3small-average-runs"
DATA_DIR = "dataset-augmented"
AUTOTUNE = tf.data.AUTOTUNE
CONFIG = dict(
    epochs=100,
    learning_rate=1e-4,
    batch_size=32,
    img_shape=(224, 224),
    input_shape=(224, 224, 3),
    num_classes=2,
    dropout_rate=0.5,
    es_patience=10,
    seed_value=42,
)

In [3]:
def train():
    image_count = len(glob.glob(f"{DATA_DIR}/*/*/*.jpg"))
    healthy_count = len(glob.glob(f"{DATA_DIR}/train/healthy/*.jpg"))
    wssv_count = len(glob.glob(f"{DATA_DIR}/train/wssv/*.jpg"))

    train_dir = os.path.join(DATA_DIR, "train")
    valid_dir = os.path.join(DATA_DIR, "val")
    test_dir = os.path.join(DATA_DIR, "test")

    train_set = tf.keras.utils.image_dataset_from_directory(
        train_dir,
        seed=CONFIG["seed_value"],
        image_size=CONFIG["img_shape"],
        batch_size=CONFIG["batch_size"],
        label_mode="categorical",
    ).prefetch(buffer_size=AUTOTUNE)

    valid_set = tf.keras.utils.image_dataset_from_directory(
        valid_dir,
        seed=CONFIG["seed_value"],
        image_size=CONFIG["img_shape"],
        batch_size=CONFIG["batch_size"],
        label_mode="categorical",
    ).prefetch(buffer_size=AUTOTUNE)

    test_set = tf.keras.utils.image_dataset_from_directory(
        test_dir,
        seed=CONFIG["seed_value"],
        image_size=CONFIG["img_shape"],
        batch_size=CONFIG["batch_size"],
        label_mode="categorical",
    ).prefetch(buffer_size=AUTOTUNE)

    base_model = tf.keras.applications.MobileNetV3Small(
        input_shape=CONFIG["input_shape"],
        include_top=False,
        classes=CONFIG["num_classes"],
    )
    base_model.trainable = False

    inputs = tf.keras.Input(shape=CONFIG["input_shape"])
    x = base_model(inputs, training=False)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(rate=CONFIG["dropout_rate"], seed=CONFIG["seed_value"])(
        x
    )
    outputs = tf.keras.layers.Dense(units=CONFIG["num_classes"], activation="softmax")(
        x
    )

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(
            learning_rate=CONFIG["learning_rate"],
        ),
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
        metrics=[
            tf.keras.metrics.Precision(name="precision"),
            tf.keras.metrics.Recall(name="recall"),
            tfa.metrics.F1Score(
                num_classes=CONFIG["num_classes"],
                average="weighted",
                name="f1_score",
                threshold=0.5,
            ),
            tf.keras.metrics.FalseNegatives(name="false_negatives"),
            tf.keras.metrics.TruePositives(name="true_positives"),
            tf.keras.metrics.FalsePositives(name="false_positives"),
            tf.keras.metrics.TrueNegatives(name="true_negatives"),
        ],
    )

    class_0_weight = (1 / healthy_count) * (image_count / 2.0)
    class_1_weight = (1 / wssv_count) * (image_count / 2.0)
    class_weight = {0: class_0_weight, 1: class_1_weight}

    if os.path.exists("checkpoints_mobilenetv3small"):
        shutil.rmtree("checkpoints_mobilenetv3small")
        os.makedirs("checkpoints_mobilenetv3small")

    checkpoint_path = "checkpoints_mobilenetv3small/cp-{epoch:04d}.ckpt"

    model.save_weights(checkpoint_path.format(epoch=0))

    history = model.fit(
        train_set,
        epochs=CONFIG["epochs"],
        validation_data=valid_set,
        class_weight=class_weight,
        callbacks=[
            tf.keras.callbacks.ModelCheckpoint(
                filepath=checkpoint_path,
                save_weights_only=True,
                monitor="val_loss",
                mode="min",
                verbose=1,
                save_best_only=True,
            ),
            WandbCallback(save_model=False),
        ],
    )

    return history

In [4]:
for i in range(5):
    wandb.init(
        project="wssv-recognition",
        config=CONFIG,
        group=EXPERIMENT_NAME,
        job_type="train",
    )
    train()
    wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mrenzo_querol[0m. Use [1m`wandb login --relogin`[0m to force relogin


Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.65036, saving model to checkpoints_mobilenetv3small/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss did not improve from 0.65036
Epoch 3/100
Epoch 3: val_loss improved from 0.65036 to 0.61656, saving model to checkpoints_mobilenetv3small/cp-0003.ckpt
Epoch 4/100
Epoch 4: val_loss did not improve from 0.61656
Epoch 5/100
Epoch 5: val_loss improved from 0.61656 to 0.60458, saving model to checkpoints_mobilenetv3small/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.60458 to 0.59733, saving model to checkpoints_mobilenetv3small/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.59733 to 0.57137, saving model to checkpoints_mobilenetv3small/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss improved from 0.57137 to 0.55063, saving model to checkpoints_mobilenetv3small/cp-0008.ckpt
Epoch 9/100
Epoch 9: val_loss impro

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▂▂▃▃▄▄▄▄▅▅▆▆▅▆▆▆▆▆▇▆▇▇▇▇▇██▇▇██████████
false_negatives,█▇▇▇▆▅▅▅▄▄▄▃▃▄▃▃▃▃▃▂▃▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁
false_positives,█▇▇▇▆▅▅▅▄▄▄▃▃▄▃▃▃▃▃▂▃▂▂▂▂▂▁▁▂▂▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▆▅▅▄▄▄▄▃▃▃▃▂▂▃▂▂▂▂▂▂▁▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁
precision,▁▂▂▂▃▄▄▄▅▅▅▆▆▅▆▆▆▆▆▇▆▇▇▇▇▇██▇▇██████████
recall,▁▂▂▂▃▄▄▄▅▅▅▆▆▅▆▆▆▆▆▇▆▇▇▇▇▇██▇▇██████████
true_negatives,▁▂▂▂▃▄▄▄▅▅▅▆▆▅▆▆▆▆▆▇▆▇▇▇▇▇██▇▇██████████
true_positives,▁▂▂▂▃▄▄▄▅▅▅▆▆▅▆▆▆▆▆▇▆▇▇▇▇▇██▇▇██████████
val_f1_score,▁▄▂▃▅▅▅▅▅▅▆▆▅█▅▇▇▇▇▇▇▇▇▆▇▆▆▇▆▆▆▆▆▆▆▆▆▆▆▆

0,1
best_epoch,97.0
best_val_loss,0.36716
epoch,99.0
f1_score,0.87792
false_negatives,130.0
false_positives,130.0
loss,0.29982
precision,0.87292
recall,0.87292
true_negatives,893.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669749650001602, max=1.0…

Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.49712, saving model to checkpoints_mobilenetv3small/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss did not improve from 0.49712
Epoch 3/100
Epoch 3: val_loss did not improve from 0.49712
Epoch 4/100
Epoch 4: val_loss did not improve from 0.49712
Epoch 5/100
Epoch 5: val_loss improved from 0.49712 to 0.47633, saving model to checkpoints_mobilenetv3small/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.47633 to 0.46019, saving model to checkpoints_mobilenetv3small/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.46019 to 0.43706, saving model to checkpoints_mobilenetv3small/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss did not improve from 0.43706
Epoch 9/100
Epoch 9: val_loss improved from 0.43706 to 0.42945, saving model to checkpoints_mobilenetv3small/cp-0009.ckpt
Epoch 10/100
Epoch 10: val_loss improve

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▁▂▂▃▃▃▄▅▅▅▅▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇█▇██████████
false_negatives,▇█▇▇▆▆▆▅▄▄▄▄▃▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
false_positives,▇█▇▇▆▆▆▅▄▄▄▄▃▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▆▅▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▂▁▁▁▁▁▁▁
precision,▂▁▂▂▃▃▃▄▅▅▅▅▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇█▇██████████
recall,▂▁▂▂▃▃▃▄▅▅▅▅▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇█▇██████████
true_negatives,▂▁▂▂▃▃▃▄▅▅▅▅▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇█▇██████████
true_positives,▂▁▂▂▃▃▃▄▅▅▅▅▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇█▇██████████
val_f1_score,▄▁▃▄▆▆▇▆▆▆▆▆▇▇▇██▇▇▇███████▇█████▇████▇█

0,1
best_epoch,40.0
best_val_loss,0.33631
epoch,99.0
f1_score,0.88543
false_negatives,122.0
false_positives,122.0
loss,0.28063
precision,0.88074
recall,0.88074
true_negatives,901.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670328266673095, max=1.0…

Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.70524, saving model to checkpoints_mobilenetv3small/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss did not improve from 0.70524
Epoch 3/100
Epoch 3: val_loss did not improve from 0.70524
Epoch 4/100
Epoch 4: val_loss did not improve from 0.70524
Epoch 5/100
Epoch 5: val_loss did not improve from 0.70524
Epoch 6/100
Epoch 6: val_loss did not improve from 0.70524
Epoch 7/100
Epoch 7: val_loss did not improve from 0.70524
Epoch 8/100
Epoch 8: val_loss did not improve from 0.70524
Epoch 9/100
Epoch 9: val_loss improved from 0.70524 to 0.66742, saving model to checkpoints_mobilenetv3small/cp-0009.ckpt
Epoch 10/100
Epoch 10: val_loss improved from 0.66742 to 0.63724, saving model to checkpoints_mobilenetv3small/cp-0010.ckpt
Epoch 11/100
Epoch 11: val_loss did not improve from 0.63724
Epoch 12/100
Epoch 12: val_loss improved fr

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▂▁▁▂▃▃▃▄▄▅▅▅▆▅▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇████▇██████
false_negatives,▇██▇▆▇▆▅▅▄▄▄▃▄▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁
false_positives,▇██▇▆▇▆▅▅▄▄▄▃▄▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁
loss,█▇▆▅▅▅▄▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
precision,▂▁▁▂▃▂▃▄▄▅▅▅▆▅▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇████▇██████
recall,▂▁▁▂▃▂▃▄▄▅▅▅▆▅▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇████▇██████
true_negatives,▂▁▁▂▃▂▃▄▄▅▅▅▆▅▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇████▇██████
true_positives,▂▁▁▂▃▂▃▄▄▅▅▅▆▅▇▆▆▇▇▇▇▇▇▇▇▇▇▇▇████▇██████
val_f1_score,▂▁▂▂▅▅▇▇▇▇▇▇▇▇▇▇█████▇██████████████████

0,1
best_epoch,97.0
best_val_loss,0.39527
epoch,99.0
f1_score,0.87812
false_negatives,130.0
false_positives,130.0
loss,0.29877
precision,0.87292
recall,0.87292
true_negatives,893.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669555249995936, max=1.0…

Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.54665, saving model to checkpoints_mobilenetv3small/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss did not improve from 0.54665
Epoch 3/100
Epoch 3: val_loss did not improve from 0.54665
Epoch 4/100
Epoch 4: val_loss did not improve from 0.54665
Epoch 5/100
Epoch 5: val_loss did not improve from 0.54665
Epoch 6/100
Epoch 6: val_loss did not improve from 0.54665
Epoch 7/100
Epoch 7: val_loss did not improve from 0.54665
Epoch 8/100
Epoch 8: val_loss did not improve from 0.54665
Epoch 9/100
Epoch 9: val_loss did not improve from 0.54665
Epoch 10/100
Epoch 10: val_loss improved from 0.54665 to 0.54414, saving model to checkpoints_mobilenetv3small/cp-0010.ckpt
Epoch 11/100
Epoch 11: val_loss improved from 0.54414 to 0.54338, saving model to checkpoints_mobilenetv3small/cp-0011.ckpt
Epoch 12/100
Epoch 12: val_loss improved fr

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▂▁▁▂▂▃▄▄▄▅▅▅▅▅▆▆▆▆▇▇▆▇▇▇▇▇▇▇▇▇▇███▇█████
false_negatives,▆██▇▇▆▅▅▅▄▄▄▄▄▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁
false_positives,▆██▇▇▆▅▅▅▄▄▄▄▄▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁
loss,█▆▅▅▄▄▃▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
precision,▃▁▁▂▂▃▄▄▄▅▅▅▅▅▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇▇▇███▇█████
recall,▃▁▁▂▂▃▄▄▄▅▅▅▅▅▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇▇▇███▇█████
true_negatives,▃▁▁▂▂▃▄▄▄▅▅▅▅▅▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇▇▇███▇█████
true_positives,▃▁▁▂▂▃▄▄▄▅▅▅▅▅▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇▇▇███▇█████
val_f1_score,▄▁▁▃▄▄▅▄▆▆▆▆▇███▇██████████▇██████▇▇████

0,1
best_epoch,88.0
best_val_loss,0.35472
epoch,99.0
f1_score,0.8717
false_negatives,137.0
false_positives,137.0
loss,0.29364
precision,0.86608
recall,0.86608
true_negatives,886.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666955793333121, max=1.0)…

Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.68593, saving model to checkpoints_mobilenetv3small/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss did not improve from 0.68593
Epoch 3/100
Epoch 3: val_loss did not improve from 0.68593
Epoch 4/100
Epoch 4: val_loss improved from 0.68593 to 0.65069, saving model to checkpoints_mobilenetv3small/cp-0004.ckpt
Epoch 5/100
Epoch 5: val_loss improved from 0.65069 to 0.61981, saving model to checkpoints_mobilenetv3small/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.61981 to 0.60183, saving model to checkpoints_mobilenetv3small/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.60183 to 0.56898, saving model to checkpoints_mobilenetv3small/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss improved from 0.56898 to 0.54660, saving model to checkpoints_mobilenetv3small/cp-0008.ckpt
Epoch 9/100
Epoch 9: val_loss impro

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▁▁▂▃▃▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇█▇▇▇▇▇█▇█▇▇▇██████
false_negatives,▇██▇▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▁▂▂▂▂▂▁▂▁▂▂▂▁▁▁▁▁▁
false_positives,▇██▇▆▆▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▁▂▂▂▂▂▁▂▁▂▂▂▁▁▁▁▁▁
loss,█▇▆▅▅▅▄▄▃▃▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁
precision,▂▁▁▂▃▃▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇█▇▇▇▇▇█▇█▇▇▇██████
recall,▂▁▁▂▃▃▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇█▇▇▇▇▇█▇█▇▇▇██████
true_negatives,▂▁▁▂▃▃▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇█▇▇▇▇▇█▇█▇▇▇██████
true_positives,▂▁▁▂▃▃▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇█▇▇▇▇▇█▇█▇▇▇██████
val_f1_score,▁▃▅▅▅▆▇▇▆▇██████████████████████████████

0,1
best_epoch,80.0
best_val_loss,0.37604
epoch,99.0
f1_score,0.88543
false_negatives,122.0
false_positives,122.0
loss,0.28336
precision,0.88074
recall,0.88074
true_negatives,901.0
