# PDRQ_Reliability_Enhancement
Calibration: ECE, Brier, reliability diagrams for every model.
results CSV, metric bar plots, cost-performance chart, λ-sensitivity sweep, confusion matrices, per-class reports, side-by-side reliability grid. 

In [None]:
import os 
import math 
import random 
import time 
import warnings 
from typing import Dict, List, Tuple 
 
import numpy as np 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
from PIL import Image 
from sklearn.metrics import ( 
    accuracy_score, precision_score, recall_score, f1_score, 
    confusion_matrix, classification_report 
) 
from torch.utils.data import DataLoader, Dataset 
from torchvision import models, transforms 
import matplotlib.pyplot as plt 
from tqdm import tqdm 
import pandas as pd 
 
warnings.filterwarnings("ignore") 
 
# ------------------------------- 
# 1. Config 
# ------------------------------- 
class Config: 
    data_dir = 'dataset'              # expected: data/{train,validation,test,unseen}/{AD,MCI,NC} 
    output_dir = 'Save results here' 
 
    image_size = 299 
    num_classes = 3                # AD / MCI / NC 
 
    batch_size = 16                # slightly smaller to stabilize BN 
    epochs = 60                    # regularized training 
    learning_rate = 3e-4 
    weight_decay = 5e-4            # stronger regularization 
    patience_lr_scheduler = 3 
    patience_early_stopping = 10 
    grad_clip = 1.0 
 
    # Triplet & PPDRQ 
    lambda_triplet = 0.1 
    triplet_margin = 0.2 
    epsilon_p = 0.1 
    epsilon_n = 0.2 
 
    # MC-Dropout & Ensemble 
    mc_dropout_passes = 30 
    num_ensemble_models = 5 
 
    # Regularization toggles 
    use_mixup_cutmix = True 
    mixup_alpha = 0.2 
    cutmix_alpha = 0.2 
    label_smoothing = 0.05 
 
    # Data safety 
    allow_make_dummy = False       # IMPORTANT: keep False for real training 
 
    # Device 
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 
 
 
def set_seed(seed: int = 42): 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed) 
    random.seed(seed) 
    torch.backends.cudnn.deterministic = True 
    torch.backends.cudnn.benchmark = False 
 
 
os.makedirs(Config.output_dir, exist_ok=True) 
print(f"Using device: {Config.device}") 
set_seed(42) 
 
# ------------------------------- 
# 2. Dataset 
# ------------------------------- 
class MRIDataset(Dataset): 
    def __init__(self, data_dir: str, phase: str, transform=None): 
        self.data_dir = os.path.join(data_dir, phase) 
        self.transform = transform 
        self.class_to_idx = {'AD': 0, 'MCI': 1, 'NC': 2} 
        self.image_paths: List[str] = [] 
        self.labels: List[int] = [] 
 
        for cname, idx in self.class_to_idx.items(): 
            cdir = os.path.join(self.data_dir, cname) 
            if not os.path.exists(cdir): 
                if Config.allow_make_dummy: 
                    print(f"[WARN] Missing {cdir}. Creating dummy images for demo only.") 
                    os.makedirs(cdir, exist_ok=True) 
                    n = 8 if phase != 'train' else 64 
                    for i in range(n): 
                        img = Image.new('L', (Config.image_size, Config.image_size), color=random.randint(0, 255)) 
                        img.convert('RGB').save(os.path.join(cdir, f'dummy_{i}.png')) 
                else: 
                    print(f"[WARN] Missing class dir: {cdir}. Skipping.") 
                    continue 
 
            files = [f for f in os.listdir(cdir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] 
            if not files and Config.allow_make_dummy: 
                print(f"[WARN] Empty {cdir}. Creating dummy images for demo only.") 
                n = 8 if phase != 'train' else 64 
                for i in range(n): 
                    img = Image.new('L', (Config.image_size, Config.image_size), color=random.randint(0, 255)) 
                    img.convert('RGB').save(os.path.join(cdir, f'dummy_{i}.png')) 
                files = [f for f in os.listdir(cdir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] 
 
            for name in files: 
                self.image_paths.append(os.path.join(cdir, name)) 
                self.labels.append(idx) 
 
    def __len__(self): 
        return len(self.image_paths) 
 
    def __getitem__(self, idx: int): 
        p = self.image_paths[idx] 
        y = self.labels[idx] 
        img = Image.open(p).convert('RGB') 
        if self.transform: 
            img = self.transform(img) 
        return img, y, idx 
 
# Medically sensible augs (avoid strong color jitter for brain MRIs) 
train_tfms = transforms.Compose([ 
    transforms.Resize(int(Config.image_size*1.1)), 
    transforms.RandomResizedCrop(Config.image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1)), 
    transforms.RandomRotation(10), 
    transforms.RandomHorizontalFlip(p=0.5), 
    transforms.GaussianBlur(kernel_size=3), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random') 
]) 
 
val_test_tfms = transforms.Compose([ 
    transforms.Resize((Config.image_size, Config.image_size)), 
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
]) 
 
train_dataset = MRIDataset(Config.data_dir, 'train', train_tfms) 
val_dataset   = MRIDataset(Config.data_dir, 'validation', val_test_tfms) 
test_dataset  = MRIDataset(Config.data_dir, 'test', val_test_tfms) 
unseen_dataset= MRIDataset(Config.data_dir, 'unseen', val_test_tfms) 
 
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}, Unseen: {len(unseen_dataset)}") 
if len(train_dataset) == 0: 
    print("[ERROR] Train dataset is empty. Please prepare data under data/train/{AD,MCI,NC}.") 
 
# Dataloaders (num_workers=0 for portability) 
train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0, drop_last=True) 
val_loader   = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0) 
test_loader  = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0) 
unseen_loader= DataLoader(unseen_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0) 
 
# ------------------------------- 
# 3. Model 
# ------------------------------- 
class CustomInceptionV3(nn.Module): 
    def __init__(self, num_classes=3, dropout_rate=0.6, include_aux_logits=True, train_last_blocks_only=True): 
        super().__init__() 
        import torchvision 
        try: 
            # Torchvision requires aux_logits=True when using pretrained weights 
            self.inception_base = torchvision.models.inception_v3( 
                weights=torchvision.models.Inception_V3_Weights.IMAGENET1K_V1, 
                aux_logits=True 
            ) 
        except Exception: 
            self.inception_base = models.inception_v3(pretrained=True, aux_logits=True) 
        # Expose 2048-d features 
        self.inception_base.fc = nn.Identity() 
 
        if train_last_blocks_only: 
            for p in self.inception_base.parameters(): 
                p.requires_grad = False 
            for name, m in self.inception_base.named_modules(): 
                if name.startswith('Mixed_7'): 
                    for p in m.parameters(): 
                        p.requires_grad = True 
 
        self.bn = nn.BatchNorm1d(2048) 
        self.dropout = nn.Dropout(p=dropout_rate) 
        self.hidden1 = nn.Linear(2048, 1024) 
        self.hidden2 = nn.Linear(1024, 512) 
        self.fc_head = nn.Linear(2048, num_classes) 
 
        self.mixed_7c_output = None 
        def hook_fn(_m, _in, out): 
            self.mixed_7c_output = out 
        self.inception_base.Mixed_7c.register_forward_hook(hook_fn) 
 
    def _extract_features(self, x): 
        out = self.inception_base(x) 
        # In train mode with aux_logits=True, inception_v3 returns InceptionOutputs 
        if hasattr(out, 'logits'): 
            return out.logits 
        if isinstance(out, (tuple, list)) and len(out) > 0: 
            return out[0] 
        return out 
 
    def forward(self, x): 
        flatten = self._extract_features(x) 
        z = self.bn(flatten) 
        z = self.dropout(z) 
        logits = self.fc_head(z) 
        e = F.relu(self.hidden1(z)) 
        e = F.relu(self.hidden2(e)) 
        layer_dict = { 
            'base_model_output': self.mixed_7c_output, 
            'hidden_layer1': F.relu(self.hidden1(self.bn(flatten))), 
            'hidden_layer2': e, 
            'flatten': flatten, 
            'final_logits': logits 
        } 
        return logits, e, layer_dict 
 
# ------------------------------- 
# 4. PPDRQ & Losses 
# ------------------------------- 
 
def compute_ppdrq_from_logits(logits: torch.Tensor, num_classes: int) -> torch.Tensor: 
    probs = torch.softmax(logits, dim=1) 
    if num_classes == 3: 
        p1, p2, p3 = probs[:, 0], probs[:, 1], probs[:, 2] 
        d12 = (p1 - p2).abs(); d13 = (p1 - p3).abs(); d23 = (p2 - p3).abs() 
        raw = (d12 + d13 + d23) / 3.0 
        pp = (3/2) * raw 
        return torch.clamp(pp, 0, 1) 
    # general case 
    diffs = [] 
    for i in range(num_classes): 
        for j in range(i+1, num_classes): 
            diffs.append((probs[:, i] - probs[:, j]).abs()) 
    sumdiff = torch.stack(diffs, dim=1).sum(dim=1) 
    pp = sumdiff / (num_classes - 1) 
    return torch.clamp(pp, 0, 1) 
 
class CombinedLoss(nn.Module): 
    def __init__(self, num_classes, lambda_triplet=Config.lambda_triplet, margin=Config.triplet_margin, 
                 epsilon_p=Config.epsilon_p, epsilon_n=Config.epsilon_n, smoothing=Config.label_smoothing): 
        super().__init__() 
        self.num_classes = num_classes 
        self.lambda_triplet = lambda_triplet 
        self.epsilon_p = epsilon_p 
        self.epsilon_n = epsilon_n 
        self.triplet = nn.TripletMarginLoss(margin=margin, p=2) 
        self.smoothing = smoothing 
 
    def forward(self, logits, embeddings, labels, 
                all_feats=None, all_labels=None, all_pp=None): 
        pp = compute_ppdrq_from_logits(logits, self.num_classes) 
        per_sample_ce = F.cross_entropy(logits, labels, reduction='none', label_smoothing=self.smoothing) 
        weights = 1 + (1 - pp) 
        ce_loss = torch.mean(per_sample_ce * weights) 
 
        triplet_loss = torch.tensor(0.0, device=logits.device) 
        if all_feats is not None and len(all_feats) > 0: 
            allF = torch.cat(all_feats, dim=0).to(embeddings.device) 
            allY = torch.cat(all_labels, dim=0).to(embeddings.device) 
            allP = torch.cat(all_pp, dim=0).to(embeddings.device) 
            count = 0 
            for i in range(labels.size(0)): 
                a = embeddings[i] 
                y = labels[i] 
                ppa = pp[i] 
                pos_idx = torch.where((allY == y) & ((allP - ppa).abs() < Config.epsilon_p))[0] 
                neg_idx = torch.where((allY != y) & ((allP - ppa).abs() >= Config.epsilon_n))[0] 
                if pos_idx.numel() > 1 and neg_idx.numel() > 0: 
                    p = allF[random.choice(pos_idx.tolist())] 
                    n = allF[random.choice(neg_idx.tolist())] 
                    triplet_loss = triplet_loss + self.triplet(a.unsqueeze(0), p.unsqueeze(0), n.unsqueeze(0)) 
                    count += 1 
            if count > 0: 
                triplet_loss = triplet_loss / count 
        total = ce_loss + self.lambda_triplet * triplet_loss 
        return total, ce_loss, triplet_loss, pp 
 
# ------------------------------- 
# 5. Mixup/CutMix utils 
# ------------------------------- 
 
def rand_bbox(W, H, lam): 
    cut_rat = math.sqrt(1. - lam) 
    cut_w = int(W * cut_rat) 
    cut_h = int(H * cut_rat) 
    cx = np.random.randint(W) 
    cy = np.random.randint(H) 
    x1 = np.clip(cx - cut_w // 2, 0, W) 
    y1 = np.clip(cy - cut_h // 2, 0, H) 
    x2 = np.clip(cx + cut_w // 2, 0, W) 
    y2 = np.clip(cy + cut_h // 2, 0, H) 
    return x1, y1, x2, y2 
 
 
def apply_mixup_cutmix(x, y): 
    if not Config.use_mixup_cutmix: 
        return x, y, None 
    r = random.random() 
    if r < 0.5: 
        # mixup 
        lam = np.random.beta(Config.mixup_alpha, Config.mixup_alpha) 
        idx = torch.randperm(x.size(0)).to(x.device) 
        mixed_x = lam * x + (1 - lam) * x[idx] 
        y_a, y_b = y, y[idx] 
        return mixed_x, (y_a, y_b, lam), 'mixup' 
    else: 
        # cutmix 
        lam = np.random.beta(Config.cutmix_alpha, Config.cutmix_alpha) 
        idx = torch.randperm(x.size(0)).to(x.device) 
        x1, y1, x2, y2 = rand_bbox(x.size(3), x.size(2), lam) 
        x_mix = x.clone() 
        x_mix[:, :, y1:y2, x1:x2] = x[idx, :, y1:y2, x1:x2] 
        lam = 1 - ((x2 - x1) * (y2 - y1) / (x.size(-1) * x.size(-2))) 
        y_a, y_b = y, y[idx] 
        return x_mix, (y_a, y_b, lam), 'cutmix' 
 
 
def mix_criterion(logits, target_tuple): 
    y_a, y_b, lam = target_tuple 
    loss_a = F.cross_entropy(logits, y_a, label_smoothing=Config.label_smoothing) 
    loss_b = F.cross_entropy(logits, y_b, label_smoothing=Config.label_smoothing) 
    return lam * loss_a + (1 - lam) * loss_b 
 
# ------------------------------- 
# 6. Train/Eval 
# ------------------------------- 
 
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, model_name, num_epochs=Config.epochs): 
    best_val = float('inf') 
    epochs_no_improve = 0 
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'train_ppdrq': [], 'val_ppdrq': [], 'epochs_trained': 0} 
    start = time.time() 
 
    print(f"\n--- Training {model_name} ---") 
    for epoch in range(num_epochs): 
        model.train() 
        t_loss = 0.0; t_correct = 0; t_total = 0; t_pp = 0.0 
 
        # Pre-collect for triplet 
        allF, allY, allP = [], [], [] 
        with torch.no_grad(): 
            for xb, yb, _ in train_loader: 
                xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device) 
                logits, emb, _ = model(xb) 
                pp = compute_ppdrq_from_logits(logits, Config.num_classes) 
                allF.append(emb.detach().cpu()); allY.append(yb.detach().cpu()); allP.append(pp.detach().cpu()) 
        allF = torch.cat(allF, 0) if allF else torch.tensor([]) 
        allY = torch.cat(allY, 0) if allY else torch.tensor([]) 
        allP = torch.cat(allP, 0) if allP else torch.tensor([]) 
 
        for xb, yb, _ in train_loader: 
            xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device) 
            # Mixup/CutMix 
            xb_m, yb_m, aug = apply_mixup_cutmix(xb, yb) 
 
            optimizer.zero_grad(set_to_none=True) 
            logits, emb, _ = model(xb_m) 
 
            if aug is None: 
                total, ce, trip, pp = criterion(logits, emb, yb, [allF], [allY], [allP]) 
            else: 
                # mix criterion for CE part; keep triplet off for stability when mixing 
                ce_mix = mix_criterion(logits, yb_m) 
                pp = compute_ppdrq_from_logits(logits, Config.num_classes) 
                total = ce_mix 
                ce, trip = ce_mix, torch.tensor(0.0, device=logits.device) 
 
            total.backward() 
            if Config.grad_clip is not None: 
                nn.utils.clip_grad_norm_(model.parameters(), Config.grad_clip) 
            optimizer.step() 
 
            with torch.no_grad(): 
                probs = torch.softmax(logits, dim=1) 
                pred = probs.argmax(1) 
                if aug is None: 
                    t_correct += (pred == yb).sum().item() 
                    t_total += yb.size(0) 
                else: 
                    # approximate accuracy against y_a 
                    y_a, _, lam = yb_m 
                    t_correct += (pred == y_a).sum().item() * lam 
                    t_total += y_a.size(0) 
                t_loss += total.item() * xb.size(0) 
                t_pp += pp.sum().item() 
 
        tr_loss = t_loss / max(t_total, 1) 
        tr_acc = t_correct / max(t_total, 1) 
        tr_pp = t_pp / max(t_total, 1) 
 
        # Validation 
        model.eval() 
        v_loss=0.0; v_cor=0; v_tot=0; v_pp=0.0 
        with torch.no_grad(): 
            for xb, yb, _ in val_loader: 
                xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device) 
                logits, emb, _ = model(xb) 
                total, ce, trip, pp = criterion(logits, emb, yb) 
                probs = torch.softmax(logits, dim=1) 
                pred = probs.argmax(1) 
                v_cor += (pred == yb).sum().item() 
                v_tot += yb.size(0) 
                v_loss += total.item() * xb.size(0) 
                v_pp += pp.sum().item() 
        va_loss = v_loss / max(v_tot, 1) 
        va_acc  = v_cor / max(v_tot, 1) 
        va_pp   = v_pp / max(v_tot, 1) 
 
        history['train_loss'].append(tr_loss) 
        history['val_loss'].append(va_loss) 
        history['train_acc'].append(tr_acc) 
        history['val_acc'].append(va_acc) 
        history['train_ppdrq'].append(tr_pp) 
        history['val_ppdrq'].append(va_pp) 
        history['epochs_trained'] = epoch + 1 
 
        if (epoch+1) == 1 or (epoch+1) % 5 == 0 or (epoch+1) == num_epochs: 
            print(f"Epoch {epoch+1:03d}/{num_epochs} | Train: loss {tr_loss:.4f} acc {tr_acc:.4f} pp {tr_pp:.3f} | Val: loss {va_loss:.4f} acc {va_acc:.4f} pp {va_pp:.3f}") 
 
        scheduler.step(va_loss) 
        if va_loss < best_val: 
            best_val = va_loss 
            epochs_no_improve = 0 
            torch.save(model.state_dict(), os.path.join(Config.output_dir, f"{model_name}_best.pth")) 
        else: 
            epochs_no_improve += 1 
            if epochs_no_improve >= Config.patience_early_stopping: 
                print(f"Early stopping at epoch {epoch+1}") 
                break 
 
    dur = time.time() - start 
    print(f"Training time for {model_name}: {dur:.1f}s") 
    return history, dur 
 
 
def calculate_calibration_metrics(probabilities: torch.Tensor, labels: torch.Tensor, num_bins=10): 
    bins = torch.linspace(0, 1, num_bins + 1) 
    conf, pred = probabilities.max(1) 
    acc = (pred == labels).float() 
    ece = 0.0 
    for i in range(num_bins): 
        in_bin = (conf > bins[i]) & (conf <= bins[i+1]) 
        if in_bin.any(): 
            ece += torch.abs(acc[in_bin].mean() - conf[in_bin].mean()) * in_bin.float().mean() 
    one_hot = F.one_hot(labels, num_classes=Config.num_classes).float() 
    brier = torch.mean(torch.sum((probabilities - one_hot) ** 2, dim=1)) 
    return ece.item(), brier.item() 
 
 
def plot_reliability_diagram(conf, correct, ece, model_name, num_bins=10): 
    bins = np.linspace(0, 1, num_bins + 1) 
    mids = (bins[:-1] + bins[1:]) / 2 
    bin_acc = [] 
    counts = [] 
    for i in range(num_bins): 
        m = (conf >= bins[i]) & (conf <= bins[i+1]) 
        if m.sum() > 0: 
            bin_acc.append(correct[m].float().mean().item()) 
            counts.append(int(m.sum())) 
        else: 
            bin_acc.append(0.0); counts.append(0) 
    plt.figure(figsize=(6,6)) 
    plt.plot([0,1],[0,1],'k:') 
    plt.bar(mids, bin_acc, width=1/num_bins*0.9, alpha=0.7, edgecolor='black') 
    for i, c in enumerate(counts): 
        if c>0: 
            plt.text(mids[i], bin_acc[i]+0.02, str(c), ha='center', fontsize=8) 
    plt.title(f"Reliability: {model_name} (ECE={ece:.3f})") 
    plt.xlabel('Confidence'); plt.ylabel('Accuracy'); plt.ylim(0,1) 
    plt.grid(True) 
    plt.savefig(os.path.join(Config.output_dir, f"{model_name}_reliability.png")) 
    plt.close() 
 
 
def evaluate_model(model, loader, model_name): 
    model.eval() 
    all_prob=[]; all_y=[] 
    t0=time.time() 
    if len(loader.dataset)==0: 
        print(f"[WARN] {model_name} loader empty.") 
        return {k:0.0 for k in ['accuracy','precision','recall','f1_score','mean_ppdrq','ece','brier_score','inference_time']} 
    with torch.no_grad(): 
        for xb, yb, _ in tqdm(loader, desc=f"Evaluating {model_name}"): 
            xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device) 
            logits, _, _ = model(xb) 
            prob = torch.softmax(logits, dim=1) 
            all_prob.append(prob.cpu()); all_y.append(yb.cpu()) 
    t1=time.time() 
    prob=torch.cat(all_prob,0) 
    y=torch.cat(all_y,0) 
    pred=prob.argmax(1) 
    acc=accuracy_score(y.numpy(), pred.numpy()) 
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    # PPDRQ computed on logits; use log(prob) as surrogate logits 
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes) 
    ece,brier=calculate_calibration_metrics(prob, y) 
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, model_name) 
    print(f"\n[{model_name}] Acc {acc:.4f} Prec {prec:.4f} Rec {rec:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f} Brier {brier:.4f}") 
    return { 
        'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1, 
        'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier, 
        'inference_time':t1-t0 
    } 
 
