In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, MessagePassing, NNConv, EdgeConv
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.data import Data, DataLoader
import numpy as np
import pandas as pd
from sklearn.preprocessing import RobustScaler # Keep import if you decide to use pre-scaling later
import traceback # For better error reporting
from sklearn.metrics import confusion_matrix, matthews_corrcoef, accuracy_score, balanced_accuracy_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import RobustScaler
import matplotlib.pyplot as plt
import random
import os

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# --- Amino Acid Encodings ---
# For sequence CNN track
AMINO_ACIDS = 'ARNDCQEGHILKMFPSTWYV-'  # Include padding char '-'
AA_TO_INT = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
VALID_AA = 'ARNDCQEGHILKMFPSTWYV'

# --- Simple GNN Layer similar to TF version ---
class SimpleGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(SimpleGNNLayer, self).__init__(aggr='add')  # "Add" aggregation (sum)
        self.node_update = nn.Linear(in_channels + out_channels, out_channels)
        self.edge_transform = nn.Linear(in_channels, out_channels)
        self.message = nn.Linear(in_channels + in_channels, out_channels)
        
    def forward(self, x, edge_index, edge_attr=None):
        # x: Node features [num_nodes, in_channels]
        # edge_index: Graph connectivity [2, num_edges]
        # edge_attr: Edge features [num_edges, in_channels]
        
        # No edge attributes: use zero vectors
        if edge_attr is None:
            edge_attr = torch.zeros(edge_index.size(1), x.size(1), device=x.device)
        
        # Transform edge features
        edge_attr_transformed = self.edge_transform(edge_attr)
        
        # Propagate messages
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr_transformed)
        
        # Update node features
        out = self.node_update(torch.cat([x, out], dim=1))
        
        return F.relu(out)
    
    def message(self, x_i, x_j, edge_attr):
        # x_i: Target node features [num_edges, in_channels]
        # x_j: Source node features [num_edges, in_channels]
        # edge_attr: Edge features [num_edges, in_channels]
        
        # Combine source node features with edge features
        combined = torch.cat([x_j, edge_attr], dim=1)
        
        # Generate messages
        return self.message(combined)

class GCNNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, dropout=0.4, layers=3):
        super(GCNNetwork, self).__init__()
        
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # First layer
        self.convs.append(GCNConv(input_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        # Middle layers
        for i in range(1, layers):
            self.convs.append(GCNConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        
        self.dropout = nn.Dropout(dropout)
        self.output_dim = hidden_dim
        
    def forward(self, x, edge_index): # Remove central_node_idx, batch, ptr
        # x: Node features [num_nodes, input_dim]
        # edge_index: Graph connectivity [2, num_edges]

        # First layer
        x = self.convs[0](x, edge_index)
        if x.shape[0] > 0: # Handle potentially empty graphs after processing
            x = self.batch_norms[0](x)
        x = F.relu(x)
        x = self.dropout(x)

        # Middle layers with residual connections
        for i in range(1, len(self.convs)):
            x_res = x  # Save for residual connection
            x = self.convs[i](x, edge_index)
            if x.shape[0] > 0:
                x = self.batch_norms[i](x)
            x = F.relu(x)
            if x_res.shape == x.shape: # Check shape match for residual connection
                x = x + x_res  # Residual connection
            x = self.dropout(x)

        # ---- REMOVE THE CENTRAL NODE EXTRACTION FROM HERE ----
        # graph_starts = ptr[:-1]
        # absolute_central_node_indices = graph_starts + central_node_idx
        # if absolute_central_node_indices.max() >= x.shape[0]:
        #     raise IndexError("GCN absolute central node index out of bounds")
        # central_node_features = x[absolute_central_node_indices]

        # Return ALL final node features
        return x  # Shape: [num_nodes, output_dim]

# --- Sequence CNN Module ---
class SequenceCNN(nn.Module):
    def __init__(self, vocab_size, embed_dim=21, out_channels=32, kernel_height=17, kernel_width=3, dropout=0.4):
        super(SequenceCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=AA_TO_INT['-'])
        
        # Using Conv2d like the TF example
        self.conv2d = nn.Conv2d(
            in_channels=1,  # Add channel dimension
            out_channels=out_channels,
            kernel_size=(kernel_height, kernel_width),
            padding='valid'  # Matches TF 'valid' padding
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.dropout1 = nn.Dropout(dropout)
        
        # Calculate flattened size based on input dimensions and operations
        # Assuming input with seq_len=33, embed_dim=21
        # After conv: [batch, out_channels, (33-17+1), (21-3+1)] = [batch, 32, 17, 19]
        # After pool: [batch, 32, 8, 9]
        self.pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        self.flatten = nn.Flatten()
        flat_size = out_channels * 8 * 9  # 32 * 8 * 9 = 2304
        
        self.fc1 = nn.Linear(flat_size, 32)
        self.bn2 = nn.BatchNorm1d(32)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, seq_indices):
        # seq_indices shape: [batch_size, seq_len]
        x = self.embedding(seq_indices)  # [batch_size, seq_len, embed_dim]
        
        # Reshape for Conv2D: [batch_size, channels, height, width]
        x = x.unsqueeze(1)  # [batch_size, 1, seq_len, embed_dim]
        
        x = self.conv2d(x)  # [batch_size, out_channels, H_out, W_out]
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.pool(x)
        
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout2(x)
        
        return x


# --- Hybrid Model (GNN + CNN) ---
class HybridModel(nn.Module):
    def __init__(self, gnn_type, node_feature_dim, edge_feature_dim=18,
                 hidden_dim=128, seq_len=33):
        super(HybridModel, self).__init__()

        # GNN track (keep as before)
        self.gnn_type = gnn_type
        if gnn_type == 'gcn':
            self.gnn = GCNNetwork(node_feature_dim, hidden_dim)
        else:
            raise ValueError(f"Unsupported GNN type: {gnn_type}")

        # Sequence CNN track (keep as before)
        self.sequence_cnn = SequenceCNN(
            vocab_size=len(AMINO_ACIDS),
            embed_dim=21,
            out_channels=32
        )

        # Combination layers
        gnn_output_dim = hidden_dim
        cnn_output_dim = 32  # From SequenceCNN

        # ---- ADJUST INPUT DIMENSION FOR FC1 ----
        # Input: Central Node + Global Mean Pool + Global Max Pool + CNN Features
        combined_input_dim = gnn_output_dim + gnn_output_dim + gnn_output_dim + cnn_output_dim

        self.fc1 = nn.Linear(combined_input_dim, 64) # Adjust input size
        self.bn = nn.BatchNorm1d(64)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, data):
        # Process GNN input
        x, edge_index = data.x, data.edge_index
        edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
        batch = data.batch # Essential for global pooling
        central_node_idx = data.central_node_idx
        ptr = data.ptr if hasattr(data, 'ptr') else None # Needed for central node indexing

        # Handle missing batch vector (e.g., during single graph inference)
        if batch is None:
            batch = torch.zeros(x.shape[0], dtype=torch.long, device=x.device)

        # Handle potentially missing ptr (needed if using ptr for central node)
        if ptr is None:
            # Simple way to reconstruct ptr if batch is available (assumes contiguous batches)
            # This might be needed if using model outside DataLoader context
            # For standard DataLoader use, ptr should be present.
            # Add more robust handling if needed.
             if batch.numel() > 0:
                 counts = torch.bincount(batch)
                 ptr = torch.cat([torch.tensor([0], device=batch.device), counts.cumsum(0)])
             else: # Handle empty batch case
                 ptr = torch.tensor([0], device=batch.device)


        # Process sequence input
        seq_flat = data.sequence  # Flattened sequence tensor
        batch_size = data.num_graphs
        seq_len = 33  # Fixed sequence length
        # (Keep sequence reshaping and error handling as before)
        if seq_flat.numel() != batch_size * seq_len:
            print(f"Warning: Sequence tensor size mismatch: {seq_flat.numel()} vs expected {batch_size * seq_len}. Using zeros.")
            seq_tensor = torch.zeros((batch_size, seq_len), dtype=torch.long, device=seq_flat.device)
        else:
            seq_tensor = seq_flat.view(batch_size, seq_len)


        # --- GNN track ---
        # Call the GNN - it now returns ALL node features
        if self.gnn_type == 'gcn':
            # Pass only node features and edge index to the modified GCNNetwork
            gnn_node_features = self.gnn(x, edge_index)
        else:
             raise ValueError(f"Unsupported GNN type: {self.gnn_type}")

        # If GNN returned all node features (like modified GCN):
        if self.gnn_type == 'gcn': # or if GAT is modified similarly
             # --- Perform Central Node Extraction HERE ---
             graph_starts = ptr[:-1]
             absolute_central_node_indices = graph_starts + central_node_idx
             # Add boundary check
             if absolute_central_node_indices.numel() > 0: # Check if indices exist
                 if absolute_central_node_indices.max() >= gnn_node_features.shape[0]:
                      raise IndexError(f"Hybrid absolute central node index out of bounds: Max index {absolute_central_node_indices.max()} vs shape {gnn_node_features.shape[0]}")
                 central_node_features = gnn_node_features[absolute_central_node_indices]
             else: # Handle case where there are no central node indices (e.g., empty batch)
                  # Need to decide appropriate action, e.g., create zero tensor of correct shape
                  print("Warning: No central node indices found.")
                  central_node_features = torch.zeros((0, self.gnn.output_dim), device=gnn_node_features.device) # Placeholder


             # --- Perform Global Pooling HERE ---
             global_avg_features = global_mean_pool(gnn_node_features, batch)
             # Optional: global_max_features = global_max_pool(gnn_node_features, batch)
             global_max_features = global_max_pool(gnn_node_features, batch)

        else: # If GNN returned only central node features (like original GAT)
             central_node_features = gnn_node_features # Already extracted
             # Cannot do global pooling unless GNN returns all node features
             # Create a placeholder zero tensor for global features if needed for concatenation consistency
             print(f"Warning: Global pooling skipped for GNN type {self.gnn_type} as it doesn't return all node features.")
             global_avg_features = torch.zeros((batch_size, self.gnn.output_dim), device=central_node_features.device)
             global_max_features = torch.zeros((batch_size, self.gnn.output_dim), device=central_node_features.device)



        # --- CNN track ---
        seq_features = self.sequence_cnn(seq_tensor)


        # --- Combine features ---
        # Ensure dimensions match batch size before concatenating
        # Central node features should have shape [batch_size, gnn_dim]
        # Global pool features should have shape [batch_size, gnn_dim]
        # Seq features should have shape [batch_size, cnn_dim]

        # Basic dimension check
        if not (central_node_features.shape[0] == global_avg_features.shape[0] == seq_features.shape[0]):
             # Handle potential mismatch if batch sizes derived differently or errors occurred
             print(f"Warning: Feature dimension mismatch before concatenation.")
             print(f"Central: {central_node_features.shape}, Global: {global_avg_features.shape}, Seq: {seq_features.shape}")
             # Fallback or raise error - For now, attempt concatenation, may fail
             min_batch_size = min(central_node_features.shape[0], global_avg_features.shape[0], seq_features.shape[0])
             if min_batch_size <= 0:
                 # If any feature set is empty, return zeros or handle appropriately
                 return torch.zeros((batch_size, 1), device=x.device) # Return dummy output

             combined = torch.cat([
                 central_node_features[:min_batch_size],
                 global_avg_features[:min_batch_size],
                 global_max_features[:min_batch_size],
                 seq_features[:min_batch_size]
             ], dim=1)
        else:
            #  combined = torch.cat([central_node_features, global_avg_features, seq_features], dim=1)
             # If using mean and max pooling:
             combined = torch.cat([central_node_features, global_avg_features, global_max_features, seq_features], dim=1)


        # --- Final classification ---
        x = self.fc1(combined)
        # Add check for batch norm input size
        if x.shape[0] > 0: # Can't apply BatchNorm1d to empty input
             x = self.bn(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return torch.sigmoid(x)


def integer_encode_sequence(sequence):
    """Integer encode sequence for the CNN track (includes padding char)."""
    return [AA_TO_INT.get(char, AA_TO_INT['-']) for char in sequence]

def prepare_graph_data(df, distance_threshold=8.0, use_ss=True): # Added use_ss flag
    """
    Prepare graph data for PyTorch Geometric, extracting detailed node features
    for all valid residues in the 33-residue window.

    Includes: AA one-hot, central K indicator, phi, psi, omega, tau, chi1-4 (sin/cos),
              SASA, SS (one-hot), pLDDT.
    """
    graph_list = []
    labels = []
    skipped = 0
    expected_seq_len = 33
    central_k_pos_abs = 16 # 0-based index in the 33-residue window

    feature_names_to_parse = [
        'phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4',
        'sasa', 'ss', 'plDDT', 'distance_map'
    ]

    for idx, row in df.iterrows():
        try:
            sequence = row['sequence']
            label = row['label']

            # --- Initial Validation ---
            if pd.isna(sequence) or len(sequence) != expected_seq_len or sequence[central_k_pos_abs] != 'K':
                skipped += 1
                continue

            # --- Parse Required Data ---
            parsed_data = {}
            valid_row = True
            for name in feature_names_to_parse:
                if name not in row or pd.isna(row[name]):
                    print(f"Warning: Missing or NaN data for '{name}' in row {idx}. Skipping row.")
                    valid_row = False
                    break
                try:
                    # Special handling for ss string
                    if name == 'ss':
                         parsed_data[name] = str(row[name]) # Keep as string initially
                         if len(parsed_data[name]) != expected_seq_len:
                              raise ValueError(f"SS sequence length mismatch")
                    else:
                        # Use eval carefully - ensure input data is trusted
                        parsed_data[name] = np.array(eval(str(row[name])), dtype=np.float32)
                except Exception as e:
                    print(f"Warning: Error parsing '{name}' in row {idx}: {e}. Skipping row.")
                    # traceback.print_exc() # Uncomment for detailed parsing errors
                    valid_row = False
                    break
            if not valid_row:
                skipped += 1
                continue

            # Reshape distance map if needed
            if parsed_data['distance_map'].size == expected_seq_len * expected_seq_len:
                 distance_map = parsed_data['distance_map'].reshape(expected_seq_len, expected_seq_len)
            else:
                 print(f"Warning: Unexpected distance_map size ({parsed_data['distance_map'].size}) in row {idx}. Skipping.")
                 skipped += 1
                 continue


            # --- Identify Valid (Non-padded) Positions ---
            valid_pos_indices = [i for i, aa in enumerate(sequence) if aa != '-']
            if not valid_pos_indices:
                skipped += 1
                continue

            num_nodes = len(valid_pos_indices)
            valid_sequence = ''.join([sequence[i] for i in valid_pos_indices])

            # Find new 0-based index of central K within the valid nodes
            try:
                # Map absolute central K position (16) to its index in the valid_pos_indices list
                central_k_new_idx = valid_pos_indices.index(central_k_pos_abs)
            except ValueError:
                # This happens if the central K was padded '-'
                skipped += 1
                continue

            # --- Node Feature Extraction ---
            node_features_list = []

            # 1. One-hot encode amino acids (20 features)
            aa_onehot = np.zeros((num_nodes, len(VALID_AA)), dtype=np.float32)
            for i, aa in enumerate(valid_sequence):
                aa_idx = VALID_AA.find(aa)
                if aa_idx >= 0:
                    aa_onehot[i, aa_idx] = 1.0
            node_features_list.append(aa_onehot)

            # 2. Central K indicator (1 feature)
            is_central_k = np.zeros((num_nodes, 1), dtype=np.float32)
            is_central_k[central_k_new_idx, 0] = 1.0
            node_features_list.append(is_central_k)

            # 3. Process Angles (phi, psi, omega, tau, chi1-4) -> sin/cos encoding (2 features each)
            angle_keys = ['phi', 'psi', 'omega', 'tau', 'chi1', 'chi2', 'chi3', 'chi4']
            angle_keys = ['phi', 'psi', 'omega']

            for key in angle_keys:
                # Extract angles for valid positions
                valid_angles = parsed_data[key][valid_pos_indices]
                # Handle potential NaNs from parsing (e.g., chi angles for GLY) -> replace with 0
                valid_angles = np.nan_to_num(valid_angles)
                # Convert to radians and apply sin/cos
                angle_rad = np.pi * valid_angles / 180.0
                sin_cos_features = np.stack([np.sin(angle_rad), np.cos(angle_rad)], axis=-1)
                node_features_list.append(sin_cos_features.astype(np.float32))

            # 4. Process SASA (1 feature per node)
            valid_sasa = parsed_data['sasa'][valid_pos_indices].reshape(-1, 1)
            valid_sasa = np.nan_to_num(valid_sasa) # Handle potential NaNs
            # node_features_list.append(valid_sasa.astype(np.float32)) # Add raw value, scaling handled by BatchNorm

            # 5. Process SS (Secondary Structure) (3 features per node if use_ss=True)
            if use_ss:
                ss_string = parsed_data['ss']
                valid_ss = [ss_string[i] for i in valid_pos_indices]
                ss_onehot = np.zeros((num_nodes, 3), dtype=np.float32) # H=0, E=1, L=2
                ss_map = {'H': 0, 'E': 1, 'L': 2, '-': 2} # Map '-' to 'L' or handle differently if needed
                for i, ss_char in enumerate(valid_ss):
                    ss_idx = ss_map.get(ss_char, 2) # Default to 'L' if unexpected char
                    ss_onehot[i, ss_idx] = 1.0
                # node_features_list.append(ss_onehot)

            # 6. Process plDDT (1 feature per node)
            valid_plddt = parsed_data['plDDT'][valid_pos_indices].reshape(-1, 1)
            valid_plddt = np.nan_to_num(valid_plddt) # Handle potential NaNs
            # node_features_list.append(valid_plddt.astype(np.float32)) # Add raw value

            # --- Concatenate all node features ---
            try:
                node_features = np.concatenate(node_features_list, axis=1)
            except ValueError as e:
                 print(f"Error concatenating features for row {idx}: {e}")
                 # Print shapes for debugging
                 for i, feat in enumerate(node_features_list):
                     print(f"  Feature {i} shape: {feat.shape}")
                 skipped += 1
                 continue

            # --- Edge Construction (based on distance map) ---
            # Extract submatrix for valid positions
            valid_distance_map = distance_map[np.ix_(valid_pos_indices, valid_pos_indices)]

            edges = []
            edge_features = [] # Keep edge feature extraction simple for now

            # Create edges based on distance threshold
            adj = valid_distance_map < distance_threshold
            adj &= valid_distance_map > 0 # Ensure distance > 0
            np.fill_diagonal(adj, False) # No self-loops from distance map
            edge_list = np.argwhere(adj) # Find indices where condition is True

            if edge_list.shape[0] > 0:
                edges = edge_list.tolist()
                # Basic edge features: distance bins (as before)
                for i, j in edges:
                    dist = valid_distance_map[i, j]
                    edge_feature = np.zeros(4, dtype=np.float32)
                    if dist <= 4.0: edge_feature[0] = 1.0
                    elif dist <= 8.0: edge_feature[1] = 1.0
                    elif dist <= 12.0: edge_feature[2] = 1.0
                    else: edge_feature[3] = 1.0
                    edge_features.append(edge_feature)

            # Add sequential edges as fallback if NO distance-based edges exist
            # (Consider if you ALWAYS want sequential edges added)
            if not edges and num_nodes > 1:
                for i in range(num_nodes - 1):
                    edges.extend([[i, i+1], [i+1, i]])
                    # Default features for sequential edges (e.g., bin 1-4Å)
                    default_edge_feat = np.array([1.0, 0.0, 0.0, 0.0], dtype=np.float32)
                    edge_features.extend([default_edge_feat, default_edge_feat])

            # Skip graph if it has no edges at all
            if not edges:
                skipped += 1
                continue

            # --- Convert to PyTorch Tensors ---
            x = torch.tensor(node_features, dtype=torch.float)
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_features, dtype=torch.float)
            y = torch.tensor([label], dtype=torch.float)

            # Integer-encode full sequence for CNN (including padding)
            sequence_tensor = torch.tensor(integer_encode_sequence(sequence), dtype=torch.long)

            # Create PyG Data object
            data = Data(
                x=x,
                edge_index=edge_index,
                edge_attr=edge_attr,
                y=y,
                sequence=sequence_tensor,
                central_node_idx=torch.tensor([central_k_new_idx], dtype=torch.long) # Index within valid nodes
            )

            graph_list.append(data)
            labels.append(label)

        except Exception as e:
            print(f"--- Critical Error processing row {idx}: {e} ---")
            traceback.print_exc() # Print full traceback for critical errors
            skipped += 1
            continue

    print(f"\nCreated {len(graph_list)} graphs, skipped {skipped} rows during data preparation.")
    if graph_list:
         print(f"Example graph node feature dimension: {graph_list[0].x.shape[1]}")
         # Calculate expected dimension:
         expected_dim = len(VALID_AA) + 1 + len(angle_keys)*2 + 1 + (3 if use_ss else 0) + 1
         print(f"Expected dimension based on config: {expected_dim}")
    return graph_list, labels


def train_model(model, loader, optimizer, device, class_weights=None):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        output = model(batch)
        target = batch.y.view(-1, 1)
        
        # Apply class weights if provided
        if class_weights is not None:
            weight = torch.tensor([class_weights[int(t.item())] for t in target], device=device).view(-1, 1)
            loss = F.binary_cross_entropy(output, target, weight=weight)
        else:
            loss = F.binary_cross_entropy(output, target)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * batch.num_graphs
        
        # Calculate accuracy
        pred = (output > 0.5).float()
        correct += (pred == target).sum().item()
        total += target.size(0)
    
    return total_loss / total, correct / total


def evaluate_model(model, loader, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            output = model(batch)
            target = batch.y.view(-1, 1)
            
            loss = F.binary_cross_entropy(output, target, reduction='sum')
            total_loss += loss.item()
            
            pred = (output > 0.5).float()
            all_preds.append(pred.cpu().numpy())
            all_targets.append(target.cpu().numpy())
    
    all_preds = np.concatenate(all_preds).flatten()
    all_targets = np.concatenate(all_targets).flatten()
    
    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_preds)
    balanced_acc = balanced_accuracy_score(all_targets, all_preds)
    mcc = matthews_corrcoef(all_targets, all_preds)
    
    cm = confusion_matrix(all_targets, all_preds)
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    else:
        sensitivity = specificity = 0
    
    return {
        'accuracy': accuracy,
        'balanced_acc': balanced_acc,
        'mcc': mcc,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'confusion_matrix': cm,
        'loss': total_loss / len(all_targets),
        'predictions': all_preds,
        'targets': all_targets
    }


def print_metrics(metrics):
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Balanced Accuracy: {metrics['balanced_acc']:.4f}")
    print(f"MCC: {metrics['mcc']:.4f}")
    print(f"Sensitivity: {metrics['sensitivity']:.4f}")
    print(f"Specificity: {metrics['specificity']:.4f}")
    print("Confusion Matrix:")
    print(metrics['confusion_matrix'])


def train_with_cv(train_df, test_df, gnn_type='gcn', distance_threshold=8.0):
    """
    Train model with cross-validation
    """
    print(f"\n--- Training {gnn_type.upper()} model with distance threshold {distance_threshold}Å ---")
    
    # Prepare data
    train_graphs, train_labels = prepare_graph_data(train_df, distance_threshold)
    test_graphs, test_labels = prepare_graph_data(test_df, distance_threshold)
    
    if not train_graphs:
        print("Error: No training graphs created.")
        return None
    
    # Calculate class weights
    total = len(train_labels)
    pos = sum(train_labels)
    neg = total - pos
    class_weights = {
        0: total / (2 * neg) if neg > 0 else 1.0,
        1: total / (2 * pos) if pos > 0 else 1.0
    }
    print(f"Class weights: {class_weights}")
    
    # Cross-validation
    kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
    fold_metrics = []
    test_predictions = []
    
    for fold, (train_idx, val_idx) in enumerate(kfold.split(train_graphs, train_labels), 1):
        print(f"\n--- Fold {fold}/5 ---")
        
        # Split data
        train_fold = [train_graphs[i] for i in train_idx]
        val_fold = [train_graphs[i] for i in val_idx]
        
        # Create data loaders
        train_loader = DataLoader(train_fold, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_fold, batch_size=32, shuffle=False)
        
        # Create model
        node_feature_dim = train_graphs[0].x.shape[1]
        model = HybridModel(
            gnn_type=gnn_type,
            node_feature_dim=node_feature_dim,
            edge_feature_dim=train_graphs[0].edge_attr.shape[1],
            hidden_dim=128
        ).to(device)
        
        # Optimizer and scheduler
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )
        
        # Training loop
        epochs = 50
        patience = 10
        best_val_loss = float('inf')
        counter = 0
        best_state_dict = None
        
        for epoch in range(epochs):
            train_loss, train_acc = train_model(
                model, train_loader, optimizer, device, class_weights
            )
            
            val_metrics = evaluate_model(model, val_loader, device)
            val_loss = val_metrics['loss']
            val_acc = val_metrics['accuracy']
            
            scheduler.step(val_loss)
            
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                counter = 0
                best_state_dict = model.state_dict().copy()
                print(f"  -> New best validation loss: {best_val_loss:.4f}")
            else:
                counter += 1
                if counter >= patience:
                    print("Early stopping triggered")
                    break
        
        # Load best model
        if best_state_dict:
            model.load_state_dict(best_state_dict)
        
        # Evaluate on validation set
        val_metrics = evaluate_model(model, val_loader, device)
        print("\nValidation metrics:")
        print_metrics(val_metrics)
        fold_metrics.append(val_metrics)
        
        # Predict on test set
        if test_graphs:
            test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)
            test_fold_metrics = evaluate_model(model, test_loader, device)
            test_predictions.append(test_fold_metrics['predictions'])
    
    # Print average metrics across folds
    print("\n--- Cross-validation summary ---")
    metrics_to_report = ['accuracy', 'balanced_acc', 'mcc', 'sensitivity', 'specificity']
    for metric in metrics_to_report:
        values = [m[metric] for m in fold_metrics]
        mean_val = np.mean(values)
        std_val = np.std(values)
        print(f"{metric}: {mean_val:.4f} ± {std_val:.4f}")
    
    # Ensemble prediction on test set
    if test_predictions and test_labels:
        print("\n--- Test set evaluation (ensemble) ---")
        test_pred_avg = np.mean(np.stack(test_predictions), axis=0)
        test_pred_binary = (test_pred_avg > 0.5).astype(int)
        
        test_metrics = {
            'accuracy': accuracy_score(test_labels, test_pred_binary),
            'balanced_acc': balanced_accuracy_score(test_labels, test_pred_binary),
            'mcc': matthews_corrcoef(test_labels, test_pred_binary),
            'confusion_matrix': confusion_matrix(test_labels, test_pred_binary)
        }
        
        # Calculate sensitivity and specificity
        cm = test_metrics['confusion_matrix']
        if cm.shape == (2, 2):
            tn, fp, fn, tp = cm.ravel()
            test_metrics['sensitivity'] = tp / (tp + fn) if (tp + fn) > 0 else 0
            test_metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
        
        print_metrics(test_metrics)
        return test_metrics
    
    return None


