In [8]:
# All imports and setup
import os, math, json, sys, warnings
from datetime import datetime, timezone
import numpy as np
import pandas as pd
from tqdm import tqdm
import h3
import pytz
import holidays
import requests
import folium
from shapely.geometry import Polygon
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import average_precision_score, brier_score_loss
from xgboost import XGBClassifier

# Setup
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 120)
os.chdir("..")  # Go to project root
print("Current working directory:", os.getcwd())

Current working directory: /Users/lucchen/Desktop


## 0) Configuration

In [None]:
RAW_CSV = "data/raw/US_Accidents_March23.csv"  #change
OUT_DIR = "nb_artifacts"                       # notebook outputs go here
# Los Angeles
BBOX = {
    "min_lat": 33.5,
    "min_lng": -119.0,
    "max_lat": 34.9,
    "max_lng": -117.0
}

H3_RES = 8

# Time window in UTC (panel range)
START_UTC = "2022-01-01T00:00:00Z"
END_UTC   = "2023-01-01T00:00:00Z"

# Negative sampling fraction (keep this % of negatives)
NEG_FRAC = 0.05
RANDOM_SEED = 42

# Model params
XGB_PARAMS = dict(
    n_estimators=400,
    max_depth=6,
    learning_rate=0.06,
    subsample=0.9,
    colsample_bytree=0.9,
    reg_lambda=1.0,
)

# Evaluation splits (forward chaining)
TRAIN_END = "2022-07-01T00:00:00Z"
VAL_END   = "2022-10-01T00:00:00Z"
TEST_END  = "2023-01-01T00:00:00Z"

TOPK = 50


In [None]:
import os
os.chdir("..")  # go up one level from notebooks/ to project root
print("Current working directory:", os.getcwd())


'/Users/lucchen/Desktop/accident-risk/notebooks'

In [None]:
RAW_CSV

'data/raw/US_Accidents_March23.csv'

## 1) Load & filter US‑Accidents

In [None]:
#assert os.path.exists(RAW_CSV), f"Place the Kaggle CSV at: {RAW_CSV}"

# must-have
base_cols = ["ID","Start_Time","Start_Lat","Start_Lng","City","County","State","Timezone"]

# nice-to-have infra columns (some files may miss a few)
infra_cols = [
    "Amenity","Bump","Crossing","Give_Way","Junction","No_Exit","Railway",
  "Roundabout","Station","Stop","Traffic_Calming","Traffic_Signal","Turning_Loop"
]


# detect what’s present in this CSV
header_cols = pd.read_csv(RAW_CSV, nrows=0).columns.tolist()
available_infra = [c for c in infra_cols if c in header_cols]

usecols = list(dict.fromkeys(base_cols + available_infra))  # de-dup
acc = pd.read_csv(RAW_CSV, usecols=usecols, low_memory=False)
len_before = len(acc)

# bbox filter
acc = acc[(acc.Start_Lat >= BBOX["min_lat"]) & (acc.Start_Lat <= BBOX["max_lat"]) &
          (acc.Start_Lng >= BBOX["min_lng"]) & (acc.Start_Lng <= BBOX["max_lng"])].copy()

# timestamps
acc["Start_Time"] = pd.to_datetime(acc["Start_Time"], utc=True, errors="coerce")
acc = acc.dropna(subset=["Start_Time","Start_Lat","Start_Lng"])

# time window
start_utc = pd.to_datetime(START_UTC, utc=True)
end_utc   = pd.to_datetime(END_UTC, utc=True)
acc = acc[(acc["Start_Time"] >= start_utc) & (acc["Start_Time"] < end_utc)].copy()

print(f"Loaded {len_before:,} rows; after bbox/time filters: {len(acc):,} rows")
print("Infra columns loaded:", available_infra)

FileNotFoundError: [Errno 2] No such file or directory: 'data/raw/US_Accidents_March23.csv'

In [None]:
import requests

city = "Los Angeles, California"
url = f"https://nominatim.openstreetmap.org/search?q={city}&format=json"

headers = {"User-Agent": "accident-risk-demo/1.0 (lucchen@harvard.edu)"}

r = requests.get(url, headers=headers)
r.raise_for_status()
data = r.json()

bbox = data[0]["boundingbox"]
print("Bounding box:", bbox)


## 2) Build H3 grid for bbox

In [None]:
import requests
import h3
import pandas as pd
import folium

# --- 1. Get Los Angeles city boundary polygon from OpenStreetMap ---
url = "https://nominatim.openstreetmap.org/search.php?q=Los+Angeles+California&polygon_geojson=1&format=json"
r = requests.get(url, headers={"User-Agent": "LA-grid-demo"}).json()
geojson_poly = r[0]["geojson"]  # Full GeoJSON

