In [None]:
import os
import random
import copy
import json
from dataclasses import dataclass
from typing import List, Tuple, Dict, Any

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import StandardScaler

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_undirected, coalesce

# ----------------- BASIC CONFIG & UTILS ----------------- #

def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

@dataclass
class TrainConfig:
    # General training
    num_epochs: int = 300
    patience: int = 30
    kfold: int = 5
    warmup: int = 10

    # Pretrain (cross-disease)
    pretrain_epochs: int = 300
    pretrain_patience: int = 30
    pretrain_lr: float = 1e-2

    # Search space
    depth_options: Tuple[int, ...] = (1, 2)
    hidden_options: Tuple[Tuple[int, int], ...] = ((64, 64), (64, 128), (128, 128))
    dropout_options: Tuple[float, ...] = (0.3, 0.4, 0.5)
    lr_options: Tuple[float, ...] = (1e-2, 3e-3, 1e-3)
    weight_decay_options: Tuple[float, ...] = (1e-4, 5e-4, 1e-3)

    # IO
    base_dir: str = "/content/drive/MyDrive/...."
    target_disease: str = "BRCA"
    results_dir: str = "./results_mtdriver"

# ----------------- DATA & LABEL LOADING ----------------- #

def load_graph_and_features(cfg: TrainConfig):
    ppi_path = os.path.join(cfg.base_dir, "PPI_CPDB.csv")
    ppi_df = pd.read_csv(ppi_path)

    col1, col2 = ppi_df.columns[:2]
    g1 = ppi_df[col1].astype(str).values
    g2 = ppi_df[col2].astype(str).values

    genes_edge = np.unique(np.concatenate([g1, g2]))

    feat_path = os.path.join(cfg.base_dir, f"features_for_{cfg.target_disease}.csv")
    feat_df = pd.read_csv(feat_path, index_col=0)
    feat_df.index = feat_df.index.astype(str)
    feat_genes = feat_df.index.values

    all_genes = np.unique(np.concatenate([genes_edge, feat_genes]))
    node_names = all_genes
    num_nodes = len(node_names)
    print(f"[INFO] #nodes (CPDB union features) = {num_nodes}")

    node_to_idx: Dict[str, int] = {g: i for i, g in enumerate(node_names)}

    src = np.array([node_to_idx[g] for g in g1], dtype=np.int64)
    dst = np.array([node_to_idx[g] for g in g2], dtype=np.int64)
    edge_index = torch.tensor(np.stack([src, dst], axis=0), dtype=torch.long)

    edge_index = to_undirected(edge_index, num_nodes=num_nodes)
    edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
    print(f"[INFO] #edges after undirected+coalesce = {edge_index.size(1)}")

    feature_dim = feat_df.shape[1]
    x_mat = np.zeros((num_nodes, feature_dim), dtype=np.float32)
    has_feat = np.zeros(num_nodes, dtype=bool)

    genes_in_both = feat_df.index.intersection(pd.Index(node_names))
    print(f"[INFO] #genes with features ∩ graph = {len(genes_in_both)}")

    scaler = StandardScaler()
    feat_scaled = scaler.fit_transform(feat_df.loc[genes_in_both].values)
    feat_scaled_df = pd.DataFrame(feat_scaled, index=genes_in_both, columns=feat_df.columns)

    for g in genes_in_both:
        idx = node_to_idx[g]
        x_mat[idx] = feat_scaled_df.loc[g].values
        has_feat[idx] = True

    neighbors = {i: [] for i in range(num_nodes)}
    for u, v in edge_index.t().tolist():
        neighbors[u].append(v)
        neighbors[v].append(u)

    for i in range(num_nodes):
        if not has_feat[i]:
            nb = [x_mat[n] for n in neighbors[i] if has_feat[n]]
            if nb:
                x_mat[i] = np.mean(nb, axis=0)

    x = torch.tensor(x_mat, dtype=torch.float32)
    data = Data(x=x, edge_index=edge_index)
    return data, node_names, node_to_idx