if __name__ == "__main__":
    # Load data
    try:
        train_df = pd.read_csv("../../data/train/structure/processed_features_train.csv")
        test_df = pd.read_csv("../../data/test/structure/processed_features_test.csv")
        
        # Print class distribution
        print("Train class distribution:", train_df['label'].value_counts().to_dict())
        print("Test class distribution:", test_df['label'].value_counts().to_dict())
        
        # Train models
        gnn_types = ['gcn']
        threshold = 8.0  # You can try different thresholds
        
        for gnn_type in gnn_types:
            results = train_with_cv(train_df, test_df, gnn_type=gnn_type, distance_threshold=threshold)
            
    except Exception as e:
        print(f"Error occurred: {e}")

Using device: cuda
Train class distribution: {1: 4592, 0: 4261}
Test class distribution: {0: 2497, 1: 240}

--- Training GCN model with distance threshold 8.0Å ---

Created 8853 graphs, skipped 0 rows during data preparation.
Example graph node feature dimension: 27
Expected dimension based on config: 32

Created 2737 graphs, skipped 0 rows during data preparation.
Example graph node feature dimension: 27
Expected dimension based on config: 32
Class weights: {0: 1.0388406477352734, 1: 0.9639590592334495}

--- Fold 1/5 ---




Epoch 1/50, Train Loss: 0.6812, Train Acc: 0.5716, Val Loss: 0.6731, Val Acc: 0.5697
  -> New best validation loss: 0.6731




