In [5]:
# ===========================
# Temporal Transformer - Train & Save Best Model + CSV Embeddings
# Kaggle-ready (uses GPU if available)
# ===========================
import os, math, random, time, copy
from typing import List
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

# PyTorch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# sklearn
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score

# -----------------------
# Reproducibility
# -----------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# -----------------------
# SETTINGS / HYPERPARAMS
# -----------------------
CSV_PATH = "/kaggle/input/transformer-training-dataset/temporal_vitals_6h_labeled.csv"
OUT_MODEL_PATH = "/kaggle/working/best_temporal_transformer.pt"
OUT_CSV_PATH = "/kaggle/working/temporal_embeddings_predictions.csv"
OUT_SCALER_PATH = "/kaggle/working/scaler.npy"
OUT_IMPUTER_PATH = "/kaggle/working/imputer.npy"

ID_COL = "subject_id"
STAY_COL = "stay_id"
TIME_COL = "hour_bin"
LABEL_COL = "ventilation_within_12h"

BATCH_SIZE = 64
EMBED_DIM = 128
NHEAD = 4
NUM_LAYERS = 3
DROPOUT = 0.2
EPOCHS = 100
PATIENCE = 8
LR = 1e-3
WEIGHT_DECAY = 1e-4
CLIP_NORM = 1.0
NUM_WORKERS = 2
MAX_SEQ_LEN_LIMIT = None  # None -> use dataset max

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# -----------------------
# 1) Load data
# -----------------------
df = pd.read_csv(CSV_PATH)
print("Raw shape:", df.shape)
print("Columns:", df.columns.tolist())

# Ensure required cols exist
for c in (ID_COL, STAY_COL, TIME_COL, LABEL_COL):
    assert c in df.columns, f"Missing required column: {c}"

# -----------------------
# 2) Preprocess features
# -----------------------
# exclude ids/time/label from features
exclude = {ID_COL, "hadm_id", STAY_COL, TIME_COL, LABEL_COL}
feature_cols = [c for c in df.columns if c not in exclude and df[c].dtype != object]
print(f"Using {len(feature_cols)} feature columns.")

# sort by stay + time
df = df.sort_values([STAY_COL, TIME_COL]).reset_index(drop=True)

# Per-stay forward/back-fill then global mean imputing
for col in feature_cols:
    # transform keeps index alignment
    df[col] = df.groupby(STAY_COL)[col].transform(lambda s: s.ffill().bfill())

# global mean imputer for any remaining NaNs (e.g., a whole stay had NaN)
imputer = SimpleImputer(strategy="mean")
df[feature_cols] = imputer.fit_transform(df[feature_cols])

# Standard scaling (fit on full dataset here; in strict evaluation use train-only)
scaler = StandardScaler()
df[feature_cols] = scaler.fit_transform(df[feature_cols])

# Save scaler & imputer (numpy)
np.save(OUT_SCALER_PATH, np.array([scaler.mean_, scaler.scale_], dtype=object), allow_pickle=True)
np.save(OUT_IMPUTER_PATH, np.array([imputer.statistics_], dtype=object), allow_pickle=True)

# Determine max seq length
seq_counts = df.groupby(STAY_COL)[TIME_COL].nunique().values
max_seq_len = int(seq_counts.max()) if MAX_SEQ_LEN_LIMIT is None else min(MAX_SEQ_LEN_LIMIT, int(seq_counts.max()))
print("Max seq length used:", max_seq_len)

# -----------------------
# 3) Train-Val split by stay_id (no leakage)
# -----------------------
stays = df[STAY_COL].unique()
train_stays, val_stays = train_test_split(stays, test_size=0.2, random_state=SEED, shuffle=True)
train_df = df[df[STAY_COL].isin(train_stays)].reset_index(drop=True)
val_df = df[df[STAY_COL].isin(val_stays)].reset_index(drop=True)
print(f"Train stays: {len(train_stays)}, Val stays: {len(val_stays)}")

# -----------------------
# 4) Dataset + collate function (padding + mask)
# -----------------------
class TemporalStayDataset(Dataset):
    def __init__(self, df: pd.DataFrame, stay_col: str, id_col: str, time_col: str, feature_cols: List[str], label_col: str, max_len: int):
        self.feature_cols = feature_cols
        self.label_col = label_col
        self.max_len = max_len
        # group stays
        self.rows = []
        for stay_id, g in df.groupby(stay_col):
            g = g.sort_values(time_col)
            feats = g[feature_cols].to_numpy(dtype=np.float32)
            label = int(g[label_col].iloc[0])
            subject_id = g[id_col].iloc[0]
            self.rows.append((subject_id, stay_id, feats, label))
    def __len__(self):
        return len(self.rows)
    def __getitem__(self, idx):
        return self.rows[idx]

