In [1]:

# --- Imports & Config ---
import os, re, gc, json
import numpy as np
import pandas as pd
from pathlib import Path
from joblib import dump, load

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    roc_auc_score, average_precision_score, accuracy_score, f1_score,
    classification_report, confusion_matrix, precision_recall_curve
)
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils.class_weight import compute_class_weight

# Path to your Kaggle CSV directory (MIMIC-III demo/mirror with lowercase headers).
DATA_DIR = "/home/sraja/.cache/kagglehub/datasets/asjad99/mimiciii/versions/1/mimic-iii-clinical-database-demo-1.4"  # <<< EDIT THIS

# Core settings
HOURS_WINDOW = 6
CHUNK_SIZE   = 1_000_000  # tune for your RAM

# Optional feature blocks
ENABLE_LABS = False   # set True to add LABEVENTS aggregates
ENABLE_GCS  = False   # set True to add GCS aggregates from CHARTEVENTS via D_ITEMS

# Known ITEMIDs (CareVue & MetaVision) for vitals; will be extended via D_ITEMS if present
KNOWN_ITEM_MAP = {
    "hr":       [211, 220045],
    "sysbp":    [51, 220179],
    "diabp":    [8368, 220180],
    "meanbp":   [52, 220181],
    "resprate": [618, 220210],
    "spo2":     [220277],
    "tempc":    [676, 223761],
}

def ensure_dir(p):
    Path(p).mkdir(parents=True, exist_ok=True)
ensure_dir("models")


In [2]:
# --- Helper functions ---
def normalize_cols(df: pd.DataFrame) -> pd.DataFrame:
    df.columns = [c.lower() for c in df.columns]
    return df

def safe_read_csv(data_dir, name, usecols=None, dtype=None):
    path = os.path.join(data_dir, name)
    df = pd.read_csv(path, dtype=dtype, low_memory=False)
    df = normalize_cols(df)
    if usecols:
        df = df[[c for c in [u.lower() for u in usecols] if c in df.columns]]
    return df

def is_stroke_icd9(code: str) -> bool:
    if code is None: return False
    c = re.sub(r"[^0-9]", "", str(code))
    if c.startswith(("430", "431", "432")):  # hemorrhagic incl. SAH
        return True
    if len(c) >= 3 and (c.startswith("433") or c.startswith("434")):
        return c.endswith("1")              # infarction present
    return False

def build_item_map(d_items_df: pd.DataFrame | None, base_map: dict) -> dict:
    item_map = {k: set(v) for k, v in base_map.items()}
    if d_items_df is None or "label" not in d_items_df.columns:
        return {k: sorted(v) for k, v in item_map.items()}
    di = d_items_df.copy()
    di["label_l"] = di["label"].str.lower()

    ITEM_KEYWORDS = {
        "hr":       ["heart rate"],
        "sysbp":    ["non invasive systolic", "systolic blood pressure"],
        "diabp":    ["non invasive diastolic", "diastolic blood pressure"],
        "meanbp":   ["non invasive mean", "mean blood pressure"],
        "resprate": ["respiratory rate"],
        "spo2":     ["spo2", "oxygen saturation"],
        "tempc":    ["temperature c", "temperature celsius", "temperature f"],
    }
    for var, kws in ITEM_KEYWORDS.items():
        for kw in kws:
            m = di[di["label_l"].str.contains(kw, na=False)]
            for iid in m["itemid"].tolist():
                item_map[var].add(int(iid))
    return {k: sorted(v) for k, v in item_map.items()}

def normalize_temp_to_c(values, uoms):
    vals = values.copy()
    if uoms is None: return vals
    mask_f = uoms.astype(str).str.contains("f", case=False, na=False)
    vals.loc[mask_f] = (vals.loc[mask_f] - 32.0) * (5.0/9.0)
    return vals

