In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# PyG imports
from torch_geometric.nn import GCNConv, GATv2Conv, PNAConv, GINEConv
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import degree

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import StratifiedKFold
# from sklearn.preprocessing import RobustScaler # Not used directly in this version
import matplotlib.pyplot as plt # Keep if plotting results later
import random
import os
import traceback # For detailed error reporting

# Set random seeds ONCE for the single run
SEED = 42 # Or choose a specific seed for this run
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    # Optional: For stricter reproducibility (can slow down training)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
else:
    print("CUDA not available, running on CPU.")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Using Seed: {SEED}")

Using device: cuda
Using Seed: 42


In [None]:
# Amino Acid Encodings
AMINO_ACIDS = 'ARNDCQEGHILKMFPSTWYV-' # Includes padding char
AA_TO_INT = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
VALID_AA = 'ARNDCQEGHILKMFPSTWYV' # Valid AAs for one-hot encoding

# Define expected edge feature dimension (corrected based on previous discussion)
EXPECTED_EDGE_FEATURE_DIM = 17

# Function to create detailed Edge Features (Needed for GATv2 w/ EF, GINE)
def create_edge_features(i_orig, j_orig, dist_map, sequence, ss_string, sasa_vals, K_POS=16):
    """
    Create rich edge features (17 dims) between original indices i_orig and j_orig.
    Handles padding checks.

    Args:
        i_orig, j_orig: Original 0-32 indices in the sequence window.
        dist_map: Full 33x33 distance map numpy array.
        sequence: Full 33-char sequence string.
        ss_string: Full 33-char secondary structure string.
        sasa_vals: Full 33-element SASA numpy array.
        K_POS: Absolute index (0-32) of the central Lysine.

    Returns:
        A fixed-size numpy array (17,) dtype=np.float32.
    """
    edge_features = np.zeros(EXPECTED_EDGE_FEATURE_DIM, dtype=np.float32) # Use corrected dimension (17)

    # Check if indices are valid and if residues are padded
    seq_len = len(sequence)
    if not (0 <= i_orig < seq_len and 0 <= j_orig < seq_len):
        # print(f"Warning: Index out of bounds in create_edge_features ({i_orig}, {j_orig}) vs len {seq_len}")
        return edge_features # Should ideally not happen if called correctly
    if sequence[i_orig] == '-' or sequence[j_orig] == '-':
        return edge_features # Return zeros if edge connects to a padded position

    # Safely get distance
    distance = 0.0
    if i_orig < dist_map.shape[0] and j_orig < dist_map.shape[1]:
        distance = dist_map[i_orig, j_orig]
    else:
        # print(f"Warning: Index out of bounds for dist_map in create_edge_features ({i_orig}, {j_orig})")
        return edge_features # Index out of bounds for dist_map

    feature_idx = 0

    # 1. Distance Features (5 features total)
    if distance <= 0: # Handle zero or negative distance case for bins
        pass # Leave features 0-3 as 0
    elif distance <= 4.0:
        edge_features[feature_idx] = 1.0
    elif distance <= 8.0:
        edge_features[feature_idx + 1] = 1.0
    elif distance <= 12.0:
        edge_features[feature_idx + 2] = 1.0
    else: # distance > 12.0
        edge_features[feature_idx + 3] = 1.0
    feature_idx += 4

    # Continuous distance feature (inverse distance) - handle division by zero
    edge_features[feature_idx] = (1.0 / distance) if distance > 1e-6 else 0.0 # Use epsilon for safety
    feature_idx += 1

    # 2. Sequential Features (2 features)
    seq_dist = abs(i_orig - j_orig)
    edge_features[feature_idx] = float(seq_dist == 1)  # Is sequential neighbor
    edge_features[feature_idx + 1] = seq_dist / max(1, seq_len - 1) # Normalize by max possible seq dist
    feature_idx += 2

    # 3. K-relative Features (2 features)
    edge_features[feature_idx] = float(i_orig == K_POS or j_orig == K_POS)  # Is connected to K
    # Normalize distance to K by max possible distance from K within the window
    max_dist_from_k = max(1, K_POS, seq_len - 1 - K_POS)
    edge_features[feature_idx + 1] = min(abs(i_orig - K_POS), abs(j_orig - K_POS)) / max_dist_from_k
    feature_idx += 2

    # 4. Secondary Structure Interaction (6 features)
    if i_orig < len(ss_string) and j_orig < len(ss_string):
        ss_pairs = ['HH', 'HE', 'HL', 'EE', 'EL', 'LL'] # Canonical pairs
        # Use get with default for safety, map '-' to L (though padding check should prevent '-')
        ss_map = {'H': 'H', 'E': 'E', 'L': 'L', '-': 'L'}
        ss_i = ss_map.get(ss_string[i_orig], 'L')
        ss_j = ss_map.get(ss_string[j_orig], 'L')
        ss_pair = ''.join(sorted([ss_i, ss_j])) # Sort to make order irrelevant (HE == EH)
        for idx, pair in enumerate(ss_pairs):
            edge_features[feature_idx + idx] = float(ss_pair == pair)
    # else: print(f"Warning: Index out of bounds for ss_string in create_edge_features") # Debugging
    feature_idx += 6 # Increment index regardless of whether features were calculated (maintains size 17)

    # 5. SASA Interaction (2 features)
    if i_orig < len(sasa_vals) and j_orig < len(sasa_vals):
        sasa_i = np.nan_to_num(sasa_vals[i_orig]) # Handle potential NaNs
        sasa_j = np.nan_to_num(sasa_vals[j_orig])
        edge_features[feature_idx] = abs(sasa_i - sasa_j) # SASA difference
        edge_features[feature_idx + 1] = (sasa_i + sasa_j) / 2.0 # SASA average
    # else: print(f"Warning: Index out of bounds for sasa_vals in create_edge_features") # Debugging
    feature_idx += 2 # Increment index regardless

    # Final check (optional, for debugging)
    # if feature_idx != EXPECTED_EDGE_FEATURE_DIM:
    #    print(f"Warning: Feature index calculation error in create_edge_features. Expected {EXPECTED_EDGE_FEATURE_DIM}, got {feature_idx}")

    # Clip features or handle NaNs/Infs just in case
    return np.nan_to_num(edge_features, nan=0.0, posinf=0.0, neginf=0.0)

