In [None]:
import os
import io
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

from sklearn.metrics import (
    confusion_matrix, classification_report,
    roc_curve, precision_recall_curve, auc,
    average_precision_score
)

from scipy.stats import friedmanchisquare
from sklearn.utils import resample
from statsmodels.stats.contingency_tables import mcnemar

from openpyxl import Workbook
from openpyxl.drawing.image import Image as XLImage
from PIL import Image as PILImage

# ============================
# CONFIG
# ============================
TEST_DIR = r"C:/Users/vb11574/Desktop/Salmonella_Project/Models/3_SameLearning/Af_to_Af/African_Dataset/test"
PTH_DIR  = r"C:/Users/vb11574/Desktop/Salmonella_Project/Models/3_SameLearning/Af_to_Af/Models"
OUTPUT_EXCEL = r"C:/Users/vb11574/Desktop/Salmonella_Project/Models/3_SameLearning/Af_to_Af/FINAL_SINGLE_SHEET.xlsx"

BATCH_SIZE = 64
NUM_WORKERS = 8
CLASS_NAMES = ["healthy", "salmo"]

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ============================
# HELPERS
# ============================
def fig_to_xl(fig):
    buf = io.BytesIO()
    fig.savefig(buf, dpi=150, bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    return XLImage(PILImage.open(buf))

def bootstrap_metrics(y_true, y_prob, n_boot=1000, seed=42):
    rng = np.random.RandomState(seed)
    accs, aucs, pr_aucs, f1s = [], [], [], []

    for _ in range(n_boot):
        idx = rng.randint(0, len(y_true), len(y_true))
        yt = y_true[idx]
        yp = y_prob[idx]
        yhat = (yp > 0.5).astype(int)

        accs.append((yhat == yt).mean())

        fpr, tpr, _ = roc_curve(yt, yp)
        aucs.append(auc(fpr, tpr))

        p, r, _ = precision_recall_curve(yt, yp)
        pr_aucs.append(auc(r, p))

        rep = classification_report(
            yt, yhat, output_dict=True, zero_division=0
        )
        f1s.append(rep["macro avg"]["f1-score"])

    def s(x):
        return np.mean(x), np.std(x), np.percentile(x, 2.5), np.percentile(x, 97.5)

    return {
        "acc": s(accs),
        "roc": s(aucs),
        "pr": s(pr_aucs),
        "f1": s(f1s),
    }

def mcnemar_p(y_true, y_pred):
    rand = np.random.randint(0, 2, size=len(y_true))
    table = [
        [np.sum((y_pred == y_true) & (rand == y_true)),
         np.sum((y_pred == y_true) & (rand != y_true))],
        [np.sum((y_pred != y_true) & (rand == y_true)),
         np.sum((y_pred != y_true) & (rand != y_true))]
    ]
    return mcnemar(table, exact=True).pvalue

# ============================
# MODEL FACTORY
# ============================
def binary_head(in_f):
    return nn.Sequential(
        nn.Linear(in_f, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.5),
        nn.Linear(512, 1),
    )

def build_model(name):
    n = name.lower()
    if n == "vgg16":
        m = models.vgg16(weights=None)
        m.classifier = binary_head(m.classifier[0].in_features)
    elif n == "vgg19":
        m = models.vgg19(weights=None)
        m.classifier = binary_head(m.classifier[0].in_features)
    elif n == "densenet169":
        m = models.densenet169(weights=None)
        m.classifier = binary_head(m.classifier.in_features)
    elif n == "resnet50":
        m = models.resnet50(weights=None)
        m.fc = binary_head(m.fc.in_features)
    elif n == "inceptionv3":
        m = models.inception_v3(weights=None, aux_logits=False)
        m.fc = binary_head(m.fc.in_features)
    elif n == "xception":
        import timm
        m = timm.create_model("xception", pretrained=False, num_classes=1)
    else:
        raise ValueError(f"Unsupported model: {name}")
    return m

# ============================
# DATA
# ============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(TEST_DIR, transform=transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
y_true = np.array([y for _, y in dataset.samples])

# ============================
# EXCEL INIT
# ============================
wb = Workbook()
ws = wb.active
ws.title = "All_Models_Stats"

headers = [
    "Model",
    "Acc_mean", "Acc_std", "Acc_CI_low", "Acc_CI_high",
    "ROC_AUC_mean", "ROC_AUC_std",
    "PR_AUC_mean", "PR_AUC_std",
    "mAP",
    "F1_macro_mean", "F1_CI_low", "F1_CI_high",
    "McNemar_p"
]
ws.append(headers)

plot_row = 3  # start plots below table later

# ============================
# MAIN LOOP
# ============================
for fname in sorted(os.listdir(PTH_DIR)):
    if not fname.endswith(".pth"):
        continue

    model_name = fname.split("_")[0]
    print(f"Evaluating {model_name}")

    model = build_model(model_name).to(DEVICE)
    ckpt = torch.load(os.path.join(PTH_DIR, fname), map_location="cpu")
    state = ckpt.get("state_dict", ckpt)
    state = {k.replace("module.", ""): v for k, v in state.items()}
    model.load_state_dict(state, strict=False)
    model.eval()

    probs = []
    with torch.no_grad():
        for x, _ in loader:
            x = x.to(DEVICE)
            probs.append(torch.sigmoid(model(x)).cpu().numpy().ravel())

    y_prob = np.concatenate(probs)
    y_pred = (y_prob > 0.5).astype(int)

    conf = confusion_matrix(y_true, y_pred)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    prec, rec, _ = precision_recall_curve(y_true, y_prob)

    roc_auc = auc(fpr, tpr)
    pr_auc = auc(rec, prec)
    mAP = average_precision_score(y_true, y_prob)

    boot = bootstrap_metrics(y_true, y_prob)
    p_mcn = mcnemar_p(y_true, y_pred)

    ws.append([
        model_name,
        boot["acc"][0], boot["acc"][1], boot["acc"][2], boot["acc"][3],
        boot["roc"][0], boot["roc"][1],
        boot["pr"][0], boot["pr"][1],
        mAP,
        boot["f1"][0], boot["f1"][2], boot["f1"][3],
        p_mcn
    ])

    # ===== PLOTS =====
    fig = plt.figure()
    sns.heatmap(conf, annot=True, fmt="d", cmap="Blues")
    ws.add_image(fig_to_xl(fig), f"A{plot_row}")

    fig = plt.figure()
    plt.plot(fpr, tpr)
    plt.title(f"{model_name} ROC")
    ws.add_image(fig_to_xl(fig), f"I{plot_row}")

    fig = plt.figure()
    plt.plot(rec, prec)
    plt.title(f"{model_name} PR")
    ws.add_image(fig_to_xl(fig), f"Q{plot_row}")

    plot_row += 20  # spacing between models

wb.save(OUTPUT_EXCEL)
print("âœ… SINGLE-SHEET EXCEL REPORT GENERATED")
