In [None]:
# -*- coding: utf-8 -*-
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import GroupKFold, GroupShuffleSplit
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import PCA

# ===== Configuración general =====
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Ventana fija por hora (48h)
WINDOW_HOURS = 48
L_MAX = WINDOW_HOURS  # pasos temporales
TS_CHANNELS = 2       # HR + RR/MAP (o HR + Lactato)
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 2
DROPOUT = 0.2
LR = 2e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 20
GRAD_CLIP = 1.0
BATCH_TRAIN = 64
BATCH_VAL = 128
BATCH_TEST = 128

# ===== Carga/transformación de MIMIC =====
def load_mimic_data(mimic_root: Path):
    """
    Espera que tengas tablas/CSVs ya extraídos:
    - chartevents_hourly.csv: columns [subject_id, hadm_id, icustay_id, chart_time_hour, HR, RR, MAP]
    - labevents_hourly.csv: columns [subject_id, hadm_id, icustay_id, chart_time_hour, Lactate]
    - admissions_labels.csv: [subject_id, hadm_id, icustay_id, admit_time, label_binary]
    - profiles.csv: [subject_id, hadm_id, icustay_id, age, gender, admission_type, charlson, ...]
    - notes_text.csv: [subject_id, hadm_id, icustay_id, note_time, note_text] (primera nota o conjunto)
    """
    charts = pd.read_csv(mimic_root / 'chartevents_hourly.csv', parse_dates=['chart_time_hour'])
    labs = pd.read_csv(mimic_root / 'labevents_hourly.csv', parse_dates=['chart_time_hour'])
    labels = pd.read_csv(mimic_root / 'admissions_labels.csv', parse_dates=['admit_time'])
    profiles = pd.read_csv(mimic_root / 'profiles.csv')
    notes = pd.read_csv(mimic_root / 'notes_text.csv', parse_dates=['note_time'])

    # Normaliza IDs a string
    for df in (charts, labs, labels, profiles, notes):
        for col in ['subject_id', 'hadm_id', 'icustay_id']:
            if col in df.columns:
                df[col] = df[col].astype(str)

    return charts, labs, labels, profiles, notes

def build_hourly_grid(start_time: pd.Timestamp, hours=L_MAX):
    # Rejilla por hora desde start_time
    tgrid = [start_time + pd.Timedelta(hours=h) for h in range(hours)]
    return pd.DataFrame({'chart_time_hour': tgrid})

def aggregate_signals(charts_user, labs_user, start_time, hours=L_MAX):
    # Construye la rejilla y hace merge por hora
    grid = build_hourly_grid(start_time.normalize(), hours=hours)  # puedes usar admit_time como base
    # Merge charts (HR, RR/MAP)
    charts_user = charts_user.copy()
    charts_user['chart_time_hour'] = charts_user['chart_time_hour'].dt.floor('H')
    merged = grid.merge(charts_user, on='chart_time_hour', how='left')

    # Si usas labs (ej. Lactato) como segunda señal, mézclalo aquí
    if labs_user is not None and not labs_user.empty:
        labs_user = labs_user.copy()
        labs_user['chart_time_hour'] = labs_user['chart_time_hour'].dt.floor('H')
        merged = merged.merge(labs_user[['chart_time_hour','Lactate']], on='chart_time_hour', how='left')

    # Selección de canales
    # Opción A: HR + RR
    v1 = merged['HR'].astype(float) if 'HR' in merged.columns else pd.Series([np.nan]*len(merged))
    v2 = merged['RR'].astype(float) if 'RR' in merged.columns else (
        merged['MAP'].astype(float) if 'MAP' in merged.columns else (
            merged['Lactate'].astype(float) if 'Lactate' in merged.columns else pd.Series([np.nan]*len(merged))
        )
    )

    X = np.stack([
        v1.fillna(np.nan).to_numpy(),
        v2.fillna(np.nan).to_numpy()
    ], axis=-1)  # shape (L, 2)

    M = ~np.isnan(X)
    X = np.where(M, X, 0.0)
    # T normalizado 0..1
    T = np.linspace(0.0, 1.0, len(merged), dtype=float)
    M = M.astype(float)

    return X.astype(float), T.astype(float), M.astype(float)