Epoch 2/50, Train Loss: 0.6471, Train Acc: 0.6330, Val Loss: 0.6552, Val Acc: 0.6047
  -> New best validation loss: 0.6552




Epoch 3/50, Train Loss: 0.6322, Train Acc: 0.6445, Val Loss: 0.6451, Val Acc: 0.6211
  -> New best validation loss: 0.6451




Epoch 4/50, Train Loss: 0.6167, Train Acc: 0.6638, Val Loss: 0.6389, Val Acc: 0.6307
  -> New best validation loss: 0.6389




Epoch 5/50, Train Loss: 0.6051, Train Acc: 0.6803, Val Loss: 0.6237, Val Acc: 0.6499
  -> New best validation loss: 0.6237




Epoch 6/50, Train Loss: 0.6042, Train Acc: 0.6757, Val Loss: 0.6315, Val Acc: 0.6403




Epoch 7/50, Train Loss: 0.5970, Train Acc: 0.6888, Val Loss: 0.6006, Val Acc: 0.6968
  -> New best validation loss: 0.6006




Epoch 8/50, Train Loss: 0.5874, Train Acc: 0.6903, Val Loss: 0.5880, Val Acc: 0.6945
  -> New best validation loss: 0.5880




Epoch 9/50, Train Loss: 0.5805, Train Acc: 0.7025, Val Loss: 0.5991, Val Acc: 0.6765




