In [None]:
import os
import math
import json
import copy
import random
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, r2_score

import psycopg2
from psycopg2.extras import execute_values

# =====================================================================================
# CONFIG / CONSTANTS
# =====================================================================================

PG_HOST    = "119.148.17.102"
PG_PORT    = 5432
PG_DB      = "ewarsdb"
PG_USER    = "ewars"
PG_PASS    = "Iedcr@Ewars2025"
PG_SSLMODE = "require"

OUT = "/content/diarrhoea_out"
os.makedirs(OUT, exist_ok=True)

SEED=42
DEVICE="cuda" if torch.cuda.is_available() else "cpu"

# sequence/training
SEQ=30              # lookback window length (days)
BATCH=128
EPOCHS=300
PATIENCE=25
LSTM_UNITS=160
HEADS=8
DROP=0.2
LR=8e-4
WD=1e-4
CLIP=1.0
TF_START=1.0
TF_END=0.35
Q=(0.1,0.5,0.9)

# forecast horizon for per-district roll-forward (days)
HORIZON=15

# feature engineering params
CASE_LAGS = (1,2,3,7,14,21,28)
WEA_LAGS  = (1,2,3,7)
CASE_ROLL = (7,14)
WEA_ROLL  = (7,)

# director table logic / thresholds (reuse dengue style)
DEFAULT_CAPACITY_PER_DISTRICT = 25
RAPID_GROWTH_THRESH = 0.30       # >=30% wow growth => red
MODERATE_GROWTH_THRESH = 0.10    # >=10% wow growth => orange

PEAK_LOOKAHEAD_DAYS = 15         # look-ahead window for "peak projection" (use horizon)
CONF_HIGH = 5        # width<=5 => high confidence
CONF_MED  = 15       # width<=15 => medium

# random seeding for reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark=False
    torch.backends.cudnn.deterministic=True


# =====================================================================================
# PG HELPERS
# =====================================================================================

def get_db_conn():
    dsn = (
        f"host={PG_HOST} port={PG_PORT} dbname={PG_DB} "
        f"user={PG_USER} password={PG_PASS} sslmode={PG_SSLMODE}"
    )
    return psycopg2.connect(dsn)

