# Import libraries


In [None]:
#import libraries

# SUPPRESS WARININGs!!
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

SEED = 42

PATH = "/content/drive/MyDrive/PW02-Neuroengineering/"


from google.colab import drive
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.utils.class_weight import compute_class_weight
from sklearn.utils import resample
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import os
from collections import Counter
import math
import scipy
from scipy.signal import butter, filtfilt, iirnotch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim
import seaborn as sns
import gc
import optuna

import time

from tqdm import tqdm
from sklearn.metrics import f1_score, recall_score, accuracy_score, confusion_matrix, balanced_accuracy_score, roc_auc_score,  roc_curve, classification_report
import matplotlib.pyplot as plt
from google.colab import drive



In [None]:
#drive.mount("/content/drive")
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data Exploration and Data Analysis

In [None]:
#prepare functions for filtering

def butter_bandpass(lowcut, highcut, fs, order=4):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a

def apply_bandpass_filter(data, lowcut=1, highcut=40, fs=500, order=2):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    return filtfilt(b, a, data)

def notch_filter(data, freq=50, fs=500, quality_factor=30):      #  remove noise (50hz)
    b, a = iirnotch(freq / (fs / 2), quality_factor)
    return filtfilt(b, a, data)

In [None]:
# only on kaggle

PATH = "/kaggle/input/test-ecg-pw02-1/drive-download-20251124T104021Z-1-001/"

In [None]:
#import the data and filter the signals

ECG_folder = f"{PATH}1_batch_extracted"
ECG_folder_2batch = f"{PATH}2_batch_extracted"


tabular_data = pd.read_excel(f"{PATH}VALETUDO_database_1st_batch_en_all_info.xlsx")
tabular_data_2batch = pd.read_excel(f"{PATH}VALETUDO_database_2nd_batch_en_all_info.xlsx")

# --- Load and filter both batches ---
ECGs_1 = [f for f in os.listdir(ECG_folder) if f.endswith(".mat")]
ECGs_2 = [f for f in os.listdir(ECG_folder_2batch) if f.endswith(".mat")]

def extract_patient_id(filename):
    return int(filename.split(".")[0])

ECGs_1.sort(key=extract_patient_id)
ECGs_2.sort(key=extract_patient_id)

signals_1 = np.empty((len(ECGs_1), 5000, 12))    # empty 3d array   5000 --> time lenght / 12 --> leads
signals_2 = np.empty((len(ECGs_2), 5000, 12))


In [None]:

for index, ecg_path in enumerate(ECGs_1):
    filepath = os.path.join(ECG_folder, ecg_path)
    matdata = scipy.io.loadmat(filepath)
    ecg = matdata['val']
    for i in range(12):
        ecg[:, i] = ecg[:, i] - np.mean(ecg[:, i])    #signal centered in 0
        ecg[:, i] = apply_bandpass_filter(ecg[:, i])  # filter
        ecg[:, i] = notch_filter(ecg[:, i])           #filter noise
    signals_1[index, :, :] = ecg

# --- same ---

for index, ecg_path in enumerate(ECGs_2):
    filepath = os.path.join(ECG_folder_2batch, ecg_path)
    matdata = scipy.io.loadmat(filepath)
    ecg = matdata['val']
    for i in range(12):
        ecg[:, i] = ecg[:, i] - np.mean(ecg[:, i])
        ecg[:, i] = apply_bandpass_filter(ecg[:, i])
        ecg[:, i] = notch_filter(ecg[:, i])
    signals_2[index, :, :] = ecg


In [None]:
# --- Concatenate signals and tabular data ---
signals = np.concatenate([signals_1, signals_2], axis=0)
tabular_data = pd.concat([
    tabular_data.sort_values(by="ECG_patient_id").reset_index(drop=True),
    tabular_data_2batch.sort_values(by="ECG_patient_id").reset_index(drop=True)
], ignore_index=True)

print("Combined signals shape:", signals.shape)
print("Combined tabular shape:", tabular_data.shape)