In [None]:
def prepare_graph_data(df, distance_threshold=8.0, use_ss=True, include_edge_features=False):
    """
    Prepare graph data for PyTorch Geometric.

    Extracts detailed node features for all valid residues.
    Optionally generates 17-dim edge features if include_edge_features=True.

    Args:
        df (pd.DataFrame): Input DataFrame.
        distance_threshold (float): Max distance for graph edges.
        use_ss (bool): Whether to include secondary structure node features.
        include_edge_features (bool): Whether to generate and include edge_attr.

    Returns:
        tuple: (list_of_Data_objects, list_of_labels)
    """
    graph_list = []
    labels = []
    skipped = 0
    expected_seq_len = 33
    central_k_pos_abs = 16 # 0-based index in the 33-residue window

    # Columns expected to be lists/arrays represented as strings
    feature_names_to_parse = [
        'phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4',
        'sasa', 'ss', 'plDDT', 'distance_map'
    ]
    print(f"Preparing data (Edge Features: {include_edge_features}, SS Features: {use_ss}, Threshold: {distance_threshold}Å)...")

    for idx, row in df.iterrows():
        try:
            sequence = row['sequence']
            label = row['label']

            # --- Initial Validation ---
            if pd.isna(sequence) or len(sequence) != expected_seq_len or sequence[central_k_pos_abs] != 'K':
                skipped += 1
                continue

            # --- Parse Required Data from row ---
            parsed_data = {}
            valid_row = True
            for name in feature_names_to_parse:
                if name not in row or pd.isna(row[name]):
                    # print(f"Warning: Missing data for '{name}' row {idx}.") # Less verbose
                    valid_row = False; break
                try:
                    # Store SS as string, others evaluate to numpy arrays
                    if name == 'ss':
                        parsed_data[name] = str(row[name])
                        if len(parsed_data[name]) != expected_seq_len: raise ValueError("SS length mismatch")
                    else:
                        # Using eval assumes data is stored as string representations of lists/arrays
                        parsed_data[name] = np.array(eval(str(row[name])), dtype=np.float32)
                except Exception as e:
                    # print(f"Warning: Error parsing '{name}' in row {idx}: {e}.") # Less verbose
                    valid_row = False; break
            if not valid_row: skipped += 1; continue

            # Reshape distance map and check size
            if parsed_data['distance_map'].size == expected_seq_len * expected_seq_len:
                 distance_map = parsed_data['distance_map'].reshape(expected_seq_len, expected_seq_len)
            else: skipped += 1; continue # Skip if distance map wrong size

            # --- Identify Valid (Non-padded) Positions ---
            valid_pos_indices = [i for i, aa in enumerate(sequence) if aa != '-']
            if not valid_pos_indices: skipped += 1; continue # Skip if no valid residues
            num_nodes = len(valid_pos_indices)
            valid_sequence = ''.join([sequence[i] for i in valid_pos_indices])

            # Find new 0-based index of central K within the valid nodes
            try:
                central_k_new_idx = valid_pos_indices.index(central_k_pos_abs)
            except ValueError: # Happens if the central K was padded '-'
                skipped += 1; continue

            # --- Node Feature Extraction ---
            node_features_list = []
            # 1. AA One-hot (20 features)
            aa_onehot = np.zeros((num_nodes, len(VALID_AA)), dtype=np.float32)
            for i, aa in enumerate(valid_sequence):
                 aa_idx = VALID_AA.find(aa)
                 if aa_idx >= 0: aa_onehot[i, aa_idx] = 1.0
            node_features_list.append(aa_onehot)

            # 2. Central K indicator (1 feature)
            is_central_k = np.zeros((num_nodes, 1), dtype=np.float32)
            is_central_k[central_k_new_idx, 0] = 1.0
            node_features_list.append(is_central_k)

            # 3. Angles (phi, psi, omega, tau, chi1-4) -> sin/cos (8 * 2 = 16 features)
            angle_keys = ['phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4']
            angle_keys = ['phi', 'psi', 'omega']

            for key in angle_keys:
                # Extract angles for valid positions
                valid_angles = parsed_data[key][valid_pos_indices]
                # Handle potential NaNs (e.g., chi angles for GLY) -> replace with 0
                valid_angles = np.nan_to_num(valid_angles)
                # Convert to radians and apply sin/cos
                angle_rad = np.pi * valid_angles / 180.0
                sin_cos_features = np.stack([np.sin(angle_rad), np.cos(angle_rad)], axis=-1)
                node_features_list.append(sin_cos_features.astype(np.float32))

            # 4. SASA (1 feature)
            valid_sasa = parsed_data['sasa'][valid_pos_indices].reshape(-1, 1)
            # node_features_list.append(np.nan_to_num(valid_sasa).astype(np.float32)) # Use raw, handle NaN

            # 5. SS (Secondary Structure) (3 features if use_ss=True)
            if use_ss:
                ss_string = parsed_data['ss']
                valid_ss_chars = [ss_string[i] for i in valid_pos_indices]
                ss_onehot = np.zeros((num_nodes, 3), dtype=np.float32) # H=0, E=1, L=2
                ss_map = {'H': 0, 'E': 1, 'L': 2, '-': 2} # Map '-' to 'L' just in case
                for i, ss_char in enumerate(valid_ss_chars):
                    ss_onehot[i, ss_map.get(ss_char, 2)] = 1.0 # Default to 'L' if unexpected char
                # node_features_list.append(ss_onehot)

            # 6. pLDDT (1 feature)
            valid_plddt = parsed_data['plDDT'][valid_pos_indices].reshape(-1, 1)
            # node_features_list.append(np.nan_to_num(valid_plddt).astype(np.float32)) # Use raw, handle NaN

            # --- Concatenate all node features ---
            try:
                node_features = np.concatenate(node_features_list, axis=1)
            except ValueError as e: # Catch potential concatenation errors (e.g., empty list)
                 print(f"Error concatenating node features for row {idx}: {e}")
                 skipped += 1; continue

            # --- Edge Construction (Structure only for edge_index) ---
            valid_distance_map = distance_map[np.ix_(valid_pos_indices, valid_pos_indices)]
            adj = valid_distance_map < distance_threshold
            adj &= valid_distance_map > 0 # Distance must be positive
            np.fill_diagonal(adj, False) # No self-loops from distance threshold
            edge_list_valid = np.argwhere(adj) # Indices relative to valid_pos_indices

            edges = []
            edge_features_list = [] # Only populated if include_edge_features is True

            # Create distance-based edges
            if edge_list_valid.shape[0] > 0:
                edges = edge_list_valid.tolist() # Edge indices relative to valid nodes
                # Optionally create rich edge features
                if include_edge_features:
                    # Need full arrays for create_edge_features context
                    sasa_full = parsed_data['sasa']
                    ss_full = parsed_data['ss']
                    for i_valid, j_valid in edges:
                        # Map valid indices back to original 0-32 indices
                        i_orig = valid_pos_indices[i_valid]
                        j_orig = valid_pos_indices[j_valid]
                        # Call function to get 17-dim features
                        edge_feat_vec = create_edge_features(
                            i_orig, j_orig, distance_map, sequence, ss_full, sasa_full
                        )
                        edge_features_list.append(edge_feat_vec)

            # --- Handle Fallback Sequential Edges ---
            if not edges and num_nodes > 1: # Add seq edges ONLY if no distance edges were found
                for i_valid in range(num_nodes - 1):
                    j_valid = i_valid + 1
                    # Add edges relative to valid nodes
                    edges.extend([[i_valid, j_valid], [j_valid, i_valid]])
                    # Optionally create features for these sequential edges
                    if include_edge_features:
                        i_orig = valid_pos_indices[i_valid]
                        j_orig = valid_pos_indices[j_valid]
                        sasa_full = parsed_data['sasa']
                        ss_full = parsed_data['ss']
                        feat1 = create_edge_features(i_orig, j_orig, distance_map, sequence, ss_full, sasa_full)
                        feat2 = create_edge_features(j_orig, i_orig, distance_map, sequence, ss_full, sasa_full)
                        edge_features_list.extend([feat1, feat2])

            # Skip graph if it has no edges after fallback
            if not edges:
                skipped += 1
                continue

            # --- Convert to PyTorch Tensors ---
            x_tensor = torch.tensor(node_features, dtype=torch.float)
            edge_index_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()
            y_tensor = torch.tensor([label], dtype=torch.float)
            # Keep sequence tensor in case it's needed elsewhere, though CNN is removed
            sequence_tensor = torch.tensor(integer_encode_sequence(sequence), dtype=torch.long)
            central_node_idx_tensor = torch.tensor([central_k_new_idx], dtype=torch.long)

            # Create Data object, conditionally adding edge_attr
            data_dict = {
                'x': x_tensor,
                'edge_index': edge_index_tensor,
                'y': y_tensor,
                'sequence': sequence_tensor,
                'central_node_idx': central_node_idx_tensor
            }
            if include_edge_features and edge_features_list:
                edge_attr_tensor = torch.tensor(np.array(edge_features_list), dtype=torch.float)
                # Dimension Check for Edge Features
                if edge_attr_tensor.shape[0] == edge_index_tensor.shape[1] and \
                   edge_attr_tensor.shape[1] == EXPECTED_EDGE_FEATURE_DIM:
                    data_dict['edge_attr'] = edge_attr_tensor
                else:
                    print(f"Warning: Edge feature dimension mismatch row {idx}. Attr:{edge_attr_tensor.shape}, Idx:{edge_index_tensor.shape}. Skipping edge_attr.")
                    # Decide whether to skip row or just not add edge_attr
                    # Skipping edge_attr for this sample if mismatch occurs
            elif include_edge_features and not edge_features_list and edge_index_tensor.shape[1] > 0:
                 print(f"Warning: include_edge_features=True but edge_features_list is empty for row {idx} with {edge_index_tensor.shape[1]} edges.")


            data = Data(**data_dict)
            graph_list.append(data)
            labels.append(label)

        except Exception as e:
            print(f"--- Critical Error processing row {idx}: {e} ---")
            traceback.print_exc() # Print full traceback for critical errors
            skipped += 1
            continue # Skip to next row

    print(f"Created {len(graph_list)} graphs, skipped {skipped} rows.")
    if graph_list:
         print(f" Node feature dimension: {graph_list[0].x.shape[1]}")
         if 'edge_attr' in graph_list[0]:
              print(f" Edge feature dimension: {graph_list[0].edge_attr.shape[1]}")
         else:
              print(f" Edge features not included in Data objects.")
    return graph_list, labels


