# HGTDR + TransE Embeddings - 5-Fold CV Comparison

This notebook modifies the original HGTDR implementation
to incorporate TransE embeddings (loaded externally) alongside BioBERT and
ChemBERTa, followed by the implementation of the paper's 5-fold CV structure for fair comparison.

## 1. Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, Linear
from torch_geometric.loader import HGTLoader
from torch_geometric.data import HeteroData
import pandas as pd
import numpy as np
import pickle
import random
import copy
import os
import time
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, auc
from matplotlib import pyplot as plt

## 2. Configuration

In [None]:
# --- File Paths ---
PRIMKG_CSV_PATH = '../data/kg.csv'
BIOBERT_EMBEDDINGS_PATH = '../data/entities_embeddings.pkl'
CHEMBERTA_EMBEDDINGS_PATH = '../data/smiles_embeddings.pkl'
TRANSE_EMBEDDINGS_PATH = '../data/entity_embeddings.npy' # Adjust as needed
PYKEEN_MAPPING_PATH = '../data/entity_to_id_mapping.pkl'
CV_DATA_DIR = '../data/CV data/'
OUT_DIR = '../output'

# --- Embedding Dimensions ---
BIOBERT_DIM = 768
CHEMBERTA_DIM = 768
TRANSE_DIM = 50 # From TransE training

# --- Embedding Selection ---
USE_BIOBERT = True
USE_CHEMBERTA = True
USE_TRANSE = True

# --- HGTDR Config (Match hgtdrOG.py) ---
HGTDR_CONFIG = {
    "num_samples": 512,      # Per layer in HGTLoader
    "batch_size": 164,
    "dropout": 0.5,          # Dropout in HGT input projection and Predictor
    "epochs": 300,           
    "hidden_channels": [64, 64, 64, 64], # Input projection + 3 layers
    "out_channels": 64,      # GNN output embedding size
    "num_heads": [8, 8, 8],  # Heads per HGT layer
    "num_layers": 3,
    "learning_rate": 0.001,  # Default AdamW LR, adjust if needed
    "weight_decay": 0.01,    # Default AdamW WD, adjust if needed
    "predictor_dropout": 0.2 # Specific dropout for the MLP predictor
}

# --- Node Types and Relation ---
NODE_TYPE_DRUG = 'drug'
NODE_TYPE_DISEASE = 'disease'
INDICATION_RELATION = 'indication'
INDICATION_EDGE_TYPE = (NODE_TYPE_DRUG, INDICATION_RELATION, NODE_TYPE_DISEASE)
REV_INDICATION_EDGE_TYPE = (NODE_TYPE_DISEASE, INDICATION_RELATION, NODE_TYPE_DRUG) # If exists

# --- Device Setup ---
NUM_GPUS = torch.cuda.device_count()
print(f"Found {NUM_GPUS} GPUs.")
# Simple fold assignment (adjust if more GPUs)
def get_device_for_fold(fold_num):
    """ if NUM_GPUS == 0:
        print("Warning: No GPU found, using CPU.")
        return torch.device('cpu')
    elif NUM_GPUS == 1:
        return torch.device('cuda:0')
    else: # Assign to 2 GPUs
        gpu_id = 0 if fold_num <= 3 else 1
        return torch.device(f'cuda:{gpu_id}') """
    print("Forcing CPU execution for debugging.")
    return torch.device('cpu')

# --- Output Directory ---
if not os.path.exists(OUT_DIR):
    os.makedirs(OUT_DIR)
print(f"Output directory: {OUT_DIR}")

## 3. Utility Functions

In [None]:
def write_to_out(text, filename="out.txt", fold_num=None):
    """Writes text to a file, optionally creating fold-specific logs."""
    prefix = f"Fold_{fold_num}_" if fold_num is not None else ""
    filepath = os.path.join(OUT_DIR, prefix + filename)
    print(text) # Also print to console
    try:
        with open(filepath, 'a') as out_file:
            out_file.write(str(text) + '\n')
    except IOError as e:
        print(f"Error writing to {filepath}: {e}")

def plot_losses(losses, val_losses, filename="losses.png", fold_num=None):
    """Plots losses and saves the figure, optionally fold-specific."""
    prefix = f"Fold_{fold_num}_" if fold_num is not None else ""
    filepath = os.path.join(OUT_DIR, prefix + filename)
    try:
        plt.figure(figsize=(10, 5))
        plt.plot(losses, label='Training Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.title(f'Fold {fold_num} - Training and Validation Loss' if fold_num else 'Training and Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(filepath, dpi=200)
        plt.close() 
        print(f"Saved loss plot to {filepath}")
    except Exception as e:
        print(f"Error plotting/saving losses to {filepath}: {e}")


def calculate_metrics(scores, labels, fold_num):
    """Calculates and logs AUROC and AUPR."""
    scores_np = scores.detach().cpu().numpy()
    labels_np = labels.detach().cpu().numpy()
    
    if not np.all(np.isfinite(scores_np)):
        print("Warning: Non-finite scores detected. Replacing with 0.")
        scores_np = np.nan_to_num(scores_np)
    if not np.all(np.isfinite(labels_np)):
        print("Error: Non-finite labels detected. Cannot calculate metrics.")
        return 0.0, 0.0 # Return default values

    # Check for single class in labels
    if len(np.unique(labels_np)) < 2:
        write_to_out(f"Warning: Only one class present in labels for Fold {fold_num}. Metrics are undefined.", fold_num=fold_num)
        return 0.0, 0.0

    lr_auc = roc_auc_score(labels_np, scores_np)
    lr_precision, lr_recall, _ = precision_recall_curve(labels_np, scores_np)
    lr_aupr = auc(lr_recall, lr_precision)

    write_to_out(f'Fold {fold_num} Final Val AUROC: {lr_auc:.4f}', fold_num=fold_num)
    write_to_out(f'Fold {fold_num} Final Val AUPR: {lr_aupr:.4f}', fold_num=fold_num)

    # Optional: Plot curves (can generate many files)
    # plot_metric_curves(scores_np, labels_np, fold_num)

    return lr_auc, lr_aupr

# Optional function to plot curves if desired
def plot_metric_curves(scores_np, labels_np, fold_num):
    """Plots AUROC and AUPR curves."""
    try:
        # AUROC Curve
        ns_probs = [0 for _ in range(len(labels_np))]
        ns_fpr, ns_tpr, _ = roc_curve(labels_np, ns_probs)
        lr_fpr, lr_tpr, _ = roc_curve(labels_np, scores_np)
        plt.figure(figsize=(6, 6))
        plt.plot(ns_fpr, ns_tpr, linestyle='--', label='No Skill')
        plt.plot(lr_fpr, lr_tpr, marker='.', label='HGTDR+TransE')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'Fold {fold_num} AUROC')
        plt.legend()
        plt.savefig(os.path.join(OUT_DIR, f'Fold_{fold_num}_AUROC.png'), dpi=150)
        plt.close()

        # AUPR Curve
        lr_precision, lr_recall, _ = precision_recall_curve(labels_np, scores_np)
        no_skill = len(labels_np[labels_np==1]) / len(labels_np)
        plt.figure(figsize=(6, 6))
        plt.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No Skill')
        plt.plot(lr_recall, lr_precision, marker='.', label='HGTDR+TransE')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title(f'Fold {fold_num} AUPR')
        plt.legend()
        plt.savefig(os.path.join(OUT_DIR, f'Fold_{fold_num}_AUPR.png'), dpi=150)
        plt.close()
    except Exception as e:
        write_to_out(f"Error plotting curves for Fold {fold_num}: {e}", fold_num=fold_num)

def compute_loss(scores, labels):
    """Calculates weighted BCE loss based on batch label distribution."""
    pos_num = (labels == 1).sum().item()
    neg_num = (labels == 0).sum().item()

    # Handle cases where one class might be missing in a batch
    if pos_num == 0 or neg_num == 0:
         # Fallback to unweighted loss if one class is missing
         # write_to_out("Warning: Only one class present in batch, using unweighted loss.")
         return F.binary_cross_entropy_with_logits(scores, labels.float(), reduction='mean')

    # Create weights tensor
    pos_weight_val = neg_num / labels.shape[0]
    neg_weight_val = pos_num / labels.shape[0]

    # Create weights tensor based on labels
    weights = torch.zeros_like(labels, dtype=torch.float)
    weights[labels == 1] = pos_weight_val
    weights[labels == 0] = neg_weight_val

    # Calculate weighted loss
    weight_for_pos_class = torch.tensor([neg_num / pos_num], device=scores.device) if pos_num > 0 else torch.tensor([1.0], device=scores.device)

    return F.binary_cross_entropy_with_logits(scores, labels.float(), pos_weight=weight_for_pos_class)

def edge_exists(edges, edge):
    """Checks if a single edge exists within a set of edges."""
    # edges: Tensor of shape [2, N]
    # edge: Tensor of shape [2, 1] or [2]
    if edges is None or edges.numel() == 0:
        return False
    edge = edge.view(2, 1) # Ensure edge is [2, 1]
    edges = edges.to(edge.device)
    return (edges == edge).all(dim=0).any().item()

## 4. Data Loading (Raw KG and Embeddings)