In [None]:
print(f"nb pos: {np.sum(tabular_data['sport_ability']==1)}")
print(f"% pos: {np.sum(tabular_data['sport_ability']==1)/len(tabular_data['sport_ability'])*100:.2f}%")

dataset ~ sbilanciato, circa 70% classe 1

In [None]:
tabular_data.head(5)

In [None]:
tabular_data.info()

do not need encoding, everything already numerical

In [None]:
#tabular_data.isnull().sum()

In [None]:
tabular_data = tabular_data.dropna(axis=1)
tabular_data.info()

rn dropped weight height training_load columns

future test we can drop the rows (at least training load w 1 NA)



# Data Preprocessing

In [None]:

TARGET_COL = 'sport_ability'
ID_COL = 'ECG_patient_id'

patient_labels = tabular_data.groupby(ID_COL)[TARGET_COL].first()
unique_patient_ids = patient_labels.index.values
unique_patient_targets = patient_labels.values

# 80% patient train
train_ids, temp_ids, train_targets, temp_targets = train_test_split(
    unique_patient_ids,
    unique_patient_targets,
    test_size=0.2,
    stratify=unique_patient_targets,
    random_state=SEED
)

# 10% patient val, 10% test
val_ids, test_ids, val_targets, test_targets = train_test_split(
    temp_ids,
    temp_targets,
    test_size=0.5,
    stratify=temp_targets,
    random_state=SEED
)

# mask
train_mask = tabular_data[ID_COL].isin(train_ids)
val_mask   = tabular_data[ID_COL].isin(val_ids)
test_mask  = tabular_data[ID_COL].isin(test_ids)

# --- TRAINING SET ---
X_sig_train = signals[train_mask]
X_tab_train = tabular_data[train_mask].drop(columns=[TARGET_COL])
y_train     = tabular_data[train_mask][TARGET_COL]

# --- VALIDATION SET ---
X_sig_val   = signals[val_mask]
X_tab_val   = tabular_data[val_mask].drop(columns=[TARGET_COL])
y_val       = tabular_data[val_mask][TARGET_COL]

# --- TEST SET ---
X_sig_test  = signals[test_mask]
X_tab_test  = tabular_data[test_mask].drop(columns=[TARGET_COL])
y_test      = tabular_data[test_mask][TARGET_COL]


# --- CHECK ---
print(f"Patients: {len(unique_patient_ids)}")
print(f"Train IDs: {len(train_ids)}, Val IDs: {len(val_ids)}, Test IDs: {len(test_ids)}")
print("-" * 30)
print(f'Training set:   {X_sig_train.shape} {X_tab_train.shape} -> {y_train.shape}')
print(f'Validation set: {X_sig_val.shape}   {X_tab_val.shape}   -> {y_val.shape}')
print(f'Test set:       {X_sig_test.shape}  {X_tab_test.shape}  -> {y_test.shape}')

assert set(tabular_data[train_mask][ID_COL]) & set(tabular_data[val_mask][ID_COL]) == set()
print("Integrity OK --> no patient split")

In [None]:
plt.figure(figsize=(16, 3))

sns.countplot(x=y_train, palette='tab10')

plt.title("Train set sample distribution")
plt.xlabel("Sport Ability (0 -> NotAble // 1 --> Able)")
plt.show()

## Normalization


In [None]:
COLS_TO_SCALE = ['age_at_exam']  # only normalize this column for tab data. add weight height if we dont delete them

# normalize tabular data

scaler_tab = MinMaxScaler()
scaler_tab.fit(X_tab_train[COLS_TO_SCALE])

X_tab_train.loc[:, COLS_TO_SCALE] = scaler_tab.transform(X_tab_train[COLS_TO_SCALE])
X_tab_val.loc[:, COLS_TO_SCALE]   = scaler_tab.transform(X_tab_val[COLS_TO_SCALE])
X_tab_test.loc[:, COLS_TO_SCALE]  = scaler_tab.transform(X_tab_test[COLS_TO_SCALE])