def collate_fn(batch):
    # batch: list of (subject_id, stay_id, feats (T,F), label)
    batch_size = len(batch)
    n_feat = batch[0][2].shape[1]
    seq_lens = [min(x[2].shape[0], max_seq_len) for x in batch]
    T = max(seq_lens)
    X = np.zeros((batch_size, T, n_feat), dtype=np.float32)
    pad_mask = np.ones((batch_size, T), dtype=bool)  # True indicates padding for transformer.mask
    labels = np.zeros((batch_size,), dtype=np.float32)
    subject_ids = []
    stay_ids = []
    for i, (subject_id, stay_id, feats, label) in enumerate(batch):
        L = min(feats.shape[0], T)
        X[i, :L, :] = feats[:L]
        pad_mask[i, :L] = False
        labels[i] = label
        subject_ids.append(subject_id)
        stay_ids.append(stay_id)
    return {
        "X": torch.from_numpy(X),             # (B, T, F)
        "pad_mask": torch.from_numpy(pad_mask), # (B, T) bool
        "y": torch.from_numpy(labels).unsqueeze(1), # (B,1)
        "subject_ids": subject_ids,
        "stay_ids": stay_ids
    }

train_dataset = TemporalStayDataset(train_df, STAY_COL, ID_COL, TIME_COL, feature_cols, LABEL_COL, max_seq_len)
val_dataset = TemporalStayDataset(val_df, STAY_COL, ID_COL, TIME_COL, feature_cols, LABEL_COL, max_seq_len)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True)

print("Train samples:", len(train_dataset), "Val samples:", len(val_dataset))

# -----------------------
# 5) Model (Transformer Encoder + classification head)
# -----------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model:int, max_len:int=1024):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    def forward(self, x: torch.Tensor):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :].to(x.device)

class TemporalTransformerClassifier(nn.Module):
    def __init__(self, in_dim:int, emb_dim:int=EMBED_DIM, nhead:int=NHEAD, nlayers:int=NUM_LAYERS, dropout:float=DROPOUT):
        super().__init__()
        self.input_fc = nn.Linear(in_dim, emb_dim)
        self.pos = PositionalEncoding(emb_dim, max_len=max_seq_len)
        enc_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=emb_dim*4, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=nlayers)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(emb_dim, 1)
    def forward(self, x: torch.Tensor, src_key_padding_mask: torch.Tensor):
        # x: (B, T, F)
        x = self.input_fc(x)              # (B, T, D)
        x = self.pos(x)                   # add positional
        # src_key_padding_mask: (B, T) True for padding positions
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)  # (B, T, D)
        # mean pool over valid steps
        valid_mask = (~src_key_padding_mask).unsqueeze(-1).to(x.device)  # (B, T, 1)
        sum_hidden = (x * valid_mask).sum(dim=1)                         # (B, D)
        lengths = valid_mask.sum(dim=1).clamp(min=1).to(x.device)        # (B,1)
        pooled = sum_hidden / lengths                                    # (B, D)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)                                 # (B,1)
        return logits, pooled

model = TemporalTransformerClassifier(len(feature_cols), EMBED_DIM, NHEAD, NUM_LAYERS, DROPOUT).to(DEVICE)

# -----------------------
# 6) Loss, optimizer, scheduler
# Use class weighting to handle imbalance
# -----------------------
# compute pos weight for BCEWithLogitsLoss
all_labels = train_df[LABEL_COL].astype(int).values
pos_count = all_labels.sum()
neg_count = len(all_labels) - pos_count
pos_weight = torch.tensor([(neg_count / (pos_count + 1e-9))]).to(DEVICE)  # avoid zero divide
print("Pos weight:", float(pos_weight))

criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

# -----------------------
# 7) Training / Evaluation functions
# -----------------------
def evaluate_model(loader):
    model.eval()
    losses, preds_list, labels_list = [], [], []
    with torch.no_grad():
        for batch in loader:
            X = batch["X"].to(DEVICE)
            pad_mask = batch["pad_mask"].to(DEVICE)
            y = batch["y"].to(DEVICE)
            logits, _ = model(X, src_key_padding_mask=pad_mask)
            loss = criterion(logits, y)
            losses.append(loss.item() * y.size(0))
            probs = torch.sigmoid(logits).cpu().numpy().flatten()
            preds_list.extend(probs.tolist())
            labels_list.extend(y.cpu().numpy().flatten().tolist())
    if len(labels_list) == 0:
        return float("nan"), {}
    loss_epoch = sum(losses) / len(loader.dataset)
    try:
        auc = roc_auc_score(labels_list, preds_list)
    except Exception:
        auc = float("nan")
    aupr = average_precision_score(labels_list, preds_list) if len(set(labels_list))>1 else float("nan")
    preds_binary = [1 if p >= 0.5 else 0 for p in preds_list]
    acc = accuracy_score(labels_list, preds_binary)
    f1 = f1_score(labels_list, preds_binary, zero_division=0)
    metrics = {"loss": loss_epoch, "auc": auc, "aupr": aupr, "acc": acc, "f1": f1}
    return loss_epoch, metrics