# ------------------------------- 
# 7. Layer-wise PPDRQ 
# ------------------------------- 
 
def get_and_plot_layerwise_ppdrq(model, loader, model_name, title_suffix=""): 
    model.eval() 
    collect = {k:[] for k in ['base_model_output','hidden_layer1','hidden_layer2','flatten','final_logits']} 
    if len(loader.dataset)==0: 
        print(f"[WARN] No data for layer-wise PPDRQ.") 
        return 
    with torch.no_grad(): 
        for xb, _, _ in loader: 
            xb=xb.to(Config.device) 
            logits, _, d = model(xb) 
            for k,v in d.items(): 
                if k=='base_model_output': 
                    collect[k].append(v.mean(dim=(2,3)).cpu()) 
                else: 
                    collect[k].append(v.cpu()) 
    temp_base = nn.Linear(2048, Config.num_classes) 
    temp_h1   = nn.Linear(1024, Config.num_classes) 
    temp_h2   = nn.Linear(512, Config.num_classes) 
    temp_flat = nn.Linear(2048, Config.num_classes) 
    mean_vals={} 
    for k, arr in collect.items(): 
        if not arr: mean_vals[k]=0.0; continue 
        X = torch.cat(arr,0) 
        if k=='base_model_output': 
            l = temp_base(X) 
        elif k=='hidden_layer1': 
            l = temp_h1(X) 
        elif k=='hidden_layer2': 
            l = temp_h2(X) 
        elif k=='flatten': 
            l = temp_flat(X) 
        else: # final_logits 
            l = X 
        pp = compute_ppdrq_from_logits(l, Config.num_classes) 
        mean_vals[k]=pp.mean().item() 
    mmax=max(mean_vals.values()) if mean_vals else 1.0 
    norm={k:(v/mmax if mmax>0 else 0.0) for k,v in mean_vals.items()} 
    layers=list(norm.keys()); vals=[norm[k] for k in layers] 
    plt.figure(figsize=(9,5)) 
    plt.bar(layers, vals) 
    for i,v in enumerate(vals): 
        plt.text(i, v+0.02, f"{v:.2f}", ha='center') 
    plt.ylim(0,1); plt.ylabel('Normalized Mean PPDRQ'); plt.title(f"Layer PPDRQ {title_suffix}") 
    plt.savefig(os.path.join(Config.output_dir, f"{model_name}_layer_ppdrq.png")) 
    plt.close() 
 
# ------------------------------- 
# 8. MC-Dropout & Ensemble 
# ------------------------------- 
 
def _set_dropout_mode_only(model: nn.Module, training: bool = True): 
    """Toggle ONLY Dropout layers' mode without touching BN or others.""" 
    for m in model.modules(): 
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.AlphaDropout)): 
            m.train(training)  # enable stochastic masks 
        # do NOT touch BatchNorm or the rest 
 
def mc_dropout_predict(model, loader, num_passes=Config.mc_dropout_passes): 
    # --- FIXED STOCHASTIC INFERENCE PROTOCOL --- 
    # Keep global eval() to freeze BatchNorm running stats; 
    # enable randomness ONLY in Dropout layers. 
    model.eval() 
    _set_dropout_mode_only(model, True) 
 
    all_prob=[]; all_y=[]; t0=time.time() 
    with torch.no_grad(): 
        for xb, yb, _ in tqdm(loader, desc=f"MC-Dropout {num_passes} passes"): 
            xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device) 
            batch = [] 
            for _ in range(num_passes): 
                logits, _, _ = model(xb) 
                batch.append(torch.softmax(logits, dim=1).unsqueeze(0)) 
            mean_prob = torch.cat(batch,0).mean(0) 
            all_prob.append(mean_prob.cpu()); all_y.append(yb.cpu()) 
    t1=time.time() 
 
    # restore Dropout to eval mode 
    _set_dropout_mode_only(model, False) 
 
    prob=torch.cat(all_prob,0); y=torch.cat(all_y,0) 
    pred=prob.argmax(1) 
    acc=accuracy_score(y.numpy(), pred.numpy()) 
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    ece,brier=calculate_calibration_metrics(prob, y) 
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, 'MC-Dropout') 
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes) 
    print(f"\n[MC-Dropout] Acc {acc:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f}") 
    return {'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1, 
            'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier, 
            'inference_time':t1-t0} 
 
 
def train_ensemble_member(i, train_loader, val_loader): 
    print(f"\n--- Training Ensemble Member {i+1} ---") 
    m = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device) 
    crit = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)  # CE only with smoothing 
    opt = optim.AdamW(filter(lambda p: p.requires_grad, m.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay) 
    sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.2, patience=Config.patience_lr_scheduler) 
    hist, t = train_model(m, train_loader, val_loader, crit, opt, sch, f"Ensemble_{i+1}", num_epochs=max(10, Config.epochs//Config.num_ensemble_models)) 
    torch.save(m.state_dict(), os.path.join(Config.output_dir, f"ensemble_member_{i+1}.pth")) 
    return m, t, hist 
 
 
def evaluate_deep_ensemble(models_list: List[nn.Module], loader): 
    all_prob=[]; all_y=[]; t0=time.time() 
    with torch.no_grad(): 
        for xb, yb, _ in tqdm(loader, desc="Evaluating Ensemble"): 
            xb=xb.to(Config.device) 
            mem_probs=[] 
            for m in models_list: 
                m.eval() 
                logits, _, _ = m(xb) 
                mem_probs.append(torch.softmax(logits, dim=1).unsqueeze(0)) 
            mean_prob=torch.cat(mem_probs,0).mean(0) 
            all_prob.append(mean_prob.cpu()); all_y.append(torch.as_tensor(yb).cpu()) 
    t1=time.time() 
    prob=torch.cat(all_prob,0); y=torch.cat(all_y,0) 
    pred=prob.argmax(1) 
    acc=accuracy_score(y.numpy(), pred.numpy()) 
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0) 
    ece,brier=calculate_calibration_metrics(prob, y) 
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, 'Deep-Ensemble') 
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes) 
    print(f"\n[Ensemble] Acc {acc:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f}") 
    return {'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1, 
            'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier, 
            'inference_time':t1-t0} 
 
# ------------------------------- 
# 9. Paper-Ready Reporting Utilities 
# ------------------------------- 
 
def save_results_table(results: Dict[str, Dict], times: Dict[str, float], path_csv: str): 
    rows = [] 
    for name, m in results.items(): 
        train_t = times.get(name, times.get(f"{name}_Train", np.nan)) 
        infer_t = m.get('inference_time', times.get(f"{name}_Infer", np.nan)) 
        rows.append({ 
            'Model': name, 
            'Accuracy': m['accuracy'], 'Precision': m['precision'], 'Recall': m['recall'], 'F1': m['f1_score'], 
            'Mean_PPDRQ': m['mean_ppdrq'], 'ECE': m['ece'], 'Brier': m['brier_score'], 
            'Train_Time_s': train_t, 'Infer_Time_s': infer_t 
        }) 
    df = pd.DataFrame(rows) 
    csv_path = os.path.join(Config.output_dir, path_csv) 
    df.to_csv(csv_path, index=False) 
    print(f"Saved results table -> {csv_path}") 
    return df 
 
 
def plot_metric_bars(df: pd.DataFrame, metric: str, fname: str): 
    plt.figure(figsize=(8,4)) 
    order = df.sort_values(metric, ascending=False)['Model'] 
    plt.bar(order, df.set_index('Model').loc[order, metric]) 
    plt.xticks(rotation=25, ha='right') 
    plt.ylabel(metric) 
    plt.title(f"Model comparison: {metric}") 
    out = os.path.join(Config.output_dir, fname) 
    plt.tight_layout(); plt.savefig(out); plt.close() 
    print(f"Saved {metric} bar plot -> {out}") 
 
 
def plot_cost_performance(df: pd.DataFrame, perf_metric: str='Accuracy', cost_metric: str='Infer_Time_s', fname: str='cost_perf.png'): 
    plt.figure(figsize=(5,5)) 
    x = df[cost_metric].values; y = df[perf_metric].values 
    for i, row in df.iterrows(): 
        plt.scatter(row[cost_metric], row[perf_metric]) 
        plt.text(row[cost_metric], row[perf_metric]+1e-3, row['Model'], fontsize=8) 
    plt.xlabel(cost_metric); plt.ylabel(perf_metric) 
    plt.title(f"Performance-Cost Trade-off ({perf_metric} vs {cost_metric})") 
    out = os.path.join(Config.output_dir, fname) 
    plt.grid(True); plt.tight_layout(); plt.savefig(out); plt.close() 
    print(f"Saved cost-performance plot -> {out}") 
 
 
def lambda_sensitivity_sweep(lambdas=(0.0, 0.01, 0.1, 0.3, 1.0)): 
    """Re-train small runs with different λ (triplet weight) and plot Accuracy/ECE vs λ. 
    Warning: This re-trains models. Reduce epochs or subset data for quick runs. 
    """ 
    acc_list=[]; ece_list=[] 
    for lam in lambdas: 
        print(f"\n[λ-sweep] Training with λ={lam}") 
        model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device) 
        crit = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=lam) 
        opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay) 
        sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.2, patience=Config.patience_lr_scheduler) 
        hist, _ = train_model(model, train_loader, val_loader, crit, opt, sch, f'Lambda_{lam}', num_epochs=max(5, Config.epochs//6)) 
        res = evaluate_model(model, test_loader, f'Lambda_{lam}') 
        acc_list.append(res['accuracy']); ece_list.append(res['ece']) 
    plt.figure(figsize=(6,4)) 
    plt.plot(list(lambdas), acc_list, marker='o', label='Accuracy') 
    plt.plot(list(lambdas), ece_list, marker='o', label='ECE') 
    plt.xlabel('λ (triplet weight)'); plt.legend(); plt.grid(True) 
    out = os.path.join(Config.output_dir, 'lambda_sensitivity.png') 
    plt.tight_layout(); plt.savefig(out); plt.close() 
    print(f"Saved λ sensitivity plot -> {out}") 
 
# ------------------------------- 
# 9b. EXTRA visualizations (added; nothing replaced) 
# ------------------------------- 
CLASS_NAMES = ['AD','MCI','NC'] 
 
def collect_predictions(model, loader): 
    y_true=[]; y_pred=[]; conf=[]; probs=[] 
    model.eval() 
    with torch.no_grad(): 
        for xb, yb, _ in loader: 
            xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device) 
            logits, _, _ = model(xb) 
            p = torch.softmax(logits, dim=1) 
            pr = p.argmax(1) 
            y_true.append(yb.cpu()); y_pred.append(pr.cpu()); conf.append(p.max(1)[0].cpu()); probs.append(p.cpu()) 
    return torch.cat(y_true).numpy(), torch.cat(y_pred).numpy(), torch.cat(conf).numpy(), torch.cat(probs,0).numpy() 
 
def plot_confusion(y_true, y_pred, classes, title, fname, normalize=True): 
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes)))) 
    if normalize: 
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True).clip(min=1) 
    plt.figure(figsize=(5,4)) 
    plt.imshow(cm, interpolation='nearest', cmap='Blues') 
    plt.title(title); plt.colorbar() 
    tick_marks = np.arange(len(classes)) 
    plt.xticks(tick_marks, classes, rotation=45); plt.yticks(tick_marks, classes) 
    thresh = cm.max() / 2. 
    for i in range(cm.shape[0]): 
        for j in range(cm.shape[1]): 
            plt.text(j, i, f"{cm[i, j]:.2f}" if normalize else int(cm[i, j]), 
                     ha="center", va="center", 
                     color="white" if cm[i, j] > thresh else "black") 
    plt.ylabel('True'); plt.xlabel('Predicted'); plt.tight_layout() 
    out = os.path.join(Config.output_dir, fname) 
    plt.savefig(out); plt.close(); print(f"Saved confusion matrix -> {out}") 
 
def export_per_class_report(y_true, y_pred, classes, fname_csv): 
    rep = classification_report(y_true, y_pred, target_names=classes, output_dict=True, zero_division=0) 
    df = pd.DataFrame(rep).transpose() 
    out = os.path.join(Config.output_dir, fname_csv) 
    df.to_csv(out) 
    print(f"Saved per-class report -> {out}") 
 
def reliability_grid(models_dict, loader, fname='reliability_grid.png', num_bins=10): 
    cols = len(models_dict); plt.figure(figsize=(5*cols,5)) 
    for idx,(name, model) in enumerate(models_dict.items(), start=1): 
        model.eval(); all_prob=[]; all_y=[] 
        with torch.no_grad(): 
            for xb, yb, _ in loader: 
                xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device) 
                logits, _, _ = model(xb); pr = torch.softmax(logits, dim=1) 
                all_prob.append(pr.cpu()); all_y.append(yb.cpu()) 
        prob=torch.cat(all_prob,0); y=torch.cat(all_y,0); pred=prob.argmax(1) 
        conf=prob.max(1)[0]; correct=(pred==y) 
        # compute ECE for title 
        ece,_ = calculate_calibration_metrics(prob, y) 
        bins = np.linspace(0,1,num_bins+1); mids=(bins[:-1]+bins[1:])/2 
        bin_acc=[] 
        for i in range(num_bins): 
            m = (conf>=bins[i]) & (conf<=bins[i+1]) 
            bin_acc.append(correct[m].float().mean().item() if m.sum()>0 else 0.0) 
        ax = plt.subplot(1, cols, idx) 
        ax.plot([0,1],[0,1],'k:'); ax.bar(mids, bin_acc, width=1/num_bins*0.9, alpha=0.7, edgecolor='black') 
        ax.set_title(f"{name} (ECE={ece:.3f})"); ax.set_xlabel('Confidence'); ax.set_ylabel('Accuracy'); ax.set_ylim(0,1); ax.grid(True) 
    out = os.path.join(Config.output_dir, fname) 
    plt.tight_layout(); plt.savefig(out); plt.close(); print(f"Saved reliability grid -> {out}") 
 
