In [1]:
# ============================================================
# FULL UPDATED CODE: ConvNeXt-Tiny + Mixup DG (FIXED SHAPES)
# Fixes:
#  - mean/std KeyError (no weights.meta["mean"])
#  - ImageClassification preset has no .transforms attribute
#  - ConvNeXt output is [B, C, 1, 1] -> flatten to [B, C] before Linear
# ============================================================

import os
import random
from typing import Dict, Tuple, Any

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.models import ConvNeXt_Tiny_Weights

from sklearn.metrics import roc_auc_score, f1_score, roc_curve, average_precision_score
import matplotlib.pyplot as plt
from tqdm import tqdm


# ============================================================
# 1) HARMONIZATION
# ============================================================

def harmonize_odir(split: str) -> pd.DataFrame:
    base_path = "/kaggle/input/odir-clr/ODIR_CLR"

    if split == "train":
        img_dir = f"{base_path}/Training_ Set/train_images"
        label_file = f"{base_path}/Training_ Set/train_annotation.xlsx"
    elif split == "val":
        img_dir = f"{base_path}/Validation_set/val_images"
        label_file = f"{base_path}/Validation_set/val_annotation.xlsx"
    else:
        img_dir = f"{base_path}/Test_Set/test_images"
        label_file = f"{base_path}/Test_Set/test_annotation.xlsx"

    df = pd.read_excel(label_file)
    label_cols = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']
    harmonized_rows = []

    for _, row in df.iterrows():
        patient_id = row['ID']
        labels = {col: int(row[col]) for col in label_cols}

        if pd.notna(row['Left-Fundus']) and str(row['Left-Fundus']).strip():
            left_path = os.path.join(img_dir, f"{patient_id}_left.jpg")
            harmonized_rows.append({
                'image_path': left_path,
                'dataset': 'ODIR',
                'split': split,
                'ID': patient_id,
                **labels
            })

        if pd.notna(row['Right-Fundus']) and str(row['Right-Fundus']).strip():
            right_path = os.path.join(img_dir, f"{patient_id}_right.jpg")
            harmonized_rows.append({
                'image_path': right_path,
                'dataset': 'ODIR',
                'split': split,
                'ID': patient_id,
                **labels
            })

    return pd.DataFrame(harmonized_rows)


def harmonize_rfmid_v1(split: str) -> pd.DataFrame:
    base_path = "/kaggle/input/retinal-disease-classification"

    if split == "train":
        img_dir = f"{base_path}/Training_Set/Training_Set/Training"
        label_file = f"{base_path}/Training_Set/Training_Set/RFMiD_Training_Labels.csv"
    elif split == "val":
        img_dir = f"{base_path}/Evaluation_Set/Evaluation_Set/Validation"
        label_file = f"{base_path}/Evaluation_Set/Evaluation_Set/RFMiD_Validation_Labels.csv"
    else:
        img_dir = f"{base_path}/Test_Set/Test_Set/Test"
        label_file = f"{base_path}/Test_Set/Test_Set/RFMiD_Testing_Labels.csv"

    df = pd.read_csv(label_file)

    all_disease_cols = [
        'DR', 'ARMD', 'MH', 'DN', 'MYA', 'BRVO', 'TSLN', 'ERM', 'LS', 'MS', 'CSR',
        'ODC', 'CRVO', 'TV', 'AH', 'ODP', 'ODE', 'ST', 'AION', 'PT', 'RT', 'RS',
        'CRS', 'EDN', 'RPEC', 'MHL', 'RP', 'CWS', 'CB', 'ODPM', 'PRH', 'MNF',
        'HR', 'CRAO', 'TD', 'CME', 'PTCR', 'CF', 'VH', 'MCA', 'VS', 'BRAO',
        'PLQ', 'HPED', 'CL'
    ]

    harmonized_rows = []

    for _, row in df.iterrows():
        image_id = row['ID']

        image_path = None
        for ext in ['.png', '.jpg', '.jpeg']:
            candidate = os.path.join(img_dir, f"{image_id}{ext}")
            if os.path.exists(candidate):
                image_path = candidate
                break
        if image_path is None:
            image_path = os.path.join(img_dir, str(image_id))

        N = 1 if row['Disease_Risk'] == 0 else 0
        D = 1 if row['DR'] == 1 else 0
        G = 1 if row['ODC'] == 1 else 0
        C = 1 if row['MH'] == 1 else 0
        A = 1 if row['ARMD'] == 1 else 0
        H = 1 if row['HR'] == 1 else 0
        M = 1 if row['MYA'] == 1 else 0

        used_for_mapping = ['DR', 'ODC', 'MH', 'ARMD', 'HR', 'MYA']
        other_cols = [col for col in all_disease_cols if col not in used_for_mapping]
        O = 1 if any(row[col] == 1 for col in other_cols) else 0

        harmonized_rows.append({
            'image_path': image_path,
            'dataset': 'RFMiD_v1',
            'split': split,
            'ID': image_id,
            'N': N, 'D': D, 'G': G, 'C': C, 'A': A, 'H': H, 'M': M, 'O': O
        })

    return pd.DataFrame(harmonized_rows)


