In [None]:
import os
import numpy as np
import torch
from torch_geometric.data import Data
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, precision_recall_curve, auc, roc_curve
)
import matplotlib.pyplot as plt

from utils import load_feature_table, make_pairs_and_labels
from embeddings import build_role2vec_embeddings, build_enhanced_X_from_embeddings
from model import EnhancedGNNWithAttention
from training import FocalLoss, train_enhanced, find_best_threshold

# =========================
# Paths (EDIT TO YOUR DATASET)
# =========================
circ_path = r"data/circRNA_Extractedfeatures.csv"
mir_path  = r"data/miRNA_Extractedfeatures.csv"
cmi_path  = r"data/CMI-9589/9589pairs.csv"

# =========================
# Load data
# =========================
circ_names, circ_feats = load_feature_table(circ_path)
mir_names,  mir_feats  = load_feature_table(mir_path)
cmi_df = np.recfromcsv(cmi_path, encoding="utf-8", delimiter=",", names=True)
cmi_df = torch.tensor([])  # <-- if you want, replace with pandas.read_csv instead

import pandas as pd
cmi_df = pd.read_csv(cmi_path)

pos_pairs, neg_pairs, pair_circ_names, pair_mir_names, y = make_pairs_and_labels(
    circ_names, mir_names, cmi_df, circ_feats, mir_feats
)

pair_indices = np.arange(len(y))
is_positive = np.zeros(len(y), dtype=bool)
is_positive[:len(pos_pairs)] = True

# =========================
# Cross-Validation Setup
# =========================
kf = KFold(n_splits=5, shuffle=True, random_state=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EPOCHS = 100
fold_metrics = []
fpr_dict, tpr_dict, precision_dict, recall_dict = {}, {}, {}, {}

# =========================
# Run Cross-Validation
# =========================
for fold, (train_idx, test_idx) in enumerate(kf.split(pair_indices)):
    print(f"\n===== Fold {fold+1} =====")
    train_pos_mask = is_positive & np.isin(np.arange(len(y)), train_idx)

    emb_dict, edge_index, emb_dim = build_role2vec_embeddings(
        train_pos_mask, pair_circ_names, pair_mir_names, emb_dim_fallback=16
    )

    X = build_enhanced_X_from_embeddings(
        circ_feats, mir_feats, circ_names, mir_names,
        emb_dict, emb_dim, pair_circ_names, pair_mir_names
    )

    # Standard scaling
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X[train_idx])
    X_test  = scaler.transform(X[test_idx])
    X_scaled = np.zeros_like(X)
    X_scaled[train_idx] = X_train
    X_scaled[test_idx] = X_test

    x = torch.tensor(X_scaled, dtype=torch.float)
    y_tensor = torch.tensor(y, dtype=torch.long)

    data = Data(x=x, edge_index=edge_index, y=y_tensor)
    data.train_mask = torch.zeros(len(y), dtype=torch.bool); data.train_mask[train_idx] = True
    data.test_mask  = torch.zeros(len(y), dtype=torch.bool); data.test_mask[test_idx]  = True
    data = data.to(device)

    # Train multiple seeds → ensemble
    models = []
    for seed in [42, 123, 456]:
        torch.manual_seed(seed); np.random.seed(seed)
        model = EnhancedGNNWithAttention(input_dim=X.shape[1], hidden_dim=512, dropout=0.3).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, steps_per_epoch=1, epochs=EPOCHS
        )
        criterion = FocalLoss(alpha=0.7, gamma=2.0)

        for epoch in range(EPOCHS):
            loss = train_enhanced(model, optimizer, data, criterion, scheduler)
        models.append(model)

    # Inference (Ensemble Averaging)
    all_probs = []
    for model in models:
        model.eval()
        with torch.no_grad():
            logits = model(data)
            probs = torch.exp(logits)
            all_probs.append(probs.cpu().numpy())
    avg_probs = np.mean(all_probs, axis=0)

    y_probs = avg_probs[data.test_mask.cpu().numpy()][:, 1]
    y_true = data.y[data.test_mask].cpu().numpy()

    # Optimize threshold
    best_thr = find_best_threshold(y_true, y_probs)
    y_pred = (y_probs > best_thr).astype(int)

    # Metrics
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    auroc = roc_auc_score(y_true, y_probs)
    aupr = auc(*reversed(precision_recall_curve(y_true, y_probs)[:2]))

    fpr, tpr, _ = roc_curve(y_true, y_probs)
    prec_curve, rec_curve, _ = precision_recall_curve(y_true, y_probs)
    fpr_dict[fold] = fpr; tpr_dict[fold] = tpr
    precision_dict[fold] = prec_curve; recall_dict[fold] = rec_curve

    print(f"Fold {fold+1} -- Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, "
          f"F1: {f1:.4f}, AUROC: {auroc:.4f}, AUPR: {aupr:.4f}, Thr: {best_thr:.2f}")
    fold_metrics.append({'accuracy': acc, 'precision': prec, 'recall': rec,
                         'f1': f1, 'roc_auc': auroc, 'aupr': aupr})

# =========================
# Final Results
# =========================
avg = {m: np.mean([f[m] for f in fold_metrics]) for m in fold_metrics[0]}
std = {m: np.std([f[m] for f in fold_metrics]) for m in fold_metrics[0]}
print("\nFINAL RESULTS ACROSS 5 FOLDS:")
for k in avg:
    print(f"{k}: {avg[k]:.4f} ± {std[k]:.4f}")

# =========================
# Save & Plot Curves
# =========================
save_dir = "Figures"
os.makedirs(save_dir, exist_ok=True)

# ROC
plt.figure()
for fold in range(5):
    plt.plot(fpr_dict[fold], tpr_dict[fold], lw=2,
             label=f'Fold {fold + 1} ROC (AUC = {auc(fpr_dict[fold], tpr_dict[fold]):.4f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
plt.title('ROC Curves (5-Fold)')
plt.legend(loc="lower right"); plt.tight_layout()
plt.savefig(f"{save_dir}/roc_curves.png", dpi=300); plt.close()

# PR
plt.figure()
for fold in range(5):
    plt.plot(recall_dict[fold], precision_dict[fold], lw=2,
             label=f'Fold {fold + 1} PR (AUPR = {auc(recall_dict[fold], precision_dict[fold]):.4f})')
plt.xlabel('Recall'); plt.ylabel('Precision')
plt.title('Precision-Recall Curves (5-Fold)')
plt.legend(loc="lower left"); plt.tight_layout()
plt.savefig(f"{save_dir}/pr_curves.png", dpi=300); plt.close()

print(f"\nFigures saved in: {save_dir}/")