In [None]:
# func to normalize signals z-score
def normalize_instance_wise(signals):

    mean = np.mean(signals, axis=1, keepdims=True)
    std = np.std(signals, axis=1, keepdims=True)

    epsilon = 1e-8

    return (signals - mean) / (std + epsilon)

In [None]:
X_sig_train_norm = normalize_instance_wise(X_sig_train)
X_sig_val_norm   = normalize_instance_wise(X_sig_val)
X_sig_test_norm  = normalize_instance_wise(X_sig_test)


### Data Augmentation

In [None]:
WINDOW_SIZE = 2500
STRIDE = 250        

BATCH_SIZE = 32
JITTER_STRENGTH = 0.1

CHANNEL_MASK = 0.4

In [None]:
def build_sequences_numpy(signals, tabular, labels, window=2500, stride=1250):

    X_sig_seq = []
    X_tab_seq = []
    y_seq = []

    num_patients = signals.shape[0]
    signal_len = signals.shape[1]

    for i in range(num_patients):

        curr_sig = signals[i]
        curr_tab = tabular.iloc[i].values
        curr_label = labels.iloc[i] if hasattr(labels, 'iloc') else labels[i]

        idx = 0
        while idx + window <= signal_len:
            segment = curr_sig[idx : idx + window, :]

            X_sig_seq.append(segment)
            X_tab_seq.append(curr_tab)
            y_seq.append(curr_label)

            idx += stride

    return np.array(X_sig_seq), np.array(X_tab_seq), np.array(y_seq)

In [None]:

# training
X_sig_train_seq, X_tab_train_seq, y_train_seq = build_sequences_numpy(
    X_sig_train_norm, X_tab_train, y_train, window=WINDOW_SIZE, stride=1250
)

# Validation e Test (we dont augment data here)
X_sig_val_seq, X_tab_val_seq, y_val_seq = build_sequences_numpy(
    X_sig_val_norm, X_tab_val, y_val, window=WINDOW_SIZE, stride=WINDOW_SIZE
)

X_sig_test_seq, X_tab_test_seq, y_test_seq = build_sequences_numpy(
    X_sig_test_norm, X_tab_test, y_test, window=WINDOW_SIZE, stride=WINDOW_SIZE
)

# --- check ---
print(f"Original Train Patients: {X_sig_train_norm.shape[0]}")
print(f"Augmented Train Segments: {X_sig_train_seq.shape}")
print(f"Augmented Tabular Shape: {X_tab_train_seq.shape}")
print(f"Augmented Labels Shape: {y_train_seq.shape}")

In [None]:
class ECGThreeBranchDataset(Dataset):
    def __init__(self, signals, tabular, labels=None, is_train=False, JITTER_STRENGTH=0.05, CHANNEL_MASK=0.3):
        """
        signals: (N, 2500, 12) -> splitted in (N, 6, 2500) and (N, 6, 2500)  --> to get 6 leads/6 leads
        """
        self.signals = torch.tensor(signals, dtype=torch.float32).permute(0, 2, 1)
        self.tabular = torch.tensor(tabular, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long) if labels is not None else None

        self.is_train = is_train
        self.jitter_strength = JITTER_STRENGTH
        self.channel_mask_prob = CHANNEL_MASK

    def __len__(self):
        return len(self.signals)

    def __getitem__(self, idx):
        full_sig = self.signals[idx].clone()

        if self.is_train:

            shift = torch.randint(low=-250, high=250, size=(1,)).item()   #time shifting
            full_sig = torch.roll(full_sig, shifts=shift, dims=1)

            if self.jitter_strength > 0:          # jitter -> put JITTER_STRENGHT = 0 to pass w/
                noise = torch.randn_like(full_sig) * self.jitter_strength
                full_sig = full_sig + noise

            # Channel Masking   --> turn a signal to 0 to prevent model laziness
            if self.channel_mask_prob > 0 and torch.rand(1) < self.channel_mask_prob:
                mask_idx = torch.randint(0, 12, (1,)).item()
                full_sig[mask_idx, :] = 0 
                
        
        # we get 3 dataset --> model needs 3 branch
        # branch 1: Limb Leads (first 6: I, II, III, aVR, aVL, aVF)
        limb_sig = full_sig[:6, :]      # Shape: (6, 2500)

        # Branch 2: Precordial Leads (last  6: V1-V6)
        prec_sig = full_sig[6:, :]      # Shape: (6, 2500)

        # Branch 3: Tabular
        tab_data = self.tabular[idx]

        if self.labels is not None:
            return limb_sig, prec_sig, tab_data, self.labels[idx]
        else:
            return limb_sig, prec_sig, tab_data