def harmonize_rfmid_v2(split: str) -> pd.DataFrame:
    base_path = "/kaggle/input/rdc-version-2/RFDiM2_0"

    if split == "train":
        img_dir = f"{base_path}/Training_set_2/Train_2"
        label_file = f"{base_path}/Training_set_2/RFMiD_2_Training_labels.csv"
    elif split == "val":
        img_dir = f"{base_path}/Validation_set_2/Validation_2"
        label_file = f"{base_path}/Validation_set_2/RFMiD_2_Validation_labels.csv"
    else:
        img_dir = f"{base_path}/Test_set_2/Test_2"
        label_file = f"{base_path}/Test_set_2/RFMiD_2_Testing_labels.csv"

    try:
        df = pd.read_csv(label_file, encoding="utf-8")
    except UnicodeDecodeError:
        df = pd.read_csv(label_file, encoding="latin1")

    df.columns = df.columns.str.strip()

    potential_disease_cols = [
        'AH', 'AION', 'ARMD', 'BRVO', 'CB', 'CF', 'CL', 'CME', 'CNV', 'CRAO',
        'CRS', 'CRVO', 'CSR', 'CWS', 'CSC', 'DN', 'DR', 'EDN', 'ERM', 'GRT',
        'HPED', 'HR', 'LS', 'MCA', 'ME', 'MH', 'MHL', 'MS', 'MYA', 'ODC',
        'ODE', 'ODP', 'ON', 'OPDM', 'PRH', 'RD', 'RHL', 'RTR', 'RP', 'RPEC',
        'RS', 'RT', 'SOFE', 'ST', 'TD', 'TSLN', 'TV', 'VS', 'HTN', 'IIH'
    ]
    all_disease_cols = [c for c in potential_disease_cols if c in df.columns]

    harmonized_rows = []
    skipped_count, found_count = 0, 0

    for _, row in df.iterrows():
        image_id = int(row['ID'])

        image_path = None
        for ext in ['.jpg', '.JPG', '.png', '.PNG', '.jpeg', '.JPEG']:
            candidate = os.path.join(img_dir, f"{image_id}{ext}")
            if os.path.exists(candidate):
                image_path = candidate
                found_count += 1
                break

        if image_path is None:
            skipped_count += 1
            continue

        wnl = row.get('WNL', 0)
        N = 1 if wnl == 1 else 0
        D = 1 if row.get('DR', 0) == 1 else 0
        G = 1 if row.get('ODC', 0) == 1 else 0
        C = 1 if row.get('MH', 0) == 1 else 0
        A = 1 if row.get('ARMD', 0) == 1 else 0

        H = 1 if row.get('HTN', 0) == 1 else 0
        if H == 0 and 'HR' in df.columns:
            H = 1 if row.get('HR', 0) == 1 else 0

        M = 1 if row.get('MYA', 0) == 1 else 0

        used_for_mapping = ['DR', 'ODC', 'MH', 'ARMD', 'HTN', 'HR', 'MYA', 'WNL']
        other_cols = [c for c in all_disease_cols if c not in used_for_mapping]
        O = 1 if any(row.get(c, 0) == 1 for c in other_cols) else 0

        harmonized_rows.append({
            'image_path': image_path,
            'dataset': 'RFMiD_v2',
            'split': split,
            'ID': image_id,
            'N': N, 'D': D, 'G': G, 'C': C, 'A': A, 'H': H, 'M': M, 'O': O
        })

    if skipped_count > 0:
        print(f"RFMiD_v2 {split}: Found {found_count}, Skipped {skipped_count}")

    return pd.DataFrame(harmonized_rows)


