In [None]:
!pip install -q medmnist scikit-learn torchmetrics

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchmetrics.classification import MulticlassCalibrationError
from medmnist import PathMNIST
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, balanced_accuracy_score, mean_absolute_error, brier_score_loss
from sklearn.preprocessing import label_binarize
import seaborn as sns
from models import NCA, CNNBaseline
import torch.nn.functional as F
import numpy as np



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
import os
import shutil

drive_folder = "./data/"
cache_dir = os.path.expanduser("~/.medmnist")
os.makedirs(cache_dir, exist_ok=True)

resolutions = ["", "_64", "_128", "_224"]
for res in resolutions:
    filename = f"pathmnist{res}.npz"
    src = os.path.join(drive_folder, filename)
    dst = os.path.join(cache_dir, filename)

    if os.path.exists(src):
        shutil.copyfile(src, dst)
        print(f"Copied {filename} to cache.")
    else:
        print(f"File not found in Drive: {filename}")

In [None]:
nca = NCA().to(device)
nca.load_state_dict(torch.load("./models/best_nca_pathmnist.pth"))
nca.eval()

cnn = CNNBaseline().to(device)
cnn.load_state_dict(torch.load("./models/best_cnn_pathmnist.pth"))
cnn.eval()

In [None]:
def get_loader(size, batch_size=64):
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = PathMNIST(split="test", size=size, download=False, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
@torch.no_grad()
def evaluate(model, loader, name="Model", size=28, save_dir="./results", is_NCA=False):
    all_preds, all_labels, all_logits = [], [], []

    for x, y in loader:
        x, y = x.to(device), y.squeeze()
        if is_NCA:
            out, rgb_steps = model(x, True)
        else:
            out = model(x)

        all_logits.append(out.cpu())
        pred = out.argmax(dim=1).cpu().numpy()
        label = y.numpy()
        all_preds.extend(pred)
        all_labels.extend(label)

    # Converting for scores
    logits = torch.cat(all_logits)  # (N, C)
    probs = F.softmax(logits, dim=1).numpy()  # (N, C)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Metrics
    cm = confusion_matrix(all_labels, all_preds)
    overall_acc = accuracy_score(all_labels, all_preds)
    bal_acc = balanced_accuracy_score(all_labels, all_preds)
    mae = mean_absolute_error(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, digits=4)


    ### Uncertainty Quantification (NQM)

    # Brier Score
    y_true_bin = label_binarize(all_labels, classes=list(range(probs.shape[1])))
    brier = np.mean(np.sum((probs - y_true_bin) ** 2, axis=1))

    # NLL (Cross-Entropy)
    all_labels_tensor = torch.tensor(all_labels)
    nll = F.cross_entropy(logits, all_labels_tensor, reduction='mean').item()

    # Entroyp
    entropy = -np.sum(probs * np.log(probs + 1e-12), axis=1)
    mean_entropy = np.mean(entropy)

    # ECE
    probs_tensor = torch.from_numpy(probs).float()
    labels_tensor = torch.from_numpy(all_labels).long()

    ece_metric = MulticlassCalibrationError(num_classes=probs.shape[1], n_bins=15, norm='l1')
    ece = ece_metric(probs_tensor, labels_tensor).item()

    # Logging
    print(f"\n{name} @ {size}x{size}")
    print(f"Overall Accuracy: {overall_acc:.4f}")
    print("Balanced Accuracy:", f"{bal_acc:.4f}")
    print("Mean Absolute Error (MAE):", f"{mae:.4f}")
    print("")
    print("Uncertainty Quantification:")
    print(f"Brier Score: {brier:.4f}")
    print(f"NLL (Cross-Entropy)  : {nll:.4f}")
    print(f"Mean Predictive Entropy: {mean_entropy:.4f}")
    print(f"Expected Calibration Error (ECE): {ece:.4f}")

    print(report)

    # Save confusion matrix
    os.makedirs(save_dir, exist_ok=True)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title(f"Confusion Matrix: {name} @ {size}x{size}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    fname = f"{save_dir}/cm_{name.replace(' ', '_')}_{size}x{size}.png"
    plt.savefig(fname)
    plt.close()
    print(f"Confusion matrix saved to: {fname}")

    return {
        "overall_acc": overall_acc,
        "bal_acc": bal_acc,
        "mae": mae,
        "brier": brier,
        "nll": nll,
        "entropy": mean_entropy,
        "ece": ece
    }

In [None]:
def plot_comparison(results, save_dir="./results"):
    os.makedirs(save_dir, exist_ok=True)

    metric_keys = ["overall_acc", "bal_acc", "mae", "brier", "nll", "entropy", "ece"]

    for metric in metric_keys:
        plt.figure(figsize=(8, 5))
        plt.plot(results["CNN"]["size"], results["CNN"][metric], marker='o', label="CNN")
        plt.plot(results["NCA"]["size"], results["NCA"][metric], marker='s', label="NCA")

        plt.title(f"{metric.replace('_', ' ').title()} vs Resolution")
        plt.xlabel("Image Size")
        plt.ylabel(metric.replace('_', ' ').title())
        plt.xticks(results["CNN"]["size"])  # saubere Ticks
        plt.legend()
        plt.grid(True)

        fname = os.path.join(save_dir, f"comparison_{metric}.png")
        plt.savefig(fname)
        plt.close()
        print(f"Saved: {fname}")

In [None]:
results = {
    "CNN": {"size": [], "overall_acc": [], "bal_acc": [], "mae": [], "brier": [], "nll": [], "entropy": [], "ece": []},
    "NCA": {"size": [], "overall_acc": [], "bal_acc": [], "mae": [], "brier": [], "nll": [], "entropy": [], "ece": []}
}

for size in [28, 64, 128, 224]:
    print(f"\n==============================")
    print(f"Resolution: {size}x{size}")
    loader = get_loader(size)

    print("CNN:")
    cnn_metrics = evaluate(cnn, loader, name="CNN", size=size)
    for k in cnn_metrics:
        results["CNN"][k].append(cnn_metrics[k])
    results["CNN"]["size"].append(size)

    print("NCA:")
    nca_metrics = evaluate(nca, loader, name="NCA", size=size, is_NCA=True)
    for k in nca_metrics:
        results["NCA"][k].append(nca_metrics[k])
    results["NCA"]["size"].append(size)

In [None]:
plot_comparison(results)