In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_addons as tfa
import pathlib
import wandb
from wandb.keras import WandbCallback
import numpy as np
import os
import shutil
import tf2onnx

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

In [None]:
wandb.login()

In [None]:
DATA_DIR = pathlib.Path("augmented_dataset")
AUTOTUNE = tf.data.AUTOTUNE

CONFIG = dict(
    epochs=100,
    learning_rate=1e-3,
    batch_size=16,
    img_shape=(224, 224),
    input_shape=(224, 224, 3),
    num_classes=2,
    dropout_rate=0.2,
    es_patience=20,
    valid_split=0.4,
    seed_value=42,
)

In [None]:
# from preprocessing import preprocess_image

validation_batches = tf.data.experimental.cardinality(valid_set)
validation_dataset = valid_set.skip(validation_batches // 5)
test_dataset = valid_set.take(validation_batches // 5)

'''
Prefetch for faster loading
'''
train_dataset = train_set.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

'''
Get number of batches per split
'''
num_training_batches = tf.data.experimental.cardinality(train_dataset)
num_validation_batches = tf.data.experimental.cardinality(validation_dataset)
num_test_batches = tf.data.experimental.cardinality(test_dataset)

print(f"\nNumber of training batches: {num_training_batches}")
print(f"Number of validation batches: {num_validation_batches}")
print(f"Number of test batches: {num_test_batches}")

In [None]:
def create_model():
    metrics = [
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tfa.metrics.F1Score(num_classes=2, average="weighted", name="f1_score", threshold=0.5),
        tf.keras.metrics.FalseNegatives(name='false_negatives'),
        tf.keras.metrics.TruePositives(name='true_positives')
    ]
    
    base_model = tf.keras.applications.MobileNetV3Small(input_shape=CONFIG["input_shape"], include_top=False)
    base_model.trainable = False
    
    inputs = tf.keras.Input(shape=CONFIG["input_shape"])
    x = base_model(inputs, training=False)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dropout(rate=CONFIG["dropout_rate"], seed=CONFIG["seed_value"])(x)
    outputs = tf.keras.layers.Dense(units=2, activation='softmax')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=CONFIG["learning_rate"]),
        loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
        metrics=metrics,
    )

    return model

In [None]:
class_0_weight = (1 / healthy_count) * (image_count / 2.0)
class_1_weight = (1 / wssv_count) * (image_count / 2.0)
class_weight = {0: class_0_weight, 1: class_1_weight}

callbacks = [
    # tf.keras.callbacks.EarlyStopping(patience=es_patience),
    # tf.keras.callbacks.ModelCheckpoint(filepath="checkpoints_training/cp-{epoch:04d}.ckpt",
    #                                    save_weights_only=True,
    #                                    monitor="val_loss",
    #                                    mode="min",
    #                                    verbose=1,
    #                                    save_best_only=True),
    WandbCallback(save_model=False),
    tfa.callbacks.TQDMProgressBar()
]

model.fit(train_dataset,
          epochs=epochs,
          validation_data=validation_dataset,
          class_weight=class_weight,
          callbacks=callbacks)