In [1]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import tensorflow as tf
import tensorflow_addons as tfa
import glob
import wandb
from wandb.keras import WandbCallback
import shutil
import matplotlib.pyplot as plt
import numpy as np


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
EXPERIMENT_NAME = "efficientnetv2b0-average-runs"
DATA_DIR = "dataset-augmented"
AUTOTUNE = tf.data.AUTOTUNE
CONFIG = dict(
    epochs=100,
    learning_rate=1e-4,
    batch_size=64,
    img_shape=(224, 224),
    input_shape=(224, 224, 3),
    num_classes=2,
    dropout_rate=0.5,
    es_patience=10,
    seed_value=42,
)

In [3]:
def train():
    image_count = len(glob.glob(f"{DATA_DIR}/*/*/*.jpg"))
    healthy_count = len(glob.glob(f"{DATA_DIR}/train/healthy/*.jpg"))
    wssv_count = len(glob.glob(f"{DATA_DIR}/train/wssv/*.jpg"))

    train_dir = os.path.join(DATA_DIR, "train")
    valid_dir = os.path.join(DATA_DIR, "val")
    test_dir = os.path.join(DATA_DIR, "test")

    train_set = tf.keras.utils.image_dataset_from_directory(
        train_dir,
        seed=CONFIG["seed_value"],
        image_size=CONFIG["img_shape"],
        batch_size=CONFIG["batch_size"],
        label_mode="categorical",
    ).prefetch(buffer_size=AUTOTUNE)

    valid_set = tf.keras.utils.image_dataset_from_directory(
        valid_dir,
        seed=CONFIG["seed_value"],
        image_size=CONFIG["img_shape"],
        batch_size=CONFIG["batch_size"],
        label_mode="categorical",
    ).prefetch(buffer_size=AUTOTUNE)

    test_set = tf.keras.utils.image_dataset_from_directory(
        test_dir,
        seed=CONFIG["seed_value"],
        image_size=CONFIG["img_shape"],
        batch_size=CONFIG["batch_size"],
        label_mode="categorical",
    ).prefetch(buffer_size=AUTOTUNE)

    base_model = tf.keras.applications.EfficientNetV2B0(
        input_shape=CONFIG["input_shape"],
        include_top=False,
        classes=CONFIG["num_classes"],
    )
    base_model.trainable = False

    inputs = tf.keras.Input(shape=CONFIG["input_shape"])
    x = base_model(inputs, training=False)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(rate=CONFIG["dropout_rate"], seed=CONFIG["seed_value"])(
        x
    )
    outputs = tf.keras.layers.Dense(units=CONFIG["num_classes"], activation="softmax")(
        x
    )

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=CONFIG["learning_rate"]),
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
        metrics=[
            tf.keras.metrics.Precision(name="precision"),
            tf.keras.metrics.Recall(name="recall"),
            tfa.metrics.F1Score(
                num_classes=CONFIG["num_classes"],
                average="weighted",
                name="f1_score",
                threshold=0.5,
            ),
            tf.keras.metrics.FalseNegatives(name="false_negatives"),
            tf.keras.metrics.TruePositives(name="true_positives"),
            tf.keras.metrics.FalsePositives(name="false_positives"),
            tf.keras.metrics.TrueNegatives(name="true_negatives"),
        ],
    )

    class_0_weight = (1 / healthy_count) * (image_count / 2.0)
    class_1_weight = (1 / wssv_count) * (image_count / 2.0)
    class_weight = {0: class_0_weight, 1: class_1_weight}

    if os.path.exists("checkpoints_efficientnetv2b0"):
        shutil.rmtree("checkpoints_efficientnetv2b0")
        os.makedirs("checkpoints_efficientnetv2b0")

    checkpoint_path = "checkpoints_efficientnetv2b0/cp-{epoch:04d}.ckpt"

    model.save_weights(checkpoint_path.format(epoch=0))

    model.fit(
        train_set,
        epochs=CONFIG["epochs"],
        validation_data=valid_set,
        class_weight=class_weight,
        callbacks=[
            tf.keras.callbacks.ModelCheckpoint(
                filepath=checkpoint_path,
                save_weights_only=True,
                monitor="val_loss",
                mode="min",
                verbose=1,
                save_best_only=True,
            ),
            WandbCallback(save_model=False),
        ],
    )