Epoch 10/50, Train Loss: 0.5780, Train Acc: 0.7019, Val Loss: 0.6105, Val Acc: 0.6392




Epoch 11/50, Train Loss: 0.5681, Train Acc: 0.7135, Val Loss: 0.6190, Val Acc: 0.6533




Epoch 12/50, Train Loss: 0.5609, Train Acc: 0.7177, Val Loss: 0.5755, Val Acc: 0.7064
  -> New best validation loss: 0.5755




Epoch 13/50, Train Loss: 0.5496, Train Acc: 0.7265, Val Loss: 0.5532, Val Acc: 0.7205
  -> New best validation loss: 0.5532




Epoch 14/50, Train Loss: 0.5435, Train Acc: 0.7348, Val Loss: 0.5578, Val Acc: 0.7301




Epoch 15/50, Train Loss: 0.5410, Train Acc: 0.7333, Val Loss: 0.5489, Val Acc: 0.7284
  -> New best validation loss: 0.5489




Epoch 16/50, Train Loss: 0.5355, Train Acc: 0.7402, Val Loss: 0.5457, Val Acc: 0.7312
  -> New best validation loss: 0.5457




Epoch 17/50, Train Loss: 0.5340, Train Acc: 0.7398, Val Loss: 0.5451, Val Acc: 0.7284
  -> New best validation loss: 0.5451




Epoch 18/50, Train Loss: 0.5265, Train Acc: 0.7492, Val Loss: 0.5353, Val Acc: 0.7397
  -> New best validation loss: 0.5353




Epoch 19/50, Train Loss: 0.5280, Train Acc: 0.7482, Val Loss: 0.5388, Val Acc: 0.7391




Epoch 20/50, Train Loss: 0.5179, Train Acc: 0.7484, Val Loss: 0.5369, Val Acc: 0.7357




Epoch 21/50, Train Loss: 0.5207, Train Acc: 0.7454, Val Loss: 0.5546, Val Acc: 0.7137




Epoch 22/50, Train Loss: 0.5174, Train Acc: 0.7553, Val Loss: 0.5381, Val Acc: 0.7318




Epoch 23/50, Train Loss: 0.5106, Train Acc: 0.7544, Val Loss: 0.5203, Val Acc: 0.7431
  -> New best validation loss: 0.5203




Epoch 24/50, Train Loss: 0.5151, Train Acc: 0.7557, Val Loss: 0.5376, Val Acc: 0.7307




Epoch 25/50, Train Loss: 0.5171, Train Acc: 0.7509, Val Loss: 0.5389, Val Acc: 0.7216




Epoch 26/50, Train Loss: 0.5023, Train Acc: 0.7609, Val Loss: 0.5304, Val Acc: 0.7465




Epoch 27/50, Train Loss: 0.5077, Train Acc: 0.7588, Val Loss: 0.5408, Val Acc: 0.7290




Epoch 28/50, Train Loss: 0.4990, Train Acc: 0.7598, Val Loss: 0.5410, Val Acc: 0.7284




Epoch 29/50, Train Loss: 0.5000, Train Acc: 0.7657, Val Loss: 0.5372, Val Acc: 0.7301




Epoch 30/50, Train Loss: 0.4752, Train Acc: 0.7810, Val Loss: 0.5275, Val Acc: 0.7470




Epoch 31/50, Train Loss: 0.4612, Train Acc: 0.7912, Val Loss: 0.5292, Val Acc: 0.7295




Epoch 32/50, Train Loss: 0.4500, Train Acc: 0.7917, Val Loss: 0.5322, Val Acc: 0.7374




Epoch 33/50, Train Loss: 0.4509, Train Acc: 0.7948, Val Loss: 0.5300, Val Acc: 0.7363
Early stopping triggered

Validation metrics:
Accuracy: 0.7363
Balanced Accuracy: 0.7349
MCC: 0.4715
Sensitivity: 0.7715
Specificity: 0.6984
Confusion Matrix:
[[595 257]
 [210 709]]





--- Fold 2/5 ---




Epoch 1/50, Train Loss: 0.6873, Train Acc: 0.5638, Val Loss: 0.7061, Val Acc: 0.5212
  -> New best validation loss: 0.7061




Epoch 2/50, Train Loss: 0.6599, Train Acc: 0.6127, Val Loss: 0.6595, Val Acc: 0.6138
  -> New best validation loss: 0.6595




Epoch 3/50, Train Loss: 0.6366, Train Acc: 0.6456, Val Loss: 0.6550, Val Acc: 0.5839
  -> New best validation loss: 0.6550




Epoch 4/50, Train Loss: 0.6204, Train Acc: 0.6644, Val Loss: 0.6304, Val Acc: 0.6951
  -> New best validation loss: 0.6304




Epoch 5/50, Train Loss: 0.6176, Train Acc: 0.6704, Val Loss: 0.6118, Val Acc: 0.6945
  -> New best validation loss: 0.6118




Epoch 6/50, Train Loss: 0.6044, Train Acc: 0.6837, Val Loss: 0.6319, Val Acc: 0.6132




Epoch 7/50, Train Loss: 0.5982, Train Acc: 0.6851, Val Loss: 0.6199, Val Acc: 0.6516




Epoch 8/50, Train Loss: 0.5895, Train Acc: 0.6958, Val Loss: 0.6089, Val Acc: 0.7019
  -> New best validation loss: 0.6089




Epoch 9/50, Train Loss: 0.5911, Train Acc: 0.6963, Val Loss: 0.5960, Val Acc: 0.7047
  -> New best validation loss: 0.5960




Epoch 10/50, Train Loss: 0.5777, Train Acc: 0.7057, Val Loss: 0.5876, Val Acc: 0.7143
  -> New best validation loss: 0.5876




Epoch 11/50, Train Loss: 0.5796, Train Acc: 0.7026, Val Loss: 0.5674, Val Acc: 0.7307
  -> New best validation loss: 0.5674




Epoch 12/50, Train Loss: 0.5717, Train Acc: 0.7071, Val Loss: 0.5789, Val Acc: 0.6985




Epoch 13/50, Train Loss: 0.5701, Train Acc: 0.7052, Val Loss: 0.5725, Val Acc: 0.7177




Epoch 14/50, Train Loss: 0.5583, Train Acc: 0.7169, Val Loss: 0.5710, Val Acc: 0.7132




Epoch 15/50, Train Loss: 0.5494, Train Acc: 0.7323, Val Loss: 0.5524, Val Acc: 0.7329
  -> New best validation loss: 0.5524




Epoch 16/50, Train Loss: 0.5397, Train Acc: 0.7324, Val Loss: 0.5677, Val Acc: 0.7103




