In [1]:
# ============================================================
# FULL PIPELINE (UPDATED)
# - Fixes your FileNotFoundError by auto-generating harmonized CSVs
#   into /kaggle/working/harmonized_labels if they don't exist.
# - Uses ViT-B/16 (torchvision) instead of ResNet50.
# - Keeps your Mixup DG training / metrics / plots the same.
# ============================================================

import os
import random
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import torchvision.transforms as transforms
import torchvision.models as models

from sklearn.metrics import roc_auc_score, f1_score, roc_curve, average_precision_score
import matplotlib.pyplot as plt
from tqdm import tqdm


# ============================================================
# 0) PATHS / CONFIG
# ============================================================

SEED = 42

SAVE_DIR = './results_lodo_mixup/fold_test_RFMiD_v1'
TEST_DOMAIN = "RFMiD_v1"
TRAIN_DOMAINS = "ODIR + RFMiD_v2"
MIXUP_ALPHA = 0.2

# Where we will write/read harmonized CSVs (THIS FIXES YOUR ERROR)
HARMONIZED_DIR = "/kaggle/working/harmonized_labels"

# Filenames produced by the harmonization step
TRAIN_CSVS = [
    f"{HARMONIZED_DIR}/ODIR_train.csv",
    f"{HARMONIZED_DIR}/RFMiD_v2_train.csv",
]
VAL_CSVS = [
    f"{HARMONIZED_DIR}/ODIR_val.csv",
    f"{HARMONIZED_DIR}/RFMiD_v2_val.csv",
]
TEST_CSV = f"{HARMONIZED_DIR}/RFMiD_v1_test.csv"


# ============================================================
# 1) SMALL HELPERS (ROBUST PATH RESOLUTION)
# ============================================================

def _pick_existing_path(candidates: List[str], what: str) -> str:
    """Pick the first existing path from candidates, else raise a clear error."""
    for p in candidates:
        if os.path.exists(p):
            return p
    msg = f"[ERROR] Could not find {what}. Tried:\n" + "\n".join(candidates)
    raise FileNotFoundError(msg)


# ============================================================
# 2) HARMONIZATION: ODIR / RFMiD v1 / RFMiD v2
# ============================================================

def harmonize_odir(split: str) -> pd.DataFrame:
    base_path = "/kaggle/input/odir-clr/ODIR_CLR"

    if split == "train":
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Training_Set/train_images",
                f"{base_path}/Training_ Set/train_images",   # some versions have this
            ],
            what="ODIR train_images directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Training_Set/train_annotation.xlsx",
                f"{base_path}/Training_ Set/train_annotation.xlsx",
            ],
            what="ODIR train_annotation.xlsx",
        )

    elif split == "val":
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Validation_set/val_images",
                f"{base_path}/Validation_Set/val_images",
            ],
            what="ODIR val_images directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Validation_set/val_annotation.xlsx",
                f"{base_path}/Validation_Set/val_annotation.xlsx",
            ],
            what="ODIR val_annotation.xlsx",
        )

    else:
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Test_Set/test_images",
                f"{base_path}/Test_set/test_images",
            ],
            what="ODIR test_images directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Test_Set/test_annotation.xlsx",
                f"{base_path}/Test_set/test_annotation.xlsx",
            ],
            what="ODIR test_annotation.xlsx",
        )

    df = pd.read_excel(label_file)

    label_cols = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']
    for c in ['ID', 'Left-Fundus', 'Right-Fundus'] + label_cols:
        if c not in df.columns:
            raise KeyError(f"[ERROR] ODIR annotation missing column: {c}")

    harmonized_rows = []
    for _, row in df.iterrows():
        patient_id = row['ID']
        labels = {col: int(row[col]) for col in label_cols}

        # Keep your original naming convention
        if pd.notna(row['Left-Fundus']) and str(row['Left-Fundus']).strip():
            left_path = os.path.join(img_dir, f"{patient_id}_left.jpg")
            harmonized_rows.append({
                'image_path': left_path,
                'dataset': 'ODIR',
                'split': split,
                'ID': patient_id,
                **labels
            })

        if pd.notna(row['Right-Fundus']) and str(row['Right-Fundus']).strip():
            right_path = os.path.join(img_dir, f"{patient_id}_right.jpg")
            harmonized_rows.append({
                'image_path': right_path,
                'dataset': 'ODIR',
                'split': split,
                'ID': patient_id,
                **labels
            })

    return pd.DataFrame(harmonized_rows)


