In [18]:
# -*- coding: utf-8 -*-
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import PCA
import csv

In [2]:
# ===== Configuración general =====
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Ventana fija por hora (48h)
WINDOW_HOURS = 48
L_MAX = WINDOW_HOURS  # pasos temporales
TS_CHANNELS = 2       # HR + RR/MAP (o HR + Lactato)
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
DROPOUT = 0.2
LR = 2e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 20
GRAD_CLIP = 1.0
BATCH_TRAIN = 64
BATCH_VAL = 128
BATCH_TEST = 128

In [19]:
def read_notes_safely(path, usecols=None):
    chunks = []
    for chunk in pd.read_csv(
            path,
            usecols=usecols,
            engine='python',        # más robusto con comillas rotas
            quoting=csv.QUOTE_MINIMAL,
            on_bad_lines='skip',    # si alguna línea es irrecuperable, sáltala
            encoding='utf-8',       # cambia a 'latin1' si necesitas
            parse_dates=['CHARTDATE','CHARTTIME'],
            chunksize=200_000):     # ajusta según memoria
        # Limpieza básica de texto
        if 'TEXT' in chunk.columns:
            chunk['TEXT'] = chunk['TEXT'].astype(str)
            # Normaliza comillas “rotas”
            chunk['TEXT'] = chunk['TEXT'].str.replace('\r', '\n', regex=False)
            chunk['TEXT'] = chunk['TEXT'].str.replace('"', "'", regex=False)
        chunks.append(chunk)
    notes = pd.concat(chunks, ignore_index=True)
    # Tipos de ID a string
    for col in ('SUBJECT_ID','HADM_ID'):
        if col in notes.columns:
            notes[col] = notes[col].astype(str)
    return notes

In [20]:
import pandas as pd
from pathlib import Path

def load_mimic_data(mimic_root: Path):
    # Lee los CSV originales de MIMIC-III con sus columnas reales
    charts = pd.read_csv(mimic_root / 'CHARTEVENTS.csv',
                         usecols=['SUBJECT_ID','HADM_ID','ICUSTAY_ID','ITEMID','CHARTTIME','VALUENUM','VALUEUOM'],
                         parse_dates=['CHARTTIME'])
    labs = pd.read_csv(mimic_root / 'LABEVENTS.csv',
                       usecols=['SUBJECT_ID','HADM_ID','ITEMID','CHARTTIME','VALUENUM','VALUEUOM','FLAG'],
                       parse_dates=['CHARTTIME'])
    labels = pd.read_csv(mimic_root / 'ADMISSIONS.csv',
                         usecols=['SUBJECT_ID','HADM_ID','ADMITTIME','DISCHTIME','DEATHTIME',
                                  'ADMISSION_TYPE','HOSPITAL_EXPIRE_FLAG','HAS_CHARTEVENTS_DATA'],
                         parse_dates=['ADMITTIME','DISCHTIME','DEATHTIME'])
    profiles = pd.read_csv(mimic_root / 'PATIENTS.csv',
                           usecols=['SUBJECT_ID','GENDER','DOB','DOD','DOD_HOSP','EXPIRE_FLAG'],
                           parse_dates=['DOB','DOD'])
    icustays = pd.read_csv(mimic_root / 'ICUSTAYS.csv',
                           usecols=['SUBJECT_ID','HADM_ID','ICUSTAY_ID','INTIME','OUTTIME'],
                           parse_dates=['INTIME','OUTTIME'])
    notes = read_notes_safely(mimic_root / 'NOTEEVENTS.csv.gz',
                          usecols=['SUBJECT_ID','HADM_ID','CHARTDATE','CHARTTIME','CATEGORY','TEXT'])


    # Normaliza tipos de ID a string
    for df in (charts, labs, labels, profiles, icustays, notes):
        for col in ['SUBJECT_ID','HADM_ID']:
            if col in df.columns:
                df[col] = df[col].astype(str)
        if 'ICUSTAY_ID' in df.columns:
            df['ICUSTAY_ID'] = df['ICUSTAY_ID'].astype(str)

    return charts, labs, labels, profiles, icustays, notes


In [9]:
# Diccionario simple de ITEMIDs (ajústalo a tu versión de MIMIC-III)
ITEMIDS = {
    'HR': [211, 220045],           # Heart Rate (CareVue, MetaVision)
    'RR': [618, 220210],           # Respiratory Rate
    'MAP': [456, 220052],          # Mean Arterial Pressure
    # Si quieres usar laboratorio como segundo canal, tendrás que mapear por D_LABITEMS,
    # pero en muchas setups el label "Lactate" tiene ITEMID conocidos. Si no los tienes,
    # filtra por VALUEUOM y nombre con un diccionario auxiliar.
}