def build_instances(charts, labs, labels, profiles, notes):
    """
    Construye instancias por (subject_id, hadm_id, icustay_id).
    Toma primeras 48h desde admit_time (o icu_in_time si lo tienes).
    Genera X (L, C=2), T (L), M (L, C), y target binario; y perfil tabular P.
    También produce embeddings de notas (opcional) para P.
    """
    # Prepara embeddings simples de notas (TF-IDF + PCA → 64 dims)
    # Nota: en producción usarías un encoder clínico; esto es para tener una baseline reproducible.
    # Selecciona la primera nota por admisión
    first_notes = notes.sort_values('note_time').groupby(['subject_id','hadm_id','icustay_id'], as_index=False).first()
    text_corpus = first_notes['note_text'].fillna('')
    tfidf = TfidfVectorizer(max_features=5000)
    X_tfidf = tfidf.fit_transform(text_corpus).astype(np.float32)
    pca = PCA(n_components=64, random_state=SEED)
    note_emb = pca.fit_transform(X_tfidf.toarray()) if X_tfidf.shape[0] > 0 else np.zeros((0,64), dtype=np.float32)
    first_notes['note_emb'] = list(note_emb)

    # Perfiles num/cat
    prof = profiles.copy()
    prof = prof.merge(first_notes[['subject_id','hadm_id','icustay_id','note_emb']], on=['subject_id','hadm_id','icustay_id'], how='left')
    prof['note_emb'] = prof['note_emb'].apply(lambda x: np.array(x, dtype=np.float32) if isinstance(x, (list,np.ndarray)) else np.zeros((64,), dtype=np.float32))

    # One-hot para categóricas, fillna mean para numéricas
    cat_cols = [c for c in prof.columns if c in ('gender','admission_type')]
    num_cols = [c for c in prof.columns if c not in cat_cols and c not in ('subject_id','hadm_id','icustay_id','note_emb')]
    prof_num = prof[num_cols].apply(lambda col: col.fillna(col.mean()))
    prof_cat = pd.get_dummies(prof[cat_cols], prefix=cat_cols) if len(cat_cols)>0 else pd.DataFrame(index=prof.index)
    P_tab = pd.concat([prof_num, prof_cat], axis=1)

    # Concatena note_emb a P_tab
    P_all = []
    for i in range(len(prof)):
        vec_tab = P_tab.iloc[i].to_numpy(dtype=np.float32)
        vec_note = prof.iloc[i]['note_emb']
        P_all.append(np.concatenate([vec_tab, vec_note], axis=0))
    # Estandariza P_all canal a canal
    P_mat = np.stack(P_all, axis=0)
    P_mat = (P_mat - np.nanmean(P_mat, axis=0)) / (np.nanstd(P_mat, axis=0) + 1e-6)

    # Empareja con labels
    lab = labels[['subject_id','hadm_id','icustay_id','admit_time','label_binary']].copy()
    lab['label_binary'] = lab['label_binary'].astype(int)
    merged = lab.merge(prof[['subject_id','hadm_id','icustay_id']], on=['subject_id','hadm_id','icustay_id'], how='inner')
    # Construye X/T/M por instancia
    X_list, T_list, M_list, y_list, meta, P_list = [], [], [], [], [], []

    # Índices para reindexar P_mat
    key_to_index = {tuple(row[['subject_id','hadm_id','icustay_id']]): idx for idx, row in prof.reset_index(drop=True).iterrows()}

    for _, row in merged.iterrows():
        key = (row['subject_id'], row['hadm_id'], row['icustay_id'])
        admit_time = pd.to_datetime(row['admit_time'])
        # Filtra datos del usuario/admisión
        charts_user = charts[(charts['subject_id']==key[0]) & (charts['hadm_id']==key[1]) & (charts['icustay_id']==key[2])]
        labs_user = labs[(labs['subject_id']==key[0]) & (labs['hadm_id']==key[1]) & (labs['icustay_id']==key[2])] if labs is not None else None
        X, T, M = aggregate_signals(charts_user, labs_user, admit_time, hours=L_MAX)
        y = int(row['label_binary'])
        X_list.append(X); T_list.append(T); M_list.append(M); y_list.append(y); meta.append(key)
        P_list.append(P_mat[key_to_index[key]])

    return X_list, T_list, M_list, np.array(y_list, dtype=int), meta, np.stack(P_list, axis=0)