def harmonize_all_datasets() -> Dict[str, pd.DataFrame]:
    results = {}
    print("\n" + "=" * 60)
    print("HARMONIZING DATASETS")
    print("=" * 60)

    for split in ['train', 'val', 'test']:
        print(f"\nProcessing {split} split...")
        results[f'ODIR_{split}'] = harmonize_odir(split)
        results[f'RFMiD_v1_{split}'] = harmonize_rfmid_v1(split)
        results[f'RFMiD_v2_{split}'] = harmonize_rfmid_v2(split)

    return results


def print_statistics(harmonized_data: Dict[str, pd.DataFrame]):
    label_cols = ['N', 'D', 'G', 'C', 'A', 'H', 'M', 'O']

    print("\n" + "=" * 60)
    print("DATASET STATISTICS")
    print("=" * 60)

    print("\nSample Counts:")
    print("-" * 40)
    for key, df in sorted(harmonized_data.items()):
        print(f"{key:20s}: {len(df):5d} images")

    total_train = sum(len(df) for k, df in harmonized_data.items() if 'train' in k)
    total_val = sum(len(df) for k, df in harmonized_data.items() if 'val' in k)
    total_test = sum(len(df) for k, df in harmonized_data.items() if 'test' in k)

    print("-" * 40)
    print(f"{'Total Train':20s}: {total_train:5d} images")
    print(f"{'Total Val':20s}: {total_val:5d} images")
    print(f"{'Total Test':20s}: {total_test:5d} images")
    print(f"{'Grand Total':20s}: {total_train + total_val + total_test:5d} images")


def save_harmonized_data(harmonized_data: Dict[str, pd.DataFrame], output_dir: str = './harmonized_labels'):
    os.makedirs(output_dir, exist_ok=True)

    print("\n" + "=" * 60)
    print("SAVING HARMONIZED DATA")
    print("=" * 60)

    for key, df in harmonized_data.items():
        output_path = os.path.join(output_dir, f"{key}.csv")
        df.to_csv(output_path, index=False)
        print(f"Saved {key:20s}: {len(df):5d} rows -> {output_path}")

    print("\nCreating combined files...")
    for split in ['train', 'val', 'test']:
        split_dfs = [v for k, v in harmonized_data.items() if k.endswith(f'_{split}')]
        if split_dfs:
            combined = pd.concat(split_dfs, ignore_index=True)
            output_path = os.path.join(output_dir, f"combined_{split}.csv")
            combined.to_csv(output_path, index=False)
            print(f"Saved combined_{split:5s}: {len(combined):5d} rows -> {output_path}")

    print(f"\nAll files saved to: {output_dir}")


def verify_images(harmonized_data: Dict[str, pd.DataFrame]):
    print("\n" + "=" * 60)
    print("VERIFYING IMAGE PATHS")
    print("=" * 60)

    all_good = True
    for key, df in sorted(harmonized_data.items()):
        missing = df[~df['image_path'].apply(os.path.exists)]
        if len(missing) > 0:
            print(f"{key:20s}: {len(missing)} missing images")
            all_good = False
        else:
            print(f"{key:20s}: All {len(df)} images found")

    if all_good:
        print("\nAll image paths verified successfully!")
    else:
        print("\nSome images are missing - check the paths above")