# --- 2. Create an H3 LatLngPoly shape object (required for H3 4.x) ---
coords = geojson_poly["coordinates"][0]  # Outer ring
poly = h3.LatLngPoly([(lat, lng) for lng, lat in coords])  # Note: swap lon→lat order

# --- 3. Generate hex cells ---
cells = list(h3.polygon_to_cells(poly, res=8))  # official new method (pleae dont chaneg this)
cells_df = pd.DataFrame({"h3_id": cells})
print(f"H3 res=8 → {len(cells_df):,} cells inside Los Angeles")

# --- 4. Visualize ---
# --- Visualization ---
m = folium.Map(location=[34.05, -118.25], zoom_start=10, tiles="cartodb positron")

# 
for h in cells_df["h3_id"]:
    boundary = h3.cell_to_boundary(h)
    folium.Polygon(
        locations=[(lat, lng) for lat, lng in boundary],
        color="blue",
        weight=0.5,
        fill=True,
        fill_opacity=0.3,
    ).add_to(m)
m



## 3) Build full cell × hour panel and label `y`

In [None]:
# --- Use the LA polygon-derived cells you already built above ---
H3_RES = 8  # must match the res you used in polygon_to_cells
cells_df = cells_df.drop_duplicates(subset=["h3_id"]).reset_index(drop=True)

print(f"[LA] H3 res {H3_RES}: {len(cells_df):,} cells")

# --- Build hours index and panel ---
hours = pd.date_range(start=start_utc, end=end_utc, freq="H",
                      inclusive="left", tz="UTC")
hours_df = pd.DataFrame({"ts_utc": hours})

cells_df["key"] = 1
hours_df["key"] = 1
panel = cells_df.merge(hours_df, on="key").drop(columns=["key"])

# Ensure tz-aware UTC
panel["ts_utc"] = pd.to_datetime(panel["ts_utc"], utc=True)

# --- Map accidents to (h3, hour) and count them ---
# 1) Compute each accident's h3 cell at the SAME resolution as the grid
acc = acc.copy()
acc["h3_id"] = acc.apply(
    lambda r: h3.latlng_to_cell(r["Start_Lat"], r["Start_Lng"], H3_RES), axis=1
)

# 2) Keep only accidents that fall inside the LA cell set
la_cells = set(cells_df["h3_id"].tolist())
acc = acc[acc["h3_id"].isin(la_cells)].copy()

# 3) Floor times to hour in UTC (assumes acc["Start_Time"] is already UTC-aware)
acc["ts_utc"] = pd.to_datetime(acc["Start_Time"], utc=True).dt.floor("H")

# 4) Count accidents per cell-hour
acc_counts = (acc.groupby(["h3_id", "ts_utc"])
                .size()
                .rename("cnt")
                .reset_index())

# --- Label merge ---
if "y" in panel.columns:
    panel = panel.drop(columns=["y"])

panel = panel.merge(acc_counts, on=["h3_id", "ts_utc"], how="left")
panel["y"] = (panel["cnt"].fillna(0) > 0).astype("int8")
panel = panel.drop(columns=["cnt"])

# --- Optional: static road/infrastructure flags (if present in your CSV) ---
infra_cols = [
    "Amenity","Bump","Crossing","Give_Way","Junction","No_Exit","Railway",
    "Roundabout","Station","Stop","Traffic_Calming","Traffic_Signal","Turning_Loop"
]
available_infra = [c for c in infra_cols if c in acc.columns]
print("Infra cols available:", available_infra)

if available_infra:
    def to01(s):
        return s.map({True:1, False:0, "True":1, "False":0, 1:1, 0:0}).fillna(0).astype("int8")
    for c in available_infra:
        acc[c] = to01(acc[c])

    static_feats = (
        acc.groupby("h3_id")[available_infra]
           .mean()                      # fraction of historical crashes with that attribute
           .reset_index()
    )

    panel = panel.merge(static_feats, on="h3_id", how="left")
    for c in available_infra:
        panel[c] = panel[c].fillna(0).astype("float32")
else:
    print("No infrastructure columns in this CSV; skipping static road features.")

# --- Quick sanity checks ---
print(panel["y"].value_counts(dropna=False))

# Sample one accident row and verify its panel label
if len(acc) > 0:
    row = acc.sample(1, random_state=0).iloc[0]
    mask = (panel["h3_id"] == row["h3_id"]) & (panel["ts_utc"] == row["ts_utc"])
    print("Matches for sampled accident:", mask.sum())
    display(panel.loc[mask, ["h3_id","ts_utc","y"]].head())