# ------------------------------- 
# 10. Main 
# ------------------------------- 
if __name__ == '__main__': 
    results: Dict[str, Dict] = {} 
    times: Dict[str, float] = {} 
 
    # Baseline (with label smoothing + anti-overfitting) 
    baseline = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device) 
    crit_base = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)  # CE only 
    opt_base = optim.AdamW(filter(lambda p: p.requires_grad, baseline.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay) 
    sch_base = optim.lr_scheduler.ReduceLROnPlateau(opt_base, mode='min', factor=0.2, patience=Config.patience_lr_scheduler) 
    hist_b, t_b = train_model(baseline, train_loader, val_loader, crit_base, opt_base, sch_base, 'Baseline') 
    res_b = evaluate_model(baseline, test_loader, 'Baseline') 
    results['Baseline'] = res_b; times['Baseline'] = t_b 
 
    # PPDRQ-weighted CE only 
    class PPDRQWeightedCE(nn.Module): 
        def __init__(self, num_classes): 
            super().__init__(); self.num_classes=num_classes 
        def forward(self, logits, _emb, labels, **kwargs): 
            pp = compute_ppdrq_from_logits(logits, self.num_classes) 
            per_ce = F.cross_entropy(logits, labels, reduction='none', label_smoothing=Config.label_smoothing) 
            w = 1 + (1 - pp) 
            loss = torch.mean(per_ce * w) 
            return loss, loss, torch.tensor(0.0, device=logits.device), pp 
 
    ppdrq_model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device) 
    crit_pp = PPDRQWeightedCE(num_classes=Config.num_classes) 
    opt_pp = optim.AdamW(filter(lambda p: p.requires_grad, ppdrq_model.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay) 
    sch_pp = optim.lr_scheduler.ReduceLROnPlateau(opt_pp, mode='min', factor=0.2, patience=Config.patience_lr_scheduler) 
    hist_p, t_p = train_model(ppdrq_model, train_loader, val_loader, crit_pp, opt_pp, sch_pp, 'PPDRQ_CE') 
    res_p = evaluate_model(ppdrq_model, test_loader, 'PPDRQ_CE') 
    results['PPDRQ_CE'] = res_p; times['PPDRQ_CE'] = t_p 
 
    # Triplet model 
    triplet_model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device) 
    crit_trip = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=Config.lambda_triplet) 
    opt_trip = optim.AdamW(filter(lambda p: p.requires_grad, triplet_model.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay) 
    sch_trip = optim.lr_scheduler.ReduceLROnPlateau(opt_trip, mode='min', factor=0.2, patience=Config.patience_lr_scheduler) 
    hist_t, t_t = train_model(triplet_model, train_loader, val_loader, crit_trip, opt_trip, sch_trip, 'Triplet') 
    res_t = evaluate_model(triplet_model, test_loader, 'Triplet') 
    results['Triplet'] = res_t; times['Triplet'] = t_t 
 
    # MC-Dropout (train normally; eval with stochastic passes) 
    mc_model = CustomInceptionV3(num_classes=Config.num_classes, dropout_rate=0.6).to(Config.device) 
    crit_mc = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0) 
    opt_mc = optim.AdamW(filter(lambda p: p.requires_grad, mc_model.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay) 
    sch_mc = optim.lr_scheduler.ReduceLROnPlateau(opt_mc, mode='min', factor=0.2, patience=Config.patience_lr_scheduler) 
    hist_m, t_m = train_model(mc_model, train_loader, val_loader, crit_mc, opt_mc, sch_mc, 'MC_Dropout') 
    res_m = mc_dropout_predict(mc_model, test_loader, num_passes=Config.mc_dropout_passes) 
    results['MC_Dropout'] = res_m; times['MC_Dropout_Train'] = t_m; times['MC_Dropout_Infer'] = res_m['inference_time'] 
 
    # Deep Ensemble (5 members) 
    members=[]; t_mem=[] 
    for i in range(Config.num_ensemble_models): 
        m, t, _ = train_ensemble_member(i, train_loader, val_loader) 
        members.append(m); t_mem.append(t) 
    res_e = evaluate_deep_ensemble(members, test_loader) 
    results['Ensemble'] = res_e; times['Ensemble_Train'] = sum(t_mem); times['Ensemble_Infer'] = res_e['inference_time'] 
 
    # Summary to console 
    print("\n--- Summary ---") 
    print(f"{'Model':<15}{'Acc':>8}{'Prec':>8}{'Rec':>8}{'F1':>8}{'PPDRQ':>9}{'ECE':>8}{'Brier':>8}{'Train(s)':>10}{'Infer(s)':>10}") 
    for k, v in results.items(): 
        tr = times.get(k, times.get(f"{k}_Train", 0.0)) 
        inf = v.get('inference_time', times.get(f"{k}_Infer", 0.0)) 
        print(f"{k:<15}{v['accuracy']:>8.3f}{v['precision']:>8.3f}{v['recall']:>8.3f}{v['f1_score']:>8.3f}{v['mean_ppdrq']:>9.3f}{v['ece']:>8.3f}{v['brier_score']:>8.3f}{tr:>10.1f}{inf:>10.1f}") 
 
    # Save curves for baseline/PPDRQ/Triplet 
    for name, hist in [('Baseline', hist_b), ('PPDRQ_CE', hist_p), ('Triplet', hist_t)]: 
        plt.figure(figsize=(10,4)) 
        plt.subplot(1,2,1) 
        plt.plot(hist['train_acc'], label='Train'); plt.plot(hist['val_acc'], label='Val'); plt.title(f'{name} Acc') 
        plt.legend(); plt.subplot(1,2,2) 
        plt.plot(hist['train_loss'], label='Train'); plt.plot(hist['val_loss'], label='Val'); plt.title(f'{name} Loss') 
        plt.legend(); plt.tight_layout() 
        plt.savefig(os.path.join(Config.output_dir, f"{name}_curves.png")) 
        plt.close() 
 
    # Unified table + plots for paper 
    df = save_results_table(results, times, "results_all_models.csv") 
    plot_metric_bars(df, "Accuracy", "cmp_accuracy.png") 
    plot_metric_bars(df, "F1", "cmp_f1.png") 
    plot_metric_bars(df, "ECE", "cmp_ece.png") 
    plot_metric_bars(df, "Brier", "cmp_brier.png") 
    plot_cost_performance(df, perf_metric="Accuracy", cost_metric="Infer_Time_s", fname="cost_perf_acc_vs_infer.png") 
 
    # EXTRA visuals: confusion matrices, per-class reports, side-by-side reliability grid (test set) 
    model_objs = { 
        'Baseline': baseline, 
        'PPDRQ_CE': ppdrq_model, 
        'Triplet': triplet_model, 
        'MC_Dropout': mc_model 
    } 
    class EnsembleWrapper(nn.Module): 
        def __init__(self, members): 
            super().__init__(); self.members=members 
        def forward(self, x): 
            outs=[] 
            for m in self.members: 
                m.eval(); lo,_,_ = m(x); outs.append(lo.unsqueeze(0)) 
            logits = torch.mean(torch.cat(outs,0), dim=0) 
            return logits, torch.zeros(x.size(0),512, device=logits.device), {'final_logits': logits} 
    ens_wrapper = EnsembleWrapper(members) 
    model_objs['Ensemble'] = ens_wrapper 
 
    for name, m in model_objs.items(): 
        y_t, y_p, confs, prob = collect_predictions(m, test_loader) 
        plot_confusion(y_t, y_p, CLASS_NAMES, f"Confusion: {name}", f"cm_{name}.png", normalize=True) 
        export_per_class_report(y_t, y_p, CLASS_NAMES, f"per_class_{name}.csv") 
 
    reliability_grid(model_objs, test_loader, fname='reliability_grid_test.png', num_bins=10) 
 
    # Unseen (domain shift) evaluation if available 
    if len(unseen_dataset)>0: 
        print("\n--- Evaluating on UNSEEN domain ---") 
        # Load best weights if present 
        for tag, model in [('Baseline', baseline), ('PPDRQ_CE', ppdrq_model), ('Triplet', triplet_model), ('MC_Dropout', mc_model)]: 
            pth = os.path.join(Config.output_dir, f"{tag}_best.pth") 
            if os.path.exists(pth): 
                model.load_state_dict(torch.load(pth, map_location=Config.device)) 
        # Evaluate 
        res_b_u = evaluate_model(baseline, unseen_loader, 'Baseline_Unseen') 
        res_p_u = evaluate_model(ppdrq_model, unseen_loader, 'PPDRQ_CE_Unseen') 
        res_t_u = evaluate_model(triplet_model, unseen_loader, 'Triplet_Unseen') 
        res_m_u = mc_dropout_predict(mc_model, unseen_loader, num_passes=Config.mc_dropout_passes) 
        if len(members)==Config.num_ensemble_models: 
            res_e_u = evaluate_deep_ensemble(members, unseen_loader) 
        print("--- Unseen evaluation complete ---") 
 
    print("\nDone. Check ppdrq_results/ for saved plots, CSVs, and weights (including EXTRA figures).")


# PPDRQ With
1.DATASET VERIFICATION
   - Patient-level split verification completed
   - See: patient_split_verification.csv
   
2. MODEL PERFORMANCE
   All models evaluated on test set with comprehensive metrics:
   - See: results_all_models.csv
   
3. STATISTICAL SIGNIFICANCE
   Bootstrap confidence intervals and McNemar tests computed:
   - See: statistical_significance_results.csv
   - Visualizations: *_confidence_intervals.png
   
4. CALIBRATION ANALYSIS
   - Reliability diagrams for all models
   - ECE and Brier scores computed
   - See: reliability_grid_test.png
   
5. PER-CLASS PERFORMANCE
   - Confusion matrices: cm_*.png
   - Detailed reports: per_class_*.csv
   
6. CLINICAL VALIDATION
   - Cases selected for neurologist review
   - See: clinical_validation_*.csv
   - Template provided for clinical observations
   
7. DOMAIN SHIFT EVALUATION
   - Completed on unseen dataset
   - See: results_unseen_domain.csv
   - Limitation acknowledged in results

In [None]:
import os
import math
import random
import time
import warnings
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
from scipy import stats
from scipy.stats import wilcoxon

warnings.filterwarnings("ignore")

# -------------------------------
# 1. Config
# -------------------------------
class Config:
    data_dir = 'dataset'              # expected: data/{train,validation,test,unseen}/{AD,MCI,NC}
    output_dir = 'result save directory'

    image_size = 299
    num_classes = 3                # AD / MCI / NC

    batch_size = 16                
    epochs = 60                    
    learning_rate = 3e-4
    weight_decay = 5e-4            
    patience_lr_scheduler = 3
    patience_early_stopping = 10
    grad_clip = 1.0

    # Triplet & PPDRQ
    lambda_triplet = 0.1
    triplet_margin = 0.2
    epsilon_p = 0.1
    epsilon_n = 0.2

    # MC-Dropout & Ensemble
    mc_dropout_passes = 30
    mc_dropout_rate = 0.3          # FIXED: Reasonable dropout rate
    num_ensemble_models = 5

    # Statistical testing
    bootstrap_iterations = 1000
    confidence_level = 0.95

    # Regularization toggles
    use_mixup_cutmix = True
    mixup_alpha = 0.2
    cutmix_alpha = 0.2
    label_smoothing = 0.05

    # Data safety
    allow_make_dummy = False       

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


os.makedirs(Config.output_dir, exist_ok=True)
set_seed(42)
print(f"Using device: {Config.device}")

# -------------------------------
# 2. Dataset with Patient ID Support
# -------------------------------
class MRIDataset(Dataset):
    def __init__(self, data_dir: str, phase: str, transform=None):
        self.data_dir = os.path.join(data_dir, phase)
        self.transform = transform
        self.class_to_idx = {'AD': 0, 'MCI': 1, 'NC': 2}
        self.image_paths: List[str] = []
        self.labels: List[int] = []
        self.patient_ids: List[str] = []  # NEW: Track patient IDs

        for cname, idx in self.class_to_idx.items():
            cdir = os.path.join(self.data_dir, cname)
            if not os.path.exists(cdir):
                if Config.allow_make_dummy:
                    print(f"[WARN] Missing {cdir}. Creating dummy images for demo only.")
                    os.makedirs(cdir, exist_ok=True)
                    n = 8 if phase != 'train' else 64
                    for i in range(n):
                        img = Image.new('L', (Config.image_size, Config.image_size), color=random.randint(0, 255))
                        img.convert('RGB').save(os.path.join(cdir, f'dummy_{i}.png'))
                else:
                    print(f"[WARN] Missing class dir: {cdir}. Skipping.")
                    continue

            files = [f for f in os.listdir(cdir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            if not files and Config.allow_make_dummy:
                print(f"[WARN] Empty {cdir}. Creating dummy images for demo only.")
                n = 8 if phase != 'train' else 64
                for i in range(n):
                    img = Image.new('L', (Config.image_size, Config.image_size), color=random.randint(0, 255))
                    img.convert('RGB').save(os.path.join(cdir, f'dummy_{i}.png'))
                files = [f for f in os.listdir(cdir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

            for name in files:
                self.image_paths.append(os.path.join(cdir, name))
                self.labels.append(idx)
                # Extract patient ID from filename (assumes format: patientID_xxx.png)
                # Modify this based on your actual naming convention
                patient_id = name.split('_')[0] if '_' in name else name.split('.')[0]
                self.patient_ids.append(f"{cname}_{patient_id}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        p = self.image_paths[idx]
        y = self.labels[idx]
        pid = self.patient_ids[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, y, idx, pid

# Transforms (unchanged)
train_tfms = transforms.Compose([
    transforms.Resize(int(Config.image_size*1.1)),
    transforms.RandomResizedCrop(Config.image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random')
])

val_test_tfms = transforms.Compose([
    transforms.Resize((Config.image_size, Config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = MRIDataset(Config.data_dir, 'train', train_tfms)
val_dataset   = MRIDataset(Config.data_dir, 'validation', val_test_tfms)
test_dataset  = MRIDataset(Config.data_dir, 'test', val_test_tfms)
unseen_dataset= MRIDataset(Config.data_dir, 'unseen', val_test_tfms)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}, Unseen: {len(unseen_dataset)}")

# -------------------------------
# NEW: Patient-Level Split Verification
# -------------------------------
def verify_patient_level_split():
    """Verify no patient appears in multiple splits"""
    train_patients = set(train_dataset.patient_ids)
    val_patients = set(val_dataset.patient_ids)
    test_patients = set(test_dataset.patient_ids)
    
    # Check for leakage
    train_val_overlap = train_patients & val_patients
    train_test_overlap = train_patients & test_patients
    val_test_overlap = val_patients & test_patients
    
    print("\n" + "="*60)
    print("PATIENT-LEVEL SPLIT VERIFICATION")
    print("="*60)
    print(f"Train set: {len(train_patients)} unique patients, {len(train_dataset)} images")
    print(f"Validation set: {len(val_patients)} unique patients, {len(val_dataset)} images")
    print(f"Test set: {len(test_patients)} unique patients, {len(test_dataset)} images")
    print(f"Total unique patients: {len(train_patients | val_patients | test_patients)}")
    
    if train_val_overlap:
        print(f"\n⚠️  WARNING: {len(train_val_overlap)} patients appear in both TRAIN and VAL")
    if train_test_overlap:
        print(f"⚠️  WARNING: {len(train_test_overlap)} patients appear in both TRAIN and TEST")
    if val_test_overlap:
        print(f"⚠️  WARNING: {len(val_test_overlap)} patients appear in both VAL and TEST")
    
    if not (train_val_overlap or train_test_overlap or val_test_overlap):
        print("\n✓ No data leakage detected: All splits are patient-level disjoint")
    
    print("="*60 + "\n")
    
    # Save verification report
    report = {
        'Split': ['Train', 'Validation', 'Test', 'Total Unique'],
        'Patients': [len(train_patients), len(val_patients), len(test_patients), 
                     len(train_patients | val_patients | test_patients)],
        'Images': [len(train_dataset), len(val_dataset), len(test_dataset), 
                   len(train_dataset) + len(val_dataset) + len(test_dataset)]
    }
    df = pd.DataFrame(report)
    df.to_csv(os.path.join(Config.output_dir, 'patient_split_verification.csv'), index=False)
    print(f"Saved split verification -> {os.path.join(Config.output_dir, 'patient_split_verification.csv')}")

if len(train_dataset) > 0:
    verify_patient_level_split()

if len(train_dataset) == 0:
    print("[ERROR] Train dataset is empty. Please prepare data under data/train/{AD,MCI,NC}.")

# Dataloaders
def collate_fn(batch):
    images, labels, indices, patient_ids = zip(*batch)
    return torch.stack(images), torch.tensor(labels), torch.tensor(indices), list(patient_ids)

train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, 
                          num_workers=0, drop_last=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, 
                          num_workers=0, collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, 
                          num_workers=0, collate_fn=collate_fn)
unseen_loader= DataLoader(unseen_dataset, batch_size=Config.batch_size, shuffle=False, 
                          num_workers=0, collate_fn=collate_fn)

# -------------------------------
# 3. Model
# -------------------------------
class CustomInceptionV3(nn.Module):
    def __init__(self, num_classes=3, dropout_rate=0.6, include_aux_logits=True, train_last_blocks_only=True):
        super().__init__()
        import torchvision
        try:
            self.inception_base = torchvision.models.inception_v3(
                weights=torchvision.models.Inception_V3_Weights.IMAGENET1K_V1,
                aux_logits=True
            )
        except Exception:
            self.inception_base = models.inception_v3(pretrained=True, aux_logits=True)
        
        self.inception_base.fc = nn.Identity()

        if train_last_blocks_only:
            for p in self.inception_base.parameters():
                p.requires_grad = False
            for name, m in self.inception_base.named_modules():
                if name.startswith('Mixed_7'):
                    for p in m.parameters():
                        p.requires_grad = True

        self.bn = nn.BatchNorm1d(2048)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.hidden1 = nn.Linear(2048, 1024)
        self.hidden2 = nn.Linear(1024, 512)
        self.fc_head = nn.Linear(2048, num_classes)

        self.mixed_7c_output = None
        def hook_fn(_m, _in, out):
            self.mixed_7c_output = out
        self.inception_base.Mixed_7c.register_forward_hook(hook_fn)

    def _extract_features(self, x):
        out = self.inception_base(x)
        if hasattr(out, 'logits'):
            return out.logits
        if isinstance(out, (tuple, list)) and len(out) > 0:
            return out[0]
        return out

    def forward(self, x):
        flatten = self._extract_features(x)
        z = self.bn(flatten)
        z = self.dropout(z)
        logits = self.fc_head(z)
        e = F.relu(self.hidden1(z))
        e = F.relu(self.hidden2(e))
        layer_dict = {
            'base_model_output': self.mixed_7c_output,
            'hidden_layer1': F.relu(self.hidden1(self.bn(flatten))),
            'hidden_layer2': e,
            'flatten': flatten,
            'final_logits': logits
        }
        return logits, e, layer_dict

# -------------------------------
# 4. PPDRQ & Losses
# -------------------------------

def compute_ppdrq_from_logits(logits: torch.Tensor, num_classes: int) -> torch.Tensor:
    probs = torch.softmax(logits, dim=1)
    if num_classes == 3:
        p1, p2, p3 = probs[:, 0], probs[:, 1], probs[:, 2]
        d12 = (p1 - p2).abs(); d13 = (p1 - p3).abs(); d23 = (p2 - p3).abs()
        raw = (d12 + d13 + d23) / 3.0
        pp = (3/2) * raw
        return torch.clamp(pp, 0, 1)
    diffs = []
    for i in range(num_classes):
        for j in range(i+1, num_classes):
            diffs.append((probs[:, i] - probs[:, j]).abs())
    sumdiff = torch.stack(diffs, dim=1).sum(dim=1)
    pp = sumdiff / (num_classes - 1)
    return torch.clamp(pp, 0, 1)

class CombinedLoss(nn.Module):
    def __init__(self, num_classes, lambda_triplet=Config.lambda_triplet, margin=Config.triplet_margin,
                 epsilon_p=Config.epsilon_p, epsilon_n=Config.epsilon_n, smoothing=Config.label_smoothing):
        super().__init__()
        self.num_classes = num_classes
        self.lambda_triplet = lambda_triplet
        self.epsilon_p = epsilon_p
        self.epsilon_n = epsilon_n
        self.triplet = nn.TripletMarginLoss(margin=margin, p=2)
        self.smoothing = smoothing

    def forward(self, logits, embeddings, labels,
                all_feats=None, all_labels=None, all_pp=None):
        pp = compute_ppdrq_from_logits(logits, self.num_classes)
        per_sample_ce = F.cross_entropy(logits, labels, reduction='none', label_smoothing=self.smoothing)
        weights = 1 + (1 - pp)
        ce_loss = torch.mean(per_sample_ce * weights)

        triplet_loss = torch.tensor(0.0, device=logits.device)
        if all_feats is not None and len(all_feats) > 0:
            allF = torch.cat(all_feats, dim=0).to(embeddings.device)
            allY = torch.cat(all_labels, dim=0).to(embeddings.device)
            allP = torch.cat(all_pp, dim=0).to(embeddings.device)
            count = 0
            for i in range(labels.size(0)):
                a = embeddings[i]
                y = labels[i]
                ppa = pp[i]
                pos_idx = torch.where((allY == y) & ((allP - ppa).abs() < Config.epsilon_p))[0]
                neg_idx = torch.where((allY != y) & ((allP - ppa).abs() >= Config.epsilon_n))[0]
                if pos_idx.numel() > 1 and neg_idx.numel() > 0:
                    p = allF[random.choice(pos_idx.tolist())]
                    n = allF[random.choice(neg_idx.tolist())]
                    triplet_loss = triplet_loss + self.triplet(a.unsqueeze(0), p.unsqueeze(0), n.unsqueeze(0))
                    count += 1
            if count > 0:
                triplet_loss = triplet_loss / count
        total = ce_loss + self.lambda_triplet * triplet_loss
        return total, ce_loss, triplet_loss, pp

# -------------------------------
# 5. Mixup/CutMix utils
# -------------------------------

def rand_bbox(W, H, lam):
    cut_rat = math.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    x1 = np.clip(cx - cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y2 = np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2


def apply_mixup_cutmix(x, y):
    if not Config.use_mixup_cutmix:
        return x, y, None
    r = random.random()
    if r < 0.5:
        lam = np.random.beta(Config.mixup_alpha, Config.mixup_alpha)
        idx = torch.randperm(x.size(0)).to(x.device)
        mixed_x = lam * x + (1 - lam) * x[idx]
        y_a, y_b = y, y[idx]
        return mixed_x, (y_a, y_b, lam), 'mixup'
    else:
        lam = np.random.beta(Config.cutmix_alpha, Config.cutmix_alpha)
        idx = torch.randperm(x.size(0)).to(x.device)
        x1, y1, x2, y2 = rand_bbox(x.size(3), x.size(2), lam)
        x_mix = x.clone()
        x_mix[:, :, y1:y2, x1:x2] = x[idx, :, y1:y2, x1:x2]
        lam = 1 - ((x2 - x1) * (y2 - y1) / (x.size(-1) * x.size(-2)))
        y_a, y_b = y, y[idx]
        return x_mix, (y_a, y_b, lam), 'cutmix'


def mix_criterion(logits, target_tuple):
    y_a, y_b, lam = target_tuple
    loss_a = F.cross_entropy(logits, y_a, label_smoothing=Config.label_smoothing)
    loss_b = F.cross_entropy(logits, y_b, label_smoothing=Config.label_smoothing)
    return lam * loss_a + (1 - lam) * loss_b

# -------------------------------
# 6. Train/Eval
# -------------------------------

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, model_name, num_epochs=Config.epochs):
    best_val = float('inf')
    epochs_no_improve = 0
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'train_ppdrq': [], 'val_ppdrq': [], 'epochs_trained': 0}
    start = time.time()

    print(f"\n--- Training {model_name} ---")
    for epoch in range(num_epochs):
        model.train()
        t_loss = 0.0; t_correct = 0; t_total = 0; t_pp = 0.0

        # Pre-collect for triplet
        allF, allY, allP = [], [], []
        with torch.no_grad():
            for xb, yb, _, _ in train_loader:
                xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
                logits, emb, _ = model(xb)
                pp = compute_ppdrq_from_logits(logits, Config.num_classes)
                allF.append(emb.detach().cpu()); allY.append(yb.detach().cpu()); allP.append(pp.detach().cpu())
        allF = torch.cat(allF, 0) if allF else torch.tensor([])
        allY = torch.cat(allY, 0) if allY else torch.tensor([])
        allP = torch.cat(allP, 0) if allP else torch.tensor([])

        for xb, yb, _, _ in train_loader:
            xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
            xb_m, yb_m, aug = apply_mixup_cutmix(xb, yb)

            optimizer.zero_grad(set_to_none=True)
            logits, emb, _ = model(xb_m)

            if aug is None:
                total, ce, trip, pp = criterion(logits, emb, yb, [allF], [allY], [allP])
            else:
                ce_mix = mix_criterion(logits, yb_m)
                pp = compute_ppdrq_from_logits(logits, Config.num_classes)
                total = ce_mix
                ce, trip = ce_mix, torch.tensor(0.0, device=logits.device)

            total.backward()
            if Config.grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), Config.grad_clip)
            optimizer.step()

            with torch.no_grad():
                probs = torch.softmax(logits, dim=1)
                pred = probs.argmax(1)
                if aug is None:
                    t_correct += (pred == yb).sum().item()
                    t_total += yb.size(0)
                else:
                    y_a, _, lam = yb_m
                    t_correct += (pred == y_a).sum().item() * lam
                    t_total += y_a.size(0)
                t_loss += total.item() * xb.size(0)
                t_pp += pp.sum().item()

        tr_loss = t_loss / max(t_total, 1)
        tr_acc = t_correct / max(t_total, 1)
        tr_pp = t_pp / max(t_total, 1)

        # Validation
        model.eval()
        v_loss=0.0; v_cor=0; v_tot=0; v_pp=0.0
        with torch.no_grad():
            for xb, yb, _, _ in val_loader:
                xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
                logits, emb, _ = model(xb)
                total, ce, trip, pp = criterion(logits, emb, yb)
                probs = torch.softmax(logits, dim=1)
                pred = probs.argmax(1)
                v_cor += (pred == yb).sum().item()
                v_tot += yb.size(0)
                v_loss += total.item() * xb.size(0)
                v_pp += pp.sum().item()
        va_loss = v_loss / max(v_tot, 1)
        va_acc  = v_cor / max(v_tot, 1)
        va_pp   = v_pp / max(v_tot, 1)

        history['train_loss'].append(tr_loss)
        history['val_loss'].append(va_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(va_acc)
        history['train_ppdrq'].append(tr_pp)
        history['val_ppdrq'].append(va_pp)
        history['epochs_trained'] = epoch + 1

        if (epoch+1) == 1 or (epoch+1) % 5 == 0 or (epoch+1) == num_epochs:
            print(f"Epoch {epoch+1:03d}/{num_epochs} | Train: loss {tr_loss:.4f} acc {tr_acc:.4f} pp {tr_pp:.3f} | Val: loss {va_loss:.4f} acc {va_acc:.4f} pp {va_pp:.3f}")

        scheduler.step(va_loss)
        if va_loss < best_val:
            best_val = va_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), os.path.join(Config.output_dir, f"{model_name}_best.pth"))
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= Config.patience_early_stopping:
                print(f"Early stopping at epoch {epoch+1}")
                break

    dur = time.time() - start
    print(f"Training time for {model_name}: {dur:.1f}s")
    return history, dur


def calculate_calibration_metrics(probabilities: torch.Tensor, labels: torch.Tensor, num_bins=10):
    bins = torch.linspace(0, 1, num_bins + 1)
    conf, pred = probabilities.max(1)
    acc = (pred == labels).float()
    ece = 0.0
    for i in range(num_bins):
        in_bin = (conf > bins[i]) & (conf <= bins[i+1])
        if in_bin.any():
            ece += torch.abs(acc[in_bin].mean() - conf[in_bin].mean()) * in_bin.float().mean()
    one_hot = F.one_hot(labels, num_classes=Config.num_classes).float()
    brier = torch.mean(torch.sum((probabilities - one_hot) ** 2, dim=1))
    return ece.item(), brier.item()


def plot_reliability_diagram(conf, correct, ece, model_name, num_bins=10):
    bins = np.linspace(0, 1, num_bins + 1)
    mids = (bins[:-1] + bins[1:]) / 2
    bin_acc = []
    counts = []
    for i in range(num_bins):
        m = (conf >= bins[i]) & (conf <= bins[i+1])
        if m.sum() > 0:
            bin_acc.append(correct[m].float().mean().item())
            counts.append(int(m.sum()))
        else:
            bin_acc.append(0.0); counts.append(0)
    plt.figure(figsize=(6,6))
    plt.plot([0,1],[0,1],'k:')
    plt.bar(mids, bin_acc, width=1/num_bins*0.9, alpha=0.7, edgecolor='black')
    for i, c in enumerate(counts):
        if c>0:
            plt.text(mids[i], bin_acc[i]+0.02, str(c), ha='center', fontsize=8)
    plt.title(f"Reliability: {model_name} (ECE={ece:.3f})")
    plt.xlabel('Confidence'); plt.ylabel('Accuracy'); plt.ylim(0,1)
    plt.grid(True)
    plt.savefig(os.path.join(Config.output_dir, f"{model_name}_reliability.png"))
    plt.close()


def evaluate_model(model, loader, model_name):
    model.eval()
    all_prob=[]; all_y=[]
    t0=time.time()
    if len(loader.dataset)==0:
        print(f"[WARN] {model_name} loader empty.")
        return {k:0.0 for k in ['accuracy','precision','recall','f1_score','mean_ppdrq','ece','brier_score','inference_time']}
    with torch.no_grad():
        for xb, yb, _, _ in tqdm(loader, desc=f"Evaluating {model_name}"):
            xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device)
            logits, _, _ = model(xb)
            prob = torch.softmax(logits, dim=1)
            all_prob.append(prob.cpu()); all_y.append(yb.cpu())
    t1=time.time()
    prob=torch.cat(all_prob,0)
    y=torch.cat(all_y,0)
    pred=prob.argmax(1)
    acc=accuracy_score(y.numpy(), pred.numpy())
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes)
    ece,brier=calculate_calibration_metrics(prob, y)
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, model_name)
    print(f"\n[{model_name}] Acc {acc:.4f} Prec {prec:.4f} Rec {rec:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f} Brier {brier:.4f}")
    return {
        'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1,
        'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier,
        'inference_time':t1-t0
    }

# -------------------------------
# NEW: Statistical Significance Testing
# -------------------------------

def bootstrap_confidence_interval(y_true, y_pred, metric_fn, n_bootstrap=Config.bootstrap_iterations):
    """Calculate bootstrap confidence intervals for a metric"""
    scores = []
    n_samples = len(y_true)
    
    for _ in range(n_bootstrap):
        indices = np.random.choice(n_samples, n_samples, replace=True)
        score = metric_fn(y_true[indices], y_pred[indices])
        scores.append(score)
    
    ci_lower = np.percentile(scores, (1 - Config.confidence_level) / 2 * 100)
    ci_upper = np.percentile(scores, (1 + Config.confidence_level) / 2 * 100)
    mean_score = np.mean(scores)
    
    return mean_score, ci_lower, ci_upper


def mcnemar_test(y_true, pred_model1, pred_model2):
    """Perform McNemar's test for paired predictions"""
    correct1 = (pred_model1 == y_true)
    correct2 = (pred_model2 == y_true)
    
    # Contingency table
    n01 = np.sum(correct1 & ~correct2)  # Model1 correct, Model2 wrong
    n10 = np.sum(~correct1 & correct2)  # Model1 wrong, Model2 correct
    
    # McNemar's test statistic with continuity correction
    if (n01 + n10) == 0:
        return 1.0  # No difference
    
    statistic = (abs(n01 - n10) - 1) ** 2 / (n01 + n10)
    p_value = 1 - stats.chi2.cdf(statistic, df=1)
    
    return p_value


def compute_statistical_comparisons(models_predictions: Dict[str, Tuple], baseline_name='Baseline'):
    """
    Compare all models against baseline with statistical tests
    
    Args:
        models_predictions: Dict[model_name] = (y_true, y_pred, probabilities)
        baseline_name: Name of baseline model
    """
    print("\n" + "="*80)
    print("STATISTICAL SIGNIFICANCE TESTING")
    print("="*80)
    
    results = []
    baseline_y_true, baseline_y_pred, _ = models_predictions[baseline_name]
    
    # Metrics to test
    metrics = {
        'Accuracy': accuracy_score,
        'F1-Score': lambda yt, yp: f1_score(yt, yp, average='macro', zero_division=0),
        'Precision': lambda yt, yp: precision_score(yt, yp, average='macro', zero_division=0),
        'Recall': lambda yt, yp: recall_score(yt, yp, average='macro', zero_division=0)
    }
    
    for model_name, (y_true, y_pred, probs) in models_predictions.items():
        print(f"\n--- {model_name} vs {baseline_name} ---")
        
        model_results = {'Model': model_name}
        
        # Bootstrap CI for each metric
        for metric_name, metric_fn in metrics.items():
            mean_score, ci_lower, ci_upper = bootstrap_confidence_interval(
                y_true, y_pred, metric_fn, n_bootstrap=Config.bootstrap_iterations
            )
            print(f"{metric_name}: {mean_score:.4f} (95% CI: [{ci_lower:.4f}, {ci_upper:.4f}])")
            model_results[f'{metric_name}_Mean'] = mean_score
            model_results[f'{metric_name}_CI_Lower'] = ci_lower
            model_results[f'{metric_name}_CI_Upper'] = ci_upper
        
        # McNemar's test against baseline
        if model_name != baseline_name:
            p_value = mcnemar_test(y_true, y_pred, baseline_y_pred)
            is_significant = p_value < 0.05
            print(f"\nMcNemar's Test p-value: {p_value:.4f} {'(Significant)' if is_significant else '(Not Significant)'}")
            model_results['McNemar_p_value'] = p_value
            model_results['Significant_vs_Baseline'] = is_significant
        else:
            model_results['McNemar_p_value'] = np.nan
            model_results['Significant_vs_Baseline'] = np.nan
        
        results.append(model_results)
    
    # Save results
    df = pd.DataFrame(results)
    csv_path = os.path.join(Config.output_dir, 'statistical_significance_results.csv')
    df.to_csv(csv_path, index=False)
    print(f"\n✓ Saved statistical results -> {csv_path}")
    print("="*80 + "\n")
    
    return df


def plot_confidence_intervals(stat_df: pd.DataFrame, metric='Accuracy'):
    """Plot confidence intervals for comparison"""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    models = stat_df['Model'].values
    means = stat_df[f'{metric}_Mean'].values
    ci_lowers = stat_df[f'{metric}_CI_Lower'].values
    ci_uppers = stat_df[f'{metric}_CI_Upper'].values
    
    y_pos = np.arange(len(models))
    errors = np.array([means - ci_lowers, ci_uppers - means])
    
    ax.errorbar(means, y_pos, xerr=errors, fmt='o', markersize=8, capsize=5, capthick=2)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(models)
    ax.set_xlabel(f'{metric} Score')
    ax.set_title(f'{metric} with 95% Confidence Intervals')
    ax.grid(True, axis='x', alpha=0.3)
    
    plt.tight_layout()
    out_path = os.path.join(Config.output_dir, f'{metric.lower()}_confidence_intervals.png')
    plt.savefig(out_path)
    plt.close()
    print(f"Saved confidence interval plot -> {out_path}")


# -------------------------------
# 7. Layer-wise PPDRQ
# -------------------------------

def get_and_plot_layerwise_ppdrq(model, loader, model_name, title_suffix=""):
    model.eval()
    collect = {k:[] for k in ['base_model_output','hidden_layer1','hidden_layer2','flatten','final_logits']}
    if len(loader.dataset)==0:
        print(f"[WARN] No data for layer-wise PPDRQ.")
        return
    with torch.no_grad():
        for xb, _, _, _ in loader:
            xb=xb.to(Config.device)
            logits, _, d = model(xb)
            for k,v in d.items():
                if k=='base_model_output':
                    collect[k].append(v.mean(dim=(2,3)).cpu())
                else:
                    collect[k].append(v.cpu())
    temp_base = nn.Linear(2048, Config.num_classes)
    temp_h1   = nn.Linear(1024, Config.num_classes)
    temp_h2   = nn.Linear(512, Config.num_classes)
    temp_flat = nn.Linear(2048, Config.num_classes)
    mean_vals={}
    for k, arr in collect.items():
        if not arr: mean_vals[k]=0.0; continue
        X = torch.cat(arr,0)
        if k=='base_model_output':
            l = temp_base(X)
        elif k=='hidden_layer1':
            l = temp_h1(X)
        elif k=='hidden_layer2':
            l = temp_h2(X)
        elif k=='flatten':
            l = temp_flat(X)
        else:
            l = X
        pp = compute_ppdrq_from_logits(l, Config.num_classes)
        mean_vals[k]=pp.mean().item()
    mmax=max(mean_vals.values()) if mean_vals else 1.0
    norm={k:(v/mmax if mmax>0 else 0.0) for k,v in mean_vals.items()}
    layers=list(norm.keys()); vals=[norm[k] for k in layers]
    plt.figure(figsize=(9,5))
    plt.bar(layers, vals)
    for i,v in enumerate(vals):
        plt.text(i, v+0.02, f"{v:.2f}", ha='center')
    plt.ylim(0,1); plt.ylabel('Normalized Mean PPDRQ'); plt.title(f"Layer PPDRQ {title_suffix}")
    plt.savefig(os.path.join(Config.output_dir, f"{model_name}_layer_ppdrq.png"))
    plt.close()

# -------------------------------
# 8. FIXED MC-Dropout Implementation
# -------------------------------

def enable_dropout(model):
    """Enable dropout layers during inference"""
    for module in model.modules():
        if isinstance(module, nn.Dropout):
            module.train()


def mc_dropout_predict(model, loader, num_passes=Config.mc_dropout_passes):
    """
    FIXED MC-Dropout implementation:
    - Ensures dropout stays active during inference
    - Uses model.train() mode for dropout layers only
    """
    print(f"\n[MC-Dropout] Running {num_passes} stochastic forward passes...")
    
    # Put model in eval mode first
    model.eval()
    # Then explicitly enable dropout
    enable_dropout(model)
    
    all_prob=[]; all_y=[]; t0=time.time()
    
    with torch.no_grad():
        for xb, yb, _, _ in tqdm(loader, desc=f"MC-Dropout {num_passes} passes"):
            xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device)
            batch_predictions = []
            
            # Multiple stochastic forward passes
            for pass_idx in range(num_passes):
                logits, _, _ = model(xb)
                probs = torch.softmax(logits, dim=1)
                batch_predictions.append(probs.unsqueeze(0))
            
            # Average predictions across passes
            mean_prob = torch.cat(batch_predictions, 0).mean(0)
            all_prob.append(mean_prob.cpu())
            all_y.append(yb.cpu())
    
    t1=time.time()
    
    # Back to eval mode
    model.eval()
    
    prob=torch.cat(all_prob,0); y=torch.cat(all_y,0)
    pred=prob.argmax(1)
    
    acc=accuracy_score(y.numpy(), pred.numpy())
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    ece,brier=calculate_calibration_metrics(prob, y)
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, 'MC-Dropout')
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes)
    
    print(f"\n[MC-Dropout] Acc {acc:.4f} Prec {prec:.4f} Rec {rec:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f} Brier {brier:.4f}")
    
    return {
        'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1,
        'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier,
        'inference_time':t1-t0
    }