# ============================================================
# 2) TRAINING CONFIG
# ============================================================

SEED = 42
SAVE_DIR = './results_lodo_mixup/fold_test_RFMiD_v1'
TEST_DOMAIN = "RFMiD_v1"
TRAIN_DOMAINS = "ODIR + RFMiD_v2"
MIXUP_ALPHA = 0.2

TEST_CSV = '/kaggle/working/harmonized_labels/RFMiD_v1_test.csv'
TRAIN_CSVS = [
    '/kaggle/working/harmonized_labels/ODIR_train.csv',
    '/kaggle/working/harmonized_labels/RFMiD_v2_train.csv'
]
VAL_CSVS = [
    '/kaggle/working/harmonized_labels/ODIR_val.csv',
    '/kaggle/working/harmonized_labels/RFMiD_v2_val.csv'
]


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)


# ============================================================
# Dataset
# ============================================================

class RetinalDataset(Dataset):
    def __init__(self, csv_path, domain_id, transform=None):
        self.data = pd.read_csv(csv_path)
        self.domain_id = domain_id
        self.transform = transform
        self.label_cols = ["N", "D", "G", "C", "A", "H", "M", "O"]

        dup = self.data.duplicated(subset=["image_path"], keep="first").sum()
        if dup > 0:
            print(f"[WARN] {dup} duplicates found, keeping first")
        self.data = self.data.drop_duplicates(subset=["image_path"]).reset_index(drop=True)

        valid_mask = self.data["image_path"].apply(os.path.exists)
        missing = int((~valid_mask).sum())
        if missing > 0:
            print(f"[WARN] Dropping {missing} missing images")
        self.data = self.data.loc[valid_mask].reset_index(drop=True)

        print(f"[INFO] Domain {domain_id}: Loaded {len(self.data)} valid images")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        try:
            image = Image.open(row["image_path"]).convert("RGB")
        except Exception:
            image = Image.new("RGB", (224, 224), color=(128, 128, 128))

        labels = torch.tensor(row[self.label_cols].values.astype("float32"))

        if self.transform:
            image = self.transform(image)

        return image, labels, self.domain_id


# ============================================================
# Pos weights
# ============================================================

def calculate_pos_weights(datasets, clip_min=0.5, clip_max=50.0):
    label_cols = ["N", "D", "G", "C", "A", "H", "M", "O"]
    combined = np.vstack([ds.data[label_cols].values for ds in datasets])

    pos = combined.sum(axis=0)
    neg = len(combined) - pos
    raw = neg / (pos + 1e-5)
    clipped = np.clip(raw, clip_min, clip_max)

    print("\n[INFO] Positive class weights (from training domains):")
    for i, col in enumerate(label_cols):
        print(f"  {col}: {clipped[i]:.2f}")

    return torch.tensor(clipped, dtype=torch.float32)


# ============================================================
# ConvNeXt preprocessing params (robust across torchvision)
# ============================================================

def _to_int_square(x: Any, default: int) -> int:
    if x is None:
        return default
    if isinstance(x, int):
        return x
    if isinstance(x, (tuple, list)) and len(x) > 0:
        return int(x[0])
    return default


def get_convnext_preprocess_params():
    weights = ConvNeXt_Tiny_Weights.DEFAULT
    preset = weights.transforms()  # callable preset

    mean = getattr(preset, "mean", None)
    std = getattr(preset, "std", None)

    # fallback
    if mean is None or std is None:
        meta = getattr(weights, "meta", {}) or {}
        mean = meta.get("mean", (0.485, 0.456, 0.406))
        std = meta.get("std", (0.229, 0.224, 0.225))

    crop_size = _to_int_square(getattr(preset, "crop_size", None), default=224)
    resize_size = _to_int_square(getattr(preset, "resize_size", None), default=256)

    try:
        interp = transforms.InterpolationMode.BILINEAR
    except Exception:
        interp = 2  # PIL BILINEAR int fallback

    antialias = bool(getattr(preset, "antialias", True))
    return tuple(mean), tuple(std), crop_size, resize_size, interp, antialias


