## 모델 학습

In [1]:
from main_model import TSB_MIMIC_IV  # uses diff_models.diff_CSDI internally

In [2]:
# ============================================================
# CSDI diffusion(imputation) + GRU mortality prediction (M1 Air friendly)
# - uses your diff_models.py (diff_CSDI) and main_model.py (TSB_eICU)
# - NO external ffill/median imputation. We only build observed_mask and fill NaN->0 inside observed_data.
# - VALID threshold tuning, TEST fixed threshold
# ============================================================

from pathlib import Path
import os
import sys
import random
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import (
    roc_auc_score, average_precision_score,
    precision_score, recall_score, f1_score
)

# ----------------------------
# 0) seed / device (M1: mps)
# ----------------------------
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("device:", device)

# ----------------------------
# 1) import your modules
# ----------------------------
# (필요하면) 현재 경로를 PYTHONPATH에 추가
sys.path.append(os.getcwd())



# ----------------------------
# 2) load data
# ----------------------------
train_candidates = [
    Path("../data/29757_train_merged.csv"),
    Path("../data/1000_train_merged.csv"),
    Path("../data/10000_train_merged.csv"),
]
test_candidates = [
    Path("../data/29757_test_merged.csv"),
    Path("../data/1000_test_merged.csv"),
    Path("../data/10000_test_merged.csv"),
]

def resolve_path(candidates: list[Path]) -> Path:
    for p in candidates:
        if p.exists():
            return p
    raise FileNotFoundError(f"No dataset found in: {candidates}")

train_path = resolve_path(train_candidates)
test_path = resolve_path(test_candidates)

train_df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)

target_col = "died_in_icu"
possible_group_cols = ["patientunitstayid", "patient_id"]
possible_time_cols = ["observationoffset"]

group_col = next((c for c in possible_group_cols if c in train_df.columns), None)
time_col = next((c for c in possible_time_cols if c in train_df.columns), None)
if group_col is None or time_col is None:
    raise ValueError("group/time column not found")

# patient_id leakage guard (if exists)
if "patient_id" in train_df.columns and "patient_id" in test_df.columns:
    overlap = set(train_df["patient_id"].unique()) & set(test_df["patient_id"].unique())
    assert len(overlap) == 0, f"patient_id leakage! overlap={len(overlap)}"

numeric_cols = train_df.select_dtypes(include=["number"]).columns
exclude = {target_col, "patient_id", "patientunitstayid"}
feature_cols = [c for c in numeric_cols if c not in exclude]
K = len(feature_cols)
print("n_features(K):", K)

# ----------------------------
# 3) group-stratified train/valid split (by stay)
# ----------------------------
def group_stratified_split(df: pd.DataFrame, group_col: str, target_col: str, valid_ratio=0.2, seed=42):
    stay_y = df.groupby(group_col)[target_col].max()
    pos = stay_y[stay_y == 1].index.to_list()
    neg = stay_y[stay_y == 0].index.to_list()

    rng = np.random.default_rng(seed)
    rng.shuffle(pos)
    rng.shuffle(neg)

    n_pos_val = int(len(pos) * valid_ratio)
    n_neg_val = int(len(neg) * valid_ratio)

    val_groups = set(pos[:n_pos_val] + neg[:n_neg_val])
    tr_groups = set(stay_y.index) - val_groups

    tr = df[df[group_col].isin(tr_groups)].copy()
    va = df[df[group_col].isin(val_groups)].copy()
    return tr, va

train_part, valid_part = group_stratified_split(train_df, group_col, target_col, valid_ratio=0.2, seed=42)
print("rows train/valid:", train_part.shape, valid_part.shape)

# ----------------------------
# 4) build sequences -> batch dict for TSB_eICU.process_data()
#    observed_data: (B, L, K) with NaN replaced by 0
#    observed_mask: (B, L, K) 1 if observed else 0
#    gt_mask: here set to observed_mask (we'll create cond_mask separately during training)
#    offsets: (B, L) padded time values (not used by model core, but required key)
#    seq_length: (B,)
# ----------------------------
MAX_SEQ_LEN = 128  # M1 friendly
TAKE = "last"      # use latest part

def build_stay_arrays(df: pd.DataFrame, group_col: str, time_col: str, feature_cols: list[str], target_col: str,
                      max_seq_len=128, take="last"):
    df = df.sort_values([group_col, time_col]).copy()
    stays = []

    for sid, g in df.groupby(group_col, sort=False):
        t = g[time_col].to_numpy(dtype=np.float32)
        x = g[feature_cols].to_numpy(dtype=np.float32)  # may contain NaN

        if len(g) > max_seq_len:
            if take == "last":
                t = t[-max_seq_len:]
                x = x[-max_seq_len:]
            else:
                t = t[:max_seq_len]
                x = x[:max_seq_len]

        m = (~np.isnan(x)).astype(np.float32)      # observed mask
        x0 = np.nan_to_num(x, nan=0.0).astype(np.float32)  # NaN -> 0 in observed_data

        y = int(g[target_col].max())
        stays.append((sid, x0, m, t, y, len(t)))

    return stays