def filter_charts_for_signals(charts_raw, signals=('HR','RR')):
    # Quedarse solo con ITEMIDs de interés y columnas necesarias
    mask = charts_raw['ITEMID'].isin(sum([ITEMIDS[s] for s in signals if s in ITEMIDS], []))
    charts = charts_raw.loc[mask, ['SUBJECT_ID','HADM_ID','ICUSTAY_ID','ITEMID','CHARTTIME','VALUENUM','VALUEUOM']].copy()
    charts = charts.dropna(subset=['VALUENUM'])
    charts['VALUENUM'] = charts['VALUENUM'].astype(float)
    return charts

def pick_first_note(notes_raw):
    # Primera nota por HADM_ID (puedes restringir CATEGORY a 'Nursing'/'Physician' si quieres notas tempranas)
    notes = notes_raw.sort_values(['HADM_ID','CHARTTIME']).groupby(['SUBJECT_ID','HADM_ID'], as_index=False).first()
    notes = notes[['SUBJECT_ID','HADM_ID','CHARTTIME','CATEGORY','TEXT']].copy()
    notes['TEXT'] = notes['TEXT'].fillna('')
    return notes

def build_profiles(admissions, patients, notes_first):
    # Perfil: edad, sexo, tipo de admisión + texto opcional
    prof = admissions[['SUBJECT_ID','HADM_ID','ADMITTIME','ADMISSION_TYPE','HOSPITAL_EXPIRE_FLAG']].copy()
    prof = prof.merge(patients[['SUBJECT_ID','GENDER','DOB']], on='SUBJECT_ID', how='left')
    # Edad aproximada (MIMIC-III tiene edades >89 enmascaradas; considera cap a 90)
    prof['age'] = ((prof['ADMITTIME'] - prof['DOB']).dt.days / 365.25).clip(lower=0, upper=90)
    prof['gender'] = prof['GENDER'].astype(str)
    prof['admission_type'] = prof['ADMISSION_TYPE'].astype(str)
    prof['label_binary'] = prof['HOSPITAL_EXPIRE_FLAG'].astype(int)

    # Adjuntar primera nota
    notes_small = notes_first[['SUBJECT_ID','HADM_ID','TEXT']]
    prof = prof.merge(notes_small, on=['SUBJECT_ID','HADM_ID'], how='left')
    prof['TEXT'] = prof['TEXT'].fillna('')

    return prof[['SUBJECT_ID','HADM_ID','ADMITTIME','age','gender','admission_type','label_binary','TEXT']]

def align_charts_to_icu_window(charts, icu, window_hours=48):
    # Usa INTIME como t0, mantiene solo 48h
    icu_small = icu[['SUBJECT_ID','HADM_ID','ICUSTAY_ID','INTIME']].copy()
    charts = charts.merge(icu_small, on=['SUBJECT_ID','HADM_ID','ICUSTAY_ID'], how='inner')
    # Diferencia temporal en horas desde INTIME
    dt = (charts['CHARTTIME'] - charts['INTIME']).dt.total_seconds() / 3600.0
    charts['HOUR_IDX'] = np.floor(dt).astype(int)
    charts = charts[(charts['HOUR_IDX'] >= 0) & (charts['HOUR_IDX'] < window_hours)]
    return charts

def hourly_aggregate(charts_aligned, signals=('HR','RR'), window_hours=48):
    # Construye una tabla (HOUR_IDX, canal) con medianas por hora
    # Mapea ITEMID → señal
    item_to_signal = {}
    for s in signals:
        for it in ITEMIDS.get(s, []):
            item_to_signal[it] = s
    charts_aligned['signal'] = charts_aligned['ITEMID'].map(item_to_signal)

    # Mediana por (SUBJECT_ID,HADM_ID,ICUSTAY_ID,HOUR_IDX,signal)
    gcols = ['SUBJECT_ID','HADM_ID','ICUSTAY_ID','HOUR_IDX','signal']
    agg = charts_aligned.groupby(gcols)['VALUENUM'].median().reset_index()

    # Pivot a matriz por hora: columnas = señales
    pivot = agg.pivot_table(index=['SUBJECT_ID','HADM_ID','ICUSTAY_ID','HOUR_IDX'],
                            columns='signal', values='VALUENUM')
    pivot = pivot.reset_index()

    # Rellenar rejilla completa 0..window_hours-1
    def fill_grid(df_one):
        # df_one: filas para un (subject, hadm, icu)
        grid = pd.DataFrame({'HOUR_IDX': np.arange(window_hours)})
        df = grid.merge(df_one, on='HOUR_IDX', how='left')
        return df

    frames = []
    for key, dfk in pivot.groupby(['SUBJECT_ID','HADM_ID','ICUSTAY_ID']):
        filled = fill_grid(dfk.drop(columns=['SUBJECT_ID','HADM_ID','ICUSTAY_ID']))
        filled['SUBJECT_ID'], filled['HADM_ID'], filled['ICUSTAY_ID'] = key
        frames.append(filled)
    hourly = pd.concat(frames, axis=0, ignore_index=True)

    # Orden de canales acorde a signals
    X_cols = list(signals)
    return hourly[['SUBJECT_ID','HADM_ID','ICUSTAY_ID','HOUR_IDX'] + X_cols]

