In [1]:
from __future__ import annotations
import os, json, math, warnings, string
from typing import List, Tuple, Optional, Dict, Any
from dataclasses import dataclass

import numpy as np
import pandas as pd

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.impute import SimpleImputer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report, confusion_matrix, f1_score, mean_squared_error,
    accuracy_score, precision_recall_fscore_support
)
from sklearn.ensemble import HistGradientBoostingRegressor
import joblib
from textwrap import wrap

# CONFIG
RAW_EXCEL_PATH = "dataset final.xlsx"
OUT_DIR = "content/outputs"
MODELS_DIR = "content/models"
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

@dataclass
class PipelineOutputs:
    df_final: pd.DataFrame
    df_final_min: pd.DataFrame
    risk_tiles: pd.DataFrame
    cause_report: Dict[str, Any]
    cause_macro_f1: float
    cause_conf_mat: np.ndarray
    cause_classes: List[str]
    gbr_rmse: float
    cause_accuracy: float
    cause_precision_macro: float
    cause_recall_macro: float
    cause_f1_macro: float
    cause_precision_weighted: float
    cause_recall_weighted: float
    cause_f1_weighted: float

# UTILITIES
def _titlecase(s: str) -> str:
    if s is None:
        return s
    s = str(s).strip()
    if not s:
        return s
    return string.capwords(s)

def _read_excel_first_sheet(path: str) -> pd.DataFrame:
    if not os.path.exists(path):
        raise FileNotFoundError(f"File not found: {path}")
    return pd.read_excel(path, sheet_name=0)

def _drop_slope_elevation(df: pd.DataFrame) -> pd.DataFrame:
    to_drop = [c for c in df.columns if ("slope" in c.lower()) or ("elev" in c.lower())]
    return df.drop(columns=to_drop, errors="ignore")

def _build_timestamp(df: pd.DataFrame) -> pd.Series:
    ts = None
    for cand in ["Datetime","DateTime","Timestamp","DATE_TIME","date_time"]:
        if cand in df.columns:
            ts = pd.to_datetime(df[cand], errors="coerce")
            break
    if ts is None:
        if "Date" in df.columns and "Time" in df.columns:
            ts = pd.to_datetime(df["Date"].astype(str) + " " + df["Time"].astype(str), errors="coerce")
        elif "Date" in df.columns:
            ts = pd.to_datetime(df["Date"], errors="coerce")
        else:
            ts = pd.Series([pd.NaT]*len(df), index=df.index)
    return ts

def _clean_reason(series: pd.Series) -> pd.Series:
    s = series.astype(str)
    missing_like = s.str.strip().str.lower().isin({"nan","na","n/a","none","-",""})
    s = s.mask(missing_like, np.nan)
    s = s.where(s.isna(), s.str.strip().str.title())
    return s