train_stays = build_stay_arrays(train_part, group_col, time_col, feature_cols, target_col, MAX_SEQ_LEN, TAKE)
valid_stays = build_stay_arrays(valid_part, group_col, time_col, feature_cols, target_col, MAX_SEQ_LEN, TAKE)
test_stays  = build_stay_arrays(test_df,     group_col, time_col, feature_cols, target_col, MAX_SEQ_LEN, TAKE)

print("n_stays train/valid/test:", len(train_stays), len(valid_stays), len(test_stays))

class StayDataset(Dataset):
    def __init__(self, stays):
        self.stays = stays
    def __len__(self):
        return len(self.stays)
    def __getitem__(self, idx):
        sid, x0, m, t, y, L = self.stays[idx]
        return sid, x0, m, t, y, L

def collate_stays(batch):
    # batch: list of (sid, x0[L,K], m[L,K], t[L], y, L)
    B = len(batch)
    Lmax = max(b[5] for b in batch)
    K = batch[0][1].shape[1]

    patient_id = np.zeros((B,), dtype=np.int64)
    observed_data = np.zeros((B, Lmax, K), dtype=np.float32)
    observed_mask = np.zeros((B, Lmax, K), dtype=np.float32)
    gt_mask = np.zeros((B, Lmax, K), dtype=np.float32)
    offsets = np.zeros((B, Lmax), dtype=np.float32)
    status = np.zeros((B,), dtype=np.int64)
    seq_length = np.zeros((B,), dtype=np.int64)

    for i, (sid, x0, m, t, y, L) in enumerate(batch):
        patient_id[i] = int(sid) if np.isscalar(sid) else i
        observed_data[i, :L] = x0
        observed_mask[i, :L] = m
        gt_mask[i, :L] = m
        offsets[i, :L] = t
        status[i] = int(y)
        seq_length[i] = int(L)

    return {
        "patient_id": torch.from_numpy(patient_id),
        "observed_data": torch.from_numpy(observed_data),
        "observed_mask": torch.from_numpy(observed_mask),
        "gt_mask": torch.from_numpy(gt_mask),
        "status": torch.from_numpy(status),
        "offsets": torch.from_numpy(offsets),
        "seq_length": torch.from_numpy(seq_length),
    }

# ----------------------------
# 5) diffusion model config (M1 friendly)
# ----------------------------
config = {
    "model": {
        "featureemb": 16,
        "target_strategy": "random",  # use get_randmask
    },
    "diffusion": {
        "num_steps": 20,                 # M1 friendly
        "schedule": "linear",
        "beta_start": 1e-4,
        "beta_end": 2e-2,
        "channels": 64,
        "diffusion_embedding_dim": 64,
        "nheads": 4,
        "layers": 4,
    }
}

csdi = TSB_MIMIC_IV(config=config, device=device, target_dim=K).to(device)

# ----------------------------
# 6) patch calc_loss bug (train uses randint, eval uses full)
# ----------------------------
def calc_loss_fixed(self, observed_data, cond_mask, observed_mask, side_info, is_train, seq_length, set_t=-1):
    # observed_data: (B,K,L)
    B, K, L = observed_data.shape
    if is_train == 1:
        t = torch.randint(0, self.num_steps, (B,), device=self.device, dtype=torch.long)
    else:
        t = torch.full((B,), int(set_t), device=self.device, dtype=torch.long)

    current_alpha = self.alpha_torch[t]  # (B,1,1)
    noise = torch.randn_like(observed_data)
    noisy_data = (current_alpha ** 0.5) * observed_data + ((1.0 - current_alpha) ** 0.5) * noise

    total_input = self.set_input_to_diffmodel(noisy_data, observed_data, cond_mask)
    predicted = self.diffmodel(total_input, side_info, t, seq_length)

    target_mask = observed_mask - cond_mask
    residual = (noise - predicted) * target_mask
    num_eval = target_mask.sum()
    loss = (residual ** 2).sum() / (num_eval if num_eval > 0 else 1)
    return loss

import types
csdi.calc_loss = types.MethodType(calc_loss_fixed, csdi)

# ----------------------------
# 7) train diffusion (imputation) model
# ----------------------------
batch_train = 8   # M1 friendly
batch_eval = 16

train_loader = DataLoader(StayDataset(train_stays), batch_size=batch_train, shuffle=True,
                          num_workers=0, collate_fn=collate_stays, drop_last=False)