def make_sequences(hourly_df, profiles_df, signals=('HR','RR'), window_hours=48):
    # Construye X(L,C), M(L,C), T(L)
    instances = []
    for key, dfk in hourly_df.groupby(['SUBJECT_ID','HADM_ID','ICUSTAY_ID']):
        dfk = dfk.sort_values('HOUR_IDX')
        X = dfk[list(signals)].to_numpy(dtype=float)  # (L, C)
        M = ~np.isnan(X)
        X = np.where(M, X, 0.0)
        M = M.astype(float)
        T = np.linspace(0.0, 1.0, window_hours, dtype=float)

        # Etiqueta y perfil
        subj, hadm, icu = key
        prof = profiles_df[(profiles_df['SUBJECT_ID']==subj) & (profiles_df['HADM_ID']==hadm)].iloc[0]
        y = int(prof['label_binary'])
        instances.append({
            'key': key, 'X': X, 'M': M, 'T': T,
            'age': prof['age'], 'gender': prof['gender'], 'admission_type': prof['admission_type'],
            'text': prof['TEXT'], 'label': y
        })
    return instances


In [11]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import PCA

def build_P_matrix(instances):
    # Tabular básico
    genders = [x['gender'] for x in instances]
    adm_types = [x['admission_type'] for x in instances]
    ages = np.array([x['age'] for x in instances], dtype=np.float32).reshape(-1,1)

    # One-hot
    df_cat = pd.DataFrame({'gender': genders, 'admission_type': adm_types})
    cat_oh = pd.get_dummies(df_cat, prefix=['gender','admtype'])
    cat_oh = cat_oh.to_numpy(dtype=np.float32)

    # Texto → TF-IDF → PCA(64)
    texts = [x['text'] for x in instances]
    tfidf = TfidfVectorizer(max_features=5000)
    X_tfidf = tfidf.fit_transform(texts)
    pca = PCA(n_components=64, random_state=42)
    X_text = pca.fit_transform(X_tfidf.toarray()).astype(np.float32)

    # Concatenación y estandarización
    P = np.concatenate([ages, cat_oh, X_text], axis=1)
    P = (P - np.nanmean(P, axis=0)) / (np.nanstd(P, axis=0) + 1e-6)
    return P

In [12]:
def pack_arrays(instances, P):
    X_list = [x['X'] for x in instances]
    M_list = [x['M'] for x in instances]
    T_list = [x['T'] for x in instances]
    y = np.array([x['label'] for x in instances], dtype=int)
    meta = [x['key'] for x in instances]
    X_all = np.stack(X_list, axis=0)
    M_all = np.stack(M_list, axis=0)
    T_all = np.stack(T_list, axis=0)
    return X_all, M_all, T_all, y, meta, P

In [4]:
def build_hourly_grid(start_time: pd.Timestamp, hours=L_MAX):
    # Rejilla por hora desde start_time
    tgrid = [start_time + pd.Timedelta(hours=h) for h in range(hours)]
    return pd.DataFrame({'chart_time_hour': tgrid})