else:
    print("Warning: No accidents landed inside the LA H3 grid for the given time window.")

## 4) Time/holiday features

In [None]:
ts = pd.to_datetime(panel["ts_utc"], utc=True)

panel["hour"]       = ts.dt.hour.astype("int16")
panel["dow"]        = ts.dt.dayofweek.astype("int16")
panel["month"]      = ts.dt.month.astype("int16")
panel["is_weekend"] = (panel["dow"]>=5).astype("int8")

years = sorted(set(ts.dt.year.tolist()))
us_holidays = holidays.UnitedStates(years=years)
panel["is_holiday"] = ts.dt.date.astype("O").map(lambda d: 1 if d in us_holidays else 0).astype("int8")

## 5) Lag features (computed **strictly** from past values)

In [None]:
panel = panel.sort_values(["h3_id","ts_utc"])
def compute_lags(g):
    y = g["y"].astype(int)
    g["lag_1h"] = y.shift(1).fillna(0)
    g["lag_3h"] = y.shift(1).rolling(3).sum().fillna(0)
    g["lag_24h"]= y.shift(1).rolling(24).sum().fillna(0)
    g["lag_7d_sum"]   = y.shift(1).rolling(168, min_periods=1).sum()
    g["lag_30d_sum"]  = y.shift(1).rolling(720, min_periods=1).sum()

    return g

panel = panel.groupby("h3_id", group_keys=False).apply(compute_lags)

## 6 get weather data using meteostat integration

temp_c: Air temperature, measured 2 m above ground level.  Represents the ambient air temperature for that hour.

precip_mm: Hourly total precipitation (liquid equivalent of rain, drizzle, snow, etc.). 0 = no precipitation during that hour.

wind_kph: Mean wind speed over the hour, measured at 10 m height.

pressure_hpa: Mean sea-level air pressure.  Standard atmospheric pressure ≈ 1013 hPa.

humidity_pct: Relative humidity – ratio of current air moisture to the maximum possible at that temperature (0–100 %).



In [None]:

# we use meteostat (LA)
from meteostat import Stations, Hourly
import pandas as pd
import numpy as np

# 1) Use naive start/end for Meteostat, but keep everything UTC in our panel
start_dt = pd.to_datetime(START_UTC, utc=True).tz_convert("UTC").tz_localize(None)
end_dt   = pd.to_datetime(END_UTC,   utc=True).tz_convert("UTC").tz_localize(None)

# 2) One centroid per cell to find nearest station
cell_centroids = (
    panel[["h3_id"]].drop_duplicates()
    .assign(
        lat=lambda df: df["h3_id"].map(lambda h: h3.cell_to_latlng(h)[0]),
        lng=lambda df: df["h3_id"].map(lambda h: h3.cell_to_latlng(h)[1]),
    )
)

def nearest_station_id(lat, lng, k=20):
    st = Stations().nearby(lat, lng).fetch(k)   # fetch more; LA is dense
    return None if st.empty else st.index[0]

cell_centroids["station_id"] = [
    nearest_station_id(r.lat, r.lng) for r in cell_centroids.itertuples(index=False)
]
cell_centroids = cell_centroids.dropna(subset=["station_id"]).reset_index(drop=True)

# 3) Fetch hourly weather for each station and normalize columns
station_ids = sorted(cell_centroids["station_id"].unique().tolist())
weather_frames = []
for sid in station_ids:
    try:
        w = Hourly(sid, start_dt, end_dt, model=True).fetch()
        if w.empty:
            continue

        # robust tz: localize if naive, else convert
        if getattr(w.index, "tz", None) is None:
            w.index = w.index.tz_localize("UTC")
        else:
            w = w.tz_convert("UTC")

        w = w.reset_index().rename(columns={"time": "ts_utc"})

        # Map raw Meteostat names -> our names
        rename_map = {
            "temp": "temp_c",       # °C
            "prcp": "precip_mm",    # mm
            "wspd": "wind_kph",     # km/h
            "pres": "pressure_hpa", # hPa
            "rhum": "humidity_pct", # %
            "vsby": "vis_km"        # km
        }
        for src, dst in rename_map.items():
            if src in w.columns:
                w[dst] = pd.to_numeric(w[src], errors="coerce")

        keep = ["ts_utc"] + [v for v in rename_map.values() if v in w.columns]
        w = w[keep].copy()
        w["station_id"] = sid
        weather_frames.append(w)
    except Exception as e:
        print(f"[weather] skip {sid}: {e}")