def aggregate_first6h_vitals(vitals_chunk, icu_df, item_map):
    merged = vitals_chunk.merge(
        icu_df[["icustay_id","hadm_id","intime"]],
        on="icustay_id",
        how="inner",
        suffixes=("", "_icu")
    )
    # Ensure single hadm_id
    if "hadm_id" not in merged.columns:
        if "hadm_id_icu" in merged.columns:
            merged["hadm_id"] = merged["hadm_id_icu"]
        elif "hadm_id_x" in merged.columns or "hadm_id_y" in merged.columns:
            merged["hadm_id"] = merged.get("hadm_id_x", merged.get("hadm_id_y"))

    merged["charttime"] = pd.to_datetime(merged["charttime"])
    merged["intime"]    = pd.to_datetime(merged["intime"])
    merged["hours_since_intime"] = (merged["charttime"] - merged["intime"]).dt.total_seconds()/3600.0
    merged = merged[(merged["hours_since_intime"] >= 0) & (merged["hours_since_intime"] <= HOURS_WINDOW)]

    # Map ITEMID → item
    inv = {}
    for name, ids in item_map.items():
        for iid in ids:
            inv[int(iid)] = name
    merged["item"] = merged["itemid"].map(inv)

    merged = merged.dropna(subset=["item","valuenum"]).copy()
    merged["valuenum"] = pd.to_numeric(merged["valuenum"], errors="coerce")
    merged = merged.dropna(subset=["valuenum"])

    if "valueuom" in merged.columns:
        tmask = merged["item"].eq("tempc")
        if tmask.any():
            merged.loc[tmask, "valuenum"] = normalize_temp_to_c(merged.loc[tmask, "valuenum"], merged.loc[tmask, "valueuom"])

    def agg_one(df):
        df = df.sort_values("charttime")
        vals = df["valuenum"].values.astype(float)
        times = df["hours_since_intime"].values
        res = {
            "mean": float(np.mean(vals)),
            "min":  float(np.min(vals)),
            "max":  float(np.max(vals)),
            "std":  float(np.std(vals)) if vals.size > 1 else 0.0,
            "last": float(vals[-1]),
            "slope": 0.0
        }
        if vals.size >= 2 and (times[-1]-times[0]) > 0:
            res["slope"] = float((vals[-1] - vals[0]) / (times[-1] - times[0]))
        return pd.Series(res)

    if merged.empty:
        return pd.DataFrame(columns=["hadm_id"])

    agg = merged.groupby(["hadm_id","item"], as_index=True).apply(agg_one)
    wide = agg.unstack("item")
    wide.columns = [f"{stat}_{item}" for stat, item in wide.columns]
    wide = wide.reset_index()
    return wide

In [3]:

# --- Load base tables (lowercase columns) ---
print("Loading ADMISSIONS, ICUSTAYS, DIAGNOSES_ICD, D_ITEMS ...")
adm  = safe_read_csv(DATA_DIR, "ADMISSIONS.csv",   usecols=["subject_id","hadm_id"])
icu  = safe_read_csv(DATA_DIR, "ICUSTAYS.csv",     usecols=["icustay_id","hadm_id","intime"])
diag = safe_read_csv(DATA_DIR, "DIAGNOSES_ICD.csv",usecols=["subject_id","hadm_id","icd9_code"])

# optional D_ITEMS
try:
    d_items = safe_read_csv(DATA_DIR, "D_ITEMS.csv", usecols=["itemid","label","dbsource"])
except Exception:
    d_items = None

# stroke labels by hadm_id
diag["label"] = diag["icd9_code"].apply(is_stroke_icd9)
labels = diag.groupby("hadm_id", as_index=False)["label"].max()
print("Labels:", labels["label"].value_counts(dropna=False).to_dict())

# Build item map (extend known ids using D_ITEMS labels if available)
item_map = build_item_map(d_items, KNOWN_ITEM_MAP)
all_itemids = sorted({iid for ids in item_map.values() for iid in ids})
print("Item map (first few):", {k: (list(v)[:5] + (['...'] if len(v) > 5 else [])) for k, v in item_map.items()})