def aggregate_signals(charts_user, labs_user, start_time, hours=L_MAX):
    # Construye la rejilla y hace merge por hora
    grid = build_hourly_grid(start_time.normalize(), hours=hours)  # puedes usar admit_time como base
    # Merge charts (HR, RR/MAP)
    charts_user = charts_user.copy()
    charts_user['chart_time_hour'] = charts_user['chart_time_hour'].dt.floor('H')
    merged = grid.merge(charts_user, on='chart_time_hour', how='left')

    # Si usas labs (ej. Lactato) como segunda señal, mézclalo aquí
    if labs_user is not None and not labs_user.empty:
        labs_user = labs_user.copy()
        labs_user['chart_time_hour'] = labs_user['chart_time_hour'].dt.floor('H')
        merged = merged.merge(labs_user[['chart_time_hour','Lactate']], on='chart_time_hour', how='left')

    # Selección de canales
    # Opción A: HR + RR
    v1 = merged['HR'].astype(float) if 'HR' in merged.columns else pd.Series([np.nan]*len(merged))
    v2 = merged['RR'].astype(float) if 'RR' in merged.columns else (
        merged['MAP'].astype(float) if 'MAP' in merged.columns else (
            merged['Lactate'].astype(float) if 'Lactate' in merged.columns else pd.Series([np.nan]*len(merged))
        )
    )

    X = np.stack([
        v1.fillna(np.nan).to_numpy(),
        v2.fillna(np.nan).to_numpy()
    ], axis=-1)  # shape (L, 2)

    M = ~np.isnan(X)
    X = np.where(M, X, 0.0)
    # T normalizado 0..1
    T = np.linspace(0.0, 1.0, len(merged), dtype=float)
    M = M.astype(float)

    return X.astype(float), T.astype(float), M.astype(float)

def build_instances(charts, labs, labels, profiles, notes):
    """
    Construye instancias por (subject_id, hadm_id, icustay_id).
    Toma primeras 48h desde admit_time (o icu_in_time si lo tienes).
    Genera X (L, C=2), T (L), M (L, C), y target binario; y perfil tabular P.
    También produce embeddings de notas (opcional) para P.
    """
    # Prepara embeddings simples de notas (TF-IDF + PCA → 64 dims)
    # Nota: en producción usarías un encoder clínico; esto es para tener una baseline reproducible.
    # Selecciona la primera nota por admisión
    first_notes = notes.sort_values('note_time').groupby(['subject_id','hadm_id','icustay_id'], as_index=False).first()
    text_corpus = first_notes['note_text'].fillna('')
    tfidf = TfidfVectorizer(max_features=5000)
    X_tfidf = tfidf.fit_transform(text_corpus).astype(np.float32)
    pca = PCA(n_components=64, random_state=SEED)
    note_emb = pca.fit_transform(X_tfidf.toarray()) if X_tfidf.shape[0] > 0 else np.zeros((0,64), dtype=np.float32)
    first_notes['note_emb'] = list(note_emb)

    # Perfiles num/cat
    prof = profiles.copy()
    prof = prof.merge(first_notes[['subject_id','hadm_id','icustay_id','note_emb']], on=['subject_id','hadm_id','icustay_id'], how='left')
    prof['note_emb'] = prof['note_emb'].apply(lambda x: np.array(x, dtype=np.float32) if isinstance(x, (list,np.ndarray)) else np.zeros((64,), dtype=np.float32))

    # One-hot para categóricas, fillna mean para numéricas
    cat_cols = [c for c in prof.columns if c in ('gender','admission_type')]
    num_cols = [c for c in prof.columns if c not in cat_cols and c not in ('subject_id','hadm_id','icustay_id','note_emb')]
    prof_num = prof[num_cols].apply(lambda col: col.fillna(col.mean()))
    prof_cat = pd.get_dummies(prof[cat_cols], prefix=cat_cols) if len(cat_cols)>0 else pd.DataFrame(index=prof.index)
    P_tab = pd.concat([prof_num, prof_cat], axis=1)

    # Concatena note_emb a P_tab
    P_all = []
    for i in range(len(prof)):
        vec_tab = P_tab.iloc[i].to_numpy(dtype=np.float32)
        vec_note = prof.iloc[i]['note_emb']
        P_all.append(np.concatenate([vec_tab, vec_note], axis=0))
    # Estandariza P_all canal a canal
    P_mat = np.stack(P_all, axis=0)
    P_mat = (P_mat - np.nanmean(P_mat, axis=0)) / (np.nanstd(P_mat, axis=0) + 1e-6)

    # Empareja con labels
    lab = labels[['subject_id','hadm_id','icustay_id','admit_time','label_binary']].copy()
    lab['label_binary'] = lab['label_binary'].astype(int)
    merged = lab.merge(prof[['subject_id','hadm_id','icustay_id']], on=['subject_id','hadm_id','icustay_id'], how='inner')
    # Construye X/T/M por instancia
    X_list, T_list, M_list, y_list, meta, P_list = [], [], [], [], [], []

    # Índices para reindexar P_mat
    key_to_index = {tuple(row[['subject_id','hadm_id','icustay_id']]): idx for idx, row in prof.reset_index(drop=True).iterrows()}

    for _, row in merged.iterrows():
        key = (row['subject_id'], row['hadm_id'], row['icustay_id'])
        admit_time = pd.to_datetime(row['admit_time'])
        # Filtra datos del usuario/admisión
        charts_user = charts[(charts['subject_id']==key[0]) & (charts['hadm_id']==key[1]) & (charts['icustay_id']==key[2])]
        labs_user = labs[(labs['subject_id']==key[0]) & (labs['hadm_id']==key[1]) & (labs['icustay_id']==key[2])] if labs is not None else None
        X, T, M = aggregate_signals(charts_user, labs_user, admit_time, hours=L_MAX)
        y = int(row['label_binary'])
        X_list.append(X); T_list.append(T); M_list.append(M); y_list.append(y); meta.append(key)
        P_list.append(P_mat[key_to_index[key]])

    return X_list, T_list, M_list, np.array(y_list, dtype=int), meta, np.stack(P_list, axis=0)