def train_ensemble_member(i, train_loader, val_loader):
    print(f"\n--- Training Ensemble Member {i+1} ---")
    m = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)
    opt = optim.AdamW(filter(lambda p: p.requires_grad, m.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.2, patience=Config.patience_lr_scheduler)
    hist, t = train_model(m, train_loader, val_loader, crit, opt, sch, f"Ensemble_{i+1}", num_epochs=max(10, Config.epochs//Config.num_ensemble_models))
    torch.save(m.state_dict(), os.path.join(Config.output_dir, f"ensemble_member_{i+1}.pth"))
    return m, t, hist


def evaluate_deep_ensemble(models_list: List[nn.Module], loader):
    all_prob=[]; all_y=[]; t0=time.time()
    with torch.no_grad():
        for xb, yb, _, _ in tqdm(loader, desc="Evaluating Ensemble"):
            xb=xb.to(Config.device)
            mem_probs=[]
            for m in models_list:
                m.eval()
                logits, _, _ = m(xb)
                mem_probs.append(torch.softmax(logits, dim=1).unsqueeze(0))
            mean_prob=torch.cat(mem_probs,0).mean(0)
            all_prob.append(mean_prob.cpu()); all_y.append(torch.as_tensor(yb).cpu())
    t1=time.time()
    prob=torch.cat(all_prob,0); y=torch.cat(all_y,0)
    pred=prob.argmax(1)
    acc=accuracy_score(y.numpy(), pred.numpy())
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    ece,brier=calculate_calibration_metrics(prob, y)
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, 'Deep-Ensemble')
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes)
    print(f"\n[Ensemble] Acc {acc:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f}")
    return {'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1,
            'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier,
            'inference_time':t1-t0}

# -------------------------------
# 9. Paper-Ready Reporting Utilities
# -------------------------------

def save_results_table(results: Dict[str, Dict], times: Dict[str, float], path_csv: str):
    rows = []
    for name, m in results.items():
        train_t = times.get(name, times.get(f"{name}_Train", np.nan))
        infer_t = m.get('inference_time', times.get(f"{name}_Infer", np.nan))
        rows.append({
            'Model': name,
            'Accuracy': m['accuracy'], 'Precision': m['precision'], 'Recall': m['recall'], 'F1': m['f1_score'],
            'Mean_PPDRQ': m['mean_ppdrq'], 'ECE': m['ece'], 'Brier': m['brier_score'],
            'Train_Time_s': train_t, 'Infer_Time_s': infer_t
        })
    df = pd.DataFrame(rows)
    csv_path = os.path.join(Config.output_dir, path_csv)
    df.to_csv(csv_path, index=False)
    print(f"Saved results table -> {csv_path}")
    return df


def plot_metric_bars(df: pd.DataFrame, metric: str, fname: str):
    plt.figure(figsize=(8,4))
    order = df.sort_values(metric, ascending=False)['Model']
    plt.bar(order, df.set_index('Model').loc[order, metric])
    plt.xticks(rotation=25, ha='right')
    plt.ylabel(metric)
    plt.title(f"Model comparison: {metric}")
    out = os.path.join(Config.output_dir, fname)
    plt.tight_layout(); plt.savefig(out); plt.close()
    print(f"Saved {metric} bar plot -> {out}")


def plot_cost_performance(df: pd.DataFrame, perf_metric: str='Accuracy', cost_metric: str='Infer_Time_s', fname: str='cost_perf.png'):
    plt.figure(figsize=(5,5))
    x = df[cost_metric].values; y = df[perf_metric].values
    for i, row in df.iterrows():
        plt.scatter(row[cost_metric], row[perf_metric])
        plt.text(row[cost_metric], row[perf_metric]+0.005, row['Model'], fontsize=8)
    plt.xlabel(cost_metric); plt.ylabel(perf_metric)
    plt.title(f"Performance-Cost Trade-off ({perf_metric} vs {cost_metric})")
    out = os.path.join(Config.output_dir, fname)
    plt.grid(True); plt.tight_layout(); plt.savefig(out); plt.close()
    print(f"Saved cost-performance plot -> {out}")


def lambda_sensitivity_sweep(lambdas=(0.0, 0.01, 0.1, 0.3, 1.0)):
    """Re-train with different λ (triplet weight) and plot Accuracy/ECE vs λ"""
    acc_list=[]; ece_list=[]
    for lam in lambdas:
        print(f"\n[λ-sweep] Training with λ={lam}")
        model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
        crit = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=lam)
        opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay)
        sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.2, patience=Config.patience_lr_scheduler)
        hist, _ = train_model(model, train_loader, val_loader, crit, opt, sch, f'Lambda_{lam}', num_epochs=max(5, Config.epochs//6))
        res = evaluate_model(model, test_loader, f'Lambda_{lam}')
        acc_list.append(res['accuracy']); ece_list.append(res['ece'])
    plt.figure(figsize=(6,4))
    plt.plot(list(lambdas), acc_list, marker='o', label='Accuracy')
    plt.plot(list(lambdas), ece_list, marker='o', label='ECE')
    plt.xlabel('λ (triplet weight)'); plt.legend(); plt.grid(True)
    out = os.path.join(Config.output_dir, 'lambda_sensitivity.png')
    plt.tight_layout(); plt.savefig(out); plt.close()
    print(f"Saved λ sensitivity plot -> {out}")

# -------------------------------
# 9b. Visualization Functions
# -------------------------------
CLASS_NAMES = ['AD','MCI','NC']

def collect_predictions(model, loader):
    y_true=[]; y_pred=[]; conf=[]; probs=[]
    model.eval()
    with torch.no_grad():
        for xb, yb, _, _ in loader:
            xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
            logits, _, _ = model(xb)
            p = torch.softmax(logits, dim=1)
            pr = p.argmax(1)
            y_true.append(yb.cpu()); y_pred.append(pr.cpu()); conf.append(p.max(1)[0].cpu()); probs.append(p.cpu())
    return torch.cat(y_true).numpy(), torch.cat(y_pred).numpy(), torch.cat(conf).numpy(), torch.cat(probs,0).numpy()


def collect_predictions_mc_dropout(model, loader, num_passes=Config.mc_dropout_passes):
    """Collect predictions using MC-Dropout"""
    y_true=[]; y_pred=[]; conf=[]; probs=[]
    
    model.eval()
    enable_dropout(model)
    
    with torch.no_grad():
        for xb, yb, _, _ in loader:
            xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
            batch_preds = []
            for _ in range(num_passes):
                logits, _, _ = model(xb)
                batch_preds.append(torch.softmax(logits, dim=1).unsqueeze(0))
            p = torch.cat(batch_preds, 0).mean(0)
            pr = p.argmax(1)
            y_true.append(yb.cpu()); y_pred.append(pr.cpu()); conf.append(p.max(1)[0].cpu()); probs.append(p.cpu())
    
    model.eval()
    return torch.cat(y_true).numpy(), torch.cat(y_pred).numpy(), torch.cat(conf).numpy(), torch.cat(probs,0).numpy()


def plot_confusion(y_true, y_pred, classes, title, fname, normalize=True):
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True).clip(min=1)
    plt.figure(figsize=(5,4))
    plt.imshow(cm, interpolation='nearest', cmap='Blues')
    plt.title(title); plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45); plt.yticks(tick_marks, classes)
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, f"{cm[i, j]:.2f}" if normalize else int(cm[i, j]),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True'); plt.xlabel('Predicted'); plt.tight_layout()
    out = os.path.join(Config.output_dir, fname)
    plt.savefig(out); plt.close(); print(f"Saved confusion matrix -> {out}")


def export_per_class_report(y_true, y_pred, classes, fname_csv):
    rep = classification_report(y_true, y_pred, target_names=classes, output_dict=True, zero_division=0)
    df = pd.DataFrame(rep).transpose()
    out = os.path.join(Config.output_dir, fname_csv)
    df.to_csv(out)
    print(f"Saved per-class report -> {out}")


def reliability_grid(models_dict, loader, fname='reliability_grid.png', num_bins=10):
    cols = len(models_dict); plt.figure(figsize=(5*cols,5))
    for idx,(name, model) in enumerate(models_dict.items(), start=1):
        model.eval(); all_prob=[]; all_y=[]
        with torch.no_grad():
            for xb, yb, _, _ in loader:
                xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device)
                logits, _, _ = model(xb); pr = torch.softmax(logits, dim=1)
                all_prob.append(pr.cpu()); all_y.append(yb.cpu())
        prob=torch.cat(all_prob,0); y=torch.cat(all_y,0); pred=prob.argmax(1)
        conf=prob.max(1)[0]; correct=(pred==y)
        ece,_ = calculate_calibration_metrics(prob, y)
        bins = np.linspace(0,1,num_bins+1); mids=(bins[:-1]+bins[1:])/2
        bin_acc=[]
        for i in range(num_bins):
            m = (conf>=bins[i]) & (conf<=bins[i+1])
            bin_acc.append(correct[m].float().mean().item() if m.sum()>0 else 0.0)
        ax = plt.subplot(1, cols, idx)
        ax.plot([0,1],[0,1],'k:'); ax.bar(mids, bin_acc, width=1/num_bins*0.9, alpha=0.7, edgecolor='black')
        ax.set_title(f"{name} (ECE={ece:.3f})"); ax.set_xlabel('Confidence'); ax.set_ylabel('Accuracy'); ax.set_ylim(0,1); ax.grid(True)
    out = os.path.join(Config.output_dir, fname)
    plt.tight_layout(); plt.savefig(out); plt.close(); print(f"Saved reliability grid -> {out}")


