In [1]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import random
from pathlib import Path
from sklearn.metrics import roc_auc_score, average_precision_score

from models_gat import HeteroGAT


In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE
# Notebook-safe root
ROOT = Path.cwd().parent

DATA_DIR = ROOT / "data" / "data_cleaned"
GRAPH_PATH = ROOT / "outputs" / "data.pt"
OUT_DIR = ROOT / "outputs"
OUT_DIR.mkdir(exist_ok=True)

SEEDS = range(42, 47)   # 42, 43, 44, 45, 46
EPOCHS = 50
LR = 1e-3

print("ROOT:", ROOT)
print("DATA_DIR:", DATA_DIR)
print("GRAPH_PATH:", GRAPH_PATH)


ROOT: C:\Users\ayish\OneDrive\Documents\circRNA-disease-gnn
DATA_DIR: C:\Users\ayish\OneDrive\Documents\circRNA-disease-gnn\data\data_cleaned
GRAPH_PATH: C:\Users\ayish\OneDrive\Documents\circRNA-disease-gnn\outputs\data.pt


### Reproducibility

In [3]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


### Load Node Preserving Splits

In [4]:
def load_split(name, le_circ, le_dis):
    df = pd.read_csv(DATA_DIR / name)

    circ_ids = le_circ.transform(df["circRNA"].astype(str))
    dis_ids  = le_dis.transform(df["disease"].astype(str))

    edges = torch.from_numpy(
        np.vstack([circ_ids, dis_ids])
    ).long()

    labels = torch.tensor(df["label"].values, dtype=torch.float)

    return edges, labels


In [5]:
encoders = torch.load(
    OUT_DIR / "label_encoders.pt",
    weights_only=False
)

le_circ = encoders["circRNA"]
le_dis  = encoders["disease"]

train_edges, train_labels = load_split(
    "circRNA_disease_train.csv", le_circ, le_dis
)
val_edges, val_labels = load_split(
    "circRNA_disease_val.csv", le_circ, le_dis
)
test_edges, test_labels = load_split(
    "circRNA_disease_test.csv", le_circ, le_dis
)
train_edges.shape, train_labels.shape

(torch.Size([2, 929]), torch.Size([929]))

### Move Splits to Device

In [6]:
train_edges, train_labels = train_edges.to(DEVICE), train_labels.to(DEVICE)
val_edges, val_labels     = val_edges.to(DEVICE), val_labels.to(DEVICE)
test_edges, test_labels   = test_edges.to(DEVICE), test_labels.to(DEVICE)


In [7]:
print("Loading heterogeneous graph...")
data = torch.load(
    GRAPH_PATH,
    map_location=DEVICE,
    weights_only=False
)
data


Loading heterogeneous graph...


HeteroData(
  circRNA={ x=[828, 6] },
  miRNA={ x=[521, 6] },
  disease={ x=[122, 6] },
  (circRNA, interacts, miRNA)={ edge_index=[2, 896] },
  (miRNA, interacts, disease)={ edge_index=[2, 828] },
  (circRNA, associated, disease)={ edge_index=[2, 985] },
  (circRNA, gip_sim, circRNA)={
    edge_index=[2, 685584],
    edge_weight=[685584],
  },
  (miRNA, gip_sim, miRNA)={
    edge_index=[2, 271441],
    edge_weight=[271441],
  },
  (miRNA, rev_interacts, circRNA)={ edge_index=[2, 896] },
  (disease, rev_interacts, miRNA)={ edge_index=[2, 828] },
  (disease, rev_associated, circRNA)={ edge_index=[2, 985] }
)

In [8]:
print(data.edge_types)


[('circRNA', 'interacts', 'miRNA'), ('miRNA', 'interacts', 'disease'), ('circRNA', 'associated', 'disease'), ('circRNA', 'gip_sim', 'circRNA'), ('miRNA', 'gip_sim', 'miRNA'), ('miRNA', 'rev_interacts', 'circRNA'), ('disease', 'rev_interacts', 'miRNA'), ('disease', 'rev_associated', 'circRNA')]


### Multi Seed Training Loop

In [9]:
results = []