# ===== Datasets =====
class TimeDataset(torch.utils.data.Dataset):
    def __init__(self, X, T, M, y, user_ids=None):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.T = torch.tensor(T, dtype=torch.float32)
        self.M = torch.tensor(M, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.user_ids = user_ids if user_ids is not None else [None]*len(y)
    def __len__(self): return len(self.y)
    def __getitem__(self, idx): return self.X[idx], self.T[idx], self.M[idx], self.y[idx], self.user_ids[idx]

class TimeDatasetWithProfile(TimeDataset):
    def __init__(self, X, T, M, y, P, user_ids=None):
        super().__init__(X, T, M, y, user_ids)
        self.P = torch.tensor(P, dtype=torch.float32)
    def __getitem__(self, idx):
        X, T, M, y, uid = super().__getitem__(idx)
        return X, T, M, y, uid, self.P[idx]

# ===== Modelo (reuso de tus bloques) =====
# Copia aquí TimeAttentionBlock, MTANBackbone, FiLMGenerator, HeadMLP, ModelPhase1_TS, ModelPhase2_TSProfile, ModelPhase3_FiLM
# (idénticos a tu código original, salvo que in_channels=TS_CHANNELS)

# ===== Utilidades =====
def compute_channel_stats(X_train, M_train):
    B, L, C = X_train.shape
    means = np.zeros((C,), dtype=float)
    stds = np.ones((C,), dtype=float)
    for c in range(C):
        vals = X_train[..., c][M_train[..., c] == 1.0]
        means[c] = float(np.mean(vals)) if vals.size>0 else 0.0
        stds[c] = float(np.std(vals) + 1e-6) if vals.size>0 else 1.0
    return means, stds

def standardize_by_stats(X, M, means, stds):
    X_std = (X - means[None, None, :]) / stds[None, None, :]
    return np.where(M == 1.0, X_std, 0.0)

# Entrenamiento/Evaluación (igual que tu código)
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    logits_all, targets_all = [], []
    for batch in loader:
        optimizer.zero_grad()
        if len(batch) == 5:
            X, T, M, y, _ = batch
            X, T, M, y = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE)
            logits = model(X, T, M)
        else:
            X, T, M, y, _, P = batch
            X, T, M, y, P = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE), P.to(DEVICE)
            logits = model(X, T, M, P)
        loss = criterion(logits, y)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
        optimizer.step()
        total_loss += loss.item() * X.size(0)
        logits_all.append(logits.detach().cpu().numpy())
        targets_all.append(y.detach().cpu().numpy())
    yhat = np.concatenate(logits_all); yt = np.concatenate(targets_all)
    try: auroc = roc_auc_score(yt, yhat)
    except: auroc = np.nan
    try: auprc = average_precision_score(yt, yhat)
    except: auprc = np.nan
    return total_loss / len(yt), auroc, auprc

