In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# PyG imports
from torch_geometric.nn import GCNConv, GATv2Conv, PNAConv, GINEConv
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import degree

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import StratifiedKFold
# from sklearn.preprocessing import RobustScaler # Not used directly in this version
import matplotlib.pyplot as plt # Keep if plotting results later
import random
import os
import traceback # For detailed error reporting

# Set random seeds ONCE for the single run
SEED = 42 # Or choose a specific seed for this run
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    # Optional: For stricter reproducibility (can slow down training)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
else:
    print("CUDA not available, running on CPU.")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Using Seed: {SEED}")

# Amino Acid Encodings
AMINO_ACIDS = 'ARNDCQEGHILKMFPSTWYV-' # Includes padding char
AA_TO_INT = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
VALID_AA = 'ARNDCQEGHILKMFPSTWYV' # Valid AAs for GNN node features

# Define expected edge feature dimension (if using rich features)
EXPECTED_EDGE_FEATURE_DIM = 17

# Helper function for sequence encoding (needed by CNN)
def integer_encode_sequence(sequence):
    """Integer encode sequence using AA_TO_INT (includes padding char)."""
    return [AA_TO_INT.get(char, AA_TO_INT['-']) for char in sequence]

# Function to create detailed Edge Features (Needed for GATv2 w/ EF, GINE)
def create_edge_features(i_orig, j_orig, dist_map, sequence, ss_string, sasa_vals, K_POS=16):
    """
    Create rich edge features (17 dims) between original indices i_orig and j_orig.
    Handles padding checks.

    Args:
        i_orig, j_orig: Original 0-32 indices in the sequence window.
        dist_map: Full 33x33 distance map numpy array.
        sequence: Full 33-char sequence string.
        ss_string: Full 33-char secondary structure string.
        sasa_vals: Full 33-element SASA numpy array.
        K_POS: Absolute index (0-32) of the central Lysine.

    Returns:
        A fixed-size numpy array (17,) dtype=np.float32.
    """
    edge_features = np.zeros(EXPECTED_EDGE_FEATURE_DIM, dtype=np.float32) # Use corrected dimension (17)
    seq_len = len(sequence) # Should be 33

    # Basic validity and padding checks
    if not (0 <= i_orig < seq_len and 0 <= j_orig < seq_len):
        return edge_features
    if sequence[i_orig] == '-' or sequence[j_orig] == '-':
        return edge_features

    # Safely get distance
    distance = 0.0
    if i_orig < dist_map.shape[0] and j_orig < dist_map.shape[1]:
        distance = dist_map[i_orig, j_orig]
    else:
        return edge_features # Index out of bounds for dist_map

    feature_idx = 0

    # 1. Distance Features (5 features) - Bins + Inverse
    if distance <= 0: pass # Handle zero/negative distance
    elif distance <= 4.0: edge_features[feature_idx] = 1.0
    elif distance <= 8.0: edge_features[feature_idx + 1] = 1.0
    elif distance <= 12.0: edge_features[feature_idx + 2] = 1.0
    else: edge_features[feature_idx + 3] = 1.0
    feature_idx += 4
    edge_features[feature_idx] = (1.0 / distance) if distance > 1e-6 else 0.0 # Inverse distance
    feature_idx += 1

    # 2. Sequential Features (2 features) - Neighbor + Normalized Distance
    seq_dist = abs(i_orig - j_orig)
    edge_features[feature_idx] = float(seq_dist == 1)
    edge_features[feature_idx + 1] = seq_dist / max(1, seq_len - 1)
    feature_idx += 2

    # 3. K-relative Features (2 features) - Connected to K + Normalized Dist to K
    edge_features[feature_idx] = float(i_orig == K_POS or j_orig == K_POS)
    max_dist_from_k = max(1, K_POS, seq_len - 1 - K_POS)
    edge_features[feature_idx + 1] = min(abs(i_orig - K_POS), abs(j_orig - K_POS)) / max_dist_from_k
    feature_idx += 2

    # 4. Secondary Structure Interaction (6 features)
    if i_orig < len(ss_string) and j_orig < len(ss_string):
        ss_pairs = ['HH', 'HE', 'HL', 'EE', 'EL', 'LL']
        ss_map = {'H': 'H', 'E': 'E', 'L': 'L', '-': 'L'} # Map padding to Loop
        ss_i = ss_map.get(ss_string[i_orig], 'L')
        ss_j = ss_map.get(ss_string[j_orig], 'L')
        ss_pair = ''.join(sorted([ss_i, ss_j]))
        for idx, pair in enumerate(ss_pairs):
            edge_features[feature_idx + idx] = float(ss_pair == pair)
    feature_idx += 6

    # 5. SASA Interaction (2 features) - Difference + Average
    if i_orig < len(sasa_vals) and j_orig < len(sasa_vals):
        sasa_i = np.nan_to_num(sasa_vals[i_orig])
        sasa_j = np.nan_to_num(sasa_vals[j_orig])
        edge_features[feature_idx] = abs(sasa_i - sasa_j)
        edge_features[feature_idx + 1] = (sasa_i + sasa_j) / 2.0
    feature_idx += 2

    # Final check (optional)
    # if feature_idx != EXPECTED_EDGE_FEATURE_DIM: print("Warning: Feature index mismatch")

    return np.nan_to_num(edge_features, nan=0.0, posinf=0.0, neginf=0.0)

