In [100]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold
# Required for Imputation/Preprocessing
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix, roc_curve, auc
from scipy.signal import butter, filtfilt, iirnotch, find_peaks
import scipy.io # Required for loading .mat files
import matplotlib.pyplot as plt
from typing import List, Dict, Any
import os
from tqdm import tqdm

# --- GLOBAL PARAMETERS ---
FS = 500  # Sampling Frequency (Hz)
LEAD_II_INDEX = 1
BEAT_LENGTH_SAMPLES = 600
NUMERICAL_FEATURES = ['age_at_exam', 'weight', 'trainning_load', 'BMI', 'BSA']
CATEGORICAL_FEATURES = ['sex', 'sport_classification']
TABULAR_INPUT_SIZE = 9 # Calculated as 5 Numerical + 2 One-Hot (sex) + 2 One-Hot (sport_classification)
N_SPLITS = 5
BATCH_SIZE = 32
NUM_EPOCHS = 10
# NOTE: Set to 5 to run all folds. Change to 1 for quick execution.
NUM_FOLDS_TO_RUN = 5


from google.colab import drive

drive.mount("/content/drive")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [90]:
# --------------------------------------------------------------------------------
# 1. ECG SIGNAL PROCESSING FUNCTIONS
# --------------------------------------------------------------------------------

def apply_bandpass_filter(data, lowcut=1.0, highcut=40.0, fs=FS, order=5):
    """Applies a Butterworth bandpass filter."""
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

def notch_filter(data, notch_freq=50.0, fs=FS, Q=30):
    """Applies a 50Hz Notch filter to remove powerline noise."""
    nyq = 0.5 * fs
    w0 = notch_freq / nyq
    b, a = iirnotch(w0, Q)
    return filtfilt(b, a, data)

def preprocess_signal(ecg: np.ndarray) -> np.ndarray:
    """Applies per-lead Z-score normalization to the already-filtered signal."""
    ecg = ecg.copy()
    for i in range(ecg.shape[1]):
        ecg[:, i] = (ecg[:, i] - np.mean(ecg[:, i])) / (np.std(ecg[:, i]) + 1e-6)
    return ecg

def r_peak_detection_and_segmentation(ecg_12_leads: np.ndarray, fs: int = FS) -> np.ndarray:
    """Detects R-peaks, segments beats, and returns the average beat morphology."""
    lead_for_detection = ecg_12_leads[:, LEAD_II_INDEX]

    diff_signal = np.diff(lead_for_detection)**2
    window_size = int(0.150 * fs)
    integrated_signal = np.convolve(diff_signal, np.ones(window_size)/window_size, mode='same')

    distance_min = int(0.3 * fs)
    peak_threshold = np.max(integrated_signal) * 0.4
    r_peaks_idx, _ = find_peaks(integrated_signal, height=peak_threshold, distance=distance_min)

    all_beats = []
    half_beat = BEAT_LENGTH_SAMPLES // 2

    for r_idx in r_peaks_idx:
        start_idx = r_idx - half_beat
        end_idx = r_idx + half_beat
        if start_idx >= 0 and end_idx <= ecg_12_leads.shape[0]:
            beat = ecg_12_leads[start_idx:end_idx, :]
            all_beats.append(beat)

    if not all_beats:
        return np.zeros((12, BEAT_LENGTH_SAMPLES), dtype=np.float32)

    segmented_ecg = np.array(all_beats, dtype=np.float32)
    segmented_ecg = np.transpose(segmented_ecg, (0, 2, 1))

    representative_ecg = np.mean(segmented_ecg, axis=0)
    return representative_ecg

In [91]:
# --------------------------------------------------------------------------------
# 2. TABULAR FEATURE ENGINEERING AND PROCESSING
# --------------------------------------------------------------------------------