def load_labels(cfg: TrainConfig, data: Data, node_to_idx: Dict[str, int]):
    num_nodes = data.num_nodes
    y1 = torch.full((num_nodes,), -1, dtype=torch.long)
    y2 = torch.full((num_nodes,), -1, dtype=torch.long)

    y1_path = os.path.join(cfg.base_dir, f"{cfg.target_disease}_labels(0_1).csv")
    y1_df = pd.read_csv(y1_path)
    y1_df["Gene"] = y1_df["Gene"].astype(str)
    y1_map = dict(zip(y1_df['Gene'], y1_df['Labels']))
    for g, lab in y1_map.items():
        if g in node_to_idx:
            y1[node_to_idx[g]] = int(lab)

    y2_path = os.path.join(cfg.base_dir, "dataset", "label_telomere.csv")
    y2_df = pd.read_csv(y2_path)
    y2_df["Gene"] = y2_df["Gene"].astype(str)
    y2_map = dict(zip(y2_df['Gene'], y2_df['Labels']))
    for g, lab in y2_map.items():
        if g in node_to_idx:
            y2[node_to_idx[g]] = int(lab)

    data = data.to(DEVICE)
    data.y = y1.to(DEVICE)
    data.y2 = y2.to(DEVICE)
    labeled_idx = (data.y != -1).nonzero(as_tuple=True)[0]
    print(f"[INFO] #labeled (y1 != -1) = {len(labeled_idx)}")
    return data, labeled_idx

# ----------------- PRETRAIN META ----------------- #

def build_disease_dict(base_dir: str) -> Dict[str, Dict[str, str]]:
    disease_list = ["BRCA", "LUAD", "CESC", "BLCA", "LIHC", "THCA",
                    "ESCA", "PRAD", "STAD", "COAD", "UCEC", "LUSC"]
    diseases = {
        d: {
            "Y1": os.path.join(base_dir, f"{d}_labels(0_1).csv"),
        }
        for d in disease_list
    }
    return diseases

# ----------------- MODEL DEFINITIONS ----------------- #

class ResidualGCNEncoder(nn.Module):
    def __init__(self, in_dim: int, hidden_dims: List[int],
                 dropout: float = 0.5, use_layernorm: bool = False):
        super().__init__()
        assert len(hidden_dims) >= 1
        self.convs = nn.ModuleList()
        last = in_dim
        for h in hidden_dims:
            self.convs.append(GCNConv(last, h))
            last = h
        self.res_proj = nn.Linear(in_dim, hidden_dims[-1], bias=False) if in_dim != hidden_dims[-1] else nn.Identity()
        self.ln = nn.LayerNorm(hidden_dims[-1]) if use_layernorm else nn.Identity()
        self.dropout = dropout

        for conv in self.convs:
            nn.init.xavier_uniform_(conv.lin.weight)
        if isinstance(self.res_proj, nn.Linear):
            nn.init.xavier_uniform_(self.res_proj.weight)

    def forward(self, x, edge_index):
        h = x
        for conv in self.convs:
            h = conv(h, edge_index)
            h = F.relu(h)
            h = F.dropout(h, p=self.dropout, training=self.training)
        h = h + (self.res_proj(x) if isinstance(self.res_proj, nn.Linear) else x)
        h = self.ln(h)
        return h

