In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# PyG imports
from torch_geometric.nn import GCNConv, GATv2Conv, PNAConv, GINEConv
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader, Batch # Explicitly import Batch
from torch_geometric.utils import degree # Keep degree for PNA

import numpy as np
import pandas as pd
import traceback # For detailed error reporting
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import StratifiedKFold
# from sklearn.preprocessing import RobustScaler # Not used directly in this version
# import matplotlib.pyplot as plt # Keep if visualizing later
import random
import os
import time # Optional: for timing execution
from typing import Optional, Tuple, Dict, List # Type hints for clarity

# --- Reproducibility Setup ---
SEED = 42 # Or choose a specific seed for this run
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    # Optional: For stricter reproducibility (can slow down training)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
print(f"Reproducibility seeds set with SEED={SEED}")

# --- Device Configuration ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

# --- Constants ---
# Amino acid definitions (for node features)
AMINO_ACIDS = 'ARNDCQEGHILKMFPSTWYV-' # Includes padding char
AA_TO_INT = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
VALID_AA = 'ARNDCQEGHILKMFPSTWYV'  # Valid amino acids for one-hot encoding

# ProtT5 Embedding dimension
PROT_T5_DIM = 1024
print(f"ProtT5 embedding dimension set to: {PROT_T5_DIM}")

# Expected edge feature dimension (if using rich features)
EXPECTED_EDGE_FEATURE_DIM = 17 # Corrected dimension
print(f"Expected rich edge feature dimension set to: {EXPECTED_EDGE_FEATURE_DIM}")


# Default sequence length and central K position (used in data prep validation)
EXPECTED_SEQ_LEN = 33
CENTRAL_K_POS_ABS = 16 # 0-based index

# --- ProtT5 Data Loading and Alignment Functions ---

def load_prot_t5_data(pos_file: str, neg_file: str) -> Tuple[Optional[Dict], Optional[Dict]]:
    """
    Load ProtT5 embeddings from positive and negative sample files.

    Args:
        pos_file (str): Path to the CSV file with positive sample embeddings.
                          Expected format: entry,pos,embedding_vector...
        neg_file (str): Path to the CSV file with negative sample embeddings.
                          Expected format: entry,pos,embedding_vector...

    Returns:
        tuple: (pos_dict, neg_dict) where dictionaries map (entry, pos)
               tuples to their embedding lists (as floats), or (None, None) on error.
    """
    start_time = time.time()
    print(f"Loading ProtT5 data...")
    print(f"  Positive file: {pos_file}")
    print(f"  Negative file: {neg_file}")

    # Helper function to parse one embedding file
    def _parse_emb_file(filepath):
        data_list = []
        loaded_count = 0
        skipped_count = 0
        try:
            with open(filepath, 'r') as f:
                for i, line in enumerate(f):
                    try:
                        parts = line.strip().split(',')
                        if len(parts) < 3: skipped_count += 1; continue
                        entry = parts[0]
                        pos = int(parts[1])
                        embeddings = [float(x) for x in parts[2:]]
                        if len(embeddings) != PROT_T5_DIM: skipped_count += 1; continue
                        data_list.append((entry, pos, embeddings))
                        loaded_count += 1
                    except ValueError as e:
                        print(f"Error converting line {i+1} in {filepath}: {e}. Skipping.")
                        skipped_count += 1
                    except Exception as e_inner:
                        print(f"Error processing line {i+1} in {filepath}: {e_inner}. Skipping.")
                        skipped_count += 1
        except FileNotFoundError:
            print(f"Error: ProtT5 file not found: {filepath}")
            return None, 0, 0 # Return None to indicate failure
        except Exception as e_outer:
            print(f"Error reading {filepath}: {e_outer}")
            return None, 0, 0
        return data_list, loaded_count, skipped_count

    # Load positive and negative data
    pos_data, loaded_pos, skipped_pos = _parse_emb_file(pos_file)
    neg_data, loaded_neg, skipped_neg = _parse_emb_file(neg_file)

    # Check if loading failed
    if pos_data is None or neg_data is None:
        return None, None

    # Convert lists to dictionaries
    pos_dict = {(entry, pos): emb for entry, pos, emb in pos_data}
    neg_dict = {(entry, pos): emb for entry, pos, emb in neg_data}

    end_time = time.time()
    print(f"Loaded {loaded_pos} positive (skipped {skipped_pos}) and {loaded_neg} negative (skipped {skipped_neg}) ProtT5 embeddings.")
    print(f"ProtT5 Loading finished in {end_time - start_time:.2f} seconds.")

    return pos_dict, neg_dict


def prepare_aligned_data(seq_struct_df: pd.DataFrame, pos_dict: Optional[Dict], neg_dict: Optional[Dict]) -> Tuple[np.ndarray, pd.DataFrame]:
    """
    Align ProtT5 embeddings with the main sequence/structure DataFrame.

    Filters the input DataFrame to include only rows for which a
    corresponding ProtT5 embedding exists in the provided dictionaries,
    matching on ('entry', 'pos') and the row's 'label'.

    Args:
        seq_struct_df (pd.DataFrame): DataFrame containing sequence/structure info.
                                      Must include 'entry', 'pos', and 'label' columns.
        pos_dict (Optional[Dict]): Dictionary mapping (entry, pos) to positive embeddings.
        neg_dict (Optional[Dict]): Dictionary mapping (entry, pos) to negative embeddings.

    Returns:
        tuple: (X_prot_t5, aligned_df) where X_prot_t5 is a NumPy array of
               aligned embeddings (shape [num_aligned, PROT_T5_DIM]) and
               aligned_df is the filtered DataFrame containing only the rows
               corresponding to the embeddings in X_prot_t5. Returns an empty
               array and an empty DataFrame if alignment fails or inputs are invalid.
    """
    start_time = time.time()
    print("Aligning ProtT5 embeddings with main DataFrame...")
    embeddings = []
    aligned_indices = [] # Store original DataFrame indices of aligned rows
    skipped_count = 0

    # --- Input Validation ---
    if pos_dict is None or neg_dict is None:
        print("Error: Invalid ProtT5 dictionaries provided. Cannot align.")
        return np.array([]).reshape(0, PROT_T5_DIM), pd.DataFrame(columns=seq_struct_df.columns)

    required_cols = ['entry', 'pos', 'label']
    if not all(col in seq_struct_df.columns for col in required_cols):
        print(f"Error: Input DataFrame must contain columns: {required_cols}")
        return np.array([]).reshape(0, PROT_T5_DIM), pd.DataFrame(columns=seq_struct_df.columns)

    # --- Alignment Loop ---
    for idx, row in seq_struct_df.iterrows():
        try:
            # Ensure key components are correct type
            key = (str(row['entry']), int(row['pos']))
            label = int(row['label'])
            emb_dict = pos_dict if label == 1 else neg_dict
            emb = emb_dict.get(key)

            if emb is not None:
                if len(emb) == PROT_T5_DIM: # Final check on loaded dim
                    embeddings.append(emb)
                    aligned_indices.append(idx)
                else: skipped_count += 1 # Dim mismatch
            else: skipped_count += 1 # Key not found

        except (TypeError, ValueError) as e:
             print(f"Warning: Skipping row index {idx} due to type error during key creation (entry='{row.get('entry')}', pos='{row.get('pos')}'): {e}")
             skipped_count += 1
        except Exception as e_inner:
             print(f"Warning: Unexpected error processing row index {idx}: {e_inner}")
             skipped_count += 1

    if skipped_count > 0:
        print(f"Alignment: Skipped {skipped_count} rows due to missing/invalid ProtT5 embeddings or data errors.")
    if not aligned_indices:
        print("Error: No data points could be aligned with ProtT5 embeddings.")
        return np.array([]).reshape(0, PROT_T5_DIM), pd.DataFrame(columns=seq_struct_df.columns)

    # --- Final Output Creation ---
    X_prot_t5 = np.array(embeddings, dtype=np.float32)
    aligned_df = seq_struct_df.loc[aligned_indices].copy() # Use .loc with original indices

    end_time = time.time()
    print(f"Alignment complete. Kept {len(aligned_df)} out of {len(seq_struct_df)} original rows.")
    print(f"Alignment finished in {end_time - start_time:.2f} seconds.")

    # Final dimension check
    if X_prot_t5.shape[0] != len(aligned_df):
         print(f"CRITICAL WARNING: Mismatch after alignment! Embeddings: {X_prot_t5.shape[0]}, DataFrame: {len(aligned_df)}")
         return np.array([]).reshape(0, PROT_T5_DIM), pd.DataFrame(columns=seq_struct_df.columns) # Return empty on critical error

    return X_prot_t5, aligned_df

def create_edge_features(i_orig, j_orig, dist_map, sequence, ss_string, sasa_vals, K_POS=16):
    """
    Create rich edge features (17 dims) between original indices i_orig and j_orig.
    Handles padding checks.

    Args:
        i_orig, j_orig: Original 0-32 indices in the sequence window.
        dist_map: Full 33x33 distance map numpy array.
        sequence: Full 33-char sequence string.
        ss_string: Full 33-char secondary structure string.
        sasa_vals: Full 33-element SASA numpy array.
        K_POS: Absolute index (0-32) of the central Lysine.

    Returns:
        A fixed-size numpy array (17,) dtype=np.float32.
    """
    edge_features = np.zeros(EXPECTED_EDGE_FEATURE_DIM, dtype=np.float32) # Use corrected dimension (17)
    seq_len = len(sequence) # Should be 33

    # Basic validity and padding checks
    if not (0 <= i_orig < seq_len and 0 <= j_orig < seq_len):
        return edge_features
    if sequence[i_orig] == '-' or sequence[j_orig] == '-':
        return edge_features

    # Safely get distance
    distance = 0.0
    if i_orig < dist_map.shape[0] and j_orig < dist_map.shape[1]:
        distance = dist_map[i_orig, j_orig]
    else:
        return edge_features # Index out of bounds for dist_map

    feature_idx = 0

    # 1. Distance Features (5 features) - Bins + Inverse
    if distance <= 0: pass # Handle zero/negative distance
    elif distance <= 4.0: edge_features[feature_idx] = 1.0
    elif distance <= 8.0: edge_features[feature_idx + 1] = 1.0
    elif distance <= 12.0: edge_features[feature_idx + 2] = 1.0
    else: edge_features[feature_idx + 3] = 1.0
    feature_idx += 4
    edge_features[feature_idx] = (1.0 / distance) if distance > 1e-6 else 0.0 # Inverse distance
    feature_idx += 1

    # 2. Sequential Features (2 features) - Neighbor + Normalized Distance
    seq_dist = abs(i_orig - j_orig)
    edge_features[feature_idx] = float(seq_dist == 1)
    edge_features[feature_idx + 1] = seq_dist / max(1, seq_len - 1)
    feature_idx += 2

    # 3. K-relative Features (2 features) - Connected to K + Normalized Dist to K
    edge_features[feature_idx] = float(i_orig == K_POS or j_orig == K_POS)
    max_dist_from_k = max(1, K_POS, seq_len - 1 - K_POS)
    edge_features[feature_idx + 1] = min(abs(i_orig - K_POS), abs(j_orig - K_POS)) / max_dist_from_k
    feature_idx += 2

    # 4. Secondary Structure Interaction (6 features)
    if i_orig < len(ss_string) and j_orig < len(ss_string):
        ss_pairs = ['HH', 'HE', 'HL', 'EE', 'EL', 'LL']
        ss_map = {'H': 'H', 'E': 'E', 'L': 'L', '-': 'L'} # Map padding to Loop
        ss_i = ss_map.get(ss_string[i_orig], 'L')
        ss_j = ss_map.get(ss_string[j_orig], 'L')
        ss_pair = ''.join(sorted([ss_i, ss_j]))
        for idx, pair in enumerate(ss_pairs):
            edge_features[feature_idx + idx] = float(ss_pair == pair)
    feature_idx += 6

    # 5. SASA Interaction (2 features) - Difference + Average
    if i_orig < len(sasa_vals) and j_orig < len(sasa_vals):
        sasa_i = np.nan_to_num(sasa_vals[i_orig])
        sasa_j = np.nan_to_num(sasa_vals[j_orig])
        edge_features[feature_idx] = abs(sasa_i - sasa_j)
        edge_features[feature_idx + 1] = (sasa_i + sasa_j) / 2.0
    feature_idx += 2

    # Final check (optional)
    # if feature_idx != EXPECTED_EDGE_FEATURE_DIM: print("Warning: Feature index mismatch")

    return np.nan_to_num(edge_features, nan=0.0, posinf=0.0, neginf=0.0)

# from torch_geometric.nn import GCNConv, GATv2Conv, PNAConv, GINEConv
# from torch_geometric.utils import degree