In [None]:
# --- Load PrimeKG Raw Data ---
try:
    df_primekg_raw = pd.read_csv(PRIMKG_CSV_PATH, sep=",")
    print(f"Loaded PrimeKG raw data: {df_primekg_raw.shape[0]} triplets")
except FileNotFoundError:
    print(f"Error: PrimeKG file not found at {PRIMKG_CSV_PATH}")
    exit()

# --- Load Base Embeddings (BioBERT/ChemBERTa) ---
biobert_embeddings_df = None
if USE_BIOBERT:
    try:
        biobert_embeddings_df = pd.read_pickle(BIOBERT_EMBEDDINGS_PATH)
        print(f"Loaded BioBERT embeddings: {biobert_embeddings_df.shape[0]} entities")
        if 'embedding' in biobert_embeddings_df.columns and len(biobert_embeddings_df) > 0:
            inferred_dim = len(biobert_embeddings_df['embedding'].iloc[0])
            if inferred_dim != BIOBERT_DIM:
                print(f"Warning: Inferred BioBERT dim ({inferred_dim}) != configured dim ({BIOBERT_DIM}). Using inferred.")
                BIOBERT_DIM = inferred_dim
    except FileNotFoundError:
        print(f"Warning: BioBERT embeddings file not found at {BIOBERT_EMBEDDINGS_PATH}. Skipping.")
        USE_BIOBERT = False

chemberta_embeddings_df = None
if USE_CHEMBERTA:
    try:
        chemberta_embeddings_df = pd.read_pickle(CHEMBERTA_EMBEDDINGS_PATH)
        print(f"Loaded ChemBERTa embeddings: {chemberta_embeddings_df.shape[0]} drugs")
        if 'embedding' in chemberta_embeddings_df.columns and len(chemberta_embeddings_df) > 0:
            inferred_dim = len(chemberta_embeddings_df['embedding'].iloc[0])
            if inferred_dim != CHEMBERTA_DIM:
                print(f"Warning: Inferred ChemBERTa dim ({inferred_dim}) != configured dim ({CHEMBERTA_DIM}). Using inferred.")
                CHEMBERTA_DIM = inferred_dim
    except FileNotFoundError:
        print(f"Warning: ChemBERTa embeddings file not found at {CHEMBERTA_EMBEDDINGS_PATH}. Skipping.")
        USE_CHEMBERTA = False

# --- Load TransE Embeddings ---
transe_embeddings_npy = None
if USE_TRANSE:
    try:
        transe_embeddings_npy = np.load(TRANSE_EMBEDDINGS_PATH)
        print(f"Loaded TransE embeddings: shape {transe_embeddings_npy.shape}")
        inferred_dim = transe_embeddings_npy.shape[1]
        if inferred_dim != TRANSE_DIM:
            print(f"Warning: Loaded TransE dim ({inferred_dim}) != configured dim ({TRANSE_DIM}). Using loaded.")
            TRANSE_DIM = inferred_dim
    except FileNotFoundError:
        print(f"Warning: TransE embeddings file not found at {TRANSE_EMBEDDINGS_PATH}. Skipping.")
        USE_TRANSE = False

## 5. PyKEEN Mapping

In [None]:
pykeen_entity_to_id = None
if USE_TRANSE:
    try:
        # Option 1: Load saved mapping (if you saved it during TransE training)
        if PYKEEN_MAPPING_PATH and os.path.exists(PYKEEN_MAPPING_PATH):
            with open(PYKEEN_MAPPING_PATH, 'rb') as f:
                pykeen_entity_to_id = pickle.load(f)
            print(f"Loaded PyKEEN entity mapping from {PYKEEN_MAPPING_PATH}")
        else:
            # Option 2: Recreate from PyKEEN dataset (Requires pykeen installed)
            try:
                from pykeen.datasets import PrimeKG
                print("Recreating PyKEEN mapping from PrimeKG dataset...")
                pykeen_dataset = PrimeKG()
                pykeen_entity_to_id = pykeen_dataset.training.entity_to_id
                print(f"Recreated PyKEEN entity mapping: {len(pykeen_entity_to_id)} entities")
                # Optional: Save the mapping for future use
                # with open('pykeen_primekg_entity_mapping.pkl', 'wb') as f:
                #     pickle.dump(pykeen_entity_to_id, f)
            except ImportError:
                print("Error: PyKEEN library not installed. Cannot recreate mapping.")
                print("Please install PyKEEN or provide a path to a saved mapping (PYKEEN_MAPPING_PATH).")
                USE_TRANSE = False
            except Exception as e:
                print(f"Error recreating PyKEEN mapping: {e}. Skipping TransE.")
                USE_TRANSE = False

    except Exception as e:
        print(f"Error loading or recreating PyKEEN mapping: {e}")
        USE_TRANSE = False

if USE_TRANSE and pykeen_entity_to_id is None:
    print("Error: Could not load or create PyKEEN mapping. Disabling TransE usage.")
    USE_TRANSE = False

## 6. Preprocessing Function 

