In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, MessagePassing # Removed unused imports like NNConv, EdgeConv
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader
import numpy as np
import pandas as pd
# from sklearn.preprocessing import RobustScaler # Keep import if you decide to use pre-scaling later
import traceback # For better error reporting
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import StratifiedKFold
# from sklearn.preprocessing import RobustScaler # Already imported if needed
import matplotlib.pyplot as plt
import random
import os

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    # Consider adding these for stricter reproducibility on GPU, might affect performance
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Amino Acid Encodings ---
# For sequence CNN track
AMINO_ACIDS = 'ARNDCQEGHILKMFPSTWYV-'  # Include padding char '-'
AA_TO_INT = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
VALID_AA = 'ARNDCQEGHILKMFPSTWYV'
PROT_T5_DIM = 1024 # Define ProtT5 embedding dimension

# --- ProtT5 Data Loading and Alignment ---
def load_prot_t5_data(pos_file, neg_file):
    """Load ProtT5 embeddings and return dictionaries keyed by (entry, pos)."""
    print(f"Loading ProtT5 data from {pos_file} and {neg_file}")
    pos_data = []
    try:
        with open(pos_file, 'r') as f:
            for i, line in enumerate(f):
                try:
                    parts = line.strip().split(',')
                    if len(parts) < 3:
                        print(f"Warning: Skipping malformed line {i+1} in {pos_file}: {line.strip()}")
                        continue
                    entry = parts[0]
                    pos = int(parts[1])
                    # Ensure embeddings have the expected dimension
                    embeddings = [float(x) for x in parts[2:]]
                    if len(embeddings) != PROT_T5_DIM:
                         print(f"Warning: Incorrect embedding dim ({len(embeddings)} vs {PROT_T5_DIM}) in {pos_file} line {i+1}. Skipping.")
                         continue
                    pos_data.append((entry, pos, embeddings))
                except Exception as e:
                    print(f"Error processing line {i+1} in {pos_file}: {line.strip()} - {e}")
    except FileNotFoundError:
        print(f"Error: Positive ProtT5 file not found: {pos_file}")
        return None, None
    except Exception as e:
        print(f"Error reading {pos_file}: {e}")
        return None, None

    neg_data = []
    try:
        with open(neg_file, 'r') as f:
            for i, line in enumerate(f):
                try:
                    parts = line.strip().split(',')
                    if len(parts) < 3:
                        print(f"Warning: Skipping malformed line {i+1} in {neg_file}: {line.strip()}")
                        continue
                    entry = parts[0]
                    pos = int(parts[1])
                    embeddings = [float(x) for x in parts[2:]]
                    if len(embeddings) != PROT_T5_DIM:
                         print(f"Warning: Incorrect embedding dim ({len(embeddings)} vs {PROT_T5_DIM}) in {neg_file} line {i+1}. Skipping.")
                         continue
                    neg_data.append((entry, pos, embeddings))
                except Exception as e:
                    print(f"Error processing line {i+1} in {neg_file}: {line.strip()} - {e}")
    except FileNotFoundError:
        print(f"Error: Negative ProtT5 file not found: {neg_file}")
        return None, None
    except Exception as e:
        print(f"Error reading {neg_file}: {e}")
        return None, None

    # Convert to dictionaries for easy lookup
    pos_dict = {(entry, pos): emb for entry, pos, emb in pos_data}
    neg_dict = {(entry, pos): emb for entry, pos, emb in neg_data}
    print(f"Loaded {len(pos_dict)} positive and {len(neg_dict)} negative ProtT5 embeddings.")
    return pos_dict, neg_dict

def prepare_aligned_data(seq_struct_df, pos_dict, neg_dict):
    """Align ProtT5 embeddings with sequence+structure data"""
    print("Aligning ProtT5 embeddings with main DataFrame...")
    embeddings = []
    aligned_indices = [] # Store original indices of aligned rows
    skipped_count = 0

    if pos_dict is None or neg_dict is None:
        print("Error: ProtT5 dictionaries not loaded. Cannot align.")
        # Return empty alignment, original df to avoid crashing downstream
        return np.array([]).reshape(0, PROT_T5_DIM), seq_struct_df

    for idx, row in seq_struct_df.iterrows():
        # Use 'entry' and 'pos' if they exist, otherwise adapt based on actual DataFrame columns
        if 'entry' not in row or 'pos' not in row:
            raise KeyError("DataFrame must contain 'entry' and 'pos' columns for alignment.")
        key = (row['entry'], row['pos'])
        emb = pos_dict.get(key) if row['label'] == 1 else neg_dict.get(key)

        if emb is not None:
            embeddings.append(emb)
            aligned_indices.append(idx) # Keep track of the original index
        else:
            skipped_count += 1
            # print(f"Warning: No ProtT5 embedding found for key {key} (Label: {row['label']}). Skipping row index {idx}.")

    if skipped_count > 0:
        print(f"Warning: Skipped {skipped_count} rows due to missing ProtT5 embeddings.")

    if not aligned_indices:
        print("Error: No data points could be aligned with ProtT5 embeddings.")
        return np.array([]).reshape(0, PROT_T5_DIM), pd.DataFrame(columns=seq_struct_df.columns)


    # Convert to numpy array
    X_prot_t5 = np.array(embeddings, dtype=np.float32)

    # Get aligned sequence+structure data using the stored original indices
    aligned_df = seq_struct_df.loc[aligned_indices].copy() # Use .loc and .copy()
    print(f"Alignment complete. Kept {len(aligned_df)} out of {len(seq_struct_df)} original rows.")

    return X_prot_t5, aligned_df