In [5]:
# ===== Datasets =====
class TimeDataset(torch.utils.data.Dataset):
    def __init__(self, X, T, M, y, user_ids=None):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.T = torch.tensor(T, dtype=torch.float32)
        self.M = torch.tensor(M, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.user_ids = user_ids if user_ids is not None else [None]*len(y)
    def __len__(self): return len(self.y)
    def __getitem__(self, idx): return self.X[idx], self.T[idx], self.M[idx], self.y[idx], self.user_ids[idx]

class TimeDatasetWithProfile(TimeDataset):
    def __init__(self, X, T, M, y, P, user_ids=None):
        super().__init__(X, T, M, y, user_ids)
        self.P = torch.tensor(P, dtype=torch.float32)
    def __getitem__(self, idx):
        X, T, M, y, uid = super().__getitem__(idx)
        return X, T, M, y, uid, self.P[idx]

In [6]:
# ===== Modelo (reuso de tus bloques) =====
# Copia aquí TimeAttentionBlock, MTANBackbone, FiLMGenerator, HeadMLP, ModelPhase1_TS, ModelPhase2_TSProfile, ModelPhase3_FiLM
# (idénticos a tu código original, salvo que in_channels=TS_CHANNELS)

# mTAN-like + FiLM (modelos)
class TimeAttentionBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.2):
        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.dk = d_model // n_heads
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4*d_model, d_model),
            nn.Dropout(dropout),
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.time_decay = nn.Parameter(torch.tensor(0.1, dtype=torch.float32))

    def forward(self, x, T, M, gamma=None, beta=None):
        B, L, d = x.size()
        Q = self.Wq(x); K = self.Wk(x); V = self.Wv(x)
        def split_heads(t): return t.view(B, L, self.n_heads, self.dk).transpose(1, 2)
        Qh, Kh, Vh = split_heads(Q), split_heads(K), split_heads(V)
        scores = torch.matmul(Qh, Kh.transpose(-2, -1)) / np.sqrt(self.dk)
        Ti = T.unsqueeze(1).unsqueeze(-1)
        Tj = T.unsqueeze(1).unsqueeze(-2)
        time_dist = torch.abs(Ti - Tj)
        scores = scores - self.time_decay * time_dist
        step_valid = (M.sum(dim=-1) > 0).unsqueeze(1).unsqueeze(2)
        scores = scores.masked_fill(~step_valid, float('-inf'))
        A = torch.softmax(scores, dim=-1)
        A = self.dropout(A)
        Zh = torch.matmul(A, Vh)
        Z = Zh.transpose(1, 2).contiguous().view(B, L, d)
        h_attn = self.out(Z)
        y1 = self.ln1(x + h_attn)
        if gamma is not None and beta is not None:
            y1 = y1 * gamma + beta
        f = self.ffn(y1)
        y2 = self.ln2(y1 + f)
        if gamma is not None and beta is not None:
            y2 = y2 * gamma + beta
        return y2

class MTANBackbone(nn.Module):
    def __init__(self, in_channels, d_model=128, n_layers=2, n_heads=4, dropout=0.2):
        super().__init__()
        self.input_proj = nn.Linear(in_channels, d_model)
        self.layers = nn.ModuleList([
            TimeAttentionBlock(d_model, n_heads, dropout=dropout)
            for _ in range(n_layers)
        ])
        self.d_model = d_model
        self.n_layers = n_layers
    def forward(self, X, T, M, gammas=None, betas=None):
        h = self.input_proj(X)
        for l, layer in enumerate(self.layers):
            gamma_l = None if gammas is None else gammas[:, l, :].unsqueeze(1)
            beta_l  = None if betas  is None else betas[:, l, :].unsqueeze(1)
            h = layer(h, T, M, gamma=gamma_l, beta=beta_l)
        return h