class GCN_Residual_TwoHeads(nn.Module):
    def __init__(self, in_dim: int, hidden_dims: List[int],
                 dropout: float = 0.5, use_layernorm: bool = False,
                 head_hidden: int = None):
        super().__init__()
        self.encoder = ResidualGCNEncoder(in_dim, hidden_dims,
                                          dropout=dropout,
                                          use_layernorm=use_layernorm)
        hd = hidden_dims[-1]
        if head_hidden is None:
            head_hidden = hd
        self.shared = nn.Sequential(
            nn.Linear(hd, head_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.out_y1 = nn.Linear(head_hidden, 1)
        self.out_y2 = nn.Linear(head_hidden, 1)

        for m in self.shared:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_uniform_(m.weight, nonlinearity="relu")
        nn.init.kaiming_uniform_(self.out_y1.weight, nonlinearity="sigmoid")
        nn.init.kaiming_uniform_(self.out_y2.weight, nonlinearity="sigmoid")

    def forward(self, x, edge_index, return_h: bool = False):
        h = self.encoder(x, edge_index)
        h = self.shared(h)
        logit1 = self.out_y1(h).squeeze(-1)
        logit2 = self.out_y2(h).squeeze(-1)
        if return_h:
            return logit1, logit2, h
        return logit1, logit2

class LearnableAlpha(nn.Module):
    def __init__(self, init_alpha: float = 0.7):
        super().__init__()
        init_logit = torch.logit(torch.tensor(float(init_alpha)))
        self.logit_alpha = nn.Parameter(init_logit)
    def forward(self):
        return torch.sigmoid(self.logit_alpha)

# ----------------- LOSS / METRIC UTILS ----------------- #

@torch.no_grad()
def auprc_on_mask(logits: torch.Tensor, y_long: torch.Tensor, mask: torch.Tensor) -> float:
    if mask.sum().item() == 0:
        return 0.0
    labels = y_long[mask].detach().cpu().numpy()
    if (labels == 1).sum() == 0 or (labels == 0).sum() == 0:
        return 0.0
    probs = torch.sigmoid(logits[mask]).detach().cpu().numpy()
    return float(average_precision_score(labels, probs))

def bce_pos_weight_from_mask(y_long: torch.Tensor, mask: torch.Tensor, device) -> torch.Tensor:
    yy = y_long[mask]
    if yy.numel() == 0:
        return None
    pos = (yy == 1).sum().item()
    neg = (yy == 0).sum().item()
    if pos == 0:
        return None
    return torch.tensor([neg / float(pos)], dtype=torch.float, device=device)

# ----------------- CROSS-DISEASE PRETRAIN ----------------- #

def build_cross_disease_pretrain_labels(
    node_to_idx: Dict[str, int],
    target_name: str,
    data_y: torch.Tensor,
    base_dir: str
):
    diseases = build_disease_dict(base_dir)
    num_nodes = len(node_to_idx)
    y_pre = torch.full((num_nodes,), -1, dtype=torch.long)
    has_any_label = torch.zeros(num_nodes, dtype=torch.bool)

    target_labeled_mask = (data_y.detach().cpu() != -1)

    for d, meta in diseases.items():
        if d == target_name:
            continue
        df = pd.read_csv(meta["Y1"])
        df["Gene"] = df["Gene"].astype(str)
        for g, lab in zip(df["Gene"], df["Labels"]):
            if g in node_to_idx:
                idx = node_to_idx[g]
                has_any_label[idx] = True
                if int(lab) == 1:
                    y_pre[idx] = 1

    unlabeled_any = (~has_any_label)
    y_pre[unlabeled_any] = 0

    pretrain_mask = ((y_pre == 1) | (y_pre == 0)) & (~target_labeled_mask)

    pos_ct = int(((y_pre == 1) & pretrain_mask).sum().item())
    neg_ct = int(((y_pre == 0) & pretrain_mask).sum().item())
    excl   = int((target_labeled_mask & ((y_pre == 1) | (y_pre == 0))).sum().item())
    print(f"[Pretrain-XDisease] Pos: {pos_ct} | Neg: {neg_ct} | Excluded target-labeled: {excl}")
    return y_pre.to(DEVICE), pretrain_mask.to(DEVICE)

def pretrain_on_cross_disease(
    data: Data,
    y_pre: torch.Tensor,
    pretrain_mask: torch.Tensor,
    hidden_dims: List[int],
    cfg: TrainConfig,
    dropout: float,
    weight_decay: float,
):
    model = GCN_Residual_TwoHeads(
        in_dim=data.num_features,
        hidden_dims=hidden_dims,
        dropout=dropout,
        use_layernorm=False,
        head_hidden=None,
    ).to(DEVICE)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=cfg.pretrain_lr,
        weight_decay=weight_decay,
    )

    yy = y_pre[pretrain_mask]
    pos = int((yy == 1).sum().item())
    neg = int((yy == 0).sum().item())
    if pos == 0:
        print("[Pretrain-XDisease] No positive samples; skipping pretraining.")
        return None

    pos_weight = torch.tensor([neg / float(pos)], dtype=torch.float, device=DEVICE)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    best_state = copy.deepcopy(model.state_dict())
    best_loss = float("inf")
    wait = 0

    for ep in range(1, cfg.pretrain_epochs + 1):
        model.train()
        optimizer.zero_grad()
        logit1, _ = model(data.x, data.edge_index)
        loss = criterion(logit1[pretrain_mask], y_pre[pretrain_mask].float())
        loss.backward()
        optimizer.step()

        cur_loss = float(loss.item())
        if cur_loss < best_loss - 1e-6:
            best_loss = cur_loss
            best_state = copy.deepcopy(model.state_dict())
            wait = 0
        else:
            wait += 1
            if wait >= cfg.pretrain_patience:
                print(f"[Pretrain-XDisease] Early stop @ epoch {ep}, best_loss={best_loss:.4f}")
                break

    pre_state = {
        k: v
        for k, v in best_state.items()
        if k.startswith("encoder.") or k.startswith("shared.") or k.startswith("out_y1.")
    }
    return pre_state

# ----------------- TRAIN / EVAL (MULTITASK) ----------------- #

def train_one_epoch_dual(
    model: nn.Module,
    data: Data,
    train_mask: torch.Tensor,
    y2_mask_train: torch.Tensor,
    opt: torch.optim.Optimizer,
    crit_y1,
    crit_y2,
    alpha_module: LearnableAlpha,
    epoch: int,
    warmup: int = 10,
):
    model.train()
    opt.zero_grad()
    logit1, logit2 = model(data.x, data.edge_index)

    loss1 = (
        crit_y1(logit1[train_mask], data.y[train_mask].float())
        if (crit_y1 is not None and train_mask.sum().item() > 0)
        else torch.tensor(0.0, device=logit1.device)
    )
    loss2 = (
        crit_y2(logit2[y2_mask_train], data.y2[y2_mask_train].float())
        if (crit_y2 is not None and y2_mask_train.sum().item() > 0)
        else torch.tensor(0.0, device=logit2.device)
    )

    if epoch < warmup:
        total = loss1
        alpha_val = 1.0
    else:
        alpha = alpha_module()
        total = alpha * loss1 + (1.0 - alpha) * loss2
        alpha_val = float(alpha.item())

    total.backward()
    opt.step()
    return float(loss1.item()), float(loss2.item()), float(total.item()), alpha_val

@torch.no_grad()
def evaluate_y1(model: nn.Module, data: Data, mask: torch.Tensor, criterion_y1):
    model.eval()
    logit1, _ = model(data.x, data.edge_index)
    loss = (
        criterion_y1(logit1[mask], data.y[mask].float())
        if (criterion_y1 is not None and mask.sum().item() > 0)
        else 0.0
    )
    auprc = auprc_on_mask(logit1, data.y, mask)
    return float(loss), float(auprc)

# ----------------- GRID SEARCH ----------------- #

def grid_search_for_outer_fold(
    data: Data,
    outer_trainval_idx: torch.Tensor,
    outer_test_idx: torch.Tensor,
    y_pre: torch.Tensor,
    pretrain_mask: torch.Tensor,
    cfg: TrainConfig,
    base_seed: int,
):

    best_hp = None
    best_val_auc = -1.0

    y_tv = data.y[outer_trainval_idx].detach().cpu().numpy()

    for depth in cfg.depth_options:
        for hd_pair in cfg.hidden_options:
            hidden_dims = [hd_pair[0]] if depth == 1 else list(hd_pair)
            for dr in cfg.dropout_options:
                for lr in cfg.lr_options:
                    for wd in cfg.weight_decay_options:
                        print("\n[GRID-FOLD] depth={}, hidden={}, drop={}, lr={}, wd={}".format(
                            depth, hidden_dims, dr, lr, wd))

                        set_seed(base_seed + 1000)

                        pretrained_state = pretrain_on_cross_disease(
                            data=data,
                            y_pre=y_pre,
                            pretrain_mask=pretrain_mask,
                            hidden_dims=hidden_dims,
                            cfg=cfg,
                            dropout=dr,
                            weight_decay=wd,
                        )


                        sss = StratifiedShuffleSplit(
                            n_splits=1,
                            test_size=0.2,
                            random_state=base_seed + 1234,
                        )
                        tr_sub, va_sub = next(sss.split(outer_trainval_idx.cpu(), y_tv))
                        inner_train_nodes = outer_trainval_idx[tr_sub].to(DEVICE)
                        inner_val_nodes   = outer_trainval_idx[va_sub].to(DEVICE)

                        N = data.num_nodes
                        train_mask = torch.zeros(N, dtype=torch.bool, device=DEVICE)
                        val_mask   = torch.zeros(N, dtype=torch.bool, device=DEVICE)
                        test_mask  = torch.zeros(N, dtype=torch.bool, device=DEVICE)
                        train_mask[inner_train_nodes] = True
                        val_mask[inner_val_nodes]     = True
                        test_mask[outer_test_idx.to(DEVICE)] = True

                        y2_mask_train = (data.y2 != -1) & (~test_mask)
                        pw_y1 = bce_pos_weight_from_mask(data.y,  train_mask, DEVICE)
                        pw_y2 = bce_pos_weight_from_mask(data.y2, y2_mask_train, DEVICE)
                        crit_y1 = nn.BCEWithLogitsLoss(pos_weight=pw_y1) if pw_y1 is not None else None
                        crit_y2 = nn.BCEWithLogitsLoss(pos_weight=pw_y2) if pw_y2 is not None else None

                        model = GCN_Residual_TwoHeads(
                            in_dim=data.num_features,
                            hidden_dims=hidden_dims,
                            dropout=dr,
                            use_layernorm=False,
                            head_hidden=None,
                        ).to(DEVICE)

                        if pretrained_state is not None:
                            model.load_state_dict(pretrained_state, strict=False)

                        alpha_module = LearnableAlpha(init_alpha=0.7).to(DEVICE)
                        optimizer = torch.optim.Adam(
                            list(model.parameters()) + list(alpha_module.parameters()),
                            lr=lr,
                            weight_decay=wd,
                        )

                        best_val_auc_hp = -1.0
                        wait = 0
                        for ep in range(1, cfg.num_epochs + 1):
                            l1, l2, ltot, a = train_one_epoch_dual(
                                model, data,
                                train_mask, y2_mask_train,
                                optimizer, crit_y1, crit_y2,
                                alpha_module, ep, warmup=cfg.warmup
                            )
                            _, v_auc = evaluate_y1(model, data, val_mask, crit_y1)

                            if v_auc > best_val_auc_hp:
                                best_val_auc_hp = v_auc
                                wait = 0
                            else:
                                wait += 1
                                if wait >= cfg.patience:
                                    break

                        print(f"[GRID-FOLD] Val AUPRC = {best_val_auc_hp:.4f}")

                        if best_val_auc_hp > best_val_auc:
                            best_val_auc = best_val_auc_hp
                            best_hp = {
                                "depth": depth,
                                "hidden_dims": hidden_dims,
                                "dropout": dr,
                                "lr": lr,
                                "weight_decay": wd,
                            }

    print(f"[GRID-FOLD] Best HP for this outer fold: {best_hp}, Val AUPRC={best_val_auc:.4f}")
    return best_hp

# ----------------- FINAL TRAIN + TEST FOR 1 OUTER FOLD ----------------- #

def train_and_test_one_outer_fold(
    data: Data,
    outer_trainval_idx: torch.Tensor,
    outer_test_idx: torch.Tensor,
    y_pre: torch.Tensor,
    pretrain_mask: torch.Tensor,
    cfg: TrainConfig,
    hp: Dict[str, Any],
    base_seed: int,
):

    depth = hp["depth"]
    hidden_dims = hp["hidden_dims"]
    dr = hp["dropout"]
    lr = hp["lr"]
    wd = hp["weight_decay"]

    set_seed(base_seed + 2000)


    pretrained_state = pretrain_on_cross_disease(
        data=data,
        y_pre=y_pre,
        pretrain_mask=pretrain_mask,
        hidden_dims=hidden_dims,
        cfg=cfg,
        dropout=dr,
        weight_decay=wd,
    )

    y_tv = data.y[outer_trainval_idx].detach().cpu().numpy()
    sss = StratifiedShuffleSplit(
        n_splits=1,
        test_size=0.2,
        random_state=base_seed + 2345,
    )
    tr_sub, va_sub = next(sss.split(outer_trainval_idx.cpu(), y_tv))
    inner_train_nodes = outer_trainval_idx[tr_sub].to(DEVICE)
    inner_val_nodes   = outer_trainval_idx[va_sub].to(DEVICE)

    N = data.num_nodes
    train_mask = torch.zeros(N, dtype=torch.bool, device=DEVICE)
    val_mask   = torch.zeros(N, dtype=torch.bool, device=DEVICE)
    test_mask  = torch.zeros(N, dtype=torch.bool, device=DEVICE)
    train_mask[inner_train_nodes] = True
    val_mask[inner_val_nodes]     = True
    test_mask[outer_test_idx.to(DEVICE)] = True

    y2_mask_train = (data.y2 != -1) & (~test_mask)
    pw_y1 = bce_pos_weight_from_mask(data.y,  train_mask, DEVICE)
    pw_y2 = bce_pos_weight_from_mask(data.y2, y2_mask_train, DEVICE)
    crit_y1 = nn.BCEWithLogitsLoss(pos_weight=pw_y1) if pw_y1 is not None else None
    crit_y2 = nn.BCEWithLogitsLoss(pos_weight=pw_y2) if pw_y2 is not None else None

    model = GCN_Residual_TwoHeads(
        in_dim=data.num_features,
        hidden_dims=hidden_dims,
        dropout=dr,
        use_layernorm=False,
        head_hidden=None,
    ).to(DEVICE)

    if pretrained_state is not None:
        model.load_state_dict(pretrained_state, strict=False)

    alpha_module = LearnableAlpha(init_alpha=0.7).to(DEVICE)
    optimizer = torch.optim.Adam(
        list(model.parameters()) + list(alpha_module.parameters()),
        lr=lr,
        weight_decay=wd,
    )

    best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
    best_val_auc = -1.0
    wait = 0

    for ep in range(1, cfg.num_epochs + 1):
        l1, l2, ltot, a = train_one_epoch_dual(
            model, data,
            train_mask, y2_mask_train,
            optimizer, crit_y1, crit_y2,
            alpha_module, ep, warmup=cfg.warmup
        )
        _, v_auc = evaluate_y1(model, data, val_mask, crit_y1)

        if v_auc > best_val_auc:
            best_val_auc = v_auc
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            wait = 0
        else:
            wait += 1
            if wait >= cfg.patience:
                break


    model.load_state_dict(best_state, strict=True)
    logit1, _ = model(data.x, data.edge_index)
    test_auprc = auprc_on_mask(logit1, data.y, test_mask)
    print(f"[OUTER-FOLD] Test AUPRC = {test_auprc:.4f}")
    return test_auprc


def nested_cv_one_run(
    data: Data,
    labeled_idx: torch.Tensor,
    y_pre: torch.Tensor,
    pretrain_mask: torch.Tensor,
    cfg: TrainConfig,
    base_seed: int,
):
    skf = StratifiedKFold(
        n_splits=cfg.kfold,
        shuffle=True,
        random_state=base_seed,
    )
    y_np = data.y[labeled_idx].detach().cpu().numpy()
    fold_test_scores = []

    for fold, (tr_idx, te_idx) in enumerate(skf.split(labeled_idx.cpu(), y_np), start=1):
        print(f"\n========== RUN seed={base_seed} | OUTER FOLD {fold}/{cfg.kfold} ==========")
        outer_trainval_idx = labeled_idx[tr_idx]
        outer_test_idx     = labeled_idx[te_idx]


        best_hp = grid_search_for_outer_fold(
            data, outer_trainval_idx, outer_test_idx,
            y_pre, pretrain_mask,
            cfg=cfg,
            base_seed=base_seed + fold * 10,
        )


        test_auc = train_and_test_one_outer_fold(
            data, outer_trainval_idx, outer_test_idx,
            y_pre, pretrain_mask,
            cfg=cfg,
            hp=best_hp,
            base_seed=base_seed + fold * 20,
        )
        fold_test_scores.append(test_auc)

    mean_auc = float(np.mean(fold_test_scores))
    std_auc  = float(np.std(fold_test_scores))
    print(f"\n[RUN seed={base_seed}] Mean Test AUPRC over {cfg.kfold} folds: {mean_auc:.4f} ± {std_auc:.4f}")
    return fold_test_scores, mean_auc, std_auc

# ----------------- MAIN: 10 RUNS (SEEDS 42..51) ----------------- #

def main():
    cfg = TrainConfig()
    os.makedirs(cfg.results_dir, exist_ok=True)


    data, node_names, node_to_idx = load_graph_and_features(cfg)
    data, labeled_idx = load_labels(cfg, data, node_to_idx)


    y_pre, pretrain_mask = build_cross_disease_pretrain_labels(
        node_to_idx=node_to_idx,
        target_name=cfg.target_disease,
        data_y=data.y,
        base_dir=cfg.base_dir,
    )

    seeds = list(range(42, 43))  # 10 runs
    all_run_results = []

    for i, seed in enumerate(seeds, start=1):
        print("\n=========================================")
        print(f"======== NESTED CV RUN {i}/10 — seed = {seed} =========")
        print("=========================================")

        set_seed(seed)
        fold_scores, mean_auc, std_auc = nested_cv_one_run(
            data, labeled_idx,
            y_pre, pretrain_mask,
            cfg=cfg,
            base_seed=seed,
        )

        all_run_results.append({
            "seed": seed,
            "fold_test_auprc": fold_scores,
            "mean_test_auprc": mean_auc,
            "std_test_auprc": std_auc,
        })


    out_path = os.path.join(cfg.results_dir, f"nested_cv_runs_{cfg.target_disease}.json")
    with open(out_path, "w") as f:
        json.dump(all_run_results, f, indent=2)
    print(f"\n[INFO] Nested CV 10-run results saved to: {out_path}")

    all_means = [r["mean_test_auprc"] for r in all_run_results]
    print("\n========== 10-RUN SUMMARY (Nested CV) ==========")
    print(f"Mean of run-wise mean Test AUPRC: {np.mean(all_means):.4f} ± {np.std(all_means):.4f}")

if __name__ == "__main__":
    main()