# --- GNN Architectures
class GCNNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, dropout=0.4, layers=3):
        super(GCNNetwork, self).__init__()

        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()

        # First layer
        self.convs.append(GCNConv(input_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # Middle layers
        for i in range(1, layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        self.dropout = nn.Dropout(dropout)
        self.output_dim = hidden_dim # Store output dimension

    def forward(self, x, edge_index): # Keep modified signature
        # x: Node features [num_nodes, input_dim]
        # edge_index: Graph connectivity [2, num_edges]

        # First layer
        x = self.convs[0](x, edge_index)
        if x.shape[0] > 1: # BatchNorm requires more than 1 sample
            x = self.batch_norms[0](x)
        x = F.relu(x)
        x = self.dropout(x)

        # Middle layers with residual connections
        for i in range(1, len(self.convs)):
            x_res = x
            x = self.convs[i](x, edge_index)
            if x.shape[0] > 1:
                x = self.batch_norms[i](x)
            x = F.relu(x)
            if x_res.shape == x.shape:
                x = x + x_res
            x = self.dropout(x)

        # Return ALL final node features
        return x # Shape: [num_nodes, output_dim]

# --- Sequence CNN Module ---
class SequenceCNN(nn.Module):
    def __init__(self, vocab_size, embed_dim=21, out_channels=32, kernel_height=17, kernel_width=3, dropout=0.4):
        super(SequenceCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=AA_TO_INT['-'])

        self.conv2d = nn.Conv2d(
            in_channels=1,
            out_channels=out_channels,
            kernel_size=(kernel_height, kernel_width),
            padding='valid'
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.dropout1 = nn.Dropout(dropout)

        # Calculate flattened size dynamically (more robust)
        # We need a dummy input to calculate the size
        # Assuming seq_len=33 (can be passed as arg if variable)
        self._calculate_flat_size(33, embed_dim, out_channels, kernel_height, kernel_width)

        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(self.flat_size, 32)
        self.bn2 = nn.BatchNorm1d(32)
        self.dropout2 = nn.Dropout(dropout)
        self.output_dim = 32 # Store output dimension

    def _calculate_flat_size(self, seq_len, embed_dim, out_channels, kernel_h, kernel_w):
        # Calculate output size of Conv2d 'valid'
        h_out = seq_len - kernel_h + 1
        w_out = embed_dim - kernel_w + 1
        # Calculate output size of MaxPool2d
        h_pool = (h_out // 2)
        w_pool = (w_out // 2)
        self.flat_size = out_channels * h_pool * w_pool
        print(f"SequenceCNN calculated flat size: {self.flat_size}")


    def forward(self, seq_indices):
        # seq_indices shape: [batch_size, seq_len]
        x = self.embedding(seq_indices)  # [batch_size, seq_len, embed_dim]
        x = x.unsqueeze(1)  # [batch_size, 1, seq_len, embed_dim]

        x = self.conv2d(x)
        if x.shape[0] > 1: # BatchNorm requires > 1 sample
             x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.pool(x)

        x = self.flatten(x)
        x = self.fc1(x)
        if x.shape[0] > 1:
             x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout2(x)

        return x


# --- *** Hybrid Model (GNN + CNN + ProtT5) *** ---
class HybridModel(nn.Module):
    def __init__(self, gnn_type, node_feature_dim, edge_feature_dim=18, # edge_feature_dim might vary
                 hidden_dim=128, seq_len=33, prot_t5_dim=PROT_T5_DIM):
        super(HybridModel, self).__init__()

        # GNN track
        self.gnn_type = gnn_type
        if gnn_type == 'gcn':
            self.gnn = GCNNetwork(node_feature_dim, hidden_dim)

        else:
            raise ValueError(f"Unsupported GNN type: {gnn_type}")
        gnn_output_dim = self.gnn.output_dim

        # Sequence CNN track
        self.sequence_cnn = SequenceCNN(
            vocab_size=len(AMINO_ACIDS),
            embed_dim=21,
            out_channels=32
        )
        cnn_output_dim = self.sequence_cnn.output_dim

        # --- ProtT5 track ---
        self.prot_t5_mlp = nn.Sequential(
            nn.Linear(prot_t5_dim, 256),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.4)
        )
        prot_t5_output_dim = 128

        # --- Combination layers ---
        if gnn_type == 'gcn':
             combined_input_dim = gnn_output_dim + gnn_output_dim + gnn_output_dim + cnn_output_dim + prot_t5_output_dim
        else:
             combined_input_dim = cnn_output_dim + prot_t5_output_dim

        print(f"HybridModel combined input dimension for fc1: {combined_input_dim}")

        self.fc1 = nn.Linear(combined_input_dim, 64)
        self.bn = nn.BatchNorm1d(64)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, data):
        # --- Input Extraction ---
        x, edge_index = data.x, data.edge_index
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        batch = data.batch
        central_node_idx = data.central_node_idx
        ptr = data.ptr if hasattr(data, 'ptr') else None

        # Handle missing batch/ptr
        if batch is None: batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)
        if ptr is None:
             if batch.numel() > 0:
                 counts = torch.bincount(batch)
                 ptr = torch.cat([torch.tensor([0], device=batch.device), counts.cumsum(0)])
             else:
                 ptr = torch.tensor([0], device=batch.device)

        # Sequence Input
        seq_flat = data.sequence
        batch_size = data.num_graphs # Use num_graphs as the reliable batch size
        seq_len = 33
        # (Sequence reshaping and error handling as before)
        try:
            seq_tensor = seq_flat.view(batch_size, seq_len)
        except RuntimeError as e:
            print(f"Error reshaping sequence tensor (shape: {seq_flat.shape}) to ({batch_size}, {seq_len}): {e}. Returning zeros.")
            return torch.zeros((batch_size, 1), device=x.device, dtype=torch.float)

        # ProtT5 Input
        if not hasattr(data, 'prot_t5_embedding'):
             raise AttributeError("Data object missing 'prot_t5_embedding'. Check data preparation.")
        prot_t5_emb_flat = data.prot_t5_embedding # This is the potentially flattened tensor

        try:
            # Reshape using batch_size (data.num_graphs) and known dimension
            prot_t5_emb = prot_t5_emb_flat.view(batch_size, PROT_T5_DIM)
        except RuntimeError as e:
            print(f"Error reshaping ProtT5 embeddings: {e}")
            print(f"Original flat shape: {prot_t5_emb_flat.shape}, Target shape: ({batch_size}, {PROT_T5_DIM})")
            # If reshape fails, likely indicates a deeper issue in data loading/batching
            return torch.zeros((batch_size, 1), device=x.device, dtype=torch.float) # Return dummy output

        # --- GNN Track Processing ---
        central_node_features = None
        global_avg_features = None
        global_max_features = None
        if self.gnn_type == 'gcn':
            gnn_node_features = self.gnn(x, edge_index)
            # Extract Central Node Features
            graph_starts = ptr[:-1]
            absolute_central_node_indices = graph_starts + central_node_idx
            if absolute_central_node_indices.numel() > 0:
                 if absolute_central_node_indices.max() >= gnn_node_features.shape[0]:
                      raise IndexError(f"Hybrid GCN absolute central node index out of bounds: Max index {absolute_central_node_indices.max()} vs shape {gnn_node_features.shape[0]}")
                 central_node_features = gnn_node_features[absolute_central_node_indices]
            elif gnn_node_features.shape[0] > 0:
                 print("Warning: GCN central node indices empty despite nodes existing.")
                 central_node_features = torch.zeros((batch_size, self.gnn.output_dim), device=gnn_node_features.device)
            else:
                 central_node_features = torch.zeros((batch_size, self.gnn.output_dim), device=gnn_node_features.device)
            # Perform Global Pooling
            if gnn_node_features.shape[0] > 0:
                 global_avg_features = global_mean_pool(gnn_node_features, batch)
                 global_max_features = global_max_pool(gnn_node_features, batch)
            else:
                 global_avg_features = torch.zeros((batch_size, self.gnn.output_dim), device=gnn_node_features.device)
                 global_max_features = torch.zeros((batch_size, self.gnn.output_dim), device=gnn_node_features.device)

        # --- CNN Track ---
        seq_features = self.sequence_cnn(seq_tensor)

        # --- ProtT5 Track ---
        prot_t5_features = self.prot_t5_mlp(prot_t5_emb)

        # --- Combine features (as before) ---
        features_to_combine = []
        if self.gnn_type == 'gcn':
             features_to_combine = [central_node_features, global_avg_features, global_max_features, seq_features, prot_t5_features]

        # (Feature validation and concatenation as before)
        valid_features = []
        all_shapes_match = True
        expected_batch_size = batch_size
        for i, feat in enumerate(features_to_combine):
            if feat is None:
                 print(f"Error: Feature at index {i} is None before concatenation.")
                 all_shapes_match = False; break
            if feat.shape[0] != expected_batch_size:
                print(f"Warning: Feature dimension mismatch at index {i}. Expected batch size {expected_batch_size}, got {feat.shape[0]}.")
                all_shapes_match = False; break
            valid_features.append(feat)

        if not all_shapes_match or not valid_features:
            print("Error during feature combination due to shape mismatch or missing features. Returning zeros.")
            return torch.zeros((batch_size, 1), device=x.device, dtype=torch.float)

        try:
             combined = torch.cat(valid_features, dim=1)
        except Exception as e:
             print(f"Error during torch.cat: {e}"); traceback.print_exc()
             return torch.zeros((batch_size, 1), device=x.device, dtype=torch.float)

        # --- Final classification (as before) ---
        x = self.fc1(combined)
        if x.shape[0] > 1: x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return torch.sigmoid(x)


def integer_encode_sequence(sequence):
    """Integer encode sequence for the CNN track (includes padding char)."""
    return [AA_TO_INT.get(char, AA_TO_INT['-']) for char in sequence]


# --- *** prepare_graph_data *** ---
def prepare_graph_data(df_aligned, prot_t5_embeddings, distance_threshold=8.0, use_ss=True):
    """
    Prepare graph data for PyTorch Geometric using the ALIGNED DataFrame
    and corresponding ProtT5 embeddings. Adds ProtT5 embedding to Data object.

    Args:
        df_aligned (pd.DataFrame): DataFrame already filtered to include only rows
                                   with corresponding ProtT5 embeddings.
        prot_t5_embeddings (np.ndarray): NumPy array of ProtT5 embeddings,
                                        ordered corresponding to df_aligned rows.
        distance_threshold (float): Max distance for graph edges.
        use_ss (bool): Flag to include secondary structure features (if available and uncommented).
    """
    print(f"Preparing PyG graph data for {len(df_aligned)} aligned samples...")
    graph_list = []
    labels = []
    skipped = 0
    expected_seq_len = 33
    central_k_pos_abs = 16 # 0-based index in the 33-residue window

    # Ensure ProtT5 embeddings match DataFrame length
    if len(df_aligned) != prot_t5_embeddings.shape[0]:
         raise ValueError(f"Mismatch between aligned DataFrame length ({len(df_aligned)}) and ProtT5 embeddings count ({prot_t5_embeddings.shape[0]})")

    feature_names_to_parse = [
        'phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4',
        'sasa', 'ss', 'plDDT', 'distance_map'
        # Add other features from CSV if needed
    ]

    # Iterate using index to access corresponding ProtT5 embedding
    for i, (idx, row) in enumerate(df_aligned.iterrows()): # Use index 'i' for prot_t5_embeddings
        try:
            sequence = row['sequence']
            label = row['label']
            current_prot_t5_embedding = prot_t5_embeddings[i] # Get embedding by order

            # --- Initial Validation ---
            if pd.isna(sequence) or len(sequence) != expected_seq_len:
                print(f"Warning: Invalid sequence (len={len(sequence)}) at aligned index {i} (original index {idx}). Skipping.")
                skipped += 1
                continue
            if sequence[central_k_pos_abs] != 'K':
                 # This check might be redundant if alignment/filtering already ensured this
                 # Keep it as a safeguard
                 print(f"Warning: Central residue is not K ('{sequence[central_k_pos_abs]}') at aligned index {i} (original index {idx}). Skipping.")
                 skipped += 1
                 continue


            # --- Parse Required Structural Data ---
            parsed_data = {}
            valid_row = True
            for name in feature_names_to_parse:
                if name not in row or pd.isna(row[name]):
                    # Decide how critical missing features are. Maybe allow processing with NaNs/zeros?
                    # For now, skip if essential features like distance_map are missing.
                    if name == 'distance_map':
                         print(f"Warning: Missing or NaN data for critical feature '{name}' in row index {idx}. Skipping.")
                         valid_row = False
                         break
                    else:
                         # Handle non-critical missing data (e.g., set to NaN or default)
                         parsed_data[name] = np.nan # Placeholder
                         # print(f"Debug: Missing data for '{name}' in row index {idx}. Setting to NaN.")

                else:
                    try:
                        if name == 'ss':
                            parsed_data[name] = str(row[name])
                            if len(parsed_data[name]) != expected_seq_len:
                                raise ValueError(f"SS sequence length mismatch ({len(parsed_data[name])} vs {expected_seq_len})")
                        elif name == 'distance_map':
                            # Assuming distance map is stored as a flat string list
                             parsed_data[name] = np.array(eval(str(row[name])), dtype=np.float32)
                             if parsed_data[name].size == expected_seq_len * expected_seq_len:
                                 parsed_data[name] = parsed_data[name].reshape(expected_seq_len, expected_seq_len)
                             else:
                                 raise ValueError(f"Unexpected distance_map size ({parsed_data[name].size})")
                        else:
                            # Assuming other features are numeric lists/arrays stored as strings
                            parsed_data[name] = np.array(eval(str(row[name])), dtype=np.float32)
                            # Basic shape check for vector features
                            if parsed_data[name].ndim == 1 and len(parsed_data[name]) != expected_seq_len:
                                # Handle potential length mismatch for 1D features (e.g., angles)
                                print(f"Warning: Feature '{name}' length mismatch ({len(parsed_data[name])} vs {expected_seq_len}) at index {idx}. Check data source. Using NaN padding/truncation.")
                                temp_arr = np.full(expected_seq_len, np.nan, dtype=np.float32)
                                L = min(len(parsed_data[name]), expected_seq_len)
                                temp_arr[:L] = parsed_data[name][:L]
                                parsed_data[name] = temp_arr


                    except Exception as e:
                        print(f"Warning: Error parsing '{name}' in row index {idx}: {e}. Skipping row.")
                        valid_row = False
                        break

            if not valid_row:
                skipped += 1
                continue

            distance_map = parsed_data.get('distance_map')
            if distance_map is None: # Double check after parsing loop
                 print(f"Error: distance_map is missing or failed to parse for row index {idx}. Skipping.")
                 skipped+=1
                 continue


            # --- Identify Valid (Non-padded) Positions ---
            valid_pos_indices = [k for k, aa in enumerate(sequence) if aa != '-']
            if not valid_pos_indices:
                skipped += 1
                continue

            num_nodes = len(valid_pos_indices)
            valid_sequence = ''.join([sequence[k] for k in valid_pos_indices])

            # Find new 0-based index of central K within the valid nodes
            try:
                central_k_new_idx = valid_pos_indices.index(central_k_pos_abs)
            except ValueError:
                skipped += 1
                continue

            # --- Node Feature Extraction ---
            node_features_list = []

            # 1. One-hot encode amino acids (20 features)
            aa_onehot = np.zeros((num_nodes, len(VALID_AA)), dtype=np.float32)
            for k, node_idx in enumerate(valid_pos_indices): # Use k for node index, node_idx for sequence index
                aa = sequence[node_idx]
                aa_idx_lookup = VALID_AA.find(aa)
                if aa_idx_lookup >= 0:
                    aa_onehot[k, aa_idx_lookup] = 1.0
            node_features_list.append(aa_onehot)

            # 2. Central K indicator (1 feature)
            is_central_k = np.zeros((num_nodes, 1), dtype=np.float32)
            is_central_k[central_k_new_idx, 0] = 1.0
            node_features_list.append(is_central_k)

            # 3. Process Angles (phi, psi, omega) -> sin/cos encoding (2 features each)
            angle_keys = ['phi', 'psi', 'omega'] # Define angles to use
            for key in angle_keys:
                if key in parsed_data and isinstance(parsed_data[key], np.ndarray):
                     # Extract angles for valid positions using valid_pos_indices
                     valid_angles = parsed_data[key][valid_pos_indices]
                     # Handle potential NaNs (e.g., from parsing or missing data) -> replace with 0
                     valid_angles = np.nan_to_num(valid_angles, nan=0.0)
                     angle_rad = np.pi * valid_angles / 180.0
                     sin_cos_features = np.stack([np.sin(angle_rad), np.cos(angle_rad)], axis=-1)
                     node_features_list.append(sin_cos_features.astype(np.float32))
                else:
                     # Append zeros if angle data is missing entirely
                     print(f"Warning: Missing or invalid data for angle '{key}' at index {idx}. Using zero features.")
                     node_features_list.append(np.zeros((num_nodes, 2), dtype=np.float32))


            # 4. Process SASA (Optional - Uncomment if needed)
            # if 'sasa' in parsed_data and isinstance(parsed_data['sasa'], np.ndarray):
            #      valid_sasa = parsed_data['sasa'][valid_pos_indices].reshape(-1, 1)
            #      valid_sasa = np.nan_to_num(valid_sasa, nan=0.0)
            #      # Consider scaling SASA values if the range is large
            #      # scaler = RobustScaler() # Or StandardScaler
            #      # valid_sasa = scaler.fit_transform(valid_sasa)
            #      node_features_list.append(valid_sasa.astype(np.float32))
            # else:
            #      node_features_list.append(np.zeros((num_nodes, 1), dtype=np.float32))


            # 5. Process SS (Secondary Structure) (Optional - Uncomment if needed)
            # if use_ss and 'ss' in parsed_data and isinstance(parsed_data['ss'], str):
            #     ss_string = parsed_data['ss']
            #     valid_ss = [ss_string[k] for k in valid_pos_indices]
            #     ss_onehot = np.zeros((num_nodes, 3), dtype=np.float32) # H=0, E=1, L=2
            #     ss_map = {'H': 0, 'E': 1, 'L': 2, '-': 2, 'C': 2} # Map '-', 'C' to 'L'
            #     for k, ss_char in enumerate(valid_ss):
            #         ss_idx = ss_map.get(ss_char.upper(), 2) # Default to 'L'
            #         ss_onehot[k, ss_idx] = 1.0
            #     node_features_list.append(ss_onehot)
            # elif use_ss: # If use_ss is True but data missing
            #      node_features_list.append(np.zeros((num_nodes, 3), dtype=np.float32))


            # 6. Process plDDT (Optional - Uncomment if needed)
            # if 'plDDT' in parsed_data and isinstance(parsed_data['plDDT'], np.ndarray):
            #      valid_plddt = parsed_data['plDDT'][valid_pos_indices].reshape(-1, 1)
            #      valid_plddt = np.nan_to_num(valid_plddt, nan=50.0) # Replace NaN pLDDT with 50?
            #      # Scale pLDDT (usually 0-100) -> maybe to 0-1
            #      # valid_plddt = valid_plddt / 100.0
            #      node_features_list.append(valid_plddt.astype(np.float32))
            # else:
            #      node_features_list.append(np.zeros((num_nodes, 1), dtype=np.float32))


            # --- Concatenate all node features ---
            try:
                if not node_features_list:
                     raise ValueError("No node features were generated.")
                node_features = np.concatenate(node_features_list, axis=1)
            except ValueError as e:
                 print(f"Error concatenating node features for row index {idx}: {e}")
                 # Print shapes for debugging
                 for k, feat in enumerate(node_features_list):
                     print(f"  Feature {k} shape: {feat.shape if isinstance(feat, np.ndarray) else 'Invalid Type'}")
                 skipped += 1
                 continue

            # --- Edge Construction (based on distance map) ---
            valid_distance_map = distance_map[np.ix_(valid_pos_indices, valid_pos_indices)]

            edges = []
            edge_features = [] # Example: Use simple distance bins

            # Create edges based on distance threshold
            adj = (valid_distance_map < distance_threshold) & (valid_distance_map > 0)
            np.fill_diagonal(adj, False) # No self-loops from distance
            edge_list = np.argwhere(adj)

            if edge_list.shape[0] > 0:
                edges = edge_list.tolist()
                # Basic edge features: distance bins (example: 4 bins)
                num_dist_bins = 4
                for row_idx, col_idx in edges:
                    dist = valid_distance_map[row_idx, col_idx]
                    edge_feature = np.zeros(num_dist_bins, dtype=np.float32)
                    if dist <= 4.0: edge_feature[0] = 1.0
                    elif dist <= 8.0: edge_feature[1] = 1.0
                    elif dist <= 12.0: edge_feature[2] = 1.0
                    else: edge_feature[3] = 1.0 # > 12 Angstrom bin
                    edge_features.append(edge_feature)

            # Add sequential edges if NO distance-based edges exist (or always add them)
            add_sequential_always = False # Set to True to always add sequential edges
            if (add_sequential_always or not edges) and num_nodes > 1:
                seq_edges_added = 0
                default_edge_feat = np.array([1.0] + [0.0]*(num_dist_bins-1), dtype=np.float32) # Assume sequential are close
                for k in range(num_nodes - 1):
                    # Avoid adding duplicate edges if they already exist from distance
                    if [k, k+1] not in edges:
                         edges.append([k, k+1])
                         edge_features.append(default_edge_feat)
                         seq_edges_added += 1
                    if [k+1, k] not in edges:
                         edges.append([k+1, k])
                         edge_features.append(default_edge_feat)
                         seq_edges_added += 1
                # if seq_edges_added > 0: print(f"Debug: Added {seq_edges_added} sequential edges for index {idx}")


            # Skip graph if it has no nodes or no edges at all
            if num_nodes == 0 or not edges:
                # print(f"Debug: Skipping graph for index {idx} due to nodes={num_nodes}, edges={len(edges)}")
                skipped += 1
                continue

            # --- Convert to PyTorch Tensors ---
            x_tensor = torch.tensor(node_features, dtype=torch.float)
            edge_index_tensor = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_attr_tensor = torch.tensor(edge_features, dtype=torch.float) if edge_features else torch.empty((0, num_dist_bins), dtype=torch.float) # Handle case with edges but no features defined
            y_tensor = torch.tensor([label], dtype=torch.float)

            # Integer-encode full sequence for CNN
            sequence_tensor = torch.tensor(integer_encode_sequence(sequence), dtype=torch.long)

            # ProtT5 embedding tensor
            prot_t5_tensor = torch.tensor(current_prot_t5_embedding, dtype=torch.float)

            # Create PyG Data object
            data = Data(
                x=x_tensor,
                edge_index=edge_index_tensor,
                edge_attr=edge_attr_tensor,
                y=y_tensor,
                sequence=sequence_tensor,
                central_node_idx=torch.tensor([central_k_new_idx], dtype=torch.long), # Index within valid nodes
                prot_t5_embedding=prot_t5_tensor # *** ADDED PROTT5 EMBEDDING ***
            )

            # Validate data object (optional but recommended)
            if not data.validate(raise_on_error=False): # Set True to hard fail
                 print(f"Warning: Data validation failed for graph at aligned index {i} (original index {idx}). Skipping.")
                 skipped += 1
                 continue

            graph_list.append(data)
            labels.append(label) # Keep track of labels for stratification

        except Exception as e:
            print(f"--- Critical Error processing aligned row index {i} (original index {idx}): {e} ---")
            traceback.print_exc()
            skipped += 1
            continue

    print(f"\nCreated {len(graph_list)} graphs from aligned data, skipped {skipped} rows during graph preparation.")
    if graph_list:
         print(f"Example graph node feature dimension: {graph_list[0].x.shape[1]}")
         print(f"Example graph edge feature dimension: {graph_list[0].edge_attr.shape[1] if graph_list[0].edge_attr is not None else 'N/A'}")
         print(f"Example sequence length: {graph_list[0].sequence.shape[0]}")
         print(f"Example ProtT5 embedding dimension: {graph_list[0].prot_t5_embedding.shape[0]}")
    else:
         print("Warning: No graphs were created. Check alignment and preparation steps.")

    return graph_list, labels


# --- Training and Evaluation Functions ---

def train_model(model, loader, optimizer, device, class_weights=None):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    skipped_batches = 0

    for batch in loader:
        try:
            batch = batch.to(device)
            # Basic check for essential attributes before passing to model
            if not hasattr(batch, 'x') or not hasattr(batch, 'edge_index') or \
               not hasattr(batch, 'sequence') or not hasattr(batch, 'prot_t5_embedding') or \
               not hasattr(batch, 'y') or not hasattr(batch, 'batch') or \
               not hasattr(batch, 'central_node_idx'):
                 print("Warning: Skipping batch due to missing attributes.")
                 skipped_batches += 1
                 continue

            if batch.x.shape[0] == 0: # Skip empty batches
                 # print("Warning: Skipping empty batch.")
                 skipped_batches += 1
                 continue


            optimizer.zero_grad()
            output = model(batch)

            # Ensure target is correct shape and type
            target = batch.y.view(-1, 1).float() # Make sure target is float for BCE

            # Apply class weights if provided
            if class_weights is not None:
                # Ensure weights are float and match target device/shape
                weights_list = [class_weights[int(t.item())] for t in target]
                weight_tensor = torch.tensor(weights_list, device=device, dtype=torch.float).view(-1, 1)
                loss = F.binary_cross_entropy(output, target, weight=weight_tensor)
            else:
                loss = F.binary_cross_entropy(output, target)

            # Check for NaN loss
            if torch.isnan(loss):
                print("Warning: NaN loss detected. Skipping batch.")
                skipped_batches += 1
                continue


            loss.backward()
            # Optional: Gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item() * batch.num_graphs # loss.item() is avg loss in batch

            # Calculate accuracy
            pred = (output > 0.5).float()
            correct += (pred == target).sum().item()
            total += target.size(0)

        except IndexError as e:
             print(f"IndexError during training batch: {e}. Skipping batch.")
             print(f"Batch info: Nodes={batch.num_nodes if hasattr(batch, 'num_nodes') else 'N/A'}, Graphs={batch.num_graphs if hasattr(batch, 'num_graphs') else 'N/A'}")
             traceback.print_exc()
             skipped_batches += 1
             continue
        except RuntimeError as e:
             print(f"RuntimeError during training batch: {e}. Skipping batch.")
             # E.g. CUDA out of memory
             traceback.print_exc()
             skipped_batches += 1
             # Optional: break loop if OOM
             if "CUDA out of memory" in str(e): raise e # Re-raise OOM
             continue
        except Exception as e:
             print(f"Generic error during training batch: {e}. Skipping batch.")
             traceback.print_exc()
             skipped_batches += 1
             continue


    if skipped_batches > 0:
         print(f"Skipped {skipped_batches} batches during training epoch.")

    # Avoid division by zero if all batches were skipped or loader was empty
    num_processed_graphs = len(loader.dataset) - skipped_batches * loader.batch_size # Approximation
    if num_processed_graphs <= 0 or total == 0:
         print("Warning: No graphs processed or no targets found in training epoch.")
         return 0.0, 0.0

    avg_loss = total_loss / num_processed_graphs # Average loss per graph
    accuracy = correct / total
    return avg_loss, accuracy


def evaluate_model(model, loader, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    skipped_batches = 0

    with torch.no_grad():
        for batch in loader:
            try:
                batch = batch.to(device)
                # Basic check for essential attributes
                if not hasattr(batch, 'x') or not hasattr(batch, 'edge_index') or \
                   not hasattr(batch, 'sequence') or not hasattr(batch, 'prot_t5_embedding') or \
                   not hasattr(batch, 'y') or not hasattr(batch, 'batch') or \
                   not hasattr(batch, 'central_node_idx'):
                     print("Warning: Skipping evaluation batch due to missing attributes.")
                     skipped_batches += 1
                     continue

                if batch.x.shape[0] == 0: # Skip empty batches
                    # print("Warning: Skipping empty evaluation batch.")
                    skipped_batches += 1
                    continue

                output = model(batch)
                target = batch.y.view(-1, 1).float() # Ensure float

                # Use reduction='sum' to get total loss for the batch
                loss = F.binary_cross_entropy(output, target, reduction='sum')

                # Check for NaN loss
                if torch.isnan(loss):
                    print("Warning: NaN loss detected during evaluation. Skipping batch contribution.")
                    skipped_batches += 1
                    continue

                total_loss += loss.item()

                pred = (output > 0.5).float()
                all_preds.append(pred.cpu().numpy())
                all_targets.append(target.cpu().numpy())

            except IndexError as e:
                 print(f"IndexError during evaluation batch: {e}. Skipping batch.")
                 print(f"Batch info: Nodes={batch.num_nodes if hasattr(batch, 'num_nodes') else 'N/A'}, Graphs={batch.num_graphs if hasattr(batch, 'num_graphs') else 'N/A'}")
                 skipped_batches += 1
                 continue
            except RuntimeError as e:
                 print(f"RuntimeError during evaluation batch: {e}. Skipping batch.")
                 skipped_batches += 1
                 if "CUDA out of memory" in str(e): raise e
                 continue
            except Exception as e:
                 print(f"Generic error during evaluation batch: {e}. Skipping batch.")
                 skipped_batches += 1
                 continue

    if skipped_batches > 0:
        print(f"Skipped {skipped_batches} batches during evaluation.")

    if not all_preds:
        print("Warning: No predictions were made during evaluation.")
        return {
            'accuracy': 0.0, 'balanced_acc': 0.0, 'mcc': 0.0,
            'sensitivity': 0.0, 'specificity': 0.0, 'confusion_matrix': np.zeros((2,2)),
            'loss': float('inf'), 'predictions': np.array([]), 'targets': np.array([])
        }

    all_preds = np.concatenate(all_preds).flatten()
    all_targets = np.concatenate(all_targets).flatten()

    # Avoid division by zero if no targets exist
    if len(all_targets) == 0:
         print("Warning: No targets found during evaluation.")
         avg_loss = float('inf')
    else:
         avg_loss = total_loss / len(all_targets) # Average loss per sample

    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_preds)
    balanced_acc = balanced_accuracy_score(all_targets, all_preds)
    # Handle case where MCC is undefined (e.g., perfect prediction, or only one class predicted)
    try:
        mcc = matthews_corrcoef(all_targets, all_preds)
    except ValueError:
        mcc = 0.0 # Or handle as appropriate

    # Handle confusion matrix for cases with only one class present or predicted
    cm = confusion_matrix(all_targets, all_preds)
    if cm.shape == (1, 1): # Only one class present and predicted
        if np.unique(all_targets)[0] == 0: # Only negatives
            cm = np.array([[cm[0,0], 0], [0, 0]]) # TN, FP=0, FN=0, TP=0
        else: # Only positives
            cm = np.array([[0, 0], [0, cm[0,0]]]) # TN=0, FP=0, FN=0, TP
    elif cm.shape != (2, 2): # Unexpected shape, create a zero matrix
        cm = np.zeros((2, 2), dtype=int)


    tn, fp, fn, tp = cm.ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

    return {
        'accuracy': accuracy,
        'balanced_acc': balanced_acc,
        'mcc': mcc,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'confusion_matrix': cm,
        'loss': avg_loss,
        'predictions': all_preds,
        'targets': all_targets
    }


def print_metrics(metrics, prefix=""):
    print(f"{prefix}Accuracy: {metrics['accuracy']:.4f}")
    print(f"{prefix}Balanced Accuracy: {metrics['balanced_acc']:.4f}")
    print(f"{prefix}MCC: {metrics['mcc']:.4f}")
    print(f"{prefix}Sensitivity (Recall): {metrics['sensitivity']:.4f}")
    print(f"{prefix}Specificity: {metrics['specificity']:.4f}")
    print(f"{prefix}Loss: {metrics['loss']:.4f}")
    print(f"{prefix}Confusion Matrix:")
    print(metrics['confusion_matrix'])


# --- *** train_with_cv *** ---
def train_with_cv(train_df_aligned, train_prot_t5,
                  test_df_aligned, test_prot_t5,
                  gnn_type='gcn', distance_threshold=8.0, batch_size=32, epochs=50, lr=0.001):
    """
    Train model with cross-validation using ALIGNED data and ProtT5 embeddings.
    """
    print(f"\n--- Training {gnn_type.upper()} model with ProtT5 ---")
    print(f"Params: distance_threshold={distance_threshold}Å, batch_size={batch_size}, epochs={epochs}, lr={lr}")

    # Prepare graph data using the aligned inputs
    train_graphs, train_labels = prepare_graph_data(train_df_aligned, train_prot_t5, distance_threshold)
    test_graphs, test_labels = prepare_graph_data(test_df_aligned, test_prot_t5, distance_threshold)

    if not train_graphs:
        print("Error: No training graphs created from aligned data. Aborting.")
        return None, None

    # Check if test data exists
    has_test_data = bool(test_graphs)
    if not has_test_data:
         print("Warning: No test graphs created. Test set evaluation will be skipped.")


    # Calculate class weights from the training labels (derived from aligned data)
    total = len(train_labels)
    pos = sum(train_labels)
    neg = total - pos
    if neg == 0 or pos == 0:
         print("Warning: Training data contains only one class. Class weights set to 1.0.")
         class_weights = {0: 1.0, 1: 1.0}
    else:
        class_weights = {
            0: total / (2 * neg),
            1: total / (2 * pos)
        }
    print(f"Calculated Class weights: {class_weights}")

    # Cross-validation setup
    # Use StratifiedKFold on train_labels which corresponds to train_graphs order
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    fold_metrics = []
    test_predictions_list = [] # Store test predictions from each fold model

    for fold, (train_idx, val_idx) in enumerate(kfold.split(train_graphs, train_labels), 1):
        print(f"\n--- Fold {fold}/5 ---")

        # Split graph list into train/validation for this fold
        train_fold = [train_graphs[i] for i in train_idx]
        val_fold = [train_graphs[i] for i in val_idx]

        # Create data loaders
        # Consider num_workers for faster loading, pin_memory if using GPU
        train_loader = DataLoader(train_fold, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
        val_loader = DataLoader(val_fold, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())

        # --- Model Initialization for the fold ---
        # Determine feature dimensions from the first graph (assuming consistency)
        first_graph = train_graphs[0]
        node_feature_dim = first_graph.x.shape[1]
        # Handle case where edge_attr might be None or empty
        edge_feature_dim = first_graph.edge_attr.shape[1] if first_graph.edge_attr is not None and first_graph.edge_attr.numel() > 0 else 0


        model = HybridModel(
            gnn_type=gnn_type,
            node_feature_dim=node_feature_dim,
            edge_feature_dim=edge_feature_dim,
            hidden_dim=128, # GNN hidden dim
            seq_len=33,
            prot_t5_dim=PROT_T5_DIM
        ).to(device)
        # print(model) # Optional: print model summary

        # Optimizer and scheduler
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.01) # Added weight decay
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=True # Reduce LR if val_loss plateaus
        )

        # Training loop with early stopping
        patience = 10 # Number of epochs to wait for improvement before stopping
        best_val_loss = float('inf')
        epochs_no_improve = 0
        best_model_state = None # Store the best model state_dict


        for epoch in range(epochs):
            train_loss, train_acc = train_model(
                model, train_loader, optimizer, device, class_weights
            )

            val_metrics = evaluate_model(model, val_loader, device)
            val_loss = val_metrics['loss']
            val_acc = val_metrics['accuracy'] # Or use balanced_acc for monitoring

            # Step the scheduler based on validation loss
            scheduler.step(val_loss)

            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | "
                  f"LR: {optimizer.param_groups[0]['lr']:.1e}")


            # Early stopping check
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                epochs_no_improve = 0
                # Save the best model state
                best_model_state = model.state_dict().copy()
                print(f"  -> New best validation loss: {best_val_loss:.4f}. Saving model state.")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs.")
                    break

        # Load the best model state found during training for this fold
        if best_model_state:
            print("Loading best model state for final fold evaluation.")
            model.load_state_dict(best_model_state)
        else:
             print("Warning: No best model state saved (possibly due to errors or short training). Using last state.")


        # Final evaluation on the validation set for this fold using the best model
        print("\nEvaluating best model on validation set for Fold", fold)
        final_val_metrics = evaluate_model(model, val_loader, device)
        print_metrics(final_val_metrics, prefix="Fold Validation ")
        fold_metrics.append(final_val_metrics) # Store metrics for averaging later

        # Predict on the TEST set using the best model from this fold
        if has_test_data:
            print("Predicting on test set for Fold", fold)
            test_loader = DataLoader(test_graphs, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())
            test_fold_eval = evaluate_model(model, test_loader, device)
            # We only need the raw predictions for ensembling
            test_predictions_list.append(test_fold_eval['predictions']) # Store predictions (probabilities)


    # --- Cross-validation summary ---
    print("\n--- Cross-validation Summary (Validation Sets) ---")
    if fold_metrics:
        metrics_to_avg = ['accuracy', 'balanced_acc', 'mcc', 'sensitivity', 'specificity', 'loss']
        avg_metrics = {}
        for metric in metrics_to_avg:
            values = [m[metric] for m in fold_metrics if metric in m]
            if values:
                mean_val = np.mean(values)
                std_val = np.std(values)
                print(f"Avg {metric}: {mean_val:.4f} ± {std_val:.4f}")
                avg_metrics[metric] = mean_val
            else:
                print(f"Avg {metric}: N/A (metric not found in fold results)")
                avg_metrics[metric] = None
    else:
        print("No fold metrics recorded.")
        avg_metrics = None

    # --- Final Test Set Evaluation (Ensemble) ---
    final_test_metrics = None
    if has_test_data and test_predictions_list:
        print("\n--- Final Test Set Evaluation (Ensemble Predictions) ---")
        # Average the predictions across folds
        test_pred_stack = np.stack(test_predictions_list, axis=0)
        test_pred_avg = np.mean(test_pred_stack, axis=0) # Average probabilities
        test_pred_binary = (test_pred_avg > 0.5).astype(int) # Convert to binary predictions

        # Calculate final metrics using the ensembled predictions
        final_test_metrics = {
            'accuracy': accuracy_score(test_labels, test_pred_binary),
            'balanced_acc': balanced_accuracy_score(test_labels, test_pred_binary),
            'mcc': matthews_corrcoef(test_labels, test_pred_binary),
            'confusion_matrix': confusion_matrix(test_labels, test_pred_binary)
        }
        # Calculate sensitivity/specificity from CM
        cm = final_test_metrics['confusion_matrix']
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            final_test_metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            final_test_metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        else: # Handle case where CM is not 2x2 (e.g., only one class in test set)
            final_test_metrics['sensitivity'] = 0.0
            final_test_metrics['specificity'] = 0.0

        # Add loss (not directly applicable to ensemble, maybe report avg fold loss?)
        final_test_metrics['loss'] = np.mean([m['loss'] for m in fold_metrics]) if fold_metrics else float('inf')

        print_metrics(final_test_metrics, prefix="Ensemble Test ")

    elif has_test_data:
         print("\n--- Test set evaluation skipped (no predictions gathered) ---")
    else:
         print("\n--- Test set evaluation skipped (no test data) ---")


    # Return average CV metrics and final test metrics
    return avg_metrics, final_test_metrics


# --- Main Execution Block ---
if __name__ == "__main__":
    # --- Configuration ---
    GNN_TYPE = 'gcn'
    DISTANCE_THRESHOLD = 8.0
    BATCH_SIZE = 32
    EPOCHS = 50 # Max epochs
    LEARNING_RATE = 0.001

    # Define file paths
    base_data_path = "../../data/" # Adjust as needed
    train_csv_path = os.path.join(base_data_path, "train/structure/processed_features_train.csv") # Main features CSV
    test_csv_path = os.path.join(base_data_path, "test/structure/processed_features_test.csv")

    # ProtT5 Embedding file paths
    # Use absolute paths or paths relative to the script location
    prot_t5_base_path = '../../data/' # Example base path
    train_pos_prot_t5_path = os.path.join(prot_t5_base_path, 'train/PLM/train_positive_ProtT5-XL-UniRef50.csv')
    train_neg_prot_t5_path = os.path.join(prot_t5_base_path, 'train/PLM/train_negative_ProtT5-XL-UniRef50.csv')
    test_pos_prot_t5_path = os.path.join(prot_t5_base_path, 'test/PLM/test_positive_ProtT5-XL-UniRef50.csv')
    test_neg_prot_t5_path = os.path.join(prot_t5_base_path, 'test/PLM/test_negative_ProtT5-XL-UniRef50.csv')


    # --- Data Loading and Preparation ---
    try:
        print("Loading main feature data...")
        train_df_orig = pd.read_csv(train_csv_path)
        test_df_orig = pd.read_csv(test_csv_path)
        print(f"Loaded {len(train_df_orig)} training samples, {len(test_df_orig)} test samples.")

        # Check for required columns
        required_cols = ['entry', 'pos', 'label', 'sequence', 'distance_map'] # Add others if needed
        if not all(col in train_df_orig.columns for col in required_cols):
             raise ValueError(f"Training CSV missing one or more required columns: {required_cols}")
        if not all(col in test_df_orig.columns for col in required_cols):
             raise ValueError(f"Test CSV missing one or more required columns: {required_cols}")


        print("\nLoading ProtT5 embeddings...")
        train_pos_dict, train_neg_dict = load_prot_t5_data(train_pos_prot_t5_path, train_neg_prot_t5_path)
        test_pos_dict, test_neg_dict = load_prot_t5_data(test_pos_prot_t5_path, test_neg_prot_t5_path)

        if train_pos_dict is None or test_pos_dict is None:
             raise RuntimeError("Failed to load ProtT5 embeddings. Check file paths and formats.")

        print("\nAligning training data...")
        X_train_prot_t5, train_df_aligned = prepare_aligned_data(train_df_orig, train_pos_dict, train_neg_dict)

        print("\nAligning test data...")
        X_test_prot_t5, test_df_aligned = prepare_aligned_data(test_df_orig, test_pos_dict, test_neg_dict)

        # --- Check if alignment produced data ---
        if train_df_aligned.empty:
            raise ValueError("Training data alignment resulted in an empty DataFrame. Cannot proceed.")
        print(f"\nAligned Training Data Shape: {train_df_aligned.shape}")
        print(f"Aligned Training ProtT5 Shape: {X_train_prot_t5.shape}")
        print(f"Aligned Test Data Shape: {test_df_aligned.shape}")
        print(f"Aligned Test ProtT5 Shape: {X_test_prot_t5.shape}")

        # Print class distribution of *aligned* data
        print("\nAligned Train class distribution:", train_df_aligned['label'].value_counts().to_dict())
        if not test_df_aligned.empty:
            print("Aligned Test class distribution:", test_df_aligned['label'].value_counts().to_dict())
        else:
             print("Aligned Test set is empty.")


        # --- Train and Evaluate ---
        # Pass the aligned dataframes and ProtT5 embeddings to the CV function
        avg_cv_metrics, final_test_metrics = train_with_cv(
            train_df_aligned=train_df_aligned,
            train_prot_t5=X_train_prot_t5,
            test_df_aligned=test_df_aligned,
            test_prot_t5=X_test_prot_t5,
            gnn_type=GNN_TYPE,
            distance_threshold=DISTANCE_THRESHOLD,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
            lr=LEARNING_RATE
        )

        print("\n--- Overall Results ---")
        if avg_cv_metrics:
             print("\nAverage Cross-Validation Metrics:")
             for k, v in avg_cv_metrics.items():
                  print(f"  {k}: {v:.4f}" if v is not None else f"  {k}: N/A")
        if final_test_metrics:
             print("\nFinal Ensembled Test Metrics:")
             for k, v in final_test_metrics.items():
                  if k != 'confusion_matrix':
                      print(f"  {k}: {v:.4f}" if v is not None else f"  {k}: N/A")
             print(f"  Confusion Matrix:\n{final_test_metrics.get('confusion_matrix', 'N/A')}")


    except FileNotFoundError as e:
        print(f"Error: Data file not found. {e}")
        print("Please check the file paths specified in the script.")
    except KeyError as e:
         print(f"Error: Missing expected column in DataFrame: {e}")
         print("Please ensure your CSV files contain the required columns ('entry', 'pos', 'label', 'sequence', 'distance_map', etc.).")
    except ValueError as e:
         print(f"Error during data preparation or alignment: {e}")
    except RuntimeError as e:
         print(f"Runtime Error (potentially CUDA memory issue): {e}")
    except Exception as e:
        print(f"\n--- An unexpected error occurred ---")
        traceback.print_exc()

Using device: cuda
Loading main feature data...
Loaded 8853 training samples, 2737 test samples.

Loading ProtT5 embeddings...
Loading ProtT5 data from /home/ubuntu/data/hai/thesis/data/train/features/train_positive_ProtT5-XL-UniRef50.csv and /home/ubuntu/data/hai/thesis/data/train/features/train_negative_ProtT5-XL-UniRef50.csv
Loaded 4750 positive and 4529 negative ProtT5 embeddings.
Loading ProtT5 data from /home/ubuntu/data/hai/thesis/data/test/features/test_positive_ProtT5-XL-UniRef50.csv and /home/ubuntu/data/hai/thesis/data/test/features/test_negative_ProtT5-XL-UniRef50.csv
Loaded 253 positive and 2972 negative ProtT5 embeddings.

Aligning training data...
Aligning ProtT5 embeddings with main DataFrame...
Alignment complete. Kept 8853 out of 8853 original rows.

Aligning test data...
Aligning ProtT5 embeddings with main DataFrame...
Alignment complete. Kept 2737 out of 2737 original rows.

Aligned Training Data Shape: (8853, 16)
Aligned Training ProtT5 Shape: (8853, 1024)
Aligned



Epoch 1/50 | Train Loss: 0.6783, Train Acc: 0.5685 | Val Loss: 0.6632, Val Acc: 0.6087 | LR: 1.0e-03
  -> New best validation loss: 0.6632. Saving model state.




Epoch 2/50 | Train Loss: 0.6476, Train Acc: 0.6296 | Val Loss: 0.6558, Val Acc: 0.6025 | LR: 1.0e-03
  -> New best validation loss: 0.6558. Saving model state.




Epoch 3/50 | Train Loss: 0.6329, Train Acc: 0.6446 | Val Loss: 0.6418, Val Acc: 0.6318 | LR: 1.0e-03
  -> New best validation loss: 0.6418. Saving model state.




Epoch 4/50 | Train Loss: 0.6204, Train Acc: 0.6641 | Val Loss: 0.6330, Val Acc: 0.6352 | LR: 1.0e-03
  -> New best validation loss: 0.6330. Saving model state.




Epoch 5/50 | Train Loss: 0.6123, Train Acc: 0.6744 | Val Loss: 0.6196, Val Acc: 0.6669 | LR: 1.0e-03
  -> New best validation loss: 0.6196. Saving model state.




Epoch 6/50 | Train Loss: 0.5990, Train Acc: 0.6830 | Val Loss: 0.6394, Val Acc: 0.6296 | LR: 1.0e-03




Epoch 7/50 | Train Loss: 0.5961, Train Acc: 0.6834 | Val Loss: 0.5982, Val Acc: 0.6866 | LR: 1.0e-03
  -> New best validation loss: 0.5982. Saving model state.




Epoch 8/50 | Train Loss: 0.5907, Train Acc: 0.6942 | Val Loss: 0.6141, Val Acc: 0.6736 | LR: 1.0e-03




Epoch 9/50 | Train Loss: 0.5907, Train Acc: 0.6910 | Val Loss: 0.6190, Val Acc: 0.6426 | LR: 1.0e-03




Epoch 10/50 | Train Loss: 0.5816, Train Acc: 0.6957 | Val Loss: 0.6030, Val Acc: 0.6934 | LR: 1.0e-03




Epoch 11/50 | Train Loss: 0.5758, Train Acc: 0.7067 | Val Loss: 0.5764, Val Acc: 0.7132 | LR: 1.0e-03
  -> New best validation loss: 0.5764. Saving model state.




Epoch 12/50 | Train Loss: 0.5747, Train Acc: 0.7073 | Val Loss: 0.5819, Val Acc: 0.7126 | LR: 1.0e-03




Epoch 13/50 | Train Loss: 0.5660, Train Acc: 0.7159 | Val Loss: 0.5671, Val Acc: 0.7165 | LR: 1.0e-03
  -> New best validation loss: 0.5671. Saving model state.




Epoch 14/50 | Train Loss: 0.5502, Train Acc: 0.7302 | Val Loss: 0.5452, Val Acc: 0.7420 | LR: 1.0e-03
  -> New best validation loss: 0.5452. Saving model state.




Epoch 15/50 | Train Loss: 0.5273, Train Acc: 0.7495 | Val Loss: 0.5317, Val Acc: 0.7199 | LR: 1.0e-03
  -> New best validation loss: 0.5317. Saving model state.




Epoch 16/50 | Train Loss: 0.5141, Train Acc: 0.7544 | Val Loss: 0.5059, Val Acc: 0.7702 | LR: 1.0e-03
  -> New best validation loss: 0.5059. Saving model state.




Epoch 17/50 | Train Loss: 0.5038, Train Acc: 0.7595 | Val Loss: 0.5052, Val Acc: 0.7538 | LR: 1.0e-03
  -> New best validation loss: 0.5052. Saving model state.




Epoch 18/50 | Train Loss: 0.4997, Train Acc: 0.7624 | Val Loss: 0.4936, Val Acc: 0.7668 | LR: 1.0e-03
  -> New best validation loss: 0.4936. Saving model state.




Epoch 19/50 | Train Loss: 0.4937, Train Acc: 0.7665 | Val Loss: 0.4899, Val Acc: 0.7713 | LR: 1.0e-03
  -> New best validation loss: 0.4899. Saving model state.




Epoch 20/50 | Train Loss: 0.4884, Train Acc: 0.7715 | Val Loss: 0.4919, Val Acc: 0.7679 | LR: 1.0e-03




Epoch 21/50 | Train Loss: 0.4795, Train Acc: 0.7804 | Val Loss: 0.4763, Val Acc: 0.7758 | LR: 1.0e-03
  -> New best validation loss: 0.4763. Saving model state.




Epoch 22/50 | Train Loss: 0.4754, Train Acc: 0.7817 | Val Loss: 0.5185, Val Acc: 0.7442 | LR: 1.0e-03




Epoch 23/50 | Train Loss: 0.4782, Train Acc: 0.7813 | Val Loss: 0.5019, Val Acc: 0.7561 | LR: 1.0e-03




Epoch 24/50 | Train Loss: 0.4695, Train Acc: 0.7866 | Val Loss: 0.4886, Val Acc: 0.7679 | LR: 1.0e-03




Epoch 25/50 | Train Loss: 0.4694, Train Acc: 0.7859 | Val Loss: 0.4811, Val Acc: 0.7820 | LR: 1.0e-03




Epoch 26/50 | Train Loss: 0.4656, Train Acc: 0.7858 | Val Loss: 0.4778, Val Acc: 0.7787 | LR: 1.0e-03




Epoch 27/50 | Train Loss: 0.4580, Train Acc: 0.7906 | Val Loss: 0.4672, Val Acc: 0.7770 | LR: 1.0e-03
  -> New best validation loss: 0.4672. Saving model state.




Epoch 28/50 | Train Loss: 0.4617, Train Acc: 0.7857 | Val Loss: 0.4767, Val Acc: 0.7804 | LR: 1.0e-03




Epoch 29/50 | Train Loss: 0.4589, Train Acc: 0.7923 | Val Loss: 0.4802, Val Acc: 0.7685 | LR: 1.0e-03




Epoch 30/50 | Train Loss: 0.4627, Train Acc: 0.7876 | Val Loss: 0.4746, Val Acc: 0.7787 | LR: 1.0e-03




Epoch 31/50 | Train Loss: 0.4522, Train Acc: 0.7960 | Val Loss: 0.4624, Val Acc: 0.7832 | LR: 1.0e-03
  -> New best validation loss: 0.4624. Saving model state.




Epoch 32/50 | Train Loss: 0.4502, Train Acc: 0.7955 | Val Loss: 0.4785, Val Acc: 0.7798 | LR: 1.0e-03




Epoch 33/50 | Train Loss: 0.4492, Train Acc: 0.7996 | Val Loss: 0.4867, Val Acc: 0.7674 | LR: 1.0e-03




Epoch 34/50 | Train Loss: 0.4451, Train Acc: 0.8022 | Val Loss: 0.4898, Val Acc: 0.7668 | LR: 1.0e-03




Epoch 35/50 | Train Loss: 0.4409, Train Acc: 0.8105 | Val Loss: 0.4876, Val Acc: 0.7708 | LR: 1.0e-03




Epoch 36/50 | Train Loss: 0.4432, Train Acc: 0.7985 | Val Loss: 0.4928, Val Acc: 0.7595 | LR: 1.0e-03




Epoch 37/50 | Train Loss: 0.4445, Train Acc: 0.8006 | Val Loss: 0.4823, Val Acc: 0.7764 | LR: 5.0e-04




Epoch 38/50 | Train Loss: 0.4102, Train Acc: 0.8187 | Val Loss: 0.4700, Val Acc: 0.7747 | LR: 5.0e-04




Epoch 39/50 | Train Loss: 0.4004, Train Acc: 0.8291 | Val Loss: 0.4684, Val Acc: 0.7798 | LR: 5.0e-04




Epoch 40/50 | Train Loss: 0.3945, Train Acc: 0.8294 | Val Loss: 0.4778, Val Acc: 0.7804 | LR: 5.0e-04




Epoch 41/50 | Train Loss: 0.3807, Train Acc: 0.8409 | Val Loss: 0.4740, Val Acc: 0.7792 | LR: 5.0e-04
Early stopping triggered after 41 epochs.
Loading best model state for final fold evaluation.

Evaluating best model on validation set for Fold 1
Fold Validation Accuracy: 0.7792
Fold Validation Balanced Accuracy: 0.7754
Fold Validation MCC: 0.5645
Fold Validation Sensitivity (Recall): 0.8760
Fold Validation Specificity: 0.6749
Fold Validation Loss: 0.4740
Fold Validation Confusion Matrix:
[[575 277]
 [114 805]]
Predicting on test set for Fold 1





--- Fold 2/5 ---
SequenceCNN calculated flat size: 2304
HybridModel combined input dimension for fc1: 544




Epoch 1/50 | Train Loss: 0.6826, Train Acc: 0.5733 | Val Loss: 0.6601, Val Acc: 0.6104 | LR: 1.0e-03
  -> New best validation loss: 0.6601. Saving model state.




Epoch 2/50 | Train Loss: 0.6549, Train Acc: 0.6188 | Val Loss: 0.6562, Val Acc: 0.6036 | LR: 1.0e-03
  -> New best validation loss: 0.6562. Saving model state.




Epoch 3/50 | Train Loss: 0.6327, Train Acc: 0.6490 | Val Loss: 0.6489, Val Acc: 0.6036 | LR: 1.0e-03
  -> New best validation loss: 0.6489. Saving model state.




Epoch 4/50 | Train Loss: 0.6201, Train Acc: 0.6649 | Val Loss: 0.6289, Val Acc: 0.6815 | LR: 1.0e-03
  -> New best validation loss: 0.6289. Saving model state.




Epoch 5/50 | Train Loss: 0.6131, Train Acc: 0.6720 | Val Loss: 0.6390, Val Acc: 0.6364 | LR: 1.0e-03




Epoch 6/50 | Train Loss: 0.6001, Train Acc: 0.6800 | Val Loss: 0.6201, Val Acc: 0.6629 | LR: 1.0e-03
  -> New best validation loss: 0.6201. Saving model state.




Epoch 7/50 | Train Loss: 0.5952, Train Acc: 0.6868 | Val Loss: 0.6159, Val Acc: 0.6804 | LR: 1.0e-03
  -> New best validation loss: 0.6159. Saving model state.




Epoch 8/50 | Train Loss: 0.5906, Train Acc: 0.6903 | Val Loss: 0.6302, Val Acc: 0.6381 | LR: 1.0e-03




Epoch 9/50 | Train Loss: 0.5862, Train Acc: 0.7001 | Val Loss: 0.5920, Val Acc: 0.6973 | LR: 1.0e-03
  -> New best validation loss: 0.5920. Saving model state.




Epoch 10/50 | Train Loss: 0.5824, Train Acc: 0.6978 | Val Loss: 0.5960, Val Acc: 0.6838 | LR: 1.0e-03




Epoch 11/50 | Train Loss: 0.5811, Train Acc: 0.6977 | Val Loss: 0.5933, Val Acc: 0.7069 | LR: 1.0e-03




Epoch 12/50 | Train Loss: 0.5750, Train Acc: 0.7016 | Val Loss: 0.5863, Val Acc: 0.7075 | LR: 1.0e-03
  -> New best validation loss: 0.5863. Saving model state.




Epoch 13/50 | Train Loss: 0.5568, Train Acc: 0.7166 | Val Loss: 0.5700, Val Acc: 0.6968 | LR: 1.0e-03
  -> New best validation loss: 0.5700. Saving model state.




Epoch 14/50 | Train Loss: 0.5365, Train Acc: 0.7341 | Val Loss: 0.5428, Val Acc: 0.7261 | LR: 1.0e-03
  -> New best validation loss: 0.5428. Saving model state.




Epoch 15/50 | Train Loss: 0.5266, Train Acc: 0.7485 | Val Loss: 0.5158, Val Acc: 0.7685 | LR: 1.0e-03
  -> New best validation loss: 0.5158. Saving model state.




Epoch 16/50 | Train Loss: 0.5163, Train Acc: 0.7516 | Val Loss: 0.5567, Val Acc: 0.7081 | LR: 1.0e-03




Epoch 17/50 | Train Loss: 0.5150, Train Acc: 0.7511 | Val Loss: 0.5136, Val Acc: 0.7628 | LR: 1.0e-03
  -> New best validation loss: 0.5136. Saving model state.




Epoch 18/50 | Train Loss: 0.5052, Train Acc: 0.7573 | Val Loss: 0.5171, Val Acc: 0.7470 | LR: 1.0e-03




Epoch 19/50 | Train Loss: 0.5026, Train Acc: 0.7561 | Val Loss: 0.5017, Val Acc: 0.7708 | LR: 1.0e-03
  -> New best validation loss: 0.5017. Saving model state.




Epoch 20/50 | Train Loss: 0.4896, Train Acc: 0.7744 | Val Loss: 0.4927, Val Acc: 0.7741 | LR: 1.0e-03
  -> New best validation loss: 0.4927. Saving model state.




Epoch 21/50 | Train Loss: 0.4816, Train Acc: 0.7783 | Val Loss: 0.5007, Val Acc: 0.7651 | LR: 1.0e-03




Epoch 22/50 | Train Loss: 0.4788, Train Acc: 0.7749 | Val Loss: 0.5024, Val Acc: 0.7549 | LR: 1.0e-03




Epoch 23/50 | Train Loss: 0.4664, Train Acc: 0.7842 | Val Loss: 0.4952, Val Acc: 0.7696 | LR: 1.0e-03




Epoch 24/50 | Train Loss: 0.4666, Train Acc: 0.7838 | Val Loss: 0.5121, Val Acc: 0.7476 | LR: 1.0e-03




Epoch 25/50 | Train Loss: 0.4664, Train Acc: 0.7876 | Val Loss: 0.4877, Val Acc: 0.7685 | LR: 1.0e-03
  -> New best validation loss: 0.4877. Saving model state.




Epoch 26/50 | Train Loss: 0.4597, Train Acc: 0.7912 | Val Loss: 0.4844, Val Acc: 0.7685 | LR: 1.0e-03
  -> New best validation loss: 0.4844. Saving model state.




Epoch 27/50 | Train Loss: 0.4632, Train Acc: 0.7878 | Val Loss: 0.4792, Val Acc: 0.7888 | LR: 1.0e-03
  -> New best validation loss: 0.4792. Saving model state.




Epoch 28/50 | Train Loss: 0.4653, Train Acc: 0.7912 | Val Loss: 0.4900, Val Acc: 0.7702 | LR: 1.0e-03




Epoch 29/50 | Train Loss: 0.4552, Train Acc: 0.7917 | Val Loss: 0.4959, Val Acc: 0.7662 | LR: 1.0e-03




Epoch 30/50 | Train Loss: 0.4593, Train Acc: 0.7879 | Val Loss: 0.4882, Val Acc: 0.7668 | LR: 1.0e-03




Epoch 31/50 | Train Loss: 0.4541, Train Acc: 0.8034 | Val Loss: 0.5036, Val Acc: 0.7572 | LR: 1.0e-03




Epoch 32/50 | Train Loss: 0.4508, Train Acc: 0.7914 | Val Loss: 0.4739, Val Acc: 0.7792 | LR: 1.0e-03
  -> New best validation loss: 0.4739. Saving model state.




Epoch 33/50 | Train Loss: 0.4412, Train Acc: 0.7982 | Val Loss: 0.4882, Val Acc: 0.7651 | LR: 1.0e-03




Epoch 34/50 | Train Loss: 0.4508, Train Acc: 0.7986 | Val Loss: 0.4894, Val Acc: 0.7719 | LR: 1.0e-03




Epoch 35/50 | Train Loss: 0.4400, Train Acc: 0.8050 | Val Loss: 0.5006, Val Acc: 0.7555 | LR: 1.0e-03




Epoch 36/50 | Train Loss: 0.4395, Train Acc: 0.8005 | Val Loss: 0.4826, Val Acc: 0.7843 | LR: 1.0e-03




Epoch 37/50 | Train Loss: 0.4331, Train Acc: 0.8033 | Val Loss: 0.4983, Val Acc: 0.7617 | LR: 1.0e-03




Epoch 38/50 | Train Loss: 0.4379, Train Acc: 0.8043 | Val Loss: 0.4851, Val Acc: 0.7764 | LR: 5.0e-04




Epoch 39/50 | Train Loss: 0.3938, Train Acc: 0.8315 | Val Loss: 0.4742, Val Acc: 0.7883 | LR: 5.0e-04




Epoch 40/50 | Train Loss: 0.3828, Train Acc: 0.8396 | Val Loss: 0.4758, Val Acc: 0.7815 | LR: 5.0e-04




Epoch 41/50 | Train Loss: 0.3754, Train Acc: 0.8385 | Val Loss: 0.4691, Val Acc: 0.7787 | LR: 5.0e-04
  -> New best validation loss: 0.4691. Saving model state.




Epoch 42/50 | Train Loss: 0.3762, Train Acc: 0.8421 | Val Loss: 0.4764, Val Acc: 0.7724 | LR: 5.0e-04




Epoch 43/50 | Train Loss: 0.3630, Train Acc: 0.8444 | Val Loss: 0.4794, Val Acc: 0.7753 | LR: 5.0e-04




Epoch 44/50 | Train Loss: 0.3613, Train Acc: 0.8451 | Val Loss: 0.4805, Val Acc: 0.7736 | LR: 5.0e-04




Epoch 45/50 | Train Loss: 0.3644, Train Acc: 0.8448 | Val Loss: 0.4882, Val Acc: 0.7696 | LR: 5.0e-04




Epoch 46/50 | Train Loss: 0.3571, Train Acc: 0.8474 | Val Loss: 0.4999, Val Acc: 0.7578 | LR: 5.0e-04




Epoch 47/50 | Train Loss: 0.3408, Train Acc: 0.8565 | Val Loss: 0.4933, Val Acc: 0.7696 | LR: 2.5e-04




Epoch 48/50 | Train Loss: 0.3139, Train Acc: 0.8714 | Val Loss: 0.4915, Val Acc: 0.7730 | LR: 2.5e-04




Epoch 49/50 | Train Loss: 0.3002, Train Acc: 0.8818 | Val Loss: 0.4893, Val Acc: 0.7804 | LR: 2.5e-04




Epoch 50/50 | Train Loss: 0.2878, Train Acc: 0.8880 | Val Loss: 0.5050, Val Acc: 0.7640 | LR: 2.5e-04
Loading best model state for final fold evaluation.

Evaluating best model on validation set for Fold 2
Fold Validation Accuracy: 0.7640
Fold Validation Balanced Accuracy: 0.7630
Fold Validation MCC: 0.5270
Fold Validation Sensitivity (Recall): 0.7900
Fold Validation Specificity: 0.7359
Fold Validation Loss: 0.5050
Fold Validation Confusion Matrix:
[[627 225]
 [193 726]]
Predicting on test set for Fold 2





--- Fold 3/5 ---
SequenceCNN calculated flat size: 2304
HybridModel combined input dimension for fc1: 544




Epoch 1/50 | Train Loss: 0.6838, Train Acc: 0.5569 | Val Loss: 0.6901, Val Acc: 0.5471 | LR: 1.0e-03
  -> New best validation loss: 0.6901. Saving model state.




Epoch 2/50 | Train Loss: 0.6519, Train Acc: 0.6212 | Val Loss: 0.6848, Val Acc: 0.5522 | LR: 1.0e-03
  -> New best validation loss: 0.6848. Saving model state.




Epoch 3/50 | Train Loss: 0.6298, Train Acc: 0.6505 | Val Loss: 0.6548, Val Acc: 0.6200 | LR: 1.0e-03
  -> New best validation loss: 0.6548. Saving model state.




Epoch 4/50 | Train Loss: 0.6193, Train Acc: 0.6652 | Val Loss: 0.6542, Val Acc: 0.5997 | LR: 1.0e-03
  -> New best validation loss: 0.6542. Saving model state.




Epoch 5/50 | Train Loss: 0.6148, Train Acc: 0.6666 | Val Loss: 0.6392, Val Acc: 0.6533 | LR: 1.0e-03
  -> New best validation loss: 0.6392. Saving model state.




Epoch 6/50 | Train Loss: 0.5958, Train Acc: 0.6913 | Val Loss: 0.6257, Val Acc: 0.6527 | LR: 1.0e-03
  -> New best validation loss: 0.6257. Saving model state.




Epoch 7/50 | Train Loss: 0.5932, Train Acc: 0.6942 | Val Loss: 0.6184, Val Acc: 0.6578 | LR: 1.0e-03
  -> New best validation loss: 0.6184. Saving model state.




Epoch 8/50 | Train Loss: 0.5890, Train Acc: 0.6910 | Val Loss: 0.6143, Val Acc: 0.6663 | LR: 1.0e-03
  -> New best validation loss: 0.6143. Saving model state.




Epoch 9/50 | Train Loss: 0.5827, Train Acc: 0.7014 | Val Loss: 0.6025, Val Acc: 0.6838 | LR: 1.0e-03
  -> New best validation loss: 0.6025. Saving model state.




Epoch 10/50 | Train Loss: 0.5746, Train Acc: 0.7110 | Val Loss: 0.5957, Val Acc: 0.6855 | LR: 1.0e-03
  -> New best validation loss: 0.5957. Saving model state.




Epoch 11/50 | Train Loss: 0.5746, Train Acc: 0.7036 | Val Loss: 0.6009, Val Acc: 0.6708 | LR: 1.0e-03




Epoch 12/50 | Train Loss: 0.5693, Train Acc: 0.7112 | Val Loss: 0.5855, Val Acc: 0.6917 | LR: 1.0e-03
  -> New best validation loss: 0.5855. Saving model state.




Epoch 13/50 | Train Loss: 0.5622, Train Acc: 0.7136 | Val Loss: 0.5876, Val Acc: 0.6877 | LR: 1.0e-03




Epoch 14/50 | Train Loss: 0.5484, Train Acc: 0.7275 | Val Loss: 0.5569, Val Acc: 0.7335 | LR: 1.0e-03
  -> New best validation loss: 0.5569. Saving model state.




Epoch 15/50 | Train Loss: 0.5318, Train Acc: 0.7416 | Val Loss: 0.5458, Val Acc: 0.7199 | LR: 1.0e-03
  -> New best validation loss: 0.5458. Saving model state.




Epoch 16/50 | Train Loss: 0.5220, Train Acc: 0.7554 | Val Loss: 0.5362, Val Acc: 0.7436 | LR: 1.0e-03
  -> New best validation loss: 0.5362. Saving model state.




Epoch 17/50 | Train Loss: 0.5239, Train Acc: 0.7494 | Val Loss: 0.5320, Val Acc: 0.7380 | LR: 1.0e-03
  -> New best validation loss: 0.5320. Saving model state.




Epoch 18/50 | Train Loss: 0.5035, Train Acc: 0.7626 | Val Loss: 0.5425, Val Acc: 0.7273 | LR: 1.0e-03




Epoch 19/50 | Train Loss: 0.4917, Train Acc: 0.7667 | Val Loss: 0.5169, Val Acc: 0.7504 | LR: 1.0e-03
  -> New best validation loss: 0.5169. Saving model state.




Epoch 20/50 | Train Loss: 0.4874, Train Acc: 0.7748 | Val Loss: 0.5176, Val Acc: 0.7436 | LR: 1.0e-03




Epoch 21/50 | Train Loss: 0.4840, Train Acc: 0.7776 | Val Loss: 0.5037, Val Acc: 0.7583 | LR: 1.0e-03
  -> New best validation loss: 0.5037. Saving model state.




Epoch 22/50 | Train Loss: 0.4812, Train Acc: 0.7768 | Val Loss: 0.5064, Val Acc: 0.7612 | LR: 1.0e-03




Epoch 23/50 | Train Loss: 0.4737, Train Acc: 0.7833 | Val Loss: 0.5073, Val Acc: 0.7555 | LR: 1.0e-03




Epoch 24/50 | Train Loss: 0.4723, Train Acc: 0.7824 | Val Loss: 0.5097, Val Acc: 0.7572 | LR: 1.0e-03




Epoch 25/50 | Train Loss: 0.4672, Train Acc: 0.7834 | Val Loss: 0.4950, Val Acc: 0.7662 | LR: 1.0e-03
  -> New best validation loss: 0.4950. Saving model state.




Epoch 26/50 | Train Loss: 0.4696, Train Acc: 0.7855 | Val Loss: 0.5000, Val Acc: 0.7538 | LR: 1.0e-03




Epoch 27/50 | Train Loss: 0.4646, Train Acc: 0.7888 | Val Loss: 0.4979, Val Acc: 0.7566 | LR: 1.0e-03




Epoch 28/50 | Train Loss: 0.4545, Train Acc: 0.7991 | Val Loss: 0.5310, Val Acc: 0.7340 | LR: 1.0e-03




Epoch 29/50 | Train Loss: 0.4627, Train Acc: 0.7882 | Val Loss: 0.4883, Val Acc: 0.7623 | LR: 1.0e-03
  -> New best validation loss: 0.4883. Saving model state.




Epoch 30/50 | Train Loss: 0.4489, Train Acc: 0.7972 | Val Loss: 0.5003, Val Acc: 0.7589 | LR: 1.0e-03




Epoch 31/50 | Train Loss: 0.4485, Train Acc: 0.7981 | Val Loss: 0.5029, Val Acc: 0.7493 | LR: 1.0e-03




Epoch 32/50 | Train Loss: 0.4527, Train Acc: 0.7912 | Val Loss: 0.4916, Val Acc: 0.7679 | LR: 1.0e-03




Epoch 33/50 | Train Loss: 0.4410, Train Acc: 0.8015 | Val Loss: 0.5040, Val Acc: 0.7674 | LR: 1.0e-03




Epoch 34/50 | Train Loss: 0.4455, Train Acc: 0.8005 | Val Loss: 0.4979, Val Acc: 0.7595 | LR: 1.0e-03




Epoch 35/50 | Train Loss: 0.4446, Train Acc: 0.8036 | Val Loss: 0.4952, Val Acc: 0.7691 | LR: 5.0e-04




Epoch 36/50 | Train Loss: 0.4138, Train Acc: 0.8149 | Val Loss: 0.4986, Val Acc: 0.7645 | LR: 5.0e-04




Epoch 37/50 | Train Loss: 0.3872, Train Acc: 0.8330 | Val Loss: 0.4813, Val Acc: 0.7764 | LR: 5.0e-04
  -> New best validation loss: 0.4813. Saving model state.




Epoch 38/50 | Train Loss: 0.3949, Train Acc: 0.8317 | Val Loss: 0.4821, Val Acc: 0.7764 | LR: 5.0e-04




Epoch 39/50 | Train Loss: 0.3815, Train Acc: 0.8362 | Val Loss: 0.4941, Val Acc: 0.7674 | LR: 5.0e-04




Epoch 40/50 | Train Loss: 0.3791, Train Acc: 0.8372 | Val Loss: 0.4799, Val Acc: 0.7781 | LR: 5.0e-04
  -> New best validation loss: 0.4799. Saving model state.




Epoch 41/50 | Train Loss: 0.3755, Train Acc: 0.8430 | Val Loss: 0.4913, Val Acc: 0.7640 | LR: 5.0e-04




Epoch 42/50 | Train Loss: 0.3652, Train Acc: 0.8467 | Val Loss: 0.4875, Val Acc: 0.7724 | LR: 5.0e-04




Epoch 43/50 | Train Loss: 0.3677, Train Acc: 0.8447 | Val Loss: 0.5023, Val Acc: 0.7606 | LR: 5.0e-04




Epoch 44/50 | Train Loss: 0.3569, Train Acc: 0.8503 | Val Loss: 0.5009, Val Acc: 0.7668 | LR: 5.0e-04




Epoch 45/50 | Train Loss: 0.3434, Train Acc: 0.8587 | Val Loss: 0.5035, Val Acc: 0.7662 | LR: 5.0e-04




Epoch 46/50 | Train Loss: 0.3465, Train Acc: 0.8519 | Val Loss: 0.4985, Val Acc: 0.7583 | LR: 2.5e-04




Epoch 47/50 | Train Loss: 0.3123, Train Acc: 0.8773 | Val Loss: 0.4886, Val Acc: 0.7691 | LR: 2.5e-04




Epoch 48/50 | Train Loss: 0.3014, Train Acc: 0.8793 | Val Loss: 0.4952, Val Acc: 0.7685 | LR: 2.5e-04




Epoch 49/50 | Train Loss: 0.2920, Train Acc: 0.8822 | Val Loss: 0.5034, Val Acc: 0.7696 | LR: 2.5e-04




Epoch 50/50 | Train Loss: 0.2896, Train Acc: 0.8834 | Val Loss: 0.5014, Val Acc: 0.7691 | LR: 2.5e-04
Early stopping triggered after 50 epochs.
Loading best model state for final fold evaluation.

Evaluating best model on validation set for Fold 3
Fold Validation Accuracy: 0.7691
Fold Validation Balanced Accuracy: 0.7681
Fold Validation MCC: 0.5372
Fold Validation Sensitivity (Recall): 0.7930
Fold Validation Specificity: 0.7433
Fold Validation Loss: 0.5014
Fold Validation Confusion Matrix:
[[634 219]
 [190 728]]
Predicting on test set for Fold 3





--- Fold 4/5 ---
SequenceCNN calculated flat size: 2304
HybridModel combined input dimension for fc1: 544




Epoch 1/50 | Train Loss: 0.6891, Train Acc: 0.5611 | Val Loss: 0.6585, Val Acc: 0.6322 | LR: 1.0e-03
  -> New best validation loss: 0.6585. Saving model state.




Epoch 2/50 | Train Loss: 0.6570, Train Acc: 0.6154 | Val Loss: 0.6575, Val Acc: 0.6367 | LR: 1.0e-03
  -> New best validation loss: 0.6575. Saving model state.




Epoch 3/50 | Train Loss: 0.6380, Train Acc: 0.6394 | Val Loss: 0.6567, Val Acc: 0.6124 | LR: 1.0e-03
  -> New best validation loss: 0.6567. Saving model state.




Epoch 4/50 | Train Loss: 0.6179, Train Acc: 0.6631 | Val Loss: 0.6441, Val Acc: 0.5966 | LR: 1.0e-03
  -> New best validation loss: 0.6441. Saving model state.




Epoch 5/50 | Train Loss: 0.6037, Train Acc: 0.6847 | Val Loss: 0.6260, Val Acc: 0.6458 | LR: 1.0e-03
  -> New best validation loss: 0.6260. Saving model state.




Epoch 6/50 | Train Loss: 0.5961, Train Acc: 0.6898 | Val Loss: 0.6155, Val Acc: 0.6593 | LR: 1.0e-03
  -> New best validation loss: 0.6155. Saving model state.




Epoch 7/50 | Train Loss: 0.5875, Train Acc: 0.6950 | Val Loss: 0.6018, Val Acc: 0.6847 | LR: 1.0e-03
  -> New best validation loss: 0.6018. Saving model state.




Epoch 8/50 | Train Loss: 0.5818, Train Acc: 0.6991 | Val Loss: 0.6008, Val Acc: 0.6949 | LR: 1.0e-03
  -> New best validation loss: 0.6008. Saving model state.




Epoch 9/50 | Train Loss: 0.5769, Train Acc: 0.7020 | Val Loss: 0.5882, Val Acc: 0.7017 | LR: 1.0e-03
  -> New best validation loss: 0.5882. Saving model state.




Epoch 10/50 | Train Loss: 0.5701, Train Acc: 0.7058 | Val Loss: 0.5827, Val Acc: 0.7096 | LR: 1.0e-03
  -> New best validation loss: 0.5827. Saving model state.




Epoch 11/50 | Train Loss: 0.5704, Train Acc: 0.7065 | Val Loss: 0.5796, Val Acc: 0.7034 | LR: 1.0e-03
  -> New best validation loss: 0.5796. Saving model state.




Epoch 12/50 | Train Loss: 0.5624, Train Acc: 0.7168 | Val Loss: 0.5564, Val Acc: 0.7181 | LR: 1.0e-03
  -> New best validation loss: 0.5564. Saving model state.




Epoch 13/50 | Train Loss: 0.5444, Train Acc: 0.7316 | Val Loss: 0.5366, Val Acc: 0.7407 | LR: 1.0e-03
  -> New best validation loss: 0.5366. Saving model state.




Epoch 14/50 | Train Loss: 0.5243, Train Acc: 0.7505 | Val Loss: 0.5391, Val Acc: 0.7322 | LR: 1.0e-03




Epoch 15/50 | Train Loss: 0.5199, Train Acc: 0.7504 | Val Loss: 0.5508, Val Acc: 0.7254 | LR: 1.0e-03




Epoch 16/50 | Train Loss: 0.5048, Train Acc: 0.7563 | Val Loss: 0.5515, Val Acc: 0.7260 | LR: 1.0e-03




Epoch 17/50 | Train Loss: 0.5002, Train Acc: 0.7637 | Val Loss: 0.5294, Val Acc: 0.7299 | LR: 1.0e-03
  -> New best validation loss: 0.5294. Saving model state.




Epoch 18/50 | Train Loss: 0.4838, Train Acc: 0.7750 | Val Loss: 0.5094, Val Acc: 0.7576 | LR: 1.0e-03
  -> New best validation loss: 0.5094. Saving model state.




Epoch 19/50 | Train Loss: 0.4756, Train Acc: 0.7816 | Val Loss: 0.5144, Val Acc: 0.7503 | LR: 1.0e-03




Epoch 20/50 | Train Loss: 0.4772, Train Acc: 0.7796 | Val Loss: 0.5028, Val Acc: 0.7548 | LR: 1.0e-03
  -> New best validation loss: 0.5028. Saving model state.




Epoch 21/50 | Train Loss: 0.4584, Train Acc: 0.7885 | Val Loss: 0.5068, Val Acc: 0.7571 | LR: 1.0e-03




Epoch 22/50 | Train Loss: 0.4597, Train Acc: 0.7899 | Val Loss: 0.5433, Val Acc: 0.7220 | LR: 1.0e-03




Epoch 23/50 | Train Loss: 0.4611, Train Acc: 0.7875 | Val Loss: 0.5049, Val Acc: 0.7576 | LR: 1.0e-03




Epoch 24/50 | Train Loss: 0.4595, Train Acc: 0.7878 | Val Loss: 0.4994, Val Acc: 0.7701 | LR: 1.0e-03
  -> New best validation loss: 0.4994. Saving model state.




Epoch 25/50 | Train Loss: 0.4541, Train Acc: 0.7932 | Val Loss: 0.5101, Val Acc: 0.7599 | LR: 1.0e-03




Epoch 26/50 | Train Loss: 0.4430, Train Acc: 0.7985 | Val Loss: 0.4940, Val Acc: 0.7706 | LR: 1.0e-03
  -> New best validation loss: 0.4940. Saving model state.




Epoch 27/50 | Train Loss: 0.4510, Train Acc: 0.7939 | Val Loss: 0.4900, Val Acc: 0.7819 | LR: 1.0e-03
  -> New best validation loss: 0.4900. Saving model state.




Epoch 28/50 | Train Loss: 0.4482, Train Acc: 0.7939 | Val Loss: 0.4973, Val Acc: 0.7644 | LR: 1.0e-03




Epoch 29/50 | Train Loss: 0.4454, Train Acc: 0.7988 | Val Loss: 0.4932, Val Acc: 0.7667 | LR: 1.0e-03




Epoch 30/50 | Train Loss: 0.4406, Train Acc: 0.8056 | Val Loss: 0.4986, Val Acc: 0.7576 | LR: 1.0e-03




Epoch 31/50 | Train Loss: 0.4392, Train Acc: 0.8046 | Val Loss: 0.5004, Val Acc: 0.7548 | LR: 1.0e-03




Epoch 32/50 | Train Loss: 0.4373, Train Acc: 0.7988 | Val Loss: 0.4922, Val Acc: 0.7763 | LR: 1.0e-03




Epoch 33/50 | Train Loss: 0.4371, Train Acc: 0.8016 | Val Loss: 0.5067, Val Acc: 0.7548 | LR: 5.0e-04




Epoch 34/50 | Train Loss: 0.4005, Train Acc: 0.8320 | Val Loss: 0.4937, Val Acc: 0.7621 | LR: 5.0e-04




Epoch 35/50 | Train Loss: 0.3815, Train Acc: 0.8426 | Val Loss: 0.5000, Val Acc: 0.7514 | LR: 5.0e-04




Epoch 36/50 | Train Loss: 0.3765, Train Acc: 0.8410 | Val Loss: 0.4822, Val Acc: 0.7808 | LR: 5.0e-04
  -> New best validation loss: 0.4822. Saving model state.




Epoch 37/50 | Train Loss: 0.3705, Train Acc: 0.8450 | Val Loss: 0.4981, Val Acc: 0.7678 | LR: 5.0e-04




Epoch 38/50 | Train Loss: 0.3605, Train Acc: 0.8453 | Val Loss: 0.5017, Val Acc: 0.7627 | LR: 5.0e-04




Epoch 39/50 | Train Loss: 0.3505, Train Acc: 0.8511 | Val Loss: 0.5030, Val Acc: 0.7655 | LR: 5.0e-04




Epoch 40/50 | Train Loss: 0.3473, Train Acc: 0.8550 | Val Loss: 0.5141, Val Acc: 0.7520 | LR: 5.0e-04




Epoch 41/50 | Train Loss: 0.3443, Train Acc: 0.8530 | Val Loss: 0.5034, Val Acc: 0.7667 | LR: 5.0e-04




Epoch 42/50 | Train Loss: 0.3291, Train Acc: 0.8660 | Val Loss: 0.5212, Val Acc: 0.7588 | LR: 2.5e-04




Epoch 43/50 | Train Loss: 0.3097, Train Acc: 0.8753 | Val Loss: 0.5316, Val Acc: 0.7373 | LR: 2.5e-04




Epoch 44/50 | Train Loss: 0.2891, Train Acc: 0.8869 | Val Loss: 0.5230, Val Acc: 0.7565 | LR: 2.5e-04




Epoch 45/50 | Train Loss: 0.2731, Train Acc: 0.8933 | Val Loss: 0.5306, Val Acc: 0.7542 | LR: 2.5e-04




Epoch 46/50 | Train Loss: 0.2706, Train Acc: 0.8971 | Val Loss: 0.5420, Val Acc: 0.7475 | LR: 2.5e-04
Early stopping triggered after 46 epochs.
Loading best model state for final fold evaluation.

Evaluating best model on validation set for Fold 4
Fold Validation Accuracy: 0.7475
Fold Validation Balanced Accuracy: 0.7471
Fold Validation MCC: 0.4942
Fold Validation Sensitivity (Recall): 0.7571
Fold Validation Specificity: 0.7371
Fold Validation Loss: 0.5420
Fold Validation Confusion Matrix:
[[628 224]
 [223 695]]
Predicting on test set for Fold 4





--- Fold 5/5 ---
SequenceCNN calculated flat size: 2304
HybridModel combined input dimension for fc1: 544




Epoch 1/50 | Train Loss: 0.6831, Train Acc: 0.5760 | Val Loss: 0.6596, Val Acc: 0.6062 | LR: 1.0e-03
  -> New best validation loss: 0.6596. Saving model state.




Epoch 2/50 | Train Loss: 0.6442, Train Acc: 0.6325 | Val Loss: 0.6546, Val Acc: 0.6441 | LR: 1.0e-03
  -> New best validation loss: 0.6546. Saving model state.




Epoch 3/50 | Train Loss: 0.6260, Train Acc: 0.6523 | Val Loss: 0.6456, Val Acc: 0.6254 | LR: 1.0e-03
  -> New best validation loss: 0.6456. Saving model state.




Epoch 4/50 | Train Loss: 0.6152, Train Acc: 0.6708 | Val Loss: 0.6388, Val Acc: 0.6582 | LR: 1.0e-03
  -> New best validation loss: 0.6388. Saving model state.




Epoch 5/50 | Train Loss: 0.6034, Train Acc: 0.6789 | Val Loss: 0.6235, Val Acc: 0.6655 | LR: 1.0e-03
  -> New best validation loss: 0.6235. Saving model state.




Epoch 6/50 | Train Loss: 0.5928, Train Acc: 0.6888 | Val Loss: 0.6165, Val Acc: 0.6650 | LR: 1.0e-03
  -> New best validation loss: 0.6165. Saving model state.




Epoch 7/50 | Train Loss: 0.5872, Train Acc: 0.6928 | Val Loss: 0.6018, Val Acc: 0.7079 | LR: 1.0e-03
  -> New best validation loss: 0.6018. Saving model state.




Epoch 8/50 | Train Loss: 0.5811, Train Acc: 0.7011 | Val Loss: 0.5973, Val Acc: 0.6944 | LR: 1.0e-03
  -> New best validation loss: 0.5973. Saving model state.




Epoch 9/50 | Train Loss: 0.5743, Train Acc: 0.7048 | Val Loss: 0.5986, Val Acc: 0.6977 | LR: 1.0e-03




Epoch 10/50 | Train Loss: 0.5648, Train Acc: 0.7196 | Val Loss: 0.6027, Val Acc: 0.6740 | LR: 1.0e-03




Epoch 11/50 | Train Loss: 0.5707, Train Acc: 0.7063 | Val Loss: 0.5841, Val Acc: 0.7102 | LR: 1.0e-03
  -> New best validation loss: 0.5841. Saving model state.




Epoch 12/50 | Train Loss: 0.5610, Train Acc: 0.7154 | Val Loss: 0.5782, Val Acc: 0.7040 | LR: 1.0e-03
  -> New best validation loss: 0.5782. Saving model state.




Epoch 13/50 | Train Loss: 0.5514, Train Acc: 0.7342 | Val Loss: 0.5609, Val Acc: 0.7209 | LR: 1.0e-03
  -> New best validation loss: 0.5609. Saving model state.




Epoch 14/50 | Train Loss: 0.5318, Train Acc: 0.7439 | Val Loss: 0.5600, Val Acc: 0.7028 | LR: 1.0e-03
  -> New best validation loss: 0.5600. Saving model state.




Epoch 15/50 | Train Loss: 0.5174, Train Acc: 0.7534 | Val Loss: 0.5401, Val Acc: 0.7333 | LR: 1.0e-03
  -> New best validation loss: 0.5401. Saving model state.




Epoch 16/50 | Train Loss: 0.5069, Train Acc: 0.7603 | Val Loss: 0.5367, Val Acc: 0.7215 | LR: 1.0e-03
  -> New best validation loss: 0.5367. Saving model state.




Epoch 17/50 | Train Loss: 0.4945, Train Acc: 0.7764 | Val Loss: 0.4948, Val Acc: 0.7757 | LR: 1.0e-03
  -> New best validation loss: 0.4948. Saving model state.




Epoch 18/50 | Train Loss: 0.4822, Train Acc: 0.7772 | Val Loss: 0.5065, Val Acc: 0.7565 | LR: 1.0e-03




Epoch 19/50 | Train Loss: 0.4807, Train Acc: 0.7762 | Val Loss: 0.5026, Val Acc: 0.7610 | LR: 1.0e-03




Epoch 20/50 | Train Loss: 0.4790, Train Acc: 0.7790 | Val Loss: 0.4956, Val Acc: 0.7576 | LR: 1.0e-03




Epoch 21/50 | Train Loss: 0.4670, Train Acc: 0.7848 | Val Loss: 0.4908, Val Acc: 0.7644 | LR: 1.0e-03
  -> New best validation loss: 0.4908. Saving model state.




Epoch 22/50 | Train Loss: 0.4634, Train Acc: 0.7884 | Val Loss: 0.4939, Val Acc: 0.7684 | LR: 1.0e-03




Epoch 23/50 | Train Loss: 0.4666, Train Acc: 0.7827 | Val Loss: 0.4763, Val Acc: 0.7791 | LR: 1.0e-03
  -> New best validation loss: 0.4763. Saving model state.




Epoch 24/50 | Train Loss: 0.4574, Train Acc: 0.7882 | Val Loss: 0.4847, Val Acc: 0.7655 | LR: 1.0e-03




Epoch 25/50 | Train Loss: 0.4546, Train Acc: 0.7908 | Val Loss: 0.4794, Val Acc: 0.7780 | LR: 1.0e-03




Epoch 26/50 | Train Loss: 0.4540, Train Acc: 0.7937 | Val Loss: 0.4878, Val Acc: 0.7712 | LR: 1.0e-03




Epoch 27/50 | Train Loss: 0.4457, Train Acc: 0.8001 | Val Loss: 0.5040, Val Acc: 0.7627 | LR: 1.0e-03




Epoch 28/50 | Train Loss: 0.4446, Train Acc: 0.7977 | Val Loss: 0.5068, Val Acc: 0.7582 | LR: 1.0e-03




Epoch 29/50 | Train Loss: 0.4477, Train Acc: 0.7940 | Val Loss: 0.4762, Val Acc: 0.7712 | LR: 1.0e-03
  -> New best validation loss: 0.4762. Saving model state.




Epoch 30/50 | Train Loss: 0.4449, Train Acc: 0.8038 | Val Loss: 0.4770, Val Acc: 0.7661 | LR: 1.0e-03




Epoch 31/50 | Train Loss: 0.4461, Train Acc: 0.7984 | Val Loss: 0.4878, Val Acc: 0.7644 | LR: 1.0e-03




Epoch 32/50 | Train Loss: 0.4361, Train Acc: 0.8098 | Val Loss: 0.4888, Val Acc: 0.7650 | LR: 1.0e-03




Epoch 33/50 | Train Loss: 0.4303, Train Acc: 0.8081 | Val Loss: 0.4883, Val Acc: 0.7644 | LR: 1.0e-03




Epoch 34/50 | Train Loss: 0.4371, Train Acc: 0.8057 | Val Loss: 0.4864, Val Acc: 0.7734 | LR: 1.0e-03




Epoch 35/50 | Train Loss: 0.4259, Train Acc: 0.8108 | Val Loss: 0.4875, Val Acc: 0.7644 | LR: 5.0e-04




Epoch 36/50 | Train Loss: 0.3912, Train Acc: 0.8306 | Val Loss: 0.4812, Val Acc: 0.7757 | LR: 5.0e-04




Epoch 37/50 | Train Loss: 0.3794, Train Acc: 0.8391 | Val Loss: 0.4780, Val Acc: 0.7797 | LR: 5.0e-04




Epoch 38/50 | Train Loss: 0.3684, Train Acc: 0.8436 | Val Loss: 0.4901, Val Acc: 0.7678 | LR: 5.0e-04




Epoch 39/50 | Train Loss: 0.3595, Train Acc: 0.8482 | Val Loss: 0.4829, Val Acc: 0.7689 | LR: 5.0e-04
Early stopping triggered after 39 epochs.
Loading best model state for final fold evaluation.

Evaluating best model on validation set for Fold 5
Fold Validation Accuracy: 0.7689
Fold Validation Balanced Accuracy: 0.7681
Fold Validation MCC: 0.5369
Fold Validation Sensitivity (Recall): 0.7908
Fold Validation Specificity: 0.7453
Fold Validation Loss: 0.4829
Fold Validation Confusion Matrix:
[[635 217]
 [192 726]]
Predicting on test set for Fold 5





--- Cross-validation Summary (Validation Sets) ---
Avg accuracy: 0.7657 ± 0.0104
Avg balanced_acc: 0.7643 ± 0.0095
Avg mcc: 0.5320 ± 0.0226
Avg sensitivity: 0.8014 ± 0.0396
Avg specificity: 0.7273 ± 0.0264
Avg loss: 0.5011 ± 0.0235

--- Final Test Set Evaluation (Ensemble Predictions) ---
Ensemble Test Accuracy: 0.7410
Ensemble Test Balanced Accuracy: 0.7827
Ensemble Test MCC: 0.3435
Ensemble Test Sensitivity (Recall): 0.8333
Ensemble Test Specificity: 0.7321
Ensemble Test Loss: 0.5011
Ensemble Test Confusion Matrix:
[[1828  669]
 [  40  200]]

--- Overall Results ---

Average Cross-Validation Metrics:
  accuracy: 0.7657
  balanced_acc: 0.7643
  mcc: 0.5320
  sensitivity: 0.8014
  specificity: 0.7273
  loss: 0.5011

Final Ensembled Test Metrics:
  accuracy: 0.7410
  balanced_acc: 0.7827
  mcc: 0.3435
  sensitivity: 0.8333
  specificity: 0.7321
  loss: 0.5011
  Confusion Matrix:
[[1828  669]
 [  40  200]]