# Helper function (ensure it's defined)
def integer_encode_sequence(sequence):
    """Integer encode sequence (includes padding char)."""
    return [AA_TO_INT.get(char, AA_TO_INT['-']) for char in sequence]

In [None]:
class GCNNetwork(nn.Module):
    """ GNN Network using GCNConv layers (No edge features) """
    def __init__(self, input_dim, hidden_dim=128, dropout=0.4, layers=3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        self.convs.append(GCNConv(input_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        for _ in range(1, layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout_layer = nn.Dropout(dropout)
        self.output_dim = hidden_dim # For downstream model

    def forward(self, x, edge_index):
        # x: [N, input_dim], edge_index: [2, E]
        for i in range(len(self.convs)):
            x_res = x # Store for residual connection

            x = self.convs[i](x, edge_index)

            if x.shape[0] > 0: # Handle cases where batch might be empty
                x = self.batch_norms[i](x)
            else:
                return x # Return empty tensor if no nodes

            x = F.relu(x)

            # Add residual connection after activation and batchnorm
            # Skip for first layer or if shapes mismatch (shouldn't with consistent hidden_dim)
            if i > 0 and x_res.shape == x.shape:
                x = x + x_res

            x = self.dropout_layer(x)

        # Return all final node features
        return x # Shape: [N, hidden_dim]

class GATv2Network(nn.Module):
    """ GNN Network using GATv2Conv layers (Optionally uses edge features) """
    def __init__(self, input_dim, hidden_dim=128, heads=4, dropout=0.4, layers=3, edge_dim=None): # edge_dim=None for no EF
        super().__init__()
        if hidden_dim % heads != 0:
            raise ValueError("GATv2Network: hidden_dim must be divisible by heads")

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        head_dim = hidden_dim // heads
        self.has_edge_features = edge_dim is not None # Flag based on input

        # Layer Definitions
        # First layer
        self.convs.append(GATv2Conv(input_dim, head_dim, heads=heads, concat=True,
                                    dropout=dropout, edge_dim=edge_dim, add_self_loops=True))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Middle layers
        for _ in range(1, layers):
            self.convs.append(GATv2Conv(hidden_dim, head_dim, heads=heads, concat=True,
                                        dropout=dropout, edge_dim=edge_dim, add_self_loops=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout_layer = nn.Dropout(dropout)
        self.output_dim = hidden_dim # For downstream model

    def forward(self, x, edge_index, edge_attr=None): # edge_attr is optional input
        # x: [N, input_dim], edge_index: [2, E], edge_attr: [E, edge_dim] or None
        for i in range(len(self.convs)):
            x_res = x # Store for residual connection

            # Pass edge_attr only if the model expects it
            if self.has_edge_features:
                if edge_attr is None:
                     raise ValueError("GATv2Network configured with edge_dim but edge_attr not provided to forward.")
                x = self.convs[i](x, edge_index, edge_attr=edge_attr)
            else:
                x = self.convs[i](x, edge_index) # Call without edge_attr

            if x.shape[0] > 0:
                x = self.batch_norms[i](x)
            else:
                return x # Return empty tensor

            x = F.elu(x) # Use ELU activation for GAT variants

            # Add residual connection
            if i > 0 and x_res.shape == x.shape:
                x = x + x_res

            x = self.dropout_layer(x)

        # Return all final node features
        return x # Shape: [N, hidden_dim]

class PNANetwork(nn.Module):
    """ GNN Network using PNAConv layers (No edge features) """
    def __init__(self, input_dim, hidden_dim=128, layers=3, dropout=0.4, deg=None):
        super().__init__()
        if deg is None:
            raise ValueError("PNANetwork requires the degree histogram 'deg' argument.")

        # PNA parameters (these can be tuned)
        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        # First layer
        # towers=4, post_layers=1 are common PNA defaults/recommendations
        self.convs.append(PNAConv(input_dim, hidden_dim, aggregators=aggregators,
                                   scalers=scalers, deg=deg, towers=4, post_layers=1))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Middle layers
        for _ in range(1, layers):
            self.convs.append(PNAConv(hidden_dim, hidden_dim, aggregators=aggregators,
                                      scalers=scalers, deg=deg, towers=4, post_layers=1))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout_layer = nn.Dropout(dropout)
        self.output_dim = hidden_dim # For downstream model

    def forward(self, x, edge_index): # No edge_attr needed
        # x: [N, input_dim], edge_index: [2, E]
        for i in range(len(self.convs)):
            x_res = x # Store for residual

            x = self.convs[i](x, edge_index)

            if x.shape[0] > 0:
                x = self.batch_norms[i](x)
            else:
                return x # Return empty tensor

            x = F.relu(x) # ReLU activation

            # Add residual connection
            if i > 0 and x_res.shape == x.shape:
                x = x + x_res

            x = self.dropout_layer(x)

        # Return all final node features
        return x # Shape: [N, hidden_dim]

class GINENetwork(nn.Module):
    """ GNN Network using GINEConv layers (Requires edge features) """
    def __init__(self, input_dim, hidden_dim=128, edge_dim=EXPECTED_EDGE_FEATURE_DIM, dropout=0.4, layers=3):
        super().__init__()
        if edge_dim is None:
            raise ValueError("GINENetwork requires a valid 'edge_dim'.")

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.edge_dim = edge_dim

        # Define the MLP for GINE's node update function 'nn'
        def create_mlp(in_dim, out_dim):
            # Example: Simple 2-layer MLP
            return nn.Sequential(
                nn.Linear(in_dim, out_dim * 2), # Expand
                nn.ReLU(),
                nn.Dropout(dropout), # Add dropout within MLP if desired
                nn.Linear(out_dim * 2, out_dim)  # Contract
            )

        # First layer: input_dim -> hidden_dim
        mlp1 = create_mlp(input_dim, hidden_dim)
        # train_eps=True allows the epsilon parameter in GIN update to be learned
        self.convs.append(GINEConv(nn=mlp1, edge_dim=self.edge_dim, train_eps=True))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Middle layers: hidden_dim -> hidden_dim
        for _ in range(1, layers):
            mlp = create_mlp(hidden_dim, hidden_dim)
            self.convs.append(GINEConv(nn=mlp, edge_dim=self.edge_dim, train_eps=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout_layer = nn.Dropout(dropout)
        self.output_dim = hidden_dim # For downstream model

    def forward(self, x, edge_index, edge_attr): # Requires edge_attr
        # x: [N, input_dim], edge_index: [2, E], edge_attr: [E, edge_dim]
        if edge_attr is None:
            raise ValueError("GINENetwork requires edge_attr for forward pass.")
        if edge_attr.shape[1] != self.edge_dim:
            raise ValueError(f"GINENetwork edge_attr dimension mismatch. Expected {self.edge_dim}, got {edge_attr.shape[1]}")

        for i in range(len(self.convs)):
            x_res = x # Store for residual

            # Apply GINEConv, passing edge_attr
            x = self.convs[i](x, edge_index, edge_attr=edge_attr)

            if x.shape[0] > 0:
                x = self.batch_norms[i](x)
            else:
                return x # Return empty tensor

            x = F.relu(x) # ReLU activation

            # Add residual connection
            if i > 0 and x_res.shape == x.shape:
                x = x + x_res

            x = self.dropout_layer(x)

        # Return all final node features
        return x # Shape: [N, hidden_dim]

In [None]:
class StandaloneGNNModel(nn.Module):
    """
    Wraps a GNN backbone, adds readout, and a final classifier.
    Does NOT include the Sequence CNN track.
    """
    def __init__(self, gnn_type, node_feature_dim, hidden_dim=128,
                 # GNN specific args (needed for instantiation)
                 deg=None, # Required for PNA
                 edge_dim=None, # Required for GINE, Optional for GATv2
                 heads=4, # Required for GATv2
                 # Common GNN hypers (passed via kwargs below)
                 layers=3,
                 dropout=0.4,
                 # Classifier hypers (can be customized)
                 classifier_dropout=0.5):
        super().__init__()
        self.gnn_type = gnn_type
        # Determine if the chosen GNN type requires edge features based on init args
        self.use_edge_features = edge_dim is not None and gnn_type in ['gatv2', 'gine']

        # GNN Backbone Instantiation using common kwargs
        gnn_kwargs = {
            'input_dim': node_feature_dim,
            'hidden_dim': hidden_dim,
            'dropout': dropout,
            'layers': layers
            # Add other common hypers if needed
        }

        if gnn_type == 'gcn':
            self.gnn = GCNNetwork(**gnn_kwargs)
        elif gnn_type == 'gatv2':
            self.gnn = GATv2Network(**gnn_kwargs, heads=heads, edge_dim=edge_dim) # Pass specific args
        elif gnn_type == 'pna':
            if deg is None: raise ValueError("PNA requires 'deg' histogram.")
            self.gnn = PNANetwork(**gnn_kwargs, deg=deg) # Pass specific args
        elif gnn_type == 'gine':
            if edge_dim is None: raise ValueError("GINE requires 'edge_dim'.")
            self.gnn = GINENetwork(**gnn_kwargs, edge_dim=edge_dim) # Pass specific args
        else:
            raise ValueError(f"Unsupported GNN type: {gnn_type}")

        # Readout and Final Classifier Layers
        gnn_output_dim = self.gnn.output_dim # Get output dim from GNN

        # Input dimension for the classifier MLP:
        # [Central Node Features] + [Global Mean Pool] + [Global Max Pool]
        classifier_input_dim = gnn_output_dim + gnn_output_dim + gnn_output_dim

        # Simple classifier MLP
        self.fc1 = nn.Linear(classifier_input_dim, 64) # Example hidden size for classifier
        self.bn1 = nn.BatchNorm1d(64) # Use different name from GNN batchnorms
        self.dropout_layer = nn.Dropout(classifier_dropout)
        self.fc2 = nn.Linear(64, 1) # Final output for binary classification

    def forward(self, data):
        """
        Forward pass for the standalone GNN model.

        Args:
            data (torch_geometric.data.Data): Batch of graph data.
                Requires: x, edge_index, batch, central_node_idx, ptr.
                Conditionally requires: edge_attr (if model uses edge features).

        Returns:
            torch.Tensor: Output tensor after sigmoid (batch_size, 1).
        """
        x, edge_index = data.x, data.edge_index
        batch = data.batch
        central_node_idx = data.central_node_idx
        ptr = data.ptr if hasattr(data, 'ptr') else None

        # Handle missing batch/ptr if necessary (e.g., for single graph inference)
        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
        if ptr is None:
            if batch.numel() > 0:
                counts = torch.bincount(batch)
                ptr = torch.cat([torch.tensor([0], device=batch.device), counts.cumsum(0)])
            else: # Handle empty batch case
                ptr = torch.tensor([0], device=batch.device)

        # --- GNN track ---
        if self.use_edge_features:
             edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
             if edge_attr is None:
                 raise ValueError(f"Model '{self.gnn_type}' configured to use edge features, but 'edge_attr' not found in data batch.")
             # Pass edge_attr to GNNs that use it (GATv2, GINE)
             gnn_node_features = self.gnn(x, edge_index, edge_attr)
        else:
             # Pass only x and edge_index (GCN, PNA, GATv2 w/o edge_dim)
             gnn_node_features = self.gnn(x, edge_index)

        # --- Central Node Extraction & Global Pooling (Readout) ---
        gnn_out_dim = self.gnn.output_dim
        batch_size = data.num_graphs
        # Initialize readout feature tensors robustly
        central_node_features = torch.zeros((batch_size, gnn_out_dim), device=x.device)
        global_avg_features = torch.zeros((batch_size, gnn_out_dim), device=x.device)
        global_max_features = torch.zeros((batch_size, gnn_out_dim), device=x.device)

        # Perform readout only if GNN produced output nodes and indices are valid
        nodes_exist = gnn_node_features.shape[0] > 0
        indices_valid = ptr.numel() > 1 and hasattr(data, 'central_node_idx') and data.central_node_idx.numel() == batch_size

        if nodes_exist and indices_valid:
            # Central Node Extraction
            graph_starts = ptr[:-1]
            absolute_central_node_indices = graph_starts + central_node_idx
            # Check bounds before gathering
            if absolute_central_node_indices.max() < gnn_node_features.shape[0] and absolute_central_node_indices.min() >= 0:
                 central_node_features = gnn_node_features[absolute_central_node_indices]
            # else: print(f"Warning: Central node index out of bounds during readout.") # Less verbose

            # Global Pooling - requires valid batch vector
            # Check max index in batch tensor matches expected range for safety
            if batch.numel() > 0 and batch.max() < batch_size:
                 global_avg_features = global_mean_pool(gnn_node_features, batch)
                 global_max_features = global_max_pool(gnn_node_features, batch)
            # else: print(f"Warning: Batch index issue during pooling.") # Less verbose
        # else: print(f"Warning: Skipping readout.") # Less verbose


        # --- Combine readout features ---
        # Concatenate features for the classifier
        combined_features = torch.cat([central_node_features, global_avg_features, global_max_features], dim=1)

        # --- Final classification MLP ---
        x_out = self.fc1(combined_features)
        # Apply BN only if batch is not effectively empty
        if x_out.shape[0] > 0:
            x_out = self.bn1(x_out)
        else: # Should not happen if batch_size > 0, but handle defensively
             return torch.zeros((batch_size, 1), device=x.device)

        x_out = F.relu(x_out)
        x_out = self.dropout_layer(x_out)
        x_out = self.fc2(x_out)

        # Apply final sigmoid activation
        return torch.sigmoid(x_out)

In [None]:
def train_model(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, class_weights: dict = None):
    """ Trains the model for one epoch. """
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch in loader:
        batch = batch.to(device)
        # Skip potentially empty batches from DataLoader/filtering issues
        if batch.num_graphs == 0 or batch.x.shape[0] == 0:
             continue

        optimizer.zero_grad()
        output = model(batch) # Get model predictions
        target = batch.y.view(-1, 1) # Ensure target shape is [batch_size, 1]

        # Ensure output and target shapes match after potential filtering/empty batches
        if output.shape[0] != target.shape[0]:
             print(f"Warning: Mismatch between output ({output.shape[0]}) and target ({target.shape[0]}) shapes. Skipping batch.")
             continue
        if target.numel() == 0: # Skip if target is empty
             continue

        # Calculate loss
        current_loss = 0.0
        try:
            weight = None
            if class_weights is not None:
                # Ensure target values are valid keys for class_weights (0 or 1)
                target_cpu = target.long().cpu()
                valid_targets = (target_cpu == 0) | (target_cpu == 1)
                if not valid_targets.all():
                     print(f"Warning: Invalid target values found for class weights: {target_cpu[~valid_targets]}")
                     # Handle invalid targets - skip weighting or skip batch? Skipping weight for safety.
                else:
                     weight = torch.tensor([class_weights[k.item()] for k in target_cpu],
                                           dtype=torch.float, device=device).view(-1, 1)

            current_loss = F.binary_cross_entropy(output, target, weight=weight)
        except Exception as e:
             print(f"Error during loss calculation: {e}")
             print(f"Output range: {output.min()} - {output.max()}")
             print(f"Target values: {target.unique()}")
             continue # Skip batch if loss calculation fails

        # Backpropagation
        current_loss.backward()
        optimizer.step()

        total_loss += current_loss.item() * batch.num_graphs # Accumulate loss weighted by graphs in batch

        # Calculate training accuracy for monitoring
        with torch.no_grad():
             pred = (output > 0.5).float()
             correct_predictions += (pred == target).sum().item()
             total_samples += target.size(0)

    # Calculate average loss over the whole dataset, avg accuracy
    avg_loss = total_loss / len(loader.dataset) if len(loader.dataset) > 0 else 0
    avg_acc = correct_predictions / total_samples if total_samples > 0 else 0
    return avg_loss, avg_acc


def evaluate_model(model: nn.Module, loader: DataLoader, device: torch.device):
    """ Evaluates the model on the given data loader. """
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if batch.num_graphs == 0 or batch.x.shape[0] == 0: continue

            output = model(batch)
            target = batch.y.view(-1, 1)

            if output.shape[0] != target.shape[0]: continue # Skip if shape mismatch
            if target.numel() == 0: continue

            # Use reduction='sum' for accumulating loss correctly across batches
            loss = F.binary_cross_entropy(output, target, reduction='sum')
            total_loss += loss.item()

            pred = (output > 0.5).float() # Threshold predictions
            all_preds.append(pred.cpu().numpy())
            all_targets.append(target.cpu().numpy())

    # Handle cases where loader was empty or all batches were skipped
    if not all_targets:
        default_metrics = {
            'loss': 0.0, 'accuracy': 0.0, 'balanced_acc': 0.0, 'mcc': 0.0,
            'sensitivity': 0.0, 'specificity': 0.0,
            'confusion_matrix': np.zeros((2, 2), dtype=int),
            'predictions': np.array([]), 'targets': np.array([])
        }
        print("Warning: Evaluation yielded no targets/predictions.")
        return default_metrics

    all_preds = np.concatenate(all_preds).flatten()
    all_targets = np.concatenate(all_targets).flatten()
    num_samples = len(all_targets)

    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_preds)
    # Use zero_division=0 for balanced_accuracy and MCC for robustness
    balanced_acc = balanced_accuracy_score(all_targets, all_preds, adjusted=False) # zero_division=0 not directly supported, relies on warnings
    mcc = matthews_corrcoef(all_targets, all_preds)

    # Ensure CM has labels 0 and 1, even if only one class is present in targets/preds
    cm = confusion_matrix(all_targets, all_preds, labels=[0, 1])

    # Handle potential 1x1 or other non-2x2 CM if only one class exists/predicted
    if cm.size == 4:
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    else: # If cm is not 2x2 (e.g. only one class predicted/present)
        tn, fp, fn, tp = 0, 0, 0, 0
        # Assign sens/spec based on which class might be present if needed, otherwise 0
        # This case indicates severe prediction issues or single-class data
        sensitivity = 0.0
        specificity = 0.0
        # Or calculate more carefully based on available class if required

    metrics = {
        'loss': total_loss / num_samples if num_samples > 0 else 0, # Average loss per sample
        'accuracy': accuracy,
        'balanced_acc': balanced_acc,
        'mcc': mcc,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'confusion_matrix': cm,
        'predictions': all_preds, # Return raw predictions if needed later
        'targets': all_targets
    }
    return metrics

def print_metrics(metrics: dict):
    """ Prints formatted metrics from the evaluation dictionary. """
    print(f"  Loss: {metrics.get('loss', float('nan')):.4f}") # Use .get for safety
    print(f"  Accuracy: {metrics.get('accuracy', 0):.4f}, Balanced Acc: {metrics.get('balanced_acc', 0):.4f}, MCC: {metrics.get('mcc', 0):.4f}")
    print(f"  Sensitivity: {metrics.get('sensitivity', 0):.4f}, Specificity: {metrics.get('specificity', 0):.4f}")
    # Ensure confusion matrix exists and is numpy array for printing
    cm = metrics.get('confusion_matrix', None)
    if isinstance(cm, np.ndarray):
         print(f"  Confusion Matrix:\n{cm}")
    else:
         print("  Confusion Matrix: N/A")

In [None]:
def train_and_evaluate_model(train_df: pd.DataFrame, test_df: pd.DataFrame,
                             model_config: dict, distance_threshold: float = 8.0):
    """
    Trains and evaluates a single Standalone GNN model configuration using 5-fold CV.

    Args:
        train_df (pd.DataFrame): Training data.
        test_df (pd.DataFrame): Test data.
        model_config (dict): Configuration dictionary containing keys like
                             'gnn_type' (str), 'use_edge_features' (bool),
                             and potentially GNN hyperparameters.
        distance_threshold (float): Distance cutoff for graph edges.

    Returns:
        tuple: (list_of_fold_validation_metrics, dict_of_final_test_metrics or None)
    """
    gnn_type = model_config['gnn_type']
    use_edge_features = model_config['use_edge_features']
    run_label = f"{gnn_type.upper()} (Edge Feats: {use_edge_features})" # Label for printing
    print(f"\n{'='*15} Evaluating: {run_label} {'='*15}")

    # --- Data Preparation ---
    # Pass use_ss based on your feature selection decision (assuming True here)
    train_graphs, train_labels = prepare_graph_data(train_df, distance_threshold,
                                                    use_ss=True, include_edge_features=use_edge_features)
    test_graphs, test_labels = prepare_graph_data(test_df, distance_threshold,
                                                  use_ss=True, include_edge_features=use_edge_features)

    # Exit if data preparation failed
    if not train_graphs:
        print(f"ERROR: No training graphs created for {run_label}. Aborting.")
        return None, None

    # --- Compute Degree Histogram for PNA (if needed) ---
    deg_histogram = None
    if gnn_type == 'pna':
        print("Calculating node degree histogram for PNA...")
        max_degree = -1
        degrees = []
        for data in train_graphs:
            num_nodes = data.num_nodes
            if data.edge_index.numel() > 0:
                 # Filter invalid indices before calculating degree
                 valid_edge_mask = (data.edge_index[0] < num_nodes) & (data.edge_index[1] < num_nodes)
                 valid_edge_index = data.edge_index[:, valid_edge_mask]
                 if valid_edge_index.numel() > 0:
                      # Use target nodes (index 1) for in-degree, common practice
                      deg_list = degree(valid_edge_index[1], num_nodes=num_nodes, dtype=torch.long)
                 else: # No valid edges after filtering
                      deg_list = torch.zeros(num_nodes, dtype=torch.long)
            else: # Graph has no edges
                 deg_list = torch.zeros(num_nodes, dtype=torch.long)

            degrees.append(deg_list)
            # Find max degree safely
            if deg_list.numel() > 0:
                 current_max = deg_list.max().item()
                 max_degree = max(max_degree, current_max)

        if max_degree == -1: max_degree = 0 # Handle case with no nodes/edges

        all_degrees = torch.cat(degrees, dim=0)
        # Compute histogram up to max_degree found in the training set
        deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
        if all_degrees.numel() > 0: # Ensure there are degrees to count
             counts = torch.bincount(all_degrees, minlength=max_degree + 1)
             # Ensure indices match before assignment
             hist_len = deg_histogram.numel()
             counts_len = counts.numel()
             copy_len = min(hist_len, counts_len)
             deg_histogram[:copy_len] = counts[:copy_len]

        print(f"Max degree found: {max_degree}")
        deg_histogram = deg_histogram.to(device) # Move histogram to target device
        # Handle case where histogram might be all zeros (e.g., all nodes isolated)
        if deg_histogram.sum() == 0 and deg_histogram.numel() > 0:
             print("Warning: Degree histogram is all zeros. PNA might fail. Providing minimal histogram.")
             deg_histogram[0] = 1 # Give count 1 to degree 0

    # --- Class Weights (Optional) ---
    # Calculate based on training labels
    total_train = len(train_labels)
    pos_train = sum(train_labels)
    neg_train = total_train - pos_train
    class_weights = {
        0: total_train / (2 * neg_train) if neg_train > 0 else 1.0,
        1: total_train / (2 * pos_train) if pos_train > 0 else 1.0
    }
    # print(f"Class weights: {class_weights}") # Optional print

    # --- Cross-validation (Single Run with fixed SEED for split) ---
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED) # Use global SEED
    fold_metrics = []         # Store validation metrics for each fold
    test_predictions = []     # Store test predictions from each fold's best model

    for fold, (train_idx, val_idx) in enumerate(kfold.split(np.zeros(len(train_graphs)), train_labels), 1):
        print(f"\n--- Fold {fold}/5 ---")
        # --- Data Loaders for Fold ---
        train_fold = [train_graphs[i] for i in train_idx]
        val_fold = [train_graphs[i] for i in val_idx]
        # Consider pinning memory if using GPU and large data
        train_loader = DataLoader(train_fold, batch_size=32, shuffle=True, num_workers=0, pin_memory=False)
        val_loader = DataLoader(val_fold, batch_size=32, shuffle=False, num_workers=0, pin_memory=False)

        # --- Model Initialization ---
        try:
            node_feature_dim = train_graphs[0].x.shape[1]
        except IndexError:
             print("ERROR: Cannot get node_feature_dim from empty train_graphs list.")
             return None, None # Cannot proceed

        edge_dim_arg = EXPECTED_EDGE_FEATURE_DIM if use_edge_features else None

        # Pass GNN hyperparameters from model_config if provided, else use defaults
        model = StandaloneGNNModel(
            gnn_type=gnn_type,
            node_feature_dim=node_feature_dim,
            hidden_dim=model_config.get('hidden_dim', 128),
            deg=deg_histogram if gnn_type == 'pna' else None,
            edge_dim=edge_dim_arg,
            heads=model_config.get('heads', 4),      # Default heads for GATv2
            layers=model_config.get('layers', 3),    # Default layers
            dropout=model_config.get('dropout', 0.4) # Default dropout
            # Pass classifier_dropout if needed: classifier_dropout=model_config.get('classifier_dropout', 0.5)
        ).to(device)

        # --- Optimizer and Scheduler ---
        # Use AdamW by default, pass lr/weight_decay from config if needed
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=model_config.get('lr', 0.001),
                                      weight_decay=model_config.get('weight_decay', 0.01))
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=False)

        # --- Training Loop with Early Stopping ---
        epochs = model_config.get('epochs', 50)
        patience = model_config.get('patience', 10)
        best_val_loss = float('inf')
        epochs_no_improve = 0
        best_state_dict = None

        for epoch in range(epochs):
            # Pass class_weights if needed by train_model implementation
            train_loss, train_acc = train_model(model, train_loader, optimizer, device, class_weights=class_weights)
            val_metrics = evaluate_model(model, val_loader, device)
            val_loss = val_metrics['loss']
            scheduler.step(val_loss) # Step scheduler on validation loss

            # Simple epoch print - can be made more detailed
            # print(f"    Epoch {epoch+1:02d}, Tr L: {train_loss:.4f}, V L: {val_loss:.4f}, V BAcc: {val_metrics['balanced_acc']:.4f}")

            # Early stopping check
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0
                # Save the best model state
                best_state_dict = model.state_dict().copy()
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"  Early stopping triggered at epoch {epoch+1}")
                    break
        # --- End Epoch Loop ---

        # Load best model state for evaluation
        if best_state_dict is not None:
            model.load_state_dict(best_state_dict)
        else:
            print("Warning: No best model state saved (perhaps training ended early or failed). Evaluating last state.")

        # Evaluate on validation set for this fold
        final_val_metrics = evaluate_model(model, val_loader, device)
        fold_metrics.append(final_val_metrics)
        print(f"  Fold {fold} Final Validation Metrics:")
        print_metrics(final_val_metrics) # Print metrics for the fold

        # Evaluate on test set using the best model from this fold
        if test_graphs:
            test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False) # Use larger batch size maybe
            test_fold_metrics = evaluate_model(model, test_loader, device)
            # Store only predictions for later ensembling
            test_predictions.append(test_fold_metrics['predictions'])
        else:
             print("  Skipping test set evaluation as no test graphs provided.")
    # --- End Fold Loop ---

    # --- Aggregate and Report CV Results ---
    print(f"\n--- Cross-validation Summary for {run_label} ---")
    metrics_to_report = ['accuracy', 'balanced_acc', 'mcc', 'sensitivity', 'specificity', 'loss']
    if fold_metrics: # Check if CV completed and metrics were generated
        for metric in metrics_to_report:
            # Get metric values, handling potential errors if a fold failed
            values = [m.get(metric, float('nan')) for m in fold_metrics]
            values = [v for v in values if not np.isnan(v)] # Filter out NaNs
            if values:
                print(f"  Avg Val {metric:<15}: {np.mean(values):.4f} ± {np.std(values):.4f}")
            else:
                print(f"  Avg Val {metric:<15}: N/A")
    else:
        print("  No fold metrics recorded.")


    # --- Final Test Set Evaluation (Ensembled) ---
    final_test_metrics = None
    if test_predictions and test_labels is not None: # Check if test data exists
        print("\n--- Final Test Set Performance (Ensembled from CV Folds) ---")
        test_pred_avg = np.mean(np.stack(test_predictions), axis=0)
        test_pred_binary = (test_pred_avg > 0.5).astype(int)

        # Calculate final metrics
        final_test_metrics = {
            'accuracy': accuracy_score(test_labels, test_pred_binary),
            'balanced_acc': balanced_accuracy_score(test_labels, test_pred_binary),
            'mcc': matthews_corrcoef(test_labels, test_pred_binary),
            'confusion_matrix': confusion_matrix(test_labels, test_pred_binary, labels=[0, 1])
        }
        cm = final_test_metrics['confusion_matrix']
        if cm.size == 4: # Check if 2x2
            tn, fp, fn, tp = cm.ravel()
            final_test_metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            final_test_metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        else: # Handle non-2x2 CM case
            final_test_metrics['sensitivity'] = 0.0
            final_test_metrics['specificity'] = 0.0

        print_metrics(final_test_metrics) # Print the final test results
    else:
        print("\n--- Test set evaluation skipped or failed ---")


    # Return the collected metrics for potential aggregation outside the function
    return fold_metrics, final_test_metrics

In [None]:
if __name__ == "__main__":
    # --- Configuration ---
    DISTANCE_THRESHOLD = 8.0 # Set the distance threshold for graph construction

    # Define the model configurations to test
    # Each dict specifies the GNN type and whether it should use edge features
    # Add hyperparameters here if you want to override defaults in StandaloneGNNModel init
    model_configs = [
        {'gnn_type': 'gcn',   'use_edge_features': False},
        {'gnn_type': 'pna',   'use_edge_features': False},
        {'gnn_type': 'gatv2', 'use_edge_features': False}, # GATv2 without edge features
        {'gnn_type': 'gatv2', 'use_edge_features': True},  # GATv2 with edge features
        {'gnn_type': 'gine',  'use_edge_features': True},  # GINE requires edge features
        # --- Add more configurations below if needed ---
        # Example: Tune GCN layers
        # {'gnn_type': 'gcn', 'use_edge_features': False, 'layers': 4, 'hidden_dim': 256},
        # Example: Tune GATv2 heads
        # {'gnn_type': 'gatv2', 'use_edge_features': True, 'heads': 8},
    ]

    # --- Data Loading ---
    try:
        # Update paths if necessary
        train_df = pd.read_csv("../../../../data/train/structure/processed_features_train.csv")
        test_df = pd.read_csv("../../../../data/test/structure/processed_features_test.csv")
        print("Data loaded successfully.")
        print("Train class distribution (%):", (train_df['label'].value_counts(normalize=True) * 100).round(2).to_dict())
        print("Test class distribution (%):", (test_df['label'].value_counts(normalize=True) * 100).round(2).to_dict())
    except FileNotFoundError:
        print("ERROR: Training or test data file not found. Please check the paths.")
        exit()
    except Exception as e:
        print(f"Error loading data: {e}")
        traceback.print_exc()
        exit()

    # --- Store Results ---
    all_results = {} # Dictionary to store results for each configuration

    # --- Loop through configurations and run training/evaluation ---
    for config in model_configs:
        config_name = f"{config['gnn_type']}_{'EF' if config['use_edge_features'] else 'noEF'}"
        # Add other tuned params to name if needed: e.g. f"{config['gnn_type']}_L{config.get('layers',3)}_H{config.get('hidden_dim',128)}_{'EF' if config['use_edge_features'] else 'noEF'}"

        # Run the full CV and evaluation for this configuration
        # The function now handles the internal CV loop and printing
        fold_metrics, test_metrics = train_and_evaluate_model(
            train_df,
            test_df,
            model_config=config, # Pass the specific config dict
            distance_threshold=DISTANCE_THRESHOLD
        )

        # Store the results (optional, but useful for summary)
        all_results[config_name] = {
            'config': config, # Store config for reference
            'fold_metrics': fold_metrics,
            'test_metrics': test_metrics
        }
        print(f"\nFinished evaluation for: {config_name}")
        print("-" * 60)
        # --- End of loop for one configuration ---

    # --- Final Summary ---
    print("\n\n" + "="*25 + " Overall Run Summary " + "="*25)
    for config_name, results in all_results.items():
        print(f"\n--- Results for Configuration: {config_name} ---")
        test_metrics = results.get('test_metrics')
        fold_metrics = results.get('fold_metrics')

        if test_metrics:
            print("  Final Test Set Performance:")
            # Just print key metrics for summary comparison
            print(f"    Acc: {test_metrics.get('accuracy', 0):.4f}, "
                  f"BAcc: {test_metrics.get('balanced_acc', 0):.4f}, "
                  f"MCC: {test_metrics.get('mcc', 0):.4f}, "
                  f"Sens: {test_metrics.get('sensitivity', 0):.4f}, "
                  f"Spec: {test_metrics.get('specificity', 0):.4f}")
        elif fold_metrics:
            # Calculate average validation BAcc if test metrics are missing
            baccs = [m.get('balanced_acc', float('nan')) for m in fold_metrics]
            valid_baccs = [b for b in baccs if not np.isnan(b)]
            if valid_baccs:
                 avg_val_bacc = np.mean(valid_baccs)
                 std_val_bacc = np.std(valid_baccs)
                 print(f"  Avg Validation Balanced Acc: {avg_val_bacc:.4f} +/- {std_val_bacc:.4f} (Test metrics unavailable)")
            else:
                 print("  No valid validation metrics recorded.")
        else:
            print("  No results recorded for this configuration.")

    print("\n" + "="*60)
    print("All configurations evaluated.")

Data loaded successfully.
Train class distribution (%): {1: 51.87, 0: 48.13}
Test class distribution (%): {0: 91.23, 1: 8.77}

Preparing data (Edge Features: False, SS Features: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included in Data objects.
Preparing data (Edge Features: False, SS Features: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included in Data objects.

--- Fold 1/5 ---




  Early stopping triggered at epoch 19
  Fold 1 Final Validation Metrics:
  Loss: 0.6354
  Accuracy: 0.6471, Balanced Acc: 0.6533, MCC: 0.3228
  Sensitivity: 0.4897, Specificity: 0.8169
  Confusion Matrix:
[[696 156]
 [469 450]]





--- Fold 2/5 ---




  Early stopping triggered at epoch 22
  Fold 2 Final Validation Metrics:
  Loss: 0.5847
  Accuracy: 0.6804, Balanced Acc: 0.6749, MCC: 0.3671
  Sensitivity: 0.8215, Specificity: 0.5282
  Confusion Matrix:
[[450 402]
 [164 755]]





--- Fold 3/5 ---




  Early stopping triggered at epoch 16
  Fold 3 Final Validation Metrics:
  Loss: 0.6704
  Accuracy: 0.6070, Balanced Acc: 0.6164, MCC: 0.2697
  Sensitivity: 0.3595, Specificity: 0.8734
  Confusion Matrix:
[[745 108]
 [588 330]]





--- Fold 4/5 ---




  Early stopping triggered at epoch 29
  Fold 4 Final Validation Metrics:
  Loss: 0.6059
  Accuracy: 0.6633, Balanced Acc: 0.6635, MCC: 0.3268
  Sensitivity: 0.6569, Specificity: 0.6702
  Confusion Matrix:
[[571 281]
 [315 603]]





--- Fold 5/5 ---




  Early stopping triggered at epoch 21
  Fold 5 Final Validation Metrics:
  Loss: 0.6055
  Accuracy: 0.6627, Balanced Acc: 0.6638, MCC: 0.3277
  Sensitivity: 0.6351, Specificity: 0.6925
  Confusion Matrix:
[[590 262]
 [335 583]]





--- Cross-validation Summary for GCN (Edge Feats: False) ---
  Avg Val accuracy       : 0.6521 ± 0.0249
  Avg Val balanced_acc   : 0.6544 ± 0.0202
  Avg Val mcc            : 0.3228 ± 0.0311
  Avg Val sensitivity    : 0.5925 ± 0.1570
  Avg Val specificity    : 0.7162 ± 0.1207
  Avg Val loss           : 0.6204 ± 0.0298

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.7256, Balanced Acc: 0.6481, MCC: 0.1858
  Sensitivity: 0.5542, Specificity: 0.7421
  Confusion Matrix:
[[1853  644]
 [ 107  133]]

Finished evaluation for: gcn_noEF
------------------------------------------------------------

Preparing data (Edge Features: False, SS Features: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included in Data objects.
Preparing data (Edge Features: False, SS Features: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included in Data 



  Early stopping triggered at epoch 22
  Fold 1 Final Validation Metrics:
  Loss: 0.7152
  Accuracy: 0.6358, Balanced Acc: 0.6411, MCC: 0.2927
  Sensitivity: 0.5016, Specificity: 0.7805
  Confusion Matrix:
[[665 187]
 [458 461]]





--- Fold 2/5 ---




  Early stopping triggered at epoch 19
  Fold 2 Final Validation Metrics:
  Loss: 0.6135
  Accuracy: 0.6844, Balanced Acc: 0.6818, MCC: 0.3675
  Sensitivity: 0.7497, Specificity: 0.6138
  Confusion Matrix:
[[523 329]
 [230 689]]





--- Fold 3/5 ---




  Early stopping triggered at epoch 19
  Fold 3 Final Validation Metrics:
  Loss: 0.7054
  Accuracy: 0.6426, Balanced Acc: 0.6490, MCC: 0.3166
  Sensitivity: 0.4739, Specificity: 0.8242
  Confusion Matrix:
[[703 150]
 [483 435]]





--- Fold 4/5 ---




  Early stopping triggered at epoch 17
  Fold 4 Final Validation Metrics:
  Loss: 0.7100
  Accuracy: 0.6175, Balanced Acc: 0.6242, MCC: 0.2647
  Sensitivity: 0.4455, Specificity: 0.8028
  Confusion Matrix:
[[684 168]
 [509 409]]





--- Fold 5/5 ---




  Early stopping triggered at epoch 24
  Fold 5 Final Validation Metrics:
  Loss: 0.6482
  Accuracy: 0.6593, Balanced Acc: 0.6579, MCC: 0.3168
  Sensitivity: 0.6961, Specificity: 0.6197
  Confusion Matrix:
[[528 324]
 [279 639]]





--- Cross-validation Summary for PNA (Edge Feats: False) ---
  Avg Val accuracy       : 0.6479 ± 0.0226
  Avg Val balanced_acc   : 0.6508 ± 0.0191
  Avg Val mcc            : 0.3117 ± 0.0338
  Avg Val sensitivity    : 0.5734 ± 0.1245
  Avg Val specificity    : 0.7282 ± 0.0920
  Avg Val loss           : 0.6785 ± 0.0405

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.7362, Balanced Acc: 0.6426, MCC: 0.1820
  Sensitivity: 0.5292, Specificity: 0.7561
  Confusion Matrix:
[[1888  609]
 [ 113  127]]

Finished evaluation for: pna_noEF
------------------------------------------------------------

Preparing data (Edge Features: False, SS Features: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included in Data objects.
Preparing data (Edge Features: False, SS Features: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included in Data 



  Early stopping triggered at epoch 35
  Fold 1 Final Validation Metrics:
  Loss: 0.6260
  Accuracy: 0.6573, Balanced Acc: 0.6488, MCC: 0.3344
  Sensitivity: 0.8716, Specificity: 0.4261
  Confusion Matrix:
[[363 489]
 [118 801]]





--- Fold 2/5 ---




  Fold 2 Final Validation Metrics:
  Loss: 0.5691
  Accuracy: 0.7058, Balanced Acc: 0.7000, MCC: 0.4221
  Sensitivity: 0.8531, Specificity: 0.5469
  Confusion Matrix:
[[466 386]
 [135 784]]





--- Fold 3/5 ---




  Early stopping triggered at epoch 43
  Fold 3 Final Validation Metrics:
  Loss: 0.5999
  Accuracy: 0.6725, Balanced Acc: 0.6694, MCC: 0.3443
  Sensitivity: 0.7538, Specificity: 0.5850
  Confusion Matrix:
[[499 354]
 [226 692]]





--- Fold 4/5 ---




  Early stopping triggered at epoch 33
  Fold 4 Final Validation Metrics:
  Loss: 0.6246
  Accuracy: 0.6593, Balanced Acc: 0.6579, MCC: 0.3168
  Sensitivity: 0.6950, Specificity: 0.6209
  Confusion Matrix:
[[529 323]
 [280 638]]





--- Fold 5/5 ---




  Early stopping triggered at epoch 36
  Fold 5 Final Validation Metrics:
  Loss: 0.6305
  Accuracy: 0.6565, Balanced Acc: 0.6460, MCC: 0.3561
  Sensitivity: 0.9270, Specificity: 0.3650
  Confusion Matrix:
[[311 541]
 [ 67 851]]





--- Cross-validation Summary for GATV2 (Edge Feats: False) ---
  Avg Val accuracy       : 0.6703 ± 0.0187
  Avg Val balanced_acc   : 0.6644 ± 0.0196
  Avg Val mcc            : 0.3547 ± 0.0361
  Avg Val sensitivity    : 0.8201 ± 0.0839
  Avg Val specificity    : 0.5088 ± 0.0973
  Avg Val loss           : 0.6100 ± 0.0231

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.5532, Balanced Acc: 0.6685, MCC: 0.1906
  Sensitivity: 0.8083, Specificity: 0.5286
  Confusion Matrix:
[[1320 1177]
 [  46  194]]

Finished evaluation for: gatv2_noEF
------------------------------------------------------------

Preparing data (Edge Features: True, SS Features: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17
Preparing data (Edge Features: True, SS Features: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17

--- Fold 1/5 ---




  Fold 1 Final Validation Metrics:
  Loss: 0.6103
  Accuracy: 0.6725, Balanced Acc: 0.6673, MCC: 0.3490
  Sensitivity: 0.8041, Specificity: 0.5305
  Confusion Matrix:
[[452 400]
 [180 739]]





--- Fold 2/5 ---




  Fold 2 Final Validation Metrics:
  Loss: 0.5925
  Accuracy: 0.6917, Balanced Acc: 0.6932, MCC: 0.3869
  Sensitivity: 0.6540, Specificity: 0.7324
  Confusion Matrix:
[[624 228]
 [318 601]]





--- Fold 3/5 ---




  Fold 3 Final Validation Metrics:
  Loss: 0.6148
  Accuracy: 0.6584, Balanced Acc: 0.6534, MCC: 0.3196
  Sensitivity: 0.7887, Specificity: 0.5182
  Confusion Matrix:
[[442 411]
 [194 724]]





--- Fold 4/5 ---




  Fold 4 Final Validation Metrics:
  Loss: 0.6123
  Accuracy: 0.6678, Balanced Acc: 0.6699, MCC: 0.3413
  Sensitivity: 0.6133, Specificity: 0.7265
  Confusion Matrix:
[[619 233]
 [355 563]]





--- Fold 5/5 ---




  Fold 5 Final Validation Metrics:
  Loss: 0.5973
  Accuracy: 0.6751, Balanced Acc: 0.6746, MCC: 0.3492
  Sensitivity: 0.6895, Specificity: 0.6596
  Confusion Matrix:
[[562 290]
 [285 633]]





--- Cross-validation Summary for GATV2 (Edge Feats: True) ---
  Avg Val accuracy       : 0.6731 ± 0.0109
  Avg Val balanced_acc   : 0.6717 ± 0.0129
  Avg Val mcc            : 0.3492 ± 0.0217
  Avg Val sensitivity    : 0.7099 ± 0.0748
  Avg Val specificity    : 0.6334 ± 0.0928
  Avg Val loss           : 0.6054 ± 0.0089

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.6803, Balanced Acc: 0.6704, MCC: 0.2024
  Sensitivity: 0.6583, Specificity: 0.6824
  Confusion Matrix:
[[1704  793]
 [  82  158]]

Finished evaluation for: gatv2_EF
------------------------------------------------------------

Preparing data (Edge Features: True, SS Features: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17
Preparing data (Edge Features: True, SS Features: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17

--- Fold 1/5 ---




  Early stopping triggered at epoch 14
  Fold 1 Final Validation Metrics:
  Loss: 0.6930
  Accuracy: 0.5014, Balanced Acc: 0.5171, MCC: 0.0610
  Sensitivity: 0.1023, Specificity: 0.9319
  Confusion Matrix:
[[794  58]
 [825  94]]





--- Fold 2/5 ---




  Fold 2 Final Validation Metrics:
  Loss: 0.6547
  Accuracy: 0.6076, Balanced Acc: 0.6137, MCC: 0.2398
  Sensitivity: 0.4505, Specificity: 0.7770
  Confusion Matrix:
[[662 190]
 [505 414]]





--- Fold 3/5 ---




  Fold 3 Final Validation Metrics:
  Loss: 0.6685
  Accuracy: 0.5906, Balanced Acc: 0.5814, MCC: 0.1891
  Sensitivity: 0.8333, Specificity: 0.3294
  Confusion Matrix:
[[281 572]
 [153 765]]





--- Fold 4/5 ---




  Early stopping triggered at epoch 47
  Fold 4 Final Validation Metrics:
  Loss: 0.6573
  Accuracy: 0.6153, Balanced Acc: 0.6077, MCC: 0.2364
  Sensitivity: 0.8105, Specificity: 0.4049
  Confusion Matrix:
[[345 507]
 [174 744]]





--- Fold 5/5 ---




  Early stopping triggered at epoch 32
  Fold 5 Final Validation Metrics:
  Loss: 0.6763
  Accuracy: 0.5621, Balanced Acc: 0.5692, MCC: 0.1488
  Sensitivity: 0.3813, Specificity: 0.7570
  Confusion Matrix:
[[645 207]
 [568 350]]





--- Cross-validation Summary for GINE (Edge Feats: True) ---
  Avg Val accuracy       : 0.5754 ± 0.0412
  Avg Val balanced_acc   : 0.5778 ± 0.0345
  Avg Val mcc            : 0.1750 ± 0.0661
  Avg Val sensitivity    : 0.5156 ± 0.2760
  Avg Val specificity    : 0.6401 ± 0.2321
  Avg Val loss           : 0.6699 ± 0.0139

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.6712, Balanced Acc: 0.5995, MCC: 0.1196
  Sensitivity: 0.5125, Specificity: 0.6864
  Confusion Matrix:
[[1714  783]
 [ 117  123]]

Finished evaluation for: gine_EF
------------------------------------------------------------



--- Results for Configuration: gcn_noEF ---
  Final Test Set Performance:
    Acc: 0.7256, BAcc: 0.6481, MCC: 0.1858, Sens: 0.5542, Spec: 0.7421

--- Results for Configuration: pna_noEF ---
  Final Test Set Performance:
    Acc: 0.7362, BAcc: 0.6426, MCC: 0.1820, Sens: 0.5292, Spec: 0.7561

--- Results for Configuration: gatv2_noEF ---
  Final Test Set Performan