In [None]:
def preprocess_fold_data(fold_num, df_kg, biobert_df, chemberta_df, transe_npy, pykeen_map,
                        inference_mode=False): 
    """
    Prepares HeteroData objects. If inference_mode=True, only prepares full_data
    and skips loading splits and masking. Otherwise, prepares for CV fold.
    """
    print(f"\n--- Preprocessing Fold {fold_num} {'(Inference Mode)' if inference_mode else ''} ---")
    start_time = time.time()

    # --- Steps 1-5: Build the full_data object (Always Run) ---

    # 1. Filter DataFrame
    print("Filtering nodes based on indication involvement...")
    # ... (Filtering logic unchanged) ...
    indication_pairs = df_kg[df_kg['relation'] == INDICATION_RELATION]; drugs_in_indication = set(indication_pairs[indication_pairs['x_type'] == NODE_TYPE_DRUG]['x_index']) | set(indication_pairs[indication_pairs['y_type'] == NODE_TYPE_DRUG]['y_index']); diseases_in_indication = set(indication_pairs[indication_pairs['x_type'] == NODE_TYPE_DISEASE]['x_index']) | set(indication_pairs[indication_pairs['y_type'] == NODE_TYPE_DISEASE]['y_index'])
    df_filtered = df_kg[((df_kg['x_type'] != NODE_TYPE_DRUG) | df_kg['x_index'].isin(drugs_in_indication)) & ((df_kg['x_type'] != NODE_TYPE_DISEASE) | df_kg['x_index'].isin(diseases_in_indication)) & ((df_kg['y_type'] != NODE_TYPE_DRUG) | df_kg['y_index'].isin(drugs_in_indication)) & ((df_kg['y_type'] != NODE_TYPE_DISEASE) | df_kg['y_index'].isin(diseases_in_indication))].copy()
    print(f"Filtered DataFrame rows: {len(df_filtered)}")

    # 2. Create Global Entity/Edge Dictionaries (using display names)
    print("Creating global entity and edge dictionaries (using display names)...")
    entity_dictionary = {}; edge_dictionary = {}; node_type_map = {}
    def insert_entry_by_name(display_name, ent_type, entity_dict, type_map):
        if ent_type not in entity_dict: entity_dict[ent_type] = {}
        if display_name not in entity_dict[ent_type]: entity_dict[ent_type][display_name] = len(entity_dict[ent_type]);
        if display_name not in type_map: type_map[display_name] = ent_type
        return entity_dict, type_map
    valid_triplets_named = []; processed_rows = 0; skipped_rows = 0
    for _, row in df_filtered.iterrows():
        src_name, src_type = row['x_name'], row['x_type']; dst_name, dst_type = row['y_name'], row['y_type']; relation = row['relation']
        if pd.isna(src_name) or pd.isna(dst_name): skipped_rows += 1; continue
        entity_dictionary, node_type_map = insert_entry_by_name(src_name, src_type, entity_dictionary, node_type_map); entity_dictionary, node_type_map = insert_entry_by_name(dst_name, dst_type, entity_dictionary, node_type_map)
        valid_triplets_named.append((src_name, relation, dst_name)); processed_rows += 1
    print(f"Processed {processed_rows} rows for dictionaries, skipped {skipped_rows} rows with NaN names.")
    for src_name, rel, dst_name in valid_triplets_named:
        src_type = node_type_map.get(src_name); dst_type = node_type_map.get(dst_name)
        if src_type is None or dst_type is None: continue
        src_hgt_id = entity_dictionary[src_type][src_name]; dst_hgt_id = entity_dictionary[dst_type][dst_name]
        etype = (src_type, rel, dst_type); pair = (src_hgt_id, dst_hgt_id)
        if etype not in edge_dictionary: edge_dictionary[etype] = []
        edge_dictionary[etype].append(pair)
    print("Entity and edge dictionaries created.")

    # 3. Create Base HeteroData Object (`full_data`)
    print("Creating base HeteroData object...")
    full_data = HeteroData(); node_feature_dims = {}
    combined_dim = 0;
    if USE_BIOBERT: combined_dim += BIOBERT_DIM
    if USE_TRANSE: combined_dim += TRANSE_DIM
    for node_type, mapping in entity_dictionary.items():
        num_nodes = len(mapping); full_data[node_type].num_nodes = num_nodes; full_data[node_type].global_ids = torch.arange(num_nodes)
        final_node_dim = 0
        if USE_BIOBERT: final_node_dim += BIOBERT_DIM
        if node_type == NODE_TYPE_DRUG and USE_CHEMBERTA: final_node_dim += CHEMBERTA_DIM
        if USE_TRANSE: final_node_dim += TRANSE_DIM
        node_feature_dims[node_type] = final_node_dim
        if final_node_dim > 0: full_data[node_type].x = torch.randn((num_nodes, final_node_dim), dtype=torch.float) * 0.01
        else: full_data[node_type].x = None

    # 4. Populate Embeddings in `full_data`
    print("Populating combined embeddings (TransE First)...")
    current_offset = {nt: 0 for nt in entity_dictionary.keys()}
    if USE_TRANSE and transe_npy is not None and pykeen_map is not None:
        print("  Adding TransE...") # ... (TransE loop) ...
        missing_in_pykeen = 0; added_prints_per_type = {nt: 0 for nt in entity_dictionary.keys()}; max_added_prints = 3
        for node_type, mapping in entity_dictionary.items():
            if full_data[node_type].x is None: continue
            count = 0; type_offset = current_offset[node_type]
            for display_name, hgt_id in mapping.items():
                entity_name_pykeen = display_name
                if entity_name_pykeen in pykeen_map:
                    pykeen_id = pykeen_map[entity_name_pykeen]
                    if 0 <= pykeen_id < transe_npy.shape[0]:
                        emb = torch.tensor(transe_npy[pykeen_id], dtype=torch.float)
                        if emb.shape[0] == TRANSE_DIM:
                            full_data[node_type].x[hgt_id, type_offset : type_offset + TRANSE_DIM] = emb; count += 1
                            if added_prints_per_type[node_type] < max_added_prints: print(f"    OK: Added TransE for {node_type}: '{display_name}' (PyKEEN ID: {pykeen_id})"); added_prints_per_type[node_type] += 1
                else: missing_in_pykeen += 1
            current_offset[node_type] += TRANSE_DIM
        if missing_in_pykeen > 0: print(f"    Warning: {missing_in_pykeen} entities (by name) not found in PyKEEN mapping for TransE.")
        else: print("    Success: All entities found in PyKEEN mapping for TransE.")
    print("  Adding BioBERT...")
    name_to_type_index_map = {}; print("    Building name -> type::index map for BioBERT/ChemBERTa...")
    unique_names_df = df_kg[['x_name', 'x_type', 'x_index']].dropna(subset=['x_name']).drop_duplicates(subset=['x_name']); name_to_type_index_map.update({row['x_name']: f"{row['x_type']}::{row['x_index']}" for _, row in unique_names_df.iterrows()})
    unique_names_df = df_kg[['y_name', 'y_type', 'y_index']].dropna(subset=['y_name']).drop_duplicates(subset=['y_name']); name_to_type_index_map.update({row['y_name']: f"{row['y_type']}::{row['y_index']}" for _, row in unique_names_df.iterrows() if row['y_name'] not in name_to_type_index_map})
    print(f"    Built map with {len(name_to_type_index_map)} entries.")
    if USE_BIOBERT and biobert_df is not None: # ... (BioBERT loop) ...
        missing_biobert_lookup = 0
        for node_type, mapping in entity_dictionary.items():
            if full_data[node_type].x is None: continue
            count = 0; type_offset = current_offset[node_type]
            for display_name, hgt_id in mapping.items():
                type_index_key = name_to_type_index_map.get(display_name);
                if type_index_key is None: missing_biobert_lookup += 1; continue
                biobert_row = biobert_df[biobert_df['id'] == type_index_key]
                if not biobert_row.empty:
                    emb = torch.tensor(biobert_row['embedding'].iloc[0], dtype=torch.float)
                    if emb.shape[0] == BIOBERT_DIM: full_data[node_type].x[hgt_id, type_offset : type_offset + BIOBERT_DIM] = emb; count += 1
            current_offset[node_type] += BIOBERT_DIM
        if missing_biobert_lookup > 0: print(f"    Warning: Could not map back {missing_biobert_lookup} display names to type::index for BioBERT.")
    if USE_CHEMBERTA and chemberta_df is not None and NODE_TYPE_DRUG in entity_dictionary: # ... (ChemBERTa loop) ...
        print("  Adding ChemBERTa...")
        missing_chemberta_lookup = 0; node_type = NODE_TYPE_DRUG; mapping = entity_dictionary[node_type]
        if full_data[node_type].x is not None:
            count = 0; type_offset = current_offset[node_type]
            for display_name, hgt_id in mapping.items():
                type_index_key = name_to_type_index_map.get(display_name);
                if type_index_key is None: missing_chemberta_lookup += 1; continue
                chemberta_row = chemberta_df[chemberta_df['id'] == type_index_key]
                if not chemberta_row.empty:
                    emb = torch.tensor(chemberta_row['embedding'].iloc[0], dtype=torch.float)
                    if emb.shape[0] == CHEMBERTA_DIM: full_data[node_type].x[hgt_id, type_offset : type_offset + CHEMBERTA_DIM] = emb; count += 1
            current_offset[node_type] += CHEMBERTA_DIM
        if missing_chemberta_lookup > 0: print(f"    Warning: Could not map back {missing_chemberta_lookup} drug display names to type::index for ChemBERTa.")


    # 5. Add Edges to `full_data`
    print("Adding edges to base HeteroData...")
    for etype, pairs_list in edge_dictionary.items():
        if pairs_list: edge_index = torch.tensor(pairs_list, dtype=torch.long).t().contiguous(); full_data[etype].edge_index = edge_index
    print("Base HeteroData object created and embeddings populated.")

    # --- Conditional Split Loading and Masking ---
    if not inference_mode:
        # --- Steps 6 & 7: Only run for actual CV folds ---
        # 6. Create Fold-Specific Data (Train/Val) using LOADED HeteroData splits
        print(f"Loading CV split data for Fold {fold_num}...")
        train_split_path = os.path.join(CV_DATA_DIR, f'train{fold_num}.pkl'); val_split_path = os.path.join(CV_DATA_DIR, f'val{fold_num}.pkl')
        try:
            with open(train_split_path, 'rb') as f: loaded_train_data = pickle.load(f)
            with open(val_split_path, 'rb') as f: loaded_val_data = pickle.load(f)
            if not isinstance(loaded_train_data, HeteroData) or not isinstance(loaded_val_data, HeteroData): raise TypeError("Loaded CV split files are not HeteroData objects.")
            if INDICATION_EDGE_TYPE not in loaded_train_data.edge_types: raise KeyError(f"Indication edge type {INDICATION_EDGE_TYPE} not found in loaded train data.")
            val_label_edge_type = None
            if INDICATION_EDGE_TYPE in loaded_val_data.edge_types and hasattr(loaded_val_data[INDICATION_EDGE_TYPE], 'edge_label_index'): val_label_edge_type = INDICATION_EDGE_TYPE
            elif REV_INDICATION_EDGE_TYPE in loaded_val_data.edge_types and hasattr(loaded_val_data[REV_INDICATION_EDGE_TYPE], 'edge_label_index'): val_label_edge_type = REV_INDICATION_EDGE_TYPE; print(f"Note: Validation labels found in reverse edge type {REV_INDICATION_EDGE_TYPE}. Will flip.")
            else: raise KeyError(f"Indication edge type {INDICATION_EDGE_TYPE} or its reverse has no 'edge_label_index' in loaded validation data.")
        except Exception as e: print(f"Error loading or validating CV split files for Fold {fold_num}: {e}"); return None, None, None, None, None
        train_data = copy.deepcopy(full_data); val_data = copy.deepcopy(full_data)
        train_indication_edges = loaded_train_data[INDICATION_EDGE_TYPE].edge_index; train_data[INDICATION_EDGE_TYPE].edge_index = train_indication_edges
        if REV_INDICATION_EDGE_TYPE in train_data.edge_types: train_data[REV_INDICATION_EDGE_TYPE].edge_index = train_indication_edges[[1, 0], :]
        val_indication_edges_to_label = loaded_val_data[val_label_edge_type].edge_label_index
        if val_label_edge_type == REV_INDICATION_EDGE_TYPE: val_data[INDICATION_EDGE_TYPE].edge_label_index = val_indication_edges_to_label[[1, 0], :]
        else: val_data[INDICATION_EDGE_TYPE].edge_label_index = val_indication_edges_to_label
        val_data[INDICATION_EDGE_TYPE].edge_label = torch.ones(val_indication_edges_to_label.shape[1])

        # 7. Apply Training Mask
        print("Applying training mask...")
        num_train_indic = train_data[INDICATION_EDGE_TYPE].edge_index.shape[1]
        if num_train_indic == 0: # ... (handle no training edges) ...
            print("Warning: No training indication edges found for this fold.")
            train_data[INDICATION_EDGE_TYPE]['mask'] = torch.empty(0, dtype=torch.bool); train_data[INDICATION_EDGE_TYPE]['edge_label_index'] = torch.empty((2, 0), dtype=torch.long); train_data[INDICATION_EDGE_TYPE]['edge_label'] = torch.empty(0, dtype=torch.float)
            if REV_INDICATION_EDGE_TYPE in train_data.edge_types: train_data[REV_INDICATION_EDGE_TYPE]['mask'] = torch.empty(0, dtype=torch.bool)
        else: # ... (apply mask) ...
            mask_indices = random.sample(range(num_train_indic), int(num_train_indic * 0.8)); mask = torch.zeros(num_train_indic, dtype=torch.bool); mask[mask_indices] = True
            train_data[INDICATION_EDGE_TYPE]['mask'] = mask; train_data[INDICATION_EDGE_TYPE]['edge_label_index'] = train_data[INDICATION_EDGE_TYPE].edge_index[:, mask]; train_data[INDICATION_EDGE_TYPE]['edge_label'] = torch.ones(len(mask_indices))
            train_data[INDICATION_EDGE_TYPE].edge_index = train_data[INDICATION_EDGE_TYPE].edge_index[:, ~mask]
            if REV_INDICATION_EDGE_TYPE in train_data.edge_types:
                if train_data[REV_INDICATION_EDGE_TYPE].edge_index.shape[1] == num_train_indic: train_data[REV_INDICATION_EDGE_TYPE]['mask'] = mask; train_data[REV_INDICATION_EDGE_TYPE].edge_index = train_data[REV_INDICATION_EDGE_TYPE].edge_index[:, ~mask]
                else: print(f"Warning: Size mismatch between forward/reverse indication edges during masking. Skipping reverse edge filtering.")

        print(f"Fold {fold_num} preprocessing finished in {time.time() - start_time:.2f} seconds.")
        # Return fold-specific data
        return train_data, val_data, full_data, entity_dictionary, node_feature_dims

    else: # --- inference_mode is True ---
        print(f"Fold {fold_num} preprocessing (Inference Mode) finished in {time.time() - start_time:.2f} seconds.")
        return None, None, full_data, entity_dictionary, node_feature_dims