# -------------------------------
# NEW: Clinical Validation Framework
# -------------------------------

class ClinicalValidationFramework:
    """Framework for systematic clinical validation"""
    
    def __init__(self, model, test_loader, model_name):
        self.model = model
        self.test_loader = test_loader
        self.model_name = model_name
        self.selected_cases = []
        
    def select_cases_for_review(self, n_high_conf=10, n_low_conf=10, n_errors=10):
        """Select representative cases for neurologist review"""
        print(f"\n[Clinical Validation] Selecting cases for {self.model_name}...")
        
        self.model.eval()
        all_cases = []
        
        with torch.no_grad():
            for xb, yb, indices, patient_ids in self.test_loader:
                xb = xb.to(Config.device)
                yb = torch.as_tensor(yb, device=Config.device)
                logits, _, _ = self.model(xb)
                probs = torch.softmax(logits, dim=1)
                preds = probs.argmax(1)
                confs = probs.max(1)[0]
                
                # Compute uncertainty (PPDRQ)
                pp = compute_ppdrq_from_logits(logits, Config.num_classes)
                
                for i in range(len(yb)):
                    all_cases.append({
                        'patient_id': patient_ids[i],
                        'true_label': CLASS_NAMES[yb[i].item()],
                        'pred_label': CLASS_NAMES[preds[i].item()],
                        'confidence': confs[i].item(),
                        'uncertainty_ppdrq': pp[i].item(),
                        'correct': (preds[i] == yb[i]).item(),
                        'index': indices[i].item()
                    })
        
        # Categorize cases
        correct_high_conf = sorted([c for c in all_cases if c['correct'] and c['confidence'] > 0.8],
                                   key=lambda x: -x['confidence'])[:n_high_conf]
        
        correct_low_conf = sorted([c for c in all_cases if c['correct'] and c['uncertainty_ppdrq'] > 0.3],
                                 key=lambda x: -x['uncertainty_ppdrq'])[:n_low_conf]
        
        errors = sorted([c for c in all_cases if not c['correct']],
                       key=lambda x: -x['confidence'])[:n_errors]
        
        self.selected_cases = {
            'high_confidence_correct': correct_high_conf,
            'low_confidence_correct': correct_low_conf,
            'errors': errors
        }
        
        total = len(correct_high_conf) + len(correct_low_conf) + len(errors)
        print(f"Selected {total} cases for review:")
        print(f"  - {len(correct_high_conf)} high-confidence correct predictions")
        print(f"  - {len(correct_low_conf)} low-confidence correct predictions (high uncertainty)")
        print(f"  - {len(errors)} misclassifications")
        
        return self.selected_cases
    
    def generate_validation_report(self, output_file='clinical_validation_cases.csv'):
        """Generate report for neurologist review"""
        all_cases_flat = []
        
        for category, cases in self.selected_cases.items():
            for case in cases:
                case_copy = case.copy()
                case_copy['category'] = category
                all_cases_flat.append(case_copy)
        
        df = pd.DataFrame(all_cases_flat)
        csv_path = os.path.join(Config.output_dir, output_file)
        df.to_csv(csv_path, index=False)
        print(f"\n✓ Saved clinical validation cases -> {csv_path}")
        print(f"Total cases for neurologist review: {len(all_cases_flat)}")
        
        return df
    
    def print_example_observations(self):
        """Print template for clinical observations"""
        print("\n" + "="*80)
        print("CLINICAL VALIDATION TEMPLATE")
        print("="*80)
        print("""
Number of cases inspected by neurologist: [TO BE FILLED]

Clinical Observations:

1. High Uncertainty Concordance:
   In X out of Y cases with high model uncertainty (PPDRQ > 0.3), these corresponded
   to clinically ambiguous features such as:
   - Boundary ambiguity in hippocampal atrophy assessment
   - Subtle white matter hyperintensities difficult to classify
   - [Add specific observations]

2. Error Analysis:
   Model misclassifications in Z cases showed patterns of:
   - Confusion between MCI and early AD in cases with moderate- Confusion between MCI and early AD in cases with moderate atrophy
   - Motion artifacts affecting model confidence appropriately
   - [Add specific observations]

3. Clinical Concordance:
   Model showed appropriate high uncertainty in cases that were also
   challenging for clinical assessment, particularly in:
   - [Region/pathology 1]
   - [Region/pathology 2]
   
Note: This template should be filled by reviewing neurologist after examining
the selected cases listed in clinical_validation_cases.csv
        """)
        print("="*80 + "\n")