def tabular_feature_engineering(df: pd.DataFrame) -> pd.DataFrame:
    """Creates BMI and BSA, and cleans initial outliers."""
    df = df.copy()
    df.loc[(df['age_at_exam'] < 0.0) | (df['age_at_exam'] > 100.0), 'age_at_exam'] = np.nan
    df.loc[(df['trainning_load'] <= 0.0) | (df['trainning_load'] > 4.0), 'trainning_load'] = np.nan

    df['height_m'] = df['height'] / 100.0
    df['BMI'] = df['weight'] / (df['height_m']**2)
    df['BSA'] = np.sqrt((df['height'] * df['weight']) / 3600.0)

    df = df.drop(columns=['height', 'height_m'])
    return df

class TabularProcessor:
    """Handles Iterative Imputation, StandardScaler, and OneHotEncoder."""
    def __init__(self, numerical_features: List[str], categorical_features: List[str]):
        self.numerical_features = numerical_features
        self.categorical_features = categorical_features
        self.imputer = IterativeImputer(max_iter=10, random_state=42)
        self.scaler = StandardScaler()
        self.encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')

    def fit(self, df: pd.DataFrame):
        numerical_data_train = df[self.numerical_features].values
        self.imputer.fit(numerical_data_train)
        imputed_data_train = self.imputer.transform(numerical_data_train)
        self.scaler.fit(imputed_data_train)
        categorical_data_train = df[self.categorical_features].astype(str).values
        self.encoder.fit(categorical_data_train)

    def transform(self, df: pd.DataFrame) -> np.ndarray:
        numerical_data = df[self.numerical_features].values
        imputed_data = self.imputer.transform(numerical_data)
        scaled_data = self.scaler.transform(imputed_data)
        categorical_data = df[self.categorical_features].astype(str).values
        encoded_data = self.encoder.transform(categorical_data)
        return np.concatenate([scaled_data, encoded_data], axis=1)



In [92]:
# --------------------------------------------------------------------------------
# 3. PYTORCH DATASET AND MODEL DEFINITIONS
# --------------------------------------------------------------------------------

class ECGDataset(Dataset):
    """PyTorch Dataset to load ECG and tabular data."""
    def __init__(self, ecg_data: List[np.ndarray], tabular_data: pd.DataFrame, labels: np.ndarray):
        self.ecg_data = ecg_data
        self.tabular_data = tabular_data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        # ECG signal is already filtered and baseline corrected from the loading step (Block 1)
        raw_ecg = self.ecg_data[idx]
        cleaned_ecg = preprocess_signal(raw_ecg)
        segmented_ecg = r_peak_detection_and_segmentation(cleaned_ecg)
        ecg_tensor = torch.from_numpy(segmented_ecg).float()

        tabular_row = self.tabular_data.iloc[idx].values
        tabular_tensor = torch.from_numpy(tabular_row).float()

        label_tensor = torch.tensor(self.labels[idx], dtype=torch.float32)

        return {
            'ecg': ecg_tensor, 'tabular': tabular_tensor, 'label': label_tensor
        }

class ThreeBranchCNNModel(nn.Module):
    """CNN model with fixed feature size (4640)."""
    def __init__(self, tabular_input_size, num_classes=1):
        super(ThreeBranchCNNModel, self).__init__()

        conv_block = nn.Sequential(
            nn.Conv1d(in_channels=6, out_channels=32, kernel_size=16, stride=2, padding=7), nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=2),
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=8, stride=2, padding=3), nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=2),
        )
        self.ecg_frontal_branch = conv_block
        self.ecg_precordial_branch = conv_block

        output_sequence_length = 36
        feature_size_per_branch = 64 * output_sequence_length

        self.tabular_branch = nn.Sequential(
            nn.Linear(tabular_input_size, 64), nn.ReLU(), nn.Dropout(0.3), nn.Linear(64, 32)
        )

        total_combined_size = feature_size_per_branch * 2 + 32 # 4640

        self.classifier_head = nn.Sequential(
            nn.Linear(total_combined_size, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes), nn.Sigmoid()
        )

    def forward(self, ecg_tensor, tabular_tensor):
        ecg_frontal = ecg_tensor[:, :6, :]
        ecg_precordial = ecg_tensor[:, 6:, :]

        feat_frontal = torch.flatten(self.ecg_frontal_branch(ecg_frontal), 1)
        feat_precordial = torch.flatten(self.ecg_precordial_branch(ecg_precordial), 1)
        feat_tabular = self.tabular_branch(tabular_tensor)

        combined_features = torch.cat((feat_frontal, feat_precordial, feat_tabular), dim=1)

        return self.classifier_head(combined_features)


