In [None]:
import os
import io
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image as PILImage
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows
from openpyxl.drawing.image import Image as XLImage

from sklearn.metrics import (
    confusion_matrix, classification_report, roc_curve, precision_recall_curve,
    auc, average_precision_score
)
from scipy import stats

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# -----------------------------
# CONFIG (UPDATE THESE)
# -----------------------------
TEST_DIR = r"C:/Users/vb11574/Desktop/Salmonella_Project/Models/3_SameLearning/Ame_to_Ame/American_Dataset/test"
PTH_DIR  = r"C:/Users/vb11574/Desktop/Salmonella_Project/Models/3_SameLearning/Ame_to_Ame/Models"  # folder containing .pth
OUTPUT_EXCEL = r"C:/Users/vb11574/Desktop/Salmonella_Project/Models/3_SameLearning/Ame_to_Ame/2_Ame_to_Ame_Metrics_report.xlsx"

BATCH_SIZE = 64  # increase for faster eval if VRAM allows
NUM_WORKERS = min(8, os.cpu_count() or 0)
PIN_MEMORY = True
PERSISTENT_WORKERS = True if NUM_WORKERS > 0 else False
MIXED_PRECISION = True

# TF code uses only rescale=1/255 => PyTorch equivalent is ToTensor() only (no ImageNet normalize)
USE_IMAGENET_NORMALIZATION = False

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# If you want the TF-like label names but your folder names differ, edit these:
DEFAULT_TARGET_NAMES = ["healthy", "salmo"]  # only used if len(classes)==2

# -----------------------------
# Helpers: plots & CI
# -----------------------------
def save_plot_to_image(fig):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=160, bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    return XLImage(PILImage.open(buf))

def compute_confidence_interval(data, confidence=0.95):
    mean_ = float(np.mean(data))
    sem_ = stats.sem(data)
    margin = float(sem_ * stats.t.ppf((1 + confidence) / 2.0, len(data) - 1))
    return mean_, margin

# -----------------------------
# DataLoader builder (input-size aware)
# -----------------------------
def make_test_loader(test_dir: str, img_size: int):
    tfms = [
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),  # -> [0,1] like TF rescale=1/255
    ]
    if USE_IMAGENET_NORMALIZATION:
        tfms.append(transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD))

    ds = datasets.ImageFolder(test_dir, transform=transforms.Compose(tfms))
    loader = DataLoader(
        ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=PERSISTENT_WORKERS,
        prefetch_factor=2 if PERSISTENT_WORKERS else None,
    )
    return ds, loader

# -----------------------------
# Model builders (must match how your .pth was trained)
# Outputs logits [B,1]
# -----------------------------
def binary_head(in_features: int):
    return nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5),
        nn.Linear(512, 1),
    )