def harmonize_rfmid_v1(split: str) -> pd.DataFrame:
    base_path = "/kaggle/input/retinal-disease-classification"

    if split == "train":
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Training_Set/Training_Set/Training",
                f"{base_path}/Training_Set/Training",
            ],
            what="RFMiD_v1 train image directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Training_Set/Training_Set/RFMiD_Training_Labels.csv",
                f"{base_path}/Training_Set/RFMiD_Training_Labels.csv",
            ],
            what="RFMiD_v1 RFMiD_Training_Labels.csv",
        )

    elif split == "val":
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Evaluation_Set/Evaluation_Set/Validation",
                f"{base_path}/Evaluation_Set/Validation",
            ],
            what="RFMiD_v1 val image directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Evaluation_Set/Evaluation_Set/RFMiD_Validation_Labels.csv",
                f"{base_path}/Evaluation_Set/RFMiD_Validation_Labels.csv",
            ],
            what="RFMiD_v1 RFMiD_Validation_Labels.csv",
        )

    else:
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Test_Set/Test_Set/Test",
                f"{base_path}/Test_Set/Test",
            ],
            what="RFMiD_v1 test image directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Test_Set/Test_Set/RFMiD_Testing_Labels.csv",
                f"{base_path}/Test_Set/RFMiD_Testing_Labels.csv",
            ],
            what="RFMiD_v1 RFMiD_Testing_Labels.csv",
        )

    df = pd.read_csv(label_file)

    required = ['ID', 'Disease_Risk', 'DR', 'ODC', 'MH', 'ARMD', 'HR', 'MYA']
    for c in required:
        if c not in df.columns:
            raise KeyError(f"[ERROR] RFMiD_v1 labels missing column: {c}")

    all_disease_cols = ['DR', 'ARMD', 'MH', 'DN', 'MYA', 'BRVO', 'TSLN',
                        'ERM', 'LS', 'MS', 'CSR', 'ODC', 'CRVO', 'TV', 'AH',
                        'ODP', 'ODE', 'ST', 'AION', 'PT', 'RT', 'RS', 'CRS',
                        'EDN', 'RPEC', 'MHL', 'RP', 'CWS', 'CB', 'ODPM',
                        'PRH', 'MNF', 'HR', 'CRAO', 'TD', 'CME', 'PTCR', 'CF',
                        'VH', 'MCA', 'VS', 'BRAO', 'PLQ', 'HPED', 'CL']
    all_disease_cols = [c for c in all_disease_cols if c in df.columns]

    harmonized_rows = []
    for _, row in df.iterrows():
        image_id = row['ID']

        image_path = None
        for ext in ['.png', '.jpg', '.jpeg', '.JPG', '.PNG', '.JPEG']:
            candidate = os.path.join(img_dir, f"{image_id}{ext}")
            if os.path.exists(candidate):
                image_path = candidate
                break
        if image_path is None:
            image_path = os.path.join(img_dir, str(image_id))

        N = 1 if row['Disease_Risk'] == 0 else 0
        D = 1 if row.get('DR', 0) == 1 else 0
        G = 1 if row.get('ODC', 0) == 1 else 0
        C = 1 if row.get('MH', 0) == 1 else 0
        A = 1 if row.get('ARMD', 0) == 1 else 0
        H = 1 if row.get('HR', 0) == 1 else 0
        M = 1 if row.get('MYA', 0) == 1 else 0

        used_for_mapping = ['DR', 'ODC', 'MH', 'ARMD', 'HR', 'MYA']
        other_cols = [col for col in all_disease_cols if col not in used_for_mapping]
        O = 1 if any(row.get(col, 0) == 1 for col in other_cols) else 0

        harmonized_rows.append({
            'image_path': image_path,
            'dataset': 'RFMiD_v1',
            'split': split,
            'ID': image_id,
            'N': N, 'D': D, 'G': G, 'C': C, 'A': A, 'H': H, 'M': M, 'O': O
        })

    return pd.DataFrame(harmonized_rows)