In [96]:

# --------------------------------------------------------------------------------
# 4. METRICS AND PLOTTING FUNCTIONS
# --------------------------------------------------------------------------------

def compute_metrics(y_true, y_pred_proba):
    """Computes AUC, Accuracy, Sensitivity, and Specificity."""
    if len(np.unique(y_true)) < 2:
        auc_score = 0.5
        fpr, tpr = np.array([0.0, 1.0]), np.array([0.0, 1.0])
    else:
        fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
        auc_score = auc(fpr, tpr)

    y_pred = (y_pred_proba >= 0.5).astype(int)
    accuracy = accuracy_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = cm.ravel()

    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    return {
        'AUC': auc_score, 'Accuracy': accuracy,
        'Sensitivity': sensitivity, 'Specificity': specificity,
        'FPR': fpr, 'TPR': tpr, 'CM': cm, 'Y_PRED_BINARY': y_pred
    }

def plot_confusion_matrix(cm, fold_num, file_path):
    """Plots and saves the Confusion Matrix for a specific fold."""
    plt.figure(figsize=(6, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title(f'Confusion Matrix - Fold {fold_num}')
    plt.colorbar(fraction=0.046, pad=0.04)
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ['Negative (0)', 'Positive (1)'])
    plt.yticks(tick_marks, ['Negative (0)', 'Positive (1)'])

    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                     ha="center", va="center", fontsize=16,
                     color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig(file_path)
    plt.close()

def plot_aggregate_roc(all_fprs, all_tprs, all_aucs, file_path):
    """Plots the aggregate ROC curve across all folds."""
    plt.figure(figsize=(8, 8))
    mean_fpr = np.linspace(0, 1, 100)
    tprs_interp = []

    for i in range(len(all_fprs)):
        tprs_interp.append(np.interp(mean_fpr, all_fprs[i], all_tprs[i]))
        plt.plot(all_fprs[i], all_tprs[i], alpha=0.3, label=f'ROC Fold {i+1} (AUC = {all_aucs[i]:.2f})')

    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Chance', alpha=.8)

    mean_tpr = np.mean(tprs_interp, axis=0)
    mean_tpr[0] = 0.0
    mean_tpr[-1] = 1.0
    mean_auc = auc(mean_fpr, mean_tpr)
    std_auc = np.std(all_aucs)
    std_tpr = np.std(tprs_interp, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)

    plt.plot(mean_fpr, mean_tpr, color='blue',
             label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
             lw=2, alpha=.8)

    plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='skyblue', alpha=.2,
                     label=r'$\pm$ 1 standard dev.')

    plt.xlim([-0.01, 1.01])
    plt.ylim([-0.01, 1.01])
    plt.xlabel('False Positive Rate (FPR)')
    plt.ylabel('True Positive Rate (TPR)')
    plt.title('Aggregate Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right", fontsize='small')
    plt.grid(True)
    plt.savefig(file_path)
    plt.close()


In [101]:

print("--- Starting Block 1: Data Loading and Preparation ---")
try:
    # --- ASSUMED GOOGLE DRIVE PATHS (Required for .mat files) ---
    ECG_folder_1batch = "/content/drive/MyDrive/WP_02_data/1_batch_extracted"
    ECG_folder_2batch = "/content/drive/MyDrive/WP_02_data/2_batch_extracted"

    # Load Tabular Data from accessible CSV files
    tabular_data_1 = pd.read_excel("/content/drive/MyDrive/WP_02_data/VALETUDO_database_1st_batch_en_all_info.xlsx")
    tabular_data_2 = pd.read_excel("/content/drive/MyDrive/WP_02_data/VALETUDO_database_2nd_batch_en_all_info.xlsx")

    # Concatenate and sort tabular data
    tabular_df = pd.concat([tabular_data_1, tabular_data_2], ignore_index=True)
    tabular_df = tabular_df.sort_values(by="ECG_patient_id").reset_index(drop=True)

    # --- Load ECG File Paths and Sort ---
    ECGs_1 = [f for f in os.listdir(ECG_folder_1batch) if f.endswith(".mat")]
    ECGs_2 = [f for f in os.listdir(ECG_folder_2batch) if f.endswith(".mat")]

    def extract_patient_id(filename):
        return int(filename.split(".")[0])

    ECGs_1.sort(key=extract_patient_id)
    ECGs_2.sort(key=extract_patient_id)

    # NOTE: The list of ECG files should match the number of subjects in the tabular data.
    if (len(ECGs_1) + len(ECGs_2)) != len(tabular_df):
        print("⚠️ Warning: Number of ECG files does not match tabular entries. Check file consistency.")
        # Attempt to proceed using the tabular size as the ground truth
        N_SUBJECTS = len(tabular_df)
    else:
        N_SUBJECTS = len(tabular_df)

    # --- Load ECG Signals and Apply Filtering ---
    # The expected shape is 5000 samples, 12 leads
    raw_ecg_list = []

    # Load Batch 1
    for ecg_path in tqdm(ECGs_1, desc="Loading ECG Batch 1"):
        filepath = os.path.join(ECG_folder_1batch, ecg_path)
        matdata = scipy.io.loadmat(filepath)
        ecg = matdata['val'].T # ECGs typically come in (12, 5000) or (5000, 12). Transpose if needed. Assume (Samples, Leads) for consistency.
        if ecg.shape[0] != 5000: ecg = ecg.T # Ensure (5000, 12)

        # Apply filtering and baseline correction during loading
        filtered_ecg = np.empty_like(ecg)
        for i in range(12):
            lead_data = ecg[:, i]
            lead_data = lead_data - np.mean(lead_data) # Baseline Correction
            lead_data = apply_bandpass_filter(lead_data)
            lead_data = notch_filter(lead_data)
            filtered_ecg[:, i] = lead_data
        raw_ecg_list.append(filtered_ecg.astype(np.float32))

    # Load Batch 2
    for ecg_path in tqdm(ECGs_2, desc="Loading ECG Batch 2"):
        filepath = os.path.join(ECG_folder_2batch, ecg_path)
        matdata = scipy.io.loadmat(filepath)
        ecg = matdata['val'].T
        if ecg.shape[0] != 5000: ecg = ecg.T

        # Apply filtering and baseline correction during loading
        filtered_ecg = np.empty_like(ecg)
        for i in range(12):
            lead_data = ecg[:, i]
            lead_data = lead_data - np.mean(lead_data) # Baseline Correction
            lead_data = apply_bandpass_filter(lead_data)
            lead_data = notch_filter(lead_data)
            filtered_ecg[:, i] = lead_data
        raw_ecg_list.append(filtered_ecg.astype(np.float32))

    print(f"Successfully loaded {N_SUBJECTS} subjects.")
    print(f"First ECG signal shape: {raw_ecg_list[0].shape}")

except Exception as e:
    print(f"\n❌ FATAL ERROR: Data Loading Failed ❌")
    print("The script failed to load data, likely because the Google Drive paths were not mounted or the files were inaccessible.")
    print(f"Error details: {e}")
    # Exit or raise error to stop further execution if real data is mandatory
    raise RuntimeError("Failed to load real ECG data from disk. Execution stopped.")


# 1. Feature Engineering (Applied once)
labels = tabular_df['sport_ability'].values
processed_tabular_df = tabular_feature_engineering(tabular_df)

# 2. Tabular Preprocessing to get input size
temp_processor = TabularProcessor(NUMERICAL_FEATURES, CATEGORICAL_FEATURES)
temp_processor.fit(processed_tabular_df)
TABULAR_INPUT_SIZE = temp_processor.transform(processed_tabular_df.head(1)).shape[1]

print(f"Block 1: Data preparation complete. TABULAR_INPUT_SIZE: {TABULAR_INPUT_SIZE}")

--- Starting Block 1: Data Loading and Preparation ---


Loading ECG Batch 1: 100%|██████████| 191/191 [00:07<00:00, 23.88it/s]
Loading ECG Batch 2: 100%|██████████| 335/335 [00:10<00:00, 32.17it/s]


Successfully loaded 526 subjects.
First ECG signal shape: (5000, 12)
Block 1: Data preparation complete. TABULAR_INPUT_SIZE: 9


In [102]:
# --------------------------------------------------------------------------------
# BLOCK 2: K-FOLD EXECUTION WITH TRAINING AND METRICS
# --------------------------------------------------------------------------------
print("\n--- Starting Block 2: K-Fold Cross-Validation ---")
print(f"Running {NUM_FOLDS_TO_RUN} Fold(s), {NUM_EPOCHS} Epochs.")

skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
indices = np.arange(N_SUBJECTS)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Stratified K-Fold with {N_SPLITS} splits initialized. Using device: {device}")

# Lists to store metrics across folds
all_fprs = []
all_tprs = []
all_aucs = []
fold_metrics = []

for fold, (train_index, val_index) in enumerate(skf.split(indices, labels)):
    if fold >= NUM_FOLDS_TO_RUN:
        break

    print(f"\n--- Starting Fold {fold+1}/{N_SPLITS} ---")

    # --- A. SPLIT DATA ---
    train_df_raw = processed_tabular_df.iloc[train_index].reset_index(drop=True)
    val_df_raw = processed_tabular_df.iloc[val_index].reset_index(drop=True)
    train_labels = labels[train_index]
    val_labels = labels[val_index]
    train_ecg_list = [raw_ecg_list[i] for i in train_index]
    val_ecg_list = [raw_ecg_list[i] for i in val_index]

    # --- B. TABULAR PREPROCESSING (Fit only on training data) ---
    tabular_processor = TabularProcessor(NUMERICAL_FEATURES, CATEGORICAL_FEATURES)
    tabular_processor.fit(train_df_raw)
    train_tab_processed_arr = tabular_processor.transform(train_df_raw)
    val_tab_processed_arr = tabular_processor.transform(val_df_raw)

    train_tab_processed_df = pd.DataFrame(train_tab_processed_arr)
    val_tab_processed_df = pd.DataFrame(val_tab_processed_arr)

    # --- C. DATASET AND DATALOADER CREATION ---
    train_dataset = ECGDataset(
        ecg_data=train_ecg_list, tabular_data=train_tab_processed_df, labels=train_labels
    )
    val_dataset = ECGDataset(
        ecg_data=val_ecg_list, tabular_data=val_tab_processed_df, labels=val_labels
    )

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # --- D. MODEL TRAINING SETUP ---
    model = ThreeBranchCNNModel(tabular_input_size=TABULAR_INPUT_SIZE).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print(f"  Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")

    # --- E. TRAINING LOOP ---
    for epoch in range(NUM_EPOCHS):

        # Training Phase
        model.train()
        train_loss = 0.0

        for batch in train_loader:
            ecg_data = batch['ecg'].to(device)
            tabular_data = batch['tabular'].to(device)
            labels_tensor = batch['label'].to(device).unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(ecg_data, tabular_data)
            loss = criterion(outputs, labels_tensor)

            loss.backward()
            optimizer.step()
            train_loss += loss.item() * ecg_data.size(0)

        avg_train_loss = train_loss / len(train_dataset)

        # Validation Phase
        model.eval()
        val_loss = 0.0
        val_y_true = []
        val_y_pred_proba = []

        with torch.no_grad():
            for batch in val_loader:
                ecg_data = batch['ecg'].to(device)
                tabular_data = batch['tabular'].to(device)
                labels_tensor = batch['label'].to(device).unsqueeze(1)

                outputs = model(ecg_data, tabular_data)
                loss = criterion(outputs, labels_tensor)
                val_loss += loss.item() * ecg_data.size(0)

                # Store predictions and true labels for metrics
                val_y_true.extend(labels_tensor.cpu().numpy().flatten())
                val_y_pred_proba.extend(outputs.cpu().numpy().flatten())

        avg_val_loss = val_loss / len(val_dataset)

        # --- F. METRICS CALCULATION AND REPORTING (Only at the end of the last epoch) ---
        if epoch == NUM_EPOCHS - 1:
            val_metrics = compute_metrics(np.array(val_y_true), np.array(val_y_pred_proba))

            print(f"  Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print(f"    Validation Metrics (Final Epoch):")
            print(f"      AUC: {val_metrics['AUC']:.4f}")
            print(f"      Accuracy: {val_metrics['Accuracy']:.4f}")
            print(f"      Sensitivity: {val_metrics['Sensitivity']:.4f}")
            print(f"      Specificity: {val_metrics['Specificity']:.4f}")

            # Store for aggregate plots
            all_fprs.append(val_metrics['FPR'])
            all_tprs.append(val_metrics['TPR'])
            all_aucs.append(val_metrics['AUC'])
            fold_metrics.append({
                'Fold': fold + 1,
                'AUC': val_metrics['AUC'],
                'Accuracy': val_metrics['Accuracy'],
                'Sensitivity': val_metrics['Sensitivity'],
                'Specificity': val_metrics['Specificity']
            })

            # Save Confusion Matrix
            cm_filename = f"confusion_matrix_fold_{fold+1}.png"
            plot_confusion_matrix(val_metrics['CM'], fold+1, cm_filename)
            print(f"  Confusion Matrix saved as {cm_filename}")

        else:
            print(f"  Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    print(f"--- Fold {fold+1} Completed ({NUM_EPOCHS} Epochs) ---")

# --- G. FINAL AGGREGATE METRICS & PLOT (if running multiple folds) ---
if NUM_FOLDS_TO_RUN > 1:
    agg_df = pd.DataFrame(fold_metrics)
    mean_metrics = agg_df.drop(columns=['Fold']).mean().to_dict()
    std_metrics = agg_df.drop(columns=['Fold']).std().to_dict()

    print("\n--- Aggregate Cross-Validation Results ---")
    print(f"Mean AUC: {mean_metrics['AUC']:.4f} \u00B1 {std_metrics.get('AUC', 0):.4f}")
    print(f"Mean Accuracy: {mean_metrics['Accuracy']:.4f} \u00B1 {std_metrics.get('Accuracy', 0):.4f}")
    print(f"Mean Sensitivity: {mean_metrics['Sensitivity']:.4f} \u00B1 {std_metrics.get('Sensitivity', 0):.4f}")
    print(f"Mean Specificity: {mean_metrics['Specificity']:.4f} \u00B1 {std_metrics.get('Specificity', 0):.4f}")

    roc_filename = "aggregate_roc_curve.png"
    plot_aggregate_roc(all_fprs, all_tprs, all_aucs, roc_filename)
    print(f"Aggregate ROC Curve saved as {roc_filename}")

print("\n--- Block 2 Execution Finished ---")


--- Starting Block 2: K-Fold Cross-Validation ---
Running 5 Fold(s), 10 Epochs.
Stratified K-Fold with 5 splits initialized. Using device: cpu

--- Starting Fold 1/5 ---
  Train size: 420, Validation size: 106
  Epoch 1/10, Train Loss: 0.6435, Val Loss: 0.6117
  Epoch 2/10, Train Loss: 0.5830, Val Loss: 0.5824
  Epoch 3/10, Train Loss: 0.5591, Val Loss: 0.5581
  Epoch 4/10, Train Loss: 0.5289, Val Loss: 0.5278
  Epoch 5/10, Train Loss: 0.4910, Val Loss: 0.5409
  Epoch 6/10, Train Loss: 0.4284, Val Loss: 0.5675
  Epoch 7/10, Train Loss: 0.4700, Val Loss: 0.6241
  Epoch 8/10, Train Loss: 0.4044, Val Loss: 0.5259
  Epoch 9/10, Train Loss: 0.3585, Val Loss: 0.5708
  Epoch 10/10, Train Loss: 0.3310, Val Loss: 0.5846
    Validation Metrics (Final Epoch):
      AUC: 0.7639
      Accuracy: 0.7264
      Sensitivity: 0.9167
      Specificity: 0.3235
  Confusion Matrix saved as confusion_matrix_fold_1.png
--- Fold 1 Completed (10 Epochs) ---

--- Starting Fold 2/5 ---
  Train size: 421, Validati