if weather_frames:
    weather = pd.concat(weather_frames, ignore_index=True)

    # 4) attach station → cell, then merge to panel on (h3_id, ts_utc)
    weather_full = (
        cell_centroids[["h3_id","station_id"]]
        .merge(weather, on="station_id", how="left")
        .drop(columns=["station_id"])
    )

    # ensure UTC & same dtype as panel
    weather_full["ts_utc"] = pd.to_datetime(weather_full["ts_utc"], utc=True)

    panel = panel.merge(weather_full, on=["h3_id","ts_utc"], how="left")

    # 5) Fill reasonable gaps: per-cell ffill/bfill → per-hour citywide median
    w_cols = [c for c in ["temp_c","precip_mm","wind_kph","pressure_hpa",
                          "humidity_pct","vis_km"] if c in panel.columns]
    panel = panel.sort_values(["h3_id","ts_utc"])
    for c in w_cols:
        panel[c] = (panel.groupby("h3_id")[c]
                        .apply(lambda s: s.ffill().bfill())
                        .reset_index(level=0, drop=True))
        panel[c] = panel[c].fillna(panel.groupby("ts_utc")[c].transform("median"))

    # 6) SAFE lags (only from the past) — avoid nowcast leakage
    def add_weather_lags(g):
        for c in w_cols:
            g[f"{c}_lag1"] = g[c].shift(1)
            if c == "precip_mm":
                g[f"{c}_lag3_sum"]  = g[c].shift(1).rolling(3).sum()
                g[f"{c}_lag24_sum"] = g[c].shift(1).rolling(24).sum()
            else:
                g[f"{c}_lag3_mean"]  = g[c].shift(1).rolling(3).mean()
                g[f"{c}_lag24_mean"] = g[c].shift(1).rolling(24).mean()
        return g

    panel = panel.groupby("h3_id", group_keys=False).apply(add_weather_lags)

    # drop contemporaneous weather to keep only lags
    #panel = panel.drop(columns=w_cols, errors="ignore")

    # quick check
    lag_cols = [c for c in panel.columns if c.startswith(("temp_c_","precip_mm_","wind_kph_",
                                                          "pressure_hpa_","humidity_pct_","vis_km_"))]
    print(f"[weather] added lag features: {len(lag_cols)} cols")
else:
    print("[weather] no station data fetched; skipping weather features.")

## 8) Train XGBoost (weighted) + 9) Calibrate with Isotonic

In [None]:
import numpy as np, pandas as pd

# 1️⃣ Forward time splits
TRAIN_END = "2022-07-01T00:00:00Z"
VAL_END   = "2022-10-01T00:00:00Z"
TEST_END  = "2023-01-01T00:00:00Z"

ts = pd.to_datetime(panel["ts_utc"], utc=True)
train_end = pd.to_datetime(TRAIN_END, utc=True)
val_end   = pd.to_datetime(VAL_END,   utc=True)
test_end  = pd.to_datetime(TEST_END,  utc=True)

def split_name(t):
    if t < train_end: return "train"
    elif t < val_end: return "val"
    elif t < test_end: return "test"
    else: return "ignore"

panel["split"] = [split_name(t) for t in ts]
panel = panel[panel["split"]!="ignore"].copy()

# we can do neg sampling later
NEG_FRAC = 1
np.random.seed(42)

train_df = panel[panel["split"]=="train"].copy()
val_df   = panel[panel["split"]=="val"].copy()
test_df  = panel[panel["split"]=="test"].copy()

pos = train_df[train_df["y"]==1]
neg = train_df[train_df["y"]==0].sample(frac=NEG_FRAC, random_state=42)
train_df = pd.concat([pos, neg], ignore_index=True)

# weights: balance the sampling
for df in (train_df, val_df, test_df):
    df["weight"] = 1.0
train_df.loc[train_df["y"]==0, "weight"] = 1.0/NEG_FRAC

print(f"Train rows: {len(train_df):,},  Val: {len(val_df):,},  Test: {len(test_df):,}")
print("Positives in train:", (train_df['y']==1).sum())

In [None]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.isotonic import IsotonicRegression
from sklearn.metrics import average_precision_score, brier_score_loss
from xgboost import XGBClassifier

NUM_COLS = ["lag_1h","lag_3h","lag_24h"]
infra_cols_present = [c for c in [
    "Amenity","Bump","Crossing","Give_Way","Junction","No_Exit","Railway",
    "Roundabout","Station","Stop","Traffic_Calming","Traffic_Signal","Turning_Loop"
] if c in panel.columns]
NUM_COLS += infra_cols_present

CAT_COLS = ["hour","dow","month","is_weekend","is_holiday"]

def make_xyw(df):
    X = df[NUM_COLS + CAT_COLS].copy()
    y = df["y"].astype(int).values
    w = df["weight"].astype(float).values
    return X, y, w