def harmonize_rfmid_v2(split: str) -> pd.DataFrame:
    base_path = "/kaggle/input/rdc-version-2/RFDiM2_0"

    if split == "train":
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Training_set_2/Train_2",
                f"{base_path}/Training_set_2/Train",
            ],
            what="RFMiD_v2 train image directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Training_set_2/RFMiD_2_Training_labels.csv",
            ],
            what="RFMiD_v2 training labels csv",
        )

    elif split == "val":
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Validation_set_2/Validation_2",
                f"{base_path}/Validation_set_2/Validation",
            ],
            what="RFMiD_v2 val image directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Validation_set_2/RFMiD_2_Validation_labels.csv",
            ],
            what="RFMiD_v2 validation labels csv",
        )

    else:
        img_dir = _pick_existing_path(
            [
                f"{base_path}/Test_set_2/Test_2",
                f"{base_path}/Test_set_2/Test",
            ],
            what="RFMiD_v2 test image directory",
        )
        label_file = _pick_existing_path(
            [
                f"{base_path}/Test_set_2/RFMiD_2_Testing_labels.csv",
            ],
            what="RFMiD_v2 testing labels csv",
        )

    try:
        df = pd.read_csv(label_file, encoding='utf-8')
    except UnicodeDecodeError:
        df = pd.read_csv(label_file, encoding='latin1')

    df.columns = df.columns.str.strip()

    if 'ID' not in df.columns:
        raise KeyError("[ERROR] RFMiD_v2 labels missing column: ID")

    potential_disease_cols = ['AH', 'AION', 'ARMD', 'BRVO', 'CB', 'CF', 'CL', 'CME',
                              'CNV', 'CRAO', 'CRS', 'CRVO', 'CSR', 'CWS', 'CSC', 'DN',
                              'DR', 'EDN', 'ERM', 'GRT', 'HPED', 'HR', 'LS', 'MCA',
                              'ME', 'MH', 'MHL', 'MS', 'MYA', 'ODC', 'ODE', 'ODP',
                              'ON', 'OPDM', 'PRH', 'RD', 'RHL', 'RTR', 'RP', 'RPEC',
                              'RS', 'RT', 'SOFE', 'ST', 'TD', 'TSLN', 'TV', 'VS',
                              'HTN', 'IIH', 'WNL']
    all_disease_cols = [col for col in potential_disease_cols if col in df.columns]

    harmonized_rows = []
    skipped_count = 0
    found_count = 0

    for _, row in df.iterrows():
        try:
            image_id = int(row['ID'])
        except Exception:
            image_id = row['ID']

        image_path = None
        for ext in ['.jpg', '.JPG', '.png', '.PNG', '.jpeg', '.JPEG']:
            candidate = os.path.join(img_dir, f"{image_id}{ext}")
            if os.path.exists(candidate):
                image_path = candidate
                found_count += 1
                break

        if image_path is None:
            skipped_count += 1
            continue

        wnl = row.get('WNL', 0)
        N = 1 if wnl == 1 else 0
        D = 1 if row.get('DR', 0) == 1 else 0
        G = 1 if row.get('ODC', 0) == 1 else 0
        C = 1 if row.get('MH', 0) == 1 else 0
        A = 1 if row.get('ARMD', 0) == 1 else 0

        # Hypertension mapping preference: HTN, fallback HR
        H = 1 if row.get('HTN', 0) == 1 else 0
        if H == 0:
            H = 1 if row.get('HR', 0) == 1 else 0

        M = 1 if row.get('MYA', 0) == 1 else 0

        used_for_mapping = ['DR', 'ODC', 'MH', 'ARMD', 'HTN', 'HR', 'MYA', 'WNL']
        other_cols = [col for col in all_disease_cols if col not in used_for_mapping]
        O = 1 if any(row.get(col, 0) == 1 for col in other_cols) else 0

        harmonized_rows.append({
            'image_path': image_path,
            'dataset': 'RFMiD_v2',
            'split': split,
            'ID': image_id,
            'N': N, 'D': D, 'G': G, 'C': C, 'A': A, 'H': H, 'M': M, 'O': O
        })

    if skipped_count > 0:
        print(f"[INFO] RFMiD_v2 {split}: Found {found_count}, Skipped {skipped_count}")

    return pd.DataFrame(harmonized_rows)


def harmonize_all_datasets() -> Dict[str, pd.DataFrame]:
    results = {}
    print("\n" + "="*60)
    print("HARMONIZING DATASETS")
    print("="*60)

    for split in ['train', 'val', 'test']:
        print(f"\nProcessing {split} split...")
        results[f'ODIR_{split}'] = harmonize_odir(split)
        results[f'RFMiD_v1_{split}'] = harmonize_rfmid_v1(split)
        results[f'RFMiD_v2_{split}'] = harmonize_rfmid_v2(split)

    return results


def save_harmonized_data(harmonized_data: Dict[str, pd.DataFrame], output_dir: str):
    os.makedirs(output_dir, exist_ok=True)

    print("\n" + "="*60)
    print("SAVING HARMONIZED DATA")
    print("="*60)

    for key, df in harmonized_data.items():
        output_path = os.path.join(output_dir, f"{key}.csv")
        df.to_csv(output_path, index=False)
        print(f"✅ Saved {key:20s}: {len(df):5d} rows → {output_path}")

    print(f"\n✨ All files saved to: {output_dir}")


def ensure_harmonized_csvs():
    """If required CSVs don't exist, run harmonization + save to HARMONIZED_DIR."""
    required = TRAIN_CSVS + VAL_CSVS + [TEST_CSV]
    missing = [p for p in required if not os.path.exists(p)]
    if not missing:
        print(f"[INFO] Found harmonized CSVs in: {HARMONIZED_DIR}")
        return

    print("\n[INFO] Harmonized CSVs missing. Will generate them now...")
    print("[INFO] Missing files:")
    for p in missing:
        print("  -", p)

    harmonized = harmonize_all_datasets()
    save_harmonized_data(harmonized, output_dir=HARMONIZED_DIR)

    # Final check
    still_missing = [p for p in required if not os.path.exists(p)]
    if still_missing:
        raise FileNotFoundError(
            "[ERROR] Tried to generate harmonized CSVs but some are still missing:\n"
            + "\n".join(still_missing)
        )
    print("[INFO] ✓ Harmonized CSV generation complete.")