valid_loader = DataLoader(StayDataset(valid_stays), batch_size=batch_eval, shuffle=False,
                          num_workers=0, collate_fn=collate_stays, drop_last=False)
test_loader  = DataLoader(StayDataset(test_stays),  batch_size=batch_eval, shuffle=False,
                          num_workers=0, collate_fn=collate_stays, drop_last=False)

opt = torch.optim.AdamW(csdi.parameters(), lr=1e-3, weight_decay=1e-5)

@torch.no_grad()
def valid_diffusion_loss(model, loader, t_list=(0, 5, 10, 15)):
    model.eval()
    losses = []
    for batch in loader:
        # to device
        obs_data = batch["observed_data"].to(device).float()   # (B,L,K)
        obs_mask = batch["observed_mask"].to(device).float()
        seq_len  = batch["seq_length"].to(device).long()

        # (B,K,L)
        obs_data = obs_data.permute(0, 2, 1)
        obs_mask = obs_mask.permute(0, 2, 1)

        # random cond_mask even in valid (to evaluate recovery objective consistently)
        cond_mask = model.get_randmask(obs_mask)
        side_info = model.get_side_info(cond_mask)

        # average few fixed t's (cheap)
        l = 0.0
        for tt in t_list:
            l = l + float(model.calc_loss(obs_data, cond_mask, obs_mask, side_info, is_train=0, seq_length=seq_len, set_t=tt).item())
        losses.append(l / len(t_list))
    return float(np.mean(losses)) if losses else float("nan")

EPOCHS_DIFF = 10       # M1 friendly
PATIENCE = 2

best_state = None
best_vloss = float("inf")
pat = 0

for ep in range(1, EPOCHS_DIFF + 1):
    csdi.train()
    tr_losses = []
    for batch in train_loader:
        obs_data = batch["observed_data"].to(device).float()   # (B,L,K)
        obs_mask = batch["observed_mask"].to(device).float()
        seq_len  = batch["seq_length"].to(device).long()

        obs_data = obs_data.permute(0, 2, 1)  # (B,K,L)
        obs_mask = obs_mask.permute(0, 2, 1)

        cond_mask = csdi.get_randmask(obs_mask)
        side_info = csdi.get_side_info(cond_mask)

        loss = csdi.calc_loss(obs_data, cond_mask, obs_mask, side_info, is_train=1, seq_length=seq_len)
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(csdi.parameters(), 1.0)
        opt.step()

        tr_losses.append(loss.item())

    vloss = valid_diffusion_loss(csdi, valid_loader)
    print(f"[DIFF {ep:02d}] train_loss={np.mean(tr_losses):.4f} valid_loss={vloss:.4f}")

    if vloss < best_vloss:
        best_vloss = vloss
        best_state = {k: v.detach().cpu().clone() for k, v in csdi.state_dict().items()}
        pat = 0
    else:
        pat += 1
        if pat >= PATIENCE:
            print("Diffusion early stop.")
            break

assert best_state is not None
csdi.load_state_dict(best_state)
csdi.to(device)
csdi.eval()

# ----------------------------
# 8) GRU classifier on imputed sequences (generated by diffusion)
# ----------------------------
class GRUClassifier(nn.Module):
    def __init__(self, input_dim, hidden=64):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden, batch_first=True)
        self.fc = nn.Linear(hidden, 1)

    def forward(self, x, lengths):
        # x: (B,L,K)
        lengths_sorted, perm = lengths.sort(descending=True)
        x_sorted = x[perm]
        packed = nn.utils.rnn.pack_padded_sequence(x_sorted, lengths_sorted.cpu(), batch_first=True, enforce_sorted=True)
        _, h = self.gru(packed)
        h_last = h[-1]
        _, inv = perm.sort()
        h_last = h_last[inv]
        return self.fc(h_last).squeeze(-1)

@torch.no_grad()
def impute_batch_median(model, batch, n_samples=1):
    # returns imputed_x: (B,L,K), y: (B,), lengths: (B,)
    obs_data = batch["observed_data"].to(device).float()   # (B,L,K)
    obs_mask = batch["observed_mask"].to(device).float()
    y = batch["status"].to(device).float()
    lengths = batch["seq_length"].to(device).long()

    # (B,K,L)
    obs_data_KL = obs_data.permute(0, 2, 1)
    obs_mask_KL = obs_mask.permute(0, 2, 1)

    # condition on all observed values
    cond_mask = obs_mask_KL
    side_info = model.get_side_info(cond_mask)

    samples = model.impute(obs_data_KL, cond_mask, side_info, n_samples=n_samples, seq_length=lengths)  # (B,n,K,L)
    # median across samples -> (B,K,L)
    med = torch.median(samples, dim=1).values
    # fill missing positions with generated, keep observed as original
    imputed_KL = cond_mask * obs_data_KL + (1.0 - cond_mask) * med
    imputed_LK = imputed_KL.permute(0, 2, 1).contiguous()  # (B,L,K)
    return imputed_LK, y, lengths

