In [5]:
from itertools import combinations

def leave_one_out_splits(patients, val_count=2):
    splits = []
    for i, test_patient in enumerate(patients):
        remaining = [p for j, p in enumerate(patients) if j != i]
        val_combinations = list(combinations(remaining, val_count))
        for val_patients in val_combinations:
            train_patients = [p for p in remaining if p not in val_patients]
            splits.append({
                "train": train_patients,
                "val": list(val_patients),
                "test": [test_patient]
            })
            break  # ← prendi solo la prima combinazione di validation
    return splits

In [7]:
import torch
import pickle
from scipy.signal import resample
import os
import numpy as np

class CHBMITLoader(torch.utils.data.Dataset):
    def __init__(self, root, files, sampling_rate=200):
        self.root = root
        self.files = files
        self.default_rate = 256
        self.sampling_rate = sampling_rate

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        sample = pickle.load(open(os.path.join(self.root, self.files[index]), "rb"))
        X = sample["X"]
        # 2560 -> 2000, from 256Hz to ?
        if self.sampling_rate != self.default_rate:
            X = resample(X, 10 * self.sampling_rate, axis=-1)
        

        X = X / (
            np.quantile(np.abs(X), q=0.95, method="linear", axis=-1, keepdims=True)
            + 1e-8
        )
        Y = sample["y"]
        X = torch.FloatTensor(X)
        return X, Y

In [8]:
import os

all_patients = sorted(os.listdir("CHB-MIT/clean_segments"))

splits = leave_one_out_splits(all_patients, val_count=2)

print(splits)

[{'train': [], 'val': ['train', 'val'], 'test': ['test']}, {'train': [], 'val': ['test', 'val'], 'test': ['train']}, {'train': [], 'val': ['test', 'train'], 'test': ['val']}]


In [9]:
import os
import torch

def make_loader(patients_list, split, config):
    segment_files = []
    for patient in patients_list:
        path = os.path.join("CHB-MIT/clean_segments", patient, split)
        if os.path.exists(path):
            files = [os.path.join(patient, split, f) for f in os.listdir(path)]
            segment_files.extend(files)
    dataset = CHBMITLoader("CHB-MIT/clean_segments", segment_files, config["sampling_rate"])
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=config["batch_size"],
        shuffle=(split == "train"),
        drop_last=True,
        num_workers=config["num_workers"],
        persistent_workers=True,
    )

In [None]:

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from model import LitModel_finetune  # Assuming LitModel_finetune is defined in model.py
def supervised(config, train_loader, val_loader, test_loader, iteration_idx):
    lightning_model = LitModel_finetune(config)
    os.makedirs("log-finetuning", exist_ok=True)

    logger = TensorBoardLogger(save_dir="log-finetuning", name=f"run-{iteration_idx}")
    early_stop_callback = EarlyStopping(monitor="val_auroc", patience=5, mode="max")
    checkpoint_callback = ModelCheckpoint(
        dirpath=f"log-finetuning/run-{iteration_idx}/checkpoints",
        filename="model-{epoch:02d}-{val_loss:.4f}",
        monitor="val_loss",
        mode="min",
        save_last=True,
        save_top_k=3,
    )

    trainer = pl.Trainer(
        accelerator="cpu",
        max_epochs=config["epochs"],
        callbacks=[checkpoint_callback, early_stop_callback],
        logger=logger,
    )

    trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)
    result = trainer.test(lightning_model, dataloaders=test_loader, ckpt_path="best")[0]
    print(f"=== Split {iteration_idx} Result ===")
    print(result)

In [None]:
if __name__ == "__main__":
    config = load_config("configs/finetuning.yml")
    all_patients = sorted(os.listdir("CHB-MIT/clean_segments"))

    splits = leave_one_out_splits(all_patients, val_count=2)

    for idx, split in enumerate(splits):
        print(f"\n--- Running Split {idx + 1}/{len(splits)} ---")
        train_loader = make_loader(split["train"], "train", config)
        val_loader = make_loader(split["val"], "val", config)
        test_loader = make_loader(split["test"], "test", config)
        supervised(config, train_loader, val_loader, test_loader, idx + 1)