Loading ADMISSIONS, ICUSTAYS, DIAGNOSES_ICD, D_ITEMS ...
Labels: {False: 123, True: 6}
Item map (first few): {'hr': [211, 3494, 220045, 220046, 220047], 'sysbp': [51, 220179], 'diabp': [8368, 220180], 'meanbp': [52, 220181], 'resprate': [618, 619, 220210, 224688, 224689, '...'], 'spo2': [646, 5820, 6719, 8554, 220277, '...'], 'tempc': [676, 677, 678, 679, 223761, '...']}


In [4]:
# --- Extract vitals from CHARTEVENTS (chunked) ---
chartevents_path = os.path.join(DATA_DIR, "CHARTEVENTS.csv")
usecols = ["row_id","subject_id","hadm_id","icustay_id","charttime","itemid","valuenum","valueuom"]

features_wide_list = []
print("Scanning CHARTEVENTS.csv in chunks ...")
chunk_iter = pd.read_csv(chartevents_path, usecols=usecols, chunksize=CHUNK_SIZE,
                         dtype={"itemid":"int32","icustay_id":"float64"}, low_memory=True)
for i, chunk in enumerate(chunk_iter, 1):
    chunk = normalize_cols(chunk)
    chunk = chunk[chunk["itemid"].isin(all_itemids)]
    chunk = chunk.dropna(subset=["icustay_id","hadm_id"])
    if chunk.empty:
        continue
    chunk["icustay_id"] = chunk["icustay_id"].astype(int)

    wide = aggregate_first6h_vitals(chunk, icu, item_map)
    if not wide.empty:
        features_wide_list.append(wide)

    del chunk, wide
    if i % 5 == 0:
        print(f" processed {i} chunks ...")
        gc.collect()

if not features_wide_list:
    raise RuntimeError("No vitals found for selected ITEMIDs; check DATA_DIR and item_map.")
features_wide = pd.concat(features_wide_list, ignore_index=True).sort_values("hadm_id").drop_duplicates(subset=["hadm_id"], keep="last")
print("Vitals wide shape:", features_wide.shape)


Scanning CHARTEVENTS.csv in chunks ...


  for i, chunk in enumerate(chunk_iter, 1):
  agg = merged.groupby(["hadm_id","item"], as_index=True).apply(agg_one)


Vitals wide shape: (125, 43)


In [5]:
# --- (Optional) Add LABEVENTS aggregates in first 6h ---
if ENABLE_LABS:
    def safe_read_csv_lower(path, usecols=None, dtype=None):
        df = pd.read_csv(path, dtype=dtype, low_memory=False)
        df = normalize_cols(df)
        if usecols:
            keep = [c for c in usecols if c in df.columns]
            df = df[keep]
        return df

    labs_path = os.path.join(DATA_DIR, "LABEVENTS.csv")
    if os.path.exists(labs_path):
        labs = safe_read_csv_lower(labs_path, usecols=["subject_id","hadm_id","itemid","charttime","valuenum","valueuom"])
        labs = labs.dropna(subset=["hadm_id","valuenum"]).copy()
        # Join ICU intime (by hadm_id) to compute window
        icu_unique = icu[["hadm_id","intime"]].drop_duplicates("hadm_id")
        labs = labs.merge(icu_unique, on="hadm_id", how="left")
        labs["charttime"] = pd.to_datetime(labs["charttime"])
        labs["intime"]    = pd.to_datetime(labs["intime"])
        labs["h"] = (labs["charttime"] - labs["intime"]).dt.total_seconds()/3600.0
        labs = labs[(labs["h"]>=0) & (labs["h"]<=HOURS_WINDOW)]
        labs["valuenum"] = pd.to_numeric(labs["valuenum"], errors="coerce")
        labs = labs.dropna(subset=["valuenum"])

        def agg_lab(df):
            df = df.sort_values("charttime")
            v = df["valuenum"].values
            t = df["h"].values
            out = {
                "lab_mean": float(np.mean(v)),
                "lab_min":  float(np.min(v)),
                "lab_max":  float(np.max(v)),
                "lab_std":  float(np.std(v)) if v.size > 1 else 0.0,
                "lab_last": float(v[-1]),
                "lab_slope": 0.0
            }
            if v.size >= 2 and (t[-1]-t[0]) > 0:
                out["lab_slope"] = float((v[-1]-v[0])/(t[-1]-t[0]))
            return pd.Series(out)

        lab_wide = labs.groupby(["hadm_id","itemid"], as_index=True).apply(agg_lab).unstack("itemid")
        lab_wide.columns = [f"{stat}_lab_{itemid}" for stat, itemid in lab_wide.columns]
        lab_wide = lab_wide.reset_index()
        features_wide = features_wide.merge(lab_wide, on="hadm_id", how="left")
        print("After labs →", features_wide.shape)
    else:
        print("LABEVENTS.csv not found — skipping.")


