In [10]:
#import libraries

# from google.colab import drive
from sklearn.model_selection import StratifiedKFold
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import os
import scipy
from scipy.signal import butter, filtfilt, iirnotch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim


from tqdm import tqdm
from sklearn.metrics import f1_score, recall_score, accuracy_score, confusion_matrix, balanced_accuracy_score, roc_auc_score,  roc_curve
import matplotlib.pyplot as plt
from google.colab import drive

drive.mount("/content/drive")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [12]:
#prepare functions for filtering

def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def apply_bandpass_filter(data, lowcut=1, highcut=40, fs=500, order=2):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    return filtfilt(b, a, data)

def notch_filter(data, freq=50, fs=500, quality_factor=30):
    b, a = iirnotch(freq / (fs / 2), quality_factor)
    return filtfilt(b, a, data)

In [13]:
#import the data and filter the signals
# This could be change depending if you download or not the data

ECG_folder = "/content/drive/MyDrive/WP_02_data/1_batch_extracted"
ECG_folder_2batch = "/content/drive/MyDrive/WP_02_data/2_batch_extracted"

tabular_data = pd.read_excel("/content/drive/MyDrive/WP_02_data/VALETUDO_database_1st_batch_en_all_info.xlsx")
tabular_data_2batch = pd.read_excel(r"/content/drive/MyDrive/WP_02_data/VALETUDO_database_2nd_batch_en_all_info.xlsx")

# --- Load and filter both batches ---
ECGs_1 = [f for f in os.listdir(ECG_folder) if f.endswith(".mat")]
ECGs_2 = [f for f in os.listdir(ECG_folder_2batch) if f.endswith(".mat")]

def extract_patient_id(filename):
    return int(filename.split(".")[0])

ECGs_1.sort(key=extract_patient_id)
ECGs_2.sort(key=extract_patient_id)

signals_1 = np.empty((len(ECGs_1), 5000, 12))
signals_2 = np.empty((len(ECGs_2), 5000, 12))

for index, ecg_path in enumerate(ECGs_1):
    filepath = os.path.join(ECG_folder, ecg_path)
    matdata = scipy.io.loadmat(filepath)
    ecg = matdata['val']
    for i in range(12):
        ecg[:, i] = ecg[:, i] - np.mean(ecg[:, i])
        ecg[:, i] = apply_bandpass_filter(ecg[:, i])
        ecg[:, i] = notch_filter(ecg[:, i])
    signals_1[index, :, :] = ecg

for index, ecg_path in enumerate(ECGs_2):
    filepath = os.path.join(ECG_folder_2batch, ecg_path)
    matdata = scipy.io.loadmat(filepath)
    ecg = matdata['val']
    for i in range(12):
        ecg[:, i] = ecg[:, i] - np.mean(ecg[:, i])
        ecg[:, i] = apply_bandpass_filter(ecg[:, i])
        ecg[:, i] = notch_filter(ecg[:, i])
    signals_2[index, :, :] = ecg

# --- Concatenate signals and tabular data ---
signals = np.concatenate([signals_1, signals_2], axis=0)
tabular_data = pd.concat([
    tabular_data.sort_values(by="ECG_patient_id").reset_index(drop=True),
    tabular_data_2batch.sort_values(by="ECG_patient_id").reset_index(drop=True)
], ignore_index=True)

print("Combined signals shape:", signals.shape)
print("Combined tabular shape:", tabular_data.shape)


Combined signals shape: (526, 5000, 12)
Combined tabular shape: (526, 18)


array([[[ 4.50956743e-03,  3.80522812e-04, -4.12896013e-03, ...,
         -8.89387505e-02, -1.00410555e-01, -7.33718458e-02],
        [ 6.56458429e-03, -5.93895223e-05, -6.62396086e-03, ...,
         -8.92465207e-02, -1.02408333e-01, -7.56250467e-02],
        [ 8.60942900e-03, -4.24155269e-04, -9.03363677e-03, ...,
         -8.95468145e-02, -1.04145592e-01, -7.75775363e-02],
        ...,
        [-4.70671610e-03, -2.05793271e-02, -1.58732963e-02, ...,
         -6.70482275e-01, -5.06983539e-01, -3.53964836e-02],
        [-3.92982256e-03, -1.87085712e-02, -1.47793132e-02, ...,
         -6.15810302e-01, -4.60681550e-01, -9.46843458e-03],
        [-3.00378176e-03, -1.75008414e-02, -1.44974871e-02, ...,
         -5.60657223e-01, -4.13499363e-01,  1.63508109e-02]],

       [[-7.09383101e-03, -3.74490081e-02, -2.98189056e-02, ...,
          5.35597694e-02,  2.05006513e-02,  5.75837114e-03],
        [-7.09169202e-03, -3.79165242e-02, -3.02929560e-02, ...,
          5.34601851e-02,  2.03410368e

In [16]:
print("Missing values in signals (NumPy array):", np.isnan(signals).sum())
print("\nMissing values per column in tabular_data (Pandas DataFrame):\n", tabular_data.isnull().sum())

Missing values in signals (NumPy array): 0

Missing values per column in tabular_data (Pandas DataFrame):
 ECG_patient_id                   0
age_at_exam                      0
sex                              0
weight                          74
height                          74
trainning_load                   1
sport_classification             0
sport_ability                    0
AV block                         0
ST abnormality                   0
Complete BBB                     0
Prolonged QTc                    0
Uncontrolled hypertension        0
Supraventricular arrhythmias     0
Ventricular arrhythmias          0
Baseline ECG abnormalities       0
Valvular heart diseases          0
Symptomatic patients             0
dtype: int64


In [15]:
tabular_data

Unnamed: 0,ECG_patient_id,age_at_exam,sex,weight,height,trainning_load,sport_classification,sport_ability,AV block,ST abnormality,Complete BBB,Prolonged QTc,Uncontrolled hypertension,Supraventricular arrhythmias,Ventricular arrhythmias,Baseline ECG abnormalities,Valvular heart diseases,Symptomatic patients
0,3,20.095825,0,60.0,166.0,2.0,1,0,0,0,0,0,0,0,1,0,0,0
1,4,51.646817,0,84.0,180.0,2.0,1,0,0,1,0,0,1,0,0,0,0,0
2,5,40.936345,0,104.0,180.0,1.0,1,0,0,0,0,0,0,0,1,0,0,0
3,6,14.201232,0,80.0,174.0,2.0,1,1,0,0,0,0,0,0,0,0,0,0
4,7,16.607803,1,47.0,148.0,2.0,1,0,0,1,0,0,0,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
521,549,15.167693,1,60.0,163.0,2.0,1,0,0,0,0,0,0,0,1,0,0,0
522,550,27.764545,0,89.0,176.0,4.0,1,1,0,0,0,0,0,0,0,0,0,0
523,551,15.780972,0,72.0,183.0,3.0,1,1,0,0,0,0,0,0,0,0,0,0
524,552,22.661191,0,72.0,172.0,2.0,1,1,0,0,0,0,0,0,0,0,0,0