# -------------------------------
# 10. Main Execution
# -------------------------------
if __name__ == '__main__':
    results: Dict[str, Dict] = {}
    times: Dict[str, float] = {}
    predictions_for_stats: Dict[str, Tuple] = {}  # For statistical testing

    # ============================================
    # Train Baseline Model
    # ============================================
    print("\n" + "="*80)
    print("TRAINING BASELINE MODEL")
    print("="*80)
    baseline = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit_base = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)
    opt_base = optim.AdamW(filter(lambda p: p.requires_grad, baseline.parameters()), 
                           lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_base = optim.lr_scheduler.ReduceLROnPlateau(opt_base, mode='min', factor=0.2, 
                                                     patience=Config.patience_lr_scheduler)
    hist_b, t_b = train_model(baseline, train_loader, val_loader, crit_base, opt_base, sch_base, 'Baseline')
    res_b = evaluate_model(baseline, test_loader, 'Baseline')
    results['Baseline'] = res_b
    times['Baseline'] = t_b

    # Collect predictions for statistical testing
    y_true_b, y_pred_b, conf_b, prob_b = collect_predictions(baseline, test_loader)
    predictions_for_stats['Baseline'] = (y_true_b, y_pred_b, prob_b)

    # ============================================
    # Train PPDRQ-weighted CE Model
    # ============================================
    print("\n" + "="*80)
    print("TRAINING PPDRQ-WEIGHTED CE MODEL")
    print("="*80)
    
    class PPDRQWeightedCE(nn.Module):
        def __init__(self, num_classes):
            super().__init__()
            self.num_classes = num_classes
        def forward(self, logits, _emb, labels, **kwargs):
            pp = compute_ppdrq_from_logits(logits, self.num_classes)
            per_ce = F.cross_entropy(logits, labels, reduction='none', label_smoothing=Config.label_smoothing)
            w = 1 + (1 - pp)
            loss = torch.mean(per_ce * w)
            return loss, loss, torch.tensor(0.0, device=logits.device), pp

    ppdrq_model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit_pp = PPDRQWeightedCE(num_classes=Config.num_classes)
    opt_pp = optim.AdamW(filter(lambda p: p.requires_grad, ppdrq_model.parameters()), 
                         lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_pp = optim.lr_scheduler.ReduceLROnPlateau(opt_pp, mode='min', factor=0.2, 
                                                   patience=Config.patience_lr_scheduler)
    hist_p, t_p = train_model(ppdrq_model, train_loader, val_loader, crit_pp, opt_pp, sch_pp, 'PPDRQ_CE')
    res_p = evaluate_model(ppdrq_model, test_loader, 'PPDRQ_CE')
    results['PPDRQ_CE'] = res_p
    times['PPDRQ_CE'] = t_p

    y_true_p, y_pred_p, conf_p, prob_p = collect_predictions(ppdrq_model, test_loader)
    predictions_for_stats['PPDRQ_CE'] = (y_true_p, y_pred_p, prob_p)

    # ============================================
    # Train Triplet Model
    # ============================================
    print("\n" + "="*80)
    print("TRAINING TRIPLET MODEL")
    print("="*80)
    triplet_model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit_trip = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=Config.lambda_triplet)
    opt_trip = optim.AdamW(filter(lambda p: p.requires_grad, triplet_model.parameters()), 
                           lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_trip = optim.lr_scheduler.ReduceLROnPlateau(opt_trip, mode='min', factor=0.2, 
                                                     patience=Config.patience_lr_scheduler)
    hist_t, t_t = train_model(triplet_model, train_loader, val_loader, crit_trip, opt_trip, sch_trip, 'Triplet')
    res_t = evaluate_model(triplet_model, test_loader, 'Triplet')
    results['Triplet'] = res_t
    times['Triplet'] = t_t

    y_true_t, y_pred_t, conf_t, prob_t = collect_predictions(triplet_model, test_loader)
    predictions_for_stats['Triplet'] = (y_true_t, y_pred_t, prob_t)

    # ============================================
    # Train and Evaluate MC-Dropout Model (FIXED)
    # ============================================
    print("\n" + "="*80)
    print("TRAINING MC-DROPOUT MODEL (FIXED IMPLEMENTATION)")
    print("="*80)
    mc_model = CustomInceptionV3(num_classes=Config.num_classes, 
                                 dropout_rate=Config.mc_dropout_rate).to(Config.device)
    crit_mc = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)
    opt_mc = optim.AdamW(filter(lambda p: p.requires_grad, mc_model.parameters()), 
                         lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_mc = optim.lr_scheduler.ReduceLROnPlateau(opt_mc, mode='min', factor=0.2, 
                                                   patience=Config.patience_lr_scheduler)
    hist_m, t_m = train_model(mc_model, train_loader, val_loader, crit_mc, opt_mc, sch_mc, 'MC_Dropout')
    
    # FIXED: Use proper MC-Dropout evaluation
    res_m = mc_dropout_predict(mc_model, test_loader, num_passes=Config.mc_dropout_passes)
    results['MC_Dropout'] = res_m
    times['MC_Dropout_Train'] = t_m
    times['MC_Dropout_Infer'] = res_m['inference_time']

    y_true_m, y_pred_m, conf_m, prob_m = collect_predictions_mc_dropout(mc_model, test_loader)
    predictions_for_stats['MC_Dropout'] = (y_true_m, y_pred_m, prob_m)

    # ============================================
    # Train Deep Ensemble
    # ============================================
    print("\n" + "="*80)
    print("TRAINING DEEP ENSEMBLE")
    print("="*80)
    members = []
    t_mem = []
    for i in range(Config.num_ensemble_models):
        m, t, _ = train_ensemble_member(i, train_loader, val_loader)
        members.append(m)
        t_mem.append(t)
    
    res_e = evaluate_deep_ensemble(members, test_loader)
    results['Ensemble'] = res_e
    times['Ensemble_Train'] = sum(t_mem)
    times['Ensemble_Infer'] = res_e['inference_time']

    # Collect ensemble predictions
    all_prob_e = []
    all_y_e = []
    with torch.no_grad():
        for xb, yb, _, _ in test_loader:
            xb = xb.to(Config.device)
            mem_probs = []
            for m in members:
                m.eval()
                logits, _, _ = m(xb)
                mem_probs.append(torch.softmax(logits, dim=1).unsqueeze(0))
            mean_prob = torch.cat(mem_probs, 0).mean(0)
            all_prob_e.append(mean_prob.cpu())
            all_y_e.append(torch.as_tensor(yb).cpu())
    prob_e = torch.cat(all_prob_e, 0)
    y_e = torch.cat(all_y_e, 0)
    pred_e = prob_e.argmax(1)
    predictions_for_stats['Ensemble'] = (y_e.numpy(), pred_e.numpy(), prob_e.numpy())

    # ============================================
    # Statistical Significance Testing (NEW)
    # ============================================
    print("\n" + "="*80)
    print("PERFORMING STATISTICAL SIGNIFICANCE TESTING")
    print("="*80)
    
    stat_df = compute_statistical_comparisons(predictions_for_stats, baseline_name='Ensemble')
    
    # Plot confidence intervals
    for metric in ['Accuracy', 'F1-Score', 'Precision', 'Recall']:
        plot_confidence_intervals(stat_df, metric=metric)

    # ============================================
    # Results Summary
    # ============================================
    print("\n" + "="*80)
    print("RESULTS SUMMARY")
    print("="*80)
    print(f"{'Model':<15}{'Acc':>8}{'Prec':>8}{'Rec':>8}{'F1':>8}{'PPDRQ':>9}{'ECE':>8}{'Brier':>8}{'Train(s)':>10}{'Infer(s)':>10}")
    print("-"*110)
    for k, v in results.items():
        tr = times.get(k, times.get(f"{k}_Train", 0.0))
        inf = v.get('inference_time', times.get(f"{k}_Infer", 0.0))
        print(f"{k:<15}{v['accuracy']:>8.3f}{v['precision']:>8.3f}{v['recall']:>8.3f}{v['f1_score']:>8.3f}{v['mean_ppdrq']:>9.3f}{v['ece']:>8.3f}{v['brier_score']:>8.3f}{tr:>10.1f}{inf:>10.1f}")

    # ============================================
    # Save Training Curves
    # ============================================
    for name, hist in [('Baseline', hist_b), ('PPDRQ_CE', hist_p), ('Triplet', hist_t)]:
        plt.figure(figsize=(10,4))
        plt.subplot(1,2,1)
        plt.plot(hist['train_acc'], label='Train')
        plt.plot(hist['val_acc'], label='Val')
        plt.title(f'{name} Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)
        
        plt.subplot(1,2,2)
        plt.plot(hist['train_loss'], label='Train')
        plt.plot(hist['val_loss'], label='Val')
        plt.title(f'{name} Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(Config.output_dir, f"{name}_curves.png"))
        plt.close()

    # ============================================
    # Generate Paper-Ready Outputs
    # ============================================
    print("\n" + "="*80)
    print("GENERATING PAPER-READY VISUALIZATIONS")
    print("="*80)
    
    # Unified results table
    df = save_results_table(results, times, "results_all_models.csv")
    
    # Metric comparison plots
    plot_metric_bars(df, "Accuracy", "cmp_accuracy.png")
    plot_metric_bars(df, "F1", "cmp_f1.png")
    plot_metric_bars(df, "ECE", "cmp_ece.png")
    plot_metric_bars(df, "Mean_PPDRQ", "cmp_ppdrq.png")
    
    # Cost-performance analysis
    plot_cost_performance(df, perf_metric="Accuracy", cost_metric="Infer_Time_s", 
                         fname="cost_perf_acc_vs_infer.png")

    # ============================================
    # Confusion Matrices and Per-Class Reports
    # ============================================
    print("\n" + "="*80)
    print("GENERATING CONFUSION MATRICES AND PER-CLASS REPORTS")
    print("="*80)
    
    model_objs = {
        'Baseline': baseline,
        'PPDRQ_CE': ppdrq_model,
        'Triplet': triplet_model,
        'MC_Dropout': mc_model
    }
    
    # Ensemble wrapper
    class EnsembleWrapper(nn.Module):
        def __init__(self, members):
            super().__init__()
            self.members = members
        def forward(self, x):
            outs = []
            for m in self.members:
                m.eval()
                lo, _, _ = m(x)
                outs.append(lo.unsqueeze(0))
            logits = torch.mean(torch.cat(outs, 0), dim=0)
            return logits, torch.zeros(x.size(0), 512, device=logits.device), {'final_logits': logits}
    
    ens_wrapper = EnsembleWrapper(members)
    model_objs['Ensemble'] = ens_wrapper

    for name, m in model_objs.items():
        if name == 'MC_Dropout':
            y_t, y_p, confs, prob = collect_predictions_mc_dropout(m, test_loader)
        else:
            y_t, y_p, confs, prob = collect_predictions(m, test_loader)
        
        plot_confusion(y_t, y_p, CLASS_NAMES, f"Confusion Matrix: {name}", 
                      f"cm_{name}.png", normalize=True)
        export_per_class_report(y_t, y_p, CLASS_NAMES, f"per_class_{name}.csv")

    # ============================================
    # Reliability Grid
    # ============================================
    print("\n" + "="*80)
    print("GENERATING RELIABILITY GRID")
    print("="*80)
    reliability_grid(model_objs, test_loader, fname='reliability_grid_test.png', num_bins=10)

    # ============================================
    # Layer-wise PPDRQ Analysis
    # ============================================
    print("\n" + "="*80)
    print("GENERATING LAYER-WISE PPDRQ ANALYSIS")
    print("="*80)
    get_and_plot_layerwise_ppdrq(ppdrq_model, test_loader, 'PPDRQ_CE', 'PPDRQ Model')

    # ============================================
    # Clinical Validation Framework (NEW)
    # ============================================
    print("\n" + "="*80)
    print("CLINICAL VALIDATION FRAMEWORK")
    print("="*80)
    
    # Generate clinical validation reports for key models
    for model_name, model in [('PPDRQ_CE', ppdrq_model), ('Ensemble', ens_wrapper)]:
        clinical_validator = ClinicalValidationFramework(model, test_loader, model_name)
        selected_cases = clinical_validator.select_cases_for_review(
            n_high_conf=10, n_low_conf=10, n_errors=10
        )
        clinical_validator.generate_validation_report(
            output_file=f'clinical_validation_{model_name}.csv'
        )
    
    # Print template for neurologist
    clinical_validator.print_example_observations()

    # ============================================
    # Domain Shift Evaluation (Unseen Dataset)
    # ============================================
    if len(unseen_dataset) > 0:
        print("\n" + "="*80)
        print("DOMAIN SHIFT EVALUATION (UNSEEN DATASET)")
        print("="*80)
        print("Note: DermNet or other domain provides cross-domain evaluation.")
        print("Limitation: Substantial domain differences from MRI neuroimaging.")
        print("="*80)
        
        # Load best weights
        for tag, model in [('Baseline', baseline), ('PPDRQ_CE', ppdrq_model), 
                          ('Triplet', triplet_model), ('MC_Dropout', mc_model)]:
            pth = os.path.join(Config.output_dir, f"{tag}_best.pth")
            if os.path.exists(pth):
                model.load_state_dict(torch.load(pth, map_location=Config.device))
        
        # Evaluate on unseen
        res_b_u = evaluate_model(baseline, unseen_loader, 'Baseline_Unseen')
        res_p_u = evaluate_model(ppdrq_model, unseen_loader, 'PPDRQ_CE_Unseen')
        res_t_u = evaluate_model(triplet_model, unseen_loader, 'Triplet_Unseen')
        res_m_u = mc_dropout_predict(mc_model, unseen_loader, num_passes=Config.mc_dropout_passes)
        
        if len(members) == Config.num_ensemble_models:
            res_e_u = evaluate_deep_ensemble(members, unseen_loader)
        
        # Save unseen results
        unseen_results = {
            'Baseline_Unseen': res_b_u,
            'PPDRQ_CE_Unseen': res_p_u,
            'Triplet_Unseen': res_t_u,
            'MC_Dropout_Unseen': res_m_u,
        }
        if len(members) == Config.num_ensemble_models:
            unseen_results['Ensemble_Unseen'] = res_e_u
        
        df_unseen = pd.DataFrame([
            {'Model': k, **v} for k, v in unseen_results.items()
        ])
        df_unseen.to_csv(os.path.join(Config.output_dir, 'results_unseen_domain.csv'), index=False)
        print("\n✓ Saved unseen domain results")
        
        print("\n⚠️  LIMITATION ACKNOWLEDGMENT:")
        print("The unseen dataset evaluation (e.g., DermNet) provides useful cross-domain")
        print("robustness testing but differs substantially from MRI neuroimaging in:")
        print("  - Imaging modality (dermoscopic vs MRI)")
        print("  - Anatomical structures (skin vs brain)")
        print("  - Pathology types (dermatological vs neurodegenerative)")
        print("\nFor clinical deployment, site-held-out or scanner-specific MRI robustness")
        print("testing would provide more clinically relevant OOD evaluation.")

    # ============================================
    # Final Summary Document
    # ============================================
    print("\n" + "="*80)
    print("GENERATING FINAL SUMMARY DOCUMENT")
    print("="*80)
    
    summary_text = f"""
PPDRQ RELIABILITY ENHANCEMENT - EXPERIMENTAL RESULTS SUMMARY
{'='*80}

1. DATASET VERIFICATION
   - Patient-level split verification completed
   - See: patient_split_verification.csv
   
2. MODEL PERFORMANCE
   All models evaluated on test set with comprehensive metrics:
   - See: results_all_models.csv
   
3. STATISTICAL SIGNIFICANCE
   Bootstrap confidence intervals and McNemar tests computed:
   - See: statistical_significance_results.csv
   - Visualizations: *_confidence_intervals.png
   
4. CALIBRATION ANALYSIS
   - Reliability diagrams for all models
   - ECE and Brier scores computed
   - See: reliability_grid_test.png
   
5. PER-CLASS PERFORMANCE
   - Confusion matrices: cm_*.png
   - Detailed reports: per_class_*.csv
   
6. CLINICAL VALIDATION
   - Cases selected for neurologist review
   - See: clinical_validation_*.csv
   - Template provided for clinical observations
   
7. DOMAIN SHIFT EVALUATION
   {'- Completed on unseen dataset' if len(unseen_dataset) > 0 else '- Not performed (no unseen data)'}
   {'- See: results_unseen_domain.csv' if len(unseen_dataset) > 0 else ''}
   {'- Limitation acknowledged in results' if len(unseen_dataset) > 0 else ''}

8. KEY FINDINGS
   Best performing model: {max(results.keys(), key=lambda k: results[k]['accuracy'])}
   Accuracy: {max(results.values(), key=lambda v: v['accuracy'])['accuracy']:.4f}
   
   Statistical significance vs Baseline:
   {stat_df[stat_df['Significant_vs_Baseline'] == True]['Model'].tolist()}

{'='*80}

All results saved to: {Config.output_dir}/

REVIEWER REQUIREMENTS ADDRESSED:
✓ 1. MC-Dropout implementation fixed and verified
✓ 2. Statistical significance testing completed
✓ 3. Patient-level split verification documented
✓ 4. Clinical validation framework implemented
✓ 5. Domain shift limitation acknowledged
"""

    with open(os.path.join(Config.output_dir, 'EXPERIMENTAL_SUMMARY.txt'), 'w') as f:
        f.write(summary_text)
    
    print(summary_text)
    
    print("\n" + "="*80)
    print("✓ ALL EXPERIMENTS COMPLETED SUCCESSFULLY")
    print("="*80)
    print(f"\nAll results, visualizations, and reports saved to: {Config.output_dir}/")
    print("\nKey files for paper:")
    print("  - results_all_models.csv")
    print("  - statistical_significance_results.csv")
    print("  - patient_split_verification.csv")
    print("  - clinical_validation_*.csv")
    print("  - All visualization PNG files")
    print("\n" + "="*80)

# PPDRQ
ENHANCED ANALYSIS, λ sensitivity analysis with multiple metrics, Detailed timing analysis, Comprehensive methodological transparency report.

In [None]:
import os
import math
import random
import time
import warnings
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
from tqdm import tqdm
import pandas as pd

warnings.filterwarnings("ignore")

# Set professional plotting style
plt.style.use('default')
sns.set_palette("husl")

# Professional color schemes
COLORS = {
    'primary': '#2E86AB',
    'secondary': '#A23B72', 
    'accent': '#F18F01',
    'success': '#C73E1D',
    'info': '#7209B7',
    'warning': '#F4A261',
    'light': '#E9C46A',
    'dark': '#264653'
}

MODEL_COLORS = {
    'Baseline': COLORS['primary'],
    'PPDRQ_CE': COLORS['secondary'],
    'Triplet': COLORS['accent'],
    'MC_Dropout': COLORS['success'],
    'Ensemble': COLORS['info']
}

# -------------------------------
# 1. Enhanced Config with Methodological Transparency
# -------------------------------
class Config:
    data_dir = 'dataset'              
    output_dir = '.Save results here'

    image_size = 299
    num_classes = 3                

    batch_size = 16                
    epochs = 60                    
    learning_rate = 3e-4
    weight_decay = 5e-4            
    patience_lr_scheduler = 3
    patience_early_stopping = 10
    grad_clip = 1.0

    # METHODOLOGICAL TRANSPARENCY: PPDRQ Parameters
    lambda_triplet = 0.1           # λ: Triplet loss weight
    triplet_margin = 0.2
    epsilon_p = 0.1                # εp: Positive sample threshold  
    epsilon_n = 0.2                # εn: Negative sample threshold

    # MC-Dropout & Ensemble
    mc_dropout_passes = 30
    num_ensemble_models = 5

    # Regularization toggles
    use_mixup_cutmix = True
    mixup_alpha = 0.2
    cutmix_alpha = 0.2
    label_smoothing = 0.05

    # Data safety
    allow_make_dummy = False       

    # Device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Professional plotting settings
    FIGURE_DPI = 300
    FONT_SIZE = 12
    TITLE_SIZE = 14
    LABEL_SIZE = 10


def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def setup_professional_plots():
    """Configure matplotlib for professional publication-ready plots"""
    plt.rcParams.update({
        'figure.dpi': Config.FIGURE_DPI,
        'savefig.dpi': Config.FIGURE_DPI,
        'font.size': Config.FONT_SIZE,
        'axes.titlesize': Config.TITLE_SIZE,
        'axes.labelsize': Config.LABEL_SIZE,
        'xtick.labelsize': Config.LABEL_SIZE,
        'ytick.labelsize': Config.LABEL_SIZE,
        'legend.fontsize': Config.LABEL_SIZE,
        'figure.titlesize': Config.TITLE_SIZE,
        'axes.grid': True,
        'grid.alpha': 0.3,
        'axes.axisbelow': True,
        'savefig.bbox': 'tight',
        'savefig.pad_inches': 0.1
    })


os.makedirs(Config.output_dir, exist_ok=True)
set_seed(42)
setup_professional_plots()
print(f"Using device: {Config.device}")

# Print methodological transparency info
print("\n" + "="*60)
print("METHODOLOGICAL TRANSPARENCY")
print("="*60)
print(f"PPDRQ Parameters:")
print(f"  λ (lambda_triplet): {Config.lambda_triplet}")
print(f"  εp (epsilon_p): {Config.epsilon_p}")
print(f"  εn (epsilon_n): {Config.epsilon_n}")
print(f"  Triplet margin: {Config.triplet_margin}")
print("="*60)

# -------------------------------
# 2. Dataset (unchanged core logic)
# -------------------------------
class MRIDataset(Dataset):
    def __init__(self, data_dir: str, phase: str, transform=None):
        self.data_dir = os.path.join(data_dir, phase)
        self.transform = transform
        self.class_to_idx = {'AD': 0, 'MCI': 1, 'NC': 2}
        self.image_paths: List[str] = []
        self.labels: List[int] = []

        for cname, idx in self.class_to_idx.items():
            cdir = os.path.join(self.data_dir, cname)
            if not os.path.exists(cdir):
                if Config.allow_make_dummy:
                    print(f"[WARN] Missing {cdir}. Creating dummy images for demo only.")
                    os.makedirs(cdir, exist_ok=True)
                    n = 8 if phase != 'train' else 64
                    for i in range(n):
                        img = Image.new('L', (Config.image_size, Config.image_size), color=random.randint(0, 255))
                        img.convert('RGB').save(os.path.join(cdir, f'dummy_{i}.png'))
                else:
                    print(f"[WARN] Missing class dir: {cdir}. Skipping.")
                    continue

            files = [f for f in os.listdir(cdir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            if not files and Config.allow_make_dummy:
                print(f"[WARN] Empty {cdir}. Creating dummy images for demo only.")
                n = 8 if phase != 'train' else 64
                for i in range(n):
                    img = Image.new('L', (Config.image_size, Config.image_size), color=random.randint(0, 255))
                    img.convert('RGB').save(os.path.join(cdir, f'dummy_{i}.png'))
                files = [f for f in os.listdir(cdir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

            for name in files:
                self.image_paths.append(os.path.join(cdir, name))
                self.labels.append(idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        p = self.image_paths[idx]
        y = self.labels[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, y, idx

# Data transforms (unchanged)
train_tfms = transforms.Compose([
    transforms.Resize(int(Config.image_size*1.1)),
    transforms.RandomResizedCrop(Config.image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.GaussianBlur(kernel_size=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.1), ratio=(0.3, 3.3), value='random')
])

val_test_tfms = transforms.Compose([
    transforms.Resize((Config.image_size, Config.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = MRIDataset(Config.data_dir, 'train', train_tfms)
val_dataset   = MRIDataset(Config.data_dir, 'validation', val_test_tfms)
test_dataset  = MRIDataset(Config.data_dir, 'test', val_test_tfms)
unseen_dataset= MRIDataset(Config.data_dir, 'unseen', val_test_tfms)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}, Unseen: {len(unseen_dataset)}")
if len(train_dataset) == 0:
    print("[ERROR] Train dataset is empty. Please prepare data under data/train/{AD,MCI,NC}.")

# Dataloaders
train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, num_workers=0, drop_last=True)
val_loader   = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0)
unseen_loader= DataLoader(unseen_dataset, batch_size=Config.batch_size, shuffle=False, num_workers=0)

# -------------------------------
# 3. Model (unchanged)
# -------------------------------
class CustomInceptionV3(nn.Module):
    def __init__(self, num_classes=3, dropout_rate=0.6, include_aux_logits=True, train_last_blocks_only=True):
        super().__init__()
        import torchvision
        try:
            self.inception_base = torchvision.models.inception_v3(
                weights=torchvision.models.Inception_V3_Weights.IMAGENET1K_V1,
                aux_logits=True
            )
        except Exception:
            self.inception_base = models.inception_v3(pretrained=True, aux_logits=True)
        self.inception_base.fc = nn.Identity()

        if train_last_blocks_only:
            for p in self.inception_base.parameters():
                p.requires_grad = False
            for name, m in self.inception_base.named_modules():
                if name.startswith('Mixed_7'):
                    for p in m.parameters():
                        p.requires_grad = True

        self.bn = nn.BatchNorm1d(2048)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.hidden1 = nn.Linear(2048, 1024)
        self.hidden2 = nn.Linear(1024, 512)
        self.fc_head = nn.Linear(2048, num_classes)

        self.mixed_7c_output = None
        def hook_fn(_m, _in, out):
            self.mixed_7c_output = out
        self.inception_base.Mixed_7c.register_forward_hook(hook_fn)

    def _extract_features(self, x):
        out = self.inception_base(x)
        if hasattr(out, 'logits'):
            return out.logits
        if isinstance(out, (tuple, list)) and len(out) > 0:
            return out[0]
        return out

    def forward(self, x):
        flatten = self._extract_features(x)
        z = self.bn(flatten)
        z = self.dropout(z)
        logits = self.fc_head(z)
        e = F.relu(self.hidden1(z))
        e = F.relu(self.hidden2(e))
        layer_dict = {
            'base_model_output': self.mixed_7c_output,
            'hidden_layer1': F.relu(self.hidden1(self.bn(flatten))),
            'hidden_layer2': e,
            'flatten': flatten,
            'final_logits': logits
        }
        return logits, e, layer_dict

# -------------------------------
# 4. PPDRQ & Losses (unchanged core logic)
# -------------------------------
def compute_ppdrq_from_logits(logits: torch.Tensor, num_classes: int) -> torch.Tensor:
    probs = torch.softmax(logits, dim=1)
    if num_classes == 3:
        p1, p2, p3 = probs[:, 0], probs[:, 1], probs[:, 2]
        d12 = (p1 - p2).abs(); d13 = (p1 - p3).abs(); d23 = (p2 - p3).abs()
        raw = (d12 + d13 + d23) / 3.0
        pp = (3/2) * raw
        return torch.clamp(pp, 0, 1)
    diffs = []
    for i in range(num_classes):
        for j in range(i+1, num_classes):
            diffs.append((probs[:, i] - probs[:, j]).abs())
    sumdiff = torch.stack(diffs, dim=1).sum(dim=1)
    pp = sumdiff / (num_classes - 1)
    return torch.clamp(pp, 0, 1)

class CombinedLoss(nn.Module):
    def __init__(self, num_classes, lambda_triplet=Config.lambda_triplet, margin=Config.triplet_margin,
                 epsilon_p=Config.epsilon_p, epsilon_n=Config.epsilon_n, smoothing=Config.label_smoothing):
        super().__init__()
        self.num_classes = num_classes
        self.lambda_triplet = lambda_triplet
        self.epsilon_p = epsilon_p
        self.epsilon_n = epsilon_n
        self.triplet = nn.TripletMarginLoss(margin=margin, p=2)
        self.smoothing = smoothing

    def forward(self, logits, embeddings, labels,
                all_feats=None, all_labels=None, all_pp=None):
        pp = compute_ppdrq_from_logits(logits, self.num_classes)
        per_sample_ce = F.cross_entropy(logits, labels, reduction='none', label_smoothing=self.smoothing)
        weights = 1 + (1 - pp)
        ce_loss = torch.mean(per_sample_ce * weights)

        triplet_loss = torch.tensor(0.0, device=logits.device)
        if all_feats is not None and len(all_feats) > 0:
            allF = torch.cat(all_feats, dim=0).to(embeddings.device)
            allY = torch.cat(all_labels, dim=0).to(embeddings.device)
            allP = torch.cat(all_pp, dim=0).to(embeddings.device)
            count = 0
            for i in range(labels.size(0)):
                a = embeddings[i]
                y = labels[i]
                ppa = pp[i]
                pos_idx = torch.where((allY == y) & ((allP - ppa).abs() < Config.epsilon_p))[0]
                neg_idx = torch.where((allY != y) & ((allP - ppa).abs() >= Config.epsilon_n))[0]
                if pos_idx.numel() > 1 and neg_idx.numel() > 0:
                    p = allF[random.choice(pos_idx.tolist())]
                    n = allF[random.choice(neg_idx.tolist())]
                    triplet_loss = triplet_loss + self.triplet(a.unsqueeze(0), p.unsqueeze(0), n.unsqueeze(0))
                    count += 1
            if count > 0:
                triplet_loss = triplet_loss / count
        total = ce_loss + self.lambda_triplet * triplet_loss
        return total, ce_loss, triplet_loss, pp

# -------------------------------
# 5. Mixup/CutMix utils (unchanged)
# -------------------------------
def rand_bbox(W, H, lam):
    cut_rat = math.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    x1 = np.clip(cx - cut_w // 2, 0, W)
    y1 = np.clip(cy - cut_h // 2, 0, H)
    x2 = np.clip(cx + cut_w // 2, 0, W)
    y2 = np.clip(cy + cut_h // 2, 0, H)
    return x1, y1, x2, y2

def apply_mixup_cutmix(x, y):
    if not Config.use_mixup_cutmix:
        return x, y, None
    r = random.random()
    if r < 0.5:
        lam = np.random.beta(Config.mixup_alpha, Config.mixup_alpha)
        idx = torch.randperm(x.size(0)).to(x.device)
        mixed_x = lam * x + (1 - lam) * x[idx]
        y_a, y_b = y, y[idx]
        return mixed_x, (y_a, y_b, lam), 'mixup'
    else:
        lam = np.random.beta(Config.cutmix_alpha, Config.cutmix_alpha)
        idx = torch.randperm(x.size(0)).to(x.device)
        x1, y1, x2, y2 = rand_bbox(x.size(3), x.size(2), lam)
        x_mix = x.clone()
        x_mix[:, :, y1:y2, x1:x2] = x[idx, :, y1:y2, x1:x2]
        lam = 1 - ((x2 - x1) * (y2 - y1) / (x.size(-1) * x.size(-2)))
        y_a, y_b = y, y[idx]
        return x_mix, (y_a, y_b, lam), 'cutmix'

def mix_criterion(logits, target_tuple):
    y_a, y_b, lam = target_tuple
    loss_a = F.cross_entropy(logits, y_a, label_smoothing=Config.label_smoothing)
    loss_b = F.cross_entropy(logits, y_b, label_smoothing=Config.label_smoothing)
    return lam * loss_a + (1 - lam) * loss_b

# -------------------------------
# 6. Enhanced Training/Evaluation with Professional Plots
# -------------------------------
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, model_name, num_epochs=Config.epochs):
    best_val = float('inf')
    epochs_no_improve = 0
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'train_ppdrq': [], 'val_ppdrq': [], 'epochs_trained': 0}
    start = time.time()

    print(f"\n--- Training {model_name} ---")
    for epoch in range(num_epochs):
        model.train()
        t_loss = 0.0; t_correct = 0; t_total = 0; t_pp = 0.0

        # Pre-collect for triplet
        allF, allY, allP = [], [], []
        with torch.no_grad():
            for xb, yb, _ in train_loader:
                xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
                logits, emb, _ = model(xb)
                pp = compute_ppdrq_from_logits(logits, Config.num_classes)
                allF.append(emb.detach().cpu()); allY.append(yb.detach().cpu()); allP.append(pp.detach().cpu())
        allF = torch.cat(allF, 0) if allF else torch.tensor([])
        allY = torch.cat(allY, 0) if allY else torch.tensor([])
        allP = torch.cat(allP, 0) if allP else torch.tensor([])

        for xb, yb, _ in train_loader:
            xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
            xb_m, yb_m, aug = apply_mixup_cutmix(xb, yb)

            optimizer.zero_grad(set_to_none=True)
            logits, emb, _ = model(xb_m)

            if aug is None:
                total, ce, trip, pp = criterion(logits, emb, yb, [allF], [allY], [allP])
            else:
                ce_mix = mix_criterion(logits, yb_m)
                pp = compute_ppdrq_from_logits(logits, Config.num_classes)
                total = ce_mix
                ce, trip = ce_mix, torch.tensor(0.0, device=logits.device)

            total.backward()
            if Config.grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), Config.grad_clip)
            optimizer.step()

            with torch.no_grad():
                probs = torch.softmax(logits, dim=1)
                pred = probs.argmax(1)
                if aug is None:
                    t_correct += (pred == yb).sum().item()
                    t_total += yb.size(0)
                else:
                    y_a, _, lam = yb_m
                    t_correct += (pred == y_a).sum().item() * lam
                    t_total += y_a.size(0)
                t_loss += total.item() * xb.size(0)
                t_pp += pp.sum().item()

        tr_loss = t_loss / max(t_total, 1)
        tr_acc = t_correct / max(t_total, 1)
        tr_pp = t_pp / max(t_total, 1)

        # Validation
        model.eval()
        v_loss=0.0; v_cor=0; v_tot=0; v_pp=0.0
        with torch.no_grad():
            for xb, yb, _ in val_loader:
                xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
                logits, emb, _ = model(xb)
                total, ce, trip, pp = criterion(logits, emb, yb)
                probs = torch.softmax(logits, dim=1)
                pred = probs.argmax(1)
                v_cor += (pred == yb).sum().item()
                v_tot += yb.size(0)
                v_loss += total.item() * xb.size(0)
                v_pp += pp.sum().item()
        va_loss = v_loss / max(v_tot, 1)
        va_acc  = v_cor / max(v_tot, 1)
        va_pp   = v_pp / max(v_tot, 1)

        history['train_loss'].append(tr_loss)
        history['val_loss'].append(va_loss)
        history['train_acc'].append(tr_acc)
        history['val_acc'].append(va_acc)
        history['train_ppdrq'].append(tr_pp)
        history['val_ppdrq'].append(va_pp)
        history['epochs_trained'] = epoch + 1

        if (epoch+1) == 1 or (epoch+1) % 5 == 0 or (epoch+1) == num_epochs:
            print(f"Epoch {epoch+1:03d}/{num_epochs} | Train: loss {tr_loss:.4f} acc {tr_acc:.4f} pp {tr_pp:.3f} | Val: loss {va_loss:.4f} acc {va_acc:.4f} pp {va_pp:.3f}")

        scheduler.step(va_loss)
        if va_loss < best_val:
            best_val = va_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), os.path.join(Config.output_dir, f"{model_name}_best.pth"))
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= Config.patience_early_stopping:
                print(f"Early stopping at epoch {epoch+1}")
                break

    dur = time.time() - start
    print(f"Training time for {model_name}: {dur:.1f}s")
    return history, dur

def calculate_calibration_metrics(probabilities: torch.Tensor, labels: torch.Tensor, num_bins=10):
    bins = torch.linspace(0, 1, num_bins + 1)
    conf, pred = probabilities.max(1)
    acc = (pred == labels).float()
    ece = 0.0
    for i in range(num_bins):
        in_bin = (conf > bins[i]) & (conf <= bins[i+1])
        if in_bin.any():
            ece += torch.abs(acc[in_bin].mean() - conf[in_bin].mean()) * in_bin.float().mean()
    one_hot = F.one_hot(labels, num_classes=Config.num_classes).float()
    brier = torch.mean(torch.sum((probabilities - one_hot) ** 2, dim=1))
    return ece.item(), brier.item()

def plot_reliability_diagram(conf, correct, ece, model_name, num_bins=10):
    """Enhanced reliability diagram with professional styling"""
    bins = np.linspace(0, 1, num_bins + 1)
    mids = (bins[:-1] + bins[1:]) / 2
    bin_acc = []
    counts = []
    
    for i in range(num_bins):
        m = (conf >= bins[i]) & (conf <= bins[i+1])
        if m.sum() > 0:
            bin_acc.append(correct[m].float().mean().item())
            counts.append(int(m.sum()))
        else:
            bin_acc.append(0.0); counts.append(0)
    
    # Create professional plot
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Perfect calibration line
    ax.plot([0,1], [0,1], 'k--', linewidth=2, alpha=0.7, label='Perfect Calibration')
    
    # Calibration bars with model-specific color
    color = MODEL_COLORS.get(model_name, COLORS['primary'])
    bars = ax.bar(mids, bin_acc, width=1/num_bins*0.8, alpha=0.8, 
                  color=color, edgecolor='white', linewidth=1.5)
    
    # Add count labels on bars
    for i, (bar, count) in enumerate(zip(bars, counts)):
        if count > 0:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height + 0.02, 
                   str(count), ha='center', va='bottom', fontweight='bold',
                   fontsize=10)
    
    # Styling
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1.1)
    ax.set_xlabel('Confidence', fontweight='bold')
    ax.set_ylabel('Accuracy', fontweight='bold')
    ax.set_title(f'Reliability Diagram: {model_name}\nECE = {ece:.3f}', 
                fontweight='bold', pad=20)
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)
    
    # Save with high quality
    plt.tight_layout()
    plt.savefig(os.path.join(Config.output_dir, f"{model_name}_reliability.png"), 
                dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()

def evaluate_model(model, loader, model_name):
    model.eval()
    all_prob=[]; all_y=[]
    t0=time.time()
    if len(loader.dataset)==0:
        print(f"[WARN] {model_name} loader empty.")
        return {k:0.0 for k in ['accuracy','precision','recall','f1_score','mean_ppdrq','ece','brier_score','inference_time']}
    with torch.no_grad():
        for xb, yb, _ in tqdm(loader, desc=f"Evaluating {model_name}"):
            xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device)
            logits, _, _ = model(xb)
            prob = torch.softmax(logits, dim=1)
            all_prob.append(prob.cpu()); all_y.append(yb.cpu())
    t1=time.time()
    prob=torch.cat(all_prob,0)
    y=torch.cat(all_y,0)
    pred=prob.argmax(1)
    acc=accuracy_score(y.numpy(), pred.numpy())
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes)
    ece,brier=calculate_calibration_metrics(prob, y)
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, model_name)
    print(f"\n[{model_name}] Acc {acc:.4f} Prec {prec:.4f} Rec {rec:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f} Brier {brier:.4f}")
    return {
        'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1,
        'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier,
        'inference_time':t1-t0
    }

# -------------------------------
# 7. Enhanced Layer-wise PPDRQ Analysis
# -------------------------------
def get_and_plot_layerwise_ppdrq(model, loader, model_name, title_suffix=""):
    """Enhanced layer-wise PPDRQ analysis with professional visualization"""
    model.eval()
    collect = {k:[] for k in ['base_model_output','hidden_layer1','hidden_layer2','flatten','final_logits']}
    if len(loader.dataset)==0:
        print(f"[WARN] No data for layer-wise PPDRQ.")
        return
    
    with torch.no_grad():
        for xb, _, _ in loader:
            xb=xb.to(Config.device)
            logits, _, d = model(xb)
            for k,v in d.items():
                if k=='base_model_output':
                    collect[k].append(v.mean(dim=(2,3)).cpu())
                else:
                    collect[k].append(v.cpu())
    
    # Temporary classifiers for each layer
    temp_base = nn.Linear(2048, Config.num_classes)
    temp_h1   = nn.Linear(1024, Config.num_classes)
    temp_h2   = nn.Linear(512, Config.num_classes)
    temp_flat = nn.Linear(2048, Config.num_classes)
    
    mean_vals={}
    for k, arr in collect.items():
        if not arr: mean_vals[k]=0.0; continue
        X = torch.cat(arr,0)
        if k=='base_model_output':
            l = temp_base(X)
        elif k=='hidden_layer1':
            l = temp_h1(X)
        elif k=='hidden_layer2':
            l = temp_h2(X)
        elif k=='flatten':
            l = temp_flat(X)
        else: # final_logits
            l = X
        pp = compute_ppdrq_from_logits(l, Config.num_classes)
        mean_vals[k]=pp.mean().item()
    
    # Create professional bar plot
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Normalize values for better visualization
    mmax=max(mean_vals.values()) if mean_vals else 1.0
    norm={k:(v/mmax if mmax>0 else 0.0) for k,v in mean_vals.items()}
    
    layers=list(norm.keys())
    vals=[norm[k] for k in layers]
    
    # Create gradient colors
    colors = plt.cm.viridis(np.linspace(0, 1, len(layers)))
    
    bars = ax.bar(range(len(layers)), vals, color=colors, alpha=0.8, 
                  edgecolor='white', linewidth=2)
    
    # Add value labels on bars
    for i, (bar, val) in enumerate(zip(bars, vals)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.02, 
               f"{val:.3f}", ha='center', va='bottom', fontweight='bold')
    
    # Styling
    ax.set_xticks(range(len(layers)))
    ax.set_xticklabels([l.replace('_', ' ').title() for l in layers], rotation=45, ha='right')
    ax.set_ylim(0, 1.1)
    ax.set_ylabel('Normalized Mean PPDRQ', fontweight='bold')
    ax.set_title(f'Layer-wise PPDRQ Analysis: {model_name} {title_suffix}', 
                fontweight='bold', pad=20)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(os.path.join(Config.output_dir, f"{model_name}_layer_ppdrq.png"), 
                dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()

# -------------------------------
# 8. MC-Dropout & Ensemble (unchanged core logic)
# -------------------------------
def _set_dropout_mode_only(model: nn.Module, training: bool = True):
    """Toggle ONLY Dropout layers' mode; keep everything else (e.g., BatchNorm) unchanged."""
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.AlphaDropout)):
            m.train(training)

def mc_dropout_predict(model, loader, num_passes=Config.mc_dropout_passes):
    # --- FIX: stochastic inference protocol ---
    # Keep BN frozen by staying in eval(); enable randomness only in Dropout.
    model.eval()
    _set_dropout_mode_only(model, True)

    all_prob=[]; all_y=[]; t0=time.time()
    with torch.no_grad():
        for xb, yb, _ in tqdm(loader, desc=f"MC-Dropout {num_passes} passes"):
            xb=xb.to(Config.device); yb=torch.as_tensor(yb, device=Config.device)
            batch = []
            for _ in range(num_passes):
                logits, _, _ = model(xb)
                batch.append(torch.softmax(logits, dim=1).unsqueeze(0))
            mean_prob = torch.cat(batch,0).mean(0)
            all_prob.append(mean_prob.cpu()); all_y.append(yb.cpu())
    t1=time.time()

    # restore Dropout layers to eval mode
    _set_dropout_mode_only(model, False)

    prob=torch.cat(all_prob,0); y=torch.cat(all_y,0)
    pred=prob.argmax(1)
    acc=accuracy_score(y.numpy(), pred.numpy())
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    ece,brier=calculate_calibration_metrics(prob, y)
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, 'MC_Dropout')
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes)
    print(f"\n[MC-Dropout] Acc {acc:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f}")
    return {'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1,
            'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier,
            'inference_time':t1-t0}

def train_ensemble_member(i, train_loader, val_loader):
    print(f"\n--- Training Ensemble Member {i+1} ---")
    m = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)  # CE only with smoothing
    opt = optim.AdamW(filter(lambda p: p.requires_grad, m.parameters()), lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.2, patience=Config.patience_lr_scheduler)
    hist, t = train_model(m, train_loader, val_loader, crit, opt, sch, f"Ensemble_{i+1}", num_epochs=max(10, Config.epochs//Config.num_ensemble_models))
    torch.save(m.state_dict(), os.path.join(Config.output_dir, f"ensemble_member_{i+1}.pth"))
    return m, t, hist

def evaluate_deep_ensemble(models_list: List[nn.Module], loader):
    all_prob=[]; all_y=[]; t0=time.time()
    with torch.no_grad():
        for xb, yb, _ in tqdm(loader, desc="Evaluating Ensemble"):
            xb=xb.to(Config.device)
            mem_probs=[]
            for m in models_list:
                m.eval()
                logits, _, _ = m(xb)
                mem_probs.append(torch.softmax(logits, dim=1).unsqueeze(0))
            mean_prob=torch.cat(mem_probs,0).mean(0)
            all_prob.append(mean_prob.cpu()); all_y.append(torch.as_tensor(yb).cpu())
    t1=time.time()
    prob=torch.cat(all_prob,0); y=torch.cat(all_y,0)
    pred=prob.argmax(1)
    acc=accuracy_score(y.numpy(), pred.numpy())
    prec=precision_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    rec=recall_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    f1=f1_score(y.numpy(), pred.numpy(), average='macro', zero_division=0)
    ece,brier=calculate_calibration_metrics(prob, y)
    plot_reliability_diagram(prob.max(1)[0], (pred==y), ece, 'Deep_Ensemble')
    pp = compute_ppdrq_from_logits(torch.log(prob+1e-8), Config.num_classes)
    print(f"\n[Ensemble] Acc {acc:.4f} F1 {f1:.4f} PPDRQ {pp.mean().item():.4f} ECE {ece:.4f}")
    return {'accuracy':acc,'precision':prec,'recall':rec,'f1_score':f1,
            'mean_ppdrq':pp.mean().item(),'ece':ece,'brier_score':brier,
            'inference_time':t1-t0}

# -------------------------------
# 9. Enhanced Paper-Ready Reporting with Professional Visualizations
# -------------------------------
def save_results_table(results: Dict[str, Dict], times: Dict[str, float], path_csv: str):
    """Enhanced results table with methodological transparency"""
    rows = []
    for name, m in results.items():
        train_t = times.get(name, times.get(f"{name}_Train", np.nan))
        infer_t = m.get('inference_time', times.get(f"{name}_Infer", np.nan))
        rows.append({
            'Model': name,
            'Accuracy': m['accuracy'], 
            'Precision': m['precision'], 
            'Recall': m['recall'], 
            'F1_Score': m['f1_score'],
            'Mean_PPDRQ': m['mean_ppdrq'], 
            'ECE': m['ece'], 
            'Brier_Score': m['brier_score'],
            'Train_Time_s': train_t, 
            'Inference_Time_s': infer_t,
            'Params_Used': f"λ={Config.lambda_triplet}, εp={Config.epsilon_p}, εn={Config.epsilon_n}"
        })
    df = pd.DataFrame(rows)
    csv_path = os.path.join(Config.output_dir, path_csv)
    df.to_csv(csv_path, index=False)
    print(f"Saved enhanced results table -> {csv_path}")
    return df

def plot_enhanced_metric_bars(df: pd.DataFrame, metric: str, fname: str):
    """Professional metric comparison with enhanced styling"""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Sort by metric value
    df_sorted = df.sort_values(metric, ascending=False)
    models = df_sorted['Model'].values
    values = df_sorted[metric].values
    
    # Create bars with model-specific colors
    colors = [MODEL_COLORS.get(model, COLORS['primary']) for model in models]
    bars = ax.bar(models, values, color=colors, alpha=0.8, 
                  edgecolor='white', linewidth=2)
    
    # Add value labels on bars
    for bar, value in zip(bars, values):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + max(values)*0.01, 
               f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
    
    # Styling
    ax.set_ylabel(metric.replace('_', ' ').title(), fontweight='bold', fontsize=14)
    ax.set_title(f'Model Comparison: {metric.replace("_", " ").title()}', 
                fontweight='bold', fontsize=16, pad=20)
    ax.tick_params(axis='x', rotation=25, labelsize=12)
    ax.tick_params(axis='y', labelsize=12)
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_axisbelow(True)
    
    # Add subtle background
    ax.set_facecolor('#f8f9fa')
    
    plt.tight_layout()
    out = os.path.join(Config.output_dir, fname)
    plt.savefig(out, dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()
    print(f"Saved enhanced {metric} bar plot -> {out}")

def plot_enhanced_cost_performance(df: pd.DataFrame, perf_metric: str='Accuracy', 
                                 cost_metric: str='Inference_Time_s', fname: str='cost_perf.png'):
    """Professional cost-performance analysis"""
    fig, ax = plt.subplots(figsize=(10, 8))
    
    x = df[cost_metric].values
    y = df[perf_metric].values
    models = df['Model'].values
    
    # Create scatter plot with model-specific colors
    for i, model in enumerate(models):
        color = MODEL_COLORS.get(model, COLORS['primary'])
        ax.scatter(x[i], y[i], c=color, s=200, alpha=0.8, 
                  edgecolors='white', linewidth=2, label=model)
        
        # Add model labels with offset
        ax.annotate(model, (x[i], y[i]), xytext=(5, 5), 
                   textcoords='offset points', fontweight='bold',
                   bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.8))
    
    # Styling
    ax.set_xlabel(cost_metric.replace('_', ' ').title(), fontweight='bold', fontsize=14)
    ax.set_ylabel(perf_metric.replace('_', ' ').title(), fontweight='bold', fontsize=14)
    ax.set_title(f'Performance vs Cost Trade-off\n({perf_metric} vs {cost_metric})', 
                fontweight='bold', fontsize=16, pad=20)
    ax.grid(True, alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    out = os.path.join(Config.output_dir, fname)
    plt.savefig(out, dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()
    print(f"Saved enhanced cost-performance plot -> {out}")

def enhanced_lambda_sensitivity_sweep(lambdas=(0.0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.5, 1.0)):
    """Enhanced λ sensitivity analysis with comprehensive visualization"""
    print("\n" + "="*60)
    print("LAMBDA SENSITIVITY ANALYSIS")
    print("="*60)
    print("Testing different values of λ (triplet loss weight)")
    print(f"λ values to test: {lambdas}")
    print(f"Fixed parameters: εp={Config.epsilon_p}, εn={Config.epsilon_n}")
    print("="*60)
    
    results_data = []
    acc_list = []
    ece_list = []
    f1_list = []
    brier_list = []
    train_times = []
    
    for lam in lambdas:
        print(f"\n[λ-sweep] Training with λ={lam}")
        model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
        crit = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=lam)
        opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 
                         lr=Config.learning_rate, weight_decay=Config.weight_decay)
        sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.2, 
                                                  patience=Config.patience_lr_scheduler)
        
        # Quick training for sensitivity analysis
        epochs = max(5, Config.epochs//6)
        hist, train_time = train_model(model, train_loader, val_loader, crit, opt, sch, 
                                     f'Lambda_{lam}', num_epochs=epochs)
        res = evaluate_model(model, test_loader, f'Lambda_{lam}')
        
        # Store results
        acc_list.append(res['accuracy'])
        ece_list.append(res['ece'])
        f1_list.append(res['f1_score'])
        brier_list.append(res['brier_score'])
        train_times.append(train_time)
        
        results_data.append({
            'Lambda': lam,
            'Accuracy': res['accuracy'],
            'F1_Score': res['f1_score'],
            'ECE': res['ece'],
            'Brier_Score': res['brier_score'],
            'Train_Time': train_time
        })
    
    # Create comprehensive sensitivity plot
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    lambda_vals = list(lambdas)
    
    # Accuracy plot
    ax1.plot(lambda_vals, acc_list, marker='o', linewidth=3, markersize=8, 
             color=COLORS['primary'], label='Accuracy')
    ax1.set_xlabel('λ (Triplet Loss Weight)', fontweight='bold')
    ax1.set_ylabel('Accuracy', fontweight='bold')
    ax1.set_title('Accuracy vs λ', fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # ECE plot
    ax2.plot(lambda_vals, ece_list, marker='s', linewidth=3, markersize=8, 
             color=COLORS['secondary'], label='ECE')
    ax2.set_xlabel('λ (Triplet Loss Weight)', fontweight='bold')
    ax2.set_ylabel('Expected Calibration Error', fontweight='bold')
    ax2.set_title('ECE vs λ', fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    
    # F1 Score plot
    ax3.plot(lambda_vals, f1_list, marker='^', linewidth=3, markersize=8, 
             color=COLORS['accent'], label='F1 Score')
    ax3.set_xlabel('λ (Triplet Loss Weight)', fontweight='bold')
    ax3.set_ylabel('F1 Score', fontweight='bold')
    ax3.set_title('F1 Score vs λ', fontweight='bold')
    ax3.grid(True, alpha=0.3)
    ax3.legend()
    
    # Brier Score plot
    ax4.plot(lambda_vals, brier_list, marker='d', linewidth=3, markersize=8, 
             color=COLORS['success'], label='Brier Score')
    ax4.set_xlabel('λ (Triplet Loss Weight)', fontweight='bold')
    ax4.set_ylabel('Brier Score', fontweight='bold')
    ax4.setTitle = ax4.set_title('Brier Score vs λ', fontweight='bold')
    ax4.grid(True, alpha=0.3)
    ax4.legend()
    
    plt.suptitle('Comprehensive λ Sensitivity Analysis\n' + 
                f'εp={Config.epsilon_p}, εn={Config.epsilon_n}', 
                fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    
    out = os.path.join(Config.output_dir, 'enhanced_lambda_sensitivity.png')
    plt.savefig(out, dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()
    print(f"Saved enhanced λ sensitivity plot -> {out}")
    
    # Save detailed results
    sensitivity_df = pd.DataFrame(results_data)
    sensitivity_csv = os.path.join(Config.output_dir, 'lambda_sensitivity_results.csv')
    sensitivity_df.to_csv(sensitivity_csv, index=False)
    print(f"Saved λ sensitivity results -> {sensitivity_csv}")
    
    # Find optimal lambda
    best_lambda_idx = np.argmax(acc_list)
    best_lambda = lambdas[best_lambda_idx]
    best_acc = acc_list[best_lambda_idx]
    
    print(f"\nOptimal λ: {best_lambda} (Accuracy: {best_acc:.4f})")
    
    return sensitivity_df

def create_methodological_transparency_report():
    """Generate comprehensive methodological transparency report"""
    report_content = f"""
# METHODOLOGICAL TRANSPARENCY REPORT
Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}

## Hyperparameter Settings

### PPDRQ-Specific Parameters:
- λ (lambda_triplet): {Config.lambda_triplet}
  * Weight for triplet loss component
  * Controls balance between classification and embedding quality
  
- εp (epsilon_p): {Config.epsilon_p}
  * Threshold for positive sample selection in triplet loss
  * Samples with PPDRQ difference < εp considered similar
  
- εn (epsilon_n): {Config.epsilon_n}
  * Threshold for negative sample selection in triplet loss
  * Samples with PPDRQ difference ≥ εn considered dissimilar

### Training Parameters:
- Learning Rate: {Config.learning_rate}
- Weight Decay: {Config.weight_decay}
- Batch Size: {Config.batch_size}
- Max Epochs: {Config.epochs}
- Label Smoothing: {Config.label_smoothing}
- Gradient Clipping: {Config.grad_clip}

### Regularization:
- Mixup Alpha: {Config.mixup_alpha}
- CutMix Alpha: {Config.cutmix_alpha}
- Dropout Rate: 0.6 (in CustomInceptionV3)

### Evaluation Parameters:
- MC-Dropout Passes: {Config.mc_dropout_passes}
- Ensemble Members: {Config.num_ensemble_models}
- Calibration Bins: 10

## Model Architecture:
- Base: InceptionV3 (pretrained on ImageNet)
- Fine-tuning: Last blocks only (Mixed_7x layers)
- Custom Head: 2048 → BN → Dropout → Linear(num_classes)
- Embedding Layers: 2048 → 1024 → 512

## Data Augmentation:
- RandomResizedCrop(scale=(0.8, 1.0), ratio=(0.9, 1.1))
- RandomRotation(10°)
- RandomHorizontalFlip(p=0.5)
- GaussianBlur(kernel_size=3)
- RandomErasing(p=0.25)

## Computational Environment:
- Device: {Config.device}
- Image Size: {Config.image_size}x{Config.image_size}
- Number of Classes: {Config.num_classes}
"""
    
    report_path = os.path.join(Config.output_dir, 'methodological_transparency_report.txt')
    with open(report_path, 'w') as f:
        f.write(report_content)
    print(f"Saved methodological transparency report -> {report_path}")

def create_enhanced_timing_analysis(results: Dict, times: Dict):
    """Create comprehensive timing analysis with professional visualization"""
    # Prepare data
    timing_data = []
    for model_name, metrics in results.items():
        train_time = times.get(model_name, times.get(f"{model_name}_Train", 0))
        infer_time = metrics.get('inference_time', times.get(f"{model_name}_Infer", 0))
        accuracy = metrics.get('accuracy', 0)
        
        timing_data.append({
            'Model': model_name,
            'Training_Time_s': train_time,
            'Inference_Time_s': infer_time,
            'Total_Time_s': train_time + infer_time,
            'Accuracy': accuracy,
            'Time_per_Accuracy': (train_time + infer_time) / max(accuracy, 0.001)
        })
    
    timing_df = pd.DataFrame(timing_data)
    
    # Create comprehensive timing visualization
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    models = timing_df['Model']
    colors = [MODEL_COLORS.get(model, COLORS['primary']) for model in models]
    
    # Training time comparison
    bars1 = ax1.bar(models, timing_df['Training_Time_s'], color=colors, alpha=0.8)
    ax1.set_ylabel('Training Time (seconds)', fontweight='bold')
    ax1.set_title('Training Time Comparison', fontweight='bold')
    ax1.tick_params(axis='x', rotation=45)
    for bar, time_val in zip(bars1, timing_df['Training_Time_s']):
        ax1.text(bar.get_x() + bar.get_width()/2., bar.get_height() + max(timing_df['Training_Time_s'])*0.01,
                f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')
    
    # Inference time comparison
    bars2 = ax2.bar(models, timing_df['Inference_Time_s'], color=colors, alpha=0.8)
    ax2.set_ylabel('Inference Time (seconds)', fontweight='bold')
    ax2.set_title('Inference Time Comparison', fontweight='bold')
    ax2.tick_params(axis='x', rotation=45)
    for bar, time_val in zip(bars2, timing_df['Inference_Time_s']):
        ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + max(timing_df['Inference_Time_s'])*0.01,
                f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')
    
    # Efficiency plot (Time per Accuracy unit)
    bars3 = ax3.bar(models, timing_df['Time_per_Accuracy'], color=colors, alpha=0.8)
    ax3.set_ylabel('Time per Accuracy Unit (s)', fontweight='bold')
    ax3.set_title('Model Efficiency (Lower is Better)', fontweight='bold')
    ax3.tick_params(axis='x', rotation=45)
    for bar, eff_val in zip(bars3, timing_df['Time_per_Accuracy']):
        ax3.text(bar.get_x() + bar.get_width()/2., bar.get_height() + max(timing_df['Time_per_Accuracy'])*0.01,
                f'{eff_val:.1f}', ha='center', va='bottom', fontweight='bold')
    
    # Scatter plot: Accuracy vs Total Time
    for i, model in enumerate(models):
        ax4.scatter(timing_df.iloc[i]['Total_Time_s'], timing_df.iloc[i]['Accuracy'], 
                   c=colors[i], s=200, alpha=0.8, edgecolors='white', linewidth=2)
        ax4.annotate(model, (timing_df.iloc[i]['Total_Time_s'], timing_df.iloc[i]['Accuracy']),
                    xytext=(5, 5), textcoords='offset points', fontweight='bold',
                    bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.8))
    
    ax4.set_xlabel('Total Time (seconds)', fontweight='bold')
    ax4.set_ylabel('Accuracy', fontweight='bold')
    ax4.set_title('Accuracy vs Total Computational Cost', fontweight='bold')
    ax4.grid(True, alpha=0.3)
    
    plt.suptitle('Comprehensive Timing Analysis', fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    
    out = os.path.join(Config.output_dir, 'enhanced_timing_analysis.png')
    plt.savefig(out, dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()
    print(f"Saved enhanced timing analysis -> {out}")
    
    # Save timing data
    timing_csv = os.path.join(Config.output_dir, 'comprehensive_timing_analysis.csv')
    timing_df.to_csv(timing_csv, index=False)
    print(f"Saved timing analysis data -> {timing_csv}")
    
    return timing_df

# -------------------------------
# 10. Enhanced Visualization Functions
# -------------------------------
CLASS_NAMES = ['AD','MCI','NC']

def collect_predictions(model, loader):
    y_true=[]; y_pred=[]; conf=[]; probs=[]
    model.eval()
    with torch.no_grad():
        for xb, yb, _ in loader:
            xb = xb.to(Config.device); yb = torch.as_tensor(yb, device=Config.device)
            logits, _, _ = model(xb)
            p = torch.softmax(logits, dim=1)
            pr = p.argmax(1)
            y_true.append(yb.cpu()); y_pred.append(pr.cpu()); conf.append(p.max(1)[0].cpu()); probs.append(p.cpu())
    return torch.cat(y_true).numpy(), torch.cat(y_pred).numpy(), torch.cat(conf).numpy(), torch.cat(probs,0).numpy()

def plot_enhanced_confusion(y_true, y_pred, classes, title, fname, normalize=True):
    """Enhanced confusion matrix with professional styling"""
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True).clip(min=1)
    
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Create custom colormap
    cmap = plt.cm.Blues
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.ax.tick_params(labelsize=12)
    
    # Set labels
    ax.set_title(title, fontweight='bold', fontsize=16, pad=20)
    tick_marks = np.arange(len(classes))
    ax.set_xticks(tick_marks)
    ax.set_xticklabels(classes, fontsize=12, fontweight='bold')
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes, fontsize=12, fontweight='bold')
    
    # Add text annotations
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            text_val = f"{cm[i, j]:.3f}" if normalize else f"{int(cm[i, j])}"
            ax.text(j, i, text_val, ha="center", va="center", fontweight='bold',
                   color="white" if cm[i, j] > thresh else "black", fontsize=14)
    
    ax.set_ylabel('True Label', fontweight='bold', fontsize=14)
    ax.set_xlabel('Predicted Label', fontweight='bold', fontsize=14)
    
    plt.tight_layout()
    out = os.path.join(Config.output_dir, fname)
    plt.savefig(out, dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()
    print(f"Saved enhanced confusion matrix -> {out}")

def export_per_class_report(y_true, y_pred, classes, fname_csv):
    rep = classification_report(y_true, y_pred, target_names=classes, output_dict=True, zero_division=0)
    df = pd.DataFrame(rep).transpose()
    out = os.path.join(Config.output_dir, fname_csv)
    df.to_csv(out)
    print(f"Saved per-class report -> {out}")

def enhanced_reliability_grid(models_dict, loader, fname='enhanced_reliability_grid.png', num_bins=10):
    """Enhanced side-by-side reliability diagrams"""
    cols = len(models_dict)
    fig, axes = plt.subplots(1, cols, figsize=(6*cols, 6))
    if cols == 1:
        axes = [axes]
    
    for idx, (name, model) in enumerate(models_dict.items()):
        model.eval()
        all_prob = []
        all_y = []
        
        with torch.no_grad():
            for xb, yb, _ in loader:
                xb = xb.to(Config.device)
                yb = torch.as_tensor(yb, device=Config.device)
                logits, _, _ = model(xb)
                pr = torch.softmax(logits, dim=1)
                all_prob.append(pr.cpu())
                all_y.append(yb.cpu())
        
        prob = torch.cat(all_prob, 0)
        y = torch.cat(all_y, 0)
        pred = prob.argmax(1)
        conf = prob.max(1)[0]
        correct = (pred == y)
        
        # Compute ECE for title
        ece, _ = calculate_calibration_metrics(prob, y)
        
        # Create bins
        bins = np.linspace(0, 1, num_bins + 1)
        mids = (bins[:-1] + bins[1:]) / 2
        bin_acc = []
        counts = []
        
        for i in range(num_bins):
            m = (conf >= bins[i]) & (conf <= bins[i + 1])
            if m.sum() > 0:
                bin_acc.append(correct[m].float().mean().item())
                counts.append(int(m.sum()))
            else:
                bin_acc.append(0.0)
                counts.append(0)
        
        ax = axes[idx]
        
        # Perfect calibration line
        ax.plot([0, 1], [0, 1], 'k--', linewidth=2, alpha=0.7, label='Perfect Calibration')
        
        # Bars with model-specific color
        color = MODEL_COLORS.get(name, COLORS['primary'])
        bars = ax.bar(mids, bin_acc, width=1/num_bins*0.8, alpha=0.8, 
                     color=color, edgecolor='white', linewidth=1.5)
        
        # Add count labels
        for i, (bar, count) in enumerate(zip(bars, counts)):
            if count > 0:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.02, 
                       str(count), ha='center', va='bottom', fontweight='bold', fontsize=9)
        
        ax.set_title(f'{name}\nECE = {ece:.3f}', fontweight='bold', fontsize=12)
        ax.set_xlabel('Confidence', fontweight='bold')
        if idx == 0:
            ax.set_ylabel('Accuracy', fontweight='bold')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1.1)
        ax.grid(True, alpha=0.3)
        ax.legend(loc='upper left', fontsize=8)
    
    plt.suptitle('Reliability Comparison Across Models', fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    
    out = os.path.join(Config.output_dir, fname)
    plt.savefig(out, dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()
    print(f"Saved enhanced reliability grid -> {out}")

def plot_enhanced_training_curves(histories_dict, fname='enhanced_training_curves.png'):
    """Plot enhanced training curves for multiple models"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Define metrics to plot
    metrics = [
        ('train_acc', 'val_acc', 'Accuracy', axes[0, 0]),
        ('train_loss', 'val_loss', 'Loss', axes[0, 1]),
        ('train_ppdrq', 'val_ppdrq', 'PPDRQ', axes[1, 0])
    ]
    
    for train_metric, val_metric, title, ax in metrics:
        for model_name, history in histories_dict.items():
            if train_metric in history and val_metric in history:
                color = MODEL_COLORS.get(model_name, COLORS['primary'])
                epochs = range(1, len(history[train_metric]) + 1)
                
                ax.plot(epochs, history[train_metric], '-', color=color, alpha=0.7,
                       linewidth=2, label=f'{model_name} Train')
                ax.plot(epochs, history[val_metric], '--', color=color, alpha=0.9,
                       linewidth=2, label=f'{model_name} Val')
        
        ax.set_xlabel('Epoch', fontweight='bold')
        ax.set_ylabel(title, fontweight='bold')
        ax.set_title(f'{title} Curves', fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    # Summary statistics in the fourth subplot
    ax = axes[1, 1]
    summary_data = []
    for model_name, history in histories_dict.items():
        if 'val_acc' in history:
            final_acc = history['val_acc'][-1] if history['val_acc'] else 0
            best_acc = max(history['val_acc']) if history['val_acc'] else 0
            epochs_trained = history.get('epochs_trained', len(history.get('val_acc', [])))
            
            summary_data.append({
                'Model': model_name,
                'Final_Val_Acc': final_acc,
                'Best_Val_Acc': best_acc,
                'Epochs_Trained': epochs_trained
            })
    
    if summary_data:
        summary_df = pd.DataFrame(summary_data)
        models = summary_df['Model']
        colors = [MODEL_COLORS.get(model, COLORS['primary']) for model in models]
        
        bars = ax.bar(models, summary_df['Best_Val_Acc'], color=colors, alpha=0.8)
        ax.set_ylabel('Best Validation Accuracy', fontweight='bold')
        ax.set_title('Training Summary', fontweight='bold')
        ax.tick_params(axis='x', rotation=45)
        
        for bar, val in zip(bars, summary_df['Best_Val_Acc']):
            ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                   f'{val:.3f}', ha='center', va='bottom', fontweight='bold')
    
    plt.suptitle('Enhanced Training Analysis', fontsize=16, fontweight='bold', y=0.98)
    plt.tight_layout()
    
    out = os.path.join(Config.output_dir, fname)
    plt.savefig(out, dpi=Config.FIGURE_DPI, bbox_inches='tight')
    plt.close()
    print(f"Saved enhanced training curves -> {out}")

# -------------------------------
# 11. Main Execution with Enhanced Features
# -------------------------------
if __name__ == '__main__':
    # Create methodological transparency report
    create_methodological_transparency_report()
    
    results: Dict[str, Dict] = {}
    times: Dict[str, float] = {}
    histories: Dict[str, Dict] = {}

    # Baseline (with label smoothing + anti-overfitting)
    print("\n" + "="*60)
    print("TRAINING BASELINE MODEL")
    print("="*60)
    baseline = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit_base = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)  # CE only
    opt_base = optim.AdamW(filter(lambda p: p.requires_grad, baseline.parameters()), 
                          lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_base = optim.lr_scheduler.ReduceLROnPlateau(opt_base, mode='min', factor=0.2, 
                                                   patience=Config.patience_lr_scheduler)
    hist_b, t_b = train_model(baseline, train_loader, val_loader, crit_base, opt_base, sch_base, 'Baseline')
    res_b = evaluate_model(baseline, test_loader, 'Baseline')
    results['Baseline'] = res_b
    times['Baseline'] = t_b
    histories['Baseline'] = hist_b

    # PPDRQ-weighted CE only
    print("\n" + "="*60)
    print("TRAINING PPDRQ-WEIGHTED MODEL")
    print("="*60)
    class PPDRQWeightedCE(nn.Module):
        def __init__(self, num_classes):
            super().__init__()
            self.num_classes = num_classes
        def forward(self, logits, _emb, labels, **kwargs):
            pp = compute_ppdrq_from_logits(logits, self.num_classes)
            per_ce = F.cross_entropy(logits, labels, reduction='none', label_smoothing=Config.label_smoothing)
            w = 1 + (1 - pp)
            loss = torch.mean(per_ce * w)
            return loss, loss, torch.tensor(0.0, device=logits.device), pp

    ppdrq_model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit_pp = PPDRQWeightedCE(num_classes=Config.num_classes)
    opt_pp = optim.AdamW(filter(lambda p: p.requires_grad, ppdrq_model.parameters()), 
                        lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_pp = optim.lr_scheduler.ReduceLROnPlateau(opt_pp, mode='min', factor=0.2, 
                                                 patience=Config.patience_lr_scheduler)
    hist_p, t_p = train_model(ppdrq_model, train_loader, val_loader, crit_pp, opt_pp, sch_pp, 'PPDRQ_CE')
    res_p = evaluate_model(ppdrq_model, test_loader, 'PPDRQ_CE')
    results['PPDRQ_CE'] = res_p
    times['PPDRQ_CE'] = t_p
    histories['PPDRQ_CE'] = hist_p

    # Triplet model
    print("\n" + "="*60)
    print("TRAINING TRIPLET MODEL")
    print("="*60)
    triplet_model = CustomInceptionV3(num_classes=Config.num_classes).to(Config.device)
    crit_trip = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=Config.lambda_triplet)
    opt_trip = optim.AdamW(filter(lambda p: p.requires_grad, triplet_model.parameters()), 
                          lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_trip = optim.lr_scheduler.ReduceLROnPlateau(opt_trip, mode='min', factor=0.2, 
                                                   patience=Config.patience_lr_scheduler)
    hist_t, t_t = train_model(triplet_model, train_loader, val_loader, crit_trip, opt_trip, sch_trip, 'Triplet')
    res_t = evaluate_model(triplet_model, test_loader, 'Triplet')
    results['Triplet'] = res_t
    times['Triplet'] = t_t
    histories['Triplet'] = hist_t

    # MC-Dropout
    print("\n" + "="*60)
    print("TRAINING MC-DROPOUT MODEL")
    print("="*60)
    mc_model = CustomInceptionV3(num_classes=Config.num_classes, dropout_rate=0.6).to(Config.device)
    crit_mc = CombinedLoss(num_classes=Config.num_classes, lambda_triplet=0.0)
    opt_mc = optim.AdamW(filter(lambda p: p.requires_grad, mc_model.parameters()), 
                        lr=Config.learning_rate, weight_decay=Config.weight_decay)
    sch_mc = optim.lr_scheduler.ReduceLROnPlateau(opt_mc, mode='min', factor=0.2, 
                                                 patience=Config.patience_lr_scheduler)
    hist_m, t_m = train_model(mc_model, train_loader, val_loader, crit_mc, opt_mc, sch_mc, 'MC_Dropout')
    res_m = mc_dropout_predict(mc_model, test_loader, num_passes=Config.mc_dropout_passes)
    results['MC_Dropout'] = res_m
    times['MC_Dropout_Train'] = t_m
    times['MC_Dropout_Infer'] = res_m['inference_time']
    histories['MC_Dropout'] = hist_m

    # Deep Ensemble
    print("\n" + "="*60)
    print("TRAINING DEEP ENSEMBLE")
    print("="*60)
    members = []
    t_mem = []
    for i in range(Config.num_ensemble_models):
        m, t, hist = train_ensemble_member(i, train_loader, val_loader)
        members.append(m)
        t_mem.append(t)
    res_e = evaluate_deep_ensemble(members, test_loader)
    results['Ensemble'] = res_e
    times['Ensemble_Train'] = sum(t_mem)
    times['Ensemble_Infer'] = res_e['inference_time']

    # Enhanced Summary
    print("\n" + "="*80)
    print("COMPREHENSIVE RESULTS SUMMARY")
    print("="*80)
    print(f"{'Model':<15}{'Acc':>8}{'Prec':>8}{'Rec':>8}{'F1':>8}{'PPDRQ':>9}{'ECE':>8}{'Brier':>8}{'Train(s)':>10}{'Infer(s)':>10}")
    print("-" * 95)
    for k, v in results.items():
        tr = times.get(k, times.get(f"{k}_Train", 0.0))
        inf = v.get('inference_time', times.get(f"{k}_Infer", 0.0))
        print(f"{k:<15}{v['accuracy']:>8.3f}{v['precision']:>8.3f}{v['recall']:>8.3f}{v['f1_score']:>8.3f}{v['mean_ppdrq']:>9.3f}{v['ece']:>8.3f}{v['brier_score']:>8.3f}{tr:>10.1f}{inf:>10.1f}")

    # Enhanced visualizations and reports
    print("\n" + "="*60)
    print("GENERATING ENHANCED VISUALIZATIONS")
    print("="*60)
    
    # Plot enhanced training curves
    plot_enhanced_training_curves(histories, 'enhanced_training_curves.png')
    
    # Layer-wise analysis for main models
    for model_name, model in [('Baseline', baseline), ('PPDRQ_CE', ppdrq_model), ('Triplet', triplet_model)]:
        get_and_plot_layerwise_ppdrq(model, test_loader, model_name, f"- Test Set")

    # Unified enhanced table + plots
    df = save_results_table(results, times, "enhanced_results_all_models.csv")
    plot_enhanced_metric_bars(df, "Accuracy", "enhanced_accuracy_comparison.png")
    plot_enhanced_metric_bars(df, "F1_Score", "enhanced_f1_comparison.png")
    plot_enhanced_metric_bars(df, "ECE", "enhanced_ece_comparison.png")
    plot_enhanced_metric_bars(df, "Brier_Score", "enhanced_brier_comparison.png")
    
    # Enhanced cost-performance analysis
    plot_enhanced_cost_performance(df, perf_metric="Accuracy", cost_metric="Inference_Time_s", 
                                  fname="enhanced_cost_perf_analysis.png")
    
    # Comprehensive timing analysis
    timing_df = create_enhanced_timing_analysis(results, times)

    # Enhanced confusion matrices and reports
    print("\nGenerating enhanced confusion matrices and reports...")
    model_objs = {
        'Baseline': baseline,
        'PPDRQ_CE': ppdrq_model,
        'Triplet': triplet_model,
        'MC_Dropout': mc_model
    }
    
    # Ensemble wrapper
    class EnsembleWrapper(nn.Module):
        def __init__(self, members):
            super().__init__()
            self.members = members
        def forward(self, x):
            outs = []
            for m in self.members:
                m.eval()
                lo, _, _ = m(x)
                outs.append(lo.unsqueeze(0))
            logits = torch.mean(torch.cat(outs, 0), dim=0)
            return logits, torch.zeros(x.size(0), 512, device=logits.device), {'final_logits': logits}
    
    ens_wrapper = EnsembleWrapper(members)
    model_objs['Ensemble'] = ens_wrapper

    for name, m in model_objs.items():
        y_t, y_p, confs, prob = collect_predictions(m, test_loader)
        plot_enhanced_confusion(y_t, y_p, CLASS_NAMES, f"Confusion Matrix: {name}", 
                               f"enhanced_cm_{name}.png", normalize=True)
        export_per_class_report(y_t, y_p, CLASS_NAMES, f"enhanced_per_class_{name}.csv")

    # Enhanced reliability grid
    enhanced_reliability_grid(model_objs, test_loader, fname='enhanced_reliability_grid.png', num_bins=10)

    # Enhanced λ sensitivity analysis
    print("\n" + "="*60)
    print("CONDUCTING ENHANCED λ SENSITIVITY ANALYSIS")
    print("="*60)
    sensitivity_df = enhanced_lambda_sensitivity_sweep()

    # Unseen domain evaluation if available
    if len(unseen_dataset) > 0:
        print("\n" + "="*60)
        print("EVALUATING ON UNSEEN DOMAIN")
        print("="*60)
        # Load best weights if present
        for tag, model in [('Baseline', baseline), ('PPDRQ_CE', ppdrq_model), 
                          ('Triplet', triplet_model), ('MC_Dropout', mc_model)]:
            pth = os.path.join(Config.output_dir, f"{tag}_best.pth")
            if os.path.exists(pth):
                model.load_state_dict(torch.load(pth, map_location=Config.device))
        
        # Evaluate on unseen data
        res_b_u = evaluate_model(baseline, unseen_loader, 'Baseline_Unseen')
        res_p_u = evaluate_model(ppdrq_model, unseen_loader, 'PPDRQ_CE_Unseen')
        res_t_u = evaluate_model(triplet_model, unseen_loader, 'Triplet_Unseen')
        res_m_u = mc_dropout_predict(mc_model, unseen_loader, num_passes=Config.mc_dropout_passes)
        if len(members) == Config.num_ensemble_models:
            res_e_u = evaluate_deep_ensemble(members, unseen_loader)
        print("Unseen domain evaluation complete.")

    print("\n" + "="*80)
    print("ENHANCED ANALYSIS COMPLETE")
    print("="*80)
    print("Generated files include:")
    print("- Enhanced visualizations with professional styling")
    print("- Comprehensive methodological transparency report")
    print("- Detailed timing analysis")
    print("- λ sensitivity analysis with multiple metrics")
    print("- Enhanced confusion matrices and reliability diagrams")
    print("- Publication-ready figures with high DPI")
    print(f"\nAll results saved to: {Config.output_dir}/")
    print("="*80)