In [6]:
# --- (Optional) Add GCS features from CHARTEVENTS via D_ITEMS ---
if ENABLE_GCS:
    d_items_path = os.path.join(DATA_DIR, "D_ITEMS.csv")
    if os.path.exists(d_items_path):
        di = pd.read_csv(d_items_path, low_memory=False)
        di = normalize_cols(di)
        di["label_l"] = di["label"].str.lower()

        def find_ids(kws):
            m = di[di["label_l"].str.contains("|".join([re.escape(k) for k in kws]), na=False)]
            return sorted(set(m["itemid"].astype(int).tolist()))

        ids_gcs_total = find_ids(["glasgow coma scale total","gcs total","gcs - total"])
        ids_gcs_eye   = find_ids(["glasgow coma scale eye","gcs eye","gcs - eye opening"])
        ids_gcs_verbal= find_ids(["glasgow coma scale verbal","gcs verbal","gcs - verbal response"])
        ids_gcs_motor = find_ids(["glasgow coma scale motor","gcs motor","gcs - motor response"])

        gcs_ids_all = set(ids_gcs_total + ids_gcs_eye + ids_gcs_verbal + ids_gcs_motor)
        usecols_ce = ["subject_id","hadm_id","icustay_id","charttime","itemid","valuenum"]
        gcs_parts = []

        chunk_iter = pd.read_csv(chartevents_path, usecols=usecols_ce, chunksize=CHUNK_SIZE, low_memory=True)
        for ch in chunk_iter:
            ch = normalize_cols(ch)
            ch = ch[ch["itemid"].isin(gcs_ids_all)]
            if ch.empty: continue
            ch = ch.merge(icu[["icustay_id","hadm_id","intime"]], on="icustay_id", how="inner")
            ch["charttime"] = pd.to_datetime(ch["charttime"])
            ch["intime"]    = pd.to_datetime(ch["intime"])
            ch["h"] = (ch["charttime"] - ch["intime"]).dt.total_seconds()/3600.0
            ch = ch[(ch["h"]>=0) & (ch["h"]<=HOURS_WINDOW)]
            ch["valuenum"] = pd.to_numeric(ch["valuenum"], errors="coerce")
            ch = ch.dropna(subset=["valuenum"])
            gcs_parts.append(ch)

        if gcs_parts:
            gcs_long = pd.concat(gcs_parts, ignore_index=True)

            def comp_of(i):
                if i in ids_gcs_total: return "gcs_total"
                if i in ids_gcs_eye:   return "gcs_eye"
                if i in ids_gcs_verbal:return "gcs_verbal"
                if i in ids_gcs_motor: return "gcs_motor"
                return None

            gcs_long["comp"] = gcs_long["itemid"].map(comp_of)
            gcs_long = gcs_long.dropna(subset=["comp"])

            def agg_gcs(df):
                df = df.sort_values("charttime")
                v = df["valuenum"].values
                return pd.Series({
                    "mean": float(np.mean(v)),
                    "min":  float(np.min(v)),
                    "max":  float(np.max(v)),
                    "last": float(v[-1]),
                })

            gcs_wide = gcs_long.groupby(["hadm_id","comp"]).apply(agg_gcs).unstack("comp")
            gcs_wide.columns = [f"gcs_{stat}_{comp}" for stat, comp in gcs_wide.columns]
            gcs_wide = gcs_wide.reset_index()
            features_wide = features_wide.merge(gcs_wide, on="hadm_id", how="left")
            print("After GCS →", features_wide.shape)
        else:
            print("No GCS rows found — check D_ITEMS keywords.")
    else:
        print("D_ITEMS.csv not found — skipping GCS.")