In [4]:
for i in range(5):
    wandb.init(
        project="wssv-recognition",
        config=CONFIG,
        group=EXPERIMENT_NAME,
        job_type="train",
    )
    train()
    wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mrenzo_querol[0m. Use [1m`wandb login --relogin`[0m to force relogin


Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.66560, saving model to checkpoints_efficientnetv2b0/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss improved from 0.66560 to 0.65926, saving model to checkpoints_efficientnetv2b0/cp-0002.ckpt
Epoch 3/100
Epoch 3: val_loss improved from 0.65926 to 0.65334, saving model to checkpoints_efficientnetv2b0/cp-0003.ckpt
Epoch 4/100
Epoch 4: val_loss improved from 0.65334 to 0.63806, saving model to checkpoints_efficientnetv2b0/cp-0004.ckpt
Epoch 5/100
Epoch 5: val_loss improved from 0.63806 to 0.61954, saving model to checkpoints_efficientnetv2b0/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.61954 to 0.59959, saving model to checkpoints_efficientnetv2b0/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.59959 to 0.57930, saving model to checkpoints_efficientnetv2b0/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss i

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▂▂▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇█▇▇▇▇▇██▇██▇██████████
false_negatives,█▇▇▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▁▂▂▂▂▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁
false_positives,█▇▇▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▁▂▂▂▂▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▅▅▅▄▄▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
precision,▁▂▂▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇█▇▇▇▇▇██▇██▇██████████
recall,▁▂▂▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇█▇▇▇▇▇██▇██▇██████████
true_negatives,▁▂▂▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇█▇▇▇▇▇██▇██▇██████████
true_positives,▁▂▂▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇█▇▇▇▇▇██▇██▇██████████
val_f1_score,▁▂▄▅▅▆▆▇▇▇▇▇▇▇▇▇▇█▇█████████████████████

0,1
best_epoch,99.0
best_val_loss,0.32086
epoch,99.0
f1_score,0.94081
false_negatives,62.0
false_positives,62.0
loss,0.19684
precision,0.93939
recall,0.93939
true_negatives,961.0


Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.69140, saving model to checkpoints_efficientnetv2b0/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss improved from 0.69140 to 0.66475, saving model to checkpoints_efficientnetv2b0/cp-0002.ckpt
Epoch 3/100
Epoch 3: val_loss improved from 0.66475 to 0.64956, saving model to checkpoints_efficientnetv2b0/cp-0003.ckpt
Epoch 4/100
Epoch 4: val_loss improved from 0.64956 to 0.63004, saving model to checkpoints_efficientnetv2b0/cp-0004.ckpt
Epoch 5/100
Epoch 5: val_loss improved from 0.63004 to 0.61032, saving model to checkpoints_efficientnetv2b0/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.61032 to 0.58656, saving model to checkpoints_efficientnetv2b0/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.58656 to 0.56077, saving model to checkpoints_efficientnetv2b0/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss i

VBox(children=(Label(value='0.004 MB of 0.060 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.060089…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▂▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███▇█▇▇▇█████████
false_negatives,█▇▆▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▁▂▁▁▁▁▁▁▁▁▁
false_positives,█▇▆▅▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▁▂▁▁▁▁▁▁▁▁▁
loss,█▇▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁
precision,▁▂▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███▇█▇█▇█████████
recall,▁▂▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███▇█▇█▇█████████
true_negatives,▁▂▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███▇█▇█▇█████████
true_positives,▁▂▃▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███▇█▇█▇█████████
val_f1_score,▁▃▅▆▇▇▇▇▇▇▇▇████████████████████████████

0,1
best_epoch,99.0
best_val_loss,0.32389
epoch,99.0
f1_score,0.94081
false_negatives,62.0
false_positives,62.0
loss,0.19265
precision,0.93939
recall,0.93939
true_negatives,961.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669614666670895, max=1.0…

Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.63588, saving model to checkpoints_efficientnetv2b0/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss improved from 0.63588 to 0.60374, saving model to checkpoints_efficientnetv2b0/cp-0002.ckpt
Epoch 3/100
Epoch 3: val_loss improved from 0.60374 to 0.58335, saving model to checkpoints_efficientnetv2b0/cp-0003.ckpt
Epoch 4/100
Epoch 4: val_loss improved from 0.58335 to 0.56414, saving model to checkpoints_efficientnetv2b0/cp-0004.ckpt
Epoch 5/100
Epoch 5: val_loss improved from 0.56414 to 0.53662, saving model to checkpoints_efficientnetv2b0/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.53662 to 0.51706, saving model to checkpoints_efficientnetv2b0/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.51706 to 0.49764, saving model to checkpoints_efficientnetv2b0/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss i

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▂▄▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇█▇▇▇████████████████
false_negatives,█▇▅▅▄▄▄▃▃▂▃▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
false_positives,█▇▅▅▄▄▄▃▃▂▃▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
precision,▁▂▄▄▅▅▅▆▆▇▆▇▇▇▇▇▇▇▇▇█▇▇▇████████████████
recall,▁▂▄▄▅▅▅▆▆▇▆▇▇▇▇▇▇▇▇▇█▇▇▇████████████████
true_negatives,▁▂▄▄▅▅▅▆▆▇▆▇▇▇▇▇▇▇▇▇█▇▇▇████████████████
true_positives,▁▂▄▄▅▅▅▆▆▇▆▇▇▇▇▇▇▇▇▇█▇▇▇████████████████
val_f1_score,▁▅▇▇██▇▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇

0,1
best_epoch,99.0
best_val_loss,0.27566
epoch,99.0
f1_score,0.93626
false_negatives,67.0
false_positives,67.0
loss,0.19769
precision,0.93451
recall,0.93451
true_negatives,956.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670198133336575, max=1.0…

Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.54414, saving model to checkpoints_efficientnetv2b0/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss improved from 0.54414 to 0.50656, saving model to checkpoints_efficientnetv2b0/cp-0002.ckpt
Epoch 3/100
Epoch 3: val_loss improved from 0.50656 to 0.48490, saving model to checkpoints_efficientnetv2b0/cp-0003.ckpt
Epoch 4/100
Epoch 4: val_loss improved from 0.48490 to 0.46677, saving model to checkpoints_efficientnetv2b0/cp-0004.ckpt
Epoch 5/100
Epoch 5: val_loss improved from 0.46677 to 0.44865, saving model to checkpoints_efficientnetv2b0/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.44865 to 0.43379, saving model to checkpoints_efficientnetv2b0/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.43379 to 0.42483, saving model to checkpoints_efficientnetv2b0/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss i

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▃▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇█▇██████████████
false_negatives,█▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
false_positives,█▆▅▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
precision,▁▃▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇█▇██████████████
recall,▁▃▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇█▇██████████████
true_negatives,▁▃▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇█▇██████████████
true_positives,▁▃▄▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇█▇██████████████
val_f1_score,▁▇▇▇██▇▇▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅

0,1
best_epoch,95.0
best_val_loss,0.27041
epoch,99.0
f1_score,0.93069
false_negatives,73.0
false_positives,73.0
loss,0.20152
precision,0.92864
recall,0.92864
true_negatives,950.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669349266673333, max=1.0…

Found 1023 files belonging to 2 classes.
Found 49 files belonging to 2 classes.
Found 13 files belonging to 2 classes.
Epoch 1/100
Epoch 1: val_loss improved from inf to 0.82902, saving model to checkpoints_efficientnetv2b0/cp-0001.ckpt
Epoch 2/100
Epoch 2: val_loss improved from 0.82902 to 0.74039, saving model to checkpoints_efficientnetv2b0/cp-0002.ckpt
Epoch 3/100
Epoch 3: val_loss improved from 0.74039 to 0.69579, saving model to checkpoints_efficientnetv2b0/cp-0003.ckpt
Epoch 4/100
Epoch 4: val_loss improved from 0.69579 to 0.66804, saving model to checkpoints_efficientnetv2b0/cp-0004.ckpt
Epoch 5/100
Epoch 5: val_loss improved from 0.66804 to 0.64165, saving model to checkpoints_efficientnetv2b0/cp-0005.ckpt
Epoch 6/100
Epoch 6: val_loss improved from 0.64165 to 0.61775, saving model to checkpoints_efficientnetv2b0/cp-0006.ckpt
Epoch 7/100
Epoch 7: val_loss improved from 0.61775 to 0.59184, saving model to checkpoints_efficientnetv2b0/cp-0007.ckpt
Epoch 8/100
Epoch 8: val_loss i

VBox(children=(Label(value='0.004 MB of 0.061 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.059770…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
f1_score,▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████████
false_negatives,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
false_positives,█▆▅▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▇▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
precision,▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████████
recall,▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████████
true_negatives,▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████████
true_positives,▁▃▄▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████████████
val_f1_score,▁▅▅▅▆▇▇▇▇▇▇▇▇███████████████████████████

0,1
best_epoch,97.0
best_val_loss,0.29464
epoch,99.0
f1_score,0.94191
false_negatives,61.0
false_positives,61.0
loss,0.19409
precision,0.94037
recall,0.94037
true_negatives,962.0