In [None]:
def make_loader(ds, batch_size, shuffle, drop_last=False):

    cpu_cores = os.cpu_count() or 2
    num_workers = max(2, min(4, cpu_cores))

    return DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
        pin_memory_device="cuda" if torch.cuda.is_available() else "",
        prefetch_factor=2,
    )

In [None]:
# create train dataset
train_ds = ECGThreeBranchDataset(
    X_sig_train_seq, 
    X_tab_train_seq, 
    y_train_seq, 
    is_train=True,
    JITTER_STRENGTH = JITTER_STRENGTH,   
    CHANNEL_MASK = CHANNEL_MASK)
# create val test dataset
val_ds   = ECGThreeBranchDataset(X_sig_val_seq, X_tab_val_seq, y_val_seq, is_train=False)
test_ds  = ECGThreeBranchDataset(X_sig_test_seq, X_tab_test_seq, y_test_seq, is_train=False)

# create loaders
train_loader = make_loader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_loader   = make_loader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)
test_loader  = make_loader(test_ds,  batch_size=BATCH_SIZE, shuffle=False)

print("Dataloaders created!")

In [None]:
# --- SANITY CHECK ---

for limb_batch, prec_batch, tab_batch, label_batch in train_loader:

    print("--- Batch Shapes ---")
    print(f"Limb Leads Batch:        {limb_batch.shape}  -> (Batch, 6, Time)")
    print(f"Precordial Leads Batch:  {prec_batch.shape}  -> (Batch, 6, Time)")
    print(f"Tabular Batch:           {tab_batch.shape}   -> (Batch, Features)")
    print(f"Labels Batch:            {label_batch.shape} -> (Batch)")

    assert limb_batch.shape[1] == 6, "Error: Limb Leads should be 6"
    assert prec_batch.shape[1] == 6, "Error: Precordial Leads should be 6"
    assert limb_batch.shape[2] == prec_batch.shape[2], "Error time"

    break

# --- Network Configuration Parameters ---
n_timesteps = X_sig_train_seq.shape[1] # 2500
total_channels = X_sig_train_seq.shape[2] # 12
n_tab_feats = X_tab_train_seq.shape[1]
n_classes   = len(np.unique(y_train_seq))

print("\n\n--- Network Configuration ---")
print(f"Total Input Channels: {total_channels} (Split into 6 Limb + 6 Precordial)")
print(f"Input Tabular Features: {n_tab_feats}")
print(f"Number of Classes:      {n_classes}")

In [None]:
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1)
        return x * y

In [None]:

'''
convolution source:
      Cardiologist-level arrhythmia detection and classification in ambulatory electrocardiograms using a deep neural network
      Hannun, A. Y., Rajpurkar, P., Ng, A. Y., et al.
      Nature Medicine (2019)
      Link: Nature Medicine Article | https://arxiv.org/abs/1707.01836
'''

class MicroResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=7, stride=1):
        super(MicroResNetBlock, self).__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride=1, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.se = SEBlock(out_channels, reduction=4)
        self.shortcut = nn.Sequential()
        if stride > 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        residual = self.shortcut(x)
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)
        out += residual
        out = self.relu(out)
        return out

# Model Building