In [7]:
# --- Build ML dataset (merge labels + subject ids) ---
dataset = features_wide.merge(labels, on="hadm_id", how="inner")
dataset = dataset.merge(adm, on="hadm_id", how="left")  # adds subject_id
dataset = dataset.dropna(subset=["label"]).copy()
dataset["label"] = dataset["label"].astype(int)
print("Dataset shape:", dataset.shape, "positives:", int(dataset["label"].sum()))

# --- Subject-level stratified split ---
subj_lab = (
    dataset[["subject_id","label"]]
    .groupby("subject_id", as_index=False)["label"]
    .max()
    .rename(columns={"label":"subj_label"})
)

VAL_SIZE = 0.2
sss = StratifiedShuffleSplit(n_splits=1, test_size=VAL_SIZE, random_state=42)
(subj_tr_idx, subj_va_idx), = sss.split(subj_lab["subject_id"], subj_lab["subj_label"])
train_subjects = set(subj_lab.iloc[subj_tr_idx]["subject_id"])
val_subjects   = set(subj_lab.iloc[subj_va_idx]["subject_id"])

feature_cols = [c for c in dataset.columns if c not in ("hadm_id","subject_id","label")]
X_train = dataset.loc[dataset["subject_id"].isin(train_subjects), feature_cols].copy()
y_train = dataset.loc[dataset["subject_id"].isin(train_subjects), "label"].astype(int).values
X_val   = dataset.loc[dataset["subject_id"].isin(val_subjects),   feature_cols].copy()
y_val   = dataset.loc[dataset["subject_id"].isin(val_subjects),   "label"].astype(int).values

# impute with training medians
fill_values = X_train.median().to_dict()
X_train = X_train.fillna(fill_values)
X_val   = X_val.fillna(fill_values)

print("Train balance:", np.bincount(y_train) if len(y_train)>0 else "[]")
print("Val   balance:", np.bincount(y_val) if len(y_val)>0 else "[]")

# --- Class weights for imbalance ---
classes = np.array([0, 1])
weights = compute_class_weight(class_weight="balanced", classes=classes, y=y_train)
class_weight = {int(c): float(w) for c, w in zip(classes, weights)}
print("Class weight:", class_weight)

# --- Train RandomForest ---
rf = RandomForestClassifier(
    n_estimators=600,
    max_depth=None,
    min_samples_leaf=3,
    max_features="sqrt",
    n_jobs=-1,
    class_weight="balanced_subsample",
    random_state=42
)
rf.fit(X_train, y_train)

# --- Evaluate at default 0.5 ---
val_prob = rf.predict_proba(X_val)[:, 1]
val_pred = (val_prob >= 0.5).astype(int)
metrics = {
    "AUROC": roc_auc_score(y_val, val_prob) if len(np.unique(y_val))==2 else float("nan"),
    "AUPRC": average_precision_score(y_val, val_prob) if len(np.unique(y_val))==2 else float("nan"),
    "Accuracy": accuracy_score(y_val, val_pred),
    "F1": f1_score(y_val, val_pred, zero_division=0),
}
print("\nValidation metrics @0.5:", {k: (None if (isinstance(v,float) and np.isnan(v)) else round(v,4)) for k, v in metrics.items()})
print("\nClassification report @0.5:\n", classification_report(y_val, val_pred, target_names=["no-stroke(0)","stroke(1)"], zero_division=0))
print("Confusion matrix @0.5:\n", confusion_matrix(y_val, val_pred, labels=[0,1]))