def prepare_graph_data(df, distance_threshold=8.0, use_ss_node_feature=True, include_edge_features=False):
    """
    Prepare graph data for PyTorch Geometric, suitable for the Hybrid Model.

    Extracts detailed node features for all valid residues.
    Optionally generates 17-dim edge features if include_edge_features=True.
    Includes the integer-encoded full sequence for the CNN track.

    Args:
        df (pd.DataFrame): Input DataFrame.
        distance_threshold (float): Max distance for graph edges.
        use_ss_node_feature (bool): Whether to include SS one-hot in node features.
        include_edge_features (bool): Whether to generate and include edge_attr.

    Returns:
        tuple: (list_of_Data_objects, list_of_labels)
    """
    graph_list = []
    labels = []
    skipped = 0
    expected_seq_len = 33
    central_k_pos_abs = 16 # 0-based index in the 33-residue window

    # Columns expected to be lists/arrays represented as strings
    feature_names_to_parse = [
        'phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4',
        'sasa', 'ss', 'plDDT', 'distance_map'
    ]
    # Reduced list for testing if full features cause issues:
    # feature_names_to_parse = ['phi', 'psi', 'omega', 'sasa', 'ss', 'plDDT', 'distance_map']

    print(f"Preparing data (Edge Features: {include_edge_features}, SS Node Feat: {use_ss_node_feature}, Threshold: {distance_threshold}Å)...")

    for idx, row in df.iterrows():
        try:
            sequence = row['sequence'] # Full 33-char sequence
            label = row['label']

            # --- Initial Validation ---
            if pd.isna(sequence) or len(sequence) != expected_seq_len or sequence[central_k_pos_abs] != 'K':
                skipped += 1
                continue

            # --- Parse Required Data from row ---
            parsed_data = {}
            valid_row = True
            for name in feature_names_to_parse:
                if name not in row or pd.isna(row[name]):
                    # print(f"Warning: Missing data for '{name}' row {idx}.") # Less verbose
                    valid_row = False; break
                try:
                    if name == 'ss':
                        parsed_data[name] = str(row[name]) # Keep SS as string
                        if len(parsed_data[name]) != expected_seq_len: raise ValueError("SS length mismatch")
                    else:
                        # Using eval assumes data is stored as string representations of lists/arrays
                        parsed_data[name] = np.array(eval(str(row[name])), dtype=np.float32)
                except Exception as e:
                    # print(f"Warning: Error parsing '{name}' in row {idx}: {e}.") # Less verbose
                    valid_row = False; break
            if not valid_row: skipped += 1; continue

            # Reshape distance map and check size
            if parsed_data['distance_map'].size == expected_seq_len * expected_seq_len:
                 distance_map = parsed_data['distance_map'].reshape(expected_seq_len, expected_seq_len)
            else: skipped += 1; continue # Skip if distance map wrong size

            # --- Identify Valid (Non-padded) Positions ---
            valid_pos_indices = [i for i, aa in enumerate(sequence) if aa != '-']
            if not valid_pos_indices: skipped += 1; continue # Skip if no valid residues
            num_nodes = len(valid_pos_indices)
            valid_sequence = ''.join([sequence[i] for i in valid_pos_indices]) # Sequence of only valid nodes

            # Find new 0-based index of central K *within the valid nodes*
            try:
                central_k_new_idx = valid_pos_indices.index(central_k_pos_abs)
            except ValueError: # Happens if the central K was padded '-'
                skipped += 1; continue

            # --- Node Feature Extraction (Only for Valid Nodes) ---
            node_features_list = []
            # 1. AA One-hot (20 features)
            aa_onehot = np.zeros((num_nodes, len(VALID_AA)), dtype=np.float32)
            for i, aa in enumerate(valid_sequence):
                aa_idx = VALID_AA.find(aa)
                if aa_idx >= 0: aa_onehot[i, aa_idx] = 1.0
            node_features_list.append(aa_onehot)

            # 2. Central K indicator (1 feature) - Relative to valid nodes
            is_central_k = np.zeros((num_nodes, 1), dtype=np.float32)
            is_central_k[central_k_new_idx, 0] = 1.0
            node_features_list.append(is_central_k)

            # 3. Angles (e.g., phi, psi, omega) -> sin/cos (2 features each)
            # Choose which angles to include
            # angle_keys = ['phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4'] # All angles
            angle_keys = ['phi', 'psi', 'omega'] # Reduced set

            for key in angle_keys:
                if key in parsed_data: # Check if the angle data was parsed
                    valid_angles = parsed_data[key][valid_pos_indices]
                    valid_angles = np.nan_to_num(valid_angles) # Handle NaNs
                    angle_rad = np.pi * valid_angles / 180.0
                    sin_cos_features = np.stack([np.sin(angle_rad), np.cos(angle_rad)], axis=-1)
                    node_features_list.append(sin_cos_features.astype(np.float32))
                # else: print(f"Warning: Angle key '{key}' not found in parsed_data for row {idx}")

            # 4. SASA (1 feature)
            if 'sasa' in parsed_data:
                 valid_sasa = parsed_data['sasa'][valid_pos_indices].reshape(-1, 1)
                #  node_features_list.append(np.nan_to_num(valid_sasa).astype(np.float32))

            # 5. SS (Secondary Structure) (3 features if use_ss_node_feature=True)
            if use_ss_node_feature and 'ss' in parsed_data:
                ss_string = parsed_data['ss']
                valid_ss_chars = [ss_string[i] for i in valid_pos_indices]
                ss_onehot = np.zeros((num_nodes, 3), dtype=np.float32) # H=0, E=1, L=2
                ss_map = {'H': 0, 'E': 1, 'L': 2, '-': 2} # Map '-' to L just in case
                for i, ss_char in enumerate(valid_ss_chars):
                    ss_onehot[i, ss_map.get(ss_char, 2)] = 1.0
                # node_features_list.append(ss_onehot)

            # 6. pLDDT (1 feature)
            if 'plDDT' in parsed_data:
                valid_plddt = parsed_data['plDDT'][valid_pos_indices].reshape(-1, 1)
                # node_features_list.append(np.nan_to_num(valid_plddt).astype(np.float32))

            # --- Concatenate all node features ---
            try:
                if not node_features_list: # Should not happen if AA one-hot is always added
                     print(f"Warning: No node features generated for row {idx}. Skipping.")
                     skipped += 1; continue
                node_features = np.concatenate(node_features_list, axis=1)
            except ValueError as e: # Catch potential concatenation errors
                print(f"Error concatenating node features for row {idx}: {e}")
                # Debugging shapes:
                # for i, arr in enumerate(node_features_list): print(f"  Feature {i} shape: {arr.shape}")
                skipped += 1; continue

            # --- Edge Construction ---
            valid_distance_map = distance_map[np.ix_(valid_pos_indices, valid_pos_indices)]
            adj = valid_distance_map < distance_threshold
            adj &= valid_distance_map > 0 # Distance must be positive
            np.fill_diagonal(adj, False) # No self-loops from distance threshold
            edge_list_valid = np.argwhere(adj) # Indices relative to valid_pos_indices

            edges = []
            edge_features_list = [] # Only populated if include_edge_features is True

            # Create distance-based edges
            if edge_list_valid.shape[0] > 0:
                edges = edge_list_valid.tolist() # Edge indices relative to valid nodes
                if include_edge_features:
                    sasa_full = parsed_data.get('sasa', np.full(expected_seq_len, np.nan)) # Handle missing sasa
                    ss_full = parsed_data.get('ss', '-' * expected_seq_len) # Handle missing ss
                    for i_valid, j_valid in edges:
                        i_orig = valid_pos_indices[i_valid]
                        j_orig = valid_pos_indices[j_valid]
                        edge_feat_vec = create_edge_features(
                            i_orig, j_orig, distance_map, sequence, ss_full, sasa_full
                        )
                        edge_features_list.append(edge_feat_vec)

            # --- Handle Fallback Sequential Edges ---
            if not edges and num_nodes > 1: # Add seq edges ONLY if no distance edges
                for i_valid in range(num_nodes - 1):
                    j_valid = i_valid + 1
                    edges.extend([[i_valid, j_valid], [j_valid, i_valid]])
                    if include_edge_features:
                        i_orig = valid_pos_indices[i_valid]
                        j_orig = valid_pos_indices[j_valid]
                        sasa_full = parsed_data.get('sasa', np.full(expected_seq_len, np.nan))
                        ss_full = parsed_data.get('ss', '-' * expected_seq_len)
                        feat1 = create_edge_features(i_orig, j_orig, distance_map, sequence, ss_full, sasa_full)
                        feat2 = create_edge_features(j_orig, i_orig, distance_map, sequence, ss_full, sasa_full)
                        edge_features_list.extend([feat1, feat2])

            # Skip graph if it still has no edges
            if not edges:
                skipped += 1
                continue

            # --- Convert to PyTorch Tensors ---
            x_tensor = torch.tensor(node_features, dtype=torch.float)
            edge_index_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()
            y_tensor = torch.tensor([label], dtype=torch.float)
            # Integer encode the FULL sequence (with padding) for the CNN
            sequence_tensor = torch.tensor(integer_encode_sequence(sequence), dtype=torch.long)
            # Store the index of the central K relative to the VALID nodes
            central_node_idx_tensor = torch.tensor([central_k_new_idx], dtype=torch.long)

            # Create Data object, conditionally adding edge_attr
            data_dict = {
                'x': x_tensor,
                'edge_index': edge_index_tensor,
                'y': y_tensor,
                'sequence': sequence_tensor, # Full sequence for CNN
                'central_node_idx': central_node_idx_tensor # Index within valid nodes
            }
            if include_edge_features and edge_features_list:
                edge_attr_tensor = torch.tensor(np.array(edge_features_list), dtype=torch.float)
                if edge_attr_tensor.shape[0] == edge_index_tensor.shape[1] and \
                   edge_attr_tensor.shape[1] == EXPECTED_EDGE_FEATURE_DIM:
                    data_dict['edge_attr'] = edge_attr_tensor
                else:
                    print(f"Warning: Edge feature dim mismatch row {idx}. Attr:{edge_attr_tensor.shape}, Idx:{edge_index_tensor.shape}. Skipping edge_attr.")
            elif include_edge_features and not edge_features_list and edge_index_tensor.shape[1] > 0:
                 print(f"Warning: include_edge_features=True but list empty for row {idx} with {edge_index_tensor.shape[1]} edges.")

            data = Data(**data_dict)
            graph_list.append(data)
            labels.append(label)

        except Exception as e:
            print(f"--- Critical Error processing row {idx}: {e} ---")
            traceback.print_exc() # Print full traceback
            skipped += 1
            continue # Skip to next row

    print(f"Created {len(graph_list)} graphs, skipped {skipped} rows.")
    if graph_list:
         print(f" Node feature dimension: {graph_list[0].x.shape[1]}")
         if 'edge_attr' in graph_list[0]:
             print(f" Edge feature dimension: {graph_list[0].edge_attr.shape[1]}")
         else:
             print(f" Edge features not included or generated.")
    return graph_list, labels