class FiLMGenerator(nn.Module):
    def __init__(self, p_dim, d_model, n_layers, hidden=64, dropout=0.1):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(p_dim, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 2*n_layers*d_model)
        )
        nn.init.zeros_(self.mlp[-1].weight)
        nn.init.zeros_(self.mlp[-1].bias)
        self.d_model = d_model
        self.n_layers = n_layers
    def forward(self, P):
        B = P.size(0)
        out = self.mlp(P).view(B, 2, self.n_layers, self.d_model)
        gammas = out[:, 0, :, :] + 1.0
        betas  = out[:, 1, :, :]
        return gammas, betas

class HeadMLP(nn.Module):
    def __init__(self, d_in, d_hidden=64, dropout=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_in, d_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_hidden, 1)
        )
    def forward(self, x): return self.net(x).squeeze(-1)

# Modelos por fase
class ModelPhase1_TS(nn.Module):
    def __init__(self, in_channels, d_model=D_MODEL, n_layers=N_LAYERS, n_heads=N_HEADS, dropout=DROPOUT):
        super().__init__()
        self.backbone = MTANBackbone(in_channels, d_model, n_layers, n_heads, dropout)
        self.head = HeadMLP(d_model, d_hidden=64, dropout=dropout)
    def forward(self, X, T, M):
        h = self.backbone(X, T, M, gammas=None, betas=None)
        step_valid = (M.sum(dim=-1) > 0).float()
        denom = torch.clamp(step_valid.sum(dim=1, keepdim=True), min=1.0)
        h_seq = (h * step_valid.unsqueeze(-1)).sum(dim=1) / denom
        return self.head(h_seq)

class ModelPhase2_TSProfile(nn.Module):
    def __init__(self, in_channels, p_dim, d_model=D_MODEL, n_layers=N_LAYERS, n_heads=N_HEADS, dropout=DROPOUT):
        super().__init__()
        self.backbone = MTANBackbone(in_channels, d_model, n_layers, n_heads, dropout)
        self.profile_mlp = nn.Sequential(
            nn.Linear(p_dim, 64),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.head = HeadMLP(d_model + 64, d_hidden=64, dropout=dropout)
    def forward(self, X, T, M, P):
        h = self.backbone(X, T, M, gammas=None, betas=None)
        step_valid = (M.sum(dim=-1) > 0).float()
        denom = torch.clamp(step_valid.sum(dim=1, keepdim=True), min=1.0)
        h_seq = (h * step_valid.unsqueeze(-1)).sum(dim=1) / denom
        h_prof = self.profile_mlp(P)
        return self.head(torch.cat([h_seq, h_prof], dim=-1))

class ModelPhase3_FiLM(nn.Module):
    def __init__(self, in_channels, p_dim, d_model=D_MODEL, n_layers=N_LAYERS, n_heads=N_HEADS, dropout=DROPOUT, film_hidden=64):
        super().__init__()
        self.backbone = MTANBackbone(in_channels, d_model, n_layers, n_heads, dropout)
        self.film = FiLMGenerator(p_dim, d_model, n_layers, hidden=film_hidden, dropout=0.1)
        self.head = HeadMLP(d_model, d_hidden=64, dropout=dropout)
    def forward(self, X, T, M, P):
        gammas, betas = self.film(P)
        h = self.backbone(X, T, M, gammas=gammas, betas=betas)
        step_valid = (M.sum(dim=-1) > 0).float()
        denom = torch.clamp(step_valid.sum(dim=1, keepdim=True), min=1.0)
        h_seq = (h * step_valid.unsqueeze(-1)).sum(dim=1) / denom
        return self.head(h_seq)

In [7]:
# ===== Utilidades =====
def compute_channel_stats(X_train, M_train):
    B, L, C = X_train.shape
    means = np.zeros((C,), dtype=float)
    stds = np.ones((C,), dtype=float)
    for c in range(C):
        vals = X_train[..., c][M_train[..., c] == 1.0]
        means[c] = float(np.mean(vals)) if vals.size>0 else 0.0
        stds[c] = float(np.std(vals) + 1e-6) if vals.size>0 else 1.0
    return means, stds

def standardize_by_stats(X, M, means, stds):
    X_std = (X - means[None, None, :]) / stds[None, None, :]
    return np.where(M == 1.0, X_std, 0.0)

# Entrenamiento/Evaluación (igual que tu código)
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    logits_all, targets_all = [], []
    for batch in loader:
        optimizer.zero_grad()
        if len(batch) == 5:
            X, T, M, y, _ = batch
            X, T, M, y = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE)
            logits = model(X, T, M)
        else:
            X, T, M, y, _, P = batch
            X, T, M, y, P = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE), P.to(DEVICE)
            logits = model(X, T, M, P)
        loss = criterion(logits, y)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        total_loss += loss.item() * X.size(0)
        logits_all.append(logits.detach().cpu().numpy())
        targets_all.append(y.detach().cpu().numpy())
    yhat = np.concatenate(logits_all); yt = np.concatenate(targets_all)
    try: auroc = roc_auc_score(yt, yhat)
    except: auroc = np.nan
    try: auprc = average_precision_score(yt, yhat)
    except: auprc = np.nan
    return total_loss / len(yt), auroc, auprc