X_tr, y_tr, w_tr = make_xyw(train_df)
X_va, y_va, w_va = make_xyw(val_df)
X_te, y_te, w_te = make_xyw(test_df)

pre = ColumnTransformer([("cat", OneHotEncoder(handle_unknown="ignore"), CAT_COLS)],
                        remainder="passthrough")

clf = XGBClassifier(
    objective="binary:logistic",
    n_estimators=400, max_depth=6, learning_rate=0.06,
    subsample=0.9, colsample_bytree=0.9, reg_lambda=1.0,
    eval_metric="aucpr", random_state=42, n_jobs=4
)

pipe = Pipeline([("pre", pre), ("model", clf)])
pipe.fit(X_tr, y_tr, model__sample_weight=w_tr)

# Calibrate on validation
p_va_raw = pipe.predict_proba(X_va)[:,1]
iso = IsotonicRegression(out_of_bounds="clip").fit(p_va_raw, y_va, sample_weight=w_va)

# Evaluate on test (nowcast of current hour)
p_te_raw = pipe.predict_proba(X_te)[:,1]
p_te_cal = iso.predict(p_te_raw)

from sklearn.metrics import (
    accuracy_score, confusion_matrix,
    precision_score, recall_score, f1_score
)

# --- Choose a classification threshold ---
# Since accidents are rare, use a small threshold instead of 0.5
threshold = 0.01# ujst a quick check
y_pred = (p_te_cal >= threshold).astype(int)

# --- Compute standard metrics ---
acc = accuracy_score(y_te, y_pred)
prec = precision_score(y_te, y_pred, zero_division=0)
rec = recall_score(y_te, y_pred, zero_division=0)
f1 = f1_score(y_te, y_pred, zero_division=0)
cm = confusion_matrix(y_te, y_pred)

print("\n--- Classification Metrics ---")
print(f"Threshold        : {threshold}")
print(f"Accuracy         : {acc:.6f}")
print(f"Precision        : {prec:.6f}")
print(f"Recall           : {rec:.6f}")
print(f"F1-score         : {f1:.6f}")
print("Confusion Matrix :")
print(pd.DataFrame(cm,
                   index=["Actual 0","Actual 1"],
                   columns=["Pred 0","Pred 1"]))


In [None]:
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt

# --- ROC–AUC score ---
roc_auc = roc_auc_score(y_te, p_te_cal, sample_weight=w_te)
print(f"ROC–AUC: {roc_auc:.6f}")

# --- Optional: visualize ROC curve ---
fpr, tpr, thresholds = roc_curve(y_te, p_te_cal, sample_weight=w_te)
plt.figure(figsize=(6, 5))
plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.3f})")
plt.plot([0, 1], [0, 1], 'k--', label="Random chance")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate (Recall)")
plt.title("Receiver Operating Characteristic (ROC)")
plt.legend(loc="lower right")
plt.grid(True)
plt.show()


In [None]:
# ============================================================
# Best-of (LSTM / GRU / TCN) sequence model — pick by val PR-AUC
# ============================================================
import numpy as np
import pandas as pd
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import average_precision_score, precision_recall_curve, f1_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------
# 1) Config
# ---------------------
W = 168                 # sequence window (hours); try 168 for a week
BATCH_SIZE = 256
EPOCHS = 18
LR = 1e-3
BETA = 2.0              # F-beta for threshold picking on val

# Choose features you already have in your panel:
NUM_FEATS = [c for c in [
    # history (you can keep as-is; no shift policing here)
    "lag_1h","lag_3h","lag_24h","lag_7d_sum","lag_30d_sum","lag_ewm_24h","lag_ewm_7d",
    # neighbor spillover (if created; ok to omit if not in panel)
    "k1_lag_1h","k1_lag_3h","k1_lag_24h","k1_lag_7d_sum",
    "k2_lag_1h","k2_lag_3h","k2_lag_24h","k2_lag_7d_sum",
    # weather (current+lags ok, we’re ignoring leakage)
    "temp_c","temp_c_lag1","temp_c_lag3_mean","temp_c_lag24_mean","temp_c_diff1","temp_c_anom24",
    "wind_kph","wind_kph_lag1","wind_kph_lag3_mean","wind_kph_lag24_mean","wind_kph_diff1","wind_kph_anom24",
    "pressure_hpa","pressure_hpa_lag1","pressure_hpa_lag3_mean","pressure_hpa_lag24_mean","pressure_hpa_diff1","pressure_hpa_anom24",
    "humidity_pct","humidity_pct_lag1","humidity_pct_lag3_mean","humidity_pct_lag24_mean","humidity_pct_diff1","humidity_pct_anom24",
    "precip_mm","precip_mm_lag3_sum","precip_mm_lag24_sum","rain_flag_3h","rain_flag_24h",
] if c in panel.columns]