MEAN, STD, CROP, RESIZE, INTERP, ANTIALIAS = get_convnext_preprocess_params()


def _safe_resize(size):
    try:
        return transforms.Resize(size, interpolation=INTERP, antialias=ANTIALIAS)
    except TypeError:
        return transforms.Resize(size, interpolation=INTERP)


def get_transforms(is_train=False):
    if is_train:
        return transforms.Compose([
            _safe_resize((RESIZE, RESIZE)),
            transforms.RandomCrop(CROP),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2),
            transforms.ToTensor(),
            transforms.Normalize(MEAN, STD),
        ])
    else:
        return transforms.Compose([
            _safe_resize((RESIZE, RESIZE)),
            transforms.CenterCrop(CROP),
            transforms.ToTensor(),
            transforms.Normalize(MEAN, STD),
        ])


# ============================================================
# Model: ConvNeXt-Tiny (FIXED FORWARD)
# ============================================================

class ConvNeXtTinyMultiLabel(nn.Module):
    def __init__(self, num_classes=8, dropout=0.3, pretrained=True):
        super().__init__()
        weights = ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None
        self.backbone = models.convnext_tiny(weights=weights)

        # convnext_tiny final feature dim is usually 768
        self.feature_dim = self.backbone.classifier[-1].in_features  # safe way

        # We'll NOT use the built-in classifier; we use features + avgpool ourselves
        self.backbone.classifier = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(self.feature_dim, 512),
            nn.GELU(),
            nn.BatchNorm1d(512),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        # IMPORTANT: use features -> avgpool -> flatten -> classifier
        x = self.backbone.features(x)      # [B, C, H, W]
        x = self.backbone.avgpool(x)       # [B, C, 1, 1]
        x = torch.flatten(x, 1)            # [B, C]
        return self.classifier(x)          # [B, num_classes]


# ============================================================
# Metrics
# ============================================================

def compute_metrics(labels, probs, thresholds=None):
    n_classes = labels.shape[1]

    aucs = []
    for i in range(n_classes):
        if len(np.unique(labels[:, i])) > 1:
            aucs.append(roc_auc_score(labels[:, i], probs[:, i]))
        else:
            aucs.append(np.nan)

    aps = []
    for i in range(n_classes):
        if len(np.unique(labels[:, i])) > 1:
            aps.append(average_precision_score(labels[:, i], probs[:, i]))
        else:
            aps.append(np.nan)

    if thresholds is None:
        thresholds = np.full(n_classes, 0.5)

    preds = (probs >= thresholds).astype(int)
    f1 = f1_score(labels, preds, average="macro", zero_division=0)

    return {
        "mAUC": float(np.nanmean(aucs)),
        "mAP": float(np.nanmean(aps)),
        "per_class_auc": aucs,
        "per_class_ap": aps,
        "macro_f1": float(f1),
    }


def find_optimal_thresholds(labels, probs):
    n_classes = labels.shape[1]
    thresholds = []
    search_range = np.linspace(0.05, 0.95, 91)

    for i in range(n_classes):
        best_f1, best_t = 0.0, 0.5
        if len(np.unique(labels[:, i])) > 1:
            for t in search_range:
                preds = (probs[:, i] >= t).astype(int)
                f1 = f1_score(labels[:, i], preds, zero_division=0)
                if f1 > best_f1:
                    best_f1, best_t = f1, t
        thresholds.append(best_t)

    return np.array(thresholds)


# ============================================================
# Mixup Training + Validation
# ============================================================

