# ED-AI Modular Triage Pipeline
This notebook demonstrates a fully modular pipeline for Emergency Department triage prediction using Temporal Fusion Transformer (TFT) and XGBoost.
We use a baseline of triage, static vitals, demographics, and BERT embeddings, then optionally add diagnosis, home medications, and medication administrations as pluggable modules.
You can run the baseline immediately, and enrich with new features by toggling simple flags — no preprocessing rewrites required.


In [1]:
%pip install pandas numpy scikit-learn xgboost pytorch-forecasting torch transformers mlflow joblib matplotlib seaborn tqdm


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, brier_score_loss
from sklearn.impute import KNNImputer
import xgboost as xgb
import lightgbm as lgb
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
import torch
from transformers import AutoTokenizer, AutoModel
import mlflow
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')


In [3]:
# Which modules to include
use_diagnosis = False
use_medhome = False
use_pyxis = False


In [8]:
# Load core datasets
triage_df = pd.read_csv('data/diagnosis.csv')
vitals_df = pd.read_csv('data/vitalsign.csv')
print(f'Triage records: {len(triage_df)}')
print(f'Vitals records: {len(vitals_df)}')


Triage records: 899050
Vitals records: 1564610


In [17]:
def preprocess_core(triage, vitals):
    # Attach charttime to triage
    if 'intime' in triage.columns:
        triage['charttime'] = pd.to_datetime(triage['intime'])
    else:
        triage['charttime'] = pd.to_datetime(triage['charttime'])
    vitals['charttime'] = pd.to_datetime(vitals['charttime'])
    # Merge nearest vitals (within 5min)
    triage = triage.sort_values('charttime')
    vitals = vitals.sort_values('charttime')
    merged = pd.merge_asof(triage, vitals, on='charttime', by='subject_id', direction='nearest', tolerance=pd.Timedelta('5min'))
    # Impute missing vitals
    numeric_cols = ['temperature', 'heart_rate', 'blood_pressure_systolic', 'blood_pressure_diastolic', 'respiratory_rate', 'oxygen_saturation', 'pain_score']
    imputer = KNNImputer(n_neighbors=5)
    merged[numeric_cols] = imputer.fit_transform(merged[numeric_cols])
    return merged


In [10]:
def get_bert_embeddings(texts):
    tokenizer = AutoTokenizer.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
    model = AutoModel.from_pretrained('emilyalsentzer/Bio_ClinicalBERT')
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()
    embeddings = []
    for i in tqdm(range(0, len(texts), 8), desc='BERT embedding'):
        batch_texts = texts[i:i+8]
        encodings = tokenizer(batch_texts, padding=True, truncation=True, max_length=128, return_tensors='pt').to(device)
        with torch.no_grad():
            outputs = model(**encodings)
            embeddings.append(outputs.last_hidden_state[:,0,:].cpu().numpy())
    return np.vstack(embeddings)


In [11]:
def diagnosis_features(stay_ids):
    # load and process diagnosis.csv, embed ICD codes if use_diagnosis else return None
    if not use_diagnosis: return None
    diagnosis = pd.read_csv('../data/diagnosis.csv')
    # Example: mean ICD code embedding per stay_id (use your embedding logic)
    icd_map = {} # load your ICD embedding model
    stay_features = []
    for sid in stay_ids:
        codes = diagnosis[diagnosis['stay_id']==sid]['icd_code'].tolist()
        vec = np.mean([icd_map.get(c, np.zeros(128)) for c in codes], axis=0)
        stay_features.append(vec)
    return np.vstack(stay_features)

def medhome_features(stay_ids):
    if not use_medhome: return None
    medhome = pd.read_csv('../data/medrecon.csv')
    # Example: average embedding per stay_id
    med_map = {} # load med embedding model
    stay_features = []
    for sid in stay_ids:
        meds = medhome[medhome['stay_id']==sid]['med'].tolist()
        vec = np.mean([med_map.get(m, np.zeros(128)) for m in meds], axis=0)
        stay_features.append(vec)
    return np.vstack(stay_features)

def pyxis_features(pyxis_df, stay_ids, charttimes):
    if not use_pyxis: return None
    # Example: time-varying med administration
    pyxis = pd.read_csv('../data/pyxis.csv')
    # Map meds by time
    features = []
    for sid, ct in zip(stay_ids, charttimes):
        meds = pyxis[(pyxis['stay_id']==sid) & (pyxis['charttime']==ct)]['med'].tolist()
        # map to category IDs or embed as needed
        features.append(meds)
    return features


In [12]:
def build_tft_dataset(core_df, bert_vecs, diagnosis_vecs=None, medhome_vecs=None, pyxis_seq=None):
    # time-varying features: vitals
    # static covariates: age, gender, chief complaint embed (BERT), optional diagnosis and medhome
    static_features = ['age', 'gender']
    static_vec = np.concatenate([bert_vecs], axis=1)
    if diagnosis_vecs is not None:
        static_vec = np.concatenate([static_vec, diagnosis_vecs], axis=1)
    if medhome_vecs is not None:
        static_vec = np.concatenate([static_vec, medhome_vecs], axis=1)
    # Assemble TFT TimeSeriesDataSet (see pytorch_forecasting docs)
    # Placeholder: implement as per your use-case
    print(f'Static feature vector shape: {static_vec.shape}')
    # For demonstration, return the core_df and static_vec
    return core_df, static_vec