@torch.no_grad()
def eval_one_epoch(model, loader, criterion):
    model.eval()
    total_loss = 0.0
    logits_all, targets_all = [], []
    for batch in loader:
        if len(batch) == 5:
            X, T, M, y, _ = batch
            X, T, M, y = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE)
            logits = model(X, T, M)
        else:
            X, T, M, y, _, P = batch
            X, T, M, y, P = X.to(DEVICE), T.to(DEVICE), M.to(DEVICE), y.to(DEVICE), P.to(DEVICE)
            logits = model(X, T, M, P)
        loss = criterion(logits, y)
        total_loss += loss.item() * X.size(0)
        logits_all.append(logits.detach().cpu().numpy())
        targets_all.append(y.detach().cpu().numpy())
    yhat = np.concatenate(logits_all); yt = np.concatenate(targets_all)
    try: auroc = roc_auc_score(yt, yhat)
    except: auroc = np.nan
    try: auprc = average_precision_score(yt, yhat)
    except: auprc = np.nan
    return total_loss / len(yt), auroc, auprc

def run_phase(phase_name, train_loader, val_loader, in_channels, p_dim=None):
    if phase_name == 'P1':
        model = ModelPhase1_TS(in_channels).to(DEVICE)
    elif phase_name == 'P2':
        assert p_dim is not None
        model = ModelPhase2_TSProfile(in_channels, p_dim).to(DEVICE)
    elif phase_name == 'P3':
        assert p_dim is not None
        model = ModelPhase3_FiLM(in_channels, p_dim).to(DEVICE)
    else:
        raise ValueError("phase_name must be P1, P2 or P3")

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    criterion = nn.BCEWithLogitsLoss()

    best_val = -np.inf
    best_state = None

    for ep in range(1, EPOCHS+1):
        tl, tr_auc, tr_pr = train_one_epoch(model, train_loader, optimizer, criterion)
        vl, va_auc, va_pr = eval_one_epoch(model, val_loader, criterion)
        scheduler.step()
        score = np.nanmean([va_auc, va_pr])
        if score > best_val:
            best_val = score
            best_state = { 'model': model.state_dict() }
        print(f"{phase_name} | Epoch {ep:02d} | train_loss={tl:.4f} val_loss={vl:.4f} | "
              f"AUROC train/val={tr_auc:.3f}/{va_auc:.3f} | AUPRC train/val={tr_pr:.3f}/{va_pr:.3f}")

    if best_state is not None:
        model.load_state_dict(best_state['model'])
    return model