class GCNNetwork(nn.Module):
    """ GNN Network using GCNConv layers (No edge features) """
    def __init__(self, input_dim: int, hidden_dim: int = 128, dropout: float = 0.4, layers: int = 3):
        super().__init__()
        if layers < 1: raise ValueError("GCN layers must be >= 1.")
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        current_dim = input_dim
        for i in range(layers):
            self.convs.append(GCNConv(current_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
            current_dim = hidden_dim # Input for next layer is hidden_dim
        self.dropout_layer = nn.Dropout(dropout)
        self.num_layers = layers
        self.output_dim = hidden_dim
        print(f"Initialized GCNNetwork: Layers={layers}, Hidden={hidden_dim}, Dropout={dropout}")

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        """ Forward pass for GCN. Returns all node features. """
        x_res = x # Initial input for potential residual connection
        for i in range(self.num_layers):
            # Store input features for residual connection (after first layer)
            x_input = x if i > 0 else None

            x = self.convs[i](x, edge_index)

            # Apply Batch Normalization (check for nodes)
            if x.shape[0] > 1: x = self.batch_norms[i](x)
            elif x.shape[0] == 0: return x # Handle empty graph

            x = F.relu(x) # Activation

            # Apply Residual Connection
            # Check if possible (not first layer) and shapes match
            if x_input is not None and x.shape == x_input.shape:
                x = x + x_input

            x = self.dropout_layer(x) # Apply dropout

        return x # Shape: [num_nodes, hidden_dim]

class GATv2Network(nn.Module):
    """ GNN Network using GATv2Conv (Optionally uses edge features) - Returns all node features """
    def __init__(self, input_dim: int, hidden_dim: int = 128, heads: int = 4, dropout: float = 0.4, layers: int = 3, edge_dim: Optional[int] = None):
        super().__init__()
        if layers < 1: raise ValueError("GATv2 layers must be >= 1.")
        if hidden_dim % heads != 0: raise ValueError(f"GATv2 hidden_dim ({hidden_dim}) must be divisible by heads ({heads})")

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        head_dim = hidden_dim // heads
        # Flag to determine if edge_attr should be passed in forward
        self.has_edge_features = edge_dim is not None and edge_dim > 0
        current_dim = input_dim

        # Layer Definitions
        for i in range(layers):
            conv_input_dim = current_dim if i == 0 else hidden_dim
            # GATv2 output is concat=True, so output is hidden_dim (heads*head_dim)
            self.convs.append(GATv2Conv(conv_input_dim, head_dim, heads=heads, concat=True,
                                        dropout=dropout, edge_dim=edge_dim, add_self_loops=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
            current_dim = hidden_dim # Input dim for next layer is always hidden_dim

        self.dropout_layer = nn.Dropout(dropout) # Dropout applied after activation/residual
        self.num_layers = layers
        self.output_dim = hidden_dim
        print(f"Initialized GATv2Network: Layers={layers}, Hidden={hidden_dim}, Heads={heads}, Dropout={dropout}, EdgeDim={edge_dim}")

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor:
        """ Forward pass for GATv2. Returns all node features. """
        x_res = x
        for i in range(self.num_layers):
            x_input = x if i > 0 else None # Store input for residual

            # Pass edge_attr only if the model was initialized with edge_dim
            if self.has_edge_features:
                if edge_attr is None: raise ValueError("GATv2Network expects edge_attr but received None.")
                # Optional: Check edge_attr dim matches self.edge_dim if edge_dim was stored
                x = self.convs[i](x, edge_index, edge_attr=edge_attr)
            else:
                x = self.convs[i](x, edge_index) # Call without edge_attr

            # Apply Batch Normalization
            if x.shape[0] > 1: x = self.batch_norms[i](x)
            elif x.shape[0] == 0: return x

            # Apply Activation (ELU common for GAT)
            x = F.elu(x)

            # Apply Residual Connection
            if x_input is not None and x.shape == x_input.shape:
                x = x + x_input

            # Apply Dropout
            x = self.dropout_layer(x)

        return x # Shape: [num_nodes, hidden_dim]

class PNANetwork(nn.Module):
    """ GNN Network using PNAConv layers (No edge features) - Returns all node features """
    def __init__(self, input_dim: int, hidden_dim: int = 128, layers: int = 3, dropout: float = 0.4, deg: Optional[torch.Tensor] = None):
        super().__init__()
        if layers < 1: raise ValueError("PNA layers must be >= 1.")
        if deg is None: raise ValueError("PNANetwork requires the degree histogram 'deg' argument.")

        # PNA parameters (can be tuned)
        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        current_dim = input_dim

        # Layer Definitions
        for i in range(layers):
            conv_input_dim = current_dim if i == 0 else hidden_dim
            # towers=4, post_layers=1 are common defaults/recommendations for PNA
            self.convs.append(PNAConv(conv_input_dim, hidden_dim, aggregators=aggregators,
                                      scalers=scalers, deg=deg, towers=4, post_layers=1))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
            current_dim = hidden_dim

        self.dropout_layer = nn.Dropout(dropout)
        self.num_layers = layers
        self.output_dim = hidden_dim
        print(f"Initialized PNANetwork: Layers={layers}, Hidden={hidden_dim}, Dropout={dropout}")

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        """ Forward pass for PNA. Returns all node features. """
        x_res = x
        for i in range(self.num_layers):
            x_input = x if i > 0 else None

            # Apply PNAConv (does not use edge_attr)
            x = self.convs[i](x, edge_index)

            # Apply Batch Normalization
            if x.shape[0] > 1: x = self.batch_norms[i](x)
            elif x.shape[0] == 0: return x

            # Apply Activation
            x = F.relu(x)

            # Apply Residual Connection
            if x_input is not None and x.shape == x_input.shape:
                x = x + x_input

            # Apply Dropout
            x = self.dropout_layer(x)

        return x # Shape: [num_nodes, hidden_dim]


class GINENetwork(nn.Module):
    """ GNN Network using GINEConv layers (Requires edge features) - Returns all node features """
    def __init__(self, input_dim: int, hidden_dim: int = 128, edge_dim: int = EXPECTED_EDGE_FEATURE_DIM, dropout: float = 0.4, layers: int = 3):
        super().__init__()
        if layers < 1: raise ValueError("GINE layers must be >= 1.")
        if edge_dim is None or edge_dim <= 0: raise ValueError("GINENetwork requires a valid positive 'edge_dim'.")

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.edge_dim = edge_dim # Store edge dim used
        current_dim = input_dim

        # Define the MLP for GINE's node update function 'nn'
        def create_mlp(mlp_input_dim, mlp_output_dim):
            # Simple 2-layer MLP used by GINE to update node features
            return nn.Sequential(
                nn.Linear(mlp_input_dim, mlp_output_dim * 2), # Example expansion
                nn.ReLU(),
                nn.Dropout(dropout), # Dropout within the MLP
                nn.Linear(mlp_output_dim * 2, mlp_output_dim)
            )

        # Layer Definitions
        for i in range(layers):
            # The MLP inside GINE takes the current node dim as input
            mlp_input_dim = current_dim if i == 0 else hidden_dim
            mlp = create_mlp(mlp_input_dim, hidden_dim)
            # train_eps=True allows the epsilon parameter in GIN update to be learned
            self.convs.append(GINEConv(nn=mlp, edge_dim=self.edge_dim, train_eps=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
            current_dim = hidden_dim

        self.dropout_layer = nn.Dropout(dropout) # Dropout after layer block
        self.num_layers = layers
        self.output_dim = hidden_dim
        print(f"Initialized GINENetwork: Layers={layers}, Hidden={hidden_dim}, Dropout={dropout}, EdgeDim={edge_dim}")

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        """ Forward pass for GINE. Requires edge_attr. Returns all node features. """
        # x: [N, input_dim], edge_index: [2, E], edge_attr: [E, edge_dim]
        if edge_attr is None:
            raise ValueError("GINENetwork requires edge_attr for forward pass.")
        # Optional: Check edge_attr dimension matches self.edge_dim
        if edge_attr.shape[1] != self.edge_dim:
             raise ValueError(f"GINENetwork edge_attr dim mismatch: Expected {self.edge_dim}, Got {edge_attr.shape[1]}")

        x_res = x
        for i in range(self.num_layers):
            x_input = x if i > 0 else None

            # Apply GINEConv, passing edge_attr
            x = self.convs[i](x, edge_index, edge_attr=edge_attr)

            # Apply Batch Normalization
            if x.shape[0] > 1: x = self.batch_norms[i](x)
            elif x.shape[0] == 0: return x

            # Apply Activation
            x = F.relu(x)

            # Apply Residual Connection
            if x_input is not None and x.shape == x_input.shape:
                x = x + x_input

            # Apply Dropout
            x = self.dropout_layer(x)

        return x # Shape: [num_nodes, hidden_dim]


class HybridModel(nn.Module):
    """
    Hybrid model combining features from a chosen GNN backbone
    (using GNN readout: Central + Mean + Max) and a processed ProtT5 embedding
    via late fusion for binary classification.
    """
    def __init__(self, gnn_type: str, node_feature_dim: int, hidden_dim: int = 128,
                 prot_t5_dim: int = PROT_T5_DIM, # From constants
                 # --- GNN specific args (passed via model_config in train function) ---
                 deg: Optional[torch.Tensor] = None, # Required for PNA
                 edge_dim: Optional[int] = None,     # Used by GATv2/GINE if provided
                 heads: int = 4,                     # Default heads for GATv2
                 # --- Common GNN hyperparams ---
                 layers: int = 3,
                 gnn_dropout: float = 0.4,           # Dropout for GNN layers
                 # --- ProtT5 MLP hyperparams ---
                 pt5_mlp_dropout: float = 0.4,
                 # --- Final Classifier hyperparams ---
                 classifier_dropout: float = 0.5):
        """
        Initializes the Hybrid GNN+ProtT5 Model.

        Args:
            gnn_type (str): Type of GNN backbone ('gcn', 'gatv2', 'pna', 'gine').
            node_feature_dim (int): Dimensionality of input node features for GNN.
            hidden_dim (int): Hidden dimension for GNN layers.
            prot_t5_dim (int): Dimensionality of input ProtT5 embeddings.
            deg (Optional[torch.Tensor]): Degree histogram tensor, REQUIRED for 'pna'.
            edge_dim (Optional[int]): Edge feature dimension. Used by 'gatv2'/'gine' if provided.
                                     Should match EXPECTED_EDGE_FEATURE_DIM if used.
            heads (int): Number of attention heads for 'gatv2'.
            layers (int): Number of layers in the GNN backbone.
            gnn_dropout (float): Dropout rate within the GNN backbone layers.
            pt5_mlp_dropout (float): Dropout rate within the ProtT5 MLP.
            classifier_dropout (float): Dropout rate in the final classifier MLP.
        """
        super(HybridModel, self).__init__()
        self.gnn_type = gnn_type.lower()
        # Determine if edge features are needed based on GNN type and edge_dim argument
        self.use_edge_features = edge_dim is not None and self.gnn_type in ['gatv2', 'gine']
        print(f"Initializing HybridModel (GNN={self.gnn_type.upper()} + ProtT5)")
        if self.use_edge_features:
             print(f"  GNN configured to use edge features (dim={edge_dim})")
        else:
             print(f"  GNN configured NOT to use edge features.")


        # --- GNN Backbone Instantiation ---
        gnn_common_args = {
            'input_dim': node_feature_dim, 'hidden_dim': hidden_dim,
            'dropout': gnn_dropout, 'layers': layers
        }
        try:
            if self.gnn_type == 'gcn':
                self.gnn = GCNNetwork(**gnn_common_args)
            elif self.gnn_type == 'gatv2':
                # Pass edge_dim (can be None), pass heads
                self.gnn = GATv2Network(**gnn_common_args, heads=heads, edge_dim=edge_dim)
            elif self.gnn_type == 'pna':
                if deg is None: raise ValueError("PNA requires the 'deg' histogram argument.")
                self.gnn = PNANetwork(**gnn_common_args, deg=deg)
            elif self.gnn_type == 'gine':
                if edge_dim is None: raise ValueError("GINE requires the 'edge_dim' argument.")
                # Ensure passed edge_dim matches expected if necessary
                if edge_dim != EXPECTED_EDGE_FEATURE_DIM:
                    print(f"Warning: GINE edge_dim ({edge_dim}) differs from expected ({EXPECTED_EDGE_FEATURE_DIM}).")
                self.gnn = GINENetwork(**gnn_common_args, edge_dim=edge_dim)
            else:
                raise ValueError(f"Unsupported GNN type: {gnn_type}")
        except Exception as e:
             print(f"--- ERROR during GNN Backbone ({self.gnn_type.upper()}) Initialization ---")
             print(e)
             traceback.print_exc()
             raise e # Re-raise after printing

        # Get the output dimension from the instantiated GNN
        gnn_output_dim = self.gnn.output_dim


        # --- ProtT5 Track Initialization ---
        # Simple MLP to process the ProtT5 embedding before fusion
        pt5_mlp_hidden = 256 # Example hidden size for ProtT5 MLP
        prot_t5_output_dim = 128 # Final dimension after ProtT5 MLP
        self.prot_t5_mlp = nn.Sequential(
            nn.Linear(prot_t5_dim, pt5_mlp_hidden),
            # Consider BatchNorm1d? nn.BatchNorm1d(pt5_mlp_hidden),
            nn.ReLU(),
            nn.Dropout(pt5_mlp_dropout),
            nn.Linear(pt5_mlp_hidden, prot_t5_output_dim),
            # Consider BatchNorm1d? nn.BatchNorm1d(prot_t5_output_dim),
            nn.ReLU(),
            nn.Dropout(pt5_mlp_dropout)
            # No final activation here, combined then classified
        )
        print(f"  ProtT5 MLP initialized: {prot_t5_dim} -> ... -> {prot_t5_output_dim}")


        # --- Combination and Final Classifier ---
        # Input dimension = GNN Central + GNN Mean + GNN Max + Processed ProtT5
        combined_input_dim = gnn_output_dim + gnn_output_dim + gnn_output_dim + prot_t5_output_dim
        print(f"  Combined input dimension for final classifier: {combined_input_dim}")

        # Final classification layers (MLP head)
        classifier_hidden = 64 # Example hidden size
        self.fc1 = nn.Linear(combined_input_dim, classifier_hidden)
        self.bn_combine = nn.BatchNorm1d(classifier_hidden) # Renamed BN layer
        self.dropout_combine = nn.Dropout(classifier_dropout) # Renamed Dropout layer
        self.fc2 = nn.Linear(classifier_hidden, 1) # Final binary output

        print("HybridModel initialization complete.")


    def forward(self, data: Batch) -> torch.Tensor:
        """ Forward pass of the HybridModel (GNN + ProtT5). """
        # --- Input Extraction ---
        x, edge_index = data.x, data.edge_index
        # Edge attributes are needed only if self.use_edge_features is True
        edge_attr = data.edge_attr if self.use_edge_features and hasattr(data, 'edge_attr') else None
        batch = data.batch # Maps each node to its graph index in the batch
        central_node_idx = data.central_node_idx # Relative index within each graph's valid nodes
        ptr = data.ptr if hasattr(data, 'ptr') else None # Graph start/end pointers

        # Get ProtT5 embedding - check it exists
        if not hasattr(data, 'prot_t5_embedding'):
             raise AttributeError("Batch object missing 'prot_t5_embedding'. Check data preparation.")
        prot_t5_emb = data.prot_t5_embedding # Shape: [batch_size, prot_t5_dim]

        # Reconstruct ptr from batch if missing (e.g., single inference)
        if ptr is None and batch is not None:
             if batch.numel() > 0: counts = torch.bincount(batch); ptr = torch.cat([torch.tensor([0], device=batch.device), counts.cumsum(0)])
             else: ptr = torch.tensor([0], device=batch.device)

        # Essential check for pooling/indexing logic downstream
        if batch is None or ptr is None or central_node_idx is None:
             raise ValueError("HybridModel forward requires 'batch', 'ptr', and 'central_node_idx' attributes.")

        batch_size = data.num_graphs
        if batch_size == 0: # Handle empty batch early
            # print("Warning: HybridModel forward received empty batch (num_graphs=0).")
            return torch.zeros((0, 1), device=x.device if x is not None else device, dtype=torch.float)

        # Validate ProtT5 shape against batch_size
        if prot_t5_emb.shape[0] != batch_size or prot_t5_emb.shape[1] != PROT_T5_DIM:
             raise ValueError(f"ProtT5 embedding shape mismatch! Expected ({batch_size}, {PROT_T5_DIM}), Got {prot_t5_emb.shape}")

        # --- ProtT5 Track ---
        prot_t5_features = self.prot_t5_mlp(prot_t5_emb) # Shape: [batch_size, prot_t5_output_dim]

        # --- GNN Track ---
        # Call the selected GNN backbone
        # Pass edge_attr only if the model is configured to use them
        if self.use_edge_features:
             if edge_attr is None:
                 raise ValueError(f"Model '{self.gnn_type}' requires edge_attr but not found in data.")
             # Check edge_attr feature dimension if it exists
             if edge_index.shape[1] > 0: # Only check if edges exist
                 if edge_attr.shape[0] != edge_index.shape[1]:
                     raise ValueError(f"Edge attribute count ({edge_attr.shape[0]}) != edge index count ({edge_index.shape[1]})")
                 # Optional: Check dimension match with model's expectation if edge_dim was stored
                 # if edge_attr.shape[1] != self.gnn.edge_dim: # Requires GNNs to store edge_dim
                 #     raise ValueError(f"Edge attribute dim mismatch.")
             gnn_node_features = self.gnn(x, edge_index, edge_attr=edge_attr)
        else:
             # Pass only x and edge_index (GCN, PNA, GATv2 w/o edge_dim)
             gnn_node_features = self.gnn(x, edge_index)


        # --- GNN Readout (Central Node + Global Pooling) ---
        # Initialize readout tensors robustly
        gnn_output_dim = self.gnn.output_dim
        central_node_features = torch.zeros((batch_size, gnn_output_dim), device=x.device, dtype=x.dtype)
        global_avg_features = torch.zeros((batch_size, gnn_output_dim), device=x.device, dtype=x.dtype)
        global_max_features = torch.zeros((batch_size, gnn_output_dim), device=x.device, dtype=x.dtype)

        # Perform readout only if GNN output is valid
        num_nodes = gnn_node_features.shape[0] if gnn_node_features is not None else 0
        nodes_exist = num_nodes > 0
        # Indices/batch vector must also be valid for readout
        indices_valid = ptr.numel() > 1 and central_node_idx.numel() == batch_size and \
                        batch.numel() > 0 and batch.max() < batch_size

        if nodes_exist and indices_valid:
            try:
                # Extract Central Node Features
                graph_starts = ptr[:-1]
                absolute_central_node_indices = graph_starts + central_node_idx
                if absolute_central_node_indices.max() < num_nodes and absolute_central_node_indices.min() >= 0:
                     central_node_features = gnn_node_features[absolute_central_node_indices]
                # else: print(f"Warning: Readout - Central index bounds.") # Debug

                # Global Pooling
                if batch.numel() == num_nodes: # Ensure batch vector matches nodes
                     global_avg_features = global_mean_pool(gnn_node_features, batch)
                     global_max_features = global_max_pool(gnn_node_features, batch)
                # else: print(f"Warning: Readout - Batch vector length mismatch.") # Debug

            except Exception as e_readout:
                 print(f"Error during GNN readout: {e_readout}. Using zero features.")
                 # Keep initialized zero tensors

        # elif batch_size > 0 : print(f"Warning: Skipping GNN readout.") # Debug

        # --- Combine Features ---
        # Concatenate: GNN Central + GNN Mean + GNN Max + Processed ProtT5
        # Ensure all components have the correct batch size dimension
        feature_list = [central_node_features, global_avg_features, global_max_features, prot_t5_features]
        correct_batch_size = True
        for i, feat in enumerate(feature_list):
            if feat.shape[0] != batch_size:
                print(f"Error: Feature {i} has wrong batch size before concat. Expected {batch_size}, Got {feat.shape}")
                correct_batch_size = False
                break

        if not correct_batch_size:
            # Handle error state - cannot concatenate reliably
            print("FATAL: Cannot combine features due to shape mismatch. Returning zeros.")
            return torch.zeros((batch_size, 1), device=x.device, dtype=torch.float)

        # Concatenate the features
        combined = torch.cat(feature_list, dim=1)

        # --- Final Classification MLP ---
        x_out = self.fc1(combined)
        if x_out.shape[0] > 1: # Apply BN only if batch size > 1
            x_out = self.bn_combine(x_out)
        # Handle case where batch size is 1 (BN would fail) or combination resulted in empty tensor
        elif x_out.shape[0] == 0:
             return torch.zeros((batch_size, 1), device=x.device, dtype=torch.float)

        x_out = F.relu(x_out)
        x_out = self.dropout_combine(x_out)
        x_out = self.fc2(x_out) # Shape: [batch_size, 1]

        # Apply sigmoid for binary classification probability
        return torch.sigmoid(x_out)


def prepare_graph_data(
    df_aligned: pd.DataFrame,
    prot_t5_embeddings: np.ndarray,
    distance_threshold: float = 8.0,
    use_ss_node_feature: bool = True, # Flag to include SS one-hot in node features
    include_edge_features: bool = False # Flag to create/include 17-dim edge_attr
    ) -> Tuple[List[Data], List[int]]:
    """
    Prepare graph Data objects for PyTorch Geometric using the ALIGNED DataFrame
    and corresponding ProtT5 embeddings, tailored for the GNN+ProtT5 Hybrid Model.

    Extracts node features (AA, central K, angles, SASA, SS, pLDDT).
    Constructs graph edges based on distance.
    Optionally generates 17-dim edge features if include_edge_features=True.
    Adds the ProtT5 embedding for the central K to each Data object.
    Does NOT store the full sequence string in the final Data object.

    Args:
        df_aligned (pd.DataFrame): DataFrame ALIGNED with prot_t5_embeddings.
                                   Requires sequence and structural feature columns.
        prot_t5_embeddings (np.ndarray): NumPy array of ProtT5 embeddings, ordered
                                          corresponding to df_aligned rows.
                                          Shape: [num_aligned, PROT_T5_DIM].
        distance_threshold (float): Max C-alpha distance for graph edges.
        use_ss_node_feature (bool): Whether to include SS one-hot in node features.
        include_edge_features (bool): Whether to generate and include edge_attr.

    Returns:
        tuple: (graph_list, labels)
    """
    start_time = time.time()
    print(f"Preparing PyG graph data (GNN+ProtT5 format: Edge Feats={include_edge_features}, SS Node={use_ss_node_feature}) for {len(df_aligned)} aligned samples...")
    graph_list = []
    labels = []
    skipped_count = 0
    parse_errors, concat_errors, edge_errors, validation_errors = 0, 0, 0, 0

    # --- Pre-check Alignment ---
    if len(df_aligned) != prot_t5_embeddings.shape[0] or prot_t5_embeddings.shape[1] != PROT_T5_DIM:
         raise ValueError("Critical Setup Error: Mismatch between aligned df and ProtT5 embeddings dimensions.")

    # --- Define features to parse from the DataFrame ---
    # Ensure all columns needed for node features AND edge features (if used) are listed
    feature_names_to_parse = [
        'phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4', # All angles
        'sasa', 'ss', 'plDDT', # SASA, SS, pLDDT
        'distance_map', # Essential for edges
        'sequence', # Essential for AA features, edge features, validation
    ]
    print(f"  Will attempt to parse features: {feature_names_to_parse}")


    # --- Iterate through aligned data ---
    # Using enumerate to get index 'i' for accessing prot_t5_embeddings[i]
    for i, (original_idx, row) in enumerate(df_aligned.iterrows()):
        try:
            # --- Basic Data Retrieval & Validation ---
            sequence = row.get('sequence') # Needed for AA features, padding checks, edge features
            label = row.get('label')
            if pd.isna(sequence) or len(sequence) != EXPECTED_SEQ_LEN or sequence[CENTRAL_K_POS_ABS] != 'K' or pd.isna(label):
                 skipped_count += 1; continue
            label = int(label)

            # Get the corresponding ProtT5 embedding (pre-aligned)
            current_prot_t5_embedding_np = prot_t5_embeddings[i] # Shape (PROT_T5_DIM,)

            # --- Parse Structural Features ---
            parsed_data = {}
            valid_row = True
            for name in feature_names_to_parse:
                if name == 'sequence': # Already have sequence
                    parsed_data[name] = sequence
                    continue
                if name not in row or pd.isna(row[name]):
                     if name == 'distance_map': valid_row = False; break # Essential
                     else: parsed_data[name] = np.nan # Placeholder for optional features
                else:
                    try: # Parse data safely
                        feature_str = str(row[name])
                        if name == 'ss': parsed_data[name] = feature_str; assert len(parsed_data[name]) == EXPECTED_SEQ_LEN
                        else:
                            parsed_arr = np.array(eval(feature_str), dtype=np.float32)
                            if name == 'distance_map':
                                if parsed_arr.size == EXPECTED_SEQ_LEN*EXPECTED_SEQ_LEN: parsed_data[name] = parsed_arr.reshape(EXPECTED_SEQ_LEN, EXPECTED_SEQ_LEN)
                                else: raise ValueError("Distmap size mismatch")
                            elif parsed_arr.ndim == 1 and len(parsed_arr) != EXPECTED_SEQ_LEN: # Handle 1D length mismatch
                                temp_arr = np.full(EXPECTED_SEQ_LEN, np.nan, dtype=np.float32); L = min(len(parsed_arr), EXPECTED_SEQ_LEN); temp_arr[:L] = parsed_arr[:L]; parsed_data[name] = temp_arr
                            else: parsed_data[name] = parsed_arr
                    except Exception as e:
                        # print(f"Parsing error for {name}, row {original_idx}: {e}") # Debug
                        valid_row = False; parse_errors += 1; break
            if not valid_row: skipped_count += 1; continue
            # Ensure essential parsed data exists
            distance_map = parsed_data.get('distance_map')
            ss_full = parsed_data.get('ss', '-' * EXPECTED_SEQ_LEN) # Default if missing
            sasa_full = parsed_data.get('sasa', np.full(EXPECTED_SEQ_LEN, np.nan)) # Default if missing
            if distance_map is None: skipped_count += 1; continue


            # --- Identify Valid Nodes ---
            valid_pos_indices = [k for k, aa in enumerate(sequence) if aa in VALID_AA]
            if not valid_pos_indices: skipped_count += 1; continue
            num_nodes = len(valid_pos_indices)
            try: central_k_new_idx = valid_pos_indices.index(CENTRAL_K_POS_ABS)
            except ValueError: skipped_count += 1; continue # Central K was padded


            # --- Node Feature Extraction ---
            # Collect all node features required by the GNN track
            node_features_list = []
            # 1. AA One-hot (20 features)
            aa_onehot = np.zeros((num_nodes, len(VALID_AA)), dtype=np.float32);
            for k, node_idx in enumerate(valid_pos_indices): aa = sequence[node_idx]; aa_idx = VALID_AA.find(aa);
            if aa_idx >= 0: aa_onehot[k, aa_idx] = 1.0
            node_features_list.append(aa_onehot)
            # 2. Central K indicator (1 feature)
            is_central_k = np.zeros((num_nodes, 1), dtype=np.float32); is_central_k[central_k_new_idx, 0] = 1.0
            node_features_list.append(is_central_k)
            # 3. Angles (phi, psi, omega, tau, chi1-4) -> sin/cos (8 * 2 = 16 features)
            angle_keys = ['phi', 'psi', 'omega']
            for key in angle_keys:
                 if key in parsed_data and isinstance(parsed_data[key], np.ndarray):
                     valid_angles = np.nan_to_num(parsed_data[key][valid_pos_indices], nan=0.0) # Fill NaN angles with 0
                     angle_rad = np.pi * valid_angles / 180.0
                     sin_cos = np.stack([np.sin(angle_rad), np.cos(angle_rad)], axis=-1)
                     node_features_list.append(sin_cos.astype(np.float32))
                 else: node_features_list.append(np.zeros((num_nodes, 2), dtype=np.float32)) # Zeros if missing
            # 4. SASA (1 feature)
            # if 'sasa' in parsed_data and isinstance(parsed_data['sasa'], np.ndarray):
            #     valid_sasa = np.nan_to_num(parsed_data['sasa'][valid_pos_indices], nan=0.0).reshape(-1, 1)
            #     node_features_list.append(valid_sasa.astype(np.float32))
            # else: node_features_list.append(np.zeros((num_nodes, 1), dtype=np.float32))
            # # 5. SS (Secondary Structure) (3 features if use_ss_node_feature=True)
            # if use_ss_node_feature and 'ss' in parsed_data and isinstance(parsed_data['ss'], str):
            #      ss_string_feat = parsed_data['ss']; valid_ss_chars = [ss_string_feat[k] for k in valid_pos_indices]
            #      ss_onehot = np.zeros((num_nodes, 3), dtype=np.float32); ss_map = {'H': 0, 'E': 1, 'L': 2, '-': 2} # Add other SS codes if present
            #      for k, ss_char in enumerate(valid_ss_chars): ss_onehot[k, ss_map.get(ss_char.upper(), 2)] = 1.0
            #      node_features_list.append(ss_onehot)
            # elif use_ss_node_feature: node_features_list.append(np.zeros((num_nodes, 3), dtype=np.float32))
            # # 6. pLDDT (1 feature, scaled 0-1)
            # if 'plDDT' in parsed_data and isinstance(parsed_data['plDDT'], np.ndarray):
            #     valid_plddt = np.nan_to_num(parsed_data['plDDT'][valid_pos_indices], nan=50.0).reshape(-1, 1) # Fill NaN pLDDT?
            #     node_features_list.append((valid_plddt / 100.0).astype(np.float32)) # Scale 0-1
            # else: node_features_list.append(np.zeros((num_nodes, 1), dtype=np.float32))
            # --- Concatenate Node Features ---
            try:
                 if not node_features_list: raise ValueError("Node feature list empty.")
                 node_features = np.concatenate(node_features_list, axis=1)
                 if np.isnan(node_features).any() or np.isinf(node_features).any():
                      node_features = np.nan_to_num(node_features, nan=0.0, posinf=0.0, neginf=0.0)
            except ValueError as e: skipped_count += 1; concat_errors += 1; continue


            # --- Edge Construction & Optional Feature Extraction ---
            edge_features_list = [] # Reset for each graph
            edge_attr_tensor = None
            try:
                valid_distance_map = distance_map[np.ix_(valid_pos_indices, valid_pos_indices)]
                adj = (valid_distance_map < distance_threshold) & (valid_distance_map > 0); np.fill_diagonal(adj, False)
                edge_list_valid = np.argwhere(adj) # Indices relative to valid nodes
                edges = []
                if edge_list_valid.shape[0] > 0:
                    edges = edge_list_valid.tolist()
                    if include_edge_features:
                        for i_valid, j_valid in edges:
                            i_orig = valid_pos_indices[i_valid]; j_orig = valid_pos_indices[j_valid]
                            # Pass sequence string parsed earlier
                            edge_feat_vec = create_edge_features(i_orig, j_orig, distance_map, sequence, ss_full, sasa_full)
                            edge_features_list.append(edge_feat_vec)

                # Fallback Sequential Edges
                if not edges and num_nodes > 1:
                     for k_valid in range(num_nodes - 1):
                         edges.extend([[k_valid, k_valid+1], [k_valid+1, k_valid]])
                         if include_edge_features:
                             i_orig = valid_pos_indices[k_valid]; j_orig = valid_pos_indices[k_valid+1]
                             feat1 = create_edge_features(i_orig, j_orig, distance_map, sequence, ss_full, sasa_full)
                             feat2 = create_edge_features(j_orig, i_orig, distance_map, sequence, ss_full, sasa_full)
                             edge_features_list.extend([feat1, feat2])

                if not edges: skipped_count += 1; edge_errors += 1; continue # Skip if still no edges

                # Convert edge features to tensor if needed and generated
                if include_edge_features and edge_features_list:
                    edge_attr_np = np.array(edge_features_list, dtype=np.float32)
                    if edge_attr_np.shape[0] == len(edges) and edge_attr_np.shape[1] == EXPECTED_EDGE_FEATURE_DIM:
                         edge_attr_tensor = torch.tensor(edge_attr_np, dtype=torch.float)
                    # else: Print warning about mismatch (less verbose)

            except Exception as e: skipped_count += 1; edge_errors += 1; continue


            # --- Convert to Tensors ---
            try:
                x_tensor = torch.tensor(node_features, dtype=torch.float)
                edge_index_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()
                y_tensor = torch.tensor([label], dtype=torch.float)
                central_node_idx_tensor = torch.tensor([central_k_new_idx], dtype=torch.long)
                # ProtT5 tensor - unsqueeze to add batch-like dimension [1, Dim]
                prot_t5_tensor = torch.tensor(current_prot_t5_embedding_np, dtype=torch.float).unsqueeze(0)
                # --- Final Check for ProtT5 Shape ---
                if prot_t5_tensor.shape != (1, PROT_T5_DIM):
                     print(f"Warning: ProtT5 tensor shape incorrect ({prot_t5_tensor.shape}) for row {original_idx}. Skipping.")
                     skipped_count += 1; continue
            except Exception as e: skipped_count += 1; continue


            # --- Create PyG Data object ---
            data_dict = {
                'x': x_tensor,
                'edge_index': edge_index_tensor,
                'y': y_tensor,
                'central_node_idx': central_node_idx_tensor,
                'prot_t5_embedding': prot_t5_tensor # ADD T5 embedding here
                # NO 'sequence' attribute added
            }
            # Conditionally add edge_attr tensor if it was successfully created
            if edge_attr_tensor is not None:
                data_dict['edge_attr'] = edge_attr_tensor

            data = Data(**data_dict)

            # --- Validate Data Object ---
            if not data.validate(raise_on_error=False): # Raise error for debugging if needed
                 skipped_count += 1; validation_errors += 1; continue

            graph_list.append(data)
            labels.append(label)

        # --- Catch errors for the entire row processing ---
        except Exception as e_outer:
            print(f"--- Critical Error processing aligned row index {i} (original index {original_idx}): {e_outer} ---")
            traceback.print_exc(limit=2); skipped_count += 1
            continue

    # --- Final Summary ---
    end_time = time.time()
    print(f"\nGraph Preparation Summary:") # ... (print summary counts as before) ...
    print(f"  Created {len(graph_list)} graphs from {len(df_aligned)} aligned samples.")
    print(f"  Skipped {skipped_count} rows (ParseErr={parse_errors}, ConcatErr={concat_errors}, EdgeErr={edge_errors}, ValidErr={validation_errors}, Other={skipped_count-(parse_errors+concat_errors+edge_errors+validation_errors)}).")
    if graph_list:
         print(f"  Example graph node feature dimension: {graph_list[0].num_node_features}")
         edge_feat_dim_str = str(graph_list[0].num_edge_features) if graph_list[0].edge_attr is not None else 'N/A'
         print(f"  Example graph edge feature dimension: {edge_feat_dim_str}")
         print(f"  Example ProtT5 embedding dimension: {graph_list[0].prot_t5_embedding.shape}")
    print(f"Graph preparation finished in {end_time - start_time:.2f} seconds.")
    return graph_list, labels


def train_model(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer,
                device: torch.device, class_weights: Optional[Dict] = None) -> Tuple[float, float]:
    """
    Trains the model for one epoch. Handles GNN+ProtT5 HybridModel input.

    Args:
        model (nn.Module): The PyTorch model (HybridModel) to train.
        loader (DataLoader): DataLoader for the training data.
        optimizer (torch.optim.Optimizer): The optimizer to use.
        device (torch.device): The device (CPU or CUDA) to train on.
        class_weights (Optional[Dict]): Dict mapping class index (0, 1) to weight.

    Returns:
        tuple: (average_loss_per_sample, accuracy) for the epoch.
    """
    model.train() # Set the model to training mode
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    processed_graphs = 0
    skipped_batches = 0

    for batch in loader:
        try:
            batch = batch.to(device)

            # --- Batch Validation ---
            # Check essential attributes needed by HybridModel (GNN+ProtT5)
            required_attrs = ['x', 'edge_index', 'prot_t5_embedding', 'y', 'batch', 'central_node_idx']
            # edge_attr is optional depending on the GNN used within HybridModel
            if not all(hasattr(batch, attr) and getattr(batch, attr) is not None for attr in required_attrs):
                print(f"Warning: Skipping training batch due to missing essential attributes: {[attr for attr in required_attrs if not hasattr(batch, attr) or getattr(batch, attr) is None]}")
                skipped_batches += 1
                continue
            # Skip batches with no nodes or no graphs (can happen with filtering/edge cases)
            if batch.num_nodes == 0 or batch.num_graphs == 0:
                skipped_batches += 1
                continue

            # --- Forward Pass ---
            optimizer.zero_grad() # Clear previous gradients
            output = model(batch) # Get model predictions (should be sigmoid output)

            # --- Target Preparation ---
            target = batch.y.view(-1, 1).float() # Ensure target shape [batch_size, 1] and type

            # Check if output and target shapes match after potential filtering/empty graphs
            if output.shape[0] != target.shape[0] or output.shape[1] != 1:
                print(f"Warning: Shape mismatch output ({output.shape}) vs target ({target.shape}). Skipping batch.")
                skipped_batches += 1
                continue
            if target.numel() == 0: # Skip if target is empty
                skipped_batches += 1
                continue

            # --- Loss Calculation ---
            weight_tensor = None
            if class_weights is not None:
                try: # Create weight tensor safely
                    target_long = target.long()
                    if ((target_long == 0) | (target_long == 1)).all(): # Check labels are 0 or 1
                         weights_list = [class_weights[k.item()] for k in target_long]
                         weight_tensor = torch.tensor(weights_list, device=device, dtype=torch.float).view(-1, 1)
                    # else: Skip weighting if labels invalid (warning printed if needed)
                except KeyError as e: print(f"Warning: Invalid target label {e} for class weights.")
                except Exception as e_w: print(f"Warning: Error creating class weights: {e_w}")

            # Calculate Binary Cross-Entropy loss
            current_loss = F.binary_cross_entropy(output, target, weight=weight_tensor, reduction='sum')

            # Check for NaN/Inf loss
            if not torch.isfinite(current_loss):
                print(f"Warning: Non-finite loss ({current_loss.item()}) detected. Skipping batch.")
                skipped_batches += 1; continue

            # --- Backward Pass and Optimization ---
            current_loss.backward()
            # Optional: Gradient Clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # --- Accumulate Metrics ---
            total_loss += current_loss.item() # Accumulate summed loss
            with torch.no_grad():
                 pred = (output > 0.5).float()
                 correct_predictions += (pred == target).sum().item()
                 total_samples += target.size(0) # Count samples processed
                 processed_graphs += batch.num_graphs

        # --- Error Handling for Batch ---
        except Exception as e:
            print(f"--- Error during training batch: {type(e).__name__} - {e} ---")
            traceback.print_exc(limit=1) # Show where error occurred
            skipped_batches += 1; continue # Skip to next batch

    # --- Epoch Summary ---
    if skipped_batches > 0:
        print(f"Skipped {skipped_batches} batches during training epoch.")

    # Calculate average loss PER SAMPLE and accuracy for the epoch
    if total_samples > 0:
        average_loss_per_sample = total_loss / total_samples
        accuracy = correct_predictions / total_samples
    else:
        print("Warning: No samples processed in training epoch.")
        average_loss_per_sample = 0.0
        accuracy = 0.0

    return average_loss_per_sample, accuracy


def evaluate_model(model: nn.Module, loader: DataLoader, device: torch.device) -> dict:
    """
    Evaluates the HybridModel on a given dataset (validation or test).

    Args:
        model (nn.Module): The PyTorch model to evaluate.
        loader (DataLoader): DataLoader for the evaluation data.
        device (torch.device): The device (CPU or CUDA) to evaluate on.

    Returns:
        dict: A dictionary containing evaluation metrics (loss, accuracy, mcc, etc.).
    """
    model.eval() # Set the model to evaluation mode
    total_loss = 0.0
    all_preds_list = [] # Store prediction arrays from batches
    all_targets_list = [] # Store target arrays from batches
    processed_samples = 0
    skipped_batches = 0

    # Disable gradient calculations during evaluation
    with torch.no_grad():
        for batch in loader:
            try:
                batch = batch.to(device)
                # --- Batch Validation ---
                required_attrs = ['x', 'edge_index', 'prot_t5_embedding', 'y', 'batch', 'central_node_idx']
                if not all(hasattr(batch, attr) and getattr(batch, attr) is not None for attr in required_attrs):
                     skipped_batches += 1; continue
                if batch.num_nodes == 0 or batch.num_graphs == 0:
                     skipped_batches += 1; continue

                # --- Forward Pass ---
                output = model(batch) # Get model prediction probabilities
                target = batch.y.view(-1, 1).float() # Ensure shape and type

                if output.shape[0] != target.shape[0] or output.shape[1] != 1:
                    skipped_batches += 1; continue
                if target.numel() == 0:
                    skipped_batches += 1; continue

                # --- Loss Calculation ---
                loss = F.binary_cross_entropy(output, target, reduction='sum')
                if not torch.isfinite(loss):
                    print(f"Warning: Non-finite loss ({loss.item()}) during evaluation. Skipping batch metrics.")
                    skipped_batches += 1; continue

                total_loss += loss.item()

                # --- Store Predictions and Targets ---
                pred_binary = (output > 0.5).float() # Threshold at 0.5 for metrics
                all_preds_list.append(pred_binary.cpu().numpy())
                all_targets_list.append(target.cpu().numpy())
                processed_samples += target.size(0)

            # --- Error Handling for Batch ---
            except Exception as e:
                print(f"--- Error during evaluation batch: {type(e).__name__} - {e} ---")
                traceback.print_exc(limit=1); skipped_batches += 1; continue

    if skipped_batches > 0: print(f"Skipped {skipped_batches} batches during evaluation.")

    # --- Calculate Metrics ---
    # Initialize metrics dictionary with defaults
    metrics = {
        'loss': float('nan'), 'accuracy': 0.0, 'balanced_acc': 0.0, 'mcc': 0.0,
        'sensitivity': 0.0, 'specificity': 0.0, 'confusion_matrix': np.zeros((2,2), dtype=int),
        'predictions': np.array([]), 'targets': np.array([]) # Store concatenated arrays
    }

    if not all_preds_list: # Handle case where no batches were successfully processed
        print("Warning: No predictions collected during evaluation.")
        return metrics

    # Concatenate results from all batches
    all_preds = np.concatenate(all_preds_list).flatten()
    all_targets = np.concatenate(all_targets_list).flatten()
    metrics['predictions'] = all_preds # Store final binary predictions
    metrics['targets'] = all_targets

    if processed_samples > 0:
        metrics['loss'] = total_loss / processed_samples # Average loss per sample
        try:
            metrics['accuracy'] = accuracy_score(all_targets, all_preds)
            metrics['balanced_acc'] = balanced_accuracy_score(all_targets, all_preds)
            # Handle MCC undefined case (e.g., constant predictions)
            try: metrics['mcc'] = matthews_corrcoef(all_targets, all_preds)
            except ValueError: metrics['mcc'] = 0.0

            cm = confusion_matrix(all_targets, all_preds, labels=[0, 1])
            metrics['confusion_matrix'] = cm
            if cm.shape == (2, 2):
                tn, fp, fn, tp = cm.ravel()
                metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        except Exception as e_metrics:
            print(f"Error calculating evaluation metrics: {e_metrics}")
            # Metrics will keep their default zero values
    else:
        print("Warning: Processed samples count is zero after evaluation loop.")

    return metrics

def print_metrics(metrics: dict, prefix: str = ""):
    """ Prints evaluation metrics in a formatted way. """
    print(f"{prefix}Loss: {metrics.get('loss', float('nan')):.4f}")
    print(f"{prefix}Accuracy: {metrics.get('accuracy', float('nan')):.4f}")
    print(f"{prefix}Balanced Acc: {metrics.get('balanced_acc', float('nan')):.4f}")
    print(f"{prefix}MCC: {metrics.get('mcc', float('nan')):.4f}")
    print(f"{prefix}Sensitivity: {metrics.get('sensitivity', float('nan')):.4f}")
    print(f"{prefix}Specificity: {metrics.get('specificity', float('nan')):.4f}")
    print(f"{prefix}Confusion Matrix:")
    cm = metrics.get('confusion_matrix')
    if isinstance(cm, np.ndarray) and cm.shape == (2,2):
        print(f"  [[TN={cm[0,0]:<5d} FP={cm[0,1]:<5d}]")
        print(f"   [FN={cm[1,0]:<5d} TP={cm[1,1]:<5d}]]")
    else:
        print(f"  {cm}") # Print CM as is if not standard 2x2 numpy array


def train_and_evaluate_hybrid_model(
    train_df_aligned: pd.DataFrame, train_prot_t5: np.ndarray,
    test_df_aligned: pd.DataFrame, test_prot_t5: np.ndarray,
    model_config: Dict, # Contains GNN type, hypers, flags like use_edge_features
    distance_threshold: float = 8.0,
    n_splits: int = 5
    ) -> Tuple[Optional[Dict], Optional[Dict]]:
    """
    Trains and evaluates a single Hybrid GNN+ProtT5 model configuration using K-fold CV.

    Args:
        train_df_aligned (pd.DataFrame): Aligned training DataFrame.
        train_prot_t5 (np.ndarray): ProtT5 embeddings for the aligned training data.
        test_df_aligned (pd.DataFrame): Aligned testing DataFrame.
        test_prot_t5 (np.ndarray): ProtT5 embeddings for the aligned testing data.
        model_config (dict): Configuration dictionary including 'gnn_type',
                             'use_edge_features', and hyperparameters.
        distance_threshold (float): Distance cutoff for graph edges.
        n_splits (int): Number of folds for cross-validation.

    Returns:
        tuple: (avg_cv_metrics, final_test_metrics)
            - avg_cv_metrics (Optional[Dict]): Dictionary of average validation metrics across CV folds.
            - final_test_metrics (Optional[Dict]): Dictionary of metrics on the test set using ensemble.
    """
    gnn_type = model_config['gnn_type']
    use_edge_features = model_config.get('use_edge_features', False)
    run_label = f"Hybrid {gnn_type.upper()}+ProtT5 (Edge Feats: {use_edge_features})"
    print(f"\n{'='*15} Evaluating: {run_label} {'='*15}")
    cv_start_time = time.time()

    # --- Data Preparation (using function from Block 5) ---
    # Pass use_ss_node_feature based on config or default
    use_ss_node = model_config.get('use_ss_node_feature', True)
    print(f"Preparing graph data (SS Node Features: {use_ss_node}, Edge Features: {use_edge_features})...")
    train_graphs, train_labels = prepare_graph_data(
        train_df_aligned, train_prot_t5, distance_threshold,
        use_ss_node_feature=use_ss_node,
        include_edge_features=use_edge_features
    )
    # Prepare test graphs if test data exists
    has_test_data = not test_df_aligned.empty and len(test_prot_t5) > 0
    test_graphs: List[Data] = []
    test_labels: List[int] = [] # Get labels from aligned df for evaluation
    if has_test_data:
        test_graphs, test_labels = prepare_graph_data(
            test_df_aligned, test_prot_t5, distance_threshold,
            use_ss_node_feature=use_ss_node,
            include_edge_features=use_edge_features
        )
        if not test_graphs: has_test_data = False # Update flag if failed
    if not train_graphs:
        print(f"ERROR: No training graphs created for {run_label}. Aborting run.")
        return None, None

    # --- Compute Degree Histogram for PNA (if needed) ---
    deg_histogram = None
    if gnn_type.lower() == 'pna':
        print("Calculating node degree histogram for PNA (from training graphs)...")
        # Use the robust degree calculation logic from previous versions
        max_degree = -1; degrees = []
        for data in train_graphs:
             num_nodes = data.num_nodes
             if num_nodes == 0: continue
             if hasattr(data, 'edge_index') and data.edge_index is not None and data.edge_index.numel() > 0:
                 valid_mask = (data.edge_index[0] < num_nodes) & (data.edge_index[1] < num_nodes)
                 valid_idx = data.edge_index[:, valid_mask]
                 if valid_idx.numel() > 0: deg_list = degree(valid_idx[1], num_nodes=num_nodes, dtype=torch.long)
                 else: deg_list = torch.zeros(num_nodes, dtype=torch.long)
             else: deg_list = torch.zeros(num_nodes, dtype=torch.long)
             degrees.append(deg_list)
             if deg_list.numel() > 0: max_degree = max(max_degree, deg_list.max().item())
        if max_degree == -1: max_degree = 0
        # Ensure histogram size is at least 1
        deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
        if degrees:
            all_degrees = torch.cat(degrees, dim=0)
            if all_degrees.numel() > 0:
                counts = torch.bincount(all_degrees, minlength=max_degree + 1)
                copy_len = min(deg_histogram.numel(), counts.numel())
                deg_histogram[:copy_len] = counts[:copy_len]
        print(f"PNA Max degree: {max_degree}")
        if deg_histogram.sum() == 0 and deg_histogram.numel() > 0:
            print("Warning: PNA Degree histogram is all zeros. Providing minimal.")
            deg_histogram[0] = 1 # Assign count 1 to degree 0 if all else fails
        deg_histogram = deg_histogram.to(device) # Move histogram to target device

    # --- Class Weights ---
    total_train = len(train_labels); pos_train = sum(train_labels); neg_train = total_train - pos_train
    class_weights_dict = {0: total_train / (2 * neg_train) if neg_train > 0 else 1.0,
                          1: total_train / (2 * pos_train) if pos_train > 0 else 1.0}
    print(f"Calculated Class weights: {class_weights_dict}")

    # --- Cross-validation Setup ---
    kfold = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    fold_metrics_list: List[Dict] = []
    test_predictions_list: List[np.ndarray] = [] # Store binary predictions from test set

    # Determine feature dimensions from data
    try:
        node_feature_dim = train_graphs[0].num_node_features
        # Determine edge_dim based on data only if model requires it
        edge_dim_arg = None
        if use_edge_features:
            first_graph_with_edges = next((g for g in train_graphs if hasattr(g, 'edge_attr') and g.edge_attr is not None), None)
            if first_graph_with_edges is not None:
                edge_dim_arg = first_graph_with_edges.num_edge_features
                if edge_dim_arg != EXPECTED_EDGE_FEATURE_DIM:
                     print(f"Warning: Data edge_dim ({edge_dim_arg}) != Expected ({EXPECTED_EDGE_FEATURE_DIM}). Using data dim.")
            else: # Edge features requested but none found in data
                 if gnn_type == 'gine': # GINE requires edge features
                      print("ERROR: GINE requires edge features, but none generated/found in training data.")
                      return None, None
                 print(f"Warning: use_edge_features=True for {gnn_type} but no 'edge_attr' found. Proceeding without edge features.")
                 edge_dim_arg = None # Fallback to no edge features
        print(f"Feature dimensions used for model init: Nodes={node_feature_dim}, Edges={edge_dim_arg}")
    except (IndexError, AttributeError) as e:
        print(f"ERROR: Cannot determine feature dimensions from train_graphs: {e}. Aborting.")
        return None, None

    # --- K-Fold Loop ---
    for fold, (train_idx, val_idx) in enumerate(kfold.split(np.zeros(len(train_graphs)), train_labels), 1):
        fold_start_time = time.time()
        print(f"\n===== Fold {fold}/{n_splits} =====")
        train_fold_graphs = [train_graphs[i] for i in train_idx]
        val_fold_graphs = [train_graphs[i] for i in val_idx]
        print(f"  Training samples: {len(train_fold_graphs)}, Validation samples: {len(val_fold_graphs)}")

        # DataLoaders
        batch_size = model_config.get('batch_size', 32)
        num_workers = 2 if device.type == 'cuda' else 0; pin_memory = torch.cuda.is_available()
        train_loader = DataLoader(train_fold_graphs, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
        val_loader = DataLoader(val_fold_graphs, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

        # --- Model Initialization ---
        try:
            model = HybridModel(
                gnn_type=gnn_type,
                node_feature_dim=node_feature_dim,
                hidden_dim=model_config.get('hidden_dim', 128),
                prot_t5_dim=PROT_T5_DIM, # Constant
                # GNN specific args
                deg=deg_histogram if gnn_type == 'pna' else None,
                edge_dim=edge_dim_arg, # Pass determined/required edge dim
                heads=model_config.get('heads', 4),
                # Common GNN hypers
                layers=model_config.get('layers', 3),
                gnn_dropout=model_config.get('gnn_dropout', 0.4),
                # ProtT5 MLP hypers
                pt5_mlp_dropout=model_config.get('pt5_mlp_dropout', 0.4),
                # Classifier hypers
                classifier_dropout=model_config.get('classifier_dropout', 0.5)
            ).to(device)
        except Exception as model_init_error:
             print(f"--- ERROR initializing Hybrid Model for fold {fold}: {model_init_error} ---")
             traceback.print_exc(); continue # Skip this fold

        # --- Optimizer and Scheduler ---
        # optimizer = torch.optim.AdamW(model.parameters(), # Using AdamW
        #                               lr=model_config.get('lr', 0.001),
        #                               weight_decay=model_config.get('weight_decay', 0.01))
        optimizer = torch.optim.Adam(model.parameters(), # Match old code
                              lr=model_config.get('lr', 0.001),
                              weight_decay=model_config.get('weight_decay', 0.01))
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=model_config.get('scheduler_patience', 5), verbose=False, min_lr=1e-6)

        # --- Training Loop ---
        epochs = model_config.get('epochs', 50); patience = model_config.get('patience', 10)
        best_val_loss = float('inf'); epochs_no_improve = 0; best_state_dict = None
        print(f"  Starting training for max {epochs} epochs (Patience: {patience})...")
        for epoch in range(epochs):
            train_loss, train_acc = train_model(model, train_loader, optimizer, device, class_weights=class_weights_dict)
            val_metrics = evaluate_model(model, val_loader, device)
            val_loss = val_metrics.get('loss', float('inf')) # Get loss safely

            scheduler.step(val_loss)
            current_lr = optimizer.param_groups[0]['lr']

            # Optional: Print epoch progress less frequently
            if (epoch + 1) % 5 == 0 or epoch == 0 or epoch == epochs - 1:
                 print(f"  Epoch {epoch+1:02d}/{epochs} | Tr L:{train_loss:.4f}, A:{train_acc:.4f} | V L:{val_loss:.4f}, BAcc:{val_metrics['balanced_acc']:.4f} | LR:{current_lr:.1e}")

            # Early stopping check
            if val_loss < best_val_loss:
                 best_val_loss = val_loss; epochs_no_improve = 0; best_state_dict = model.state_dict().copy()
            else:
                 epochs_no_improve += 1
                 if epochs_no_improve >= patience: print(f"  Early stopping triggered at epoch {epoch+1}. Best val loss: {best_val_loss:.4f}"); break
        # --- End Epoch Loop ---

        # Load best model state for fold evaluation
        if best_state_dict is not None: model.load_state_dict(best_state_dict)
        else: print("  Warning: No improvement found / No best model state saved. Using last state."); # Should only happen if training fails instantly

        # Evaluate best model on validation set
        print(f"\n  Evaluating best model on validation set for Fold {fold}...")
        final_val_metrics = evaluate_model(model, val_loader, device)
        fold_metrics_list.append(final_val_metrics)
        print_metrics(final_val_metrics, prefix=f"  Fold {fold} Validation ")

        # Evaluate best model on test set (if data exists)
        if has_test_data and test_graphs:
            print(f"  Predicting on test set using Fold {fold}'s best model...")
            # Use a potentially larger batch size for inference
            test_loader = DataLoader(test_graphs, batch_size=batch_size*2, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
            test_fold_eval_metrics = evaluate_model(model, test_loader, device)
            if 'predictions' in test_fold_eval_metrics and isinstance(test_fold_eval_metrics['predictions'], np.ndarray):
                 test_predictions_list.append(test_fold_eval_metrics['predictions']) # Store binary predictions
            # else: print(f"  Warning: Invalid 'predictions' from test eval fold {fold}.") # Less verbose

        fold_end_time = time.time()
        print(f"===== Fold {fold} completed in {fold_end_time - fold_start_time:.2f} seconds =====")
    # --- End Fold Loop ---

    # --- Aggregate and Report CV Results ---
    print(f"\n--- Cross-validation Summary for {run_label} (Avg +/- Std Dev on Val Folds) ---")
    avg_cv_metrics: Optional[Dict] = None
    if fold_metrics_list:
        metrics_to_report = ['loss', 'accuracy', 'balanced_acc', 'mcc', 'sensitivity', 'specificity']
        avg_metrics_calc = {}
        for metric in metrics_to_report:
            values = [m.get(metric) for m in fold_metrics_list if m and m.get(metric) is not None and np.isfinite(m.get(metric))]
            if values:
                 avg_metrics_calc[metric] = np.mean(values)
                 std_dev = np.std(values)
                 print(f"  Avg Val {metric:<15}: {avg_metrics_calc[metric]:.4f} ± {std_dev:.4f}")
            else: avg_metrics_calc[metric] = None; print(f"  Avg Val {metric:<15}: N/A")
        avg_cv_metrics = avg_metrics_calc # Store the calculated dict
    else: print("  No metrics recorded from cross-validation folds.")

    # --- Final Test Set Evaluation (Ensembled) ---
    final_test_metrics: Optional[Dict] = None
    if has_test_data and test_predictions_list and test_labels is not None and len(test_labels) > 0:
        print(f"\n--- Final Test Set Performance for {run_label} (Ensemble: Majority Vote) ---")
        try:
             num_test = len(test_labels)
             valid_preds = [p for p in test_predictions_list if isinstance(p, np.ndarray) and len(p) == num_test]
             if len(valid_preds) != n_splits: print(f"Warning: Number of valid test predictions ({len(valid_preds)}) != n_splits ({n_splits}).")

             if valid_preds: # Proceed only if valid predictions exist
                test_pred_stack = np.stack(valid_preds, axis=0) # [n_folds, n_samples]
                # Majority vote based on binary predictions
                summed_preds = np.sum(test_pred_stack, axis=0)
                threshold = len(valid_preds) / 2.0
                test_pred_ensemble_binary = (summed_preds > threshold).astype(int)

                # Calculate final metrics
                final_test_metrics = {}
                final_test_metrics['targets'] = np.array(test_labels) # Store labels for reference
                final_test_metrics['predictions'] = test_pred_ensemble_binary # Store final ensemble prediction
                final_test_metrics['accuracy'] = accuracy_score(test_labels, test_pred_ensemble_binary)
                final_test_metrics['balanced_acc'] = balanced_accuracy_score(test_labels, test_pred_ensemble_binary)
                try: final_test_metrics['mcc'] = matthews_corrcoef(test_labels, test_pred_ensemble_binary)
                except ValueError: final_test_metrics['mcc'] = 0.0
                final_test_metrics['confusion_matrix'] = confusion_matrix(test_labels, test_pred_ensemble_binary, labels=[0, 1])
                cm = final_test_metrics['confusion_matrix']
                sens, spec = 0.0, 0.0
                if cm.shape == (2, 2): tn, fp, fn, tp = cm.ravel(); sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0; spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
                final_test_metrics['sensitivity'] = sens; final_test_metrics['specificity'] = spec
                # Assign loss from avg CV metrics for reporting consistency
                final_test_metrics['loss'] = avg_cv_metrics.get('loss', float('nan')) if avg_cv_metrics else float('nan')

                print_metrics(final_test_metrics, prefix="Ensemble Test ")
             else: print("  Skipping test ensemble: No valid predictions available from folds.")
        except Exception as e_ensemble: print(f"Error during test set ensembling/evaluation: {e_ensemble}"); traceback.print_exc(limit=2)
    # else: print reason why test eval skipped (handled earlier)

    cv_end_time = time.time()
    print(f"\n--- Evaluation for {run_label} completed in {cv_end_time - cv_start_time:.2f} seconds ---")

    # Return the aggregated results for this run
    return avg_cv_metrics, final_test_metrics

Reproducibility seeds set with SEED=42
Using device: cuda
CUDA Device Name: Tesla V100-PCIE-32GB
ProtT5 embedding dimension set to: 1024
Expected rich edge feature dimension set to: 17


In [None]:
if __name__ == "__main__":
    overall_start_time = time.time()

    # --- Configuration ---
    DISTANCE_THRESHOLD = 8.0 # Angstrom cutoff for graph edges

    # Default hyperparameters (can be overridden in specific configs)
    default_hypers = {
        'hidden_dim': 128,
        'layers': 3,          # Number of GNN layers
        'heads': 4,           # Number of heads for GATv2
        'gnn_dropout': 0.4,   # Dropout within GNN layers
        'pt5_mlp_dropout': 0.4, # Dropout within ProtT5 MLP
        'classifier_dropout': 0.5, # Dropout in final classifier
        'lr': 0.001,          # Learning rate
        'weight_decay': 0.01, # Weight decay for AdamW
        'epochs': 50,         # Max epochs per fold
        'patience': 10,         # Early stopping patience
        'batch_size': 32,       # Batch size for training/evaluation
        'use_ss_node_feature': True # Include SS one-hot in node features
    }

    # Define model configurations to test
    model_configs = [
        # --- Models NOT using Edge Features ---
        {
            'config_label': 'GCN_noEF', # Descriptive label
            'gnn_type': 'gcn',
            'use_edge_features': False,
            **default_hypers
        },
        {
            'config_label': 'PNA_noEF',
            'gnn_type': 'pna',
            'use_edge_features': False,
            **default_hypers
        },
        {
            'config_label': 'GATv2_noEF',
            'gnn_type': 'gatv2',
            'use_edge_features': False, # Explicitly no edge features
            **default_hypers # Includes default heads=4
        },
        # --- Models USING Edge Features ---
        {
            'config_label': 'GATv2_EF',
            'gnn_type': 'gatv2',
            'use_edge_features': True, # Will use edge_attr if created
            **default_hypers # Includes default heads=4
        },
        {
            'config_label': 'GINE_EF',
            'gnn_type': 'gine',
            'use_edge_features': True, # GINE requires edge features
            **default_hypers
        },
        # --- Add more specific hyperparameter variations below if needed ---
        # e.g., {'config_label': 'GCN_L4_H256', 'gnn_type': 'gcn', 'use_edge_features': False, **default_hypers, 'layers': 4, 'hidden_dim': 256},
    ]

    # --- File Paths (!!! IMPORTANT: Update these paths !!!) ---
    # Adjust relative paths based on where we run the script
    base_data_path = "../../data/"

    try:
        train_csv_path = os.path.join(base_data_path, "train/structure/processed_features_train.csv")
        test_csv_path = os.path.join(base_data_path, "test/structure/processed_features_test.csv")
        train_pos_prot_t5_path = os.path.join(base_data_path, 'train/PLM/train_positive_ProtT5-XL-UniRef50.csv')
        train_neg_prot_t5_path = os.path.join(base_data_path, 'train/PLM/train_negative_ProtT5-XL-UniRef50.csv')
        test_pos_prot_t5_path = os.path.join(base_data_path, 'test/PLM/test_positive_ProtT5-XL-UniRef50.csv')
        test_neg_prot_t5_path = os.path.join(base_data_path, 'test/PLM/test_negative_ProtT5-XL-UniRef50.csv')

        print("\n--- File Paths Used ---")
        print(f"Train CSV: {os.path.abspath(train_csv_path)}")
        print(f"Test CSV: {os.path.abspath(test_csv_path)}")
        print(f"Train ProtT5 (+): {os.path.abspath(train_pos_prot_t5_path)}")
        # ... print other paths if desired ...
        print("-----------------------")

        # --- Data Loading ---
        print("\nLoading main feature CSV data...")
        if not os.path.exists(train_csv_path): raise FileNotFoundError(f"Train CSV: {train_csv_path}")
        if not os.path.exists(test_csv_path): raise FileNotFoundError(f"Test CSV: {test_csv_path}")
        train_df_orig = pd.read_csv(train_csv_path)
        test_df_orig = pd.read_csv(test_csv_path)
        print(f"Loaded {len(train_df_orig)} training, {len(test_df_orig)} test samples from CSV.")
        if train_df_orig.empty: raise ValueError("Training DataFrame empty.")

        # --- Load ProtT5 Data ---
        train_pos_dict, train_neg_dict = load_prot_t5_data(train_pos_prot_t5_path, train_neg_prot_t5_path)
        test_pos_dict, test_neg_dict = load_prot_t5_data(test_pos_prot_t5_path, test_neg_prot_t5_path)

        if train_pos_dict is None or train_neg_dict is None: raise RuntimeError("Failed to load Training ProtT5 embeddings.")
        # Handle potentially missing test embeddings gracefully
        can_test = not test_df_orig.empty and test_pos_dict is not None and test_neg_dict is not None

        # --- Align Data ---
        print("\nAligning training data...")
        X_train_prot_t5, train_df_aligned = prepare_aligned_data(train_df_orig, train_pos_dict, train_neg_dict)

        if train_df_aligned.empty or X_train_prot_t5.size == 0: raise ValueError("Training data alignment resulted in empty data.")

        X_test_prot_t5, test_df_aligned = None, pd.DataFrame(columns=test_df_orig.columns) # Defaults
        if can_test:
            print("\nAligning test data...")
            X_test_prot_t5, test_df_aligned = prepare_aligned_data(test_df_orig, test_pos_dict, test_neg_dict)
            if test_df_aligned.empty or X_test_prot_t5.size == 0:
                print("Warning: Test data alignment resulted in empty data. Test evaluation skipped.")
                can_test = False
        else: print("\nTest data alignment skipped.")

        print(f"\n--- Aligned Data Shapes ---")
        print(f"Train DF: {train_df_aligned.shape}, Train ProtT5: {X_train_prot_t5.shape}")
        if can_test: print(f"Test DF : {test_df_aligned.shape}, Test ProtT5 : {X_test_prot_t5.shape}")
        else: print("Test Data: Not available or alignment failed.")
        print("--------------------------\n")

    # --- Error handling for data loading/alignment ---
    except FileNotFoundError as e: print(f"\nFATAL ERROR: Data file not found.\n {e}\nPlease check paths."); exit()
    except ValueError as e: print(f"\nFATAL ERROR: Problem with data content or alignment.\n {e}"); exit()
    except RuntimeError as e: print(f"\nFATAL ERROR: Runtime issue during data loading.\n {e}"); exit()
    except Exception as e: print(f"FATAL ERROR during data setup: {e}"); traceback.print_exc(); exit()


    # --- Store Results ---
    all_results_summary = {} # Store key metrics for final comparison

    # --- Loop through configurations and run training/evaluation ---
    for config in model_configs:
        config_label = config.get('config_label', f"{config['gnn_type']}_{'EF' if config.get('use_edge_features', False) else 'noEF'}")

        # --- Run evaluation ---
        try:
            # Call the main training/CV function with ALIGNED data
            avg_cv_metrics, final_test_metrics = train_and_evaluate_hybrid_model(
                train_df_aligned=train_df_aligned,
                train_prot_t5=X_train_prot_t5,
                test_df_aligned=test_df_aligned,
                test_prot_t5=X_test_prot_t5,
                model_config=config, # Pass the specific config for this run
                distance_threshold=DISTANCE_THRESHOLD,
                n_splits=5 # Fixed 5 folds
            )
            # Store key results for summary
            all_results_summary[config_label] = {
                'avg_val_bacc': avg_cv_metrics.get('balanced_acc', float('nan')) if avg_cv_metrics else float('nan'),
                'avg_val_mcc': avg_cv_metrics.get('mcc', float('nan')) if avg_cv_metrics else float('nan'),
                'test_bacc': final_test_metrics.get('balanced_acc', float('nan')) if final_test_metrics else float('nan'),
                'test_mcc': final_test_metrics.get('mcc', float('nan')) if final_test_metrics else float('nan')
            }
        except Exception as e:
            print(f"\n--- !!! CRITICAL ERROR during evaluation for {config_label} !!! ---")
            print(f"Error: {e}")
            traceback.print_exc()
            all_results_summary[config_label] = {'error': str(e)} # Log the error
        # --- End of loop for one configuration ---


    # --- Final Summary ---
    print("\n\n" + "="*25 + " Overall Run Summary " + "="*25)
    print(f"{'Configuration':<20} | {'AvgVal BAcc':<12} | {'AvgVal MCC':<12} | {'Test BAcc':<12} | {'Test MCC':<12}")
    print("-" * 75)
    # Sort by label maybe
    sorted_labels = sorted(all_results_summary.keys())
    for config_label in sorted_labels:
        results = all_results_summary[config_label]
        if 'error' in results:
            print(f"{config_label:<20} | {'ERROR':<12} | {'ERROR':<12} | {'ERROR':<12} | {'ERROR':<12}")
            print(f"  Error: {results['error']}")
        else:
            val_bacc_str = f"{results['avg_val_bacc']:.4f}" if not np.isnan(results['avg_val_bacc']) else "N/A"
            val_mcc_str = f"{results['avg_val_mcc']:.4f}" if not np.isnan(results['avg_val_mcc']) else "N/A"
            test_bacc_str = f"{results['test_bacc']:.4f}" if not np.isnan(results['test_bacc']) else "N/A"
            test_mcc_str = f"{results['test_mcc']:.4f}" if not np.isnan(results['test_mcc']) else "N/A"
            print(f"{config_label:<20} | {val_bacc_str:<12} | {val_mcc_str:<12} | {test_bacc_str:<12} | {test_mcc_str:<12}")

    print("\n" + "="*75)
    overall_end_time = time.time()
    print(f"Total execution finished in {(overall_end_time - overall_start_time)/60:.2f} minutes.")
    print("===== Evaluation Complete =====")


--- File Paths Used ---
Train CSV: /home/ubuntu/data/hai/thesis/training_code/data/processed_features_fixed_train_contactmap.csv
Test CSV: /home/ubuntu/data/hai/thesis/training_code/data/processed_features_fixed_test_contactmap.csv
Train ProtT5 (+): /home/ubuntu/data/hai/thesis/data/train/features/train_positive_ProtT5-XL-UniRef50.csv
-----------------------

Loading main feature CSV data...
Loaded 8853 training, 2737 test samples from CSV.
Loading ProtT5 data...
  Positive file: ../../data/train/features/train_positive_ProtT5-XL-UniRef50.csv
  Negative file: ../../data/train/features/train_negative_ProtT5-XL-UniRef50.csv
Loaded 4750 positive (skipped 0) and 4750 negative (skipped 0) ProtT5 embeddings.
ProtT5 Loading finished in 2.16 seconds.
Loading ProtT5 data...
  Positive file: ../../data/test/features/test_positive_ProtT5-XL-UniRef50.csv
  Negative file: ../../data/test/features/test_negative_ProtT5-XL-UniRef50.csv
Loaded 253 positive (skipped 0) and 2972 negative (skipped 0) Pro



  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6432, A:0.6257 | V L:0.5736, BAcc:0.6966 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5016, A:0.7584 | V L:0.5510, BAcc:0.7147 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3473, A:0.8544 | V L:0.6177, BAcc:0.7120 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1479, A:0.9438 | V L:0.8627, BAcc:0.7083 | LR:5.0e-04




  Early stopping triggered at epoch 16. Best val loss: 0.5507

  Evaluating best model on validation set for Fold 1...
  Fold 1 Validation Loss: 0.7951
  Fold 1 Validation Accuracy: 0.7233
  Fold 1 Validation Balanced Acc: 0.7238
  Fold 1 Validation MCC: 0.4473
  Fold 1 Validation Sensitivity: 0.7116
  Fold 1 Validation Specificity: 0.7359
  Fold 1 Validation Confusion Matrix:
  [[TN=627   FP=225  ]
   [FN=265   TP=654  ]]
  Predicting on test set using Fold 1's best model...




===== Fold 1 completed in 56.10 seconds =====

===== Fold 2/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GCN + ProtT5)
  GNN configured NOT to use edge features.
Initialized GCNNetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6490, A:0.6227 | V L:0.5785, BAcc:0.7002 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5052, A:0.7544 | V L:0.5469, BAcc:0.7115 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3569, A:0.8505 | V L:0.5914, BAcc:0.7122 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1585, A:0.9404 | V L:0.8781, BAcc:0.7103 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5469

  Evaluating best model on validation set for Fold 2...
  Fold 2 Validation Loss: 0.8781
  Fold 2 Validation Accuracy: 0.7098
  Fold 2 Validation Balanced Acc: 0.7103
  Fold 2 Validation MCC: 0.4203
  Fold 2 Validation Sensitivity: 0.6964
  Fold 2 Validation Specificity: 0.7242
  Fold 2 Validation Confusion Matrix:
  [[TN=617   FP=235  ]
   [FN=279   TP=640  ]]
  Predicting on test set using Fold 2's best model...




===== Fold 2 completed in 52.20 seconds =====

===== Fold 3/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GCN + ProtT5)
  GNN configured NOT to use edge features.
Initialized GCNNetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6394, A:0.6340 | V L:0.5749, BAcc:0.6988 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.4953, A:0.7691 | V L:0.5627, BAcc:0.7061 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3607, A:0.8486 | V L:0.6439, BAcc:0.7102 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1538, A:0.9446 | V L:0.9505, BAcc:0.7029 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5627

  Evaluating best model on validation set for Fold 3...
  Fold 3 Validation Loss: 0.9505
  Fold 3 Validation Accuracy: 0.7024
  Fold 3 Validation Balanced Acc: 0.7029
  Fold 3 Validation MCC: 0.4056
  Fold 3 Validation Sensitivity: 0.6895
  Fold 3 Validation Specificity: 0.7163
  Fold 3 Validation Confusion Matrix:
  [[TN=611   FP=242  ]
   [FN=285   TP=633  ]]
  Predicting on test set using Fold 3's best model...




===== Fold 3 completed in 53.03 seconds =====

===== Fold 4/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GCN + ProtT5)
  GNN configured NOT to use edge features.
Initialized GCNNetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6475, A:0.6187 | V L:0.5683, BAcc:0.7072 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5054, A:0.7642 | V L:0.5449, BAcc:0.7277 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3093, A:0.8793 | V L:0.7057, BAcc:0.7021 | LR:5.0e-04




  Early stopping triggered at epoch 12. Best val loss: 0.5446

  Evaluating best model on validation set for Fold 4...
  Fold 4 Validation Loss: 0.7378
  Fold 4 Validation Accuracy: 0.7203
  Fold 4 Validation Balanced Acc: 0.7213
  Fold 4 Validation MCC: 0.4426
  Fold 4 Validation Sensitivity: 0.6950
  Fold 4 Validation Specificity: 0.7477
  Fold 4 Validation Confusion Matrix:
  [[TN=637   FP=215  ]
   [FN=280   TP=638  ]]
  Predicting on test set using Fold 4's best model...




===== Fold 4 completed in 42.89 seconds =====

===== Fold 5/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GCN + ProtT5)
  GNN configured NOT to use edge features.
Initialized GCNNetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6444, A:0.6213 | V L:0.5779, BAcc:0.6988 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5079, A:0.7572 | V L:0.5506, BAcc:0.7290 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3643, A:0.8396 | V L:0.6362, BAcc:0.7243 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1645, A:0.9372 | V L:0.9097, BAcc:0.7271 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5506

  Evaluating best model on validation set for Fold 5...
  Fold 5 Validation Loss: 0.9097
  Fold 5 Validation Accuracy: 0.7282
  Fold 5 Validation Balanced Acc: 0.7271
  Fold 5 Validation MCC: 0.4553
  Fold 5 Validation Sensitivity: 0.7582
  Fold 5 Validation Specificity: 0.6960
  Fold 5 Validation Confusion Matrix:
  [[TN=593   FP=259  ]
   [FN=222   TP=696  ]]
  Predicting on test set using Fold 5's best model...




===== Fold 5 completed in 52.24 seconds =====

--- Cross-validation Summary for Hybrid GCN+ProtT5 (Edge Feats: False) (Avg +/- Std Dev on Val Folds) ---
  Avg Val loss           : 0.8542 ± 0.0774
  Avg Val accuracy       : 0.7168 ± 0.0094
  Avg Val balanced_acc   : 0.7171 ± 0.0091
  Avg Val mcc            : 0.4342 ± 0.0184
  Avg Val sensitivity    : 0.7102 ± 0.0251
  Avg Val specificity    : 0.7240 ± 0.0176

--- Final Test Set Performance for Hybrid GCN+ProtT5 (Edge Feats: False) (Ensemble: Majority Vote) ---
Ensemble Test Loss: 0.8542
Ensemble Test Accuracy: 0.7095
Ensemble Test Balanced Acc: 0.7165
Ensemble Test MCC: 0.2605
Ensemble Test Sensitivity: 0.7250
Ensemble Test Specificity: 0.7080
Ensemble Test Confusion Matrix:
  [[TN=1768  FP=729  ]
   [FN=66    TP=174  ]]

--- Evaluation for Hybrid GCN+ProtT5 (Edge Feats: False) completed in 296.11 seconds ---

Preparing graph data (SS Node Features: True, Edge Features: False)...
Preparing PyG graph data (GNN+ProtT5 format: Edge Feats=F



  Epoch 01/50 | Tr L:0.6469, A:0.6200 | V L:0.5765, BAcc:0.6973 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5004, A:0.7591 | V L:0.5973, BAcc:0.7125 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3459, A:0.8546 | V L:0.6518, BAcc:0.7162 | LR:5.0e-04




  Early stopping triggered at epoch 14. Best val loss: 0.5471

  Evaluating best model on validation set for Fold 1...
  Fold 1 Validation Loss: 0.8212
  Fold 1 Validation Accuracy: 0.7092
  Fold 1 Validation Balanced Acc: 0.7100
  Fold 1 Validation MCC: 0.4198
  Fold 1 Validation Sensitivity: 0.6899
  Fold 1 Validation Specificity: 0.7300
  Fold 1 Validation Confusion Matrix:
  [[TN=622   FP=230  ]
   [FN=285   TP=634  ]]
  Predicting on test set using Fold 1's best model...




===== Fold 1 completed in 134.14 seconds =====

===== Fold 2/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=PNA + ProtT5)
  GNN configured NOT to use edge features.
Initialized PNANetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6444, A:0.6224 | V L:0.5786, BAcc:0.6974 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5017, A:0.7616 | V L:0.5850, BAcc:0.7131 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3491, A:0.8587 | V L:0.6619, BAcc:0.6965 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1546, A:0.9451 | V L:0.8493, BAcc:0.7057 | LR:5.0e-04




  Early stopping triggered at epoch 16. Best val loss: 0.5705

  Evaluating best model on validation set for Fold 2...
  Fold 2 Validation Loss: 0.8561
  Fold 2 Validation Accuracy: 0.7115
  Fold 2 Validation Balanced Acc: 0.7105
  Fold 2 Validation MCC: 0.4216
  Fold 2 Validation Sensitivity: 0.7356
  Fold 2 Validation Specificity: 0.6854
  Fold 2 Validation Confusion Matrix:
  [[TN=584   FP=268  ]
   [FN=243   TP=676  ]]
  Predicting on test set using Fold 2's best model...




===== Fold 2 completed in 163.34 seconds =====

===== Fold 3/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=PNA + ProtT5)
  GNN configured NOT to use edge features.
Initialized PNANetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6412, A:0.6212 | V L:0.5868, BAcc:0.6903 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5049, A:0.7656 | V L:0.5922, BAcc:0.7063 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.2941, A:0.8859 | V L:0.7476, BAcc:0.7104 | LR:5.0e-04




  Early stopping triggered at epoch 12. Best val loss: 0.5666

  Evaluating best model on validation set for Fold 3...
  Fold 3 Validation Loss: 0.8723
  Fold 3 Validation Accuracy: 0.6838
  Fold 3 Validation Balanced Acc: 0.6889
  Fold 3 Validation MCC: 0.3917
  Fold 3 Validation Sensitivity: 0.5490
  Fold 3 Validation Specificity: 0.8288
  Fold 3 Validation Confusion Matrix:
  [[TN=707   FP=146  ]
   [FN=414   TP=504  ]]
  Predicting on test set using Fold 3's best model...




===== Fold 3 completed in 133.47 seconds =====

===== Fold 4/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=PNA + ProtT5)
  GNN configured NOT to use edge features.
Initialized PNANetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6472, A:0.6215 | V L:0.5637, BAcc:0.7056 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5004, A:0.7582 | V L:0.5645, BAcc:0.7201 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.2899, A:0.8817 | V L:0.7257, BAcc:0.7157 | LR:5.0e-04




  Early stopping triggered at epoch 12. Best val loss: 0.5544

  Evaluating best model on validation set for Fold 4...
  Fold 4 Validation Loss: 0.7935
  Fold 4 Validation Accuracy: 0.6972
  Fold 4 Validation Balanced Acc: 0.6979
  Fold 4 Validation MCC: 0.3957
  Fold 4 Validation Sensitivity: 0.6776
  Fold 4 Validation Specificity: 0.7183
  Fold 4 Validation Confusion Matrix:
  [[TN=612   FP=240  ]
   [FN=296   TP=622  ]]
  Predicting on test set using Fold 4's best model...




===== Fold 4 completed in 153.79 seconds =====

===== Fold 5/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=PNA + ProtT5)
  GNN configured NOT to use edge features.
Initialized PNANetwork: Layers=3, Hidden=128, Dropout=0.4
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6491, A:0.6205 | V L:0.5914, BAcc:0.6822 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.4985, A:0.7697 | V L:0.6676, BAcc:0.6879 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3416, A:0.8614 | V L:0.6734, BAcc:0.7163 | LR:5.0e-04




  Early stopping triggered at epoch 14. Best val loss: 0.5569

  Evaluating best model on validation set for Fold 5...
  Fold 5 Validation Loss: 0.8424
  Fold 5 Validation Accuracy: 0.7198
  Fold 5 Validation Balanced Acc: 0.7210
  Fold 5 Validation MCC: 0.4422
  Fold 5 Validation Sensitivity: 0.6885
  Fold 5 Validation Specificity: 0.7535
  Fold 5 Validation Confusion Matrix:
  [[TN=642   FP=210  ]
   [FN=286   TP=632  ]]
  Predicting on test set using Fold 5's best model...




===== Fold 5 completed in 183.31 seconds =====

--- Cross-validation Summary for Hybrid PNA+ProtT5 (Edge Feats: False) (Avg +/- Std Dev on Val Folds) ---
  Avg Val loss           : 0.8371 ± 0.0275
  Avg Val accuracy       : 0.7043 ± 0.0125
  Avg Val balanced_acc   : 0.7057 ± 0.0111
  Avg Val mcc            : 0.4142 ± 0.0185
  Avg Val sensitivity    : 0.6681 ± 0.0628
  Avg Val specificity    : 0.7432 ± 0.0481

--- Final Test Set Performance for Hybrid PNA+ProtT5 (Edge Feats: False) (Ensemble: Majority Vote) ---
Ensemble Test Loss: 0.8371
Ensemble Test Accuracy: 0.7336
Ensemble Test Balanced Acc: 0.7128
Ensemble Test MCC: 0.2629
Ensemble Test Sensitivity: 0.6875
Ensemble Test Specificity: 0.7381
Ensemble Test Confusion Matrix:
  [[TN=1843  FP=654  ]
   [FN=75    TP=165  ]]

--- Evaluation for Hybrid PNA+ProtT5 (Edge Feats: False) completed in 809.51 seconds ---

Preparing graph data (SS Node Features: True, Edge Features: False)...
Preparing PyG graph data (GNN+ProtT5 format: Edge Feats=



  Epoch 01/50 | Tr L:0.6433, A:0.6226 | V L:0.6539, BAcc:0.6545 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5077, A:0.7591 | V L:0.5440, BAcc:0.7232 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3745, A:0.8379 | V L:0.6152, BAcc:0.7185 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1730, A:0.9343 | V L:0.7986, BAcc:0.7228 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5440

  Evaluating best model on validation set for Fold 1...
  Fold 1 Validation Loss: 0.7986
  Fold 1 Validation Accuracy: 0.7233
  Fold 1 Validation Balanced Acc: 0.7228
  Fold 1 Validation MCC: 0.4458
  Fold 1 Validation Sensitivity: 0.7356
  Fold 1 Validation Specificity: 0.7101
  Fold 1 Validation Confusion Matrix:
  [[TN=605   FP=247  ]
   [FN=243   TP=676  ]]
  Predicting on test set using Fold 1's best model...




===== Fold 1 completed in 120.80 seconds =====

===== Fold 2/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured NOT to use edge features.
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=None
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6479, A:0.6209 | V L:0.5737, BAcc:0.6932 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5028, A:0.7581 | V L:0.5576, BAcc:0.7131 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3594, A:0.8499 | V L:0.6098, BAcc:0.7131 | LR:5.0e-04




  Early stopping triggered at epoch 14. Best val loss: 0.5482

  Evaluating best model on validation set for Fold 2...
  Fold 2 Validation Loss: 0.7736
  Fold 2 Validation Accuracy: 0.7002
  Fold 2 Validation Balanced Acc: 0.7011
  Fold 2 Validation MCC: 0.4021
  Fold 2 Validation Sensitivity: 0.6768
  Fold 2 Validation Specificity: 0.7254
  Fold 2 Validation Confusion Matrix:
  [[TN=618   FP=234  ]
   [FN=297   TP=622  ]]
  Predicting on test set using Fold 2's best model...




===== Fold 2 completed in 101.63 seconds =====

===== Fold 3/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured NOT to use edge features.
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=None
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6408, A:0.6243 | V L:0.5807, BAcc:0.6869 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.4935, A:0.7676 | V L:0.5759, BAcc:0.7119 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3066, A:0.8763 | V L:0.6919, BAcc:0.7064 | LR:5.0e-04




  Early stopping triggered at epoch 12. Best val loss: 0.5633

  Evaluating best model on validation set for Fold 3...
  Fold 3 Validation Loss: 0.8026
  Fold 3 Validation Accuracy: 0.7024
  Fold 3 Validation Balanced Acc: 0.7019
  Fold 3 Validation MCC: 0.4039
  Fold 3 Validation Sensitivity: 0.7157
  Fold 3 Validation Specificity: 0.6882
  Fold 3 Validation Confusion Matrix:
  [[TN=587   FP=266  ]
   [FN=261   TP=657  ]]
  Predicting on test set using Fold 3's best model...




===== Fold 3 completed in 87.35 seconds =====

===== Fold 4/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured NOT to use edge features.
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=None
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6429, A:0.6264 | V L:0.5723, BAcc:0.7037 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5041, A:0.7597 | V L:0.5375, BAcc:0.7279 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3456, A:0.8566 | V L:0.6500, BAcc:0.7094 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1529, A:0.9445 | V L:0.8900, BAcc:0.7066 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5375

  Evaluating best model on validation set for Fold 4...
  Fold 4 Validation Loss: 0.8900
  Fold 4 Validation Accuracy: 0.7062
  Fold 4 Validation Balanced Acc: 0.7066
  Fold 4 Validation MCC: 0.4130
  Fold 4 Validation Sensitivity: 0.6950
  Fold 4 Validation Specificity: 0.7183
  Fold 4 Validation Confusion Matrix:
  [[TN=612   FP=240  ]
   [FN=280   TP=638  ]]
  Predicting on test set using Fold 4's best model...




===== Fold 4 completed in 106.97 seconds =====

===== Fold 5/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured NOT to use edge features.
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=None
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6472, A:0.6180 | V L:0.5654, BAcc:0.6962 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5033, A:0.7615 | V L:0.5513, BAcc:0.7152 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3640, A:0.8437 | V L:0.6455, BAcc:0.7254 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1709, A:0.9384 | V L:0.8487, BAcc:0.7234 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5513

  Evaluating best model on validation set for Fold 5...
  Fold 5 Validation Loss: 0.8487
  Fold 5 Validation Accuracy: 0.7226
  Fold 5 Validation Balanced Acc: 0.7234
  Fold 5 Validation MCC: 0.4466
  Fold 5 Validation Sensitivity: 0.7026
  Fold 5 Validation Specificity: 0.7441
  Fold 5 Validation Confusion Matrix:
  [[TN=634   FP=218  ]
   [FN=273   TP=645  ]]
  Predicting on test set using Fold 5's best model...




===== Fold 5 completed in 92.66 seconds =====

--- Cross-validation Summary for Hybrid GATV2+ProtT5 (Edge Feats: False) (Avg +/- Std Dev on Val Folds) ---
  Avg Val loss           : 0.8227 ± 0.0415
  Avg Val accuracy       : 0.7109 ± 0.0100
  Avg Val balanced_acc   : 0.7112 ± 0.0099
  Avg Val mcc            : 0.4223 ± 0.0199
  Avg Val sensitivity    : 0.7051 ± 0.0197
  Avg Val specificity    : 0.7172 ± 0.0184

--- Final Test Set Performance for Hybrid GATV2+ProtT5 (Edge Feats: False) (Ensemble: Majority Vote) ---
Ensemble Test Loss: 0.8227
Ensemble Test Accuracy: 0.7033
Ensemble Test Balanced Acc: 0.7169
Ensemble Test MCC: 0.2594
Ensemble Test Sensitivity: 0.7333
Ensemble Test Specificity: 0.7004
Ensemble Test Confusion Matrix:
  [[TN=1749  FP=748  ]
   [FN=64    TP=176  ]]

--- Evaluation for Hybrid GATV2+ProtT5 (Edge Feats: False) completed in 553.82 seconds ---

Preparing graph data (SS Node Features: True, Edge Features: True)...
Preparing PyG graph data (GNN+ProtT5 format: Edge Fe



  Epoch 01/50 | Tr L:0.6447, A:0.6221 | V L:0.6759, BAcc:0.5925 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5004, A:0.7635 | V L:0.5721, BAcc:0.7085 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3382, A:0.8642 | V L:0.8276, BAcc:0.6763 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1651, A:0.9373 | V L:1.0332, BAcc:0.6933 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5721

  Evaluating best model on validation set for Fold 1...
  Fold 1 Validation Loss: 1.0332
  Fold 1 Validation Accuracy: 0.6872
  Fold 1 Validation Balanced Acc: 0.6933
  Fold 1 Validation MCC: 0.4060
  Fold 1 Validation Sensitivity: 0.5321
  Fold 1 Validation Specificity: 0.8545
  Fold 1 Validation Confusion Matrix:
  [[TN=728   FP=124  ]
   [FN=430   TP=489  ]]
  Predicting on test set using Fold 1's best model...




===== Fold 1 completed in 132.02 seconds =====

===== Fold 2/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6437, A:0.6241 | V L:0.7450, BAcc:0.5911 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5021, A:0.7645 | V L:0.6404, BAcc:0.6938 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3038, A:0.8810 | V L:0.7410, BAcc:0.6769 | LR:5.0e-04




  Early stopping triggered at epoch 12. Best val loss: 0.6314

  Evaluating best model on validation set for Fold 2...
  Fold 2 Validation Loss: 0.7895
  Fold 2 Validation Accuracy: 0.7041
  Fold 2 Validation Balanced Acc: 0.7072
  Fold 2 Validation MCC: 0.4185
  Fold 2 Validation Sensitivity: 0.6268
  Fold 2 Validation Specificity: 0.7876
  Fold 2 Validation Confusion Matrix:
  [[TN=671   FP=181  ]
   [FN=343   TP=576  ]]
  Predicting on test set using Fold 2's best model...




===== Fold 2 completed in 104.21 seconds =====

===== Fold 3/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6339, A:0.6324 | V L:0.6508, BAcc:0.6413 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5014, A:0.7646 | V L:0.5732, BAcc:0.7152 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3616, A:0.8522 | V L:0.7010, BAcc:0.6845 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1640, A:0.9380 | V L:0.8534, BAcc:0.7038 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5732

  Evaluating best model on validation set for Fold 3...
  Fold 3 Validation Loss: 0.8534
  Fold 3 Validation Accuracy: 0.7030
  Fold 3 Validation Balanced Acc: 0.7038
  Fold 3 Validation MCC: 0.4074
  Fold 3 Validation Sensitivity: 0.6830
  Fold 3 Validation Specificity: 0.7245
  Fold 3 Validation Confusion Matrix:
  [[TN=618   FP=235  ]
   [FN=291   TP=627  ]]
  Predicting on test set using Fold 3's best model...




===== Fold 3 completed in 125.10 seconds =====

===== Fold 4/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6443, A:0.6266 | V L:0.6121, BAcc:0.6689 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.4953, A:0.7663 | V L:0.5888, BAcc:0.6960 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3403, A:0.8611 | V L:0.6842, BAcc:0.6817 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1523, A:0.9458 | V L:0.7554, BAcc:0.7244 | LR:5.0e-04




  Early stopping triggered at epoch 17. Best val loss: 0.5762

  Evaluating best model on validation set for Fold 4...
  Fold 4 Validation Loss: 0.8291
  Fold 4 Validation Accuracy: 0.7130
  Fold 4 Validation Balanced Acc: 0.7136
  Fold 4 Validation MCC: 0.4269
  Fold 4 Validation Sensitivity: 0.6983
  Fold 4 Validation Specificity: 0.7289
  Fold 4 Validation Confusion Matrix:
  [[TN=621   FP=231  ]
   [FN=277   TP=641  ]]
  Predicting on test set using Fold 4's best model...




===== Fold 4 completed in 117.49 seconds =====

===== Fold 5/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GATV2 + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GATv2Network: Layers=3, Hidden=128, Heads=4, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6447, A:0.6274 | V L:0.7398, BAcc:0.5837 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5052, A:0.7666 | V L:0.6633, BAcc:0.6776 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3182, A:0.8722 | V L:0.7840, BAcc:0.6944 | LR:5.0e-04




  Early stopping triggered at epoch 12. Best val loss: 0.5759

  Evaluating best model on validation set for Fold 5...
  Fold 5 Validation Loss: 0.7226
  Fold 5 Validation Accuracy: 0.7164
  Fold 5 Validation Balanced Acc: 0.7165
  Fold 5 Validation MCC: 0.4326
  Fold 5 Validation Sensitivity: 0.7146
  Fold 5 Validation Specificity: 0.7183
  Fold 5 Validation Confusion Matrix:
  [[TN=612   FP=240  ]
   [FN=262   TP=656  ]]
  Predicting on test set using Fold 5's best model...




===== Fold 5 completed in 80.41 seconds =====

--- Cross-validation Summary for Hybrid GATV2+ProtT5 (Edge Feats: True) (Avg +/- Std Dev on Val Folds) ---
  Avg Val loss           : 0.8456 ± 0.1037
  Avg Val accuracy       : 0.7047 ± 0.0102
  Avg Val balanced_acc   : 0.7068 ± 0.0081
  Avg Val mcc            : 0.4183 ± 0.0105
  Avg Val sensitivity    : 0.6509 ± 0.0664
  Avg Val specificity    : 0.7627 ± 0.0522

--- Final Test Set Performance for Hybrid GATV2+ProtT5 (Edge Feats: True) (Ensemble: Majority Vote) ---
Ensemble Test Loss: 0.8456
Ensemble Test Accuracy: 0.7391
Ensemble Test Balanced Acc: 0.7007
Ensemble Test MCC: 0.2508
Ensemble Test Sensitivity: 0.6542
Ensemble Test Specificity: 0.7473
Ensemble Test Confusion Matrix:
  [[TN=1866  FP=631  ]
   [FN=83    TP=157  ]]

--- Evaluation for Hybrid GATV2+ProtT5 (Edge Feats: True) completed in 706.71 seconds ---

Preparing graph data (SS Node Features: True, Edge Features: True)...
Preparing PyG graph data (GNN+ProtT5 format: Edge Feats



  Epoch 01/50 | Tr L:0.6447, A:0.6278 | V L:0.6213, BAcc:0.6519 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5115, A:0.7608 | V L:0.5483, BAcc:0.7211 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3677, A:0.8421 | V L:0.6457, BAcc:0.6978 | LR:1.0e-03




  Epoch 15/50 | Tr L:0.1640, A:0.9396 | V L:0.9239, BAcc:0.7174 | LR:5.0e-04
  Early stopping triggered at epoch 15. Best val loss: 0.5483

  Evaluating best model on validation set for Fold 1...
  Fold 1 Validation Loss: 0.9239
  Fold 1 Validation Accuracy: 0.7137
  Fold 1 Validation Balanced Acc: 0.7174
  Fold 1 Validation MCC: 0.4413
  Fold 1 Validation Sensitivity: 0.6213
  Fold 1 Validation Specificity: 0.8134
  Fold 1 Validation Confusion Matrix:
  [[TN=693   FP=159  ]
   [FN=348   TP=571  ]]
  Predicting on test set using Fold 1's best model...




===== Fold 1 completed in 78.12 seconds =====

===== Fold 2/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GINE + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GINENetwork: Layers=3, Hidden=128, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6457, A:0.6223 | V L:0.5841, BAcc:0.6998 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5066, A:0.7518 | V L:0.5697, BAcc:0.7055 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3641, A:0.8462 | V L:0.6232, BAcc:0.7114 | LR:5.0e-04




  Early stopping triggered at epoch 14. Best val loss: 0.5522

  Evaluating best model on validation set for Fold 2...
  Fold 2 Validation Loss: 0.8367
  Fold 2 Validation Accuracy: 0.7103
  Fold 2 Validation Balanced Acc: 0.7100
  Fold 2 Validation MCC: 0.4199
  Fold 2 Validation Sensitivity: 0.7193
  Fold 2 Validation Specificity: 0.7007
  Fold 2 Validation Confusion Matrix:
  [[TN=597   FP=255  ]
   [FN=258   TP=661  ]]
  Predicting on test set using Fold 2's best model...




===== Fold 2 completed in 73.12 seconds =====

===== Fold 3/5 =====
  Training samples: 7082, Validation samples: 1771
Initializing HybridModel (GNN=GINE + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GINENetwork: Layers=3, Hidden=128, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6421, A:0.6228 | V L:0.5846, BAcc:0.6813 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5035, A:0.7592 | V L:0.5829, BAcc:0.7020 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3625, A:0.8482 | V L:0.6815, BAcc:0.6877 | LR:5.0e-04




  Early stopping triggered at epoch 14. Best val loss: 0.5703

  Evaluating best model on validation set for Fold 3...
  Fold 3 Validation Loss: 0.8198
  Fold 3 Validation Accuracy: 0.7047
  Fold 3 Validation Balanced Acc: 0.7053
  Fold 3 Validation MCC: 0.4103
  Fold 3 Validation Sensitivity: 0.6895
  Fold 3 Validation Specificity: 0.7210
  Fold 3 Validation Confusion Matrix:
  [[TN=615   FP=238  ]
   [FN=285   TP=633  ]]
  Predicting on test set using Fold 3's best model...




===== Fold 3 completed in 71.17 seconds =====

===== Fold 4/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GINE + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GINENetwork: Layers=3, Hidden=128, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6441, A:0.6322 | V L:0.5773, BAcc:0.7092 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5074, A:0.7587 | V L:0.5762, BAcc:0.7124 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3090, A:0.8753 | V L:0.7069, BAcc:0.7090 | LR:5.0e-04




  Early stopping triggered at epoch 13. Best val loss: 0.5547

  Evaluating best model on validation set for Fold 4...
  Fold 4 Validation Loss: 0.8251
  Fold 4 Validation Accuracy: 0.7034
  Fold 4 Validation Balanced Acc: 0.7040
  Fold 4 Validation MCC: 0.4078
  Fold 4 Validation Sensitivity: 0.6874
  Fold 4 Validation Specificity: 0.7207
  Fold 4 Validation Confusion Matrix:
  [[TN=614   FP=238  ]
   [FN=287   TP=631  ]]
  Predicting on test set using Fold 4's best model...




===== Fold 4 completed in 67.63 seconds =====

===== Fold 5/5 =====
  Training samples: 7083, Validation samples: 1770
Initializing HybridModel (GNN=GINE + ProtT5)
  GNN configured to use edge features (dim=17)
Initialized GINENetwork: Layers=3, Hidden=128, Dropout=0.4, EdgeDim=17
  ProtT5 MLP initialized: 1024 -> ... -> 128
  Combined input dimension for final classifier: 512
HybridModel initialization complete.
  Starting training for max 50 epochs (Patience: 10)...




  Epoch 01/50 | Tr L:0.6444, A:0.6295 | V L:0.5942, BAcc:0.6908 | LR:1.0e-03




  Epoch 05/50 | Tr L:0.5111, A:0.7552 | V L:0.5645, BAcc:0.7281 | LR:1.0e-03




  Epoch 10/50 | Tr L:0.3331, A:0.8609 | V L:0.7449, BAcc:0.7091 | LR:5.0e-04




  Early stopping triggered at epoch 13. Best val loss: 0.5614

  Evaluating best model on validation set for Fold 5...
  Fold 5 Validation Loss: 0.8446
  Fold 5 Validation Accuracy: 0.7023
  Fold 5 Validation Balanced Acc: 0.7055
  Fold 5 Validation MCC: 0.4157
  Fold 5 Validation Sensitivity: 0.6198
  Fold 5 Validation Specificity: 0.7911
  Fold 5 Validation Confusion Matrix:
  [[TN=674   FP=178  ]
   [FN=349   TP=569  ]]
  Predicting on test set using Fold 5's best model...




===== Fold 5 completed in 67.25 seconds =====

--- Cross-validation Summary for Hybrid GINE+ProtT5 (Edge Feats: True) (Avg +/- Std Dev on Val Folds) ---
  Avg Val loss           : 0.8500 ± 0.0380
  Avg Val accuracy       : 0.7069 ± 0.0044
  Avg Val balanced_acc   : 0.7084 ± 0.0049
  Avg Val mcc            : 0.4190 ± 0.0119
  Avg Val sensitivity    : 0.6675 ± 0.0399
  Avg Val specificity    : 0.7494 ± 0.0444

--- Final Test Set Performance for Hybrid GINE+ProtT5 (Edge Feats: True) (Ensemble: Majority Vote) ---
Ensemble Test Loss: 0.8500
Ensemble Test Accuracy: 0.7296
Ensemble Test Balanced Acc: 0.7162
Ensemble Test MCC: 0.2656
Ensemble Test Sensitivity: 0.7000
Ensemble Test Specificity: 0.7325
Ensemble Test Confusion Matrix:
  [[TN=1829  FP=668  ]
   [FN=72    TP=168  ]]

--- Evaluation for Hybrid GINE+ProtT5 (Edge Feats: True) completed in 502.88 seconds ---


Configuration        | AvgVal BAcc  | AvgVal MCC   | Test BAcc    | Test MCC    
----------------------------------------------