In [None]:
class ThreeBranchSimpleCNN(nn.Module):
    def __init__(self, n_tabular_features, n_classes, dropout=0.5, base_channels=8):
        super(ThreeBranchSimpleCNN, self).__init__()

        c1 = base_channels
        c2 = base_channels * 2
        c3 = base_channels * 4

        def make_resnet_branch():
            return nn.Sequential(
                MicroResNetBlock(in_channels=6, out_channels=c1, kernel_size=15, stride=1), 
                nn.MaxPool1d(2),
                
                MicroResNetBlock(in_channels=c1, out_channels=c2, kernel_size=7, stride=2),
                nn.MaxPool1d(2),
                
                MicroResNetBlock(in_channels=c2, out_channels=c3, kernel_size=5, stride=2),
                                
                nn.AdaptiveAvgPool1d(1) 
            )

        self.branch_limb = make_resnet_branch()
        self.branch_prec = make_resnet_branch()

        self.branch_tab = nn.Sequential(
            nn.Linear(n_tabular_features, c2), 
            nn.BatchNorm1d(c2),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        fusion_dim = c3 + c3 + c2
        
        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, c3 * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(c3 * 2, n_classes)
        )

    def forward(self, limb, prec, tab):
        x_limb = self.branch_limb(limb).squeeze(-1)
        x_prec = self.branch_prec(prec).squeeze(-1)
        x_tab = self.branch_tab(tab)
        combined = torch.cat([x_limb, x_prec, x_tab], dim=1)
        return self.classifier(combined)

In [None]:
# ---model instance ---

N_TAB_FEATURES = X_tab_train_seq.shape[1]
N_CLASSES = len(np.unique(y_train_seq))

model = ThreeBranchSimpleCNN(
    n_tabular_features=N_TAB_FEATURES,
    n_classes=N_CLASSES
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) 

print(model)  # to get all information of actual model
print(f"\nModel loaded on: {device}")

# HYPERPARAMETERS

In [None]:
LEARNING_RATE = 1e-4
EPOCHS = 500
PATIENCE = 100
model_dir = "/content/"
VERBOSE = 1

# --- Regularization ---
DROPOUT_RATE = 0.5
L1_LAMBDA = 0
L2_LAMBDA = 5e-2
LABEL_SMOOTHING = 0.15

# class weights
class_weights_array = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_seq),
    y=y_train_seq
)

class_weights = torch.tensor(class_weights_array, dtype=torch.float32).to(device)

print(f"Wheights: {class_weights_array} for classes: {np.unique(y_train_seq)}")

# loss function
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=LABEL_SMOOTHING)

## Functions & Main Loop