def build_model(model_name: str):
    name = model_name.lower()

    if name == "vgg16":
        m = models.vgg16(weights=None)
        m.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        in_feat = m.classifier[0].in_features
        m.classifier = binary_head(in_feat)
        return m

    if name == "vgg19":
        m = models.vgg19(weights=None)
        m.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        in_feat = m.classifier[0].in_features
        m.classifier = binary_head(in_feat)
        return m

    if name == "mobilenetv2":
        m = models.mobilenet_v2(weights=None)
        in_feat = m.classifier[1].in_features
        m.classifier = nn.Sequential(nn.Dropout(p=0.2), nn.Linear(in_feat, 1))
        return m

    if name in ["mobilenetv3", "mobilenetv3large", "mobilenetv3_large"]:
        m = models.mobilenet_v3_large(weights=None)
        m.classifier = nn.Sequential(
            nn.Linear(m.classifier[0].in_features, 512),
            nn.Hardswish(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(512, 1),
        )
        return m

    if name == "densenet121":
        m = models.densenet121(weights=None)
        in_feat = m.classifier.in_features
        m.classifier = binary_head(in_feat)
        return m

    if name == "densenet169":
        m = models.densenet169(weights=None)
        in_feat = m.classifier.in_features
        m.classifier = binary_head(in_feat)
        return m

    if name == "densenet201":
        m = models.densenet201(weights=None)
        in_feat = m.classifier.in_features
        m.classifier = binary_head(in_feat)
        return m

    if name == "resnet50":
        m = models.resnet50(weights=None)
        in_feat = m.fc.in_features
        m.fc = binary_head(in_feat)
        return m

    if name == "resnet101":
        m = models.resnet101(weights=None)
        in_feat = m.fc.in_features
        m.fc = binary_head(in_feat)
        return m

    if name == "resnet152":
        m = models.resnet152(weights=None)
        in_feat = m.fc.in_features
        m.fc = binary_head(in_feat)
        return m

    if name == "inceptionv3":
        # NOTE: torchvision inception is safest with aux_logits=True
        m = models.inception_v3(weights=None, aux_logits=True, transform_input=False)
        in_feat = m.fc.in_features
        m.fc = binary_head(in_feat)
        if m.AuxLogits is not None:
            aux_in = m.AuxLogits.fc.in_features
            m.AuxLogits.fc = nn.Linear(aux_in, 1)
        return m

    if name == "efficientnetb0":
        m = models.efficientnet_b0(weights=None)
        in_feat = m.classifier[1].in_features
        m.classifier = nn.Sequential(nn.Dropout(p=0.5), nn.Linear(in_feat, 1))
        return m

    if name == "efficientnetb3":
        m = models.efficientnet_b3(weights=None)
        in_feat = m.classifier[1].in_features
        m.classifier = nn.Sequential(nn.Dropout(p=0.5), nn.Linear(in_feat, 1))
        return m

    if name == "efficientnetb7":
        m = models.efficientnet_b7(weights=None)
        in_feat = m.classifier[1].in_features
        m.classifier = nn.Sequential(nn.Dropout(p=0.5), nn.Linear(in_feat, 1))
        return m

    if name == "xception":
        # Xception isn't in torchvision. If you used it, easiest is timm.
        # pip install timm
        import timm
        m = timm.create_model("xception", pretrained=False, num_classes=1)
        return m

    raise ValueError(f"Unsupported model: {model_name}")

def forward_logits(model, x, model_name: str):
    out = model(x)
    # Inception returns (logits, aux) during train; in eval usually returns logits,
    # but be safe:
    if model_name.lower() == "inceptionv3":
        if isinstance(out, tuple):
            out = out[0]
    return out

# -----------------------------
# Robust checkpoint loader
# -----------------------------
def load_pth_weights(model: nn.Module, pth_path: str):
    ckpt = torch.load(pth_path, map_location="cpu")

    # handle common formats
    if isinstance(ckpt, dict):
        if "state_dict" in ckpt:
            state = ckpt["state_dict"]
        elif "model_state_dict" in ckpt:
            state = ckpt["model_state_dict"]
        else:
            # might already be a state_dict-like dict
            state = ckpt
    else:
        raise ValueError("Checkpoint is not a dict; if you did torch.save(model), tell me and I‚Äôll adjust loader.")

    # strip DataParallel prefix
    if any(k.startswith("module.") for k in state.keys()):
        state = {k.replace("module.", "", 1): v for k, v in state.items()}

    missing, unexpected = model.load_state_dict(state, strict=False)
    return missing, unexpected

# -----------------------------
# Inference
# -----------------------------
@torch.no_grad()
def predict_proba(model, loader, device, model_name: str):
    model.eval()
    probs_all = []

    for x, _ in loader:
        x = x.to(device, non_blocking=True)
        with torch.amp.autocast("cuda", enabled=(MIXED_PRECISION and device.type == "cuda")):
            logits = forward_logits(model, x, model_name)
        probs = torch.sigmoid(logits).view(-1).detach().cpu().numpy()
        probs_all.append(probs)

    return np.concatenate(probs_all, axis=0)

# -----------------------------
# MAIN: evaluate all .pth in folder
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

wb = Workbook()
del wb["Sheet"]

pth_files = sorted([f for f in os.listdir(PTH_DIR) if f.lower().endswith(".pth")])
if not pth_files:
    raise RuntimeError(f"No .pth files found in: {PTH_DIR}")

for fname in pth_files:
    pth_path = os.path.join(PTH_DIR, fname)

    # IMPORTANT: This assumes your file name starts with the architecture, like:
    # "VGG16_best.pth", "DenseNet201_run1.pth", "InceptionV3_best.pth"
    model_name = os.path.splitext(fname)[0].split("_")[0]

    print(f"\nüîç Evaluating {fname}  (model={model_name})")

    try:
        # input-size: InceptionV3 / Xception typically need 299
        img_size = 299 if model_name.lower() in ["inceptionv3", "xception"] else 224

        test_ds, test_loader = make_test_loader(TEST_DIR, img_size)
        class_names = test_ds.classes
        target_names = DEFAULT_TARGET_NAMES if len(class_names) == 2 else class_names

        model = build_model(model_name).to(device)
        missing, unexpected = load_pth_weights(model, pth_path)

        if missing:
            print(f"‚ö†Ô∏è Missing keys (first 10): {missing[:10]}")
        if unexpected:
            print(f"‚ö†Ô∏è Unexpected keys (first 10): {unexpected[:10]}")

        y_true = np.array([label for _, label in test_ds.samples], dtype=int)
        y_prob = predict_proba(model, test_loader, device, model_name)
        y_pred = (y_prob > 0.5).astype(int)

        conf = confusion_matrix(y_true, y_pred)
        report = classification_report(y_true, y_pred, target_names=target_names, output_dict=True)

        fpr, tpr, _ = roc_curve(y_true, y_prob)
        precision, recall, _ = precision_recall_curve(y_true, y_prob)
        roc_auc = auc(fpr, tpr)
        pr_auc = auc(recall, precision)
        mAP = average_precision_score(y_true, y_prob)
        ci_mean, ci_margin = compute_confidence_interval(y_prob)

        # Excel sheet
        sheet_name = model_name[:31]
        ws = wb.create_sheet(sheet_name)

        ws.append(["Checkpoint"])
        ws.append(["File", fname])
        ws.append([])

        ws.append(["Summary Metrics"])
        metrics = {
            "Test Accuracy": report["accuracy"],
            "ROC AUC": roc_auc,
            "PR AUC": pr_auc,
            "mAP": mAP,
            "CI Mean": ci_mean,
            "CI Margin": ci_margin,
        }

        # Add F1s if possible
        if len(target_names) == 2 and target_names[0] in report and target_names[1] in report:
            metrics[f"F1_{target_names[0]}"] = report[target_names[0]]["f1-score"]
            metrics[f"F1_{target_names[1]}"] = report[target_names[1]]["f1-score"]

        for k, v in metrics.items():
            ws.append([k, round(float(v), 6)])
        ws.append([])

        ws.append(["Confusion Matrix"])
        conf_df = pd.DataFrame(
            conf,
            index=[f"Actual_{target_names[0]}", f"Actual_{target_names[1]}"],
            columns=[f"Pred_{target_names[0]}", f"Pred_{target_names[1]}"],
        )
        for r in dataframe_to_rows(conf_df, index=True, header=True):
            ws.append(r)

        # Plots
        fig = plt.figure()
        plt.plot(fpr, tpr)
        plt.title("ROC Curve")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        ws.add_image(save_plot_to_image(fig), "J2")

        fig = plt.figure()
        plt.plot(recall, precision)
        plt.title("PR Curve")
        plt.xlabel("Recall")
        plt.ylabel("Precision")
        ws.add_image(save_plot_to_image(fig), "J18")

        fig = plt.figure()
        sns.heatmap(conf, annot=True, fmt="d", cmap="Blues")
        plt.title("Confusion Matrix")
        ws.add_image(save_plot_to_image(fig), "A18")

        print(f"‚úÖ Done: {model_name} | Acc={report['accuracy']:.4f} | ROC_AUC={roc_auc:.4f} | PR_AUC={pr_auc:.4f}")

    except Exception as e:
        print(f"‚ùå Failed for {fname}: {e}")

wb.save(OUTPUT_EXCEL)
print(f"\n‚úÖ Excel report saved to: {OUTPUT_EXCEL}")