In [None]:
import os
import math
import json
import copy
import random
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, r2_score

import psycopg2
from psycopg2.extras import execute_values

# ----------------- Helpers for serialization / DB safety -----------------

def to_python(obj):
    """
    Recursively convert numpy / pandas scalars, arrays, and NaN into
    vanilla Python objects so they're JSON serializable.
    """
    if isinstance(obj, (np.floating, np.float32, np.float64)):
        val = float(obj)
        if math.isnan(val):
            return None
        return val
    if isinstance(obj, (np.integer, np.int32, np.int64)):
        return int(obj)
    if isinstance(obj, (np.bool_,)):
        return bool(obj)

    if isinstance(obj, np.ndarray):
        return [to_python(x) for x in obj.tolist()]

    if isinstance(obj, (pd.Timestamp,)):
        if pd.isna(obj):
            return None
        return obj.isoformat()

    if isinstance(obj, float):
        if math.isnan(obj):
            return None
        return obj

    if isinstance(obj, dict):
        return {k: to_python(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [to_python(v) for v in obj]

    return obj

def py(v):
    """
    Cast values for Postgres. Keep table schemas UNCHANGED.
    """
    if isinstance(v, (np.floating, np.float32, np.float64)):
        v = float(v)
    elif isinstance(v, (np.integer, np.int32, np.int64)):
        v = int(v)
    elif isinstance(v, (np.bool_,)):
        v = bool(v)

    if isinstance(v, float) and (math.isnan(v)):
        return None
    return v

# ----------------- Config -----------------

PG_HOST    = "119.148.17.102"
PG_PORT    = 5432
PG_DB      = "ewarsdb"
PG_USER    = "ewars"
PG_PASS    = "Iedcr@Ewars2025"
PG_SSLMODE = "require"

OUT = "/content/weekly_dengue_out"
os.makedirs(OUT, exist_ok=True)

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEQ = 16
BATCH = 64
EPOCHS_GAN = 400
PATIENCE = 25
NOISE = 12
LSTM_UNITS = 128
HEADS = 8
DROP = 0.25
LR_G = 8e-4
LR_D = 3e-4
WD = 1e-4
CLIP = 1.0
TTUR = (1.0, 1.0)
TF_START = 1.0
TF_END = 0.3
ALPHA = 0.5
K_SYNC = 5

Q = (0.1, 0.5, 0.9)
ABL = {"adv": True, "hetero": True, "quant": True}

NSIG = (0.1, 1.2)
K_MC = 50
VAL_H_WEEKS = 6
TEST_H_WEEKS = 6

DEFAULT_CAPACITY_PER_DISTRICT = 25
RAPID_GROWTH_THRESH = 0.30
MODERATE_GROWTH_THRESH = 0.10
PEAK_LOOKAHEAD_WEEKS = 6

# widen fallback Gaussian bands in case we still call copula_bands
SIG_CALIBRATION_FACTOR = 1.5

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


# ----------------- DB helpers (unchanged schemas) -----------------

def get_db_conn():
    dsn = (
        f"host={PG_HOST} port={PG_PORT} dbname={PG_DB} "
        f"user={PG_USER} password={PG_PASS} sslmode={PG_SSLMODE}"
    )
    return psycopg2.connect(dsn)

def ensure_output_tables(conn):
    cur = conn.cursor()

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_watchlist (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        expected_cases_next_week NUMERIC,
        high_scenario_p90 NUMERIC,
        status TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_overflow_risk (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        capacity_threshold_beds_per_week INT,
        forecast_median_next_week NUMERIC,
        high_scenario_p90 NUMERIC,
        breach_risk_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_acceleration_alerts (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        last_week_cases NUMERIC,
        this_week_actual NUMERIC,
        this_week_predicted NUMERIC,
        next_week_forecast NUMERIC,
        growth_rate_wow NUMERIC,
        growth_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_confidence (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        forecast_next_week NUMERIC,
        uncertainty_width NUMERIC,
        confidence_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_surveillance_quality (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        reporting_continuity_pct NUMERIC,
        data_quality_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_peak_projection (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        peak_lead_time_weeks NUMERIC,
        peak_cases_median NUMERIC,
        peak_cases_high_p90 NUMERIC,
        peak_when TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_isochrone_spread (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        forecast_next_week_median NUMERIC,
        forecast_next_week_hi NUMERIC,
        growth_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_nowcast_gap (
        district TEXT NOT NULL,
        year INT,
        epi_week INT,
        current_week_actual NUMERIC,
        current_week_predicted_from_prev NUMERIC,
        nowcast_gap_percent NUMERIC,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, year, epi_week)
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_climate_influence (
        feature TEXT,
        base_var TEXT,
        lag_info TEXT,
        pearson_corr_with_next_week_forecast NUMERIC,
        abs_corr NUMERIC,
        created_at TIMESTAMP DEFAULT now()
    );
    """)

    cur.execute("""
    CREATE TABLE IF NOT EXISTS dengue_exec_summary (
        id SERIAL PRIMARY KEY,
        summary JSONB,
        created_at TIMESTAMP DEFAULT now()
    );
    """)

    conn.commit()
    cur.close()

def _dedupe_by_pk(rows, cols, pk_cols):
    if not pk_cols:
        return rows
    pk_idx = [cols.index(pk) for pk in pk_cols]
    dedup_map = {}
    for row in rows:
        key = tuple(row[i] for i in pk_idx)
        dedup_map[key] = row
    return list(dedup_map.values())

def upsert_table(conn, table_name, cols, rows, pk_cols=None, wipe_first=False):
    if wipe_first:
        with conn.cursor() as cur:
            cur.execute(f"TRUNCATE TABLE {table_name};")
        conn.commit()

    if not rows:
        return

    if not pk_cols:
        with conn.cursor() as cur:
            insert_sql = (
                f"INSERT INTO {table_name} ({', '.join(cols)}) VALUES %s"
            )
            execute_values(cur, insert_sql, rows)
        conn.commit()
        return

    rows_dedup = _dedupe_by_pk(rows, cols, pk_cols)
    conflict_target = ", ".join(pk_cols)
    set_updates = ", ".join([
        f"{c}=EXCLUDED.{c}" for c in cols if c not in pk_cols
    ])
    insert_sql = (
        f"INSERT INTO {table_name} ({', '.join(cols)}) VALUES %s "
        f"ON CONFLICT ({conflict_target}) DO UPDATE SET {set_updates};"
    )

    with conn.cursor() as cur:
        execute_values(cur, insert_sql, rows_dedup)

    conn.commit()


# ----------------- Feature engineering / model code -----------------

def _iso_week_start(year, week):
    iso_year = str(int(year))
    iso_week = str(int(week)).zfill(2)
    return pd.to_datetime(iso_year + iso_week + "1", format="%G%V%u", errors="coerce")

def fetch_dengue_weather_from_db():
    conn = get_db_conn()
    try:
        try_query = """
            SELECT
                "District",
                "Division",
                "Year",
                "Epi_Week",
                weekly_hospitalised_cases,
                "Total_Rainfall",
                "Avg_Humidity",
                "Avg_Temperature"
            FROM vw_dengue_weekly_input
            ORDER BY "District","Year","Epi_Week";
        """
        try:
            df = pd.read_sql(try_query, conn)
            return df
        except Exception:
            fallback_query = """
                SELECT
                    district      AS "District",
                    division      AS "Division",
                    year          AS "Year",
                    epi_week      AS "Epi_Week",
                    weekly_hospitalised_cases,
                    total_rainfall    AS "Total_Rainfall",
                    avg_humidity      AS "Avg_Humidity",
                    avg_temperature   AS "Avg_Temperature"
                FROM dengue_weather
                WHERE year IS NOT NULL
                  AND epi_week IS NOT NULL
                ORDER BY district, year, epi_week;
            """
            df = pd.read_sql(fallback_query, conn)
            return df
    finally:
        conn.close()

def load_weekly_from_db():
    df = fetch_dengue_weather_from_db()

    df["District"] = df["District"].astype(str).str.strip()
    df["Division"] = df["Division"].astype(str).str.strip()
    df["Year"]     = pd.to_numeric(df["Year"], errors="coerce").astype("Int64")
    df["Epi_Week"] = pd.to_numeric(df["Epi_Week"], errors="coerce").astype("Int64")

    numeric_cols = [
        "weekly_hospitalised_cases",
        "Total_Rainfall",
        "Avg_Humidity",
        "Avg_Temperature",
    ]
    optional_numeric = ["Avg_NDVI", "Avg_NDWI"]
    for col in numeric_cols:
        df[col] = pd.to_numeric(df[col], errors="coerce")

    df["week_start"] = [
        _iso_week_start(y, w)
        for (y, w) in zip(df["Year"], df["Epi_Week"])
    ]
    mask = df["week_start"].isna()
    if mask.any():
        tmp = pd.to_datetime(df.loc[mask, "Year"].astype(str) + "-01-01", errors="coerce")
        df.loc[mask, "week_start"] = tmp + pd.to_timedelta(
            (df.loc[mask, "Epi_Week"].astype(float) - 1) * 7,
            unit="D"
        )

    df["Month"] = df["week_start"].dt.month.astype("Int64")
    df["WeekOfYear"] = df["week_start"].dt.isocalendar().week.astype(int)

    df = df.sort_values(["District", "Year", "Epi_Week"]).reset_index(drop=True)

    fill_cols = numeric_cols + [c for c in optional_numeric if c in df.columns]
    df[fill_cols] = df[fill_cols].replace([np.inf, -np.inf], np.nan)
    df[fill_cols] = (
        df.groupby("District")[fill_cols]
          .apply(lambda g: g.ffill().bfill())
          .reset_index(drop=True)
    )
    df[fill_cols] = df[fill_cols].fillna(0.0)

    return df

def add_time_features(df):
    df["week_sin"] = np.sin(2 * np.pi * df["WeekOfYear"].astype(float) / 52.0).astype(np.float32)
    df["week_cos"] = np.cos(2 * np.pi * df["WeekOfYear"].astype(float) / 52.0).astype(np.float32)
    df["month_sin"] = np.sin(2 * np.pi * df["Month"].astype(float) / 12.0).astype(np.float32)
    df["month_cos"] = np.cos(2 * np.pi * df["Month"].astype(float) / 12.0).astype(np.float32)
    df["Year_num"] = df["Year"].astype(float)
    return df

def add_incidence_derivs(df):
    """
    Surge velocity features:
    - week-over-week growth
    - 2-week growth
    - acceleration of growth
    """
    df = df.sort_values(["District","Year","Epi_Week"]).copy()
    grp = df.groupby("District")["weekly_hospitalised_cases"]

    prev1 = grp.shift(1)
    prev2 = grp.shift(2)

    df["case_growth_1w"] = ((grp.shift(0) - prev1) / (prev1 + 1e-6)).fillna(0.0)
    df["case_growth_2w"] = ((grp.shift(0) - prev2) / (prev2 + 1e-6)).fillna(0.0)
    df["acceleration"]   = (df["case_growth_1w"] - df["case_growth_2w"]).fillna(0.0)

    return df

def add_static_district_feats(df):
    """
    District static context (beds/pop/etc). For now we just give each district
    DEFAULT_CAPACITY_PER_DISTRICT so model can learn "scale".
    """
    df = df.copy()
    cap_map = {}
    for dist in df["District"].unique():
        cap_map[dist] = DEFAULT_CAPACITY_PER_DISTRICT
    df["district_capacity_proxy"] = df["District"].map(cap_map).astype(float)
    return df

def add_lags_rolls(df, base_cols, lags=(1, 3, 6, 12), rolls=(3, 6, 12)):
    df = df.sort_values(["District", "Year", "Epi_Week"]).copy()
    new_cols = []
    for col in base_cols:
        for lag in lags:
            lag_col = f"{col}_lag{lag}"
            df[lag_col] = df.groupby("District")[col].shift(lag)
            new_cols.append(lag_col)
        for window in rolls:
            mean_col = f"{col}_rmean{window}"
            std_col = f"{col}_rstd{window}"
            grp = df.groupby("District")[col]
            df[mean_col] = grp.rolling(window, min_periods=1).mean().reset_index(level=0, drop=True)
            df[std_col]  = grp.rolling(window, min_periods=1).std().reset_index(level=0, drop=True).fillna(0.0)
            new_cols.extend([mean_col, std_col])
    df[new_cols] = df[new_cols].fillna(0.0)
    return df


# ----------------- dataset prep -----------------

class Split:
    def __init__(self, X, y, c, df):
        self.X = X
        self.y = y
        self.c = c
        self.df = df

def build_mats(df, feat, target):
    le = LabelEncoder().fit(df["District"].values)

    d = df.copy()
    d["district_id"] = le.transform(d["District"])

    X = d[feat].values.astype(np.float32)
    feat_scaler = StandardScaler().fit(X)
    Xs = feat_scaler.transform(X)

    y = np.log1p(np.clip(d[target].values.reshape(-1, 1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y, np.float32)
    y_scalers = {}
    for cid, g in d.groupby("district_id"):
        idx = g.index.values
        sc = StandardScaler().fit(y[idx])
        y_scalers[cid] = sc
        Ys[idx] = sc.transform(y[idx])

    return Split(Xs, Ys, d["district_id"].values.astype(np.int64), d), le, feat_scaler, y_scalers

def apply_mats(df, feat, target, le, feat_scaler, y_scalers):
    d = df.copy()
    d["district_id"] = le.transform(d["District"])

    X = d[feat].values.astype(np.float32)
    Xs = feat_scaler.transform(X)

    y = np.log1p(np.clip(d[target].values.reshape(-1, 1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y, np.float32)

    for cid, g in d.groupby("district_id"):
        idx = g.index.values
        sc = y_scalers.get(int(cid), StandardScaler().fit(y[idx]))
        Ys[idx] = sc.transform(y[idx])

    return Split(Xs, Ys, d["district_id"].values.astype(np.int64), d)

def to_seq(split, L=SEQ, prev_y=True):
    X, y, c, df = split.X, split.y, split.c, split.df
    iso_idx = (df["Year"].astype(int) * 100 + df["Epi_Week"].astype(int)).values
    SX, SY, SC = [], [], []
    for cid in np.unique(c):
        idx = np.where(c == cid)[0]
        order = np.argsort(iso_idx[idx])
        idx = idx[order]
        for i in range(len(idx) - L + 1):
            sl = idx[i : i + L]
            Xi = X[sl]
            Yi = y[sl]
            if prev_y:
                prev = np.vstack([np.zeros((1, 1), np.float32), Yi[:-1]])
                Xi = np.concatenate([Xi, prev], 1)
            SX.append(Xi)
            SY.append(Yi)
            SC.append(cid)
    return np.asarray(SX, np.float32), np.asarray(SY, np.float32), np.asarray(SC, np.int64)


# ----------------- model -----------------

class CausalTCN(nn.Module):
    def __init__(self, in_ch, hid=64, levels=3, k=3):
        super().__init__()
        self.blocks = nn.ModuleList()
        ch = in_ch
        for level in range(levels):
            dil = 2 ** level
            pad = (k - 1) * dil
            self.blocks.append(
                nn.Sequential(
                    nn.Conv1d(ch, hid, kernel_size=k, dilation=dil, padding=pad),
                    nn.GELU()
                )
            )
            ch = hid
    def forward(self, x):
        y = x.transpose(1, 2)
        L = y.size(-1)
        for block in self.blocks:
            y = block(y)
            # causal crop
            y = y[..., :L]
        return y.transpose(1, 2)

class Generator(nn.Module):
    def __init__(self, cond_dim, noise_dim, nc, emb=8, lstm=LSTM_UNITS, heads=HEADS, drop=DROP, qu=Q):
        super().__init__()
        self.q = qu
        self.ce = nn.Embedding(nc, emb)
        self.pn = nn.Linear(noise_dim, cond_dim)
        self.tcn = CausalTCN(cond_dim * 2 + emb, hid=lstm)
        self.lstm = nn.LSTM(lstm, lstm, 1, batch_first=True)
        self.mha = nn.MultiheadAttention(lstm, heads, batch_first=True, dropout=drop)
        self.ln = nn.LayerNorm(lstm)
        self.drop = nn.Dropout(drop)
        self.mu = nn.Linear(lstm, 1)
        self.ls = nn.Linear(lstm, 1)
        self.qh = nn.ModuleList([nn.Linear(lstm, 1) for _ in qu])
        self.last = None
    def forward(self, cond, noise, cid):
        B, L, _ = cond.shape
        emb = self.ce(cid).unsqueeze(1).repeat(1, L, 1)
        z = self.pn(noise)
        h = self.tcn(torch.cat([cond, z, emb], -1))
        h, _ = self.lstm(h)
        mask = torch.triu(torch.ones(L, L, device=h.device, dtype=torch.bool), diagonal=1)
        att, weights = self.mha(h, h, h, attn_mask=mask, need_weights=True)
        self.last = weights.detach()
        h = self.drop(self.ln(att))
        mu = self.mu(h)
        ls = torch.clamp(self.ls(h), -5.0, 3.0)
        qs = [head(h) for head in self.qh]  # q10, q50, q90
        return mu, ls, qs

class Critic(nn.Module):
    def __init__(self, cond_dim, nc, emb=8, lstm=LSTM_UNITS, drop=DROP):
        super().__init__()
        self.ce = nn.Embedding(nc, emb)
        self.lstm = nn.LSTM(cond_dim + 1 + emb, lstm, 1, batch_first=True)
        self.fc = nn.Linear(lstm, 64)
        self.drop = nn.Dropout(drop)
        self.out = nn.Linear(64, 1)
    def forward(self, cond, ts, cid):
        B, L, _ = cond.shape
        emb = self.ce(cid).unsqueeze(1).repeat(1, L, 1)
        h, _ = self.lstm(torch.cat([cond, ts, emb], -1))
        feat = F.gelu(self.fc(h[:, -1, :]))
        return self.out(self.drop(feat)).squeeze(1), feat


# ----------------- loss / train utils -----------------

def smape(y, p):
    y = y.flatten(); p = p.flatten()
    return 100 * np.mean(2 * np.abs(p - y) / (np.abs(y) + np.abs(p) + 1e-8))

def lerp(a, b, t): return {k: a[k] + (b[k] - a[k]) * t for k in a}

def pinball(pred, target, quantile):
    err = target - pred
    return torch.mean(torch.maximum(quantile * err, (quantile - 1) * err))

def tv_l1(x):
    diff = x[:, 1:] - x[:, :-1]
    return diff.abs().mean(), x.abs().mean()

def gp(critic, yr, yf, cond, cc, lam=10.0):
    B = yr.size(0)
    eps = torch.rand(B, 1, 1, device=DEVICE)
    xi = (eps * yr + (1 - eps) * yf).requires_grad_(True)
    with torch.backends.cudnn.flags(enabled=False):
        score, _ = critic(cond, xi, cc)
    grad = torch.autograd.grad(
        score, xi,
        grad_outputs=torch.ones_like(score),
        create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    return ((grad.view(B, -1).norm(2, 1) - 1) ** 2).mean() * lam

def outbreak_penalty(mu, Y):
    """
    Extra weight when model underpredicts outbreaks.
    """
    diff_pos = (Y - mu).clamp(min=0.0)
    return diff_pos.mean()

def _g_loss(G, D, X, Y, C, noise, weights, pc_idx):
    mu, ls, qs = G(X, noise, C)
    sig = (ls.exp()).clamp(1e-3, 50.0)

    if ABL["hetero"]:
        nll = 0.5 * (((Y - mu) / sig) ** 2 + 2 * ls + math.log(2 * math.pi)).mean()
    else:
        nll = F.l1_loss(mu, Y)

    q_loss = torch.tensor(0.0, device=DEVICE)
    if ABL["quant"]:
        for i, qv in enumerate(Q):
            q_loss += pinball(qs[i], Y, qv)
        q_loss /= len(Q)

    if len(pc_idx) > 0:
        tv, l1 = tv_l1(X[..., pc_idx])
    else:
        tv = torch.tensor(0.0, device=DEVICE)
        l1 = torch.tensor(0.0, device=DEVICE)

    fm = torch.tensor(0.0, device=DEVICE)
    adv_term = torch.tensor(0.0, device=DEVICE)
    if ABL["adv"]:
        with torch.no_grad():
            _, real_feat = D(X, Y, C)
        score_fake, fake_feat = D(X, mu, C)
        fm = F.l1_loss(fake_feat, real_feat)
        adv_term = -score_fake.mean()

    pen = outbreak_penalty(mu, Y)  # outbreak underprediction cost
    outbreak_w = 0.2

    return (
        weights["nll"]*nll
        + weights["q"]*q_loss
        + weights["tv"]*tv
        + weights["l1"]*l1
        + (weights["fm"]*fm if ABL["adv"] else 0.0)
        + (weights["adv"]*adv_term if ABL["adv"] else 0.0)
        + outbreak_w * pen
    )

def validate(G, Xv, Yv, Cv):
    """
    Validation: SMAPE on scaled space + coverage using direct quantiles.
    """
    G.eval()
    with torch.no_grad():
        X = torch.tensor(Xv, dtype=torch.float32, device=DEVICE)
        C = torch.tensor(Cv, dtype=torch.long, device=DEVICE)
        B,L,_ = X.shape

        mu, ls, qs = G(X, torch.zeros(B,L,NOISE,device=DEVICE), C)

        mp = mu[...,0].cpu().numpy()
        y  = Yv[...,0]
        q10 = qs[0][...,0].cpu().numpy()
        q90 = qs[2][...,0].cpu().numpy()

        coverage = float(np.mean((y>=q10)&(y<=q90)))
        score    = smape(y.reshape(-1), mp.reshape(-1))
    return score, coverage


# ----------------- optimizer wrappers -----------------

class Lookahead(torch.optim.Optimizer):
    """
    Lightweight lookahead wrapper around a base optimizer.
    We'll just hold a slow copy of every param and sync every k steps.
    """
    def __init__(self, base_optimizer, alpha=ALPHA, k=K_SYNC):
        self.base = base_optimizer
        self.param_groups = self.base.param_groups
        self.alpha = alpha
        self.k = k
        self._step = 0

        # keep slow weights
        self.slow_weights = []
        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    self.slow_weights.append(p.detach().clone())
                else:
                    self.slow_weights.append(None)

    def zero_grad(self, set_to_none=True):
        self.base.zero_grad(set_to_none=set_to_none)

    def step(self, closure=None):
        loss = self.base.step(closure)
        self._step += 1
        if self._step % self.k != 0:
            return loss

        idx = 0
        for group in self.param_groups:
            for p in group['params']:
                if p.requires_grad:
                    slow = self.slow_weights[idx]
                    slow.add_(self.alpha, p.data - slow)
                    p.data.copy_(slow)
                    self.slow_weights[idx] = slow.clone()
                idx += 1
        return loss

class SAM:
    """
    Sharpness-Aware Minimization helper around an optimizer.
    Usage:
        loss.backward()
        sam.first_step()
        second_loss.backward()
        sam.second_step()
    """
    def __init__(self, optimizer, rho=0.05, adaptive=True):
        self.optimizer = optimizer      # should be Lookahead-wrapped AdamW
        self.rho = rho
        self.adaptive = adaptive
        self.e_ws = {}                  # store perturbations per param

    @torch.no_grad()
    def _grad_norm(self):
        norms=[]
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                if self.adaptive:
                    norms.append(((p.abs()) * p.grad).norm(p=2))
                else:
                    norms.append((p.grad).norm(p=2))
        if not norms:
            return torch.tensor(0.0, device=DEVICE)
        return torch.norm(torch.stack(norms), p=2)

    @torch.no_grad()
    def first_step(self):
        # compute scaled e_w and add it to params
        grad_norm = self._grad_norm()
        scale = self.rho / (grad_norm + 1e-12)
        self.e_ws = {}

        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                if self.adaptive:
                    e_w = (p.abs() * p.grad) * scale
                else:
                    e_w = p.grad * scale
                p.add_(e_w)
                self.e_ws[p] = e_w
        self.optimizer.zero_grad(set_to_none=True)

    @torch.no_grad()
    def second_step(self, clip_norm=CLIP):
        # restore weights
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p in self.e_ws:
                    p.sub_(self.e_ws[p])
        # clip grads then take optimizer step
        params = []
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p.grad is not None:
                    params.append(p)
        if len(params) > 0:
            torch.nn.utils.clip_grad_norm_(params, clip_norm)
        self.optimizer.step()
        self.optimizer.zero_grad(set_to_none=True)
        self.e_ws = {}

def train_gan(Xtr,Ytr,Ctr,Xva,Yva,Cva,cond_dim,nc,pc_idx):
    G=Generator(cond_dim=cond_dim,noise_dim=NOISE,nc=nc).to(DEVICE)
    D=Critic(cond_dim=cond_dim,nc=nc).to(DEVICE)

    baseG = torch.optim.AdamW(G.parameters(),lr=LR_G*TTUR[0],betas=(0.9,0.999),weight_decay=WD)
    baseD = torch.optim.AdamW(D.parameters(),lr=LR_D*TTUR[1],betas=(0.9,0.999),weight_decay=WD)
    optG  = Lookahead(baseG, alpha=ALPHA, k=K_SYNC)
    optD  = Lookahead(baseD, alpha=ALPHA, k=K_SYNC)
    sam   = SAM(optG, rho=0.05, adaptive=True)

    dl=DataLoader(SeqDS(Xtr,Ytr,Ctr),batch_size=BATCH,shuffle=True,drop_last=True)

    best=float("inf"); best_sd=None; wait=0
    for epoch in range(EPOCHS_GAN):
        t=epoch/(EPOCHS_GAN-1)
        noise_scale = NSIG[0]+(NSIG[1]-NSIG[0])*t
        W_START={"nll":1.0,"q":0.5,"tv":0.05,"l1":0.02,"fm":0.2,"adv":0.5}
        W_END  ={"nll":1.0,"q":1.0,"tv":0.10,"l1":0.05,"fm":0.1,"adv":0.4}
        weights=lerp(W_START,W_END,t)
        tf=TF_START+(TF_END-TF_START)*t

        G.train(); D.train()
        for Xb,Yb,Cb in dl:
            Xb=torch.tensor(Xb,dtype=torch.float32,device=DEVICE)
            Yb=torch.tensor(Yb,dtype=torch.float32,device=DEVICE)
            Cb=torch.tensor(Cb,dtype=torch.long,device=DEVICE)
            B,L,_=Xb.shape

            # teacher forcing schedule (bleeding in model prev_y)
            with torch.no_grad():
                mu0,_,_=G(Xb,torch.zeros(B,L,NOISE,device=DEVICE),Cb)
                prev=torch.cat([torch.zeros(B,1,device=DEVICE),mu0[:,:-1,0]],1)
            Xb[:,:,-1]=tf*Xb[:,:,-1]+(1-tf)*prev

            # --- train discriminator ---
            if ABL["adv"]:
                optD.zero_grad(set_to_none=True)
                z=torch.randn(B,L,NOISE,device=DEVICE)*noise_scale
                mu,_,_=G(Xb,z,Cb)
                Y_fake=mu.detach()
                score_real,_=D(Xb,Yb,Cb)
                score_fake,_=D(Xb,Y_fake,Cb)
                d_loss=-(score_real.mean()-score_fake.mean())+gp(D,Yb,Y_fake,Xb,Cb,10.0)
                d_loss.backward()
                # clip & step discriminator
                d_params = []
                for group in optD.param_groups:
                    for p in group["params"]:
                        if p.grad is not None:
                            d_params.append(p)
                if len(d_params)>0:
                    torch.nn.utils.clip_grad_norm_(d_params,CLIP)
                optD.step()

            # --- train generator with SAM ---
            optG.zero_grad(set_to_none=True)
            z1=torch.randn(B,L,NOISE,device=DEVICE)*noise_scale
            L1=_g_loss(G,D,Xb,Yb,Cb,z1,weights,pc_idx)
            L1.backward()
            sam.first_step()

            # second forward/backward at perturbed weights
            z2=torch.randn(B,L,NOISE,device=DEVICE)*noise_scale
            L2=_g_loss(G,D,Xb,Yb,Cb,z2,weights,pc_idx)
            L2.backward()
            sam.second_step()

        # --- validation / early stop ---
        sm,cov=validate(G,Xva,Yva,Cva)
        comp=sm+10*abs(cov-0.9)
        if comp<best-1e-6:
            best=comp
            best_sd=(copy.deepcopy(G.state_dict()),copy.deepcopy(D.state_dict()))
            wait=0
        else:
            wait+=1
            if wait>=PATIENCE:
                break

    if best_sd:
        G.load_state_dict(best_sd[0]); D.load_state_dict(best_sd[1])
    torch.save(G.state_dict(),f"{OUT}/generator.pt")
    torch.save(D.state_dict(),f"{OUT}/critic.pt")
    return G,D


# ----------------- forecast band generator -----------------

def psd(matrix):
    eigvals, eigvecs = np.linalg.eigh(matrix)
    eigvals[eigvals < 1e-6] = 1e-6
    return (eigvecs @ np.diag(eigvals) @ eigvecs.T).astype(np.float32)

def copula_bands(G,X,C,K=K_MC):
    """
    Kept for backward compatibility, but we're now relying on direct quantile heads.
    """
    G.eval()
    Xt=torch.tensor(X,dtype=torch.float32,device=DEVICE)
    Ct=torch.tensor(C,dtype=torch.long,device=DEVICE)
    with torch.no_grad():
        mu,ls,_=G(Xt,torch.zeros(Xt.shape[0],Xt.shape[1],NOISE,device=DEVICE),Ct)
    mu=mu.cpu().numpy()[...,0]
    sig=np.clip(np.exp(ls.cpu().numpy()[...,0]),1e-3,50.0)
    sig = sig * SIG_CALIBRATION_FACTOR
    samples=[]
    for i in range(Xt.shape[0]):
        L=Xt.shape[1]
        rho=0.5
        R=np.fromfunction(lambda a,b: rho**np.abs(a-b),(L,L))
        cov=psd((sig[i][:,None]*sig[i][None,:])*R)
        z=np.random.multivariate_normal(np.zeros(L),cov,size=K).astype(np.float32)
        samples.append(mu[i][None,:]+z)
    samples=np.stack(samples,0)
    return np.percentile(samples,10,1),np.percentile(samples,50,1),np.percentile(samples,90,1)


# ----------------- director tables helpers -----------------

class SeqDS(Dataset):
    def __init__(self,X,Y,C): self.X=X; self.Y=Y; self.C=C
    def __len__(self): return self.X.shape[0]
    def __getitem__(self,i): return self.X[i],self.Y[i],self.C[i]

def calc_reporting_continuity(df_full_raw, seq_len=SEQ):
    rows=[]
    for dist, g in df_full_raw.groupby("District"):
        g = g.sort_values(["Year","Epi_Week"])
        recent = g.tail(seq_len)
        reported_mask = ~recent["weekly_hospitalised_cases"].isna()
        continuity_pct = 100.0 * reported_mask.mean()
        rows.append({
            "District": dist,
            f"Reporting_Continuity_Last{seq_len}w(%)": continuity_pct
        })
    return pd.DataFrame(rows)

def classify_data_quality(continuity_pct):
    if continuity_pct >= 90:
        return "✅ Reliable"
    elif continuity_pct >= 70:
        return "🟡 Watch"
    else:
        return "🔴 Needs field check"

def classify_growth(g):
    if g >= RAPID_GROWTH_THRESH:
        return "🔴 Rapid growth"
    elif g >= MODERATE_GROWTH_THRESH:
        return "🟠 Moderate growth"
    else:
        return "🟢 Stable/decline"

def classify_risk(high_scenario):
    if high_scenario >= 40:
        return "🔴 Surge likely"
    elif high_scenario >= 20:
        return "🟠 Elevated"
    else:
        return "🟢 Low"

def classify_conf(width):
    if width <= 5:
        return "✅ High confidence"
    elif width <= 15:
        return "🟡 Medium confidence"
    else:
        return "⚠ Low confidence"

def build_isochrone_table(per_district_info):
    return pd.DataFrame(per_district_info)

def calc_climate_lag_influence(
    Xseq, Cseq, feats, next_week_pred_by_seq
):
    """
    Produce a clean, human-readable climate/lag influence table:
      - excludes the 'prev_y' channel
      - parses names to base_var + nice lag/roll text
      - handles seasonality encodings
      - sorts by |corr|
    """
    # Use the "current" timestep features, excluding the trailing prev_y channel
    # to mirror diarrhoea’s approach.
    t_idx = -2
    feature_matrix = Xseq[:, t_idx, :-1]   # drop prev_y
    preds = np.array(next_week_pred_by_seq)

    def parse_feature_name(raw_name: str):
        """
        Examples:
          Total_Rainfall_lag3    -> base_var='Total_Rainfall', lag_info='lag 3w'
          Avg_Temperature_rmean6 -> base_var='Avg_Temperature', lag_info='6w rolling mean'
          weekly_hospitalised_cases_rstd12 -> base_var='weekly_hospitalised_cases', lag_info='12w rolling std'
          week_sin/cos           -> base_var='season(week)',  lag_info=''
          month_sin/cos          -> base_var='season(month)', lag_info=''
          Year_num               -> base_var='Year_num',      lag_info='same week'
          Avg_NDVI               -> base_var='Avg_NDVI',      lag_info='same week'
        """
        name = raw_name

        # seasonality encodings
        if name in ("week_sin", "week_cos"):
            return "season(week)", ""
        if name in ("month_sin", "month_cos"):
            return "season(month)", ""

        # rolling mean/std
        if "_rmean" in name:
            base, tail = name.split("_rmean")
            return base, f"{tail}w rolling mean"
        if "_rstd" in name:
            base, tail = name.split("_rstd")
            return base, f"{tail}w rolling std"

        # lags
        if "_lag" in name:
            base, tail = name.split("_lag")
            return base, f"lag {tail}w"

        # default is "same week"
        return name, "same week"

    rows = []
    for fi, raw_feat_name in enumerate(feats):  # feats aligns with Xseq[..., :-1]
        col_vals = feature_matrix[:, fi]
        if np.std(col_vals) < 1e-8:
            corr = 0.0
        else:
            corr = np.corrcoef(col_vals, preds)[0, 1]

        base_var, lag_info = parse_feature_name(raw_feat_name)
        rows.append({
            "feature": raw_feat_name,
            "base_var": base_var,
            "lag_info": lag_info,
            "pearson_corr_with_next_week_forecast": float(corr),
        })

    out_df = pd.DataFrame(rows)
    out_df["abs_corr"] = out_df["pearson_corr_with_next_week_forecast"].abs()
    out_df = out_df.sort_values("abs_corr", ascending=False)
    return out_df


# ----------------- evaluate + summarize -----------------

def evaluate_and_summarize(G, Xseq, Yseq, Cseq, ysc, le, df_full_raw, feats):
    G.eval()

    X_torch = torch.tensor(Xseq, dtype=torch.float32, device=DEVICE)
    C_torch = torch.tensor(Cseq, dtype=torch.long, device=DEVICE)

    with torch.no_grad():
        mu, ls, qs = G(
            X_torch,
            torch.zeros(X_torch.shape[0], X_torch.shape[1], NOISE, device=DEVICE),
            C_torch
        )
        mp_scaled  = mu[...,0].cpu().numpy()
        q10_scaled = qs[0][...,0].cpu().numpy()
        q50_scaled = qs[1][...,0].cpu().numpy()
        q90_scaled = qs[2][...,0].cpu().numpy()

    def inv_scale(seq_arr_scaled, scaler):
        seq_arr_scaled = seq_arr_scaled.reshape(-1,1)
        unstd = scaler.inverse_transform(seq_arr_scaled).reshape(-1)
        return np.clip(np.expm1(unstd), 0, None)

    district_names = []
    last_actual_list = []
    current_actual_list = []
    current_pred_list = []
    next_week_pred_list = []
    next_week_hi_list = []
    next_week_lo_list = []
    growth_rate_list = []
    conf_width_list = []
    peak_week_list = []
    peak_value_list = []
    peak_hi_list = []
    peak_leadtime_list = []
    epiweek_meta_rows = []

    yF_all=[]; mF_all=[]; lF_all=[]; hF_all=[]

    idx_curr = -2
    idx_next = -1

    last_epi_meta = {}

    for i, cid in enumerate(Cseq):
        scaler = ysc[int(cid)]
        dist_name = le.inverse_transform([cid])[0]

        y_real  = inv_scale(Yseq[i].reshape(-1),     scaler)
        m_real  = inv_scale(mp_scaled[i].reshape(-1),scaler)
        lo_real = inv_scale(q10_scaled[i].reshape(-1),scaler)
        hi_real = inv_scale(q90_scaled[i].reshape(-1),scaler)

        yF_all.append(y_real)
        mF_all.append(m_real)
        lF_all.append(lo_real)
        hF_all.append(hi_real)

        safe_idx_curr = idx_curr if idx_curr >= -len(y_real) else -1
        safe_idx_next = idx_next if idx_next >= -len(y_real) else -1

        actual_curr = y_real[safe_idx_curr]
        pred_curr   = m_real[safe_idx_curr]
        actual_prev = y_real[safe_idx_curr-1] if (safe_idx_curr-1)>=-len(y_real) else y_real[safe_idx_curr]

        pred_next   = m_real[safe_idx_next]
        hi_next     = hi_real[safe_idx_next]
        lo_next     = lo_real[safe_idx_next]

        growth = (pred_next - actual_prev) / (actual_prev + 1e-6)
        conf_width = hi_next - lo_next

        look_slice = m_real[safe_idx_next-PEAK_LOOKAHEAD_WEEKS+1 : safe_idx_next+1]
        hi_slice   = hi_real[safe_idx_next-PEAK_LOOKAHEAD_WEEKS+1 : safe_idx_next+1]
        if len(look_slice)==0:
            peak_when="NA"; peak_val=np.nan; peak_hi=np.nan; lead_weeks=np.nan
        else:
            local_max_idx = int(np.argmax(look_slice))
            peak_val = look_slice[local_max_idx]
            peak_hi  = hi_slice[local_max_idx]
            lead_weeks = (len(look_slice)-1) - local_max_idx
            peak_when = f"t+{lead_weeks}w"

        gdist = df_full_raw[df_full_raw["District"]==dist_name].sort_values(["Year","Epi_Week"])
        if len(gdist) > 0:
            lastrow = gdist.iloc[-1]
            last_year = int(lastrow["Year"])
            last_epi  = int(lastrow["Epi_Week"])
        else:
            last_year = None
            last_epi  = None

        district_names.append(dist_name)
        last_actual_list.append(actual_prev)
        current_actual_list.append(actual_curr)
        current_pred_list.append(pred_curr)
        next_week_pred_list.append(pred_next)
        next_week_hi_list.append(hi_next)
        next_week_lo_list.append(lo_next)
        growth_rate_list.append(growth)
        conf_width_list.append(conf_width)
        peak_week_list.append(peak_when)
        peak_value_list.append(peak_val)
        peak_hi_list.append(peak_hi)
        peak_leadtime_list.append(lead_weeks)

        last_epi_meta[dist_name] = {
            "Year": last_year,
            "Epi_Week": last_epi
        }

        epiweek_meta_rows.append({
            "District": dist_name,
            "Year": last_year,
            "Epi_Week": last_epi,
            "Forecast_next_week_median": float(pred_next),
            "Forecast_next_week_hi": float(hi_next),
            "Growth_flag": classify_growth(growth)
        })

    yF_all = np.concatenate(yF_all)
    mF_all = np.concatenate(mF_all)
    lF_all = np.concatenate(lF_all)
    hF_all = np.concatenate(hF_all)

    overall_metrics = {
        "SMAPE": smape(yF_all, mF_all),
        "MSE": mean_squared_error(yF_all, mF_all),
        "RMSE": math.sqrt(mean_squared_error(yF_all, mF_all)),
        "R2": r2_score(yF_all, mF_all),
        "Coverage90": float(np.mean((yF_all >= lF_all) & (yF_all <= hF_all))),
    }

    quality_df_raw = calc_reporting_continuity(df_full_raw, seq_len=SEQ)
    quality_df = pd.merge(
        pd.DataFrame({"District": district_names}),
        quality_df_raw,
        on="District",
        how="left"
    )
    cont_col = f"Reporting_Continuity_Last{SEQ}w(%)"
    quality_df["Data_quality_flag"] = quality_df[cont_col].apply(classify_data_quality)

    watchlist_df = pd.DataFrame({
        "District": district_names,
        "Expected_cases_next_week": np.round(next_week_pred_list,2),
        "High_scenario_p90": np.round(next_week_hi_list,2),
    })
    watchlist_df["Status"] = watchlist_df["High_scenario_p90"].apply(classify_risk)
    watchlist_df = watchlist_df.sort_values("High_scenario_p90", ascending=False)

    capacity_lookup = {d: DEFAULT_CAPACITY_PER_DISTRICT for d in district_names}
    overflow_df = pd.DataFrame({
        "District": district_names,
        "Capacity_threshold_beds_per_week": [capacity_lookup[d] for d in district_names],
        "Forecast_median_next_week": np.round(next_week_pred_list,2),
        "High_scenario_p90": np.round(next_week_hi_list,2),
    })
    overflow_df["Breach_risk_flag"] = [
        "YES" if hi > capacity_lookup[d] else "NO"
        for d, hi in zip(district_names, next_week_hi_list)
    ]
    overflow_df = overflow_df.sort_values("High_scenario_p90", ascending=False)

    accel_df = pd.DataFrame({
        "District": district_names,
        "Last_week_cases": np.round(last_actual_list,2),
        "This_week_actual": np.round(current_actual_list,2),
        "This_week_predicted": np.round(current_pred_list,2),
        "Next_week_forecast": np.round(next_week_pred_list,2),
        "Growth_rate_WoW(%)": np.round(np.array(growth_rate_list)*100,1),
    })
    accel_df["Growth_flag"] = accel_df["Growth_rate_WoW(%)"].apply(lambda pct: classify_growth(pct/100.0))
    accel_df = accel_df.sort_values("Growth_rate_WoW(%)", ascending=False)

    confidence_df = pd.DataFrame({
        "District": district_names,
        "Forecast_next_week": np.round(next_week_pred_list,2),
        "Uncertainty_width(p90-p10)": np.round(conf_width_list,2),
    })
    confidence_df["Confidence_flag"] = confidence_df["Uncertainty_width(p90-p10)"].apply(classify_conf)
    confidence_df = confidence_df.sort_values("Uncertainty_width(p90-p10)", ascending=True)

    peak_df = pd.DataFrame({
        "District": district_names,
        "Peak_lead_time_weeks": peak_leadtime_list,
        "Peak_cases_median": np.round(peak_value_list,2),
        "Peak_cases_high(p90)": np.round(peak_hi_list,2),
        "Peak_when": peak_week_list
    }).sort_values("Peak_cases_median", ascending=False)

    isochrone_df = build_isochrone_table(epiweek_meta_rows)

    gap_pct_list = []
    for a, p in zip(current_actual_list, current_pred_list):
        gap_pct = (p - a)/(a + 1e-6)*100.0
        gap_pct_list.append(gap_pct)
    nowcast_gap_df = pd.DataFrame({
        "District": district_names,
        "Current_week_actual": np.round(current_actual_list,2),
        "Current_week_predicted_from_prev": np.round(current_pred_list,2),
        "Nowcast_gap_percent": np.round(gap_pct_list,1),
    }).sort_values("Nowcast_gap_percent", ascending=False)

    climate_influence_df = calc_climate_lag_influence(
    Xseq, Cseq, feats, next_week_pred_list
)

    top5 = watchlist_df.head(5)[["District","Expected_cases_next_week","High_scenario_p90","Status"]].to_dict(orient="records")
    rapid = accel_df[accel_df["Growth_flag"]=="🔴 Rapid growth"]["District"].tolist()
    overflow_risk = overflow_df[overflow_df["Breach_risk_flag"]=="YES"]["District"].tolist()
    dq_bad = quality_df[quality_df["Data_quality_flag"]=="🔴 Needs field check"]["District"].tolist()
    summary_text = {
        "Top5_high_risk_next_week": to_python(top5),
        "Districts_with_rapid_growth": to_python(rapid),
        "Districts_with_capacity_breach_risk": to_python(overflow_risk),
        "Districts_with_data_quality_issues": to_python(dq_bad),
        "Model_calibration": {
            "Coverage90": round(float(overall_metrics["Coverage90"]),3),
            "SMAPE": round(float(overall_metrics["SMAPE"]),2),
            "R2": round(float(overall_metrics["R2"]),3),
        }
    }

    products = {
        "watchlist_df": watchlist_df,
        "overflow_df": overflow_df,
        "accel_df": accel_df,
        "confidence_df": confidence_df,
        "quality_df": quality_df,
        "peak_df": peak_df,
        "isochrone_df": isochrone_df,
        "nowcast_gap_df": nowcast_gap_df,
        "climate_influence_df": climate_influence_df,
        "executive_summary": summary_text,
        "overall_metrics": overall_metrics,
        "last_epi_meta": last_epi_meta,
    }
    return overall_metrics, products


# ----------------- push results to DB (UNCHANGED OUTPUT SCHEMA) -----------------

def push_products_to_db(conn, products):
    last_epi_meta = products["last_epi_meta"]

    # WATCHLIST
    watchlist_rows = []
    for _, r in products["watchlist_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        watchlist_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Expected_cases_next_week"]),
            py(r["High_scenario_p90"]),
            r["Status"]
        ))
    upsert_table(
        conn,
        "dengue_watchlist",
        ["district","year","epi_week","expected_cases_next_week","high_scenario_p90","status"],
        watchlist_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # OVERFLOW RISK
    overflow_rows = []
    for _, r in products["overflow_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        overflow_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            int(r["Capacity_threshold_beds_per_week"]) if not pd.isna(r["Capacity_threshold_beds_per_week"]) else None,
            py(r["Forecast_median_next_week"]),
            py(r["High_scenario_p90"]),
            r["Breach_risk_flag"],
        ))
    upsert_table(
        conn,
        "dengue_overflow_risk",
        ["district","year","epi_week","capacity_threshold_beds_per_week","forecast_median_next_week","high_scenario_p90","breach_risk_flag"],
        overflow_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # ACCELERATION ALERTS
    accel_rows = []
    for _, r in products["accel_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        accel_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Last_week_cases"]),
            py(r["This_week_actual"]),
            py(r["This_week_predicted"]),
            py(r["Next_week_forecast"]),
            py(r["Growth_rate_WoW(%)"]),
            r["Growth_flag"],
        ))
    upsert_table(
        conn,
        "dengue_acceleration_alerts",
        ["district","year","epi_week","last_week_cases","this_week_actual","this_week_predicted","next_week_forecast","growth_rate_wow","growth_flag"],
        accel_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # CONFIDENCE
    conf_rows = []
    for _, r in products["confidence_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        conf_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Forecast_next_week"]),
            py(r["Uncertainty_width(p90-p10)"]),
            r["Confidence_flag"],
        ))
    upsert_table(
        conn,
        "dengue_confidence",
        ["district","year","epi_week","forecast_next_week","uncertainty_width","confidence_flag"],
        conf_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # QUALITY
    qual_rows = []
    rep_col = [c for c in products["quality_df"].columns if "Reporting_Continuity_Last" in c][0]
    for _, r in products["quality_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        qual_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r[rep_col]),
            r["Data_quality_flag"],
        ))
    upsert_table(
        conn,
        "dengue_surveillance_quality",
        ["district","year","epi_week","reporting_continuity_pct","data_quality_flag"],
        qual_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # PEAK
    peak_rows = []
    for _, r in products["peak_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        peak_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Peak_lead_time_weeks"]),
            py(r["Peak_cases_median"]),
            py(r["Peak_cases_high(p90)"]),
            r["Peak_when"],
        ))
    upsert_table(
        conn,
        "dengue_peak_projection",
        ["district","year","epi_week","peak_lead_time_weeks","peak_cases_median","peak_cases_high_p90","peak_when"],
        peak_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # ISOCHRONE
    iso_rows = []
    for _, r in products["isochrone_df"].iterrows():
        iso_rows.append((
            r["District"],
            r["Year"],
            r["Epi_Week"],
            py(r["Forecast_next_week_median"]),
            py(r["Forecast_next_week_hi"]),
            r["Growth_flag"],
        ))
    upsert_table(
        conn,
        "dengue_isochrone_spread",
        ["district","year","epi_week","forecast_next_week_median","forecast_next_week_hi","growth_flag"],
        iso_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # NOWCAST GAP
    gap_rows = []
    for _, r in products["nowcast_gap_df"].iterrows():
        dm = last_epi_meta.get(r["District"], {})
        gap_rows.append((
            r["District"],
            dm.get("Year"),
            dm.get("Epi_Week"),
            py(r["Current_week_actual"]),
            py(r["Current_week_predicted_from_prev"]),
            py(r["Nowcast_gap_percent"]),
        ))
    upsert_table(
        conn,
        "dengue_nowcast_gap",
        ["district","year","epi_week","current_week_actual","current_week_predicted_from_prev","nowcast_gap_percent"],
        gap_rows,
        pk_cols=["district","year","epi_week"],
        wipe_first=False
    )

    # CLIMATE INFLUENCE
    clim_rows = []
    for _, r in products["climate_influence_df"].iterrows():
        clim_rows.append((
            r["feature"],
            r["base_var"],
            r["lag_info"],
            py(r["pearson_corr_with_next_week_forecast"]),
            py(r["abs_corr"]),
        ))
    upsert_table(
        conn,
        "dengue_climate_influence",
        ["feature","base_var","lag_info","pearson_corr_with_next_week_forecast","abs_corr"],
        clim_rows,
        pk_cols=None,
        wipe_first=True
    )

    # EXEC SUMMARY
    summary_json = json.dumps(to_python(products["executive_summary"]))
    upsert_table(
        conn,
        "dengue_exec_summary",
        ["summary"],
        [(summary_json,)],
        pk_cols=None,
        wipe_first=True
    )


# ----------------- main -----------------

def main():
    print("1) Load weekly data from Postgres")
    df = load_weekly_from_db()

    # engineered features BEFORE lag/roll
    df = add_time_features(df)
    df = add_incidence_derivs(df)           # surge velocity features
    df = add_static_district_feats(df)      # district-level capacity proxy

    base_series = [
        "weekly_hospitalised_cases",
        "Total_Rainfall",
        "Avg_Humidity",
        "Avg_Temperature",
        "case_growth_1w",
        "case_growth_2w",
        "acceleration",
        "district_capacity_proxy"
    ]
    if "Avg_NDVI" in df.columns:
        base_series.append("Avg_NDVI")
    if "Avg_NDWI" in df.columns:
        base_series.append("Avg_NDWI")

    df = add_lags_rolls(
        df,
        base_cols=base_series,
        lags=(1, 3, 6, 12),
        rolls=(3, 6, 12),
    )

    feats = [
        "Total_Rainfall",
        "Avg_Humidity",
        "Avg_Temperature",
        "case_growth_1w",
        "case_growth_2w",
        "acceleration",
        "district_capacity_proxy",
        "Year_num",
        "week_sin",
        "week_cos",
        "month_sin",
        "month_cos",
    ]
    for optional in ["Avg_NDVI", "Avg_NDWI"]:
        if optional in df.columns:
            feats.append(optional)

    feats += [c for c in df.columns if any(tag in c for tag in ["_lag", "_rmean", "_rstd"])]

    target = "weekly_hospitalised_cases"

    print("2) Prepare train/val/test splits")
    tr_parts, va_parts, te_parts = [], [], []
    for name, group in df.groupby("District"):
        group = group.sort_values(["Year", "Epi_Week"])
        if len(group) > (VAL_H_WEEKS + TEST_H_WEEKS):
            te_parts.append(group.iloc[-TEST_H_WEEKS:])
            va_parts.append(group.iloc[-(VAL_H_WEEKS + TEST_H_WEEKS) : -TEST_H_WEEKS])
            tr_parts.append(group.iloc[: -(VAL_H_WEEKS + TEST_H_WEEKS)])
        elif len(group) > TEST_H_WEEKS:
            te_parts.append(group.iloc[-TEST_H_WEEKS:])
            tr_parts.append(group.iloc[:-TEST_H_WEEKS])
        else:
            tr_parts.append(group)

    tr_df = pd.concat(tr_parts).reset_index(drop=True)
    va_df = pd.concat(va_parts).reset_index(drop=True) if va_parts else tr_df.iloc[0:0].copy()
    te_df = pd.concat(te_parts).reset_index(drop=True) if te_parts else tr_df.iloc[0:0].copy()

    df_full_raw = df.copy()

    print("3) Fit encoders/scalers on TRAIN")
    split_tr, le_tr, fsc_tr, ysc_tr = build_mats(tr_df, feats, target)

    X_tr, Y_tr, C_tr = to_seq(split_tr, SEQ, prev_y=True)
    if X_tr.shape[0] == 0:
        raise RuntimeError("No train sequences. Reduce SEQ or inspect data continuity.")
    print(f"Train sequences: {X_tr.shape}")

    if len(va_df) > 0:
        split_va = apply_mats(va_df, feats, target, le_tr, fsc_tr, ysc_tr)
        X_va, Y_va, C_va = to_seq(split_va, SEQ, prev_y=True)
        if X_va.shape[0] == 0:
            X_va, Y_va, C_va = X_tr, Y_tr, C_tr
    else:
        X_va, Y_va, C_va = X_tr, Y_tr, C_tr

    if len(te_df) > 0:
        split_te = apply_mats(te_df, feats, target, le_tr, fsc_tr, ysc_tr)
        X_te, Y_te, C_te = to_seq(split_te, SEQ, prev_y=True)
        if X_te.shape[0] == 0:
            X_te, Y_te, C_te = X_va, Y_va, C_va
    else:
        X_te, Y_te, C_te = X_va, Y_va, C_va

    nc = int(C_tr.max()) + 1
    cond_dim = X_tr.shape[2]

    pc_idx = [
        i for i, feature in enumerate(feats + ["prev_y"])
        if feature in ("week_sin", "week_cos", "month_sin", "month_cos")
    ]

    print("4) Train GAN (train) + Early stop (val)")
    G, D = train_gan(X_tr, Y_tr, C_tr, X_va, Y_va, C_va, cond_dim, nc, pc_idx)

    print("5) Build director tables on HELD-OUT TEST")
    overall_metrics, products = evaluate_and_summarize(
        G, X_te, Y_te, C_te, ysc_tr, le_tr, df_full_raw, feats
    )

    # save locally
    products["watchlist_df"].to_csv(f"{OUT}/watchlist.csv", index=False)
    products["overflow_df"].to_csv(f"{OUT}/overflow_risk.csv", index=False)
    products["accel_df"].to_csv(f"{OUT}/acceleration_alerts.csv", index=False)
    products["confidence_df"].to_csv(f"{OUT}/forecast_confidence.csv", index=False)
    products["quality_df"].to_csv(f"{OUT}/surveillance_quality.csv", index=False)
    products["peak_df"].to_csv(f"{OUT}/peak_projection.csv", index=False)
    products["isochrone_df"].to_csv(f"{OUT}/isochrone_spread.csv", index=False)
    products["nowcast_gap_df"].to_csv(f"{OUT}/nowcast_gap.csv", index=False)
    products["climate_influence_df"].to_csv(f"{OUT}/climate_lag_influence.csv", index=False)

    dash_manifest = {
        "executive_summary": products["executive_summary"],
        "overall_metrics": products["overall_metrics"],
        "notes": "Auto-generated decision tables for dengue surge planning, including climate lag screen, spatial isochrone prep, and nowcast gap."
    }
    dash_manifest_clean = to_python(dash_manifest)
    with open(f"{OUT}/DASHBOARD_SUMMARY.json","w") as f:
        json.dump(dash_manifest_clean, f, indent=2)

    print("6) Create / upsert output tables in Postgres")
    conn = get_db_conn()
    ensure_output_tables(conn)
    push_products_to_db(conn, products)
    conn.close()

    print("\n=== Executive Summary Preview ===")
    print(json.dumps(to_python(dash_manifest["executive_summary"]), indent=2))
    print("\nArtifacts saved to", OUT, "and pushed to Postgres.\nDone.")


if __name__ == "__main__":
    main()