In [None]:
#Basic setup and imports
# main_experiments.ipynb

import os
import sys
from pathlib import Path

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt

# Make plots a bit nicer
%matplotlib inline

print("TensorFlow version:", tf.__version__)


In [None]:
#Make src/ importable and import models
# Locate repo root and add src/ to Python path
root = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
src_dir = root / "src"
sys.path.append(str(src_dir))

print("Repo root:", root)
print("src dir:", src_dir)

from models import build_resnet50_base, build_resnet50_modified


In [None]:
#Data paths and dataset loading
# Point to the processed HAM10000 data directory
DATA_ROOT = root / "data_ham10000"  # change to "data" if that's what you used

train_dir = DATA_ROOT / "train"
val_dir   = DATA_ROOT / "val"
test_dir  = DATA_ROOT / "test"

print("Train dir:", train_dir)
print("Val dir:", val_dir)
print("Test dir:", test_dir)

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
SEED = 42

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=SEED,
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    val_dir,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=True,
    seed=SEED,
)

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    test_dir,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

class_names = train_ds.class_names
num_classes = len(class_names)
print("Classes:", class_names, " | num_classes =", num_classes)


In [None]:
#Preetch and basic normalization

AUTOTUNE = tf.data.AUTOTUNE

def configure_ds(ds, shuffle=False):
    ds = ds.map(lambda x, y: (tf.cast(x, tf.float32) / 255.0, y),
                num_parallel_calls=AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(1000)
    return ds.prefetch(AUTOTUNE)

train_ds_prep = configure_ds(train_ds, shuffle=True)
val_ds_prep   = configure_ds(val_ds, shuffle=False)
test_ds_prep  = configure_ds(test_ds, shuffle=False)


In [None]:
#Build both models and give a summary

# Build baseline and modified models
resnet50_base = build_resnet50_base(num_classes=num_classes)
resnet50_mod  = build_resnet50_modified(num_classes=num_classes)

resnet50_base.summary()
resnet50_mod.summary()


In [None]:
#Compile both models
loss_fn = keras.losses.CategoricalCrossentropy(from_logits=False)
metrics = ["accuracy"]

resnet50_base.compile(
    optimizer=optimizers.Adam(learning_rate=1e-4),
    loss=loss_fn,
    metrics=metrics,
)

resnet50_mod.compile(
    optimizer=optimizers.Adam(learning_rate=1e-4),
    loss=loss_fn,
    metrics=metrics,
)


In [None]:
#Callbacks to save the best models

models_dir = root / "models"
models_dir.mkdir(exist_ok=True)

base_ckpt_path = models_dir / "best_resnet50_base.keras"
mod_ckpt_path  = models_dir / "best_resnet50_modified.keras"

checkpoint_base = keras.callbacks.ModelCheckpoint(
    filepath=str(base_ckpt_path),
    monitor="val_accuracy",
    save_best_only=True,
    mode="max",
    verbose=1,
)

checkpoint_mod = keras.callbacks.ModelCheckpoint(
    filepath=str(mod_ckpt_path),
    monitor="val_accuracy",
    save_best_only=True,
    mode="max",
    verbose=1,
)

early_stop = keras.callbacks.EarlyStopping(
    monitor="val_accuracy",
    patience=5,
    mode="max",
    restore_best_weights=True,
)


In [None]:
#Train baseline ResNet50 

EPOCHS = 15

history_base = resnet50_base.fit(
    train_ds_prep,
    validation_data=val_ds_prep,
    epochs=EPOCHS,
    callbacks=[checkpoint_base, early_stop],
)


In [None]:
#Train modified ResNet50 with suppression-and-excitement

# You can reuse early_stop, or re-create a separate one if you want different patience
early_stop_mod = keras.callbacks.EarlyStopping(
    monitor="val_accuracy",
    patience=5,
    mode="max",
    restore_best_weights=True,
)

history_mod = resnet50_mod.fit(
    train_ds_prep,
    validation_data=val_ds_prep,
    epochs=EPOCHS,
    callbacks=[checkpoint_mod, early_stop_mod],
)


In [None]:
#Find and save Confusion Matrix

import numpy as np
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

figures_dir = root / "figures"
figures_dir.mkdir(exist_ok=True)

def save_confusion_matrix(model, dataset, class_names, output_path):
    y_true = []
    y_pred = []

    for batch_images, batch_labels in dataset:
        preds = model.predict(batch_images, verbose=0)
        pred_classes = np.argmax(preds, axis=1)

        # if labels are one-hot
        if batch_labels.ndim > 1 and batch_labels.shape[-1] == len(class_names):
            true_classes = np.argmax(batch_labels, axis=1)
        else:
            true_classes = batch_labels.numpy()

        y_true.extend(true_classes)
        y_pred.extend(pred_classes)

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    cm = confusion_matrix(y_true, y_pred)
    print("Confusion matrix (raw):\n", cm)

    fig, ax = plt.subplots(figsize=(6, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(ax=ax, values_format="d", xticks_rotation=45)
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.close(fig)


In [None]:
#Evaluate on test set and save confusion matricies
import time

# Evaluate baseline
print("Baseline ResNet50 on test set: ")
base_start = time.time()
test_loss_base, test_acc_base = resnet50_base.evaluate(test_ds_prep)
base_end = int(time.time() - base_start)
print(f"Test accuracy (base): {test_acc_base:.4f}")
print(f"Total training time: {base_end}")

# Evaluate modified
print("Modified ResNet50 + SE on test set: ")
mod_start = time.time()
test_loss_mod, test_acc_mod = resnet50_mod.evaluate(test_ds_prep)
mod_end = int(time.time() - mod_start)
print(f"Test accuracy (modified): {test_acc_mod:.4f}")
print(f"Total training time: {mod_end}")

# Save confusion matrices
cm_base_path = figures_dir / "confusion_matrix_resnet50_base.png"
cm_mod_path  = figures_dir / "confusion_matrix_resnet50_modified.png"

save_confusion_matrix(resnet50_base, test_ds_prep, class_names, cm_base_path)
save_confusion_matrix(resnet50_mod,  test_ds_prep, class_names, cm_mod_path)

print("Saved confusion matrices to:")
print("  ", cm_base_path)
print("  ", cm_mod_path)


In [None]:
#Plot the training curves

def plot_history(history, title, output_path):
    plt.figure(figsize=(6, 4))
    plt.plot(history.history["accuracy"], label="train_acc")
    plt.plot(history.history["val_accuracy"], label="val_acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title(title)
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.show()

#Training curves path
train_base_path = figures_dir / "training_curves_resnet50_base.png"
train_mod_path = figures_dir / "training_curves_resnet50_modified.png"

plot_history(history_base, "Baseline ResNet50", train_base_path)
plot_history(history_mod, "Modified ResNet50 + SE", train_mod_path)