# -----------------------
# 8) Training loop with early stopping
# -----------------------
best_val_loss = float("inf")
best_state = None
patience = 0
history = []

for epoch in range(1, EPOCHS+1):
    model.train()
    epoch_loss = 0.0
    n_samples = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=False)
    for batch in pbar:
        X = batch["X"].to(DEVICE)
        pad_mask = batch["pad_mask"].to(DEVICE)
        y = batch["y"].to(DEVICE)
        optimizer.zero_grad()
        logits, _ = model(X, src_key_padding_mask=pad_mask)
        loss = criterion(logits, y)
        loss.backward()
        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        optimizer.step()
        epoch_loss += loss.item() * y.size(0)
        n_samples += y.size(0)
    train_loss = epoch_loss / n_samples if n_samples>0 else float("nan")
    val_loss, val_metrics = evaluate_model(val_loader)
    scheduler.step(val_loss)
    history.append({"epoch": epoch, "train_loss": train_loss, **{f"val_{k}":v for k,v in val_metrics.items()}})
    print(f"Epoch {epoch:03d} | Train Loss: {train_loss:.4f} | Val Loss: {val_metrics['loss']:.4f} | Val AUROC: {val_metrics['auc']:.4f} | Val AUPR: {val_metrics['aupr']:.4f} | Val Acc: {val_metrics['acc']:.4f} | Val F1: {val_metrics['f1']:.4f}")
    # save best
    if val_metrics["loss"] < best_val_loss - 1e-6:
        best_val_loss = val_metrics["loss"]
        best_state = copy.deepcopy(model.state_dict())
        torch.save(best_state, OUT_MODEL_PATH)
        patience = 0
        print("  -> New best model saved.")
    else:
        patience += 1
        if patience >= PATIENCE:
            print("Early stopping triggered. Patience exceeded.")
            break

# -----------------------
# 9) Load best model and create embeddings + predictions for ALL stays (train+val or entire dataset)
# -----------------------
if best_state is None:
    best_state = model.state_dict()
model.load_state_dict(best_state)
model.eval()

# create a DataLoader for entire dataset (train+val) to produce embeddings; use df grouped by stay
all_dataset = TemporalStayDataset(df, STAY_COL, ID_COL, TIME_COL, feature_cols, LABEL_COL, max_seq_len)
all_loader = DataLoader(all_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=NUM_WORKERS, pin_memory=True)

subject_list, stay_list, label_list, prob_list, emb_list = [], [], [], [], []
with torch.no_grad():
    for batch in tqdm(all_loader, desc="Generating embeddings for all stays"):
        X = batch["X"].to(DEVICE)
        pad_mask = batch["pad_mask"].to(DEVICE)
        logits, emb = model(X, src_key_padding_mask=pad_mask)
        probs = torch.sigmoid(logits).squeeze(1).cpu().numpy()
        emb_np = emb.cpu().numpy()
        subject_list.extend(batch["subject_ids"])
        stay_list.extend(batch["stay_ids"])
        # labels: get from df (first row for stay)
        for sid in batch["stay_ids"]:
            lab = int(df[df[STAY_COL] == sid][LABEL_COL].iloc[0])
            label_list.append(lab)
        prob_list.extend(probs.tolist())
        emb_list.extend(emb_np.tolist())

# Build output DataFrame
emb_dim = len(emb_list[0]) if len(emb_list)>0 else EMBED_DIM
emb_cols = [f"emb_{i}" for i in range(emb_dim)]
out_df = pd.DataFrame(emb_list, columns=emb_cols)
out_df.insert(0, "stay_id", stay_list)
out_df.insert(0, "subject_id", subject_list)
out_df["true_label"] = label_list
out_df["pred_prob"] = prob_list

out_df.to_csv(OUT_CSV_PATH, index=False)
print("Saved embeddings & predictions to:", OUT_CSV_PATH)
print("Output shape:", out_df.shape)

# -----------------------
# 10) Save training history
# -----------------------
hist_df = pd.DataFrame(history)
hist_df.to_csv("/kaggle/working/training_history.csv", index=False)
print("Training history saved to /kaggle/working/training_history.csv")
print("Best model path:", OUT_MODEL_PATH)


