In [43]:
import os
import json
import random
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

In [47]:
SEED = 999
DATA_DIR = Path("../data")
PREPARED_CSV = DATA_DIR / "training_prepared_data.csv"
IMAGE_PATH = DATA_DIR.joinpath("images", "images")
OUTDIR = Path("outputs/simple_twohead_b1") # Directory where the trained model is saved
MODEL_PATH = OUTDIR / "best_model.h5"

IMG_SIZE = 224
BATCH_SIZE = 32

# --- Class Definitions (MUST match training) ---
DX_CLASSES = sorted(['nv', 'mel', 'bkl', 'bcc', 'scc_akiec', 'vasc', 'df', 'other', 'no_lesion'])
LESION_TYPE_CLASSES = ["benign", "malignant", "no_lesion"]
N_DX_CLASSES = len(DX_CLASSES)
N_LESION_TYPE_CLASSES = len(LESION_TYPE_CLASSES)

In [48]:
def build_augmenter(is_training):
    if is_training:
        raise ValueError("build_augmenter should not be called with is_training=True during evaluation.")
    return keras.Sequential([
        layers.Resizing(256, 256),
        layers.CenterCrop(IMG_SIZE, IMG_SIZE),
    ], name="preprocessor")

def build_dataset(df, is_training=False):
    if is_training:
        raise ValueError("build_dataset should not be called with is_training=True during evaluation.")

    df = df.dropna(subset=['image_path', 'head2_idx']).copy()
    df_fine = df['head1_idx'].fillna(-1).astype('int32').values
    df_coarse = df['head2_idx'].astype('int32').values

    def resolve_path(p):
        p = str(p)
        return p if os.path.isabs(p) else str(IMAGE_PATH / p)

    img_paths = df['image_path'].astype(str).apply(resolve_path).tolist()

    ds = tf.data.Dataset.from_tensor_slices((img_paths, df_fine, df_coarse))

    augmenter = build_augmenter(is_training)
    rescale = layers.Rescaling(1./255)
    normalization_layer = layers.Normalization(
        mean=[0.485, 0.456, 0.406],
        variance=[0.229**2, 0.224**2, 0.225**2]
    )

    def load_and_preprocess(path, label_fine, label_coarse):
        img = tf.io.read_file(path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = augmenter(img)
        img = rescale(img)
        img = normalization_layer(img)
        return img, {"fine_output": label_fine, "coarse_output": label_coarse}

    ds = ds.map(load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

def masked_sparse_categorical_crossentropy(y_true, y_pred):
    y_true = tf.cast(y_true, tf.int32)
    mask = tf.cast(tf.not_equal(y_true, -1), dtype=tf.float32)
    loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
    masked_loss = loss * mask
    return tf.reduce_sum(masked_loss) / (tf.reduce_sum(mask) + 1e-8)

In [54]:
PRETRAINED = False
def create_two_head_model(n_fine, n_coarse, img_size=IMG_SIZE, dropout=0.2):
    """Creates the two-headed model using the Keras Functional API."""
    inputs = keras.Input(shape=(img_size, img_size, 3), name="input")
    
    backbone = keras.applications.EfficientNetB1(
        include_top=False, 
        weights="imagenet" if PRETRAINED else None, 
        input_tensor=inputs
    )
    
    x = layers.GlobalAveragePooling2D(name="avg_pool")(backbone.output)
    x = layers.Dropout(dropout, name="top_dropout")(x)

    output_fine = layers.Dense(n_fine, name="fine_output")(x)
    output_coarse = layers.Dense(n_coarse, name="coarse_output")(x)

    model = keras.Model(inputs=inputs, outputs=[output_fine, output_coarse], name="EffB1TwoHead")
    return model

model = create_two_head_model(N_DX_CLASSES, N_LESION_TYPE_CLASSES)
model.load_weights(str(OUTDIR / "best_model.h5"))


ValueError: Shape mismatch in layer #185 (named fine_output)for weight fine_output/kernel. Weight expects shape (1280, 9). Received saved weight with shape (1280, 11)

In [None]:
df = pd.read_csv(PREPARED_CSV)
test_df = df[df.split == "test"].copy()
ood_df = df[df.split == "test_ood"].copy()

test_ds = build_dataset(test_df)
ood_ds = build_dataset(ood_df)

def get_predictions_and_labels(model, dataset):
    all_labels_h1, all_labels_h2 = [], []
    all_logits_h1, all_logits_h2 = [], []

    for images, labels in dataset:
        logits_h1, logits_h2 = model.predict_on_batch(images)
        all_logits_h1.append(logits_h1)
        all_logits_h2.append(logits_h2)

        all_labels_h1.append(labels['fine_output'].numpy())
        all_labels_h2.append(labels['coarse_output'].numpy())
        
    all_logits_h1 = np.concatenate(all_logits_h1, axis=0)
    all_logits_h2 = np.concatenate(all_logits_h2, axis=0)
    all_labels_h1 = np.concatenate(all_labels_h1, axis=0)
    all_labels_h2 = np.concatenate(all_labels_h2, axis=0)

    return all_labels_h1, all_logits_h1, all_labels_h2, all_logits_h2

id_labels_h1, id_logits_h1, id_labels_h2, id_logits_h2 = get_predictions_and_labels(model, test_ds)
ood_labels_h1, ood_logits_h1, ood_labels_h2, ood_logits_h2 = get_predictions_and_labels(model, ood_ds)

In [None]:
def plot_confusion_matrix(labels, preds, class_names, title):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.show()

# --- Head 1: Diagnosis Evaluation ---
id_preds_h1 = np.argmax(id_logits_h1, axis=1)
print(classification_report(id_labels_h1, id_preds_h1, target_names=DX_CLASSES))
plot_confusion_matrix(id_labels_h1, id_preds_h1, DX_CLASSES, "Head 1 (Diagnosis) Confusion Matrix - ID")

# --- Head 2: Lesion Type Evaluation ---
id_preds_h2 = np.argmax(id_logits_h2, axis=1)
print(classification_report(id_labels_h2, id_preds_h2, target_names=LESION_TYPE_CLASSES))
plot_confusion_matrix(id_labels_h2, id_preds_h2, LESION_TYPE_CLASSES, "Head 2 (Lesion Type) Confusion Matrix - ID")

In [None]:
def get_msp_scores(logits):
    softmax_probs = tf.nn.softmax(logits, axis=1).numpy()
    return np.max(softmax_probs, axis=1)

id_msp_scores = get_msp_scores(id_logits_h1)
ood_msp_scores = get_msp_scores(ood_logits_h1)

# --- Head 1: OOD Detection ---
plt.figure(figsize=(10, 6))
sns.histplot(id_msp_scores, color='blue', label='ID (Known Diagnoses)', stat='density', bins=50, kde=True)
sns.histplot(ood_msp_scores, color='red', label='OOD (Unknown Diagnoses)', stat='density', bins=50, kde=True)
plt.title('Confidence Score (MSP) Distributions')
plt.xlabel('Maximum Softmax Probability')
plt.legend()
plt.show()

labels_id = np.ones_like(id_msp_scores)
labels_ood = np.zeros_like(ood_msp_scores)
all_scores = np.concatenate([id_msp_scores, ood_msp_scores])
all_labels = np.concatenate([labels_id, labels_ood])
auroc = roc_auc_score(all_labels, all_scores)
print(f"OOD Detection AUROC (using MSP): {auroc:.4f}")

# --- Head 2: Performance on OOD data ---
ood_preds_h2 = np.argmax(ood_logits_h2, axis=1)
print(classification_report(ood_labels_h2, ood_preds_h2, target_names=LESION_TYPE_CLASSES))
plot_confusion_matrix(ood_labels_h2, ood_preds_h2, LESION_TYPE_CLASSES, "Head 2 (Lesion Type) Confusion Matrix - OOD")

In [None]:
# df = pd.read_csv(PREPARED_CSV)
# # prefer "test" if available, else fall back to "val"
# test_df = df[df.split == "test"].copy()

# test_ds = build_dataset(test_df, is_training=False)

# model = create_two_head_model(N_DX_CLASSES, N_LESION_TYPE_CLASSES)
# model.load_weights(str(OUTDIR / "best_model.h5"))

# results = evaluate_model(model, test_ds)
# print("\n== Aggregate metrics ==")
# for k, v in results.items():
#     print(f"{k}: {v:.4f}")

[1m64/64[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 901ms/step - coarse_output_acc: 0.9360 - coarse_output_loss: 843.2475 - fine_output_acc: 0.6915 - fine_output_loss: 1626.9020 - loss: 2458.8330

== Aggregate metrics ==
coarse_output_acc: 0.9360
coarse_output_loss: 843.2475
fine_output_acc: 0.6915
fine_output_loss: 1626.9020
loss: 2458.8330


In [None]:
# # --- OOD using your unknowns (fine_output == -1) in test_df ---
# test_ood_df = df[df.split == "test_ood"].copy()
# test_ood_ds = build_dataset(test_df, is_training=False)

# z_id, z_ood = collect_fine_logits_by_mask(model, test_ood_ds)

# if z_id.shape[0] == 0 or z_ood.shape[0] == 0:
#     print("\n[OOD] Need both ID (fine label != -1) and OOD (fine label == -1) samples in the test split.")
# # else:
# #     # Energy (recommended) and MSP baselines
# #     s_id_en  = energy_score(z_id, T=1.0)
# #     s_ood_en = energy_score(z_ood, T=1.0)
# #     s_id_msp  = msp_score(z_id)
# #     s_ood_msp = msp_score(z_ood)

# #     res_en  = ood_metrics(s_id_en,  s_ood_en,  tpr=0.95)
# #     res_msp = ood_metrics(s_id_msp, s_ood_msp, tpr=0.95)

# #     print("\n== OOD (Energy, T=1.0) ==")
# #     for k, v in res_en.items(): print(f"{k}: {v:.4f}")
# #     print("\n== OOD (MSP) ==")
# #     for k, v in res_msp.items(): print(f"{k}: {v:.4f}")

# #     # Optional: save the energy threshold for gating at inference
# #     # import json
# #     # OUTDIR.mkdir(parents=True, exist_ok=True)
# #     # with open(OUTDIR / "ood_thresholds.json", "w") as f:
# #     #     json.dump({"energy_T1": res_en, "msp": res_msp}, f, indent=2)
# #     # print(f"\nSaved OOD thresholds to {OUTDIR/'ood_thresholds.json'}")

# #     # Optional quick viz: score histograms
# #     try:
# #         plt.figure(figsize=(5,3))
# #         sns.kdeplot(s_id_en, label="ID (energy)")
# #         sns.kdeplot(s_ood_en, label="OOD (energy)")
# #         plt.title("Energy score distributions")
# #         plt.legend(); plt.tight_layout()
# #         plt.show()
# #     except Exception as e:
# #         print("Plot skipped:", e)



[OOD] Need both ID (fine label != -1) and OOD (fine label == -1) samples in the test split.


In [None]:
# z_id = []
# for (x, _) in test_ds:
#     z_fine, _ = model(x, training=False)
#     z_id.append(z_fine.numpy())
# z_id = np.concatenate(z_id, axis=0) if z_id else np.empty((0, N_DX_CLASSES))

# # OOD logits from the unlabeled ds_ood (built with make_image_only_ds as above)
# # z_ood = logits_over_image_ds(model, ds_ood)

KeyboardInterrupt: 

In [None]:
# print(df["head1_idx"].value_counts(dropna=False).head(10))
# # print(df.loc[df["head1_idx"].isna()]) # , ["image_path"]].head())
# # print(df.loc[df["head1_idx"].fillna(-1).eq(-1), ["split"]].value_counts())

head1_idx
NaN     26698
0.0     20468
3.0      8453
1.0      6165
2.0      4142
6.0      3082
10.0     1750
8.0       389
7.0       386
9.0       182
Name: count, dtype: int64