def to_python(obj):
    if isinstance(obj, (np.floating, np.float32, np.float64)):
        val = float(obj)
        if math.isnan(val):
            return None
        return val
    if isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    if isinstance(obj, (np.bool_,)):
        return bool(obj)

    if isinstance(obj, np.ndarray):
        return [to_python(x) for x in obj.tolist()]

    if isinstance(obj, (pd.Timestamp,)):
        if pd.isna(obj):
            return None
        return obj.isoformat()

    if isinstance(obj, float):
        if math.isnan(obj):
            return None
        return obj

    if isinstance(obj, dict):
        return {k: to_python(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [to_python(v) for v in obj]

    return obj

def py(v):
    if isinstance(v, (np.floating, np.float32, np.float64)):
        v = float(v)
    elif isinstance(v, (np.integer, np.int32, np.int64)):
        v = int(v)
    elif isinstance(v, (np.bool_,)):
        v = bool(v)
    if isinstance(v, float) and (math.isnan(v)):
        return None
    return v

def ensure_output_tables_diarrhoea(conn):
    cur = conn.cursor()

    # identical schemas to dengue_* but prefixed diarrhoea_*
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_watchlist (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        expected_cases_next_week NUMERIC,
        high_scenario_p90 NUMERIC,
        status TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_overflow_risk (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        capacity_threshold_beds_per_week INT,
        forecast_median_next_week NUMERIC,
        high_scenario_p90 NUMERIC,
        breach_risk_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_acceleration_alerts (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        last_week_cases NUMERIC,
        this_week_actual NUMERIC,
        this_week_predicted NUMERIC,
        next_week_forecast NUMERIC,
        growth_rate_wow NUMERIC,
        growth_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_confidence (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        forecast_next_week NUMERIC,
        uncertainty_width NUMERIC,
        confidence_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_surveillance_quality (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        reporting_continuity_pct NUMERIC,
        data_quality_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_peak_projection (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        peak_lead_time_weeks NUMERIC,
        peak_cases_median NUMERIC,
        peak_cases_high_p90 NUMERIC,
        peak_when TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_isochrone_spread (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        forecast_next_week_median NUMERIC,
        forecast_next_week_hi NUMERIC,
        growth_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_nowcast_gap (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        current_week_actual NUMERIC,
        current_week_predicted_from_prev NUMERIC,
        nowcast_gap_percent NUMERIC,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_climate_influence (
        feature TEXT,
        base_var TEXT,
        lag_info TEXT,
        pearson_corr_with_next_week_forecast NUMERIC,
        abs_corr NUMERIC,
        created_at TIMESTAMP DEFAULT now()
    );
    """)
    cur.execute("""
    CREATE TABLE IF NOT EXISTS diarrhoea_exec_summary (
        id SERIAL PRIMARY KEY,
        summary JSONB,
        created_at TIMESTAMP DEFAULT now()
    );
    """)

    conn.commit()
    cur.close()

def _dedupe_by_pk(rows, cols, pk_cols):
    if not pk_cols:
        return rows
    pk_idx = [cols.index(pk) for pk in pk_cols]
    dedup = {}
    for row in rows:
        key = tuple(row[i] for i in pk_idx)
        dedup[key] = row
    return list(dedup.values())

def upsert_table(conn, table_name, cols, rows, pk_cols=None, wipe_first=False):
    if wipe_first:
        with conn.cursor() as cur:
            cur.execute(f"TRUNCATE TABLE {table_name};")
        conn.commit()
    if not rows:
        return
    if not pk_cols:
        with conn.cursor() as cur:
            insert_sql = f"INSERT INTO {table_name} ({', '.join(cols)}) VALUES %s"
            execute_values(cur, insert_sql, rows)
        conn.commit()
        return

    rows_dedup = _dedupe_by_pk(rows, cols, pk_cols)
    conflict_target = ", ".join(pk_cols)
    set_updates = ", ".join([f"{c}=EXCLUDED.{c}" for c in cols if c not in pk_cols])
    insert_sql = (
        f"INSERT INTO {table_name} ({', '.join(cols)}) VALUES %s "
        f"ON CONFLICT ({conflict_target}) DO UPDATE SET {set_updates};"
    )
    with conn.cursor() as cur:
        execute_values(cur, insert_sql, rows_dedup)
    conn.commit()


# =====================================================================================
# LOAD / PREP DAILY DATA FROM DB
# =====================================================================================

REQ = ["division","district","date","daily_cases","temperature","humidity","rainfall"]

def load_daily_from_db():
    """
    Pull raw daily diarrhoea + weather from Postgres (awd_weather),
    aggregate duplicates per district/date, pad missing days per district,
    fill exogenous (ffill/bfill), cases missing -> 0, add time features.
    """
    conn = get_db_conn()
    try:
        sql = """
            SELECT
                division,
                district,
                date,
                daily_cases,
                temperature,
                humidity,
                rainfall
            FROM awd_weather
            WHERE date IS NOT NULL
            ORDER BY district, date;
        """
        df = pd.read_sql(sql, conn)
    finally:
        conn.close()

    # normalize cols
    df.columns = [c.strip().lower() for c in df.columns]

    # sanity
    missing = [c for c in REQ if c not in df.columns]
    if missing:
        raise ValueError(f"DB missing columns {missing}")

    # types
    for cat in ["division","district"]:
        df[cat] = df[cat].astype(str).str.strip()
    df["date"] = pd.to_datetime(df["date"], errors="coerce", utc=False, infer_datetime_format=True)
    for c in ["daily_cases","temperature","humidity","rainfall"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    # aggregate duplicates (sum cases, mean weather)
    grp = df.groupby(["division","district","date"], as_index=False)
    df = grp.agg({
        "daily_cases":"sum",
        "temperature":"mean",
        "humidity":"mean",
        "rainfall":"mean",
    })

    # pad per district to continuous daily timeline
    parts = []
    for (div, dis), g in df.groupby(["division","district"], sort=False):
        g = g.sort_values("date")
        if len(g)==0:
            continue
        full_idx = pd.date_range(g["date"].min(), g["date"].max(), freq="D")
        g = g.set_index("date").reindex(full_idx)
        g.index.name = "date"
        g = g.reset_index()
        g["division"] = div
        g["district"] = dis

        # fill exogenous forward/back
        for c in ["temperature","humidity","rainfall"]:
            g[c] = g[c].ffill().bfill()

        # missing cases -> 0
        g["daily_cases"] = g["daily_cases"].fillna(0.0)

        parts.append(g)

    df = pd.concat(parts, ignore_index=True)

    # time features
    df["doy"] = df["date"].dt.dayofyear.astype(int)         # 1..366
    df["dow"] = df["date"].dt.dayofweek.astype(int)         # 0..6
    df["doy_sin"] = np.sin(2*np.pi*df["doy"].astype(float)/366.0).astype(np.float32)
    df["doy_cos"] = np.cos(2*np.pi*df["doy"].astype(float)/366.0).astype(np.float32)
    df["dow_sin"] = np.sin(2*np.pi*df["dow"].astype(float)/7.0).astype(np.float32)
    df["dow_cos"] = np.cos(2*np.pi*df["dow"].astype(float)/7.0).astype(np.float32)

    return df


# =====================================================================================
# FEATURE ENGINEERING (LAGS / ROLLS)
# =====================================================================================

def add_lags_rolls_daily(df):
    df = df.sort_values(["district","date"]).copy()
    new_cols = []

    # case lags/rolls
    for L in CASE_LAGS:
        col = f"daily_cases_lag{L}"
        df[col] = df.groupby("district")["daily_cases"].shift(L)
        new_cols.append(col)

    for R in CASE_ROLL:
        m = f"daily_cases_rmean{R}"
        s = f"daily_cases_rstd{R}"
        g = df.groupby("district")["daily_cases"]
        df[m] = g.rolling(R, min_periods=1).mean().reset_index(level=0, drop=True)
        df[s] = g.rolling(R, min_periods=1).std().reset_index(level=0, drop=True).fillna(0.0)
        new_cols += [m,s]

    # weather lags/rolls
    for w in ["temperature","humidity","rainfall"]:
        for L in WEA_LAGS:
            col = f"{w}_lag{L}"
            df[col] = df.groupby("district")[w].shift(L)
            new_cols.append(col)

        for R in WEA_ROLL:
            m = f"{w}_rmean{R}"
            g = df.groupby("district")[w]
            df[m] = g.rolling(R, min_periods=1).mean().reset_index(level=0, drop=True)
            new_cols.append(m)

    df[new_cols] = df[new_cols].fillna(0.0)
    return df, new_cols


# =====================================================================================
# MATRIX BUILDERS / DATASETS
# =====================================================================================

class Split:
    def __init__(self, X, y, c, doy, dow, df):
        self.X=X; self.y=y; self.c=c; self.doy=doy; self.dow=dow; self.df=df

def build_mats(df, feat, target="daily_cases"):
    le = LabelEncoder().fit(df["district"].values)

    d = df.copy()
    d["district_id"] = le.transform(d["district"])

    X = d[feat].values.astype(np.float32)
    feat_scaler = StandardScaler().fit(X)
    Xs = feat_scaler.transform(X)

    y = np.log1p(np.clip(d[target].values.reshape(-1,1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y, np.float32)
    y_scalers = {}
    for cid, g in d.groupby("district_id"):
        idx = g.index.values
        sc = StandardScaler().fit(y[idx])
        y_scalers[cid] = sc
        Ys[idx] = sc.transform(y[idx])

    doy = (d["doy"].values.astype(int)-1).astype(np.int64)  # 0..365
    dow = d["dow"].values.astype(np.int64)                  # 0..6
    return Split(Xs, Ys, d["district_id"].values.astype(np.int64), doy, dow, d), le, feat_scaler, y_scalers

def apply_mats(df, feat, target, le, feat_scaler, y_scalers):
    d = df.copy()
    d["district_id"] = le.transform(d["district"])

    X = d[feat].values.astype(np.float32)
    Xs = feat_scaler.transform(X)

    y = np.log1p(np.clip(d[target].values.reshape(-1,1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y, np.float32)
    for cid, g in d.groupby("district_id"):
        idx = g.index.values
        sc = y_scalers.get(int(cid), StandardScaler().fit(y[idx]))
        Ys[idx] = sc.transform(y[idx])

    doy = (d["doy"].values.astype(int)-1).astype(np.int64)
    dow = d["dow"].values.astype(np.int64)
    return Split(Xs, Ys, d["district_id"].values.astype(np.int64), doy, dow, d)

def to_seq(split, L=SEQ, prev_y=True):
    X,y,c,df = split.X, split.y, split.c, split.df
    doy_idx, dow_idx = split.doy, split.dow

    # ordinal index for "consecutive days" check
    ord_idx = split.df["date"].values.astype("datetime64[D]").astype(np.int64)

    SX,SY,SC,SDoY,SDow = [],[],[],[],[]
    for cid in np.unique(c):
        idx = np.where(c==cid)[0]
        order = np.argsort(ord_idx[idx])
        idx = idx[order]
        o   = ord_idx[idx]
        for i in range(len(idx)-L+1):
            sl = idx[i:i+L]
            # require strictly consecutive dates
            if np.all(np.diff(o[i:i+L]) == 1):
                Xi = X[sl]
                Yi = y[sl]
                if prev_y:
                    prev = np.vstack([np.zeros((1,1),np.float32), Yi[:-1]])
                    Xi = np.concatenate([Xi, prev], axis=1)
                SX.append(Xi)
                SY.append(Yi)
                SC.append(cid)
                SDoY.append(doy_idx[sl])
                SDow.append(dow_idx[sl])

    return (
        np.asarray(SX, np.float32),
        np.asarray(SY, np.float32),
        np.asarray(SC, np.int64),
        np.asarray(SDoY, np.int64),
        np.asarray(SDow, np.int64)
    )


# =====================================================================================
# MODEL
# =====================================================================================

class Forecaster(nn.Module):
    """
    LSTM + causal self-attention with district/time embeddings;
    outputs mean, log-sigma, and quantiles.
    """
    def __init__(self, cond_dim, n_district,
                 emb_dist=16, emb_doy=12, emb_dow=6,
                 lstm=LSTM_UNITS, heads=HEADS, drop=DROP, qu=Q):
        super().__init__()
        self.q = qu
        self.ed = nn.Embedding(n_district, emb_dist)
        self.et_doy = nn.Embedding(366, emb_doy)
        self.et_dow = nn.Embedding(7,   emb_dow)
        in_dim = cond_dim + emb_dist + emb_doy + emb_dow

        self.lstm = nn.LSTM(in_dim, lstm, 1, batch_first=True)
        self.mha  = nn.MultiheadAttention(lstm, heads, batch_first=True, dropout=drop)
        self.ln   = nn.LayerNorm(lstm)
        self.drop = nn.Dropout(drop)

        self.mu = nn.Linear(lstm,1)
        self.ls = nn.Linear(lstm,1)
        self.qh = nn.ModuleList([nn.Linear(lstm,1) for _ in qu])

        self.last_attn = None

    def forward(self, cond, cid, doy, dow):
        B,L,D = cond.shape
        e_dist = self.ed(cid).unsqueeze(1).repeat(1,L,1)  # [B,L,emb_dist]
        e_doy  = self.et_doy(doy)                         # [B,L,emb_doy]
        e_dow  = self.et_dow(dow)                         # [B,L,emb_dow]
        x = torch.cat([cond, e_dist, e_doy, e_dow], dim=-1)

        h,_  = self.lstm(x)

        # causal mask for attention
        mask = torch.triu(torch.ones(L, L, device=h.device, dtype=torch.bool), diagonal=1)
        att, w = self.mha(h,h,h,attn_mask=mask,need_weights=True)
        self.last_attn = w.detach()

        h = self.drop(self.ln(att))

        mu = self.mu(h)
        ls = torch.clamp(self.ls(h), -5.0, 3.0)
        qs = [head(h) for head in self.qh]

        return mu, ls, qs


# =====================================================================================
# TRAINING / VALIDATION
# =====================================================================================

def smape(y, p):
    y = y.flatten()
    p = p.flatten()
    return 100*np.mean(2*np.abs(p-y)/(np.abs(y)+np.abs(p)+1e-8))

def pinball(pred, target, quantile):
    err = target - pred
    return torch.mean(torch.maximum(quantile * err, (quantile - 1) * err))

def tv_l1_on_mu(mu):
    diff = mu[:,1:,:] - mu[:,:-1,:]
    return diff.abs().mean(), mu.abs().mean()

class SeqDS(Dataset):
    def __init__(self,X,Y,C,DoY,DoW):
        self.X=X; self.Y=Y; self.C=C; self.DoY=DoY; self.DoW=DoW
    def __len__(self): return self.X.shape[0]
    def __getitem__(self,i):
        return self.X[i], self.Y[i], self.C[i], self.DoY[i], self.DoW[i]

def validate_model(M,Xv,Yv,Cv,DvY,DvW):
    M.eval()
    with torch.no_grad():
        X  = torch.tensor(Xv, dtype=torch.float32, device=DEVICE)
        C  = torch.tensor(Cv, dtype=torch.long,   device=DEVICE)
        DY = torch.tensor(DvY,dtype=torch.long,   device=DEVICE)
        DW = torch.tensor(DvW,dtype=torch.long,   device=DEVICE)

        mu, ls, qs = M(X, C, DY, DW)

        mp  = mu[...,0].cpu().numpy()
        y   = Yv[...,0]
        q10 = qs[0][...,0].cpu().numpy()
        q90 = qs[2][...,0].cpu().numpy()

        coverage = float(np.mean((y>=q10) & (y<=q90)))
        score    = smape(y.reshape(-1), mp.reshape(-1))

    return score, coverage

def train_model(Xtr,Ytr,Ctr,DtrY,DtrW,
                Xva,Yva,Cva,DvaY,DvaW,
                cond_dim,nc):
    M = Forecaster(cond_dim=cond_dim, n_district=nc).to(DEVICE)
    opt = torch.optim.AdamW(M.parameters(),
                            lr=LR,
                            betas=(0.9,0.999),
                            weight_decay=WD)

    dl = DataLoader(SeqDS(Xtr,Ytr,Ctr,DtrY,DtrW),
                    batch_size=BATCH,
                    shuffle=True,
                    drop_last=True)

    best=float("inf")
    best_sd=None
    wait=0

    for e in range(EPOCHS):
        # scheduled sampling factor for prev_y channel
        t  = e / max(1,(EPOCHS-1))
        tf = TF_START + (TF_END-TF_START)*t

        M.train()
        for Xb,Yb,Cb,DYb,DWb in dl:
            Xb  = Xb.to(torch.float32).to(DEVICE)
            Yb  = Yb.to(torch.float32).to(DEVICE)
            Cb  = Cb.to(torch.long).to(DEVICE)
            DYb = DYb.to(torch.long).to(DEVICE)
            DWb = DWb.to(torch.long).to(DEVICE)

            B,L,_ = Xb.shape

            # teacher forcing bleed for prev_y (last feature in X)
            with torch.no_grad():
                mu0,_,_ = M(Xb,Cb,DYb,DWb)
                prev = torch.cat(
                    [torch.zeros(B,1,device=DEVICE), mu0[:,:-1,0]],
                    dim=1
                )
            Xb[:,:,-1] = tf*Xb[:,:,-1] + (1-tf)*prev

            opt.zero_grad(set_to_none=True)

            mu, ls, qs = M(Xb,Cb,DYb,DWb)

            sig = (ls.exp()).clamp(1e-3,50.0)
            nll = 0.5*(((Yb-mu)/sig)**2 + 2*ls + math.log(2*math.pi)).mean()

            ql = 0.0
            for i, qv in enumerate(Q):
                ql += pinball(qs[i],Yb,qv)
            ql /= len(Q)

            tv, l1 = tv_l1_on_mu(mu)
            loss = 1.0*nll + 1.0*ql + 0.05*tv + 0.02*l1

            loss.backward()
            torch.nn.utils.clip_grad_norm_(M.parameters(),CLIP)
            opt.step()

        sm, cov = validate_model(M,Xva,Yva,Cva,DvaY,DvaW)
        comp = sm + 10*abs(cov-0.9)

        if comp < best-1e-6:
            best    = comp
            best_sd = copy.deepcopy(M.state_dict())
            wait    = 0
        else:
            wait += 1
            if wait >= PATIENCE:
                break

    if best_sd is not None:
        M.load_state_dict(best_sd)
    return M


# =====================================================================================
# EVALUATION HELPERS FOR DIRECTOR TABLES
# =====================================================================================

def inverse_seq(arr_scaled, scaler):
    # arr_scaled shape (L,) or (L,1)
    arr_scaled = np.array(arr_scaled).reshape(-1,1)
    unstd = scaler.inverse_transform(arr_scaled).reshape(-1)
    return np.clip(np.expm1(unstd), 0, None)

def classify_growth(g):
    if g >= RAPID_GROWTH_THRESH:
        return "🔴 Rapid growth"
    elif g >= MODERATE_GROWTH_THRESH:
        return "🟠 Moderate growth"
    else:
        return "🟢 Stable/decline"

def classify_risk(high_scenario):
    if high_scenario >= 40:
        return "🔴 Surge likely"
    elif high_scenario >= 20:
        return "🟠 Elevated"
    else:
        return "🟢 Low"

def classify_conf(width):
    if width <= CONF_HIGH:
        return "✅ High confidence"
    elif width <= CONF_MED:
        return "🟡 Medium confidence"
    else:
        return "⚠ Low confidence"

def classify_data_quality(continuity_pct):
    if continuity_pct >= 90:
        return "✅ Reliable"
    elif continuity_pct >= 70:
        return "🟡 Watch"
    else:
        return "🔴 Needs field check"

def calc_reporting_continuity_daily(df_full_raw, lookback_days=SEQ):
    """
    For each district, % of days in the last `lookback_days` that had any report.
    daily_cases==0 counts as reported (we assume 0 means reported zero, not missing).
    We'll treat NaN as not reported.
    """
    out_rows = []
    for dist, g in df_full_raw.groupby("district"):
        g = g.sort_values("date")
        recent = g.tail(lookback_days)
        # "reported" means not NaN
        reported_mask = ~recent["daily_cases"].isna()
        continuity_pct = 100.0 * reported_mask.mean()
        out_rows.append({
            "district": dist,
            f"Reporting_Continuity_Last{lookback_days}d(%)": continuity_pct
        })
    return pd.DataFrame(out_rows)

def evaluate_and_build_products(
    M,
    Xseq, Yseq, Cseq, DoYseq, DoWseq,
    y_scalers, le, df_full_raw,
    feat_cols
):
    """
    Build director tables using last two forecast steps in each sequence:
      - "current day" ~ second last day in sequence
      - "next day"    ~ last day in sequence
    We'll aggregate to pseudo-week by scaling these to 7d equivalence to match dengue-style tables.
    The epi_week/year values come from ISO calendar of the last observed date per district.
    """

    M.eval()

    X_t = torch.tensor(Xseq, dtype=torch.float32, device=DEVICE)
    C_t = torch.tensor(Cseq, dtype=torch.long, device=DEVICE)
    DY_t= torch.tensor(DoYseq,dtype=torch.long, device=DEVICE)
    DW_t= torch.tensor(DoWseq,dtype=torch.long, device=DEVICE)

    with torch.no_grad():
        mu, ls, qs = M(X_t, C_t, DY_t, DW_t)
        mu_np  = mu.cpu().numpy()[...,0]      # [N,L]
        q10_np = qs[0].cpu().numpy()[...,0]
        q50_np = qs[1].cpu().numpy()[...,0]
        q90_np = qs[2].cpu().numpy()[...,0]

    # indexes
    idx_curr = -2  # "this period"
    idx_next = -1  # "forecast next period"

    # we'll track per-district aggregates
    district_names = []
    last_actual_list = []
    current_actual_list = []
    current_pred_list = []
    next_day_pred_list = []
    next_day_hi_list = []
    next_day_lo_list = []
    growth_rate_list = []
    conf_width_list = []
    peak_day_list = []
    peak_value_list = []
    peak_hi_list = []
    peak_leadtime_list = []
    epiweek_meta_rows = []

    # metrics arrays for global calibration
    all_y_real = []
    all_mu_real = []
    all_lo_real = []
    all_hi_real = []

    # We'll also need last epi-week/year per district
    last_epi_meta = {}

    for i, cid in enumerate(Cseq):
        sc = y_scalers[int(cid)]
        dist_name = le.inverse_transform([cid])[0]

        # inverse scale full sequence
        y_real_full   = inverse_seq(Yseq[i].reshape(-1),        sc)
        mu_real_full  = inverse_seq(mu_np[i].reshape(-1),       sc)
        lo_real_full  = inverse_seq(q10_np[i].reshape(-1),      sc)
        hi_real_full  = inverse_seq(q90_np[i].reshape(-1),      sc)

        all_y_real.append(y_real_full)
        all_mu_real.append(mu_real_full)
        all_lo_real.append(lo_real_full)
        all_hi_real.append(hi_real_full)

        safe_idx_curr = idx_curr if idx_curr >= -len(y_real_full) else -1
        safe_idx_next = idx_next if idx_next >= -len(y_real_full) else -1

        # observations/preds
        actual_curr  = y_real_full[safe_idx_curr]
        pred_curr    = mu_real_full[safe_idx_curr]
        actual_prev  = y_real_full[safe_idx_curr-1] if (safe_idx_curr-1)>=-len(y_real_full) else y_real_full[safe_idx_curr]

        pred_next    = mu_real_full[safe_idx_next]
        hi_next      = hi_real_full[safe_idx_next]
        lo_next      = lo_real_full[safe_idx_next]

        # growth % from prev actual to next predicted
        growth = (pred_next - actual_prev) / (actual_prev + 1e-6)

        # predictive band width
        conf_width = hi_next - lo_next

        # peak projection over the last PEAK_LOOKAHEAD_DAYS of mu_real_full
        look_slice = mu_real_full[-PEAK_LOOKAHEAD_DAYS:]
        hi_slice   = hi_real_full[-PEAK_LOOKAHEAD_DAYS:]
        if len(look_slice)==0:
            peak_when="NA"; peak_val=np.nan; peak_hi=np.nan; lead_days=np.nan
        else:
            local_max_idx = int(np.argmax(look_slice))
            peak_val = look_slice[local_max_idx]
            peak_hi  = hi_slice[local_max_idx]
            # lead_days from "now" (end)
            lead_days = (len(look_slice)-1) - local_max_idx
            peak_when = f"t+{lead_days}d"

        # epi_year/week from the last actual date we have for this district in df_full_raw
        gdist = df_full_raw[df_full_raw["district"]==dist_name].sort_values("date")
        if len(gdist) > 0:
            lastrow = gdist.iloc[-1]
            last_date = pd.to_datetime(lastrow["date"])
            iso_year, iso_week, _ = last_date.isocalendar()
            last_year = int(iso_year)
            last_epi  = int(iso_week)
            last_actual_week_cases = float(
                gdist.tail(7)["daily_cases"].sum()
            )
            this_week_actual = float(
                gdist.tail(1)["daily_cases"].sum()
            )
        else:
            last_year = None
            last_epi  = None
            last_date = None
            last_actual_week_cases = float(actual_prev)  # fallback
            this_week_actual       = float(actual_curr)

        # convert "next day forecast" to a pseudo "next week forecast"
        # scale by 7 for dashboard-style weekly numbers
        next_week_forecast = pred_next * 7.0
        next_week_hi       = hi_next   * 7.0
        next_week_lo       = lo_next   * 7.0
        this_week_pred     = pred_curr * 7.0

        # store for tables
        district_names.append(dist_name)
        last_actual_list.append(last_actual_week_cases)
        current_actual_list.append(this_week_actual)
        current_pred_list.append(this_week_pred)
        next_day_pred_list.append(next_week_forecast)
        next_day_hi_list.append(next_week_hi)
        next_day_lo_list.append(next_week_lo)
        growth_rate_list.append(growth)
        conf_width_list.append(next_week_hi - next_week_lo)
        peak_day_list.append(peak_when)
        peak_value_list.append(peak_val * 7.0)
        peak_hi_list.append(peak_hi * 7.0)
        peak_leadtime_list.append(lead_days/7.0 if not isinstance(lead_days,float) or not math.isnan(lead_days) else np.nan)

        last_epi_meta[dist_name] = {
            "Year": last_year,
            "Epi_Week": last_epi
        }

        epiweek_meta_rows.append({
            "District": dist_name,
            "Year": last_year,
            "Epi_Week": last_epi,
            "Forecast_next_week_median": float(next_week_forecast),
            "Forecast_next_week_hi": float(next_week_hi),
            "Growth_flag": classify_growth(growth)
        })

    # global calibration metrics
    all_y_real = np.concatenate(all_y_real)
    all_mu_real = np.concatenate(all_mu_real)
    all_lo_real = np.concatenate(all_lo_real)
    all_hi_real = np.concatenate(all_hi_real)

    overall_metrics = {
        "SMAPE": smape(all_y_real, all_mu_real),
        "MSE": mean_squared_error(all_y_real, all_mu_real),
        "RMSE": math.sqrt(mean_squared_error(all_y_real, all_mu_real)),
        "R2": r2_score(all_y_real, all_mu_real),
        "Coverage90": float(np.mean((all_y_real >= all_lo_real) & (all_y_real <= all_hi_real))),
    }

    # data quality
    quality_df_raw = calc_reporting_continuity_daily(df_full_raw, lookback_days=SEQ)
    rep_col = [c for c in quality_df_raw.columns if "Reporting_Continuity_Last" in c][0]
    quality_df = pd.merge(
        pd.DataFrame({"District": district_names}),
        quality_df_raw.rename(columns={"district":"District"}),
        on="District",
        how="left"
    )
    quality_df["Data_quality_flag"] = quality_df[rep_col].apply(classify_data_quality)

    # watchlist table
    watchlist_df = pd.DataFrame({
        "District": district_names,
        "Expected_cases_next_week": np.round(next_day_pred_list,2),
        "High_scenario_p90":       np.round(next_day_hi_list,2),
    })
    watchlist_df["Status"] = watchlist_df["High_scenario_p90"].apply(classify_risk)
    watchlist_df = watchlist_df.sort_values("High_scenario_p90", ascending=False)

    # overflow risk
    capacity_lookup = {d: DEFAULT_CAPACITY_PER_DISTRICT for d in district_names}
    overflow_df = pd.DataFrame({
        "District": district_names,
        "Capacity_threshold_beds_per_week": [
            capacity_lookup[d] for d in district_names
        ],
        "Forecast_median_next_week": np.round(next_day_pred_list,2),
        "High_scenario_p90":        np.round(next_day_hi_list,2),
    })
    overflow_df["Breach_risk_flag"] = [
        "YES" if hi > capacity_lookup[d] else "NO"
        for d, hi in zip(district_names, next_day_hi_list)
    ]
    overflow_df = overflow_df.sort_values("High_scenario_p90", ascending=False)

    # acceleration alerts
    accel_df = pd.DataFrame({
        "District": district_names,
        "Last_week_cases":        np.round(last_actual_list,2),
        "This_week_actual":       np.round(current_actual_list,2),
        "This_week_predicted":    np.round(current_pred_list,2),
        "Next_week_forecast":     np.round(next_day_pred_list,2),
        "Growth_rate_WoW(%)":     np.round(np.array(growth_rate_list)*100,1),
    })
    accel_df["Growth_flag"] = accel_df["Growth_rate_WoW(%)"].apply(lambda pct: classify_growth(pct/100.0))
    accel_df = accel_df.sort_values("Growth_rate_WoW(%)", ascending=False)

    # forecast confidence
    confidence_df = pd.DataFrame({
        "District": district_names,
        "Forecast_next_week":            np.round(next_day_pred_list,2),
        "Uncertainty_width(p90-p10)":    np.round(conf_width_list,2),
    })
    confidence_df["Confidence_flag"] = confidence_df["Uncertainty_width(p90-p10)"].apply(classify_conf)
    confidence_df = confidence_df.sort_values("Uncertainty_width(p90-p10)", ascending=True)

    # peak projection
    peak_df = pd.DataFrame({
        "District": district_names,
        "Peak_lead_time_weeks": peak_leadtime_list,
        "Peak_cases_median":    np.round(peak_value_list,2),
        "Peak_cases_high(p90)": np.round(peak_hi_list,2),
        "Peak_when":            peak_day_list
    }).sort_values("Peak_cases_median", ascending=False)

    # isochrone-style spread table
    isochrone_df = pd.DataFrame(epiweek_meta_rows)

    # nowcast gap
    gap_pct_list = []
    for a, p in zip(current_actual_list, current_pred_list):
        gap_pct = (p - a)/(a + 1e-6)*100.0
        gap_pct_list.append(gap_pct)
    nowcast_gap_df = pd.DataFrame({
        "District": district_names,
        "Current_week_actual":               np.round(current_actual_list,2),
        "Current_week_predicted_from_prev":  np.round(current_pred_list,2),
        "Nowcast_gap_percent":               np.round(gap_pct_list,1),
    }).sort_values("Nowcast_gap_percent", ascending=False)

    # ---------- Climate / lag influence with real variable names ----------
    def parse_feature_name(raw_name: str):
        """
        Examples:
          rainfall_lag7      -> base_var='rainfall', lag_info='lag 7d'
          temperature_rmean7 -> base_var='temperature', lag_info='7d rolling mean'
          daily_cases_rstd14 -> base_var='daily_cases', lag_info='14d rolling std'
          temperature        -> base_var='temperature', lag_info='same day'
          doy_sin            -> base_var='season(doy)', lag_info=''
          dow_cos            -> base_var='weekday(dow)', lag_info=''
        """
        name = raw_name

        # seasonal encodings
        if name.startswith("doy_"):
            return "season(doy)", ""
        if name.startswith("dow_"):
            return "weekday(dow)", ""

        # rolling mean/std patterns
        if "_rmean" in name:
            base, tail = name.split("_rmean")
            return base, f"{tail}d rolling mean"
        if "_rstd" in name:
            base, tail = name.split("_rstd")
            return base, f"{tail}d rolling std"

        # lag pattern
        if "_lag" in name:
            base, tail = name.split("_lag")
            return base, f"lag {tail}d"

        # plain same-day vars
        return name, "same day"

    # use the feature vector from the "current" timestep (-2), excluding prev_y channel at end
    last_feats_matrix = Xseq[:, -2, :-1]  # shape: [num_seq, num_features]
    preds_arr = np.array(next_day_pred_list)

    feat_import_rows = []
    for fi in range(last_feats_matrix.shape[1]):
        col_vals = last_feats_matrix[:, fi]

        if np.std(col_vals) < 1e-8:
            corr = 0.0
        else:
            corr = np.corrcoef(col_vals, preds_arr)[0, 1]

        raw_feat_name = feat_cols[fi]
        base_var, lag_info = parse_feature_name(raw_feat_name)

        feat_import_rows.append({
            "feature": raw_feat_name,
            "base_var": base_var,
            "lag_info": lag_info,
            "pearson_corr_with_next_week_forecast": float(corr),
        })

    climate_influence_df = pd.DataFrame(feat_import_rows)
    climate_influence_df["abs_corr"] = (
        climate_influence_df["pearson_corr_with_next_week_forecast"].abs()
    )
    climate_influence_df = climate_influence_df.sort_values(
        "abs_corr", ascending=False
    )
    # ---------------------------------------------------------------------

    # Executive summary
    top5 = watchlist_df.head(5)[["District","Expected_cases_next_week","High_scenario_p90","Status"]].to_dict(orient="records")
    rapid = accel_df[accel_df["Growth_flag"]=="🔴 Rapid growth"]["District"].tolist()
    overflow_risk = overflow_df[overflow_df["Breach_risk_flag"]=="YES"]["District"].tolist()
    dq_bad = quality_df[quality_df["Data_quality_flag"]=="🔴 Needs field check"]["District"].tolist()
    summary_text = {
        "Top5_high_risk_next_week": to_python(top5),
        "Districts_with_rapid_growth": to_python(rapid),
        "Districts_with_capacity_breach_risk": to_python(overflow_risk),
        "Districts_with_data_quality_issues": to_python(dq_bad),
        "Model_calibration": {
            "Coverage90": round(float(overall_metrics["Coverage90"]),3),
            "SMAPE": round(float(overall_metrics["SMAPE"]),2),
            "R2": round(float(overall_metrics["R2"]),3),
        }
    }

    products = {
        "watchlist_df": watchlist_df,
        "overflow_df": overflow_df,
        "accel_df": accel_df,
        "confidence_df": confidence_df,
        "quality_df": quality_df.rename(columns={rep_col:"reporting_continuity_pct"}),
        "peak_df": peak_df,
        "isochrone_df": isochrone_df.rename(columns={
            "District":"District",
            "Year":"Year",
            "Epi_Week":"Epi_Week",
            "Forecast_next_week_median":"Forecast_next_week_median",
            "Forecast_next_week_hi":"Forecast_next_week_hi",
            "Growth_flag":"Growth_flag",
        }),
        "nowcast_gap_df": nowcast_gap_df,
        "climate_influence_df": climate_influence_df,
        "executive_summary": summary_text,
        "overall_metrics": overall_metrics,
        "last_epi_meta": last_epi_meta,
    }

    return overall_metrics, products


# =====================================================================================
# PUSH TO diarrhoea_* TABLES
# =====================================================================================

def push_products_to_db_diarrhoea(conn, products):
    last_epi_meta = products["last_epi_meta"]

    # watchlist
    watchlist_rows = []
    for _, r in products["watchlist_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        watchlist_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Expected_cases_next_week"]),
            py(r["High_scenario_p90"]),
            r["Status"]
        ))
    upsert_table(
        conn,
        "diarrhoea_watchlist",
        ["district","year","epi_week","expected_cases_next_week","high_scenario_p90","status"],
        watchlist_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # overflow risk
    overflow_rows = []
    for _, r in products["overflow_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        overflow_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            int(r["Capacity_threshold_beds_per_week"]) if not pd.isna(r["Capacity_threshold_beds_per_week"]) else None,
            py(r["Forecast_median_next_week"]),
            py(r["High_scenario_p90"]),
            r["Breach_risk_flag"],
        ))
    upsert_table(
        conn,
        "diarrhoea_overflow_risk",
        ["district","year","epi_week","capacity_threshold_beds_per_week","forecast_median_next_week","high_scenario_p90","breach_risk_flag"],
        overflow_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # acceleration alerts
    accel_rows = []
    for _, r in products["accel_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        accel_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Last_week_cases"]),
            py(r["This_week_actual"]),
            py(r["This_week_predicted"]),
            py(r["Next_week_forecast"]),
            py(r["Growth_rate_WoW(%)"]),
            r["Growth_flag"],
        ))
    upsert_table(
        conn,
        "diarrhoea_acceleration_alerts",
        ["district","year","epi_week","last_week_cases","this_week_actual","this_week_predicted","next_week_forecast","growth_rate_wow","growth_flag"],
        accel_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # confidence
    conf_rows = []
    for _, r in products["confidence_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        conf_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Forecast_next_week"]),
            py(r["Uncertainty_width(p90-p10)"]),
            r["Confidence_flag"],
        ))
    upsert_table(
        conn,
        "diarrhoea_confidence",
        ["district","year","epi_week","forecast_next_week","uncertainty_width","confidence_flag"],
        conf_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # surveillance quality
    qual_rows = []
    for _, r in products["quality_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        qual_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["reporting_continuity_pct"]),
            r["Data_quality_flag"],
        ))
    upsert_table(
        conn,
        "diarrhoea_surveillance_quality",
        ["district","year","epi_week","reporting_continuity_pct","data_quality_flag"],
        qual_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # peak projection
    peak_rows = []
    for _, r in products["peak_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        peak_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Peak_lead_time_weeks"]),
            py(r["Peak_cases_median"]),
            py(r["Peak_cases_high(p90)"]),
            r["Peak_when"],
        ))
    upsert_table(
        conn,
        "diarrhoea_peak_projection",
        ["district","year","epi_week","peak_lead_time_weeks","peak_cases_median","peak_cases_high_p90","peak_when"],
        peak_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # isochrone-style spread
    iso_rows = []
    for _, r in products["isochrone_df"].iterrows():
        iso_rows.append((
            r["District"],
            r["Year"],
            r["Epi_Week"],
            py(r["Forecast_next_week_median"]),
            py(r["Forecast_next_week_hi"]),
            r["Growth_flag"],
        ))
    upsert_table(
        conn,
        "diarrhoea_isochrone_spread",
        ["district","year","epi_week","forecast_next_week_median","forecast_next_week_hi","growth_flag"],
        iso_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # nowcast gap
    gap_rows = []
    for _, r in products["nowcast_gap_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        gap_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Current_week_actual"]),
            py(r["Current_week_predicted_from_prev"]),
            py(r["Nowcast_gap_percent"]),
        ))
    upsert_table(
        conn,
        "diarrhoea_nowcast_gap",
        ["district","year","epi_week","current_week_actual","current_week_predicted_from_prev","nowcast_gap_percent"],
        gap_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # climate influence
    clim_rows = []
    for _, r in products["climate_influence_df"].iterrows():
        clim_rows.append((
            r["feature"],
            r["base_var"],
            r["lag_info"],
            py(r["pearson_corr_with_next_week_forecast"]),
            py(r["abs_corr"]),
        ))
    upsert_table(
        conn,
        "diarrhoea_climate_influence",
        ["feature","base_var","lag_info","pearson_corr_with_next_week_forecast","abs_corr"],
        clim_rows,
        pk_cols=None,
        wipe_first=True
    )

    # exec summary
    summary_json = json.dumps(to_python(products["executive_summary"]))
    upsert_table(
        conn,
        "diarrhoea_exec_summary",
        ["summary"],
        [(summary_json,)],
        pk_cols=None,
        wipe_first=True
    )


# =====================================================================================
# MAIN PIPELINE
# =====================================================================================

def main():
    print("1) Load diarrhoea daily data from Postgres")
    df_daily = load_daily_from_db()

    print("2) Add lags / rolling stats")
    df_daily, new_cols = add_lags_rolls_daily(df_daily)

    # feature columns
    feat_cols = [
        "temperature","humidity","rainfall",
        "doy_sin","doy_cos","dow_sin","dow_cos"
    ] + new_cols

    target_col = "daily_cases"

    # 3) Train/Val/Test split per district (daily horizon logic)
    VAL_H = 14
    TEST_H = 21
    tr_parts, va_parts, te_parts = [], [], []
    for dist, g in df_daily.groupby("district"):
        g = g.sort_values("date")
        if len(g) > (VAL_H + TEST_H):
            te_parts.append(g.iloc[-TEST_H:])
            va_parts.append(g.iloc[-(VAL_H + TEST_H):-TEST_H])
            tr_parts.append(g.iloc[:-(VAL_H + TEST_H)])
        elif len(g) > TEST_H:
            te_parts.append(g.iloc[-TEST_H:])
            tr_parts.append(g.iloc[:-TEST_H])
        else:
            tr_parts.append(g)
    tr_df = pd.concat(tr_parts).reset_index(drop=True)
    va_df = pd.concat(va_parts).reset_index(drop=True) if va_parts else tr_df.iloc[0:0].copy()
    te_df = pd.concat(te_parts).reset_index(drop=True) if te_parts else tr_df.iloc[0:0].copy()

    df_full_raw = df_daily.copy()

    print("3) Fit label encoders / scalers on TRAIN")
    split_tr, le_tr, feat_scaler_tr, y_scalers_tr = build_mats(tr_df, feat_cols, target_col)

    X_tr, Y_tr, C_tr, DoY_tr, DoW_tr = to_seq(split_tr, SEQ, prev_y=True)
    if X_tr.shape[0] == 0:
        raise RuntimeError("No train sequences. Reduce SEQ or check data continuity.")
    print(f"Train sequences: {X_tr.shape}")

    # build val sequences
    if len(va_df) > 0:
        split_va = apply_mats(va_df, feat_cols, target_col, le_tr, feat_scaler_tr, y_scalers_tr)
        X_va, Y_va, C_va, DoY_va, DoW_va = to_seq(split_va, SEQ, prev_y=True)
        if X_va.shape[0] == 0:
            X_va, Y_va, C_va, DoY_va, DoW_va = X_tr, Y_tr, C_tr, DoY_tr, DoW_tr
    else:
        X_va, Y_va, C_va, DoY_va, DoW_va = X_tr, Y_tr, C_tr, DoY_tr, DoW_tr

    # build test sequences
    if len(te_df) > 0:
        split_te = apply_mats(te_df, feat_cols, target_col, le_tr, feat_scaler_tr, y_scalers_tr)
        X_te, Y_te, C_te, DoY_te, DoW_te = to_seq(split_te, SEQ, prev_y=True)
        if X_te.shape[0] == 0:
            X_te, Y_te, C_te, DoY_te, DoW_te = X_va, Y_va, C_va, DoY_va, DoW_va
    else:
        X_te, Y_te, C_te, DoY_te, DoW_te = X_va, Y_va, C_va, DoY_va, DoW_va

    nc = int(C_tr.max()) + 1
    cond_dim = X_tr.shape[2]

    print("4) Train diarrhoea model + early stop on val")
    M = train_model(
        X_tr, Y_tr, C_tr, DoY_tr, DoW_tr,
        X_va, Y_va, C_va, DoY_va, DoW_va,
        cond_dim, nc
    )
    torch.save(M.state_dict(), os.path.join(OUT,"diarrhoea_forecaster.pt"))

    print("5) Build director tables (using TEST split, like dengue dashboard style)")
    overall_metrics, products = evaluate_and_build_products(
        M,
        X_te, Y_te, C_te, DoY_te, DoW_te,
        y_scalers_tr, le_tr, df_full_raw,
        feat_cols
    )

    # save local artifacts for inspection
    products["watchlist_df"].to_csv(os.path.join(OUT,"watchlist.csv"), index=False)
    products["overflow_df"].to_csv(os.path.join(OUT,"overflow_risk.csv"), index=False)
    products["accel_df"].to_csv(os.path.join(OUT,"acceleration_alerts.csv"), index=False)
    products["confidence_df"].to_csv(os.path.join(OUT,"forecast_confidence.csv"), index=False)
    products["quality_df"].to_csv(os.path.join(OUT,"surveillance_quality.csv"), index=False)
    products["peak_df"].to_csv(os.path.join(OUT,"peak_projection.csv"), index=False)
    products["isochrone_df"].to_csv(os.path.join(OUT,"isochrone_spread.csv"), index=False)
    products["nowcast_gap_df"].to_csv(os.path.join(OUT,"nowcast_gap.csv"), index=False)
    products["climate_influence_df"].to_csv(os.path.join(OUT,"climate_lag_influence.csv"), index=False)

    dash_manifest = {
        "executive_summary": products["executive_summary"],
        "overall_metrics": products["overall_metrics"],
        "notes": "Auto-generated diarrhoea decision tables (daily model projected to weekly scale)."
    }
    with open(os.path.join(OUT,"DASHBOARD_SUMMARY.json"), "w") as f:
        json.dump(to_python(dash_manifest), f, indent=2)

    print("6) Upsert diarrhoea_* tables in Postgres")
    conn = get_db_conn()
    ensure_output_tables_diarrhoea(conn)
    push_products_to_db_diarrhoea(conn, products)
    conn.close()

    print("\n=== Executive Summary Preview ===")
    print(json.dumps(to_python(dash_manifest["executive_summary"]), indent=2))
    print("\nArtifacts saved to", OUT, "and pushed to Postgres.\nDone.")


if __name__ == "__main__":
    main()