Epoch 17/50, Train Loss: 0.5336, Train Acc: 0.7347, Val Loss: 0.5556, Val Acc: 0.7340




Epoch 18/50, Train Loss: 0.5317, Train Acc: 0.7379, Val Loss: 0.5334, Val Acc: 0.7493
  -> New best validation loss: 0.5334




Epoch 19/50, Train Loss: 0.5250, Train Acc: 0.7444, Val Loss: 0.5505, Val Acc: 0.7228




Epoch 20/50, Train Loss: 0.5235, Train Acc: 0.7471, Val Loss: 0.5310, Val Acc: 0.7516
  -> New best validation loss: 0.5310




Epoch 21/50, Train Loss: 0.5165, Train Acc: 0.7454, Val Loss: 0.5277, Val Acc: 0.7470
  -> New best validation loss: 0.5277




Epoch 22/50, Train Loss: 0.5150, Train Acc: 0.7559, Val Loss: 0.5247, Val Acc: 0.7431
  -> New best validation loss: 0.5247




Epoch 23/50, Train Loss: 0.5075, Train Acc: 0.7581, Val Loss: 0.5192, Val Acc: 0.7493
  -> New best validation loss: 0.5192




Epoch 24/50, Train Loss: 0.5071, Train Acc: 0.7591, Val Loss: 0.5163, Val Acc: 0.7691
  -> New best validation loss: 0.5163




Epoch 25/50, Train Loss: 0.5025, Train Acc: 0.7600, Val Loss: 0.5404, Val Acc: 0.7380




Epoch 26/50, Train Loss: 0.5034, Train Acc: 0.7632, Val Loss: 0.5338, Val Acc: 0.7414




Epoch 27/50, Train Loss: 0.5037, Train Acc: 0.7557, Val Loss: 0.5320, Val Acc: 0.7459




Epoch 28/50, Train Loss: 0.4969, Train Acc: 0.7663, Val Loss: 0.5272, Val Acc: 0.7391




Epoch 29/50, Train Loss: 0.4919, Train Acc: 0.7667, Val Loss: 0.5292, Val Acc: 0.7391




Epoch 30/50, Train Loss: 0.4892, Train Acc: 0.7669, Val Loss: 0.5290, Val Acc: 0.7448




Epoch 31/50, Train Loss: 0.4575, Train Acc: 0.7907, Val Loss: 0.5166, Val Acc: 0.7465




Epoch 32/50, Train Loss: 0.4548, Train Acc: 0.7943, Val Loss: 0.5328, Val Acc: 0.7233




Epoch 33/50, Train Loss: 0.4367, Train Acc: 0.7978, Val Loss: 0.5417, Val Acc: 0.7154




Epoch 34/50, Train Loss: 0.4359, Train Acc: 0.8040, Val Loss: 0.5197, Val Acc: 0.7470
Early stopping triggered

Validation metrics:
Accuracy: 0.7470
Balanced Accuracy: 0.7466
MCC: 0.4933
Sensitivity: 0.7573
Specificity: 0.7359
Confusion Matrix:
[[627 225]
 [223 696]]





--- Fold 3/5 ---




Epoch 1/50, Train Loss: 0.6827, Train Acc: 0.5686, Val Loss: 0.6719, Val Acc: 0.5567
  -> New best validation loss: 0.6719




Epoch 2/50, Train Loss: 0.6497, Train Acc: 0.6274, Val Loss: 0.6640, Val Acc: 0.5968
  -> New best validation loss: 0.6640




Epoch 3/50, Train Loss: 0.6277, Train Acc: 0.6471, Val Loss: 0.6775, Val Acc: 0.5686




Epoch 4/50, Train Loss: 0.6158, Train Acc: 0.6709, Val Loss: 0.6264, Val Acc: 0.6539
  -> New best validation loss: 0.6264




Epoch 5/50, Train Loss: 0.6061, Train Acc: 0.6759, Val Loss: 0.6505, Val Acc: 0.5963




Epoch 6/50, Train Loss: 0.5980, Train Acc: 0.6833, Val Loss: 0.6438, Val Acc: 0.6126




Epoch 7/50, Train Loss: 0.5914, Train Acc: 0.6865, Val Loss: 0.6293, Val Acc: 0.6341




Epoch 8/50, Train Loss: 0.5834, Train Acc: 0.7014, Val Loss: 0.6099, Val Acc: 0.6759
  -> New best validation loss: 0.6099




Epoch 9/50, Train Loss: 0.5815, Train Acc: 0.7005, Val Loss: 0.5995, Val Acc: 0.6877
  -> New best validation loss: 0.5995




Epoch 10/50, Train Loss: 0.5754, Train Acc: 0.7059, Val Loss: 0.6025, Val Acc: 0.6719




Epoch 11/50, Train Loss: 0.5692, Train Acc: 0.7169, Val Loss: 0.6272, Val Acc: 0.6324




Epoch 12/50, Train Loss: 0.5660, Train Acc: 0.7170, Val Loss: 0.6378, Val Acc: 0.6307




Epoch 13/50, Train Loss: 0.5678, Train Acc: 0.7095, Val Loss: 0.6338, Val Acc: 0.6239




Epoch 14/50, Train Loss: 0.5602, Train Acc: 0.7155, Val Loss: 0.6050, Val Acc: 0.6601




Epoch 15/50, Train Loss: 0.5575, Train Acc: 0.7193, Val Loss: 0.6003, Val Acc: 0.6765




Epoch 16/50, Train Loss: 0.5281, Train Acc: 0.7463, Val Loss: 0.5640, Val Acc: 0.7115
  -> New best validation loss: 0.5640




Epoch 17/50, Train Loss: 0.5171, Train Acc: 0.7465, Val Loss: 0.6245, Val Acc: 0.6431




Epoch 18/50, Train Loss: 0.5060, Train Acc: 0.7609, Val Loss: 0.5979, Val Acc: 0.6787




Epoch 19/50, Train Loss: 0.4941, Train Acc: 0.7691, Val Loss: 0.5899, Val Acc: 0.6838




Epoch 20/50, Train Loss: 0.4869, Train Acc: 0.7725, Val Loss: 0.5743, Val Acc: 0.6945




Epoch 21/50, Train Loss: 0.4773, Train Acc: 0.7794, Val Loss: 0.5844, Val Acc: 0.6753




Epoch 22/50, Train Loss: 0.4676, Train Acc: 0.7883, Val Loss: 0.5741, Val Acc: 0.6900




Epoch 23/50, Train Loss: 0.4301, Train Acc: 0.8104, Val Loss: 0.5690, Val Acc: 0.6900




Epoch 24/50, Train Loss: 0.4151, Train Acc: 0.8171, Val Loss: 0.5919, Val Acc: 0.6815




Epoch 25/50, Train Loss: 0.4001, Train Acc: 0.8242, Val Loss: 0.5726, Val Acc: 0.6945




Epoch 26/50, Train Loss: 0.3993, Train Acc: 0.8276, Val Loss: 0.5650, Val Acc: 0.7019
Early stopping triggered

Validation metrics:
Accuracy: 0.7019
Balanced Accuracy: 0.7044
MCC: 0.4117
Sensitivity: 0.6351
Specificity: 0.7737
Confusion Matrix:
[[660 193]
 [335 583]]





--- Fold 4/5 ---




Epoch 1/50, Train Loss: 0.6847, Train Acc: 0.5581, Val Loss: 0.6786, Val Acc: 0.5520
  -> New best validation loss: 0.6786