Dataset shape: (125, 45) positives: 6
Train balance: [95  5]
Val   balance: [24  1]
Class weight: {0: 0.5263157894736842, 1: 10.0}

Validation metrics @0.5: {'AUROC': np.float64(1.0), 'AUPRC': np.float64(1.0), 'Accuracy': 0.96, 'F1': 0.0}

Classification report @0.5:
               precision    recall  f1-score   support

no-stroke(0)       0.96      1.00      0.98        24
   stroke(1)       0.00      0.00      0.00         1

    accuracy                           0.96        25
   macro avg       0.48      0.50      0.49        25
weighted avg       0.92      0.96      0.94        25

Confusion matrix @0.5:
 [[24  0]
 [ 1  0]]


In [8]:
# --- Threshold tuning via PR curve + save model pack ---
prec, rec, th = precision_recall_curve(y_val, val_prob)
thr_all = np.r_[th, 1.0]
f1s = (2*prec*rec)/(prec+rec+1e-12)
best_idx = int(np.nanargmax(f1s))
tuned_thr = float(thr_all[best_idx])
print(f"Best-F1 threshold: {tuned_thr:.4f}  (P={prec[best_idx]:.3f}, R={rec[best_idx]:.3f})")

y_val_opt = (val_prob >= tuned_thr).astype(int)
print("\nClassification report @ tuned threshold:\n",
      classification_report(y_val, y_val_opt, target_names=["no-stroke(0)","stroke(1)"], zero_division=0))
print("Confusion matrix @ tuned threshold:\n", confusion_matrix(y_val, y_val_opt, labels=[0,1]))

pack = {
    "model": rf,
    "feature_cols": feature_cols,
    "fill_values": fill_values,
    "hours_window": HOURS_WINDOW,
    "item_map": {k:list(v) for k,v in KNOWN_ITEM_MAP.items()},  # base ids (for reference)
    "threshold": tuned_thr,
    "threshold_note": "best_F1_on_validation",
}
dump(pack, "models/mimiciii_vitals_rf.joblib")
print("\nSaved → models/mimiciii_vitals_rf.joblib (threshold saved)")


Best-F1 threshold: 0.2206  (P=1.000, R=1.000)

Classification report @ tuned threshold:
               precision    recall  f1-score   support

no-stroke(0)       1.00      1.00      1.00        24
   stroke(1)       1.00      1.00      1.00         1

    accuracy                           1.00        25
   macro avg       1.00      1.00      1.00        25
weighted avg       1.00      1.00      1.00        25

Confusion matrix @ tuned threshold:
 [[24  0]
 [ 0  1]]

Saved → models/mimiciii_vitals_rf.joblib (threshold saved)


In [9]:
# --- Inference helper ---
def predict_vitals_stroke_prob(vitals_row: dict | pd.Series):
    """Returns (probability, predicted_label_using_saved_threshold)."""
    pk = load("models/mimiciii_vitals_rf.joblib")
    model = pk["model"]
    cols  = pk["feature_cols"]
    fill  = pk["fill_values"]
    thr   = float(pk.get("threshold", 0.5))
    x = pd.DataFrame([vitals_row], columns=cols).fillna(fill)
    prob = float(model.predict_proba(x)[:, 1][0])
    pred = int(prob >= thr)
    return prob, pred

# Example (fill real values from a new admission's first-6h aggregates):
# example_row = {c: 0.0 for c in feature_cols}
# prob, pred = predict_vitals_stroke_prob(example_row)
# print(prob, pred)