# ============================================================
# 3) TRAINING UTILITIES
# ============================================================

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)


class RetinalDataset(Dataset):
    def __init__(self, csv_path, domain_id, transform=None):
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"[ERROR] CSV not found: {csv_path}")

        self.data = pd.read_csv(csv_path)
        self.domain_id = domain_id
        self.transform = transform
        self.label_cols = ["N", "D", "G", "C", "A", "H", "M", "O"]

        dup = self.data.duplicated(subset=["image_path"], keep="first").sum()
        if dup > 0:
            print(f"[WARN] {dup} duplicates found in {csv_path}, keeping first")
        self.data = self.data.drop_duplicates(subset=["image_path"]).reset_index(drop=True)

        valid_mask = self.data["image_path"].apply(os.path.exists)
        missing = (~valid_mask).sum()
        if missing > 0:
            print(f"[WARN] Dropping {missing} missing images from {csv_path}")
        self.data = self.data.loc[valid_mask].reset_index(drop=True)

        print(f"[INFO] Domain {domain_id}: Loaded {len(self.data)} valid images from {os.path.basename(csv_path)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        try:
            image = Image.open(row["image_path"]).convert("RGB")
        except Exception:
            image = Image.new("RGB", (224, 224), color=(128, 128, 128))

        labels = torch.tensor(row[self.label_cols].values.astype("float32"))
        if self.transform:
            image = self.transform(image)
        return image, labels, self.domain_id


def calculate_pos_weights(datasets, clip_min=0.5, clip_max=50.0):
    all_labels = []
    for ds in datasets:
        all_labels.append(ds.data[["N", "D", "G", "C", "A", "H", "M", "O"]].values)
    combined = np.vstack(all_labels)
    pos = combined.sum(axis=0)
    neg = len(combined) - pos
    raw = neg / (pos + 1e-5)
    clipped = np.clip(raw, clip_min, clip_max)

    print("\n[INFO] Positive class weights (from training domains):")
    for i, col in enumerate(["N", "D", "G", "C", "A", "H", "M", "O"]):
        print(f"  {col}: {clipped[i]:.2f}")

    return torch.tensor(clipped, dtype=torch.float32)


# ============================================================
# 4) MODEL: ViT-B/16
# ============================================================

class MixupMultiLabelViT(nn.Module):
    def __init__(self, num_classes=8, dropout=0.3):
        super().__init__()
        self.backbone = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)

        hidden_dim = self.backbone.heads.head.in_features
        self.backbone.heads = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        feats = self.backbone(x)
        return self.classifier(feats)


def get_transforms(is_train=False):
    if is_train:
        return transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop(224),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    else:
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])


# ============================================================
# 5) METRICS
# ============================================================

def compute_metrics(labels, probs, thresholds=None):
    n_classes = labels.shape[1]

    aucs = []
    for i in range(n_classes):
        if len(np.unique(labels[:, i])) > 1:
            aucs.append(roc_auc_score(labels[:, i], probs[:, i]))
        else:
            aucs.append(np.nan)

    aps = []
    for i in range(n_classes):
        if len(np.unique(labels[:, i])) > 1:
            aps.append(average_precision_score(labels[:, i], probs[:, i]))
        else:
            aps.append(np.nan)

    if thresholds is None:
        thresholds = np.full(n_classes, 0.5)

    preds = (probs >= thresholds).astype(int)
    f1 = f1_score(labels, preds, average="macro", zero_division=0)

    return {
        "mAUC": float(np.nanmean(aucs)),
        "mAP": float(np.nanmean(aps)),
        "per_class_auc": aucs,
        "per_class_ap": aps,
        "macro_f1": float(f1),
    }


def find_optimal_thresholds(labels, probs):
    n_classes = labels.shape[1]
    thresholds = []
    search_range = np.linspace(0.05, 0.95, 91)

    for i in range(n_classes):
        best_f1, best_t = 0.0, 0.5
        if len(np.unique(labels[:, i])) > 1:
            for t in search_range:
                preds = (probs[:, i] >= t).astype(int)
                f1 = f1_score(labels[:, i], preds, zero_division=0)
                if f1 > best_f1:
                    best_f1, best_t = f1, t
        thresholds.append(best_t)

    return np.array(thresholds)


# ============================================================
# 6) MIXUP TRAINING
# ============================================================