## 7. Model Definitions (HGT + Predictor)

In [None]:
class HGT_Combined(nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers, dropout, node_feature_dims, metadata):
        super().__init__()
        assert len(hidden_channels) == num_layers + 1
        assert len(num_heads) == num_layers

        self.lin_dict = nn.ModuleDict()
        for node_type, in_dim in node_feature_dims.items():
            if in_dim > 0:
                self.lin_dict[node_type] = Linear(in_dim, hidden_channels[0])
            else:
                 print(f"Note: Node type '{node_type}' has 0 input dimension. Skipping input linear layer.")


        self.convs = nn.ModuleList()
        for i in range(num_layers):
            conv = HGTConv(in_channels=hidden_channels[i],
                           out_channels=hidden_channels[i+1],
                           metadata=metadata,
                           heads=num_heads[i])
            self.convs.append(conv)

        final_agg_dim = sum(hidden_channels[1:])
        self.out_lin = Linear(final_agg_dim, out_channels)

        self.dropout_rate = dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_dict, edge_index_dict):
        # Initial projection
        projected_x_dict = {}
        for node_type, x in x_dict.items():
            if node_type in self.lin_dict:
                projected_x = self.lin_dict[node_type](x)
                projected_x = F.gelu(projected_x)
                projected_x = self.dropout(projected_x)
                projected_x_dict[node_type] = projected_x
            elif x is not None:
                 projected_x_dict[node_type] = x

        layer_outputs = {nt: [] for nt in projected_x_dict.keys()}

        # HGT convolutions
        current_x_dict = projected_x_dict
        for i, conv in enumerate(self.convs):
            current_x_dict = conv(current_x_dict, edge_index_dict)
            for node_type, x in current_x_dict.items():
                 if node_type in layer_outputs:
                     layer_outputs[node_type].append(x)

        # Concatenate outputs
        final_x_dict = {}
        for node_type in projected_x_dict.keys():
            if node_type in layer_outputs and layer_outputs[node_type]:
                try:
                    final_x_dict[node_type] = torch.cat(layer_outputs[node_type], dim=1)
                except RuntimeError as e:
                     print(f"Error concatenating layer outputs for {node_type}: {e}")
                     for j, tensor in enumerate(layer_outputs[node_type]):
                         print(f"  Layer {j+1} shape: {tensor.shape}")
                     final_x_dict[node_type] = None
            else:
                final_x_dict[node_type] = None

        # Final linear layer and activation
        out_dict = {}
        for node_type, final_x in final_x_dict.items():
             if final_x is not None:
                  lin_out = self.out_lin(final_x)
                  out_dict[node_type] = F.gelu(lin_out)
             else:
                  out_dict[node_type] = None

        drug_embed = out_dict.get(NODE_TYPE_DRUG, None)
        disease_embed = out_dict.get(NODE_TYPE_DISEASE, None)

        if drug_embed is None and NODE_TYPE_DRUG in x_dict: pass
        if disease_embed is None and NODE_TYPE_DISEASE in x_dict: pass

        return drug_embed, disease_embed


class MLPPredictor(nn.Module):
    """MLP Predictor identical to hgtdrOG.py."""
    def __init__(self, channel_num, dropout):
        super().__init__()
        self.L1 = nn.Linear(channel_num * 2, channel_num)
        self.L2 = nn.Linear(channel_num, 1)
        self.bn = nn.BatchNorm1d(num_features=channel_num)
        self.dropout = nn.Dropout(dropout)

    def forward(self, drug_embeddings, disease_embeddings):
        x = torch.cat((drug_embeddings, disease_embeddings), dim=1)
        x = self.L1(x)
        x = self.bn(x)
        x = F.relu(x) 
        x = self.dropout(x)
        x = self.L2(x)
        return x

## 8. Batching Functions

In [None]:
def make_batch(batch, device):
    """Prepares a training batch with positive/negative samples."""
    batch = batch.to(device)

    indication_edge_type = (NODE_TYPE_DRUG, INDICATION_RELATION, NODE_TYPE_DISEASE)
    rev_indication_edge_type = (NODE_TYPE_DISEASE, INDICATION_RELATION, NODE_TYPE_DRUG)

    if indication_edge_type not in batch.edge_types or \
    not hasattr(batch[indication_edge_type], 'edge_label_index') or \
    not hasattr(batch[indication_edge_type], 'edge_label'):
        batch[indication_edge_type].edge_label_index = torch.empty((2,0), dtype=torch.long, device=device)
        batch[indication_edge_type].edge_label = torch.empty(0, dtype=torch.float, device=device)
        if 'mask' in batch[indication_edge_type]: del batch[indication_edge_type].mask
        return batch

    msg_edge_index = batch[indication_edge_type].edge_index
    supervision_edge_label_index = batch[indication_edge_type].edge_label_index
    supervision_edge_label = batch[indication_edge_type].edge_label

    if NODE_TYPE_DRUG not in batch.node_types or NODE_TYPE_DISEASE not in batch.node_types:
        batch[indication_edge_type].edge_label_index = torch.empty((2,0), dtype=torch.long, device=device)
        batch[indication_edge_type].edge_label = torch.empty(0, dtype=torch.float, device=device)
        if 'mask' in batch[indication_edge_type]: del batch[indication_edge_type].mask
        return batch

    batch_size = batch[NODE_TYPE_DRUG].batch_size
    global_to_local_drug = {gid.item(): i for i, gid in enumerate(batch[NODE_TYPE_DRUG].n_id)}
    global_to_local_disease = {gid.item(): i for i, gid in enumerate(batch[NODE_TYPE_DISEASE].n_id)}

    local_pos_src, local_pos_dst = [], []
    for i in range(supervision_edge_label_index.shape[1]):
        g_src = supervision_edge_label_index[0, i].item()
        g_dst = supervision_edge_label_index[1, i].item()
        if g_src in global_to_local_drug and g_dst in global_to_local_disease:
            local_src_idx = global_to_local_drug[g_src]
            if local_src_idx < batch_size:
                local_pos_src.append(local_src_idx)
                local_pos_dst.append(global_to_local_disease[g_dst])

    if not local_pos_src:
        batch[indication_edge_type].edge_label_index = torch.empty((2,0), dtype=torch.long, device=device)
        batch[indication_edge_type].edge_label = torch.empty(0, dtype=torch.float, device=device)
        if 'mask' in batch[indication_edge_type]: del batch[indication_edge_type].mask
        return batch

    pos_edge_label_index_local = torch.tensor([local_pos_src, local_pos_dst], dtype=torch.long, device=device)
    pos_num = pos_edge_label_index_local.shape[1]
    pos_edge_label = torch.ones(pos_num, device=device)

    neg_edges_source_local = []
    neg_edges_dest_local = []
    num_neg_needed = pos_num
    attempts = 0
    max_attempts = num_neg_needed * 20

    # --- Get number of disease nodes ---
    num_disease_nodes_in_batch = 0
    if NODE_TYPE_DISEASE in batch.node_types:
        # Access the number of nodes using the node type as a key into the batch object
        # The size of any attribute like 'x' or 'n_id' gives the node count for that type
        if hasattr(batch[NODE_TYPE_DISEASE], 'n_id'):
            num_disease_nodes_in_batch = batch[NODE_TYPE_DISEASE].n_id.shape[0]
        elif hasattr(batch[NODE_TYPE_DISEASE], 'x') and batch[NODE_TYPE_DISEASE].x is not None:
            num_disease_nodes_in_batch = batch[NODE_TYPE_DISEASE].x.shape[0]
        
    if num_disease_nodes_in_batch == 0:
        # print(f"Warning: Could not determine number of disease nodes in batch.")
        pass

    while len(neg_edges_source_local) < num_neg_needed and attempts < max_attempts and num_disease_nodes_in_batch > 0:
        attempts += 1
        source_local = random.randint(0, batch_size - 1)
        dest_local = random.randint(0, num_disease_nodes_in_batch - 1)
        neg_edge_local = torch.tensor([[source_local], [dest_local]], dtype=torch.long, device=device)

        is_in_msg_edges = edge_exists(msg_edge_index, neg_edge_local)
        is_in_pos_supervision = edge_exists(pos_edge_label_index_local, neg_edge_local)

        if not is_in_msg_edges and not is_in_pos_supervision:
            neg_edges_source_local.append(source_local)
            neg_edges_dest_local.append(dest_local)

    if len(neg_edges_source_local) < num_neg_needed:
        if attempts >= max_attempts:
            print(f"Warning: Could only sample {len(neg_edges_source_local)}/{num_neg_needed} neg edges after {max_attempts} attempts.")

    neg_edge_label_index_local = torch.tensor([neg_edges_source_local, neg_edges_dest_local], dtype=torch.long, device=device)
    neg_edge_label = torch.zeros(len(neg_edges_source_local), device=device)

    final_edge_label_index = torch.cat((pos_edge_label_index_local, neg_edge_label_index_local), dim=1)
    final_edge_label = torch.cat((pos_edge_label, neg_edge_label), dim=0)

    batch[indication_edge_type].edge_label_index = final_edge_label_index
    batch[indication_edge_type].edge_label = final_edge_label

    if 'mask' in batch[indication_edge_type]:
        del batch[indication_edge_type].mask
    if rev_indication_edge_type in batch.edge_types and 'mask' in batch[rev_indication_edge_type]:
        del batch[rev_indication_edge_type].mask

    return batch