def train_epoch_mixup(model, loaders_dict, criterion, optimizer, device, mixup_alpha=0.2):
    model.train()
    losses = []
    mixup_stats = {'total_batches': 0, 'mixup_batches': 0}

    domain_iters = {k: iter(v) for k, v in loaders_dict.items()}
    domain_ids = list(loaders_dict.keys())

    max_batches = max(len(loader) for loader in loaders_dict.values())
    pbar = tqdm(range(max_batches), desc="Train (Mixup)", leave=False)

    for _ in pbar:
        do_mixup = (len(domain_ids) >= 2) and (np.random.rand() > 0.5)

        if do_mixup:
            d1, d2 = np.random.choice(domain_ids, size=2, replace=False)

            try:
                imgs1, labels1, _ = next(domain_iters[d1])
            except StopIteration:
                domain_iters[d1] = iter(loaders_dict[d1])
                imgs1, labels1, _ = next(domain_iters[d1])

            try:
                imgs2, labels2, _ = next(domain_iters[d2])
            except StopIteration:
                domain_iters[d2] = iter(loaders_dict[d2])
                imgs2, labels2, _ = next(domain_iters[d2])

            min_size = min(imgs1.size(0), imgs2.size(0))
            imgs1, labels1 = imgs1[:min_size], labels1[:min_size]
            imgs2, labels2 = imgs2[:min_size], labels2[:min_size]

            lam = np.random.beta(mixup_alpha, mixup_alpha)

            mixed_imgs = lam * imgs1 + (1 - lam) * imgs2
            mixed_labels = lam * labels1 + (1 - lam) * labels2

            mixed_imgs = mixed_imgs.to(device)
            mixed_labels = mixed_labels.to(device)

            optimizer.zero_grad()
            logits = model(mixed_imgs)
            loss = criterion(logits, mixed_labels)

            mixup_stats['mixup_batches'] += 1

        else:
            d = np.random.choice(domain_ids)
            try:
                imgs, labels, _ = next(domain_iters[d])
            except StopIteration:
                domain_iters[d] = iter(loaders_dict[d])
                imgs, labels, _ = next(domain_iters[d])

            imgs = imgs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        mixup_stats['total_batches'] += 1
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    return {
        "loss": float(np.mean(losses)),
        "mixup_ratio": mixup_stats['mixup_batches'] / mixup_stats['total_batches']
    }


@torch.no_grad()
def validate(model, loader, criterion, device, thresholds=None):
    model.eval()
    losses, all_probs, all_labels = [], [], []

    for batch in tqdm(loader, desc="Val", leave=False):
        if len(batch) == 3:
            images, labels, _ = batch
        else:
            images, labels = batch

        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = criterion(logits, labels)

        losses.append(loss.item())
        all_probs.append(torch.sigmoid(logits).cpu().numpy())
        all_labels.append(labels.cpu().numpy())

    probs = np.vstack(all_probs)
    labels = np.vstack(all_labels)
    metrics = compute_metrics(labels, probs, thresholds)
    metrics["loss"] = float(np.mean(losses))
    return metrics, probs, labels


# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    # If you already created harmonized CSVs, comment these out.
    harmonized_data = harmonize_all_datasets()
    print_statistics(harmonized_data)
    verify_images(harmonized_data)
    save_harmonized_data(harmonized_data, output_dir="./harmonized_labels")

    set_seed(SEED)
    os.makedirs(SAVE_DIR, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("=" * 80)
    print(f"LODO FOLD with MIXUP DG: Test on {TEST_DOMAIN}")
    print(f"Training on: {TRAIN_DOMAINS}")
    print("=" * 80)
    print(f"Device: {device}")
    print(f"Seed: {SEED}")
    print(f"Mixup Alpha: {MIXUP_ALPHA}")
    print("=" * 80)

    print("\n[INFO] Loading datasets...")
    train_datasets = [RetinalDataset(csv, domain_id=i, transform=get_transforms(True))
                      for i, csv in enumerate(TRAIN_CSVS)]
    val_datasets = [RetinalDataset(csv, domain_id=i, transform=get_transforms(False))
                    for i, csv in enumerate(VAL_CSVS)]
    test_dataset = RetinalDataset(TEST_CSV, domain_id=999, transform=get_transforms(False))

    print(f"\n[INFO] Total train: {sum(len(ds) for ds in train_datasets)}")
    print(f"[INFO] Total val:   {sum(len(ds) for ds in val_datasets)}")
    print(f"[INFO] Test:        {len(test_dataset)}")

    pos_weights = calculate_pos_weights(train_datasets).to(device)

    g = torch.Generator().manual_seed(SEED)
    train_loaders = {
        i: DataLoader(ds, batch_size=32, shuffle=True, num_workers=4,
                      pin_memory=True, generator=g)
        for i, ds in enumerate(train_datasets)
    }

    combined_val = ConcatDataset(val_datasets)
    val_loader = DataLoader(combined_val, batch_size=32, shuffle=False,
                            num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                             num_workers=4, pin_memory=True)

    print("\n[INFO] Initializing ConvNeXt-Tiny model...")
    model = ConvNeXtTinyMultiLabel(num_classes=8, dropout=0.3, pretrained=True).to(device)

    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

    history = {'train_loss': [], 'val_loss': [], 'val_auc': []}
    best_val_auc, patience_counter = 0.0, 0

    print("\n[INFO] Training started...")
    print("-" * 80)

    for epoch in range(50):
        train_metrics = train_epoch_mixup(model, train_loaders, criterion, optimizer, device,
                                          mixup_alpha=MIXUP_ALPHA)
        val_metrics, _, _ = validate(model, val_loader, criterion, device)
        scheduler.step(val_metrics['mAUC'])

        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_auc'].append(val_metrics['mAUC'])

        print(f"Epoch {epoch+1:02d} | Train Loss: {train_metrics['loss']:.4f} "
              f"| Val Loss: {val_metrics['loss']:.4f} | Val mAUC: {val_metrics['mAUC']:.4f} "
              f"| Mixup: {train_metrics['mixup_ratio']:.1%}")

        if val_metrics['mAUC'] > best_val_auc:
            best_val_auc = val_metrics['mAUC']
            torch.save(model.state_dict(), f'{SAVE_DIR}/best_model.pth')
            print(f"  [OK] Saved best model (val mAUC: {best_val_auc:.4f})")
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= 10:
            print(f"\n[INFO] Early stopping at epoch {epoch+1}")
            break

    print(f"\n[INFO] Loading best model...")
    model.load_state_dict(torch.load(f'{SAVE_DIR}/best_model.pth', map_location=device))
    print(f"[INFO] Best validation mAUC: {best_val_auc:.4f}")

    print("\n[INFO] Finding optimal thresholds on validation...")
    _, val_probs, val_labels = validate(model, val_loader, criterion, device)
    thresholds = find_optimal_thresholds(val_labels, val_probs)

    class_names = ["N", "D", "G", "C", "A", "H", "M", "O"]
    print("[INFO] Optimal thresholds:")
    for c, t in zip(class_names, thresholds):
        print(f"  {c}: {t:.3f}")

    print(f"\n[INFO] Testing on {TEST_DOMAIN}...")
    test_metrics, _, _ = validate(model, test_loader, criterion, device, thresholds)

    print("\n" + "=" * 80)
    print(f"TEST RESULTS - {TEST_DOMAIN} (ConvNeXt-Tiny + Mixup)")
    print("=" * 80)
    print(f"mAUC:      {test_metrics['mAUC']:.4f}")
    print(f"mAP:       {test_metrics['mAP']:.4f}")
    print(f"Macro F1:  {test_metrics['macro_f1']:.4f}")
    print("=" * 80)


HARMONIZING DATASETS

Processing train split...
RFMiD_v2 train: Found 507, Skipped 2

Processing val split...

Processing test split...
RFMiD_v2 test: Found 170, Skipped 4

DATASET STATISTICS

Sample Counts:
----------------------------------------
ODIR_test           :  2000 images
ODIR_train          :  7000 images
ODIR_val            :  1000 images
RFMiD_v1_test       :   640 images
RFMiD_v1_train      :  1920 images
RFMiD_v1_val        :   640 images
RFMiD_v2_test       :   170 images
RFMiD_v2_train      :   507 images
RFMiD_v2_val        :   177 images
----------------------------------------
Total Train         :  9427 images
Total Val           :  1817 images
Total Test          :  2810 images
Grand Total         : 14054 images

VERIFYING IMAGE PATHS
ODIR_test           : All 2000 images found
ODIR_train          : All 7000 images found
ODIR_val            : All 1000 images found
RFMiD_v1_test       : All 640 images found
RFMiD_v1_train      : All 1920 images found
RFMiD_v1_val

100%|██████████| 109M/109M [00:00<00:00, 205MB/s] 



[INFO] Training started...
--------------------------------------------------------------------------------


                                                                             

Epoch 01 | Train Loss: 0.9446 | Val Loss: 0.9234 | Val mAUC: 0.8163 | Mixup: 53.9%
  [OK] Saved best model (val mAUC: 0.8163)


                                                                             

Epoch 02 | Train Loss: 0.7124 | Val Loss: 0.9093 | Val mAUC: 0.8096 | Mixup: 44.3%


                                                                             

Epoch 03 | Train Loss: 0.6028 | Val Loss: 0.8906 | Val mAUC: 0.8099 | Mixup: 51.6%


                                                                             

Epoch 04 | Train Loss: 0.5717 | Val Loss: 0.8895 | Val mAUC: 0.8349 | Mixup: 48.4%
  [OK] Saved best model (val mAUC: 0.8349)


                                                                             

Epoch 05 | Train Loss: 0.5602 | Val Loss: 0.9278 | Val mAUC: 0.8150 | Mixup: 50.2%


                                                                             

Epoch 06 | Train Loss: 0.5532 | Val Loss: 0.9737 | Val mAUC: 0.8334 | Mixup: 46.6%


                                                                             

Epoch 07 | Train Loss: 0.5029 | Val Loss: 1.1131 | Val mAUC: 0.8181 | Mixup: 53.4%


                                                                             

Epoch 08 | Train Loss: 0.4517 | Val Loss: 1.0811 | Val mAUC: 0.8156 | Mixup: 46.6%


                                                                             

Epoch 09 | Train Loss: 0.4816 | Val Loss: 1.0919 | Val mAUC: 0.8327 | Mixup: 52.1%


                                                                             

Epoch 10 | Train Loss: 0.3712 | Val Loss: 1.1169 | Val mAUC: 0.8328 | Mixup: 52.5%


                                                                             

Epoch 11 | Train Loss: 0.3890 | Val Loss: 1.1644 | Val mAUC: 0.8325 | Mixup: 48.4%


                                                                             

Epoch 12 | Train Loss: 0.3965 | Val Loss: 1.2922 | Val mAUC: 0.8189 | Mixup: 51.1%


                                                                             

Epoch 13 | Train Loss: 0.3155 | Val Loss: 1.2657 | Val mAUC: 0.8218 | Mixup: 47.5%


                                                                             

Epoch 14 | Train Loss: 0.3315 | Val Loss: 1.2757 | Val mAUC: 0.8327 | Mixup: 52.1%

[INFO] Early stopping at epoch 14

[INFO] Loading best model...
[INFO] Best validation mAUC: 0.8349

[INFO] Finding optimal thresholds on validation...


                                                    

[INFO] Optimal thresholds:
  N: 0.400
  D: 0.200
  G: 0.770
  C: 0.950
  A: 0.870
  H: 0.420
  M: 0.830
  O: 0.450

[INFO] Testing on RFMiD_v1...


                                                    


TEST RESULTS - RFMiD_v1 (ConvNeXt-Tiny + Mixup)
mAUC:      0.8414
mAP:       0.4879
Macro F1:  0.3586