Device: cuda
Raw shape: (43080, 23)
Columns: ['subject_id', 'hadm_id', 'stay_id', 'hour_bin', 'mean_value_dbp', 'mean_value_fio2', 'mean_value_gcs_motor', 'mean_value_gcs_total', 'mean_value_heart_rate', 'mean_value_resp_rate', 'mean_value_sbp', 'mean_value_spo2', 'mean_value_temperature', 'std_value_dbp', 'std_value_fio2', 'std_value_gcs_motor', 'std_value_gcs_total', 'std_value_heart_rate', 'std_value_resp_rate', 'std_value_sbp', 'std_value_spo2', 'std_value_temperature', 'ventilation_within_12h']
Using 18 feature columns.
Max seq length used: 6
Train stays: 5744, Val stays: 1436
Train samples: 5744 Val samples: 1436
Pos weight: 6.450064850841666




Epoch 1/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 001 | Train Loss: 0.7895 | Val Loss: 0.9075 | Val AUROC: 0.8412 | Val AUPR: 0.6898 | Val Acc: 0.9060 | Val F1: 0.6650
  -> New best model saved.


  output = torch._nested_tensor_from_mask(


Epoch 2/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 002 | Train Loss: 0.7193 | Val Loss: 0.8659 | Val AUROC: 0.8450 | Val AUPR: 0.7120 | Val Acc: 0.9269 | Val F1: 0.7273
  -> New best model saved.


Epoch 3/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 003 | Train Loss: 0.6580 | Val Loss: 0.7985 | Val AUROC: 0.8508 | Val AUPR: 0.7117 | Val Acc: 0.9241 | Val F1: 0.7241
  -> New best model saved.


Epoch 4/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 004 | Train Loss: 0.6400 | Val Loss: 0.9905 | Val AUROC: 0.8524 | Val AUPR: 0.7086 | Val Acc: 0.9178 | Val F1: 0.7050


Epoch 5/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 005 | Train Loss: 0.6440 | Val Loss: 0.8530 | Val AUROC: 0.8528 | Val AUPR: 0.7005 | Val Acc: 0.9262 | Val F1: 0.7022


Epoch 6/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 006 | Train Loss: 0.6344 | Val Loss: 0.8363 | Val AUROC: 0.8512 | Val AUPR: 0.7393 | Val Acc: 0.9318 | Val F1: 0.7407


Epoch 7/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 007 | Train Loss: 0.6417 | Val Loss: 0.7979 | Val AUROC: 0.8577 | Val AUPR: 0.7273 | Val Acc: 0.9123 | Val F1: 0.6912
  -> New best model saved.


Epoch 8/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 008 | Train Loss: 0.6368 | Val Loss: 0.9394 | Val AUROC: 0.8470 | Val AUPR: 0.7144 | Val Acc: 0.9102 | Val F1: 0.6892


Epoch 9/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 009 | Train Loss: 0.6219 | Val Loss: 0.9058 | Val AUROC: 0.8403 | Val AUPR: 0.6859 | Val Acc: 0.9032 | Val F1: 0.6651


Epoch 10/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 010 | Train Loss: 0.6179 | Val Loss: 0.9194 | Val AUROC: 0.8507 | Val AUPR: 0.7022 | Val Acc: 0.9345 | Val F1: 0.7500


Epoch 11/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 011 | Train Loss: 0.5975 | Val Loss: 0.8598 | Val AUROC: 0.8431 | Val AUPR: 0.6988 | Val Acc: 0.9185 | Val F1: 0.7068


Epoch 12/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 012 | Train Loss: 0.5882 | Val Loss: 0.8794 | Val AUROC: 0.8286 | Val AUPR: 0.7349 | Val Acc: 0.9046 | Val F1: 0.6821


Epoch 13/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 013 | Train Loss: 0.5759 | Val Loss: 0.8315 | Val AUROC: 0.8453 | Val AUPR: 0.7443 | Val Acc: 0.8726 | Val F1: 0.6273


Epoch 14/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 014 | Train Loss: 0.5834 | Val Loss: 0.8382 | Val AUROC: 0.8625 | Val AUPR: 0.7617 | Val Acc: 0.9220 | Val F1: 0.7172


Epoch 15/100:   0%|          | 0/90 [00:00<?, ?it/s]

Epoch 015 | Train Loss: 0.5668 | Val Loss: 0.8421 | Val AUROC: 0.8603 | Val AUPR: 0.7217 | Val Acc: 0.9032 | Val F1: 0.6790
Early stopping triggered. Patience exceeded.


Generating embeddings for all stays:   0%|          | 0/113 [00:00<?, ?it/s]

Saved embeddings & predictions to: /kaggle/working/temporal_embeddings_predictions.csv
Output shape: (7180, 132)
Training history saved to /kaggle/working/training_history.csv
Best model path: /kaggle/working/best_temporal_transformer.pt