@torch.no_grad()
def eval_one_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    logits_all, targets_all = [], []
    for batch in loader:
        if len(batch) == 5:
            X, T, M, y, _ = batch
            X, T, M, y = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE)
            logits = model(X, T, M)
        else:
            X, T, M, y, _, P = batch
            X, T, M, y, P = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE), P.to(DEVICE)
            logits = model(X, T, M, P)
        loss = criterion(logits, y)
        total_loss += loss.item() * X.size(0)
        logits_all.append(logits.detach().cpu().numpy())
        targets_all.append(y.detach().cpu().numpy())
    yhat = np.concatenate(logits_all); yt = np.concatenate(targets_all)
    try: auroc = roc_auc_score(yt, yhat)
    except: auroc = np.nan
    try: auprc = average_precision_score(yt, yhat)
    except: auprc = np.nan
    return total_loss / len(yt), auroc, auprc

def run_phase(phase_name, train_loader, val_loader, in_channels, p_dim=None):
    if phase_name == 'P1':
        model = ModelPhase1_TS(in_channels).to(DEVICE)
    elif phase_name == 'P2':
        assert p_dim is not None
        model = ModelPhase2_TSProfile(in_channels, p_dim).to(DEVICE)
    elif phase_name == 'P3':
        assert p_dim is not None
        model = ModelPhase3_FiLM(in_channels, p_dim).to(DEVICE)
    else:
        raise ValueError("phase_name must be P1, P2 or P3")

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    criterion = nn.BCEWithLogitsLoss()

    best_val = -np.inf
    best_state = None

    for ep in range(1, EPOCHS+1):
        tl, tr_auc, tr_pr = train_one_epoch(model, train_loader, optimizer, criterion)
        vl, va_auc, va_pr = eval_one_epoch(model, val_loader, criterion)
        scheduler.step()
        score = np.nanmean([va_auc, va_pr])
        if score > best_val:
            best_val = score
            best_state = { 'model': model.state_dict() }
        print(f"{phase_name} | Epoch {ep:02d} | train_loss={tl:.4f} val_loss={vl:.4f} | "
              f"AUROC train/val={tr_auc:.3f}/{va_auc:.3f} | AUPRC train/val={tr_pr:.3f}/{va_pr:.3f}")

    if best_state is not None:
        model.load_state_dict(best_state['model'])
    return model