CAT_FEATS = ["hour","dow","month","is_weekend","is_holiday"]
CAT_FEATS = [c for c in CAT_FEATS if c in panel.columns]

STATIC_FEATS = [c for c in [
    "Amenity","Bump","Crossing","Give_Way","Junction","No_Exit","Railway",
    "Roundabout","Station","Stop","Traffic_Calming","Traffic_Signal","Turning_Loop"
] if c in panel.columns]

# ---------------------
# 2) Prepare DataFrame
# ---------------------
need_cols = ["h3_id","ts_utc","y"] + NUM_FEATS + CAT_FEATS + STATIC_FEATS + ["split"]
df = panel[need_cols].copy()
df["ts_utc"] = pd.to_datetime(df["ts_utc"], utc=True)
df = df.sort_values(["h3_id","ts_utc"]).reset_index(drop=True)

# map cells
h3_list = df["h3_id"].drop_duplicates().tolist()
h3_to_idx = {h:i for i,h in enumerate(h3_list)}
df["h3_idx"] = df["h3_id"].map(h3_to_idx).astype("int32")

# normalize numeric (fit on train only)
train_mask = df["split"]=="train"
num_stats = {}
for c in NUM_FEATS:
    m = pd.to_numeric(df.loc[train_mask, c], errors="coerce").astype("float32")
    mu, sd = float(m.mean()), float(m.std() if m.std()>1e-6 else 1.0)
    num_stats[c] = (mu, sd)
    df[c] = ((pd.to_numeric(df[c], errors="coerce").astype("float32") - mu)/sd).fillna(0.0)

for c in CAT_FEATS:
    df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0).astype("int64")
for c in STATIC_FEATS:
    df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0).astype("float32")

groups = {gid: g for gid, g in df.groupby("h3_idx", sort=True)}

# ---------------------
# 3) Dataset
# ---------------------
class CellSeqDataset(Dataset):
    def __init__(self, groups, split, W, num_cols, cat_cols, static_cols):
        self.num_cols = num_cols
        self.cat_cols = cat_cols
        self.static_cols = static_cols
        self.W = W
        self.samples = []
        for h3_idx, g in groups.items():
            g = g[g["split"]==split]
            if len(g) < W: 
                continue
            arr_num = g[num_cols].to_numpy(np.float32)
            arr_cat = g[cat_cols].to_numpy(np.int64) if cat_cols else None
            arr_static = g[static_cols].to_numpy(np.float32) if static_cols else None
            y = g["y"].astype("int64").to_numpy()
            for t in range(W-1, len(g)):
                Xn = arr_num[t-W+1:t+1]
                Xc = arr_cat[t-W+1:t+1] if arr_cat is not None else None
                Xs = arr_static[t] if arr_static is not None else None
                self.samples.append((h3_idx, Xn, Xc, Xs, y[t]))

    def __len__(self): return len(self.samples)
    def __getitem__(self, i):
        h3_idx, Xn, Xc, Xs, y = self.samples[i]
        return (
            torch.tensor(h3_idx, dtype=torch.long),
            torch.tensor(Xn, dtype=torch.float32),
            (torch.tensor(Xc, dtype=torch.long) if Xc is not None else torch.empty(0, dtype=torch.long)),
            (torch.tensor(Xs, dtype=torch.float32) if Xs is not None else torch.empty(0)),
            torch.tensor(y, dtype=torch.float32),
        )

def collate(batch):
    h_idx, Xn, Xc, Xs, y = zip(*batch)
    Xn = torch.stack(Xn)  # (B,W,Fnum)
    if Xc[0].numel() == 0:
        Xc = None
    else:
        Xc = torch.stack(Xc)  # (B,W,C)
    if Xs[0].numel() == 0:
        Xs = None
    else:
        Xs = torch.stack(Xs)  # (B,Fstatic)
    y = torch.stack(y)
    return (torch.tensor(h_idx, dtype=torch.long), Xn, Xc, Xs, y)

train_ds = CellSeqDataset(groups, "train", W, NUM_FEATS, CAT_FEATS, STATIC_FEATS)
val_ds   = CellSeqDataset(groups, "val",   W, NUM_FEATS, CAT_FEATS, STATIC_FEATS)
test_ds  = CellSeqDataset(groups, "test",  W, NUM_FEATS, CAT_FEATS, STATIC_FEATS)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate)