def make_test_batch(batch, full_data, device):
    """Prepares a validation/test batch with positive/negative samples."""
    batch = batch.to(device)
    indication_edge_type = (NODE_TYPE_DRUG, INDICATION_RELATION, NODE_TYPE_DISEASE)

    if indication_edge_type not in batch.edge_types or \
    not hasattr(batch[indication_edge_type], 'edge_label_index') or \
    not hasattr(batch[indication_edge_type], 'edge_label'):
        batch[indication_edge_type].edge_label_index = torch.empty((2, 0), dtype=torch.long, device=device)
        batch[indication_edge_type].edge_label = torch.empty(0, dtype=torch.float, device=device)
        return batch

    if NODE_TYPE_DRUG not in batch.node_types or NODE_TYPE_DISEASE not in batch.node_types:
        batch[indication_edge_type].edge_label_index = torch.empty((2, 0), dtype=torch.long, device=device)
        batch[indication_edge_type].edge_label = torch.empty(0, dtype=torch.float, device=device)
        return batch


    batch_size = batch[NODE_TYPE_DRUG].batch_size
    pos_edge_label_index_global = batch[indication_edge_type].edge_label_index
    pos_edge_label_global = batch[indication_edge_type].edge_label

    global_to_local_drug = {gid.item(): i for i, gid in enumerate(batch[NODE_TYPE_DRUG].n_id)}
    global_to_local_disease = {gid.item(): i for i, gid in enumerate(batch[NODE_TYPE_DISEASE].n_id)}

    local_pos_src, local_pos_dst, kept_labels = [], [], []
    for i in range(pos_edge_label_index_global.shape[1]):
        g_src = pos_edge_label_index_global[0, i].item()
        g_dst = pos_edge_label_index_global[1, i].item()
        if g_src in global_to_local_drug and g_dst in global_to_local_disease:
            local_src_idx = global_to_local_drug[g_src]
            if local_src_idx < batch_size:
                local_pos_src.append(local_src_idx)
                local_pos_dst.append(global_to_local_disease[g_dst])
                kept_labels.append(pos_edge_label_global[i])

    if not local_pos_src:
        batch[indication_edge_type].edge_label_index = torch.empty((2, 0), dtype=torch.long, device=device)
        batch[indication_edge_type].edge_label = torch.empty(0, dtype=torch.float, device=device)
        return batch

    pos_edge_label_index_local = torch.tensor([local_pos_src, local_pos_dst], dtype=torch.long, device=device)
    pos_edge_label_local = torch.tensor(kept_labels, dtype=torch.float, device=device)
    pos_num = pos_edge_label_index_local.shape[1]

    neg_edges_source_local = []
    neg_edges_dest_local = []
    num_neg_needed = pos_num
    attempts = 0
    max_attempts = num_neg_needed * 20

    original_indication_edges = full_data[indication_edge_type].edge_index.to(device)

    # --- Get number of disease nodes (Eval) ---
    num_disease_nodes_in_batch = 0
    if NODE_TYPE_DISEASE in batch.node_types:
        if hasattr(batch[NODE_TYPE_DISEASE], 'n_id'):
            num_disease_nodes_in_batch = batch[NODE_TYPE_DISEASE].n_id.shape[0]
        elif hasattr(batch[NODE_TYPE_DISEASE], 'x') and batch[NODE_TYPE_DISEASE].x is not None:
            num_disease_nodes_in_batch = batch[NODE_TYPE_DISEASE].x.shape[0]



    batch_global_drug_ids = batch[NODE_TYPE_DRUG].n_id
    batch_global_disease_ids = batch[NODE_TYPE_DISEASE].n_id if num_disease_nodes_in_batch > 0 else torch.tensor([], dtype=torch.long, device=device)

    while len(neg_edges_source_local) < num_neg_needed and attempts < max_attempts and num_disease_nodes_in_batch > 0:
        attempts += 1
        source_local = random.randint(0, batch_size - 1)
        dest_local = random.randint(0, num_disease_nodes_in_batch - 1)
        global_drug_id = batch_global_drug_ids[source_local]
        global_disease_id = batch_global_disease_ids[dest_local]
        neg_edge_global = torch.tensor([[global_drug_id], [global_disease_id]], dtype=torch.long, device=device)
        if not edge_exists(original_indication_edges, neg_edge_global):
            neg_edges_source_local.append(source_local)
            neg_edges_dest_local.append(dest_local)

    if len(neg_edges_source_local) < num_neg_needed:
        if attempts >= max_attempts:
            print(f"Warning: Test/Val could only sample {len(neg_edges_source_local)}/{num_neg_needed} neg edges after {max_attempts} attempts.")

    neg_edge_label_index_local = torch.tensor([neg_edges_source_local, neg_edges_dest_local], dtype=torch.long, device=device)
    neg_edge_label_local = torch.zeros(len(neg_edges_source_local), device=device)

    final_edge_label_index = torch.cat((pos_edge_label_index_local, neg_edge_label_index_local), dim=1)
    final_edge_label = torch.cat((pos_edge_label_local, neg_edge_label_local), dim=0)

    batch[indication_edge_type].edge_label_index = final_edge_label_index
    batch[indication_edge_type].edge_label = final_edge_label

    return batch

## 9. Training and Evaluation Functions

In [None]:
def train_epoch(gnn, predictor, loader, optimizer, device):
    """Runs a single training epoch."""
    gnn.train()
    predictor.train()
     total_examples = total_loss = 0.0

     for batch in loader:
        optimizer.zero_grad()
        # Prepare batch (masking, negative sampling) using the function
        processed_batch = make_batch(copy.deepcopy(batch), device) # Use deepcopy to avoid modifying loader's cache

        # Extract labels and indices for loss *after* make_batch processing
        edge_label_index = processed_batch[INDICATION_EDGE_TYPE].edge_label_index
        edge_label = processed_batch[INDICATION_EDGE_TYPE].edge_label

        if edge_label.shape[0] == 0: # Skip if batch ended up with no pos/neg samples
            continue

        # Run GNN forward pass on the processed batch (contains message passing edges)
        drug_embed, disease_embed = gnn(processed_batch.x_dict, processed_batch.edge_index_dict)

        # Check if embeddings were generated
        if drug_embed is None or disease_embed is None:
             print("Warning: Skipping batch due to missing drug/disease embeddings.")
             continue
        if drug_embed.shape[0] == 0 or disease_embed.shape[0] == 0:
             print("Warning: Skipping batch due to empty drug/disease embeddings.")
             continue

        # Select embeddings for loss calculation using edge_label_index
        # Ensure indices are within bounds
        if edge_label_index.numel() > 0: # Check if there are edges to predict
             max_drug_idx = edge_label_index[0].max().item() if edge_label_index[0].numel() > 0 else -1
             max_disease_idx = edge_label_index[1].max().item() if edge_label_index[1].numel() > 0 else -1

             if max_drug_idx >= drug_embed.shape[0] or max_disease_idx >= disease_embed.shape[0]:
                 print(f"Warning: Index out of bounds. Max Drug Idx: {max_drug_idx} (Embed shape: {drug_embed.shape}), Max Disease Idx: {max_disease_idx} (Embed shape: {disease_embed.shape}). Skipping batch.")
                 # print("Problematic edge_label_index:", edge_label_index)
                 # print("Batch drug n_ids:", processed_batch[NODE_TYPE_DRUG].n_id)
                 # print("Batch disease n_ids:", processed_batch[NODE_TYPE_DISEASE].n_id if NODE_TYPE_DISEASE in processed_batch.node_types else "N/A")
                 continue

             drug_nodes_for_pred = drug_embed[edge_label_index[0]]
             disease_nodes_for_pred = disease_embed[edge_label_index[1]]
        else:
             # print("Warning: No edges found in edge_label_index for loss calculation.")
             continue # Skip if no edges to calculate loss on


        # Get predictions
        out = predictor(drug_nodes_for_pred, disease_nodes_for_pred).squeeze(-1) # Remove channel dim

        # Calculate Loss
        loss = compute_loss(out, edge_label)

        # Backpropagate
        if torch.isnan(loss):
             print("Warning: NaN loss detected. Skipping backward pass for this batch.")
             # Optionally add more debugging here (print inputs, model weights, etc.)
        elif loss.requires_grad:
             loss.backward()
             # Optional: Gradient Clipping
             # torch.nn.utils.clip_grad_norm_(list(gnn.parameters()) + list(predictor.parameters()), max_norm=1.0)
             optimizer.step()
             # Update totals
             current_batch_examples = edge_label_index.shape[1]
             total_examples += current_batch_examples
             total_loss += float(loss) * current_batch_examples # Weighted average
        else:
             print("Warning: Loss does not require grad. Skipping backward/step.")


    return total_loss / total_examples if total_examples > 0 else 0.0