for seed in SEEDS:
    print("\n===================================")
    print(f"Running GAT experiment | SEED = {seed}")
    print("===================================")

    set_seed(seed)

    model = HeteroGAT(
        in_channels=data["circRNA"].x.size(1),
        hidden_channels=64,
        out_channels=64,
        heads=4,
        dropout=0.2
    ).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.BCEWithLogitsLoss()

    best_val_aupr = 0.0

    # -------- Training --------
    for epoch in range(1, EPOCHS + 1):
        model.train()
        optimizer.zero_grad()

        emb = model(data.x_dict, data.edge_index_dict)
        circ_emb, dis_emb = emb["circRNA"], emb["disease"]

        logits = (circ_emb[train_edges[0]] * dis_emb[train_edges[1]]).sum(dim=1)
        loss = loss_fn(logits, train_labels)

        loss.backward()
        optimizer.step()

        # -------- Validation --------
        model.eval()
        with torch.no_grad():
            emb = model(data.x_dict, data.edge_index_dict)
            circ_emb, dis_emb = emb["circRNA"], emb["disease"]

            val_logits = (circ_emb[val_edges[0]] * dis_emb[val_edges[1]]).sum(dim=1)
            val_scores = torch.sigmoid(val_logits).cpu().numpy()
            val_true   = val_labels.cpu().numpy()

            auc  = roc_auc_score(val_true, val_scores)
            aupr = average_precision_score(val_true, val_scores)

        print(
            f"Epoch {epoch:03d} | "
            f"Loss {loss.item():.4f} | "
            f"Val AUC {auc:.4f} | "
            f"Val AUPR {aupr:.4f}"
        )

        if aupr > best_val_aupr:
            best_val_aupr = aupr
            torch.save(
                model.state_dict(),
                OUT_DIR / f"gat_best_model_seed{seed}.pth"
            )
            print("   → Saved best model")

    # -------- Test --------
    model.load_state_dict(
        torch.load(OUT_DIR / f"gat_best_model_seed{seed}.pth", map_location=DEVICE)
    )
    model.eval()

    with torch.no_grad():
        emb = model(data.x_dict, data.edge_index_dict)
        circ_emb, dis_emb = emb["circRNA"], emb["disease"]

        test_logits = (circ_emb[test_edges[0]] * dis_emb[test_edges[1]]).sum(dim=1)
        test_scores = torch.sigmoid(test_logits).cpu().numpy()
        test_true   = test_labels.cpu().numpy()

        test_auc  = roc_auc_score(test_true, test_scores)
        test_aupr = average_precision_score(test_true, test_scores)

    print(f"SEED {seed} | Test AUC {test_auc:.4f} | Test AUPR {test_aupr:.4f}")
    results.append((seed, test_auc, test_aupr))



Running GAT experiment | SEED = 42
Epoch 001 | Loss 0.7528 | Val AUC 0.5167 | Val AUPR 0.1132
   → Saved best model
Epoch 002 | Loss 0.6458 | Val AUC 0.7669 | Val AUPR 0.2361
   → Saved best model
Epoch 003 | Loss 0.5601 | Val AUC 0.8367 | Val AUPR 0.3449
   → Saved best model
Epoch 004 | Loss 0.5045 | Val AUC 0.8724 | Val AUPR 0.4228
   → Saved best model
Epoch 005 | Loss 0.4643 | Val AUC 0.8941 | Val AUPR 0.4984
   → Saved best model
Epoch 006 | Loss 0.4429 | Val AUC 0.9099 | Val AUPR 0.6018
   → Saved best model
Epoch 007 | Loss 0.4290 | Val AUC 0.9181 | Val AUPR 0.6369
   → Saved best model
Epoch 008 | Loss 0.4200 | Val AUC 0.9212 | Val AUPR 0.6554
   → Saved best model
Epoch 009 | Loss 0.4165 | Val AUC 0.9224 | Val AUPR 0.6636
   → Saved best model
Epoch 010 | Loss 0.4118 | Val AUC 0.9207 | Val AUPR 0.6562
Epoch 011 | Loss 0.4103 | Val AUC 0.9185 | Val AUPR 0.6437
Epoch 012 | Loss 0.4067 | Val AUC 0.9159 | Val AUPR 0.6284
Epoch 013 | Loss 0.4044 | Val AUC 0.9140 | Val AUPR 0.6127

In [10]:
print("\n========== FINAL SUMMARY ==========")

for seed, auc, aupr in results:
    print(f"Seed {seed}: AUC={auc:.4f}, AUPR={aupr:.4f}")

aucs  = [r[1] for r in results]
auprs = [r[2] for r in results]

print("\nMean ± Std over seeds")
print(f"AUC  : {np.mean(aucs):.4f} ± {np.std(aucs):.4f}")
print(f"AUPR : {np.mean(auprs):.4f} ± {np.std(auprs):.4f}")



Seed 42: AUC=0.9243, AUPR=0.5617
Seed 43: AUC=0.9403, AUPR=0.6851
Seed 44: AUC=0.9227, AUPR=0.6061
Seed 45: AUC=0.9349, AUPR=0.5920
Seed 46: AUC=0.9223, AUPR=0.6335

Mean ± Std over seeds
AUC  : 0.9289 ± 0.0073
AUPR : 0.6157 ± 0.0417