# These GNNs return ALL final node features, suitable for pooling/readout in HybridModel
class GCNNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, dropout=0.4, layers=3):
        super(GCNNetwork, self).__init__()
        
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # First layer
        self.convs.append(GCNConv(input_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # Middle layers
        for i in range(1, layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        self.dropout = nn.Dropout(dropout)
        self.output_dim = hidden_dim
        
    def forward(self, x, edge_index): # Remove central_node_idx, batch, ptr
        # x: Node features [num_nodes, input_dim]
        # edge_index: Graph connectivity [2, num_edges]

        # First layer
        x = self.convs[0](x, edge_index)
        if x.shape[0] > 0: # Handle potentially empty graphs after processing
            x = self.batch_norms[0](x)
        x = F.relu(x)
        x = self.dropout(x)

        # Middle layers with residual connections
        for i in range(1, len(self.convs)):
            x_res = x  # Save for residual connection
            x = self.convs[i](x, edge_index)
            if x.shape[0] > 0:
                x = self.batch_norms[i](x)
            x = F.relu(x)
            if x_res.shape == x.shape: # Check shape match for residual connection
                x = x + x_res  # Residual connection
            x = self.dropout(x)

        # ---- REMOVE THE CENTRAL NODE EXTRACTION FROM HERE ----
        # graph_starts = ptr[:-1]
        # absolute_central_node_indices = graph_starts + central_node_idx
        # if absolute_central_node_indices.max() >= x.shape[0]:
        #     raise IndexError("GCN absolute central node index out of bounds")
        # central_node_features = x[absolute_central_node_indices]

        # Return ALL final node features
        return x  # Shape: [num_nodes, output_dim]

class GATv2Network(nn.Module):
    """ GNN Network using GATv2Conv (Optionally uses edge features) - Returns all node features """
    def __init__(self, input_dim, hidden_dim=128, heads=4, dropout=0.4, layers=3, edge_dim=None):
        super().__init__()
        if hidden_dim % heads != 0:
            raise ValueError("GATv2Network: hidden_dim must be divisible by heads")

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        head_dim = hidden_dim // heads
        self.has_edge_features = edge_dim is not None

        # First layer
        self.convs.append(GATv2Conv(input_dim, head_dim, heads=heads, concat=True,
                                    dropout=dropout, edge_dim=edge_dim, add_self_loops=True))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Middle layers
        for _ in range(1, layers):
            self.convs.append(GATv2Conv(hidden_dim, head_dim, heads=heads, concat=True,
                                        dropout=dropout, edge_dim=edge_dim, add_self_loops=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout_layer = nn.Dropout(dropout)
        self.output_dim = hidden_dim

    def forward(self, x, edge_index, edge_attr=None): # edge_attr is optional input
        # x: [N_total, input_dim], edge_index: [2, E_total], edge_attr: [E_total, edge_dim] or None
        for i in range(len(self.convs)):
            x_res = x
            if self.has_edge_features:
                if edge_attr is None:
                    raise ValueError("GATv2Network configured with edge_dim but edge_attr not provided.")
                x = self.convs[i](x, edge_index, edge_attr=edge_attr)
            else:
                x = self.convs[i](x, edge_index)

            if x.shape[0] > 0: x = self.batch_norms[i](x)
            else: return x

            x = F.elu(x) # Use ELU for GAT variants

            if i > 0 and x_res.shape == x.shape: x = x + x_res
            x = self.dropout_layer(x)

        return x # Return ALL final node features: [N_total, hidden_dim]

class PNANetwork(nn.Module):
    """ GNN Network using PNAConv (No edge features) - Returns all node features """
    def __init__(self, input_dim, hidden_dim=128, layers=3, dropout=0.4, deg=None):
        super().__init__()
        if deg is None:
            raise ValueError("PNANetwork requires the degree histogram 'deg' argument.")

        aggregators = ['mean', 'min', 'max', 'std']
        scalers = ['identity', 'amplification', 'attenuation']
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        # First layer
        self.convs.append(PNAConv(input_dim, hidden_dim, aggregators=aggregators,
                                  scalers=scalers, deg=deg, towers=4, post_layers=1))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Middle layers
        for _ in range(1, layers):
            self.convs.append(PNAConv(hidden_dim, hidden_dim, aggregators=aggregators,
                                      scalers=scalers, deg=deg, towers=4, post_layers=1))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout_layer = nn.Dropout(dropout)
        self.output_dim = hidden_dim

    def forward(self, x, edge_index, edge_attr=None): # Edge_attr ignored
        # x: [N_total, input_dim], edge_index: [2, E_total]
        for i in range(len(self.convs)):
            x_res = x
            x = self.convs[i](x, edge_index)

            if x.shape[0] > 0: x = self.batch_norms[i](x)
            else: return x

            x = F.relu(x)

            if i > 0 and x_res.shape == x.shape: x = x + x_res
            x = self.dropout_layer(x)

        return x # Return ALL final node features: [N_total, hidden_dim]

class GINENetwork(nn.Module):
    """ GNN Network using GINEConv (Requires edge features) - Returns all node features """
    def __init__(self, input_dim, hidden_dim=128, edge_dim=EXPECTED_EDGE_FEATURE_DIM, dropout=0.4, layers=3):
        super().__init__()
        if edge_dim is None:
            raise ValueError("GINENetwork requires a valid 'edge_dim'.")

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.edge_dim = edge_dim

        def create_mlp(in_dim, out_dim):
            return nn.Sequential(
                nn.Linear(in_dim, out_dim * 2), nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(out_dim * 2, out_dim)
            )

        # First layer
        mlp1 = create_mlp(input_dim, hidden_dim)
        self.convs.append(GINEConv(nn=mlp1, edge_dim=self.edge_dim, train_eps=True))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Middle layers
        for _ in range(1, layers):
            mlp = create_mlp(hidden_dim, hidden_dim)
            self.convs.append(GINEConv(nn=mlp, edge_dim=self.edge_dim, train_eps=True))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout_layer = nn.Dropout(dropout)
        self.output_dim = hidden_dim

    def forward(self, x, edge_index, edge_attr): # Requires edge_attr
        # x: [N_total, input_dim], edge_index: [2, E_total], edge_attr: [E_total, edge_dim]
        if edge_attr is None:
            raise ValueError("GINENetwork requires edge_attr for forward pass.")
        if edge_attr.shape[1] != self.edge_dim:
            raise ValueError(f"GINENetwork edge_attr dim mismatch. Expected {self.edge_dim}, got {edge_attr.shape[1]}")

        for i in range(len(self.convs)):
            x_res = x
            x = self.convs[i](x, edge_index, edge_attr=edge_attr)

            if x.shape[0] > 0: x = self.batch_norms[i](x)
            else: return x

            x = F.relu(x)

            if i > 0 and x_res.shape == x.shape: x = x + x_res
            x = self.dropout_layer(x)

        return x # Return ALL final node features: [N_total, hidden_dim]


class SequenceCNN(nn.Module):
    def __init__(self, vocab_size=len(AMINO_ACIDS), embed_dim=21, out_channels=32,
                 kernel_height=17, kernel_width=3, dropout=0.4, seq_len=33):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=AA_TO_INT['-'])

        self.conv2d = nn.Conv2d(
            in_channels=1,
            out_channels=out_channels,
            kernel_size=(kernel_height, kernel_width),
            padding='valid' # Results in output H = seq_len - kernel_height + 1
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.dropout1 = nn.Dropout(dropout)
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) # Halves dimensions roughly
        self.flatten = nn.Flatten()

        # Calculate flattened size dynamically (more robust)
        # Create a dummy input matching the expected shape AFTER embedding and unsqueezing
        dummy_input = torch.randn(1, 1, seq_len, embed_dim)
        dummy_output = self.pool(self.conv2d(dummy_input))
        flat_size = dummy_output.numel() # Number of elements in the output tensor

        self.fc1 = nn.Linear(flat_size, 32) # Output dimension of CNN track
        self.bn2 = nn.BatchNorm1d(32)
        self.dropout2 = nn.Dropout(dropout)
        self.output_dim = 32

        print(f"SequenceCNN initialized. Calculated flat size: {flat_size}")


    def forward(self, seq_indices):
        # seq_indices shape: [batch_size, seq_len]
        x = self.embedding(seq_indices)  # [batch_size, seq_len, embed_dim]
        x = x.unsqueeze(1)  # [batch_size, 1, seq_len, embed_dim]

        x = self.conv2d(x)  # [batch_size, out_channels, H_out, W_out]
        if x.shape[0] > 0: x = self.bn1(x) # Apply BN only if batch exists
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.pool(x)

        x = self.flatten(x)
        x = self.fc1(x)
        if x.shape[0] > 0: x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout2(x)

        return x # Shape: [batch_size, cnn_output_dim=32]


class HybridModel(nn.Module):
    """
    Combines a GNN backbone with a Sequence CNN track.
    GNN output is processed via Central Node + Global Mean + Global Max pooling.
    """
    def __init__(self, gnn_type, node_feature_dim, hidden_dim=128,
                 # GNN specific args
                 deg=None, # Required for PNA
                 edge_dim=None, # Required for GINE, Optional for GATv2
                 heads=4, # Required for GATv2
                 # Common GNN hypers
                 layers=3,
                 dropout=0.4,
                 # Sequence CNN hypers (can be customized or defaults used)
                 vocab_size=len(AMINO_ACIDS),
                 embed_dim=21,
                 cnn_out_channels=32,
                 cnn_kernel_height=17,
                 cnn_kernel_width=3,
                 cnn_dropout=0.4,
                 seq_len=33,
                 # Final Classifier hypers
                 classifier_dropout=0.5):
        super().__init__()
        self.gnn_type = gnn_type
        self.use_edge_features = edge_dim is not None and gnn_type in ['gatv2', 'gine']

        # --- GNN Backbone Instantiation ---
        gnn_kwargs = {
            'input_dim': node_feature_dim, 'hidden_dim': hidden_dim,
            'dropout': dropout, 'layers': layers
        }
        if gnn_type == 'gcn':
            self.gnn = GCNNetwork(**gnn_kwargs)
        elif gnn_type == 'gatv2':
            self.gnn = GATv2Network(**gnn_kwargs, heads=heads, edge_dim=edge_dim)
        elif gnn_type == 'pna':
            if deg is None: raise ValueError("PNA requires 'deg' histogram.")
            self.gnn = PNANetwork(**gnn_kwargs, deg=deg)
        elif gnn_type == 'gine':
            if edge_dim is None: raise ValueError("GINE requires 'edge_dim'.")
            # Pass the actual edge_dim value expected by GINE
            self.gnn = GINENetwork(**gnn_kwargs, edge_dim=edge_dim)
        else:
            raise ValueError(f"Unsupported GNN type: {gnn_type}")

        # --- Sequence CNN Instantiation ---
        self.sequence_cnn = SequenceCNN(
            vocab_size=vocab_size, embed_dim=embed_dim, out_channels=cnn_out_channels,
            kernel_height=cnn_kernel_height, kernel_width=cnn_kernel_width,
            dropout=cnn_dropout, seq_len=seq_len
        )

        # --- Combination Layers ---
        gnn_output_dim = self.gnn.output_dim
        cnn_output_dim = self.sequence_cnn.output_dim

        # Input: Central Node + Global Mean Pool + Global Max Pool + CNN Features
        combined_input_dim = gnn_output_dim + gnn_output_dim + gnn_output_dim + cnn_output_dim
        print(f"HybridModel Combined Input Dim: {combined_input_dim} ({gnn_output_dim}+{gnn_output_dim}+{gnn_output_dim}+{cnn_output_dim})")

        # Simple classifier MLP
        self.fc1 = nn.Linear(combined_input_dim, 64) # Example hidden size
        self.bn_combine = nn.BatchNorm1d(64) # Use unique name
        self.dropout_combine = nn.Dropout(classifier_dropout)
        self.fc2 = nn.Linear(64, 1) # Final output

    def forward(self, data):
        """ Forward pass for the Hybrid model. """
        # --- Extract Data ---
        x, edge_index = data.x, data.edge_index
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        batch = data.batch # Crucial for pooling
        central_node_idx = data.central_node_idx # Indices relative to valid nodes per graph
        ptr = data.ptr if hasattr(data, 'ptr') else None # For calculating absolute indices
        sequence = data.sequence # Full sequence tensor [batch_size * seq_len]

        # Handle missing batch/ptr (e.g., single graph inference)
        batch_size = data.num_graphs
        if batch is None and x is not None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
        if ptr is None and batch is not None:
             if batch.numel() > 0:
                 counts = torch.bincount(batch)
                 ptr = torch.cat([torch.tensor([0], device=batch.device), counts.cumsum(0)])
             else: ptr = torch.tensor([0], device=batch.device) # Empty batch

        # --- GNN Track ---
        # Call GNN: Gets ALL node features [N_total, gnn_hidden_dim]
        if self.use_edge_features:
             if edge_attr is None:
                 raise ValueError(f"Model '{self.gnn_type}' uses edge features, but 'edge_attr' not in data.")
             gnn_node_features = self.gnn(x, edge_index, edge_attr)
        else:
             gnn_node_features = self.gnn(x, edge_index) # Pass None for edge_attr if GNN accepts it

        # --- GNN Readout (Central Node + Pooling) ---
        gnn_out_dim = self.gnn.output_dim
        # Initialize readout tensors robustly (handle empty batches/graphs)
        central_node_features = torch.zeros((batch_size, gnn_out_dim), device=device)
        global_avg_features = torch.zeros((batch_size, gnn_out_dim), device=device)
        global_max_features = torch.zeros((batch_size, gnn_out_dim), device=device)

        nodes_exist = gnn_node_features is not None and gnn_node_features.shape[0] > 0
        indices_valid = ptr is not None and ptr.numel() > 1 and \
                        hasattr(data, 'central_node_idx') and \
                        data.central_node_idx.numel() == batch_size and \
                        batch is not None and batch.numel() > 0

        if nodes_exist and indices_valid:
            try:
                # Central Node Extraction
                graph_starts = ptr[:-1] # Start index of each graph in the batch
                absolute_central_node_indices = graph_starts + central_node_idx

                # Boundary check before gathering
                max_idx = absolute_central_node_indices.max().item() if absolute_central_node_indices.numel() > 0 else -1
                min_idx = absolute_central_node_indices.min().item() if absolute_central_node_indices.numel() > 0 else 0

                if max_idx < gnn_node_features.shape[0] and min_idx >= 0:
                    central_node_features = gnn_node_features[absolute_central_node_indices]
                elif batch_size > 0: # Only warn if expected nodes but index failed
                    print(f"Warning: Central node index out of bounds (Max: {max_idx}, Min: {min_idx} vs Size: {gnn_node_features.shape[0]}). Using zeros.")

                # Global Pooling (requires valid batch vector)
                if batch.max() < batch_size: # Check batch indices are valid
                     global_avg_features = global_mean_pool(gnn_node_features, batch)
                     global_max_features = global_max_pool(gnn_node_features, batch)
                elif batch_size > 0:
                     print(f"Warning: Batch index issue during pooling (Max: {batch.max()} vs BatchSize: {batch_size}). Using zeros.")

            except IndexError as e:
                 print(f"Error during GNN readout: {e}")
                 # Fallback to zeros already initialized
            except Exception as e: # Catch other potential errors
                 print(f"Unexpected error during GNN readout: {e}")
                 # Fallback to zeros

        elif batch_size > 0 : # Only warn if batch_size > 0 but readout failed
             print("Warning: Skipping GNN readout due to missing nodes or invalid indices.")

        # --- Sequence CNN Track ---
        seq_len = 33 # Assuming fixed length
        # Reshape sequence tensor for CNN: [batch_size, seq_len]
        if sequence is not None and sequence.numel() == batch_size * seq_len:
            sequence_tensor = sequence.view(batch_size, seq_len)
            seq_features = self.sequence_cnn(sequence_tensor)
        elif batch_size > 0: # Sequence missing or wrong size, but batch exists
            print(f"Warning: Sequence tensor issue (Size: {sequence.numel()} vs Expected: {batch_size * seq_len}). Using zeros for CNN.")
            seq_features = torch.zeros((batch_size, self.sequence_cnn.output_dim), device=device)
        else: # No batch, no sequence processing needed
             seq_features = torch.zeros((0, self.sequence_cnn.output_dim), device=device)


        # --- Combine Features ---
        # Ensure all features have the correct batch dimension before concatenating
        # (This check is important if readout/CNN failed for some reason)
        current_batch_size = central_node_features.shape[0] # Use size from one of the features
        if not (current_batch_size == global_avg_features.shape[0] == \
                global_max_features.shape[0] == seq_features.shape[0]):
             # This indicates a potential issue upstream
             print(f"Warning: Feature batch size mismatch before concat. Central:{central_node_features.shape}, Avg:{global_avg_features.shape}, Max:{global_max_features.shape}, Seq:{seq_features.shape}. Attempting concat with min size.")
             min_b = min(central_node_features.shape[0], global_avg_features.shape[0], global_max_features.shape[0], seq_features.shape[0])
             if min_b <= 0: # Cannot combine if any feature set is empty
                  # Return zeros matching the original expected batch size if possible
                  return torch.zeros((batch_size, 1), device=device)
             combined_features = torch.cat([
                 central_node_features[:min_b],
                 global_avg_features[:min_b],
                 global_max_features[:min_b],
                 seq_features[:min_b]
             ], dim=1)
        else:
             combined_features = torch.cat([
                 central_node_features,
                 global_avg_features,
                 global_max_features,
                 seq_features
             ], dim=1)

        # --- Final Classification MLP ---
        x_out = self.fc1(combined_features)
        if x_out.shape[0] > 0: # Apply BN only if batch is not empty
            x_out = self.bn_combine(x_out)
        else: # Handle empty batch case after combination
             return torch.zeros((batch_size, 1), device=device)

        x_out = F.relu(x_out)
        x_out = self.dropout_combine(x_out)
        x_out = self.fc2(x_out)

        return torch.sigmoid(x_out)


def train_model(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, class_weights: dict = None):
    """ Trains the HybridModel for one epoch. """
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch in loader:
        batch = batch.to(device)
        if batch.num_graphs == 0 or (hasattr(batch, 'x') and batch.x is not None and batch.x.shape[0] == 0):
             continue # Skip empty/invalid batches

        optimizer.zero_grad()

        try:
            output = model(batch) # Get model predictions
            target = batch.y.view(-1, 1) # Ensure target shape is [batch_size, 1]

            # Check for shape mismatch AFTER model forward pass (in case model returns empty)
            if output.shape[0] != target.shape[0]:
                print(f"Warning: Train - Mismatch output ({output.shape[0]}) vs target ({target.shape[0]}). Skipping batch.")
                continue
            if target.numel() == 0: continue # Skip if target is empty

            # Calculate loss
            weight = None
            if class_weights is not None:
                try:
                    # Ensure target values are valid keys (0 or 1) before creating weight tensor
                    target_long = target.long()
                    if ((target_long == 0) | (target_long == 1)).all():
                         weight = torch.tensor([class_weights[k.item()] for k in target_long],
                                               dtype=torch.float, device=device).view(-1, 1)
                    else:
                         print(f"Warning: Invalid target values found for class weights: {target[~((target_long == 0) | (target_long == 1))].unique()}")
                         # Decide: Skip weights, skip batch, or raise error? Skipping weights for safety.
                except KeyError as e:
                    print(f"Warning: Target label {e} not found in class_weights {class_weights.keys()}. Skipping weights for this batch.")
                except Exception as e:
                    print(f"Error preparing class weights: {e}. Skipping weights.")


            current_loss = F.binary_cross_entropy(output, target, weight=weight)

            # Backpropagation
            current_loss.backward()
            # Optional: Gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += current_loss.item() * batch.num_graphs # Accumulate loss weighted by graphs

            # Calculate training accuracy
            with torch.no_grad():
                pred = (output > 0.5).float()
                correct_predictions += (pred == target).sum().item()
                total_samples += target.size(0)

        except Exception as e:
             print(f"--- Error during training batch: {e} ---")
             # print(f"Batch keys: {batch.keys}")
             # print(f"Batch x shape: {batch.x.shape if hasattr(batch, 'x') else 'N/A'}")
             # print(f"Batch edge_index shape: {batch.edge_index.shape if hasattr(batch, 'edge_index') else 'N/A'}")
             # print(f"Batch edge_attr shape: {batch.edge_attr.shape if hasattr(batch, 'edge_attr') else 'N/A'}")
             # print(f"Batch sequence shape: {batch.sequence.shape if hasattr(batch, 'sequence') else 'N/A'}")
             # print(f"Batch central_node_idx shape: {batch.central_node_idx.shape if hasattr(batch, 'central_node_idx') else 'N/A'}")
             # print(f"Num graphs: {batch.num_graphs}")
             traceback.print_exc()
             continue # Skip batch on error

    # Calculate average loss over the dataset size, avg accuracy
    avg_loss = total_loss / len(loader.dataset) if len(loader.dataset) > 0 else 0
    avg_acc = correct_predictions / total_samples if total_samples > 0 else 0
    return avg_loss, avg_acc


def evaluate_model(model: nn.Module, loader: DataLoader, device: torch.device):
    """ Evaluates the HybridModel on the given data loader. """
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            if batch.num_graphs == 0 or (hasattr(batch, 'x') and batch.x is not None and batch.x.shape[0] == 0):
                 continue

            try:
                 output = model(batch)
                 target = batch.y.view(-1, 1)

                 if output.shape[0] != target.shape[0]:
                     print(f"Warning: Eval - Mismatch output ({output.shape[0]}) vs target ({target.shape[0]}). Skipping batch.")
                     continue
                 if target.numel() == 0: continue

                 # Use reduction='sum' for accumulating loss correctly
                 loss = F.binary_cross_entropy(output, target, reduction='sum')
                 total_loss += loss.item()

                 pred = (output > 0.5).float() # Threshold predictions
                 all_preds.append(pred.cpu().numpy())
                 all_targets.append(target.cpu().numpy())

            except Exception as e:
                 print(f"--- Error during evaluation batch: {e} ---")
                 # Add more debug prints if necessary
                 traceback.print_exc()
                 continue # Skip batch on error

    # Handle cases where loader was empty or all batches skipped
    if not all_targets:
        default_metrics = {
            'loss': 0.0, 'accuracy': 0.0, 'balanced_acc': 0.0, 'mcc': 0.0,
            'sensitivity': 0.0, 'specificity': 0.0,
            'confusion_matrix': np.zeros((2, 2), dtype=int),
            'predictions': np.array([]), 'targets': np.array([])
        }
        print("Warning: Evaluation yielded no targets/predictions.")
        return default_metrics

    all_preds = np.concatenate(all_preds).flatten()
    all_targets = np.concatenate(all_targets).flatten()
    num_samples = len(all_targets)

    # Calculate metrics
    # Use zero_division=0 for robustness if using sklearn >= 1.1
    # Otherwise, rely on default warnings or handle manually
    accuracy = accuracy_score(all_targets, all_preds)
    balanced_acc = balanced_accuracy_score(all_targets, all_preds) # adjusted=False is default
    mcc = matthews_corrcoef(all_targets, all_preds)

    # Ensure CM has labels 0 and 1
    cm = confusion_matrix(all_targets, all_preds, labels=[0, 1])

    sensitivity, specificity = 0.0, 0.0
    if cm.shape == (2,2): # Check if 2x2 matrix
         tn, fp, fn, tp = cm.ravel()
         sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
         specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    # else: print("Warning: Confusion matrix is not 2x2.") # If only one class predicted/present

    metrics = {
        'loss': total_loss / num_samples if num_samples > 0 else 0, # Avg loss per sample
        'accuracy': accuracy, 'balanced_acc': balanced_acc, 'mcc': mcc,
        'sensitivity': sensitivity, 'specificity': specificity,
        'confusion_matrix': cm,
        'predictions': all_preds, 'targets': all_targets
    }
    return metrics

def print_metrics(metrics: dict):
    """ Prints formatted metrics from the evaluation dictionary. """
    print(f"  Loss: {metrics.get('loss', float('nan')):.4f}")
    print(f"  Accuracy: {metrics.get('accuracy', 0):.4f}, Balanced Acc: {metrics.get('balanced_acc', 0):.4f}, MCC: {metrics.get('mcc', 0):.4f}")
    print(f"  Sensitivity: {metrics.get('sensitivity', 0):.4f}, Specificity: {metrics.get('specificity', 0):.4f}")
    cm = metrics.get('confusion_matrix')
    if isinstance(cm, np.ndarray): print(f"  Confusion Matrix:\n{cm}")
    else: print("  Confusion Matrix: N/A")


def train_and_evaluate_hybrid_model(train_df: pd.DataFrame, test_df: pd.DataFrame,
                                     model_config: dict, distance_threshold: float = 8.0):
    """
    Trains and evaluates a single Hybrid GNN+CNN model configuration using 5-fold CV.

    Args:
        train_df (pd.DataFrame): Training data.
        test_df (pd.DataFrame): Test data.
        model_config (dict): Configuration dictionary including 'gnn_type',
                             'use_edge_features', and other hyperparameters for GNN, CNN, and training.
        distance_threshold (float): Distance cutoff for graph edges.

    Returns:
        tuple: (list_of_fold_validation_metrics, dict_of_final_test_metrics or None)
    """
    gnn_type = model_config['gnn_type']
    use_edge_features = model_config['use_edge_features'] # Determines if edge_attr needed
    run_label = f"Hybrid {gnn_type.upper()} + CNN (Edge Feats: {use_edge_features})"
    print(f"\n{'='*15} Evaluating: {run_label} {'='*15}")

    # --- Data Preparation ---
    # use_ss_node_feature controls if SS is part of GNN node features
    # include_edge_features controls if edge_attr is generated/added to Data
    train_graphs, train_labels = prepare_graph_data(train_df, distance_threshold,
                                                    use_ss_node_feature=model_config.get('use_ss_node_feature', True),
                                                    include_edge_features=use_edge_features)
    test_graphs, test_labels = prepare_graph_data(test_df, distance_threshold,
                                                  use_ss_node_feature=model_config.get('use_ss_node_feature', True),
                                                  include_edge_features=use_edge_features)

    if not train_graphs:
        print(f"ERROR: No training graphs created for {run_label}. Aborting.")
        return None, None

    # --- Compute Degree Histogram for PNA (if needed) ---
    deg_histogram = None
    if gnn_type == 'pna':
        print("Calculating node degree histogram for PNA...")
        max_degree = -1
        degrees = []
        for data in train_graphs: # Use only training graphs for degree calc
             num_nodes = data.num_nodes
             if num_nodes == 0: continue # Skip empty graphs
             if hasattr(data, 'edge_index') and data.edge_index is not None and data.edge_index.numel() > 0:
                 # Filter invalid indices before degree calculation
                 valid_edge_mask = (data.edge_index[0] < num_nodes) & (data.edge_index[1] < num_nodes)
                 valid_edge_index = data.edge_index[:, valid_edge_mask]
                 if valid_edge_index.numel() > 0:
                      deg_list = degree(valid_edge_index[1], num_nodes=num_nodes, dtype=torch.long) # In-degree
                 else: deg_list = torch.zeros(num_nodes, dtype=torch.long)
             else: deg_list = torch.zeros(num_nodes, dtype=torch.long) # No edges

             degrees.append(deg_list)
             if deg_list.numel() > 0: max_degree = max(max_degree, deg_list.max().item())

        if max_degree == -1: max_degree = 0 # Handle case with no nodes/edges at all

        deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
        if degrees: # If any degrees were collected
            all_degrees = torch.cat(degrees, dim=0)
            if all_degrees.numel() > 0:
                counts = torch.bincount(all_degrees, minlength=max_degree + 1)
                copy_len = min(deg_histogram.numel(), counts.numel())
                deg_histogram[:copy_len] = counts[:copy_len]

        print(f"Max degree found: {max_degree}")
        deg_histogram = deg_histogram.to(device)
        if deg_histogram.sum() == 0 and deg_histogram.numel() > 0:
            print("Warning: Degree histogram is all zeros. Providing minimal histogram.")
            deg_histogram[0] = 1 # Avoid PNA error with all-zero histogram

    # --- Class Weights ---
    total_train = len(train_labels); pos_train = sum(train_labels); neg_train = total_train - pos_train
    class_weights = {
        0: total_train / (2 * neg_train) if neg_train > 0 else 1.0,
        1: total_train / (2 * pos_train) if pos_train > 0 else 1.0
    }
    # print(f"Class weights: {class_weights}") # Optional

    # --- Cross-validation ---
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    fold_metrics = []
    test_predictions = [] # Store test predictions from each fold's best model

    for fold, (train_idx, val_idx) in enumerate(kfold.split(np.zeros(len(train_graphs)), train_labels), 1):
        print(f"\n--- Fold {fold}/5 ---")
        train_fold = [train_graphs[i] for i in train_idx]
        val_fold = [train_graphs[i] for i in val_idx]

        # Pin memory can speed up GPU training but uses more RAM
        pin_memory = torch.cuda.is_available()
        num_workers = 2 if pin_memory else 0 # Use multiple workers only with GPU usually

        train_loader = DataLoader(train_fold, batch_size=model_config.get('batch_size', 32), shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
        val_loader = DataLoader(val_fold, batch_size=model_config.get('batch_size', 32), shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

        # --- Model Initialization ---
        try:
            node_feature_dim = train_graphs[0].x.shape[1]
            # Determine edge_dim based on actual data if using features, otherwise None
            edge_dim_arg = train_graphs[0].edge_attr.shape[1] if (use_edge_features and hasattr(train_graphs[0], 'edge_attr') and train_graphs[0].edge_attr is not None) else None
            if use_edge_features and edge_dim_arg is None:
                 print("Warning: use_edge_features=True but no edge_attr found in data[0]. Setting edge_dim to None for GNN.")
            elif use_edge_features and edge_dim_arg != EXPECTED_EDGE_FEATURE_DIM:
                 print(f"Warning: Expected edge_dim {EXPECTED_EDGE_FEATURE_DIM}, but found {edge_dim_arg} in data[0]. Using found dimension.")

        except (IndexError, AttributeError):
             print("ERROR: Cannot determine feature dimensions from train_graphs. Aborting fold.")
             continue # Skip this fold

        # Instantiate the Hybrid Model
        try:
            model = HybridModel(
                gnn_type=gnn_type,
                node_feature_dim=node_feature_dim,
                hidden_dim=model_config.get('hidden_dim', 128),
                deg=deg_histogram if gnn_type == 'pna' else None,
                # Pass the determined edge_dim_arg
                edge_dim=edge_dim_arg,
                heads=model_config.get('heads', 4),
                layers=model_config.get('layers', 3),
                dropout=model_config.get('dropout', 0.4),
                # Pass CNN hypers from config or use defaults
                embed_dim=model_config.get('embed_dim', 21),
                cnn_out_channels=model_config.get('cnn_out_channels', 32),
                cnn_kernel_height=model_config.get('cnn_kernel_height', 17),
                cnn_kernel_width=model_config.get('cnn_kernel_width', 3),
                cnn_dropout=model_config.get('cnn_dropout', 0.4),
                # Pass classifier dropout from config
                classifier_dropout=model_config.get('classifier_dropout', 0.5)
            ).to(device)
        except Exception as model_init_error:
             print(f"--- ERROR initializing Hybrid Model for fold {fold}: {model_init_error} ---")
             traceback.print_exc()
             continue # Skip fold if model init fails

        # --- Optimizer and Scheduler ---
        optimizer = torch.optim.Adam(model.parameters(), # Match old code
                              lr=model_config.get('lr', 0.001),
                              weight_decay=model_config.get('weight_decay', 0.01))
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=False)

        # --- Training Loop with Early Stopping ---
        epochs = model_config.get('epochs', 50)
        patience = model_config.get('patience', 10)
        best_val_loss = float('inf')
        epochs_no_improve = 0
        best_state_dict = None

        for epoch in range(epochs):
            train_loss, train_acc = train_model(model, train_loader, optimizer, device, class_weights=class_weights)
            val_metrics = evaluate_model(model, val_loader, device)
            val_loss = val_metrics['loss']

            scheduler.step(val_loss)

            # Verbose epoch printing (optional)
            # print(f"  Epoch {epoch+1:02d}, Tr L:{train_loss:.4f}, Tr A:{train_acc:.4f} | V L:{val_loss:.4f}, V BAcc:{val_metrics['balanced_acc']:.4f}, V MCC:{val_metrics['mcc']:.4f}")

            # Early stopping
            if val_loss < best_val_loss:
                 best_val_loss = val_loss
                 epochs_no_improve = 0
                 best_state_dict = model.state_dict().copy() # Save best model
                 # print(f"    -> New best val loss: {best_val_loss:.4f}") # Optional print
            else:
                 epochs_no_improve += 1
                 if epochs_no_improve >= patience:
                     print(f"  Early stopping triggered at epoch {epoch+1}")
                     break
        # --- End Epoch Loop ---

        # Load best model state for final fold evaluation
        if best_state_dict is not None:
             model.load_state_dict(best_state_dict)
        else:
             print("Warning: No best model state saved for fold {fold}. Evaluating last state.")

        # Evaluate on validation set (best model for this fold)
        final_val_metrics = evaluate_model(model, val_loader, device)
        fold_metrics.append(final_val_metrics)
        print(f"  Fold {fold} Final Validation Metrics:")
        print_metrics(final_val_metrics)

        # Evaluate on test set using the best model from this fold
        if test_graphs:
            test_loader = DataLoader(test_graphs, batch_size=model_config.get('batch_size', 32), shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
            test_fold_metrics = evaluate_model(model, test_loader, device)
            test_predictions.append(test_fold_metrics['predictions']) # Store raw predictions
        # else: print("  Skipping test set evaluation as no test graphs provided/prepared.")
    # --- End Fold Loop ---

    # --- Aggregate and Report CV Results ---
    print(f"\n--- Cross-validation Summary for {run_label} ---")
    metrics_to_report = ['accuracy', 'balanced_acc', 'mcc', 'sensitivity', 'specificity', 'loss']
    avg_metrics = {}
    if fold_metrics:
        for metric in metrics_to_report:
            values = [m.get(metric, float('nan')) for m in fold_metrics if m is not None] # Get values safely
            values = [v for v in values if not np.isnan(v)] # Filter NaNs
            if values:
                 avg_metrics[f'avg_val_{metric}'] = np.mean(values)
                 avg_metrics[f'std_val_{metric}'] = np.std(values)
                 print(f"  Avg Val {metric:<15}: {avg_metrics[f'avg_val_{metric}']:.4f} ± {avg_metrics[f'std_val_{metric}']:.4f}")
            else: print(f"  Avg Val {metric:<15}: N/A")
    else: print("  No fold metrics recorded.")

    # --- Final Test Set Evaluation (Ensembled) ---
    final_test_metrics = None
    if test_predictions and test_labels is not None and len(test_labels) > 0:
        print("\n--- Final Test Set Performance (Ensembled from CV Folds) ---")
        try:
             # Check if all prediction arrays have the same length
             pred_lengths = [len(p) for p in test_predictions]
             if len(set(pred_lengths)) > 1:
                  print(f"Warning: Test predictions from folds have different lengths: {pred_lengths}. Cannot ensemble.")
             elif pred_lengths[0] != len(test_labels):
                  print(f"Warning: Length of test predictions ({pred_lengths[0]}) doesn't match test labels ({len(test_labels)}). Cannot evaluate ensemble.")
             else:
                  test_pred_stack = np.stack(test_predictions) # Shape: [num_folds, num_test_samples]
                  test_pred_avg = np.mean(test_pred_stack, axis=0) # Average probability across folds
                  test_pred_binary = (test_pred_avg > 0.5).astype(int) # Threshold

                  # Calculate final metrics using ensembled predictions
                  final_test_metrics = {
                      'accuracy': accuracy_score(test_labels, test_pred_binary),
                      'balanced_acc': balanced_accuracy_score(test_labels, test_pred_binary),
                      'mcc': matthews_corrcoef(test_labels, test_pred_binary),
                      'confusion_matrix': confusion_matrix(test_labels, test_pred_binary, labels=[0, 1])
                  }
                  cm = final_test_metrics['confusion_matrix']
                  sens, spec = 0.0, 0.0
                  if cm.shape == (2, 2):
                      tn, fp, fn, tp = cm.ravel()
                      sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                      spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
                  final_test_metrics['sensitivity'] = sens
                  final_test_metrics['specificity'] = spec

                  print_metrics(final_test_metrics) # Print the ensembled test results
        except Exception as e:
            print(f"--- Error during test set ensembling/evaluation: {e} ---")
            traceback.print_exc()
    elif not test_labels:
         print("\n--- Test set evaluation skipped (no test labels) ---")
    elif not test_predictions:
         print("\n--- Test set evaluation skipped (no test predictions generated from folds) ---")
    else:
         print("\n--- Test set evaluation skipped (unknown reason) ---")

    # Return CV fold metrics and final ensembled test metrics
    return fold_metrics, final_test_metrics


if __name__ == "__main__":
    print("Starting Hybrid GNN+CNN Model Training...")

    # --- Load Data ---
    try:
        # IMPORTANT: Update these paths to actual data files
        train_csv_path = "../../data/train/structure/processed_features_train.csv"
        test_csv_path = "../../data/test/structure/processed_features_test.csv"

        if not os.path.exists(train_csv_path):
             raise FileNotFoundError(f"Training data not found at: {train_csv_path}")
        if not os.path.exists(test_csv_path):
             raise FileNotFoundError(f"Test data not found at: {test_csv_path}")

        train_df = pd.read_csv(train_csv_path)
        test_df = pd.read_csv(test_csv_path)

        print(f"Loaded {len(train_df)} training samples and {len(test_df)} test samples.")
        print("Train class distribution:", train_df['label'].value_counts().to_dict())
        print("Test class distribution:", test_df['label'].value_counts().to_dict())

    except FileNotFoundError as e:
        print(e)
        print("Please update the file paths in Block 9.")
        exit()
    except Exception as e:
        print(f"Error loading data: {e}")
        traceback.print_exc()
        exit()

    # --- Define Model Configurations to Test ---
    # You can define multiple configurations here
    model_configs = [
        # GCN (No Edge Features)
        {
            'gnn_type': 'gcn',
            'use_edge_features': False,
            'hidden_dim': 128, 'layers': 3, 'dropout': 0.4, # GNN hypers
            'classifier_dropout': 0.5, # Final classifier dropout
            'lr': 0.001, 'weight_decay': 0.01, 'epochs': 50, 'patience': 10, 'batch_size': 32, # Training hypers
            # Add CNN hypers if you want to override defaults
        },
        # GATv2 (No Edge Features)
        {
            'gnn_type': 'gatv2',
            'use_edge_features': False,
            'hidden_dim': 128, 'layers': 3, 'dropout': 0.4, 'heads': 4, # GNN hypers
            'classifier_dropout': 0.5,
            'lr': 0.001, 'weight_decay': 0.01, 'epochs': 50, 'patience': 10, 'batch_size': 32,
        },
         # GATv2 (WITH Edge Features - requires edge_attr in data)
         {
             'gnn_type': 'gatv2',
             'use_edge_features': True, # Will use edge_attr if created by prepare_graph_data
             'hidden_dim': 128, 'layers': 3, 'dropout': 0.4, 'heads': 4, # GNN hypers
             'classifier_dropout': 0.5,
             'lr': 0.001, 'weight_decay': 0.01, 'epochs': 50, 'patience': 10, 'batch_size': 32,
         },
        # PNA (No Edge Features - requires degree calculation)
        {
            'gnn_type': 'pna',
            'use_edge_features': False,
            'hidden_dim': 128, 'layers': 3, 'dropout': 0.4, # GNN hypers
            'classifier_dropout': 0.5,
            'lr': 0.001, 'weight_decay': 0.01, 'epochs': 50, 'patience': 10, 'batch_size': 32,
        },
        # GINE (WITH Edge Features - requires edge_attr in data)
        {
            'gnn_type': 'gine',
            'use_edge_features': True, # Will use edge_attr if created by prepare_graph_data
            'hidden_dim': 128, 'layers': 3, 'dropout': 0.4, # GNN hypers
            'classifier_dropout': 0.5,
            'lr': 0.001, 'weight_decay': 0.01, 'epochs': 50, 'patience': 10, 'batch_size': 32,
        },
    ]

    # --- Run Training and Evaluation for each config ---
    results_summary = {}
    distance_threshold = 8.0 # Set your desired distance threshold

    for config in model_configs:
        gnn_type = config['gnn_type']
        use_ef = config['use_edge_features']
        config_label = f"{gnn_type}_EF={use_ef}"

        try:
            fold_metrics_list, final_test_metrics_dict = train_and_evaluate_hybrid_model(
                train_df,
                test_df,
                model_config=config,
                distance_threshold=distance_threshold
            )
            results_summary[config_label] = {
                'cv_metrics': fold_metrics_list,
                'test_metrics': final_test_metrics_dict
            }
        except Exception as e:
            print(f"\n--- !!! CRITICAL ERROR during evaluation for {config_label} !!! ---")
            print(f"Error: {e}")
            traceback.print_exc()
            results_summary[config_label] = {'error': str(e)} # Log the error

    # --- Optional: Print final summary ---
    print("\n\n===== Final Results Summary =====")
    for label, results in results_summary.items():
        print(f"\n--- Configuration: {label} ---")
        if 'error' in results:
            print(f"  ERROR: {results['error']}")
        elif results['test_metrics']:
            print("  Ensembled Test Metrics:")
            print_metrics(results['test_metrics'])
        else:
            print("  Test metrics not available.")
        # Optionally print average CV metrics again here if needed

    print("\n===== Evaluation Complete =====")

Using device: cuda
Using Seed: 42
Starting Hybrid GNN+CNN Model Training...
Loaded 8853 training samples and 2737 test samples.
Train class distribution: {1: 4592, 0: 4261}
Test class distribution: {0: 2497, 1: 240}

Preparing data (Edge Features: False, SS Node Feat: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included or generated.
Preparing data (Edge Features: False, SS Node Feat: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included or generated.

--- Fold 1/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 40
  Fold 1 Final Validation Metrics:
  Loss: 0.5408
  Accuracy: 0.7391, Balanced Acc: 0.7368, MCC: 0.4780
  Sensitivity: 0.7976, Specificity: 0.6761
  Confusion Matrix:
[[576 276]
 [186 733]]





--- Fold 2/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 41
  Fold 2 Final Validation Metrics:
  Loss: 0.5535
  Accuracy: 0.7453, Balanced Acc: 0.7446, MCC: 0.4897
  Sensitivity: 0.7639, Specificity: 0.7254
  Confusion Matrix:
[[618 234]
 [217 702]]





--- Fold 3/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 40
  Fold 3 Final Validation Metrics:
  Loss: 0.5691
  Accuracy: 0.7024, Balanced Acc: 0.7028, MCC: 0.4054
  Sensitivity: 0.6917, Specificity: 0.7140
  Confusion Matrix:
[[609 244]
 [283 635]]





--- Fold 4/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 39
  Fold 4 Final Validation Metrics:
  Loss: 0.5570
  Accuracy: 0.7119, Balanced Acc: 0.7111, MCC: 0.4226
  Sensitivity: 0.7309, Specificity: 0.6913
  Confusion Matrix:
[[589 263]
 [247 671]]





--- Fold 5/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 40
  Fold 5 Final Validation Metrics:
  Loss: 0.5674
  Accuracy: 0.7068, Balanced Acc: 0.7092, MCC: 0.4210
  Sensitivity: 0.6438, Specificity: 0.7746
  Confusion Matrix:
[[660 192]
 [327 591]]





--- Cross-validation Summary for Hybrid GCN + CNN (Edge Feats: False) ---
  Avg Val accuracy       : 0.7211 ± 0.0176
  Avg Val balanced_acc   : 0.7209 ± 0.0166
  Avg Val mcc            : 0.4433 ± 0.0338
  Avg Val sensitivity    : 0.7256 ± 0.0539
  Avg Val specificity    : 0.7163 ± 0.0339
  Avg Val loss           : 0.5576 ± 0.0103

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.7209, Balanced Acc: 0.7547, MCC: 0.3062
  Sensitivity: 0.7958, Specificity: 0.7137
  Confusion Matrix:
[[1782  715]
 [  49  191]]

Preparing data (Edge Features: False, SS Node Feat: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included or generated.
Preparing data (Edge Features: False, SS Node Feat: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included or generated.

--- Fold 1/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridMo



  Early stopping triggered at epoch 37
  Fold 1 Final Validation Metrics:
  Loss: 0.5508
  Accuracy: 0.7199, Balanced Acc: 0.7168, MCC: 0.4406
  Sensitivity: 0.7987, Specificity: 0.6350
  Confusion Matrix:
[[541 311]
 [185 734]]





--- Fold 2/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 44
  Fold 2 Final Validation Metrics:
  Loss: 0.5582
  Accuracy: 0.7436, Balanced Acc: 0.7407, MCC: 0.4886
  Sensitivity: 0.8194, Specificity: 0.6620
  Confusion Matrix:
[[564 288]
 [166 753]]





--- Fold 3/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 34
  Fold 3 Final Validation Metrics:
  Loss: 0.6017
  Accuracy: 0.7047, Balanced Acc: 0.6988, MCC: 0.4216
  Sensitivity: 0.8584, Specificity: 0.5393
  Confusion Matrix:
[[460 393]
 [130 788]]





--- Fold 4/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 42
  Fold 4 Final Validation Metrics:
  Loss: 0.5584
  Accuracy: 0.7266, Balanced Acc: 0.7243, MCC: 0.4526
  Sensitivity: 0.7854, Specificity: 0.6631
  Confusion Matrix:
[[565 287]
 [197 721]]





--- Fold 5/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 38
  Fold 5 Final Validation Metrics:
  Loss: 0.5503
  Accuracy: 0.7282, Balanced Acc: 0.7241, MCC: 0.4611
  Sensitivity: 0.8344, Specificity: 0.6138
  Confusion Matrix:
[[523 329]
 [152 766]]





--- Cross-validation Summary for Hybrid GATV2 + CNN (Edge Feats: False) ---
  Avg Val accuracy       : 0.7246 ± 0.0126
  Avg Val balanced_acc   : 0.7209 ± 0.0135
  Avg Val mcc            : 0.4529 ± 0.0223
  Avg Val sensitivity    : 0.8193 ± 0.0258
  Avg Val specificity    : 0.6226 ± 0.0455
  Avg Val loss           : 0.5639 ± 0.0192

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.6222, Balanced Acc: 0.7270, MCC: 0.2587
  Sensitivity: 0.8542, Specificity: 0.5999
  Confusion Matrix:
[[1498  999]
 [  35  205]]

Preparing data (Edge Features: True, SS Node Feat: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17
Preparing data (Edge Features: True, SS Node Feat: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17

--- Fold 1/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 



  Early stopping triggered at epoch 42
  Fold 1 Final Validation Metrics:
  Loss: 0.5535
  Accuracy: 0.7267, Balanced Acc: 0.7245, MCC: 0.4528
  Sensitivity: 0.7835, Specificity: 0.6655
  Confusion Matrix:
[[567 285]
 [199 720]]





--- Fold 2/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Fold 2 Final Validation Metrics:
  Loss: 0.5442
  Accuracy: 0.7205, Balanced Acc: 0.7220, MCC: 0.4444
  Sensitivity: 0.6834, Specificity: 0.7606
  Confusion Matrix:
[[648 204]
 [291 628]]





--- Fold 3/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 39
  Fold 3 Final Validation Metrics:
  Loss: 0.5515
  Accuracy: 0.7233, Balanced Acc: 0.7190, MCC: 0.4522
  Sensitivity: 0.8366, Specificity: 0.6014
  Confusion Matrix:
[[513 340]
 [150 768]]





--- Fold 4/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 34
  Fold 4 Final Validation Metrics:
  Loss: 0.5567
  Accuracy: 0.7288, Balanced Acc: 0.7270, MCC: 0.4567
  Sensitivity: 0.7756, Specificity: 0.6784
  Confusion Matrix:
[[578 274]
 [206 712]]





--- Fold 5/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 42
  Fold 5 Final Validation Metrics:
  Loss: 0.5664
  Accuracy: 0.7153, Balanced Acc: 0.7160, MCC: 0.4319
  Sensitivity: 0.6950, Specificity: 0.7371
  Confusion Matrix:
[[628 224]
 [280 638]]





--- Cross-validation Summary for Hybrid GATV2 + CNN (Edge Feats: True) ---
  Avg Val accuracy       : 0.7229 ± 0.0048
  Avg Val balanced_acc   : 0.7217 ± 0.0039
  Avg Val mcc            : 0.4476 ± 0.0088
  Avg Val sensitivity    : 0.7548 ± 0.0577
  Avg Val specificity    : 0.6886 ± 0.0562
  Avg Val loss           : 0.5545 ± 0.0072

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.7033, Balanced Acc: 0.7357, MCC: 0.2805
  Sensitivity: 0.7750, Specificity: 0.6964
  Confusion Matrix:
[[1739  758]
 [  54  186]]

Preparing data (Edge Features: False, SS Node Feat: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included or generated.
Preparing data (Edge Features: False, SS Node Feat: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge features not included or generated.
Calculating node degree histogram for PNA...
Max degree found: 18

--- Fold 1



  Early stopping triggered at epoch 32
  Fold 1 Final Validation Metrics:
  Loss: 0.6924
  Accuracy: 0.6307, Balanced Acc: 0.6403, MCC: 0.3227
  Sensitivity: 0.3874, Specificity: 0.8932
  Confusion Matrix:
[[761  91]
 [563 356]]





--- Fold 2/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 43
  Fold 2 Final Validation Metrics:
  Loss: 0.6138
  Accuracy: 0.6714, Balanced Acc: 0.6791, MCC: 0.3893
  Sensitivity: 0.4755, Specificity: 0.8826
  Confusion Matrix:
[[752 100]
 [482 437]]





--- Fold 3/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 35
  Fold 3 Final Validation Metrics:
  Loss: 0.5752
  Accuracy: 0.7075, Balanced Acc: 0.7099, MCC: 0.4225
  Sensitivity: 0.6438, Specificity: 0.7761
  Confusion Matrix:
[[662 191]
 [327 591]]





--- Fold 4/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 29
  Fold 4 Final Validation Metrics:
  Loss: 0.6095
  Accuracy: 0.6605, Balanced Acc: 0.6684, MCC: 0.3691
  Sensitivity: 0.4564, Specificity: 0.8803
  Confusion Matrix:
[[750 102]
 [499 419]]





--- Fold 5/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 23
  Fold 5 Final Validation Metrics:
  Loss: 0.6064
  Accuracy: 0.6689, Balanced Acc: 0.6767, MCC: 0.3859
  Sensitivity: 0.4684, Specificity: 0.8850
  Confusion Matrix:
[[754  98]
 [488 430]]





--- Cross-validation Summary for Hybrid PNA + CNN (Edge Feats: False) ---
  Avg Val accuracy       : 0.6678 ± 0.0246
  Avg Val balanced_acc   : 0.6749 ± 0.0223
  Avg Val mcc            : 0.3779 ± 0.0326
  Avg Val sensitivity    : 0.4863 ± 0.0848
  Avg Val specificity    : 0.8634 ± 0.0439
  Avg Val loss           : 0.6195 ± 0.0389

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.8440, Balanced Acc: 0.6810, MCC: 0.2843
  Sensitivity: 0.4833, Specificity: 0.8787
  Confusion Matrix:
[[2194  303]
 [ 124  116]]

Preparing data (Edge Features: True, SS Node Feat: True, Threshold: 8.0Å)...
Created 8853 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17
Preparing data (Edge Features: True, SS Node Feat: True, Threshold: 8.0Å)...
Created 2737 graphs, skipped 0 rows.
 Node feature dimension: 27
 Edge feature dimension: 17

--- Fold 1/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (1



  Early stopping triggered at epoch 41
  Fold 1 Final Validation Metrics:
  Loss: 0.5473
  Accuracy: 0.7386, Balanced Acc: 0.7392, MCC: 0.4780
  Sensitivity: 0.7236, Specificity: 0.7547
  Confusion Matrix:
[[643 209]
 [254 665]]





--- Fold 2/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 43
  Fold 2 Final Validation Metrics:
  Loss: 0.5525
  Accuracy: 0.7261, Balanced Acc: 0.7245, MCC: 0.4511
  Sensitivity: 0.7682, Specificity: 0.6808
  Confusion Matrix:
[[580 272]
 [213 706]]





--- Fold 3/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 38
  Fold 3 Final Validation Metrics:
  Loss: 0.5737
  Accuracy: 0.7081, Balanced Acc: 0.7094, MCC: 0.4191
  Sensitivity: 0.6743, Specificity: 0.7444
  Confusion Matrix:
[[635 218]
 [299 619]]





--- Fold 4/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 26
  Fold 4 Final Validation Metrics:
  Loss: 0.5586
  Accuracy: 0.7209, Balanced Acc: 0.7185, MCC: 0.4414
  Sensitivity: 0.7832, Specificity: 0.6538
  Confusion Matrix:
[[557 295]
 [199 719]]





--- Fold 5/5 ---
SequenceCNN initialized. Calculated flat size: 2304
HybridModel Combined Input Dim: 416 (128+128+128+32)




  Early stopping triggered at epoch 36
  Fold 5 Final Validation Metrics:
  Loss: 0.5629
  Accuracy: 0.7266, Balanced Acc: 0.7265, MCC: 0.4527
  Sensitivity: 0.7288, Specificity: 0.7242
  Confusion Matrix:
[[617 235]
 [249 669]]





--- Cross-validation Summary for Hybrid GINE + CNN (Edge Feats: True) ---
  Avg Val accuracy       : 0.7240 ± 0.0099
  Avg Val balanced_acc   : 0.7236 ± 0.0098
  Avg Val mcc            : 0.4485 ± 0.0190
  Avg Val sensitivity    : 0.7356 ± 0.0382
  Avg Val specificity    : 0.7116 ± 0.0384
  Avg Val loss           : 0.5590 ± 0.0091

--- Final Test Set Performance (Ensembled from CV Folds) ---
  Loss: nan
  Accuracy: 0.7081, Balanced Acc: 0.7402, MCC: 0.2867
  Sensitivity: 0.7792, Specificity: 0.7012
  Confusion Matrix:
[[1751  746]
 [  53  187]]


===== Final Results Summary =====

--- Configuration: gcn_EF=False ---
  Ensembled Test Metrics:
  Loss: nan
  Accuracy: 0.7209, Balanced Acc: 0.7547, MCC: 0.3062
  Sensitivity: 0.7958, Specificity: 0.7137
  Confusion Matrix:
[[1782  715]
 [  49  191]]

--- Configuration: gatv2_EF=False ---
  Ensembled Test Metrics:
  Loss: nan
  Accuracy: 0.6222, Balanced Acc: 0.7270, MCC: 0.2587
  Sensitivity: 0.8542, Specificity: 0.5999
  Confusion Matrix:
