In [None]:
from pathlib import Path


In [None]:
temp_folder = Path("valid-versus-invalid").name

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
import io
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

epochs = 120


# Learning rate scheduler
def cosine_annealing_scheduler(epoch, lr):
    initial_lr = 1e-3
    min_lr = 1e-6
    T_max = int(epochs / 2)

    cosine_decay = 0.5 * (1 + np.cos(np.pi * (epoch % T_max) / T_max))
    new_lr = (initial_lr - min_lr) * cosine_decay + min_lr

    return float(new_lr)


import numpy as np
from keras.utils import image_dataset_from_directory


# Load datasets using Keras utilities
batch_size = 8
img_size = (640, 640)

train_ds = image_dataset_from_directory(
    f"{temp_folder}/train",
    label_mode="categorical",
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    pad_to_aspect_ratio=True
)

val_ds = image_dataset_from_directory(
    f"{temp_folder}/valid",
    label_mode="categorical",
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    pad_to_aspect_ratio=True
)


from sklearn.utils.class_weight import compute_class_weight

y_train = []
for _, labels in train_ds:
    # class_indices = labels.numpy()
    class_indices = np.argmax(labels.numpy(), axis=1)
    y_train.extend(class_indices)

class_weights = compute_class_weight(
    class_weight="balanced", classes=np.unique(y_train), y=y_train
)

class_weights = dict(enumerate(class_weights))

class_weights

# Get class names
class_names = train_ds.class_names
num_classes = len(class_names)
input_shape = (640, 640, 1)

# model = new_model




class ConfusionMatrixCallback(keras.callbacks.Callback):
    def __init__(self, val_data, class_names=None, log_dir="logs"):
        super().__init__()
        self.val_data = val_data
        self.class_names = class_names
        self.file_writer = tf.summary.create_file_writer(f"{log_dir}/cm")

    def on_epoch_end(self, epoch, logs=None):
        y_true = []
        y_pred = []

        for images, labels in self.val_data:
            preds = self.model.predict(images, verbose=0)
            preds = np.argmax(preds, axis=1)
            y_pred.extend(preds)
            true_labels = np.argmax(labels.numpy(), axis=1)
            y_true.extend(true_labels)

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)

        # Count-based confusion matrix
        cm_counts = confusion_matrix(y_true, y_pred)
        fig1, ax1 = plt.subplots(figsize=(10, 10))
        disp_counts = ConfusionMatrixDisplay(
            confusion_matrix=cm_counts, display_labels=self.class_names
        )
        disp_counts.plot(ax=ax1, cmap="Blues", values_format="d")
        ax1.set_title(f"Confusion Matrix (Counts) - Epoch {epoch}")
        buf1 = io.BytesIO()
        plt.savefig(buf1, format="png")
        plt.close(fig1)
        buf1.seek(0)
        image1 = tf.image.decode_png(buf1.getvalue(), channels=4)
        image1 = tf.expand_dims(image1, 0)

        # Percentage-based confusion matrix
        cm_percent = (
            cm_counts.astype("float") / cm_counts.sum(axis=1, keepdims=True) * 100
        )
        fig2, ax2 = plt.subplots(figsize=(10, 10))
        disp_percent = ConfusionMatrixDisplay(
            confusion_matrix=cm_percent, display_labels=self.class_names
        )
        disp_percent.plot(ax=ax2, cmap="Oranges", values_format=".1f")
        ax2.set_title(f"Confusion Matrix (Percentage) - Epoch {epoch}")
        buf2 = io.BytesIO()
        plt.savefig(buf2, format="png")
        plt.close(fig2)
        buf2.seek(0)
        image2 = tf.image.decode_png(buf2.getvalue(), channels=4)
        image2 = tf.expand_dims(image2, 0)

        # Log both images to TensorBoard
        with self.file_writer.as_default():
            tf.summary.image("Confusion Matrix - Counts", image1, step=epoch)
            tf.summary.image("Confusion Matrix - Percentage", image2, step=epoch)


In [None]:

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=f"models/{temp_folder}_best_val_acc.keras", save_best_only=True, monitor="val_acc"
    ),
    keras.callbacks.ModelCheckpoint(
        filepath=f"models/{temp_folder}_best_val_loss.keras", save_best_only=True, monitor="val_loss"
    ),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=int(epochs / 2)),
    keras.callbacks.TensorBoard(log_dir="logs"),
    keras.callbacks.LearningRateScheduler(cosine_annealing_scheduler, verbose=1),
    ConfusionMatrixCallback(val_ds, class_names=class_names),
]


In [None]:
class_weights

In [None]:
from layers import (
    build_resnet,
    build_simple_cnn,
    RandomSmoothingModel,
    RandomSmoothingLoss,
)

base_backbone = keras.models.load_model(
    "models/valid-v-invalid.keras", compile=False
).get_layer("resnet_18_imagenet")
base_model = build_resnet(
    input_shape=input_shape,
    num_classes=num_classes,
    preset="resnet_18_imagenet",
    should_scale=True,
)
model = RandomSmoothingModel(inputs=base_model.inputs, outputs=base_model.outputs)

In [None]:
model.summary()

In [None]:
model.compile(
    loss=RandomSmoothingLoss(keras.losses.CategoricalCrossentropy(), smoothing_range=(0.0, 0.1)),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3, weight_decay=1e-6),
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks,
    class_weight=class_weights,
)

score = model.evaluate(val_ds, verbose=0)
print(f"Test loss: {score[0]}")
print(f"Test accuracy: {score[1]}")