Epoch 2/50, Train Loss: 0.6581, Train Acc: 0.6054, Val Loss: 0.6578, Val Acc: 0.6056
  -> New best validation loss: 0.6578




Epoch 3/50, Train Loss: 0.6326, Train Acc: 0.6468, Val Loss: 0.6506, Val Acc: 0.6147
  -> New best validation loss: 0.6506




Epoch 4/50, Train Loss: 0.6222, Train Acc: 0.6527, Val Loss: 0.6385, Val Acc: 0.6554
  -> New best validation loss: 0.6385




Epoch 5/50, Train Loss: 0.6100, Train Acc: 0.6754, Val Loss: 0.6506, Val Acc: 0.5949




Epoch 6/50, Train Loss: 0.6039, Train Acc: 0.6744, Val Loss: 0.6090, Val Acc: 0.6859
  -> New best validation loss: 0.6090




Epoch 7/50, Train Loss: 0.5913, Train Acc: 0.6967, Val Loss: 0.6142, Val Acc: 0.6746




Epoch 8/50, Train Loss: 0.5945, Train Acc: 0.6893, Val Loss: 0.6281, Val Acc: 0.6215




Epoch 9/50, Train Loss: 0.5815, Train Acc: 0.6983, Val Loss: 0.6201, Val Acc: 0.6469




Epoch 10/50, Train Loss: 0.5815, Train Acc: 0.7018, Val Loss: 0.5952, Val Acc: 0.7006
  -> New best validation loss: 0.5952




Epoch 11/50, Train Loss: 0.5779, Train Acc: 0.7045, Val Loss: 0.5837, Val Acc: 0.7028
  -> New best validation loss: 0.5837




Epoch 12/50, Train Loss: 0.5753, Train Acc: 0.7062, Val Loss: 0.5823, Val Acc: 0.7124
  -> New best validation loss: 0.5823




Epoch 13/50, Train Loss: 0.5738, Train Acc: 0.7066, Val Loss: 0.5859, Val Acc: 0.6887




Epoch 14/50, Train Loss: 0.5629, Train Acc: 0.7171, Val Loss: 0.5785, Val Acc: 0.7073
  -> New best validation loss: 0.5785




Epoch 15/50, Train Loss: 0.5517, Train Acc: 0.7253, Val Loss: 0.5743, Val Acc: 0.7107
  -> New best validation loss: 0.5743




Epoch 16/50, Train Loss: 0.5452, Train Acc: 0.7327, Val Loss: 0.5537, Val Acc: 0.7266
  -> New best validation loss: 0.5537




Epoch 17/50, Train Loss: 0.5374, Train Acc: 0.7363, Val Loss: 0.5622, Val Acc: 0.7333




Epoch 18/50, Train Loss: 0.5355, Train Acc: 0.7406, Val Loss: 0.5819, Val Acc: 0.6966




Epoch 19/50, Train Loss: 0.5312, Train Acc: 0.7414, Val Loss: 0.5495, Val Acc: 0.7294
  -> New best validation loss: 0.5495




Epoch 20/50, Train Loss: 0.5246, Train Acc: 0.7470, Val Loss: 0.5565, Val Acc: 0.7158




Epoch 21/50, Train Loss: 0.5216, Train Acc: 0.7518, Val Loss: 0.5583, Val Acc: 0.7175




Epoch 22/50, Train Loss: 0.5198, Train Acc: 0.7495, Val Loss: 0.5622, Val Acc: 0.7203




Epoch 23/50, Train Loss: 0.5121, Train Acc: 0.7558, Val Loss: 0.5524, Val Acc: 0.7333




Epoch 24/50, Train Loss: 0.5152, Train Acc: 0.7560, Val Loss: 0.5386, Val Acc: 0.7418
  -> New best validation loss: 0.5386




Epoch 25/50, Train Loss: 0.5030, Train Acc: 0.7606, Val Loss: 0.5414, Val Acc: 0.7215




Epoch 26/50, Train Loss: 0.5111, Train Acc: 0.7563, Val Loss: 0.5494, Val Acc: 0.7356




Epoch 27/50, Train Loss: 0.5040, Train Acc: 0.7574, Val Loss: 0.5442, Val Acc: 0.7362




Epoch 28/50, Train Loss: 0.4997, Train Acc: 0.7625, Val Loss: 0.5533, Val Acc: 0.7220




Epoch 29/50, Train Loss: 0.4999, Train Acc: 0.7589, Val Loss: 0.5341, Val Acc: 0.7395
  -> New best validation loss: 0.5341




Epoch 30/50, Train Loss: 0.5010, Train Acc: 0.7607, Val Loss: 0.5518, Val Acc: 0.7192




Epoch 31/50, Train Loss: 0.4980, Train Acc: 0.7668, Val Loss: 0.5282, Val Acc: 0.7384
  -> New best validation loss: 0.5282




Epoch 32/50, Train Loss: 0.4983, Train Acc: 0.7630, Val Loss: 0.5415, Val Acc: 0.7260




Epoch 33/50, Train Loss: 0.4946, Train Acc: 0.7686, Val Loss: 0.5427, Val Acc: 0.7249




Epoch 34/50, Train Loss: 0.4978, Train Acc: 0.7639, Val Loss: 0.5420, Val Acc: 0.7384




Epoch 35/50, Train Loss: 0.4891, Train Acc: 0.7696, Val Loss: 0.5463, Val Acc: 0.7209




Epoch 36/50, Train Loss: 0.4840, Train Acc: 0.7761, Val Loss: 0.5261, Val Acc: 0.7492
  -> New best validation loss: 0.5261




Epoch 37/50, Train Loss: 0.4873, Train Acc: 0.7689, Val Loss: 0.5421, Val Acc: 0.7362




Epoch 38/50, Train Loss: 0.4847, Train Acc: 0.7709, Val Loss: 0.5399, Val Acc: 0.7441




Epoch 39/50, Train Loss: 0.4746, Train Acc: 0.7855, Val Loss: 0.5466, Val Acc: 0.7356




Epoch 40/50, Train Loss: 0.4798, Train Acc: 0.7747, Val Loss: 0.5444, Val Acc: 0.7316




Epoch 41/50, Train Loss: 0.4792, Train Acc: 0.7774, Val Loss: 0.5514, Val Acc: 0.7203




Epoch 42/50, Train Loss: 0.4759, Train Acc: 0.7790, Val Loss: 0.5560, Val Acc: 0.7282




Epoch 43/50, Train Loss: 0.4517, Train Acc: 0.7950, Val Loss: 0.5396, Val Acc: 0.7316




Epoch 44/50, Train Loss: 0.4368, Train Acc: 0.8008, Val Loss: 0.5395, Val Acc: 0.7345




Epoch 45/50, Train Loss: 0.4319, Train Acc: 0.8066, Val Loss: 0.5453, Val Acc: 0.7243




Epoch 46/50, Train Loss: 0.4207, Train Acc: 0.8125, Val Loss: 0.5523, Val Acc: 0.7271
Early stopping triggered

Validation metrics:
Accuracy: 0.7271
Balanced Accuracy: 0.7257
MCC: 0.4530
Sensitivity: 0.7625
Specificity: 0.6890
Confusion Matrix:
[[587 265]
 [218 700]]





--- Fold 5/5 ---




Epoch 1/50, Train Loss: 0.6752, Train Acc: 0.5865, Val Loss: 0.6658, Val Acc: 0.5774
  -> New best validation loss: 0.6658