def train_epoch_mixup(model, loaders_dict, criterion, optimizer, device, mixup_alpha=0.2):
    model.train()
    losses = []
    mixup_stats = {'total_batches': 0, 'mixup_batches': 0}

    domain_iters = {k: iter(v) for k, v in loaders_dict.items()}
    domain_ids = list(loaders_dict.keys())

    max_batches = max(len(loader) for loader in loaders_dict.values())
    pbar = tqdm(range(max_batches), desc="Train (Mixup)", leave=False)

    for _ in pbar:
        if len(domain_ids) >= 2 and np.random.rand() > 0.5:
            d1, d2 = np.random.choice(domain_ids, size=2, replace=False)

            try:
                imgs1, labels1, _ = next(domain_iters[d1])
            except StopIteration:
                domain_iters[d1] = iter(loaders_dict[d1])
                imgs1, labels1, _ = next(domain_iters[d1])

            try:
                imgs2, labels2, _ = next(domain_iters[d2])
            except StopIteration:
                domain_iters[d2] = iter(loaders_dict[d2])
                imgs2, labels2, _ = next(domain_iters[d2])

            min_size = min(imgs1.size(0), imgs2.size(0))
            imgs1, labels1 = imgs1[:min_size], labels1[:min_size]
            imgs2, labels2 = imgs2[:min_size], labels2[:min_size]

            lam = np.random.beta(mixup_alpha, mixup_alpha)

            mixed_imgs = lam * imgs1 + (1 - lam) * imgs2
            mixed_labels = lam * labels1 + (1 - lam) * labels2

            mixed_imgs = mixed_imgs.to(device)
            mixed_labels = mixed_labels.to(device)

            optimizer.zero_grad()
            logits = model(mixed_imgs)
            loss = criterion(logits, mixed_labels)
            mixup_stats['mixup_batches'] += 1

        else:
            d = np.random.choice(domain_ids)
            try:
                imgs, labels, _ = next(domain_iters[d])
            except StopIteration:
                domain_iters[d] = iter(loaders_dict[d])
                imgs, labels, _ = next(domain_iters[d])

            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        mixup_stats['total_batches'] += 1
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    return {
        "loss": float(np.mean(losses)),
        "mixup_ratio": float(mixup_stats['mixup_batches'] / mixup_stats['total_batches'])
    }


@torch.no_grad()
def validate(model, loader, criterion, device, thresholds=None):
    model.eval()
    losses, all_probs, all_labels = [], [], []

    for batch in tqdm(loader, desc="Val", leave=False):
        if len(batch) == 3:
            images, labels, _ = batch
        else:
            images, labels = batch

        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)

        losses.append(loss.item())
        all_probs.append(torch.sigmoid(logits).cpu().numpy())
        all_labels.append(labels.cpu().numpy())

    probs = np.vstack(all_probs)
    labels = np.vstack(all_labels)
    metrics = compute_metrics(labels, probs, thresholds)
    metrics["loss"] = float(np.mean(losses))
    return metrics, probs, labels


# ============================================================
# 7) PLOTS
# ============================================================

def plot_training_curves(history, save_dir):
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    epochs = range(1, len(history['train_loss']) + 1)

    axes[0].plot(epochs, history['train_loss'], linewidth=2, label='Train Loss', marker='o')
    axes[0].plot(epochs, history['val_loss'], linewidth=2, label='Val Loss', marker='s')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss (Mixup)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(epochs, history['val_auc'], linewidth=2, label='Val mAUC', marker='s')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('mAUC')
    axes[1].set_title('Validation mAUC (Mixup)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    axes[1].set_ylim([0, 1.0])

    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/training_curves.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("[INFO] ✓ Saved training curves")


def plot_per_class_roc(labels, probs, test_domain, save_dir):
    class_names = ["N", "D", "G", "C", "A", "H", "M", "O"]
    n_classes = 8

    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()

    for i in range(n_classes):
        if len(np.unique(labels[:, i])) > 1:
            fpr, tpr, _ = roc_curve(labels[:, i], probs[:, i])
            auc = roc_auc_score(labels[:, i], probs[:, i])
            axes[i].plot(fpr, tpr, linewidth=2, label=f'AUC = {auc:.3f}')
            axes[i].plot([0, 1], [0, 1], 'k--', linewidth=1, alpha=0.5)
            axes[i].set_title(f'Class {class_names[i]}')
            axes[i].set_xlabel('FPR')
            axes[i].set_ylabel('TPR')
            axes[i].legend(loc='lower right')
            axes[i].grid(True, alpha=0.3)
        else:
            axes[i].text(0.5, 0.5, 'Single class\n(No ROC)', ha='center', va='center')
            axes[i].set_title(f'Class {class_names[i]}')

    plt.suptitle(f'Per-Class ROC Curves - Test on {test_domain} (Mixup)', y=0.995)
    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/per_class_roc_curves.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("[INFO] ✓ Saved per-class ROC curves")


def plot_macro_roc(labels, probs, test_domain, save_dir):
    n_classes = 8
    fig, ax = plt.subplots(figsize=(8, 6))

    all_fpr, all_tpr = [], []
    for i in range(n_classes):
        if len(np.unique(labels[:, i])) > 1:
            fpr, tpr, _ = roc_curve(labels[:, i], probs[:, i])
            all_fpr.append(fpr)
            all_tpr.append(tpr)

    if len(all_fpr) == 0:
        print("[WARN] Cannot plot macro ROC: no class has both positives & negatives.")
        return

    mean_fpr = np.linspace(0, 1, 100)
    tprs = [np.interp(mean_fpr, fpr, tpr) for fpr, tpr in zip(all_fpr, all_tpr)]
    mean_tpr = np.mean(tprs, axis=0)
    macro_auc = np.trapz(mean_tpr, mean_fpr)

    ax.plot(mean_fpr, mean_tpr, linewidth=3, label=f'Macro-avg ROC (AUC = {macro_auc:.3f})')
    ax.plot([0, 1], [0, 1], 'k--', linewidth=2, alpha=0.5, label='Random')
    ax.set_xlabel('FPR')
    ax.set_ylabel('TPR')
    ax.set_title(f'Macro-Average ROC - Test on {test_domain} (Mixup)')
    ax.legend(loc='lower right')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/macro_roc_curve.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("[INFO] ✓ Saved macro-average ROC curve")


def plot_per_class_metrics(test_metrics, save_dir):
    class_names = ["N", "D", "G", "C", "A", "H", "M", "O"]
    aucs = test_metrics['per_class_auc']
    aps = test_metrics['per_class_ap']

    aucs = [a if not np.isnan(a) else 0 for a in aucs]
    aps = [a if not np.isnan(a) else 0 for a in aps]

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    x = np.arange(len(class_names))

    axes[0].bar(x, aucs, alpha=0.85, edgecolor='black', linewidth=1.2)
    axes[0].set_title('Per-Class AUC (Mixup)')
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(class_names)
    axes[0].set_ylim([0, 1.0])
    axes[0].grid(True, alpha=0.3, axis='y')

    axes[1].bar(x, aps, alpha=0.85, edgecolor='black', linewidth=1.2)
    axes[1].set_title('Per-Class AP (Mixup)')
    axes[1].set_xticks(x)
    axes[1].set_xticklabels(class_names)
    axes[1].set_ylim([0, 1.0])
    axes[1].grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f'{save_dir}/per_class_metrics.png', dpi=300, bbox_inches='tight')
    plt.close()
    print("[INFO] ✓ Saved per-class metrics chart")