# ===== Pipeline principal =====
def main_mimic(mimic_root: Path):
    charts, labs, labels, profiles, notes = load_mimic_data(mimic_root)

    X_list, T_list, M_list, y, meta, P_vals = build_instances(charts, labs, labels, profiles, notes)
    # Reindex y preparar arrays
    X_all = np.stack(X_list, axis=0)  # (N, L, C=2)
    T_all = np.stack(T_list, axis=0)  # (N, L)
    M_all = np.stack(M_list, axis=0)  # (N, L, C)
    user_ids = [f"{s}-{h}-{i}" for (s,h,i) in meta]
    in_channels = X_all.shape[-1]
    p_dim = P_vals.shape[1]

    # Holdout test por HADM_ID (o SUBJECT_ID si prefieres)
    groups = np.array([m[1] for m in meta])  # hadm_id
    gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
    trainval_idx, test_idx = next(gss.split(X_all, y, groups=groups))
    X_tv, T_tv, M_tv, y_tv = X_all[trainval_idx], T_all[trainval_idx], M_all[trainval_idx], y[trainval_idx]
    X_te, T_te, M_te, y_te = X_all[test_idx], T_all[test_idx], M_all[test_idx], y[test_idx]
    P_tv, P_te = P_vals[trainval_idx], P_vals[test_idx]
    groups_tv = groups[trainval_idx]

    # Validación interna por grupos
    gkf = GroupKFold(n_splits=3)
    te_results = {'P1': [], 'P2': [], 'P3': []}

    for fold, (tr_idx_rel, va_idx_rel) in enumerate(gkf.split(X_tv, y_tv, groups=groups_tv), start=1):
        print(f"=== Fold {fold} ===")
        X_tr, T_tr, M_tr, y_tr = X_tv[tr_idx_rel], T_tv[tr_idx_rel], M_tv[tr_idx_rel], y_tv[tr_idx_rel]
        X_va, T_va, M_va, y_va = X_tv[va_idx_rel], T_tv[va_idx_rel], M_tv[va_idx_rel], y_tv[va_idx_rel]
        P_tr, P_va = P_tv[tr_idx_rel], P_tv[va_idx_rel]

        means, stds = compute_channel_stats(X_tr, M_tr)
        X_tr_std = standardize_by_stats(X_tr, M_tr, means, stds)
        X_va_std = standardize_by_stats(X_va, M_va, means, stds)
        X_te_std = standardize_by_stats(X_te, M_te, means, stds)

        ds_tr_p1 = TimeDataset(X_tr_std, T_tr, M_tr, y_tr)
        ds_va_p1 = TimeDataset(X_va_std, T_va, M_va, y_va)
        ds_te_p1 = TimeDataset(X_te_std, T_te, M_te, y_te)

        ds_tr_p2 = TimeDatasetWithProfile(X_tr_std, T_tr, M_tr, y_tr, P_tr)
        ds_va_p2 = TimeDatasetWithProfile(X_va_std, T_va, M_va, y_va, P_va)
        ds_te_p2 = TimeDatasetWithProfile(X_te_std, T_te, M_te, y_te, P_te)

        dl_tr_p1 = DataLoader(ds_tr_p1, batch_size=BATCH_TRAIN, shuffle=True)
        dl_va_p1 = DataLoader(ds_va_p1, batch_size=BATCH_VAL, shuffle=False)
        dl_te_p1 = DataLoader(ds_te_p1, batch_size=BATCH_TEST, shuffle=False)

        dl_tr_p2 = DataLoader(ds_tr_p2, batch_size=BATCH_TRAIN, shuffle=True)
        dl_va_p2 = DataLoader(ds_va_p2, batch_size=BATCH_VAL, shuffle=False)
        dl_te_p2 = DataLoader(ds_te_p2, batch_size=BATCH_TEST, shuffle=False)

        # P3 reutiliza los mismos loaders que P2
        dl_tr_p3, dl_va_p3, dl_te_p3 = dl_tr_p2, dl_va_p2, dl_te_p2

        model_p1 = run_phase('P1', dl_tr_p1, dl_va_p1, in_channels)
        model_p2 = run_phase('P2', dl_tr_p2, dl_va_p2, in_channels, p_dim=p_dim)
        model_p3 = run_phase('P3', dl_tr_p3, dl_va_p3, in_channels, p_dim=p_dim)

        # Test
        _, te_auc_p1, te_pr_p1 = eval_one_epoch(model_p1, dl_te_p1, nn.BCEWithLogitsLoss())
        _, te_auc_p2, te_pr_p2 = eval_one_epoch(model_p2, dl_te_p2, nn.BCEWithLogitsLoss())
        _, te_auc_p3, te_pr_p3 = eval_one_epoch(model_p3, dl_te_p3, nn.BCEWithLogitsLoss())
        te_results['P1'].append((te_auc_p1, te_pr_p1))
        te_results['P2'].append((te_auc_p2, te_pr_p2))
        te_results['P3'].append((te_auc_p3, te_pr_p3))
        print(f"Fold {fold} | Test AUROC P1/P2/P3 = {te_auc_p1:.3f}/{te_auc_p2:.3f}/{te_auc_p3:.3f} | "
              f"AUPRC = {te_pr_p1:.3f}/{te_pr_p2:.3f}/{te_pr_p3:.3f}")

    # Resumen final
    for p in ['P1','P2','P3']:
        arr = np.array(te_results[p])
        print(f"{p} | Test mean AUROC={np.nanmean(arr[:,0]):.3f} | AUPRC={np.nanmean(arr[:,1]):.3f}")

if __name__ == "__main__":
    main_mimic(Path("/path/to/mimic_preprocessed"))