@torch.no_grad()
def evaluate(gnn, predictor, loader, full_data, device):
    """Evaluates the model on a validation or test loader."""
    gnn.eval()
    predictor.eval()
    all_preds, all_labels = [], []

    for batch in loader:
        # Prepare batch for evaluation (finds relevant true positives, samples negatives)
        processed_batch = make_test_batch(copy.deepcopy(batch), full_data, device)

        # Extract labels and indices for evaluation *after* make_test_batch processing
        edge_label_index = processed_batch[INDICATION_EDGE_TYPE].edge_label_index
        edge_label = processed_batch[INDICATION_EDGE_TYPE].edge_label

        if edge_label.shape[0] == 0: # Skip if no evaluation samples in batch
            continue

        # Run GNN forward pass on the original batch structure (full message passing)
        # Note: make_test_batch doesn't modify edge_index used for message passing
        drug_embed, disease_embed = gnn(processed_batch.x_dict, processed_batch.edge_index_dict)

        if drug_embed is None or disease_embed is None: continue
        if drug_embed.shape[0] == 0 or disease_embed.shape[0] == 0: continue


        # Select embeddings using the edge_label_index prepared by make_test_batch
        if edge_label_index.numel() > 0:
             max_drug_idx = edge_label_index[0].max().item() if edge_label_index[0].numel() > 0 else -1
             max_disease_idx = edge_label_index[1].max().item() if edge_label_index[1].numel() > 0 else -1

             if max_drug_idx >= drug_embed.shape[0] or max_disease_idx >= disease_embed.shape[0]:
                 print(f"Warning (Eval): Index out of bounds. Max Drug Idx: {max_drug_idx}, Max Disease Idx: {max_disease_idx}. Skipping batch.")
                 continue

             drug_nodes_for_pred = drug_embed[edge_label_index[0]]
             disease_nodes_for_pred = disease_embed[edge_label_index[1]]
        else:
             continue # Skip if no edges


        # Get predictions
        out = predictor(drug_nodes_for_pred, disease_nodes_for_pred).squeeze(-1)

        all_preds.append(out)
        all_labels.append(edge_label)

    if not all_preds:
        print("Warning: No predictions made during evaluation.")
        return torch.tensor(0.0), torch.tensor([]), torch.tensor([]) # loss, preds, labels

    all_preds_tensor = torch.cat(all_preds)
    all_labels_tensor = torch.cat(all_labels)

    # Calculate validation loss
    val_loss = compute_loss(all_preds_tensor, all_labels_tensor)

    return val_loss, all_preds_tensor, all_labels_tensor

## 10. Main Cross-Validation Loop

In [None]:
all_fold_results = [] # Store metrics for each fold

print("\n=== Starting 5-Fold Cross-Validation ===")