def tune_threshold_f1(y_true, probs, grid=401):
    best_thr, best_f1 = 0.5, -1.0
    for thr in np.linspace(0, 1, grid):
        f1 = f1_score(y_true, (probs >= thr).astype(int), zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_thr = float(thr)
    return best_thr, best_f1

@torch.no_grad()
def eval_gru(csdi_model, gru_model, loader, n_samples=1):
    csdi_model.eval()
    gru_model.eval()
    probs_all, y_all = [], []
    for batch in loader:
        x_imp, y, lengths = impute_batch_median(csdi_model, batch, n_samples=n_samples)
        logits = gru_model(x_imp, lengths)
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        probs_all.append(probs)
        y_all.append(y.detach().cpu().numpy())
    return np.concatenate(probs_all), np.concatenate(y_all)

def metrics(y_true, probs, thr):
    y_pred = (probs >= thr).astype(int)
    return {
        "AUC": roc_auc_score(y_true, probs) if len(np.unique(y_true)) > 1 else np.nan,
        "AP": average_precision_score(y_true, probs) if len(np.unique(y_true)) > 1 else np.nan,
        "P": precision_score(y_true, y_pred, zero_division=0),
        "R": recall_score(y_true, y_pred, zero_division=0),
        "F1": f1_score(y_true, y_pred, zero_division=0),
    }

gru = GRUClassifier(input_dim=K, hidden=64).to(device)
gru_opt = torch.optim.AdamW(gru.parameters(), lr=1e-3, weight_decay=1e-5)

# pos_weight based on TRAIN stays labels
y_train_stay = np.array([s[4] for s in train_stays], dtype=np.int64)
pos = float(y_train_stay.sum())
neg = float(len(y_train_stay) - y_train_stay.sum())
pos_weight = torch.tensor([neg / max(pos, 1.0)], dtype=torch.float32).to(device)
bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

EPOCHS_GRU = 12      # M1 friendly
PATIENCE_GRU = 2
N_SAMPLES_IMPUTE = 1 # M1 friendly (increase to 3 if you want, but slower)

best_gru = None
best_val_ap = -1.0
pat = 0

for ep in range(1, EPOCHS_GRU + 1):
    gru.train()
    losses = []
    for batch in train_loader:
        # on-the-fly imputation (diffusion frozen)
        x_imp, y, lengths = impute_batch_median(csdi, batch, n_samples=N_SAMPLES_IMPUTE)

        logits = gru(x_imp, lengths)
        loss = bce(logits, y)

        gru_opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(gru.parameters(), 1.0)
        gru_opt.step()

        losses.append(loss.item())

    # VALID evaluate (threshold-free metric: AP)
    va_probs, va_y = eval_gru(csdi, gru, valid_loader, n_samples=N_SAMPLES_IMPUTE)
    va_ap = average_precision_score(va_y, va_probs) if len(np.unique(va_y)) > 1 else np.nan
    va_auc = roc_auc_score(va_y, va_probs) if len(np.unique(va_y)) > 1 else np.nan
    print(f"[GRU {ep:02d}] train_loss={np.mean(losses):.4f} | VALID AUC={va_auc:.4f} AP={va_ap:.4f}")

    if va_ap > best_val_ap:
        best_val_ap = va_ap
        best_gru = {k: v.detach().cpu().clone() for k, v in gru.state_dict().items()}
        pat = 0
    else:
        pat += 1
        if pat >= PATIENCE_GRU:
            print("GRU early stop.")
            break

assert best_gru is not None
gru.load_state_dict(best_gru)
gru.to(device)
gru.eval()

# ----------------------------
# 9) VALID threshold tuning -> TEST fixed
# ----------------------------
va_probs, va_y = eval_gru(csdi, gru, valid_loader, n_samples=N_SAMPLES_IMPUTE)
thr, _ = tune_threshold_f1(va_y, va_probs, grid=401)
print("\n[VALID] tuned thr:", thr)
print("[VALID]", metrics(va_y, va_probs, thr))

te_probs, te_y = eval_gru(csdi, gru, test_loader, n_samples=N_SAMPLES_IMPUTE)
print("\n[TEST] thr(from VALID):", thr)
print("[TEST]", metrics(te_y, te_probs, thr))

device: mps
n_features(K): 37
rows train/valid: (3903369, 40) (965199, 40)
n_stays train/valid/test: 23595 5898 7355




RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)