# ============================================================
# 8) MAIN
# ============================================================

if __name__ == "__main__":
    set_seed(SEED)
    os.makedirs(SAVE_DIR, exist_ok=True)

    print("="*80)
    print(f"LODO FOLD (SWAPPED) with MIXUP DG: Test on {TEST_DOMAIN}")
    print(f"Training on: {TRAIN_DOMAINS}")
    print("="*80)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    print(f"Seed: {SEED}")
    print(f"Mixup Alpha: {MIXUP_ALPHA}")
    print("="*80)

    # ---- FIX: ensure harmonized CSVs exist BEFORE loading datasets
    ensure_harmonized_csvs()

    print("\n[INFO] Loading datasets...")
    train_datasets = [RetinalDataset(csv, domain_id=i, transform=get_transforms(True))
                      for i, csv in enumerate(TRAIN_CSVS)]
    val_datasets = [RetinalDataset(csv, domain_id=i, transform=get_transforms(False))
                    for i, csv in enumerate(VAL_CSVS)]
    test_dataset = RetinalDataset(TEST_CSV, domain_id=999, transform=get_transforms(False))

    print(f"\n[INFO] Total train: {sum(len(ds) for ds in train_datasets)} images")
    print(f"[INFO] Total val: {sum(len(ds) for ds in val_datasets)} images")
    print(f"[INFO] Test: {len(test_dataset)} images")

    pos_weights = calculate_pos_weights(train_datasets).to(device)

    g = torch.Generator().manual_seed(SEED)
    train_loaders = {
        i: DataLoader(ds, batch_size=32, shuffle=True, num_workers=4,
                      pin_memory=True, generator=g)
        for i, ds in enumerate(train_datasets)
    }

    combined_val = ConcatDataset(val_datasets)
    val_loader = DataLoader(combined_val, batch_size=32, shuffle=False,
                            num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                             num_workers=4, pin_memory=True)

    print("\n[INFO] Initializing ViT-B/16 model...")
    model = MixupMultiLabelViT().to(device)

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)

    # ViT usually prefers smaller LRs than ResNet
    optimizer = optim.Adam([
        {'params': model.backbone.parameters(), 'lr': 1e-5},
        {'params': model.classifier.parameters(), 'lr': 1e-4},
    ], weight_decay=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3
    )

    history = {'train_loss': [], 'val_loss': [], 'val_auc': []}

    best_val_auc, patience_counter = 0.0, 0
    print("\n[INFO] Training started with Mixup DG (ViT-B/16)...")
    print("-"*80)

    for epoch in range(50):
        train_metrics = train_epoch_mixup(
            model, train_loaders, criterion, optimizer, device, mixup_alpha=MIXUP_ALPHA
        )
        val_metrics, _, _ = validate(model, val_loader, criterion, device)
        scheduler.step(val_metrics['mAUC'])

        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_auc'].append(val_metrics['mAUC'])

        print(f"Epoch {epoch+1:02d} | Train Loss: {train_metrics['loss']:.4f} "
              f"| Val mAUC: {val_metrics['mAUC']:.4f} | Mixup: {train_metrics['mixup_ratio']:.1%}")

        if val_metrics['mAUC'] > best_val_auc:
            best_val_auc = val_metrics['mAUC']
            torch.save(model.state_dict(), f'{SAVE_DIR}/best_model.pth')
            print(f"  ✓ Saved best model (val mAUC: {best_val_auc:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= 10:
            print(f"\n[INFO] Early stopping at epoch {epoch+1}")
            break

    print(f"\n[INFO] Loading best model...")
    model.load_state_dict(torch.load(f'{SAVE_DIR}/best_model.pth', map_location=device))
    print(f"[INFO] Best validation mAUC: {best_val_auc:.4f}")

    print(f"\n[INFO] Finding optimal thresholds on validation...")
    _, val_probs, val_labels = validate(model, val_loader, criterion, device)
    thresholds = find_optimal_thresholds(val_labels, val_probs)

    class_names = ["N", "D", "G", "C", "A", "H", "M", "O"]
    print("[INFO] Optimal thresholds:")
    for c, t in zip(class_names, thresholds):
        print(f"  {c}: {t:.3f}")

    print(f"\n[INFO] Testing on {TEST_DOMAIN}...")
    print("-"*80)
    test_metrics, test_probs, test_labels = validate(
        model, test_loader, criterion, device, thresholds
    )

    print("\n" + "="*80)
    print(f"TEST RESULTS - {TEST_DOMAIN} (Mixup DG, ViT-B/16)")
    print("="*80)
    print(f"mAUC:      {test_metrics['mAUC']:.4f}")
    print(f"mAP:       {test_metrics['mAP']:.4f}")
    print(f"Macro F1:  {test_metrics['macro_f1']:.4f}")
    print("="*80)

    print(f"\n{'Class':<8} {'AUC':<10} {'AP':<10}")
    print("-"*30)
    for i, cls in enumerate(class_names):
        auc = test_metrics['per_class_auc'][i]
        ap = test_metrics['per_class_ap'][i]
        auc_str = f"{auc:.4f}" if not np.isnan(auc) else "N/A"
        ap_str = f"{ap:.4f}" if not np.isnan(ap) else "N/A"
        print(f"{cls:<8} {auc_str:<10} {ap_str:<10}")
    print("-"*30)

    print("\n[INFO] Saving results...")
    results_df = pd.DataFrame([{
        'method': 'Mixup',
        'backbone': 'vit_b_16',
        'test_domain': TEST_DOMAIN,
        'train_domains': TRAIN_DOMAINS,
        'mAUC': test_metrics['mAUC'],
        'mAP': test_metrics['mAP'],
        'macro_f1': test_metrics['macro_f1'],
        'best_val_auc': best_val_auc,
        'mixup_alpha': MIXUP_ALPHA
    }])
    results_df.to_csv(f'{SAVE_DIR}/test_results.csv', index=False)
    print(f"[INFO] ✓ Saved test_results.csv")

    print("\n[INFO] Generating visualizations...")
    plot_training_curves(history, SAVE_DIR)
    plot_per_class_roc(test_labels, test_probs, TEST_DOMAIN, SAVE_DIR)
    plot_macro_roc(test_labels, test_probs, TEST_DOMAIN, SAVE_DIR)
    plot_per_class_metrics(test_metrics, SAVE_DIR)

    print("\n" + "="*80)
    print("✓ MIXUP LODO FOLD (SWAPPED) COMPLETE! (ViT-B/16)")
    print(f"✓ Results saved to: {SAVE_DIR}/")
    print("="*80)


LODO FOLD (SWAPPED) with MIXUP DG: Test on RFMiD_v1
Training on: ODIR + RFMiD_v2
Device: cuda
Seed: 42
Mixup Alpha: 0.2

[INFO] Harmonized CSVs missing. Will generate them now...
[INFO] Missing files:
  - /kaggle/working/harmonized_labels/ODIR_train.csv
  - /kaggle/working/harmonized_labels/RFMiD_v2_train.csv
  - /kaggle/working/harmonized_labels/ODIR_val.csv
  - /kaggle/working/harmonized_labels/RFMiD_v2_val.csv
  - /kaggle/working/harmonized_labels/RFMiD_v1_test.csv

HARMONIZING DATASETS

Processing train split...
[INFO] RFMiD_v2 train: Found 507, Skipped 2

Processing val split...

Processing test split...
[INFO] RFMiD_v2 test: Found 170, Skipped 4

SAVING HARMONIZED DATA
✅ Saved ODIR_train          :  7000 rows → /kaggle/working/harmonized_labels/ODIR_train.csv
✅ Saved RFMiD_v1_train      :  1920 rows → /kaggle/working/harmonized_labels/RFMiD_v1_train.csv
✅ Saved RFMiD_v2_train      :   507 rows → /kaggle/working/harmonized_labels/RFMiD_v2_train.csv
✅ Saved ODIR_val            :  1

100%|██████████| 330M/330M [00:01<00:00, 180MB/s]  



[INFO] Training started with Mixup DG (ViT-B/16)...
--------------------------------------------------------------------------------


                                                                             

Epoch 01 | Train Loss: 0.9637 | Val mAUC: 0.8046 | Mixup: 53.9%
  ✓ Saved best model (val mAUC: 0.8046)


                                                                             

Epoch 02 | Train Loss: 0.8047 | Val mAUC: 0.7935 | Mixup: 44.3%


                                                                             

Epoch 03 | Train Loss: 0.7324 | Val mAUC: 0.8004 | Mixup: 51.6%


                                                                             

Epoch 04 | Train Loss: 0.6861 | Val mAUC: 0.8035 | Mixup: 48.4%


                                                                             

Epoch 05 | Train Loss: 0.6535 | Val mAUC: 0.8035 | Mixup: 50.2%


                                                                             

Epoch 06 | Train Loss: 0.6066 | Val mAUC: 0.8174 | Mixup: 46.6%
  ✓ Saved best model (val mAUC: 0.8174)


                                                                             

Epoch 07 | Train Loss: 0.5684 | Val mAUC: 0.8198 | Mixup: 53.4%
  ✓ Saved best model (val mAUC: 0.8198)


                                                                             

Epoch 08 | Train Loss: 0.5231 | Val mAUC: 0.8138 | Mixup: 46.6%


                                                                             

Epoch 09 | Train Loss: 0.5684 | Val mAUC: 0.8223 | Mixup: 52.1%
  ✓ Saved best model (val mAUC: 0.8223)


                                                                             

Epoch 10 | Train Loss: 0.4889 | Val mAUC: 0.8175 | Mixup: 52.5%


                                                                             

Epoch 11 | Train Loss: 0.5202 | Val mAUC: 0.8165 | Mixup: 48.4%


                                                                             

Epoch 12 | Train Loss: 0.5410 | Val mAUC: 0.8152 | Mixup: 51.1%


                                                                             

Epoch 13 | Train Loss: 0.4783 | Val mAUC: 0.8133 | Mixup: 47.5%


                                                                             

Epoch 14 | Train Loss: 0.5026 | Val mAUC: 0.8125 | Mixup: 52.1%


                                                                             

Epoch 15 | Train Loss: 0.4723 | Val mAUC: 0.8161 | Mixup: 42.5%


                                                                             

Epoch 16 | Train Loss: 0.5063 | Val mAUC: 0.8120 | Mixup: 46.6%


                                                                             

Epoch 17 | Train Loss: 0.4978 | Val mAUC: 0.8166 | Mixup: 50.7%


                                                                             

Epoch 18 | Train Loss: 0.4467 | Val mAUC: 0.8138 | Mixup: 46.1%


                                                                             

Epoch 19 | Train Loss: 0.5087 | Val mAUC: 0.8172 | Mixup: 53.4%

[INFO] Early stopping at epoch 19

[INFO] Loading best model...
[INFO] Best validation mAUC: 0.8223

[INFO] Finding optimal thresholds on validation...


                                                    

[INFO] Optimal thresholds:
  N: 0.500
  D: 0.220
  G: 0.550
  C: 0.910
  A: 0.730
  H: 0.640
  M: 0.820
  O: 0.480

[INFO] Testing on RFMiD_v1...
--------------------------------------------------------------------------------


                                                    


TEST RESULTS - RFMiD_v1 (Mixup DG, ViT-B/16)
mAUC:      0.8293
mAP:       0.4709
Macro F1:  0.4102

Class    AUC        AP        
------------------------------
N        0.8856     0.6902    
D        0.8261     0.5735    
G        0.7426     0.3669    
C        0.8926     0.7821    
A        0.8343     0.2636    
H        0.9061     0.0164    
M        0.9274     0.5711    
O        0.6195     0.5039    
------------------------------

[INFO] Saving results...
[INFO] ✓ Saved test_results.csv

[INFO] Generating visualizations...
[INFO] ✓ Saved training curves
[INFO] ✓ Saved per-class ROC curves


  macro_auc = np.trapz(mean_tpr, mean_fpr)


[INFO] ✓ Saved macro-average ROC curve
[INFO] ✓ Saved per-class metrics chart

✓ MIXUP LODO FOLD (SWAPPED) COMPLETE! (ViT-B/16)
✓ Results saved to: ./results_lodo_mixup/fold_test_RFMiD_v1/