for fold_num in range(1, 6):
    fold_start_time = time.time()
    device = get_device_for_fold(fold_num)
    write_to_out(f"\n--- Starting Fold {fold_num}/{5} on device {device} ---", fold_num=fold_num)

    # 1. Preprocess data for the current fold

    train_data, val_data, full_data, entity_dictionary, node_feature_dims = preprocess_fold_data(
        fold_num,
        df_kg=df_primekg_raw,
        biobert_df=biobert_embeddings_df,
        chemberta_df=chemberta_embeddings_df,
        transe_npy=transe_embeddings_npy,
        pykeen_map=pykeen_entity_to_id
        # inference_mode=False is default
    )

    if train_data is None: # Check if preprocessing failed
        write_to_out(f"Error: Preprocessing failed for Fold {fold_num}. Skipping.", fold_num=fold_num)
        continue

    # 2. Define Model and Optimizer
    try:
        gnn = HGT_Combined(
            hidden_channels=HGTDR_CONFIG['hidden_channels'],
            out_channels=HGTDR_CONFIG['out_channels'],
            num_heads=HGTDR_CONFIG['num_heads'],
            num_layers=HGTDR_CONFIG['num_layers'],
            dropout=HGTDR_CONFIG['dropout'],
            node_feature_dims=node_feature_dims,
            metadata=train_data.metadata() # Get metadata from the created train_data
        ).to(device)

        predictor = MLPPredictor(
            channel_num=HGTDR_CONFIG['out_channels'],
            dropout=HGTDR_CONFIG['predictor_dropout']
        ).to(device)

        parameters = list(gnn.parameters()) + list(predictor.parameters())
        optimizer = torch.optim.AdamW(
            parameters,
            lr=HGTDR_CONFIG['learning_rate'],
            weight_decay=HGTDR_CONFIG['weight_decay']
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=HGTDR_CONFIG['epochs'],
            eta_min=0 # Cosine annealing to 0
        )
    except Exception as e:
        write_to_out(f"Error initializing model/optimizer for Fold {fold_num}: {e}", fold_num=fold_num)
        # Clean up partial data if needed
        del train_data, val_data, full_data, entity_dictionary, node_feature_dims
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        continue # Skip fold

    # 3. Define Data Loaders
    primary_input_node = NODE_TYPE_DRUG if NODE_TYPE_DRUG in train_data.node_types else list(train_data.node_types)[0]
    print(f"Using '{primary_input_node}' as primary input node for HGTLoader.")
    loader_kwargs = {'batch_size': HGTDR_CONFIG['batch_size'], 'num_workers': 2, 'persistent_workers': False, 'pin_memory': True if device != torch.device('cpu') else False}

    try:
        if primary_input_node not in train_data.node_types:
            raise ValueError(f"Primary input node '{primary_input_node}' not found in train_data for Fold {fold_num}.")

        train_loader = HGTLoader(
            train_data,
            num_samples=[HGTDR_CONFIG['num_samples']] * HGTDR_CONFIG['num_layers'],
            shuffle=True,
            input_nodes=(primary_input_node, None),
            **loader_kwargs
        )
        val_loader = HGTLoader(
            val_data, # Use val_data for validation loader
            num_samples=[HGTDR_CONFIG['num_samples']] * HGTDR_CONFIG['num_layers'],
            shuffle=False, # No shuffle for validation
            input_nodes=(primary_input_node, None),
            **loader_kwargs
        )
    except Exception as e:
        write_to_out(f"Error creating loaders for Fold {fold_num}: {e}", fold_num=fold_num)
        del gnn, predictor, optimizer, scheduler # Clean up model parts
        del train_data, val_data, full_data, entity_dictionary, node_feature_dims
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        continue # Skip to next fold

    # 4. Training Loop with Checkpointing
    train_losses, val_losses = [], []
    # --- Checkpointing Variables ---
    best_val_aupr_fold = -1.0 # Initialize best AUPR for this fold
    best_gnn_state_dict_fold = None
    best_predictor_state_dict_fold = None
    best_epoch_fold = -1
    # --- End Checkpointing Variables ---

    write_to_out(f"Starting training for Fold {fold_num} ({HGTDR_CONFIG['epochs']} epochs)...", fold_num=fold_num)
    for epoch in range(HGTDR_CONFIG['epochs']):
        epoch_start_time = time.time()
        loss = train_epoch(gnn, predictor, train_loader, optimizer, device)

        # Evaluate validation performance periodically for checkpointing and logging
        current_val_loss = float('nan')
        current_val_aupr = float('nan')
        current_val_auroc = float('nan')
        # Evaluate more frequently early on, less frequently later
        eval_freq = 5 if epoch < 20 else 10 if epoch < 100 else 25
        should_evaluate = (epoch + 1) % eval_freq == 0 or epoch == HGTDR_CONFIG['epochs'] - 1

        if should_evaluate:
            val_loss_tensor, val_preds, val_labels = evaluate(gnn, predictor, val_loader, full_data, device)
            current_val_loss = val_loss_tensor.item()

            if val_labels.numel() > 0: # Check if evaluation produced results
                current_val_auroc, current_val_aupr = calculate_metrics(val_preds, val_labels, fold_num) # Capture metrics
            else:
                current_val_auroc, current_val_aupr = 0.0, 0.0 
                write_to_out(f"Fold {fold_num} Epoch {epoch+1}: No validation samples processed.", fold_num=fold_num)

            val_losses.append(current_val_loss) # Append only when evaluated
            train_losses.append(loss) # Append corresponding train loss

            epoch_duration = time.time() - epoch_start_time
            lr = optimizer.param_groups[0]['lr'] # Get current learning rate
            # Log validation AUPR along with loss
            write_to_out(f'F{fold_num} E: {epoch+1:03d}, TrL: {loss:.4f}, VaL: {current_val_loss:.4f}, VaAUPR: {current_val_aupr:.4f}, LR: {lr:.6f}, T: {epoch_duration:.2f}s', fold_num=fold_num)

            # --- Checkpointing Logic (based on AUPR) ---
            if current_val_aupr > best_val_aupr_fold:
                best_val_aupr_fold = current_val_aupr
                best_epoch_fold = epoch + 1
                best_gnn_state_dict_fold = copy.deepcopy(gnn.state_dict())
                best_predictor_state_dict_fold = copy.deepcopy(predictor.state_dict())
                print(f"    ** New best AUPR for Fold {fold_num}: {best_val_aupr_fold:.4f} at epoch {best_epoch_fold} **")
            # --- End Checkpointing ---

            # Save intermediate loss plot
            if train_losses and val_losses:
                plot_losses(train_losses, val_losses, filename=f"losses_epoch_{epoch+1}.png", fold_num=fold_num)

        else:
            # Optional: Print train loss more often if desired, even without validation
            if (epoch + 1) % 10 == 0:
                epoch_duration = time.time() - epoch_start_time
                write_to_out(f'F{fold_num} E: {epoch+1:03d}, TrL: {loss:.4f}, T: {epoch_duration:.2f}s', fold_num=fold_num)

        # Step the scheduler after each epoch
        scheduler.step()

    # 5. Final Evaluation & Reporting for this Fold
    write_to_out(f"--- Evaluating Fold {fold_num} with FINAL model state ---", fold_num=fold_num)
    final_val_loss, final_val_preds, final_val_labels = evaluate(gnn, predictor, val_loader, full_data, device)

    if final_val_labels.numel() > 0:
        final_auroc, final_aupr = calculate_metrics(final_val_preds, final_val_labels, fold_num)
        # Store results - maybe store both final and best? Adjust as needed.
        # Storing final metrics here for consistency with potential original paper reporting style
        all_fold_results.append({
            'fold': fold_num,
            'final_auroc': final_auroc,
            'final_aupr': final_aupr,
            'final_loss': final_val_loss.item(),
            'best_val_aupr': best_val_aupr_fold, # Also record the best achieved AUPR
            'best_epoch': best_epoch_fold
        })
    else:
        write_to_out(f"Fold {fold_num}: No validation samples processed for final evaluation.", fold_num=fold_num)
        all_fold_results.append({
            'fold': fold_num, 'final_auroc': 0.0, 'final_aupr': 0.0, 'final_loss': final_val_loss.item(),
            'best_val_aupr': best_val_aupr_fold, 'best_epoch': best_epoch_fold
        })


    # 6. Save the BEST Model State for this Fold
    if best_gnn_state_dict_fold and best_predictor_state_dict_fold:
        gnn_save_path = os.path.join(OUT_DIR, f'Fold_{fold_num}_best_gnn.pth')
        predictor_save_path = os.path.join(OUT_DIR, f'Fold_{fold_num}_best_predictor.pth')
        try:
            torch.save(best_gnn_state_dict_fold, gnn_save_path)
            torch.save(best_predictor_state_dict_fold, predictor_save_path)
            write_to_out(f"Saved BEST model for Fold {fold_num} (Epoch: {best_epoch_fold}, AUPR: {best_val_aupr_fold:.4f}) to {OUT_DIR}", fold_num=fold_num)
        except Exception as e:
            write_to_out(f"Error saving best model for Fold {fold_num}: {e}", fold_num=fold_num)
    else:
        write_to_out(f"Warning: No best model state was saved for Fold {fold_num} (perhaps no improvement).", fold_num=fold_num)

    # 7. Plot final losses for this fold
    if train_losses and val_losses:
         plot_losses(train_losses, val_losses, filename="final_losses.png", fold_num=fold_num)

    fold_duration = time.time() - fold_start_time
    write_to_out(f"--- Fold {fold_num} Finished in {fold_duration:.2f} seconds ---", fold_num=fold_num)

    # 8. Clean up memory
    del gnn, predictor, optimizer, scheduler, train_loader, val_loader
    # Also delete saved best states if memory is very tight, though usually not necessary
    del best_gnn_state_dict_fold, best_predictor_state_dict_fold
    del train_data, val_data, full_data, entity_dictionary, node_feature_dims # Remove fold-specific data
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        # print(f"Cleared CUDA cache for device {device}")

print("\n=== Cross-Validation Finished ===")

## 11. Aggregate and Report Results

In [None]:
if all_fold_results:
    results_df = pd.DataFrame(all_fold_results)
    print("\n--- Aggregated 5-Fold CV Results ---")
    # Use round() for cleaner display
    print(results_df.round(4).to_markdown(index=False))

    # --- Use the correct column names ---
    mean_auroc = results_df['final_auroc'].mean() # Changed from 'auroc'
    std_auroc = results_df['final_auroc'].std()   # Changed from 'auroc'
    mean_aupr = results_df['final_aupr'].mean()   # Changed from 'aupr'
    std_aupr = results_df['final_aupr'].std()     # Changed from 'aupr'
    # --- End Correction ---

    # Also report mean best AUPR achieved during training
    mean_best_aupr = results_df['best_val_aupr'].mean()
    std_best_aupr = results_df['best_val_aupr'].std()


    summary = f"""
    -----------------------------------------
    Cross-Validation Summary (HGTDR + TransE):

    Final Model Performance (End of Training):
      Mean AUROC: {mean_auroc:.4f} +/- {std_auroc:.4f}
      Mean AUPR:  {mean_aupr:.4f} +/- {std_aupr:.4f}

    Best Validation Performance (During Training):
      Mean Best Val AUPR: {mean_best_aupr:.4f} +/- {std_best_aupr:.4f}
      (Based on checkpoints saved during training)
    -----------------------------------------
    """
    print(summary)
    # Write summary to a general output file
    write_to_out(summary, filename="summary_results.txt")

    # Save detailed results to CSV
    results_save_path = os.path.join(OUT_DIR, "detailed_results.csv")
    try:
        # Save with rounding for consistency
        results_df.round(4).to_csv(results_save_path, index=False)
        print(f"Saved detailed results to: {results_save_path}")
    except Exception as e:
        print(f"Error saving detailed results CSV: {e}")

else:
    print("No results generated. Please check for errors in the cross-validation loop.")

print("\n--- Script Finished ---") # Move this if Cell 12 follows

## 11.5. Prepare Full Data for Candidate Generation

In [None]:
print("\n--- Preparing Full Data for Inference ---")
try:
    # Call preprocess_fold_data with inference_mode=True
    _, _, full_data_for_inference, entity_dictionary_for_inference, _ = preprocess_fold_data(
        fold_num=0, # Placeholder fold number
        df_kg=df_primekg_raw,
        biobert_df=biobert_embeddings_df,
        chemberta_df=chemberta_embeddings_df,
        transe_npy=transe_embeddings_npy,
        pykeen_map=pykeen_entity_to_id,
        inference_mode=True  # <-- Tell the function to skip split loading/masking
    )

    if full_data_for_inference is None or entity_dictionary_for_inference is None:
         raise RuntimeError("Preprocessing failed to return necessary data in inference mode.")

    print("Full data prepared successfully for inference.")

except NameError as e:
    print(f"Error: Required dataframes/mappings not found. Ensure previous cells were run. ({e})")
    full_data_for_inference = None
    entity_dictionary_for_inference = None
except Exception as e:
    print(f"An error occurred during final preprocessing: {e}")
    full_data_for_inference = None
    entity_dictionary_for_inference = None

## 12. Generate Drug Repurposing Candidates

In [None]:
import torch
import pandas as pd
import itertools
import os

print("\n--- Generating Drug Repurposing Candidates ---")

# --- Configuration ---
NUM_CANDIDATES_TO_SHOW = 50
INFERENCE_BATCH_SIZE = 512
FOLD_TO_LOAD = 5 # <<< CHOOSE WHICH FOLD'S BEST MODEL TO USE (e.g., 5 or the overall best)

# --- Use the prepared data ---
if 'full_data_for_inference' not in locals() or full_data_for_inference is None or \
'entity_dictionary_for_inference' not in locals() or entity_dictionary_for_inference is None:
    print("Error: Full data/entity dictionary not prepared. Please run the 'Prepare Full Data' cell first.")
    # exit()
