In [None]:
# --- Imports ---
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import tables
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    f1_score, precision_score, recall_score
)

# --- Global configuration ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
H5_PATH = "output/dataset.h5"
SPLIT_CSV = "split_all.tsv"
MODEL_DIR = "modelos"
IMG_DIR = "imgs"
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(IMG_DIR, exist_ok=True)

EMB = 231           # Embedding dimension
FF_MULT = 2         # Feedforward multiplier
MAX_SEQ_LEN = 2016  # Maximum input sequence length
BATCH_SIZE = 16
EPOCHS = 1000
LR = 3e-4
PATIENCE = 10
MIN_DELTA = 1e-4
SEEDS = [42, 43, 44]

# --- Split dictionary ---
df_split = pd.read_csv(SPLIT_CSV, sep='\t')
split_dict = {
    split: set(df_split[df_split['split'] == split]['stay_id'].values)
    for split in ['train', 'val', 'test']
}

In [None]:
# --- Custom Dataset class for windowed data stored in HDF5 ---
class HDF5WindowDataset(Dataset):
    def __init__(self, h5_path, split):
        """
        Initializes the dataset by loading window indices and corresponding stay_ids.

        Args:
            h5_path (str): Path to the HDF5 file
            split (str): One of ['train', 'val', 'test']
        """
        self.h5_path = h5_path
        self.split = split

        with tables.open_file(h5_path, mode='r') as f:
            self.windows = f.root.patient_windows[split][:]
            self.stay_ids = f.root.patient_windows[f"{split}_stay_ids"][:]

        print(f"[INFO] Loaded {len(self.windows)} windows for split: {split}")

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        """
        Returns a tuple (X, y) of time series data and binary label sequence.
        """
        with tables.open_file(self.h5_path, mode='r') as f:
            start, stop, _ = self.windows[idx]
            x = f.root.data[self.split][start:stop]
            y = f.root.labels[self.split][start:stop]
        if len(x) > MAX_SEQ_LEN:
            x = x[:MAX_SEQ_LEN]
            y = y[:MAX_SEQ_LEN]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

# --- Collate function for DataLoader (padding to equal length) ---
def collate_fn(batch):
    xs, ys = zip(*batch)
    max_len = max(x.shape[0] for x in xs)
    padded_x = torch.zeros(len(xs), max_len, xs[0].shape[1])
    padded_y = torch.zeros(len(ys), max_len, ys[0].shape[1])
    pad_mask = torch.ones(len(xs), max_len, dtype=torch.bool)
    for i, (x, y) in enumerate(zip(xs, ys)):
        padded_x[i, :x.shape[0]] = x
        padded_y[i, :y.shape[0]] = y
        pad_mask[i, :x.shape[0]] = False
    return padded_x, padded_y.squeeze(-1), pad_mask