In [None]:
# train function
def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, device, l1_lambda=0, l2_lambda=0):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_targets = []

    device_type = 'cuda' if device.type == 'cuda' else 'cpu'

    for limb_in, prec_in, tab_in, targets in train_loader:
        limb_in = limb_in.to(device)
        prec_in = prec_in.to(device)
        tab_in = tab_in.to(device)
        targets = targets.to(device).long()

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type=device_type, enabled=(device.type=='cuda')):
            logits = model(limb_in, prec_in, tab_in)
            loss = criterion(logits, targets)

            # we dont need L1 rn, but here is the logic to make it work  -  L2 already managed by AdamW
            if l1_lambda > 0:
                l1_norm = sum(p.abs().sum() for p in model.parameters() if p.requires_grad)
                loss += l1_lambda * l1_norm


        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if scheduler is not None:
            scheduler.step()

        running_loss += loss.item() * targets.size(0)
        preds = logits.argmax(dim=1)
        all_preds.append(preds.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_f1 = f1_score(np.concatenate(all_targets), np.concatenate(all_preds), average='weighted')
    return epoch_loss, epoch_f1

In [None]:
# validate function
def validate_one_epoch(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_targets = []
    device_type = 'cuda' if device.type == 'cuda' else 'cpu'

    with torch.no_grad():
        for limb_in, prec_in, tab_in, targets in val_loader:
            limb_in = limb_in.to(device)
            prec_in = prec_in.to(device)
            tab_in = tab_in.to(device)
            targets = targets.to(device).long()

            with torch.amp.autocast(device_type=device_type, enabled=(device.type=='cuda')):
                logits = model(limb_in, prec_in, tab_in)
                loss = criterion(logits, targets)

            running_loss += loss.item() * targets.size(0)
            preds = logits.argmax(dim=1)
            all_preds.append(preds.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_f1 = f1_score(np.concatenate(all_targets), np.concatenate(all_preds), average='weighted')
    return epoch_loss, epoch_f1

In [None]:

def fit(model, train_loader, val_loader, epochs, criterion, optimizer, scheduler, scaler, device,
        l1_lambda=0, l2_lambda=0, patience=0, evaluation_metric="val_f1", mode='max',
        restore_best_weights=True, verbose=1, experiment_name="best_model"):

    training_history = {
        'train_loss': [], 'val_loss': [],
        'train_f1': [], 'val_f1': []
    }

    if patience > 0:
        patience_counter = 0
        best_metric = float('-inf') if mode == 'max' else float('inf')
        best_epoch = 0

    print(f"Start Training: {epochs} epochs on {device}... \n")
    start_time = time.time()

    for epoch in range(1, epochs + 1):

        train_loss, train_f1 = train_one_epoch(
            model, train_loader, criterion, optimizer, scheduler, scaler, device, l1_lambda, l2_lambda
        )

        val_loss, val_f1 = validate_one_epoch(
            model, val_loader, criterion, device
        )

        training_history['train_loss'].append(train_loss)
        training_history['val_loss'].append(val_loss)
        training_history['train_f1'].append(train_f1)
        training_history['val_f1'].append(val_f1)

        # print
        if verbose > 0 and (epoch % verbose == 0 or epoch == 1):
            current_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
            print(f"Epoch {epoch:3d}/{epochs} | "
                  f"Train Loss: {train_loss:.4f} F1: {train_f1:.4f} | "
                  f"Val Loss: {val_loss:.4f} F1: {val_f1:.4f} | "
                  f"LR: {current_lr:.6f}")

        # early stopping
        if patience > 0:
            current_metric = val_f1 if evaluation_metric == "val_f1" else val_loss

            if mode == 'max':
                is_improvement = current_metric > best_metric
            else:
                is_improvement = current_metric < best_metric

            if is_improvement:
                best_metric = current_metric
                best_epoch = epoch
                torch.save(model.state_dict(), f"{model_dir}/{experiment_name}.pt")
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"\nEarly stopping triggered after {epoch} epochs.")
                    break

    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time/60:.2f} minutes.")

    if restore_best_weights and patience > 0:
        model.load_state_dict(torch.load(f"{model_dir}/{experiment_name}.pt", map_location=device))
        print(f"Recovered best model: Epoch {best_epoch} with {evaluation_metric}: {best_metric:.4f}")

    return model, training_history

# Model Training

In [None]:
best_model = None
best_performance = float('-inf')

In [None]:
# --- Create model and display architecture ---
experiment_name = "ThreeBranchSimpleCNN  -  -  -  OPTUNA"

# Recuperiamo le dimensioni dai dati pronti
n_tabular_features = X_tab_train_seq.shape[1] # Features tabellari
n_classes = len(np.unique(y_train_seq))       # Classi


# optimezer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=L2_LAMBDA)

# scheduler --> test maybe w ReduceOnPlateu
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=2e-4,
    steps_per_epoch=len(train_loader),
    epochs=EPOCHS,
    pct_start=0.15,          # 15% epochs LR up, 85% down
    anneal_strategy='cos'
)

try:
    scaler = torch.amp.GradScaler(device='cuda', enabled=(device.type == 'cuda'))
except AttributeError:
    scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda'))

In [None]:


def objective(trial):

    # Optimizer & Scheduler
    lr = trial.suggest_float("lr", 1e-5, 5e-2, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-4, 5e-2, log=True)
    pct_start = trial.suggest_float("pct_start", 0.1, 0.5) 
    
    # Architecture
    dropout = trial.suggest_float("dropout", 0.2, 0.6)
    base_channels = trial.suggest_categorical("base_channels", [8, 16, 32]) 
    
    # Training
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
    
    # Augmentation
    jitter = trial.suggest_float("jitter", 0.0, 0.2)
    mask_prob = trial.suggest_float("mask_prob", 0.0, 0.5)
    
    # Loss
    label_smoothing = trial.suggest_float("label_smoothing", 0.0, 0.2)

    # ==========================================
    # 2. DATASET & LOADER
    # ==========================================
    train_ds = ECGThreeBranchDataset(
        X_sig_train_seq, X_tab_train_seq, y_train_seq, 
        is_train=True, 
        JITTER_STRENGTH=jitter,  
        CHANNEL_MASK=mask_prob    
    )
    val_ds = ECGThreeBranchDataset(X_sig_val_seq, X_tab_val_seq, y_val_seq, is_train=False)

    train_loader_opt = make_loader(train_ds, batch_size, shuffle=True, drop_last=True)
    val_loader_opt = make_loader(val_ds, batch_size, shuffle=False)

    # ==========================================
    # 3. MODEL 
    # ==========================================
    model = ThreeBranchSimpleCNN(
        n_tabular_features=N_TAB_FEATURES, 
        n_classes=2,
        dropout=dropout,
        base_channels=base_channels 
    ).to(device)

    # ==========================================
    # 4. OPTIMIZER & SCHEDULER & LOSS
    # ==========================================
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    EPOCHS_OPT = 20 
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, 
        max_lr=lr, 
        steps_per_epoch=len(train_loader_opt),
        epochs=EPOCHS_OPT,
        pct_start=pct_start
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
    scaler_amp = torch.amp.GradScaler('cuda')

    # ==========================================
    # 5. TRAINING LOOP
    # ==========================================
    best_f1 = 0.0
    
    for epoch in range(1, EPOCHS_OPT + 1):
        
        train_loss, _ = train_one_epoch(
            model, train_loader_opt, criterion, optimizer, scheduler, scaler_amp, device
        )
        
        val_loss, val_f1 = validate_one_epoch(
            model, val_loader_opt, criterion, device
        )

        # Pruning  Optuna
        trial.report(val_f1, epoch)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
            
        if val_f1 > best_f1:
            best_f1 = val_f1

    return best_f1



In [None]:
# ==========================================
# STUDY
# ==========================================
# direction="maximize"  --> maximize F1 score
study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())

study.optimize(objective, n_trials=250) 

print("-" * 50)
print("BEST PARAMETERS:")
print(study.best_params)
print(f"BEST F1: {study.best_value}")

In [None]:
import plotly.io as pio
pio.renderers.default = "iframe" 

from optuna.visualization import plot_optimization_history, plot_param_importances, plot_slice


print("Importanza degli Iperparametri:")
try:
    fig1 = plot_param_importances(study)
    fig1.show()
except:
    print("Impossibile mostrare importanza (forse pochi trial completati?)")

# 2. Come Ã¨ migliorato il modello nel tempo? (Line chart)
# Vedrai i pallini salire man mano che Optuna impara
print("Storia dell'Ottimizzazione:")
fig2 = plot_optimization_history(study)
fig2.show()

# 3. Dettaglio per ogni parametro (Scatter plot)
# Ti fa vedere, ad esempio, che "tutti i risultati buoni hanno LR basso"
print("Dettaglio distribuzione parametri:")
fig3 = plot_slice(study)
fig3.show()

# 4. Stampa testuale riassuntiva del Best Trial
print("\n" + "="*40)
print(f"MIGLIOR TRIAL (#{study.best_trial.number})")
print("="*40)
print(f"Value (F1): {study.best_value:.4f}")
print("Params: ")
for key, value in study.best_params.items():
    print(f"  {key}: {value}")

----------------------------------------------------------------------------------------------------------------------------------------------------------------