In [None]:
from pathlib import Path

first_label = "circle"
second_label = "lack_of_fusion"

src_folder = "dataset_by_aspect_ratio"
import shutil


def prepare_data(src_folder: str | Path, *labels) -> str:
    if isinstance(src_folder, str):
        src_folder = Path(src_folder)
    temporarily_folder = "-versus-".join(labels)
    temporarily_folder = Path(temporarily_folder)
    temporarily_folder.mkdir(parents=True, exist_ok=True)
    for file in src_folder.rglob("**/*.jpg"):
        if file.parent.name in labels:
            rel_path = file.relative_to(src_folder)
            tgt_path = temporarily_folder / rel_path
            tgt_path.parent.mkdir(parents=True, exist_ok=True)
            shutil.copy(file, tgt_path)
    return temporarily_folder.name

In [2]:
temp_folder = prepare_data(src_folder, first_label, second_label)

In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"
import keras
import numpy as np
import io
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

epochs = 120


# Learning rate scheduler
def cosine_annealing_scheduler(epoch, lr):
    initial_lr = 1e-3
    min_lr = 1e-6
    T_max = int(epochs / 2)

    cosine_decay = 0.5 * (1 + np.cos(np.pi * (epoch % T_max) / T_max))
    new_lr = (initial_lr - min_lr) * cosine_decay + min_lr

    return float(new_lr)


import numpy as np
from keras.utils import image_dataset_from_directory
from utils import create_images_dataset


# Load datasets using Keras utilities
batch_size = 32
img_size = (512, 512)

class_names = [first_label, second_label]

train_ds = create_images_dataset(
    f"{temp_folder}/train",
    class_names=class_names,
    target_size=img_size,
    train=True,
)

val_ds = create_images_dataset(
    f"{temp_folder}/valid",
    class_names=class_names,
    target_size=img_size,
    train=False,
)

train_ds = train_ds.shuffle(100).batch(32).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.batch(32).prefetch(tf.data.AUTOTUNE)

from sklearn.utils.class_weight import compute_class_weight

y_val = []
for _, labels in val_ds:
    # class_indices = labels.numpy()
    class_indices = np.argmax(labels.numpy(), axis=1)
    y_val.extend(class_indices)

class_weights = compute_class_weight(
    class_weight="balanced", classes=np.unique(y_val), y=y_val
)

class_weights = dict(enumerate(class_weights))

class_weights

# Get class names
num_classes = len(class_names)
input_shape = (512, 512, 1)

# model = new_model




class ConfusionMatrixCallback(keras.callbacks.Callback):
    def __init__(self, val_data, class_names=None, log_dir="logs"):
        super().__init__()
        self.val_data = val_data
        self.class_names = class_names
        self.file_writer = tf.summary.create_file_writer(f"{log_dir}/cm")

    def on_epoch_end(self, epoch, logs=None):
        y_true = []
        y_pred = []

        for images, labels in self.val_data:
            preds = self.model.predict(images, verbose=0)
            preds = np.argmax(preds, axis=1)
            y_pred.extend(preds)
            true_labels = np.argmax(labels.numpy(), axis=1)
            y_true.extend(true_labels)

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)

        # Count-based confusion matrix
        cm_counts = confusion_matrix(y_true, y_pred)
        fig1, ax1 = plt.subplots(figsize=(10, 10))
        disp_counts = ConfusionMatrixDisplay(
            confusion_matrix=cm_counts, display_labels=self.class_names
        )
        disp_counts.plot(ax=ax1, cmap="Blues", values_format="d")
        ax1.set_title(f"Confusion Matrix (Counts) - Epoch {epoch}")
        buf1 = io.BytesIO()
        plt.savefig(buf1, format="png")
        plt.close(fig1)
        buf1.seek(0)
        image1 = tf.image.decode_png(buf1.getvalue(), channels=4)
        image1 = tf.expand_dims(image1, 0)

        # Percentage-based confusion matrix
        cm_percent = (
            cm_counts.astype("float") / cm_counts.sum(axis=1, keepdims=True) * 100
        )
        fig2, ax2 = plt.subplots(figsize=(10, 10))
        disp_percent = ConfusionMatrixDisplay(
            confusion_matrix=cm_percent, display_labels=self.class_names
        )
        disp_percent.plot(ax=ax2, cmap="Oranges", values_format=".1f")
        ax2.set_title(f"Confusion Matrix (Percentage) - Epoch {epoch}")
        buf2 = io.BytesIO()
        plt.savefig(buf2, format="png")
        plt.close(fig2)
        buf2.seek(0)
        image2 = tf.image.decode_png(buf2.getvalue(), channels=4)
        image2 = tf.expand_dims(image2, 0)

        # Log both images to TensorBoard
        with self.file_writer.as_default():
            tf.summary.image("Confusion Matrix - Counts", image1, step=epoch)
            tf.summary.image("Confusion Matrix - Percentage", image2, step=epoch)


In [None]:

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=f"models/{temp_folder}_best_val_acc.keras", save_best_only=True, monitor="val_acc"
    ),
    keras.callbacks.ModelCheckpoint(
        filepath=f"models/{temp_folder}_best_val_loss.keras", save_best_only=True, monitor="val_loss"
    ),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=int(epochs / 4)),
    keras.callbacks.TensorBoard(log_dir="logs"),
    keras.callbacks.LearningRateScheduler(cosine_annealing_scheduler, verbose=1),
    ConfusionMatrixCallback(val_ds, class_names=class_names),
]


In [None]:
class_weights

In [None]:

from layers import build_resnet, build_simple_cnn

model = build_resnet(
    input_shape=input_shape,
    num_classes=num_classes,
    preset="resnet_18_imagenet",
)

In [None]:
model.summary()

In [None]:
model.compile(
    # loss=RandomSmoothingLoss(keras.losses.CategoricalCrossentropy(), smoothing_range=(0.0, 0.1)),
    loss=keras.losses.CategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3, weight_decay=1e-6),
    metrics=[keras.metrics.CategoricalAccuracy(name="acc")],
)

model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks,
    class_weight=class_weights,
)

score = model.evaluate(val_ds, verbose=0)
print(f"Test loss: {score[0]}")
print(f"Test accuracy: {score[1]}")





In [4]:
shutil.rmtree(Path(temp_folder), ignore_errors=True)