In [None]:
# --- Simple Transformer Model ---
class SimpleTransformer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, EMB)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=EMB, nhead=1, dim_feedforward=EMB * FF_MULT,
            dropout=0.1, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=1)
        self.classifier = nn.Sequential(
            nn.Linear(EMB, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x, mask):
        x = self.input_proj(x)
        x = self.transformer(x, src_key_padding_mask=mask)
        return self.classifier(x).squeeze(-1)

# --- Simple LSTM Model ---
class SimpleLSTM(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, EMB, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(EMB, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x, mask):
        out, _ = self.lstm(x)
        return self.classifier(out).squeeze(-1)

# --- Simple GRU Model ---
class SimpleGRU(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.gru = nn.GRU(input_dim, EMB, batch_first=True)
        self.classifier = nn.Sequential(
            nn.Linear(EMB, 32),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, x, mask):
        out, _ = self.gru(x)
        return self.classifier(out).squeeze(-1)


In [None]:
# --- Evaluation metrics for binary sequence classification ---
def evaluate(model, loader, criterion):
    """
    Evaluates the model using provided DataLoader and loss function.

    Args:
        model: PyTorch model
        loader: DataLoader for evaluation
        criterion: Loss function (e.g. BCELoss)

    Returns:
        Dictionary with AUC, AUPRC, F1, Precision, Recall, and Loss
    """
    model.eval()
    all_preds, all_targets = [], []
    total_loss, total_samples = 0, 0

    with torch.no_grad():
        for x, y, mask in loader:
            x, y, mask = x.to(DEVICE), y.to(DEVICE), mask.to(DEVICE)
            preds = model(x, mask)
            loss = criterion(preds, y)
            loss = loss.masked_fill(mask, 0.0)
            valid = (~mask).sum()
            total_loss += loss.sum().item()
            total_samples += valid.item()
            all_preds.extend(preds[~mask].cpu().numpy())
            all_targets.extend(y[~mask].cpu().numpy())

    metrics = {
        'auc': roc_auc_score(all_targets, all_preds),
        'auprc': average_precision_score(all_targets, all_preds),
        'f1': f1_score(all_targets, np.round(all_preds)),
        'precision': precision_score(all_targets, np.round(all_preds)),
        'recall': recall_score(all_targets, np.round(all_preds)),
        'loss': total_loss / total_samples
    }
    return metrics


In [None]:
# --- Training function with early stopping and model checkpointing ---
def train_and_evaluate(model_class, name, input_dim, seed):
    """
    Trains and evaluates a given model.

    Args:
        model_class: PyTorch model class to instantiate
        name (str): Name identifier for saving model
        input_dim (int): Number of input features
        seed (int): Random seed for reproducibility

    Returns:
        Tuple: (metrics_dict, trained_model, test_loader)
    """
    torch.manual_seed(seed)
    np.random.seed(seed)

    train_loader = DataLoader(
        HDF5WindowDataset(H5_PATH, 'train'),
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(
        HDF5WindowDataset(H5_PATH, 'val'),
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn
    )
    test_loader = DataLoader(
        HDF5WindowDataset(H5_PATH, 'test'),
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn
    )

    model = model_class(input_dim).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-6)
    criterion = nn.BCELoss(reduction='none')

    best_loss = float('inf')
    patience_counter = 0
    best_state = None

    for epoch in range(EPOCHS):
        model.train()
        for x, y, mask in train_loader:
            x, y, mask = x.to(DEVICE), y.to(DEVICE), mask.to(DEVICE)
            preds = model(x, mask)
            loss = criterion(preds, y)
            loss = loss.masked_fill(mask, 0.0)
            valid = (~mask).sum()
            loss_val = loss.sum() / valid
            loss_val.backward()
            optimizer.step()
            optimizer.zero_grad()

        val_metrics = evaluate(model, val_loader, criterion)
        print(f"Epoch {epoch+1}: AUC={val_metrics['auc']:.4f}, F1={val_metrics['f1']:.4f}, AUPRC={val_metrics['auprc']:.4f}")

        if val_metrics['loss'] + MIN_DELTA < best_loss:
            best_loss = val_metrics['loss']
            patience_counter = 0
            best_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

    model.load_state_dict(best_state)
    torch.save(best_state, f"{MODEL_DIR}/{name}_seed{seed}.pt")
    test_metrics = evaluate(model, test_loader, criterion)
    print(f"{name} | Test AUC: {test_metrics['auc']:.4f}, F1: {test_metrics['f1']:.4f}")
    return {'model': name, 'seed': seed, **test_metrics}, model, test_loader


In [None]:
# --- Define model variants ---
models = {
    'Transformer': SimpleTransformer,
    'LSTM': SimpleLSTM,
    'GRU': SimpleGRU
}

results = []
all_models = {}

# --- Train and evaluate each model with multiple random seeds ---
for name, cls in models.items():
    for seed in SEEDS:
        res, mod, loader = train_and_evaluate(cls, name, input_dim=620, seed=seed)
        results.append(res)
        all_models[f"{name}_seed{seed}"] = (mod, loader)

# --- Export final results to CSV ---
df = pd.DataFrame(results)
df.to_csv("metrics_completas.csv", index=False)

# --- Summary printout ---
print("\nSummary Results:")
print(df.groupby("model")[["auc", "f1", "auprc", "precision", "recall"]].mean())