Epoch 2/50, Train Loss: 0.6433, Train Acc: 0.6342, Val Loss: 0.6445, Val Acc: 0.6226
  -> New best validation loss: 0.6445




Epoch 3/50, Train Loss: 0.6230, Train Acc: 0.6606, Val Loss: 0.6293, Val Acc: 0.6542
  -> New best validation loss: 0.6293




Epoch 4/50, Train Loss: 0.6045, Train Acc: 0.6780, Val Loss: 0.7300, Val Acc: 0.5463




Epoch 5/50, Train Loss: 0.6031, Train Acc: 0.6801, Val Loss: 0.6238, Val Acc: 0.6746
  -> New best validation loss: 0.6238




Epoch 6/50, Train Loss: 0.5920, Train Acc: 0.6888, Val Loss: 0.6211, Val Acc: 0.6695
  -> New best validation loss: 0.6211




Epoch 7/50, Train Loss: 0.5851, Train Acc: 0.6919, Val Loss: 0.6004, Val Acc: 0.6836
  -> New best validation loss: 0.6004




Epoch 8/50, Train Loss: 0.5861, Train Acc: 0.6891, Val Loss: 0.5961, Val Acc: 0.6819
  -> New best validation loss: 0.5961




Epoch 9/50, Train Loss: 0.5787, Train Acc: 0.7034, Val Loss: 0.5884, Val Acc: 0.6994
  -> New best validation loss: 0.5884




Epoch 10/50, Train Loss: 0.5782, Train Acc: 0.7055, Val Loss: 0.6012, Val Acc: 0.6780




Epoch 11/50, Train Loss: 0.5726, Train Acc: 0.7085, Val Loss: 0.5952, Val Acc: 0.6960




Epoch 12/50, Train Loss: 0.5709, Train Acc: 0.7127, Val Loss: 0.6039, Val Acc: 0.6944




Epoch 13/50, Train Loss: 0.5682, Train Acc: 0.7069, Val Loss: 0.6126, Val Acc: 0.6520




Epoch 14/50, Train Loss: 0.5647, Train Acc: 0.7181, Val Loss: 0.6330, Val Acc: 0.6090




Epoch 15/50, Train Loss: 0.5607, Train Acc: 0.7189, Val Loss: 0.5819, Val Acc: 0.6927
  -> New best validation loss: 0.5819




Epoch 16/50, Train Loss: 0.5574, Train Acc: 0.7190, Val Loss: 0.5714, Val Acc: 0.7186
  -> New best validation loss: 0.5714




Epoch 17/50, Train Loss: 0.5537, Train Acc: 0.7306, Val Loss: 0.5790, Val Acc: 0.6955




Epoch 18/50, Train Loss: 0.5466, Train Acc: 0.7332, Val Loss: 0.5769, Val Acc: 0.6977




Epoch 19/50, Train Loss: 0.5430, Train Acc: 0.7395, Val Loss: 0.5616, Val Acc: 0.7226
  -> New best validation loss: 0.5616




Epoch 20/50, Train Loss: 0.5380, Train Acc: 0.7332, Val Loss: 0.5677, Val Acc: 0.7073




Epoch 21/50, Train Loss: 0.5380, Train Acc: 0.7418, Val Loss: 0.5719, Val Acc: 0.7102




Epoch 22/50, Train Loss: 0.5328, Train Acc: 0.7409, Val Loss: 0.5715, Val Acc: 0.7023




Epoch 23/50, Train Loss: 0.5306, Train Acc: 0.7466, Val Loss: 0.5544, Val Acc: 0.7254
  -> New best validation loss: 0.5544




Epoch 24/50, Train Loss: 0.5204, Train Acc: 0.7450, Val Loss: 0.5572, Val Acc: 0.7158




Epoch 25/50, Train Loss: 0.5185, Train Acc: 0.7497, Val Loss: 0.5624, Val Acc: 0.7153




Epoch 26/50, Train Loss: 0.5158, Train Acc: 0.7507, Val Loss: 0.5506, Val Acc: 0.7266
  -> New best validation loss: 0.5506




Epoch 27/50, Train Loss: 0.5146, Train Acc: 0.7525, Val Loss: 0.5456, Val Acc: 0.7373
  -> New best validation loss: 0.5456




Epoch 28/50, Train Loss: 0.5090, Train Acc: 0.7617, Val Loss: 0.5650, Val Acc: 0.7141




Epoch 29/50, Train Loss: 0.5123, Train Acc: 0.7601, Val Loss: 0.5409, Val Acc: 0.7311
  -> New best validation loss: 0.5409




Epoch 30/50, Train Loss: 0.5064, Train Acc: 0.7642, Val Loss: 0.5536, Val Acc: 0.7186




Epoch 31/50, Train Loss: 0.5061, Train Acc: 0.7652, Val Loss: 0.5448, Val Acc: 0.7345




Epoch 32/50, Train Loss: 0.4923, Train Acc: 0.7737, Val Loss: 0.5546, Val Acc: 0.7130




Epoch 33/50, Train Loss: 0.5022, Train Acc: 0.7596, Val Loss: 0.5402, Val Acc: 0.7339
  -> New best validation loss: 0.5402




Epoch 34/50, Train Loss: 0.4873, Train Acc: 0.7735, Val Loss: 0.5504, Val Acc: 0.7254




Epoch 35/50, Train Loss: 0.4954, Train Acc: 0.7644, Val Loss: 0.5545, Val Acc: 0.7254




Epoch 36/50, Train Loss: 0.4912, Train Acc: 0.7720, Val Loss: 0.5418, Val Acc: 0.7367




Epoch 37/50, Train Loss: 0.4887, Train Acc: 0.7693, Val Loss: 0.5624, Val Acc: 0.7119




Epoch 38/50, Train Loss: 0.4879, Train Acc: 0.7687, Val Loss: 0.5499, Val Acc: 0.7294




Epoch 39/50, Train Loss: 0.4816, Train Acc: 0.7765, Val Loss: 0.5595, Val Acc: 0.7102




Epoch 40/50, Train Loss: 0.4606, Train Acc: 0.7885, Val Loss: 0.5408, Val Acc: 0.7328




Epoch 41/50, Train Loss: 0.4463, Train Acc: 0.7990, Val Loss: 0.5487, Val Acc: 0.7254




Epoch 42/50, Train Loss: 0.4425, Train Acc: 0.8054, Val Loss: 0.5464, Val Acc: 0.7339




Epoch 43/50, Train Loss: 0.4264, Train Acc: 0.8067, Val Loss: 0.5617, Val Acc: 0.7141
Early stopping triggered

Validation metrics:
Accuracy: 0.7141
Balanced Accuracy: 0.7153
MCC: 0.4307
Sensitivity: 0.6841
Specificity: 0.7465
Confusion Matrix:
[[636 216]
 [290 628]]





--- Cross-validation summary ---
accuracy: 0.7253 ± 0.0159
balanced_acc: 0.7254 ± 0.0147
mcc: 0.4520 ± 0.0289
sensitivity: 0.7221 ± 0.0535
specificity: 0.7287 ± 0.0313

--- Test set evaluation (ensemble) ---
Accuracy: 0.7391
Balanced Accuracy: 0.7648
MCC: 0.3230
Sensitivity: 0.7958
Specificity: 0.7337
Confusion Matrix:
[[1832  665]
 [  49  191]]