In [21]:
# ===== Pipeline principal =====
def main_mimic(mimic_root: Path):
    charts, labs, labels, profiles, icustays, notes = load_mimic_data(mimic_root)

    X_list, T_list, M_list, y, meta, P_vals = build_instances(charts, labs, labels, profiles, notes)
    # Reindex y preparar arrays
    X_all = np.stack(X_list, axis=0)  # (N, L, C=2)
    T_all = np.stack(T_list, axis=0)  # (N, L)
    M_all = np.stack(M_list, axis=0)  # (N, L, C)
    user_ids = [f"{s}-{h}-{i}" for (s,h,i) in meta]
    in_channels = X_all.shape[-1]
    p_dim = P_vals.shape[1]

    # Holdout test por HADM_ID (o SUBJECT_ID si prefieres)
    groups = np.array([m[1] for m in meta])  # hadm_id
    gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
    trainval_idx, test_idx = next(gss.split(X_all, y, groups=groups))
    X_tv, T_tv, M_tv, y_tv = X_all[trainval_idx], T_all[trainval_idx], M_all[trainval_idx], y[trainval_idx]
    X_te, T_te, M_te, y_te = X_all[test_idx], T_all[test_idx], M_all[test_idx], y[test_idx]
    P_tv, P_te = P_vals[trainval_idx], P_vals[test_idx]
    groups_tv = groups[trainval_idx]

    # Validación interna por grupos
    gkf = GroupKFold(n_splits=3)
    te_results = {'P1': [], 'P2': [], 'P3': []}

    for fold, (tr_idx_rel, va_idx_rel) in enumerate(gkf.split(X_tv, y_tv, groups=groups_tv), start=1):
        print(f"=== Fold {fold} ===")
        X_tr, T_tr, M_tr, y_tr = X_tv[tr_idx_rel], T_tv[tr_idx_rel], M_tv[tr_idx_rel], y_tv[tr_idx_rel]
        X_va, T_va, M_va, y_va = X_tv[va_idx_rel], T_tv[va_idx_rel], M_tv[va_idx_rel], y_tv[va_idx_rel]
        P_tr, P_va = P_tv[tr_idx_rel], P_tv[va_idx_rel]

        means, stds = compute_channel_stats(X_tr, M_tr)
        X_tr_std = standardize_by_stats(X_tr, M_tr, means, stds)
        X_va_std = standardize_by_stats(X_va, M_va, means, stds)
        X_te_std = standardize_by_stats(X_te, M_te, means, stds)

        ds_tr_p1 = TimeDataset(X_tr_std, T_tr, M_tr, y_tr)
        ds_va_p1 = TimeDataset(X_va_std, T_va, M_va, y_va)
        ds_te_p1 = TimeDataset(X_te_std, T_te, M_te, y_te)

        ds_tr_p2 = TimeDatasetWithProfile(X_tr_std, T_tr, M_tr, y_tr, P_tr)
        ds_va_p2 = TimeDatasetWithProfile(X_va_std, T_va, M_va, y_va, P_va)
        ds_te_p2 = TimeDatasetWithProfile(X_te_std, T_te, M_te, y_te, P_te)

        dl_tr_p1 = DataLoader(ds_tr_p1, batch_size=BATCH_TRAIN, shuffle=True)
        dl_va_p1 = DataLoader(ds_va_p1, batch_size=BATCH_VAL, shuffle=False)
        dl_te_p1 = DataLoader(ds_te_p1, batch_size=BATCH_TEST, shuffle=False)

        dl_tr_p2 = DataLoader(ds_tr_p2, batch_size=BATCH_TRAIN, shuffle=True)
        dl_va_p2 = DataLoader(ds_va_p2, batch_size=BATCH_VAL, shuffle=False)
        dl_te_p2 = DataLoader(ds_te_p2, batch_size=BATCH_TEST, shuffle=False)

        # P3 reutiliza los mismos loaders que P2
        dl_tr_p3, dl_va_p3, dl_te_p3 = dl_tr_p2, dl_va_p2, dl_te_p2

        model_p1 = run_phase('P1', dl_tr_p1, dl_va_p1, in_channels=TS_CHANNELS)
        model_p2 = run_phase('P2', dl_tr_p2, dl_va_p2, in_channels=TS_CHANNELS, p_dim=p_dim)
        model_p3 = run_phase('P3', dl_tr_p3, dl_va_p3, in_channels=TS_CHANNELS, p_dim=p_dim)

        # Test
        _, te_auc_p1, te_pr_p1 = eval_one_epoch(model_p1, dl_te_p1, nn.BCEWithLogitsLoss())
        _, te_auc_p2, te_pr_p2 = eval_one_epoch(model_p2, dl_te_p2, nn.BCEWithLogitsLoss())
        _, te_auc_p3, te_pr_p3 = eval_one_epoch(model_p3, dl_te_p3, nn.BCEWithLogitsLoss())
        te_results['P1'].append((te_auc_p1, te_pr_p1))
        te_results['P2'].append((te_auc_p2, te_pr_p2))
        te_results['P3'].append((te_auc_p3, te_pr_p3))
        print(f"Fold {fold} | Test AUROC P1/P2/P3 = {te_auc_p1:.3f}/{te_auc_p2:.3f}/{te_auc_p3:.3f} | "
              f"AUPRC = {te_pr_p1:.3f}/{te_pr_p2:.3f}/{te_pr_p3:.3f}")

    # Resumen final
    for p in ['P1','P2','P3']:
        arr = np.array(te_results[p])
        print(f"{p} | Test mean AUROC={np.nanmean(arr[:,0]):.3f} | AUPRC={np.nanmean(arr[:,1]):.3f}")

if __name__ == "__main__":
    main_mimic(Path("/home/gmartinez/Tesis/Datasets/MIMIC-III"))

FileNotFoundError: [Errno 2] No such file or directory: '/home/gmartinez/Tesis/Datasets/MIMIC-III/NOTEEVENTS.csv.gz'