def build_xgb_dataset(core_df, bert_vecs, diagnosis_vecs=None, medhome_vecs=None):
    static_features = ['age', 'temperature', 'heart_rate', 'respiratory_rate', 'oxygen_saturation', 'blood_pressure_systolic', 'blood_pressure_diastolic', 'pain_score']
    X = core_df[static_features]
    X = pd.concat([X, pd.DataFrame(bert_vecs)], axis=1)
    if diagnosis_vecs is not None:
        X = pd.concat([X, pd.DataFrame(diagnosis_vecs)], axis=1)
    if medhome_vecs is not None:
        X = pd.concat([X, pd.DataFrame(medhome_vecs)], axis=1)
    return X


In [None]:
# Full updated cell: Merge triage with edstays, preprocess, embed, optional modules, build datasets with fix for missing vitals columns

# Load ED stays dataset (needed for 'intime' → 'charttime')
edstays_df = pd.read_csv('data/edstays.csv')
edstays_df['intime'] = pd.to_datetime(edstays_df['intime'])

# Merge triage_df with edstays_df ONCE to get 'intime' for charttime; handle suffixes to avoid conflicts
triage_df = triage_df.merge(
    edstays_df[['stay_id', 'subject_id', 'intime']],
    on=['stay_id', 'subject_id'],
    how='left',
    suffixes=('', '_edstays')
)

# Assign 'charttime' from whichever 'intime' column is present after merge
if 'intime' in triage_df.columns:
    triage_df['charttime'] = pd.to_datetime(triage_df['intime'])
elif 'intime_edstays' in triage_df.columns:
    triage_df['charttime'] = pd.to_datetime(triage_df['intime_edstays'])
else:
    raise KeyError("No 'intime' column found in merged triage_df")

def preprocess_core(triage, vitals):
    # Ensure datetime conversion
    triage['charttime'] = pd.to_datetime(triage['charttime'])
    vitals['charttime'] = pd.to_datetime(vitals['charttime'])

    # Sort by 'charttime' only for merge_asof
    triage = triage.sort_values('charttime')
    vitals = vitals.sort_values('charttime')

    # Merge asof nearest within 5 minutes and by subject_id
    merged = pd.merge_asof(
        triage, vitals,
        on='charttime',
        by='subject_id',
        direction='nearest',
        tolerance=pd.Timedelta('5min')
    )
    
    # Define vitals cols to impute; use those present in merged columns only
    numeric_cols = [
        'temperature', 'heart_rate', 'respiratory_rate', 'oxygen_saturation',
        'blood_pressure_systolic', 'blood_pressure_diastolic', 'pain_score'
    ]
    present_numeric_cols = [c for c in numeric_cols if c in merged.columns]
    
    if present_numeric_cols:
        imputer = KNNImputer(n_neighbors=5)
        merged[present_numeric_cols] = imputer.fit_transform(merged[present_numeric_cols])
    else:
        print("⚠️ Warning: No numeric vitals columns found for imputation.")
    
    return merged

# Run preprocessing core function with fix for missing vitals columns
core = preprocess_core(triage_df, vitals_df)

# Generate ClinicalBERT embeddings from chief complaint text safely
if 'chief_complaint' in core.columns:
    bert_vecs = get_bert_embeddings(core['chief_complaint'].fillna('').astype(str).values)
else:
    print("⚠️ Warning: 'chief_complaint' column not found; using zeros for embeddings.")
    bert_vecs = np.zeros((len(core), 768))  # fallback shape for ClinicalBERT embeddings

# Optionally run feature modules
diagnosis_vecs = diagnosis_features(core['stay_id']) if use_diagnosis else None
medhome_vecs = medhome_features(core['stay_id']) if use_medhome else None

if use_pyxis:
    pyxis_df = pd.read_csv('data/pyxis.csv')
    pyxis_seq = pyxis_features(pyxis_df, core['stay_id'], core['charttime'])
else:
    pyxis_seq = None

# Build TFT and XGBoost datasets
core_tft, static_vec = build_tft_dataset(core, bert_vecs, diagnosis_vecs, medhome_vecs, pyxis_seq)
X_xgb = build_xgb_dataset(core, bert_vecs, diagnosis_vecs, medhome_vecs)

# Create binary target variable safely; check for 'acuity' presence
if 'acuity' in core.columns:
    y = (core['acuity'] <= 2).astype(int)
else:
    print("⚠️ Warning: 'acuity' column not found; generating dummy target variable.")
    y = pd.Series(np.zeros(len(core), dtype=int))

print(f"Processed dataset shape: {core.shape}")
print(f"Target distribution:\n{y.value_counts()}")