else:
    print("Using pre-prepared full_data and entity_dictionary for inference.")
    full_data = full_data_for_inference
    entity_dictionary = entity_dictionary_for_inference

    # --- Instantiate Model Architecture (needed before loading state_dict) ---
    # Ensure node_feature_dims is available or recalculate from full_data
    try:
        if 'node_feature_dims' not in locals():
            print("Recalculating node_feature_dims...")
            node_feature_dims = {}
            for node_type in full_data.node_types:
                if hasattr(full_data[node_type], 'x') and full_data[node_type].x is not None:
                    node_feature_dims[node_type] = full_data[node_type].x.shape[1]
                else: node_feature_dims[node_type] = 0

        # Use a placeholder device first, will be moved later
        temp_device = torch.device('cpu')
        gnn = HGT_Combined(
            hidden_channels=HGTDR_CONFIG['hidden_channels'], out_channels=HGTDR_CONFIG['out_channels'],
            num_heads=HGTDR_CONFIG['num_heads'], num_layers=HGTDR_CONFIG['num_layers'],
            dropout=HGTDR_CONFIG['dropout'], node_feature_dims=node_feature_dims,
            metadata=full_data.metadata()
        ).to(temp_device) # Instantiate on CPU first

        predictor = MLPPredictor(
            channel_num=HGTDR_CONFIG['out_channels'], dropout=HGTDR_CONFIG['predictor_dropout']
        ).to(temp_device)
    except NameError as e:
        print(f"Error: HGTDR_CONFIG or other necessary components not defined. {e}")
        # exit()
    except Exception as e:
        print(f"Error instantiating model architecture: {e}")
        # exit()


    # --- Load the BEST model state dict ---
    gnn_load_path = os.path.join(OUT_DIR, f'Fold_{FOLD_TO_LOAD}_best_gnn.pth')
    predictor_load_path = os.path.join(OUT_DIR, f'Fold_{FOLD_TO_LOAD}_best_predictor.pth')

    try:
        print(f"Loading best model state from Fold {FOLD_TO_LOAD}...")
        # Load state dict using map_location='cpu' initially in case saved on GPU
        gnn.load_state_dict(torch.load(gnn_load_path, map_location='cpu'))
        predictor.load_state_dict(torch.load(predictor_load_path, map_location='cpu'))
        print("Model state loaded successfully.")

        # --- Determine Inference Device and Move Model ---
        # Use a GPU if available, otherwise CPU
        model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        gnn.to(model_device)
        predictor.to(model_device)
        gnn.eval()
        predictor.eval()
        print(f"Using model on device: {model_device}")

    except FileNotFoundError:
        print(f"Error: Model state dict files not found for Fold {FOLD_TO_LOAD}.")
        print(f"Looked for: {gnn_load_path} and {predictor_load_path}")
        # exit() # Stop if model can't be loaded
    except Exception as e:
        print(f"Error loading model state dict: {e}")

    # --- Get all Drug and Disease Nodes ---
    try:
        drug_name_to_hgt_id = entity_dictionary.get(NODE_TYPE_DRUG, {}); disease_name_to_hgt_id = entity_dictionary.get(NODE_TYPE_DISEASE, {})
        if not drug_name_to_hgt_id or not disease_name_to_hgt_id: raise ValueError("Entity dictionary missing types.")
        all_drug_hgt_ids = list(drug_name_to_hgt_id.values()); all_disease_hgt_ids = list(disease_name_to_hgt_id.values())
        hgt_id_to_drug_name = {v: k for k, v in drug_name_to_hgt_id.items()}; hgt_id_to_disease_name = {v: k for k, v in disease_name_to_hgt_id.items()}
        num_drugs = len(all_drug_hgt_ids); num_diseases = len(all_disease_hgt_ids); print(f"Found {num_drugs} drugs and {num_diseases} diseases.")
        known_indications_hgt_ids = set()
        if INDICATION_EDGE_TYPE in full_data.edge_types:
            known_edges = full_data[INDICATION_EDGE_TYPE].edge_index.t().tolist()
            for src_hgt_id, dst_hgt_id in known_edges: known_indications_hgt_ids.add((src_hgt_id, dst_hgt_id))
        print(f"Found {len(known_indications_hgt_ids)} known indications.")
    except Exception as e:
        print(f"Error getting node lists or known indications: {e}"); predicted_scores, predicted_pairs_hgt_ids = None, None

    # --- Perform Inference ---
    @torch.no_grad()
    def predict_scores(drug_ids, disease_ids):
        try:
            print("Generating node embeddings for full graph...")
            current_full_data = full_data.to(model_device) # Move data to model device
            drug_embed_all, disease_embed_all = gnn(current_full_data.x_dict, current_full_data.edge_index_dict)
            print("Node embeddings generated.")
        except Exception as e: print(f"Error generating full graph embeddings: {e}."); return None, None
        if drug_embed_all is None or disease_embed_all is None: print("Error: Failed to get embeddings."); return None, None
        scores = []; pairs = []
        drug_ids_tensor = torch.tensor(drug_ids, dtype=torch.long, device=model_device); disease_ids_tensor = torch.tensor(disease_ids, dtype=torch.long, device=model_device)
        num_pairs = len(drug_ids)
        for i in range(0, num_pairs, INFERENCE_BATCH_SIZE):
            batch_drug_ids = drug_ids_tensor[i : i + INFERENCE_BATCH_SIZE]; batch_disease_ids = disease_ids_tensor[i : i + INFERENCE_BATCH_SIZE]
            # Handle potential out-of-bounds if node IDs somehow exceed embedding tensor size
            if batch_drug_ids.max() >= drug_embed_all.shape[0] or batch_disease_ids.max() >= disease_embed_all.shape[0]:
                print(f"Warning: Skipping inference batch {i//INFERENCE_BATCH_SIZE} due to index out of bounds.")
                continue
            batch_drug_embeds = drug_embed_all[batch_drug_ids]; batch_disease_embeds = disease_embed_all[batch_disease_ids]
            batch_scores = predictor(batch_drug_embeds, batch_disease_embeds).squeeze(-1)
            scores.append(batch_scores.cpu())
            for j in range(len(batch_drug_ids)): pairs.append((drug_ids[i+j], disease_ids[i+j]))
            if (i // INFERENCE_BATCH_SIZE) % 100 == 0: print(f"  Processed {i + len(batch_drug_ids)} / {num_pairs} pairs...")
        if not scores: return None, None
        all_scores = torch.cat(scores); return all_scores, pairs

    # --- Generate Candidate Pairs and Predict ---
    if 'all_drug_hgt_ids' in locals():
        candidate_drug_ids = []; candidate_disease_ids = []; print("Generating candidate drug-disease pairs (excluding known indications)..."); count = 0
        for drug_hgt_id in all_drug_hgt_ids:
            for disease_hgt_id in all_disease_hgt_ids:
                if (drug_hgt_id, disease_hgt_id) not in known_indications_hgt_ids: candidate_drug_ids.append(drug_hgt_id); candidate_disease_ids.append(disease_hgt_id)
            count += 1
            if count % 100 == 0: print(f"  Generated pairs for {count}/{num_drugs} drugs...")
        print(f"Generated {len(candidate_drug_ids)} candidate pairs.")
        if candidate_drug_ids: predicted_scores, predicted_pairs_hgt_ids = predict_scores(candidate_drug_ids, candidate_disease_ids)
        else: predicted_scores, predicted_pairs_hgt_ids = None, None
    else: predicted_scores, predicted_pairs_hgt_ids = None, None

    # --- Rank and Display Candidates ---
    if predicted_scores is not None and predicted_pairs_hgt_ids is not None:
        print("\nRanking candidates by predicted score..."); scores_np = predicted_scores.numpy(); sorted_indices = np.argsort(scores_np)[::-1]
        results = [];
        for i in range(min(len(sorted_indices), NUM_CANDIDATES_TO_SHOW * 2)):
            idx = sorted_indices[i]; drug_hgt_id, disease_hgt_id = predicted_pairs_hgt_ids[idx]
            drug_name = hgt_id_to_drug_name.get(drug_hgt_id, f"ID:{drug_hgt_id}"); disease_name = hgt_id_to_disease_name.get(disease_hgt_id, f"ID:{disease_hgt_id}")
            score = scores_np[idx]; results.append({"Rank": len(results) + 1, "Drug Name": drug_name, "Disease Name": disease_name, "Predicted Score (Logit)": score, "Drug HGT ID": drug_hgt_id, "Disease HGT ID": disease_hgt_id})
        results_df = pd.DataFrame(results); print(f"\n--- Top {NUM_CANDIDATES_TO_SHOW} Drug Repurposing Candidates ---"); print(results_df.head(NUM_CANDIDATES_TO_SHOW).to_markdown(index=False))
        results_save_path = os.path.join(OUT_DIR, "repurposing_candidates.csv");
        try: results_df.to_csv(results_save_path, index=False, float_format='%.4f'); print(f"\nSaved top candidates to: {results_save_path}")
        except Exception as e: print(f"\nError saving results to CSV: {e}")
    else: print("\nCould not generate predictions.")

print("\n--- Candidate Generation Finished ---")