def feature_engineer(df_raw: pd.DataFrame) -> pd.DataFrame:
    df = df_raw.copy()
    df = _drop_slope_elevation(df)

    # Categories
    for col in ["Vehicle","Place","Position"]:
        if col in df.columns:
            df[col] = df[col].astype(str).map(_titlecase)

    # Reason
    if "Reason" in df.columns:
        df["Reason"] = _clean_reason(df["Reason"])

    # Timestamp
    df["ts"] = _build_timestamp(df)

    # Derived features
    if "ts" in df.columns:
        df["hour"] = pd.to_datetime(df["ts"]).dt.hour
        df["dow"] = pd.to_datetime(df["ts"]).dt.dayofweek
        df["is_weekend"] = df["dow"].isin([5,6]).astype(int)

    # Wet flag from precip
    precip_col = next((c for c in df.columns if "precip" in c.lower()), None)
    if precip_col:
        df["is_wet"] = (pd.to_numeric(df[precip_col], errors="coerce").fillna(0) > 0.1).astype(int)
    else:
        df["is_wet"] = 0

    # Lat/Lon bins + segment_id
    if "Latitude" in df.columns and "Longitude" in df.columns:
        df["lat_bin"] = pd.to_numeric(df["Latitude"], errors="coerce").round(3)
        df["lon_bin"] = pd.to_numeric(df["Longitude"], errors="coerce").round(3)
        df["segment_id"] = df["lat_bin"].astype(str) + "_" + df["lon_bin"].astype(str)
    else:
        df["lat_bin"] = np.nan
        df["lon_bin"] = np.nan
        df["segment_id"] = "NA"

    # Speed-related indicator
    df["is_speed_reason"] = df.get("Reason", "").astype(str).str.contains("Excessive Speed", case=False, na=False).astype(int)

    # Ensure numeric types
    for c in ["Temperature (C)","Humidity (%)","Precipitation (mm)","Wind Speed (km/h)","Latitude","Longitude"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")

    return df

def build_spi(df: pd.DataFrame, alpha: float = 20.0) -> pd.DataFrame:
    """Speed Propensity Index (SPI) per (lat_bin, lon_bin, hour, is_wet, Vehicle), smoothed with alpha."""
    global_spi = float(df["is_speed_reason"].mean()) if "is_speed_reason" in df.columns else 0.0
    keys = ["lat_bin","lon_bin","hour","is_wet","Vehicle"]
    for k in keys:
        if k not in df.columns:
            df[k] = np.nan
    grp = (df.groupby(keys, dropna=False)
             .agg(spi=("is_speed_reason","mean"), n=("is_speed_reason","size"))
             .reset_index())
    grp["SPI_smoothed"] = (grp["n"]*grp["spi"] + alpha*global_spi) / (grp["n"] + alpha)
    df_spi = df.merge(grp[keys + ["SPI_smoothed"]], on=keys, how="left")
    df_spi["SPI_smoothed"] = df_spi["SPI_smoothed"].fillna(global_spi)
    return df_spi

def make_final_datasets(df_spi: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    full = df_spi.copy()
    keep_cols = [
        "ts","Date","Time","Latitude","Longitude","lat_bin","lon_bin","segment_id",
        "Vehicle","Place","Position",
        "Reason","Description",
        "Temperature (C)","Humidity (%)","Precipitation (mm)","Wind Speed (km/h)",
        "hour","dow","is_weekend","is_wet","is_speed_reason","SPI_smoothed"
    ]
    min_cols = [c for c in keep_cols if c in full.columns]
    minimal = full[min_cols].copy()
    return full, minimal

def build_risk_tiles(df_spi: pd.DataFrame) -> pd.DataFrame:
    keys = ["segment_id","lat_bin","lon_bin","hour","dow","is_wet","Vehicle"]
    for k in keys:
        if k not in df_spi.columns:
            df_spi[k] = np.nan
    agg = (df_spi.groupby(keys, dropna=False)
               .agg(incident_count=("Reason","size"),
                    speed_reason_rate=("is_speed_reason","mean"),
                    n=("is_speed_reason","size"),
                    SPI_tile=("SPI_smoothed","mean"))
               .reset_index())
    return agg

def plot_bar_counts(series: pd.Series, title: str, out_path: str, top_n: int = 15):
    counts = series.value_counts().head(top_n)
    plt.figure()
    counts.plot(kind="bar")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_hourly_counts(df: pd.DataFrame, out_path: str):
    plt.figure()
    counts = df["hour"].value_counts().sort_index()
    counts.plot(kind="bar")
    plt.title("Incidents by Hour of Day")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_geo_scatter(df: pd.DataFrame, out_path: str):
    if "Longitude" not in df.columns or "Latitude" not in df.columns:
        return
    plt.figure()
    plt.scatter(df["Longitude"], df["Latitude"], s=10)
    plt.title("Incident Locations (Lon/Lat)")
    plt.xlabel("Longitude")
    plt.ylabel("Latitude")
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

def plot_confusion_matrix_pretty(cm, classes, title="Cause Classifier - Confusion Matrix",
                                 normalize=False, out_path=None, dpi=220):
    """
    cm: numpy.ndarray (n_classes x n_classes)
    classes: list[str] in the same order used to build 'cm'
    normalize: if True -> each row is shown as percentages
    """
    import numpy as np
    import matplotlib.pyplot as plt

    cm = np.array(cm, dtype=float)

    # Row-normalize for percentage view
    if normalize:
        rs = cm.sum(axis=1, keepdims=True)
        rs[rs == 0] = 1.0
        cm_disp = cm / rs
    else:
        cm_disp = cm.copy()

    n = cm.shape[0]
    labels_wrapped = ["\n".join(wrap(c, 18)) for c in classes]

    plt.figure(figsize=(9, 7), dpi=dpi)
    im = plt.imshow(cm_disp, interpolation="nearest", cmap="Blues")
    plt.title(title, pad=10)
    plt.colorbar(im, fraction=0.046, pad=0.04)

    plt.xticks(range(n), labels_wrapped, rotation=35, ha="right")
    plt.yticks(range(n), labels_wrapped)

    # grid
    ax = plt.gca()
    ax.set_xticks(np.arange(-.5, n, 1), minor=True)
    ax.set_yticks(np.arange(-.5, n, 1), minor=True)
    ax.grid(which='minor', linestyle='-', linewidth=0.5, alpha=0.4)
    ax.tick_params(which='minor', bottom=False, left=False)

    row_totals = cm.sum(axis=1, keepdims=True)
    row_totals[row_totals == 0] = 1.0
    thresh = cm_disp.max() / 2.0 if cm_disp.max() > 0 else 0.5

    for i in range(n):
        for j in range(n):
            cnt = int(round(cm[i, j]))
            pct = (cm[i, j] / row_totals[i, 0]) * 100.0
            text = f"{cnt}\n{pct:.0f}%"
            color = "white" if cm_disp[i, j] > thresh else "black"
            plt.text(j, i, text, ha="center", va="center", color=color, fontsize=9)

    plt.ylabel("True")
    plt.xlabel("Predicted")
    plt.tight_layout()
    if out_path:
        plt.savefig(out_path, bbox_inches="tight", dpi=dpi)
    plt.close()

def train_cause_classifier(df_final_min: pd.DataFrame):
    work = df_final_min.dropna(subset=["Reason","ts"]).copy()
    work = work.sort_values("ts")
    if "Description" in work.columns:
        work["Description"] = work["Description"].fillna("").astype(str)

    n = len(work)
    if n < 10:
        raise RuntimeError("Not enough records with timestamp + reason to train the classifier.")
    split_idx = max(int(0.8 * n), 1)
    train_df = work.iloc[:split_idx]
    test_df  = work.iloc[split_idx:]

    numeric_cols = [c for c in ["Temperature (C)","Humidity (%)","Precipitation (mm)","Wind Speed (km/h)",
                                "hour","dow","is_weekend","is_wet","Latitude","Longitude","SPI_smoothed"]
                    if c in work.columns]
    cat_cols = [c for c in ["Vehicle","Place","Position"] if c in work.columns]
    text_col = "Description" if "Description" in work.columns else None

    num_pipe = Pipeline([
        ("impute", SimpleImputer(strategy="median")),
        ("scale", StandardScaler())
    ])
    cat_pipe = Pipeline([
        ("impute", SimpleImputer(strategy="most_frequent")),
        ("onehot", OneHotEncoder(handle_unknown="ignore"))
    ])

    transformers = [("num", num_pipe, numeric_cols), ("cat", cat_pipe, cat_cols)]
    if text_col:
        transformers.append(("txt", TfidfVectorizer(min_df=3, ngram_range=(1,2)), text_col))

    pre = ColumnTransformer(transformers=transformers, remainder="drop", sparse_threshold=0.3)
    clf = LogisticRegression(max_iter=1000, class_weight="balanced", solver="liblinear")
    pipe = Pipeline([("pre", pre), ("clf", clf)])

    X_train = train_df[numeric_cols + cat_cols + ([text_col] if text_col else [])]
    y_train = train_df["Reason"]
    X_test  = test_df[numeric_cols + cat_cols + ([text_col] if text_col else [])]
    y_test  = test_df["Reason"]

    pipe.fit(X_train, y_train)
    y_pred = pipe.predict(X_test)

    # Full classification report (dict) + macro F1 (legacy kept)
    report_dict = classification_report(y_test, y_pred, output_dict=True, zero_division=0)
    macro_f1 = f1_score(y_test, y_pred, average="macro")
    classes = list(pipe.classes_)
    conf_mat = confusion_matrix(y_test, y_pred, labels=classes)

    # Explicit metrics
    accuracy = float(accuracy_score(y_test, y_pred))
    prec_macro, rec_macro, f1_macro, _ = precision_recall_fscore_support(
        y_test, y_pred, average="macro", zero_division=0
    )
    prec_w, rec_w, f1_w, _ = precision_recall_fscore_support(
        y_test, y_pred, average="weighted", zero_division=0
    )

    # Save matrices
    plot_confusion_matrix_pretty(
        conf_mat, classes,
        normalize=False,
        out_path=os.path.join(OUT_DIR, "confusion_matrix_counts.png")
    )
    plot_confusion_matrix_pretty(
        conf_mat, classes,
        normalize=True,
        out_path=os.path.join(OUT_DIR, "confusion_matrix_normalized.png")
    )

    return (pipe, report_dict, macro_f1, conf_mat, classes,
            accuracy, float(prec_macro), float(rec_macro), float(f1_macro),
            float(prec_w), float(rec_w), float(f1_w))

def train_segment_rate_model(risk_tiles: pd.DataFrame):
    data = risk_tiles.copy()
    feat = ["hour","dow","is_wet","Vehicle"]
    for f in ["hour","dow","is_wet"]:
        if f not in data.columns:
            data[f] = 0
    X = pd.get_dummies(data[feat], columns=["Vehicle"], dummy_na=True)
    y = data["incident_count"].astype(float)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    gbr = HistGradientBoostingRegressor(max_depth=3, random_state=42)
    gbr.fit(X_train, y_train)
    pred = gbr.predict(X_test)
    rmse = float(np.sqrt(mean_squared_error(y_test, pred)))
    return gbr, rmse

def save_metrics_json(report_dict: Dict[str, Any], macro_f1: float, gbr_rmse: float, out_path: str,
                      extra_cls: Dict[str, float]):
    payload = {
        "cause_classifier": {
            "macro_f1": macro_f1,
            "per_class": {k:v for k,v in report_dict.items() if k not in {"accuracy","macro avg","weighted avg"}},
            "macro_avg": report_dict.get("macro avg", {}),
            "weighted_avg": report_dict.get("weighted avg", {}),
            "accuracy": report_dict.get("accuracy", None),
            "accuracy_explicit": extra_cls.get("accuracy"),
            "precision_macro": extra_cls.get("precision_macro"),
            "recall_macro": extra_cls.get("recall_macro"),
            "f1_macro": extra_cls.get("f1_macro"),
            "precision_weighted": extra_cls.get("precision_weighted"),
            "recall_weighted": extra_cls.get("recall_weighted"),
            "f1_weighted": extra_cls.get("f1_weighted"),
        },
        "segment_rate_model": {
            "rmse": gbr_rmse
        }
    }
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2)

def main() -> PipelineOutputs:
    df_raw = _read_excel_first_sheet(RAW_EXCEL_PATH)
    df_feat = feature_engineer(df_raw)

    # SPI
    df_spi = build_spi(df_feat, alpha=20.0)

    df_full, df_min = make_final_datasets(df_spi)
    risk_tiles = build_risk_tiles(df_spi)

    final_full_path = os.path.join(OUT_DIR, "final_dataset.csv")
    final_min_path  = os.path.join(OUT_DIR, "final_dataset_min.csv")
    risk_tiles_path = os.path.join(OUT_DIR, "risk_tiles.csv")
    df_full.to_csv(final_full_path, index=False)
    df_min.to_csv(final_min_path, index=False)
    risk_tiles.to_csv(risk_tiles_path, index=False)

    if "Reason" in df_min.columns:
        plot_bar_counts(df_min["Reason"].dropna(), "Top Reasons", os.path.join(OUT_DIR, "top_reasons.png"))
    if "Vehicle" in df_min.columns:
        plot_bar_counts(df_min["Vehicle"].dropna(), "Top Vehicles", os.path.join(OUT_DIR, "top_vehicles.png"))
    if "hour" in df_min.columns:
        plot_hourly_counts(df_min, os.path.join(OUT_DIR, "incidents_by_hour.png"))
    plot_geo_scatter(df_min, os.path.join(OUT_DIR, "geo_scatter.png"))

    # Train cause-of-incident classifier
    (pipe, report_dict, macro_f1, conf_mat, classes,
     acc, p_macro, r_macro, f_macro, p_w, r_w, f_w) = train_cause_classifier(df_min)
    joblib.dump(pipe, os.path.join(MODELS_DIR, "cause_classifier.joblib"))

    print("\n=== Cause Classifier — Test Metrics ===")
    print(f"Accuracy        : {acc:.4f}")
    print(f"Precision (macro): {p_macro:.4f} | Recall (macro): {r_macro:.4f} | F1 (macro): {f_macro:.4f}")
    print(f"Precision (weighted): {p_w:.4f} | Recall (weighted): {r_w:.4f} | F1 (weighted): {f_w:.4f}")

    cls_metrics_path = os.path.join(OUT_DIR, "classification_metrics.json")
    with open(cls_metrics_path, "w", encoding="utf-8") as f:
        json.dump({
            "accuracy": acc,
            "precision_macro": p_macro,
            "recall_macro": r_macro,
            "f1_macro": f_macro,
            "precision_weighted": p_w,
            "recall_weighted": r_w,
            "f1_weighted": f_w,
            "classes": classes
        }, f, indent=2)
    print(f"[Info] Classification metrics saved to: {cls_metrics_path}")

    # Train segment risk model
    gbr, rmse = train_segment_rate_model(risk_tiles)
    joblib.dump(gbr, os.path.join(MODELS_DIR, "segment_gbr.joblib"))

    save_metrics_json(
        report_dict, macro_f1, rmse, os.path.join(OUT_DIR, "metrics.json"),
        extra_cls={
            "accuracy": acc,
            "precision_macro": p_macro,
            "recall_macro": r_macro,
            "f1_macro": f_macro,
            "precision_weighted": p_w,
            "recall_weighted": r_w,
            "f1_weighted": f_w
        }
    )

    print("\n=== Summary ===")
    print(f"Rows (raw): {len(df_raw)}")
    print(f"Rows (final): {len(df_min)} | Columns (final): {df_min.shape[1]}")
    print(f"Cause classifier Macro-F1: {macro_f1:.3f}")
    print(f"Segment rate RMSE: {rmse:.3f}")
    print("Saved:")
    print(f" - {final_full_path}")
    print(f" - {final_min_path}")
    print(f" - {risk_tiles_path}")
    print(f" - {os.path.join(MODELS_DIR, 'cause_classifier.joblib')}")
    print(f" - {os.path.join(MODELS_DIR, 'segment_gbr.joblib')}")
    print(f" - Plots in {OUT_DIR}")
    return PipelineOutputs(
        df_final=df_full,
        df_final_min=df_min,
        risk_tiles=risk_tiles,
        cause_report=report_dict,
        cause_macro_f1=macro_f1,
        cause_conf_mat=conf_mat,
        cause_classes=classes,
        gbr_rmse=rmse,
        cause_accuracy=acc,
        cause_precision_macro=p_macro,
        cause_recall_macro=r_macro,
        cause_f1_macro=f_macro,
        cause_precision_weighted=p_w,
        cause_recall_weighted=r_w,
        cause_f1_weighted=f_w
    )

if __name__ == "__main__":
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        main()



=== Cause Classifier — Test Metrics ===
Accuracy        : 0.9412
Precision (macro): 0.6548 | Recall (macro): 0.7308 | F1 (macro): 0.6839
Precision (weighted): 0.9356 | Recall (weighted): 0.9412 | F1 (weighted): 0.9347
[Info] Classification metrics saved to: content/outputs/classification_metrics.json

=== Summary ===
Rows (raw): 315
Rows (final): 315 | Columns (final): 23
Cause classifier Macro-F1: 0.684
Segment rate RMSE: 0.109
Saved:
 - content/outputs/final_dataset.csv
 - content/outputs/final_dataset_min.csv
 - content/outputs/risk_tiles.csv
 - content/models/cause_classifier.joblib
 - content/models/segment_gbr.joblib
 - Plots in content/outputs


In [2]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Any, Optional
import math
import numpy as np
import pandas as pd
import joblib
from datetime import datetime

# Config
CAUSE_MODEL_PATH = "content/models/cause_classifier.joblib"
RATE_MODEL_PATH  = "content/models/segment_gbr.joblib"
RISK_TILES_CSV   = "content/outputs/risk_tiles.csv"

def _title(s: Optional[str]) -> str:
    return (s or "").strip().title()

def _sigmoid(x: float) -> float:
    return 1.0 / (1.0 + math.exp(-x))

@dataclass
class RiskInputs:
    lat: float
    lon: float
    dt: datetime
    is_wet: int
    vehicle: str
    place: Optional[str] = None
    position: Optional[str] = None
    temperature_c: Optional[float] = None
    humidity_pct: Optional[float] = None
    precip_mm: Optional[float] = None
    wind_kmh: Optional[float] = None
    description: Optional[str] = ""

class RiskModel:
    def __init__(self,
                 cause_model_path: str = CAUSE_MODEL_PATH,
                 rate_model_path: str = RATE_MODEL_PATH,
                 risk_tiles_csv: str = RISK_TILES_CSV):
        # Load models
        self.cause_pipe = joblib.load(cause_model_path)
        self.rate_model = joblib.load(rate_model_path)    # Histogram based Gradient Boosting Regressor
        # Load risk tiles for SPI lookup & calibration
        self.tiles = pd.read_csv(risk_tiles_csv)
        # Precompute a few calibrators
        self.global_spi = float(self.tiles["SPI_tile"].mean()) if "SPI_tile" in self.tiles else 0.1
        # scale for rate normalization
        self.rate_scale = float(np.quantile(self.tiles["incident_count"], 0.95)) if "incident_count" in self.tiles else 5.0
        if self.rate_scale <= 0:
            self.rate_scale = 1.0

    # SPI lookup
    def _spi_lookup(self, lat: float, lon: float, hour: int, dow: int, is_wet: int, vehicle: str) -> float:
        v = _title(vehicle)
        latb, lonb = round(float(lat), 3), round(float(lon), 3)
        t = self.tiles

        # exact match (including day of week)
        m = (t["lat_bin"].round(3).eq(latb) &
             t["lon_bin"].round(3).eq(lonb) &
             t["hour"].eq(hour) &
             t["dow"].eq(dow) &
             t["is_wet"].eq(is_wet) &
             t["Vehicle"].astype(str).str.title().eq(v))
        cand = t.loc[m, "SPI_tile"]
        if not cand.empty:
            return float(cand.mean())

        # ignore day of week
        m = (t["lat_bin"].round(3).eq(latb) &
             t["lon_bin"].round(3).eq(lonb) &
             t["hour"].eq(hour) &
             t["is_wet"].eq(is_wet) &
             t["Vehicle"].astype(str).str.title().eq(v))
        cand = t.loc[m, "SPI_tile"]
        if not cand.empty:
            return float(cand.mean())

        # ignore hour (keep location/wet/vehicle)
        m = (t["lat_bin"].round(3).eq(latb) &
             t["lon_bin"].round(3).eq(lonb) &
             t["is_wet"].eq(is_wet) &
             t["Vehicle"].astype(str).str.title().eq(v))
        cand = t.loc[m, "SPI_tile"]
        if not cand.empty:
            return float(cand.mean())

        # fallback to vehicle + wet anywhere
        m = (t["is_wet"].eq(is_wet) &
             t["Vehicle"].astype(str).str.title().eq(v))
        cand = t.loc[m, "SPI_tile"]
        if not cand.empty:
            return float(cand.mean())

        # global average
        return self.global_spi

    def _build_classifier_row(self, x: RiskInputs) -> pd.DataFrame:
        hour = int(x.dt.hour)
        dow  = int(x.dt.weekday())
        is_weekend = 1 if dow in (5, 6) else 0
        latb, lonb = round(float(x.lat), 3), round(float(x.lon), 3)
        spi = self._spi_lookup(x.lat, x.lon, hour, dow, x.is_wet, x.vehicle)

        row = {
            "Temperature (C)": x.temperature_c,
            "Humidity (%)": x.humidity_pct,
            "Precipitation (mm)": x.precip_mm,
            "Wind Speed (km/h)": x.wind_kmh,
            "Latitude": x.lat,
            "Longitude": x.lon,
            "hour": hour,
            "dow": dow,
            "is_weekend": is_weekend,
            "is_wet": int(x.is_wet),
            "SPI_smoothed": spi,
            "Vehicle": _title(x.vehicle),
            "Place": _title(x.place),
            "Position": _title(x.position),
            "Description": (x.description or "")
        }
        return pd.DataFrame([row])

    def predict_cause(self, x: RiskInputs) -> Dict[str, Any]:
        row = self._build_classifier_row(x)
        # Predict probability with the fitted pipeline
        proba = self.cause_pipe.predict_proba(row)[0]
        classes = list(self.cause_pipe.classes_)
        idx = int(np.argmax(proba))
        return {
            "top_cause": classes[idx],
            "p_top_cause": float(proba[idx]),
            "probs": {c: float(p) for c, p in zip(classes, proba)}
        }

    def predict_segment_rate(self, x: RiskInputs) -> float:
        # Rate model features: hour, DoW, is_wet, Vehicle
        base = pd.DataFrame([{
            "hour": int(x.dt.hour),
            "dow": int(x.dt.weekday()),
            "is_wet": int(x.is_wet),
            "Vehicle": _title(x.vehicle)
        }])
        X = pd.get_dummies(base, columns=["Vehicle"], dummy_na=True)
        # Align columns to model
        needed = list(self.rate_model.feature_names_in_)
        for col in needed:
            if col not in X.columns:
                X[col] = 0
        X = X[needed]
        rate = float(max(0.0, self.rate_model.predict(X)[0]))
        return rate

    def score_risk(self, x: RiskInputs) -> Dict[str, Any]:
        c = self.predict_cause(x)
        r = self.predict_segment_rate(x)

        # Normalize and combine
        cause_component = _sigmoid(5.0 * (c["p_top_cause"] - 0.5))
        rate_component  = min(1.0, r / max(1e-6, self.rate_scale))

        # Light context weights + vehicle/weather multipliers
        base = 0.6 * cause_component + 0.4 * rate_component
        S_vehicle = {"Motor Cycle": 1.2, "Three Wheeler": 1.1}.get(_title(x.vehicle), 1.0)
        W_weather = 1.25 if x.is_wet else 1.0

        risk = 100.0 * base * S_vehicle * W_weather
        return {
            "risk_0_100": float(min(100.0, max(0.0, round(risk, 1)))),
            "top_cause": c["top_cause"],
            "p_top_cause": round(c["p_top_cause"], 3),
            "rate_pred": round(r, 3),
            "components": {
                "cause_component": round(cause_component, 3),
                "rate_component": round(rate_component, 3),
                "S_vehicle": S_vehicle,
                "W_weather": W_weather
            }
        }

# Example
if __name__ == "__main__":
    from datetime import datetime
    model = RiskModel()
    ctx = RiskInputs(
        lat=7.0192, lon=80.4943,
        dt=datetime.now(),
        is_wet=1,
        vehicle="Bus",
        place="Ginigathena",
        position="Bend",
        temperature_c=24.0,
        humidity_pct=92.0,
        precip_mm=0.6,
        wind_kmh=12.0,
        description=""
    )
    print("Cause:", model.predict_cause(ctx))
    print("Rate :", model.predict_segment_rate(ctx))
    print("Risk :", model.score_risk(ctx))


Cause: {'top_cause': 'Excessive Speed', 'p_top_cause': 0.7911969378894055, 'probs': {'Excessive Speed': 0.7911969378894055, 'Mechanical Error': 0.03227734344442165, 'Slipped': 0.068024801995568, 'Slipped, Excessive Speed': 0.10850091667060487}}
Rate : 1.0366130175658315
Risk : {'risk_0_100': 100.0, 'top_cause': 'Excessive Speed', 'p_top_cause': 0.791, 'rate_pred': 1.037, 'components': {'cause_component': 0.811, 'rate_component': 1.0, 'S_vehicle': 1.0, 'W_weather': 1.25}}