# ---------------------
# 4) Models: LSTM / GRU / TCN
# ---------------------
class CatEmb(nn.Module):
    def __init__(self, cardinals, emb_dim=8):
        super().__init__()
        self.embs = nn.ModuleList([nn.Embedding(c, min(emb_dim, max(2,c))) for c in cardinals])
        self.out_dim = sum(e.embedding_dim for e in self.embs)
    def forward(self, Xc):  # (B,W,C)
        if Xc is None: return None
        embs = [emb(Xc[:,:,i]) for i, emb in enumerate(self.embs)]  # list of (B,W,Di)
        return torch.cat(embs, dim=-1)  # (B,W,sumD)

class RNNSeq(nn.Module):
    def __init__(self, kind, num_in, cat_card=None, static_in=0, hid=128, layers=1):
        super().__init__()
        self.cat = CatEmb(cat_card) if cat_card else None
        in_dim = num_in + (self.cat.out_dim if self.cat else 0)
        rnn = {"lstm": nn.LSTM, "gru": nn.GRU}[kind]
        self.rnn = rnn(in_dim, hid, num_layers=layers, batch_first=True)
        head_in = hid + static_in
        self.head = nn.Sequential(
            nn.Linear(head_in, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(0.25),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, Xn, Xc=None, Xs=None):
        # Xn: (B,W,F)
        if self.cat:
            Ec = self.cat(Xc)                       # (B,W,De)
            x = torch.cat([Xn, Ec], dim=-1)
        else:
            x = Xn
        seq, _ = self.rnn(x)                        # (B,W,H)
        h = seq[:,-1,:]                             # (B,H)
        if Xs is not None: h = torch.cat([h, Xs], dim=1)
        logit = self.head(h).squeeze(1)
        return torch.sigmoid(logit)

class Chomp1d(nn.Module):
    def __init__(self, chomp_size): super().__init__(); self.chomp_size=chomp_size
    def forward(self, x): return x[:,:,:-self.chomp_size]

class TCNBlock(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, d=1, p=0.2):
        super().__init__()
        pad = (k-1)*d
        self.net = nn.Sequential(
            nn.Conv1d(in_ch, out_ch, k, padding=pad, dilation=d),
            Chomp1d(pad), nn.ReLU(), nn.BatchNorm1d(out_ch), nn.Dropout(p),
            nn.Conv1d(out_ch, out_ch, k, padding=pad, dilation=d),
            Chomp1d(pad), nn.ReLU(), nn.BatchNorm1d(out_ch), nn.Dropout(p),
        )
        self.down = nn.Conv1d(in_ch, out_ch, 1) if in_ch!=out_ch else nn.Identity()
    def forward(self, x):
        out = self.net(x)
        return out + self.down(x)

class TCNSeq(nn.Module):
    def __init__(self, num_in, cat_card=None, static_in=0, hid=128, levels=4, k=3):
        super().__init__()
        self.cat = CatEmb(cat_card) if cat_card else None
        in_dim = num_in + (self.cat.out_dim if self.cat else 0)
        chs = [in_dim] + [hid]*levels
        blocks = []
        for i in range(levels):
            blocks.append(TCNBlock(chs[i], chs[i+1], k=k, d=2**i, p=0.2))
        self.tcn = nn.Sequential(*blocks)
        head_in = hid + static_in
        self.head = nn.Sequential(
            nn.Linear(head_in, 128), nn.ReLU(), nn.BatchNorm1d(128), nn.Dropout(0.25),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, Xn, Xc=None, Xs=None):
        # Xn: (B,W,F) -> transpose to (B,F,W)
        if self.cat:
            Ec = self.cat(Xc)                # (B,W,De)
            x = torch.cat([Xn, Ec], dim=-1)  # (B,W,F+De)
        else:
            x = Xn
        x = x.transpose(1,2)                 # (B, Fch, W)
        y = self.tcn(x)                      # (B, hid, W)
        h = y[:,:,-1]                        # (B, hid)
        if Xs is not None: h = torch.cat([h, Xs], dim=1)
        logit = self.head(h).squeeze(1)
        return torch.sigmoid(logit)

# ---------------------
# 5) Train & Select
# ---------------------
def run_train(model, train_loader, val_loader, epochs=EPOCHS, lr=LR):
    model = model.to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    bce = nn.BCELoss()
    best = {"pr": -1, "state": None}
    for ep in range(1, epochs+1):
        model.train()
        for _, Xn, Xc, Xs, y in train_loader:
            Xn, y = Xn.to(DEVICE), y.to(DEVICE)
            if Xc is not None: Xc = Xc.to(DEVICE)
            if Xs is not None: Xs = Xs.to(DEVICE)
            p = model(Xn, Xc, Xs)
            loss = bce(p, y)
            opt.zero_grad(); loss.backward(); opt.step()
        # val PR-AUC
        model.eval(); pv, yv = [], []
        with torch.no_grad():
            for _, Xn, Xc, Xs, y in val_loader:
                Xn = Xn.to(DEVICE)
                if Xc is not None: Xc = Xc.to(DEVICE)
                if Xs is not None: Xs = Xs.to(DEVICE)
                p = model(Xn, Xc, Xs).detach().cpu().numpy()
                pv.append(p); yv.append(y.numpy())
        pv = np.concatenate(pv) if pv else np.array([])
        yv = np.concatenate(yv) if yv else np.array([])
        pr = average_precision_score(yv, pv) if (yv.size and len(np.unique(yv))>1) else 0.0
        if pr > best["pr"]:
            best = {"pr": pr, "state": {k:v.cpu().clone() for k,v in model.state_dict().items()}}
        print(f"[{model.__class__.__name__}] epoch {ep:02d}  val PR-AUC={pr:.6f}")
    model.load_state_dict({k:v.to(DEVICE) for k,v in best["state"].items()})
    return model, best["pr"]

cat_card = [24,7,12,2,2] if CAT_FEATS else None
c_static = len(STATIC_FEATS)

candidates = {
    "LSTM": RNNSeq(kind="lstm", num_in=len(NUM_FEATS), cat_card=cat_card, static_in=c_static, hid=160, layers=1),
    "GRU" : RNNSeq(kind="gru",  num_in=len(NUM_FEATS), cat_card=cat_card, static_in=c_static, hid=160, layers=1),
    "TCN" : TCNSeq(num_in=len(NUM_FEATS), cat_card=cat_card, static_in=c_static, hid=192, levels=4, k=3),
}

results = {}
best_name, best_model, best_val = None, None, -1
for name, mdl in candidates.items():
    print(f"\n==== Train {name} ====")
    m, val_pr = run_train(mdl, train_loader, val_loader, epochs=EPOCHS, lr=LR)
    results[name] = val_pr
    if val_pr > best_val:
        best_val, best_name, best_model = val_pr, name, m

print("\nValidation PR-AUC per model:", {k: round(v,6) for k,v in results.items()})
print("→ Selected:", best_name, "with PR-AUC", round(best_val,6))

# ---------------------
# 6) Test evaluation
# ---------------------
best_model.eval(); pt, yt = [], []
with torch.no_grad():
    for _, Xn, Xc, Xs, y in test_loader:
        Xn = Xn.to(DEVICE)
        if Xc is not None: Xc = Xc.to(DEVICE)
        if Xs is not None: Xs = Xs.to(DEVICE)
        p = best_model(Xn, Xc, Xs).detach().cpu().numpy()
        pt.append(p); yt.append(y.numpy())
pt = np.concatenate(pt) if pt else np.array([])
yt = np.concatenate(yt) if yt else np.array([])

test_pr = average_precision_score(yt, pt) if (yt.size and len(np.unique(yt))>1) else 0.0
print("Test PR-AUC (prob ranking):", round(float(test_pr), 6))

# pick threshold on VAL by F-beta
best_model.eval(); pv, yv = [], []
with torch.no_grad():
    for _, Xn, Xc, Xs, y in val_loader:
        Xn = Xn.to(DEVICE)
        if Xc is not None: Xc = Xc.to(DEVICE)
        if Xs is not None: Xs = Xs.to(DEVICE)
        pv.append(best_model(Xn, Xc, Xs).detach().cpu().numpy()); yv.append(y.numpy())
pv = np.concatenate(pv) if pv else np.array([])
yv = np.concatenate(yv) if yv else np.array([])

prec, rec, thr = precision_recall_curve(yv, pv)
fb = (1+BETA**2)*prec*rec/(BETA**2*prec + rec + 1e-12)
best_t = float(thr[np.nanargmax(fb[:-1])]) if thr.size else 0.5
print("Chosen threshold (val, Fβ):", best_t)

from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, brier_score_loss
y_pred = (pt >= best_t).astype(int)
print("Test precision:", round(float(precision_score(yt, y_pred, zero_division=0)),6))
print("Test recall   :", round(float(recall_score(yt, y_pred, zero_division=0)),6))
print("Test F1       :", round(float(f1_score(yt, y_pred, zero_division=0)),6))
print("Brier score   :", round(float(brier_score_loss(yt, np.clip(pt,1e-6,1-1e-6))),6))
print("Confusion:\n", pd.DataFrame(confusion_matrix(yt, y_pred),
                                    index=["Actual 0","Actual 1"],
                                    columns=["Pred 0","Pred 1"]))