In [None]:
import os, math, json, copy, random, warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, r2_score

import psycopg2
from psycopg2.extras import execute_values

# ============================================================
# CONFIG
# ============================================================

PG_HOST    = "119.148.17.102"
PG_PORT    = 5432
PG_DB      = "ewarsdb"
PG_USER    = "ewars"
PG_PASS    = "Iedcr@Ewars2025"
PG_SSLMODE = "require"

OUT_ROOT = "/content/malaria_out"
os.makedirs(OUT_ROOT, exist_ok=True)

SEED=42
DEVICE="cuda" if torch.cuda.is_available() else "cpu"

# model / training
SEQ=12                  # lookback window length (months)
BATCH=64
EPOCHS=350
PATIENCE=25
LSTM_UNITS=128
HEADS=8
DROP=0.25
LR=8e-4
WD=1e-4
CLIP=1.0
TF_START=1.0
TF_END=0.3
Q=(0.1,0.5,0.9)

# split lengths per upazila
VAL_H_MONTHS  = 6
TEST_H_MONTHS = 6

# classification thresholds / dashboard logic (reuse diarrhoea style)
DEFAULT_CAPACITY_PER_DISTRICT = 25
RAPID_GROWTH_THRESH = 0.30     # >=30% wow growth => red
MODERATE_GROWTH_THRESH = 0.10  # >=10% wow growth => orange

PEAK_LOOKAHEAD_DAYS = 15       # same placeholder logic as diarrhoea (we'll treat it as ~2 weeks window)
CONF_HIGH = 5
CONF_MED  = 15

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark=False
    torch.backends.cudnn.deterministic=True

# which climate/env features we keep
BASE_FEATURES = [
    "average_temperature",
    "total_rainfall",
    "relative_humidity",
    "average_ndvi",
    "average_ndwi"
]

# ============================================================
# DB HELPERS
# ============================================================

def get_db_conn():
    dsn = (
        f"host={PG_HOST} port={PG_PORT} dbname={PG_DB} "
        f"user={PG_USER} password={PG_PASS} sslmode={PG_SSLMODE}"
    )
    return psycopg2.connect(dsn)

def _dedupe_by_pk(rows, cols, pk_cols):
    if not pk_cols:
        return rows
    pk_idx = [cols.index(pk) for pk in pk_cols]
    dedup = {}
    for row in rows:
        key = tuple(row[i] for i in pk_idx)
        dedup[key] = row
    return list(dedup.values())

def upsert_table(conn, table_name, cols, rows, pk_cols=None, wipe_first=False):
    if wipe_first:
        with conn.cursor() as cur:
            cur.execute(f"TRUNCATE TABLE {table_name};")
        conn.commit()
    if not rows:
        return
    if not pk_cols:
        with conn.cursor() as cur:
            insert_sql = f"INSERT INTO {table_name} ({', '.join(cols)}) VALUES %s"
            execute_values(cur, insert_sql, rows)
        conn.commit()
        return

    rows_dedup = _dedupe_by_pk(rows, cols, pk_cols)
    conflict_target = ", ".join(pk_cols)
    set_updates = ", ".join([f"{c}=EXCLUDED.{c}" for c in cols if c not in pk_cols])
    insert_sql = (
        f"INSERT INTO {table_name} ({', '.join(cols)}) VALUES %s "
        f"ON CONFLICT ({conflict_target}) DO UPDATE SET {set_updates};"
    )
    with conn.cursor() as cur:
        execute_values(cur, insert_sql, rows_dedup)
    conn.commit()

def to_python(v):
    if isinstance(v, (np.floating, np.float32, np.float64)):
        v = float(v)
        if math.isnan(v):
            return None
        return v
    if isinstance(v, (np.integer, np.int32, np.int64)):
        return int(v)
    if isinstance(v, (np.bool_,)):
        return bool(v)
    if isinstance(v, float) and math.isnan(v):
        return None
    if isinstance(v, dict):
        return {k: to_python(x) for k,x in v.items()}
    if isinstance(v, (list, tuple)):
        return [to_python(x) for x in v]
    return v

def py(v):
    if isinstance(v, (np.floating, np.float32, np.float64)):
        v = float(v)
    elif isinstance(v, (np.integer, np.int32, np.int64)):
        v = int(v)
    elif isinstance(v, (np.bool_,)):
        v = bool(v)
    if isinstance(v, float) and math.isnan(v):
        return None
    return v

# ============================================================
# CREATE OUTPUT TABLES (PER TARGET PREFIX)
# ============================================================

def ensure_output_tables_malaria(conn, prefix):
    cur = conn.cursor()

    # drop in reverse dependency-ish order just to be safe
    table_suffixes = [
        "exec_summary",
        "climate_influence",
        "nowcast_gap",
        "isochrone_spread",
        "peak_projection",
        "surveillance_quality",
        "confidence",
        "acceleration_alerts",
        "overflow_risk",
        "watchlist",
    ]

    for suf in table_suffixes:
        cur.execute(f"DROP TABLE IF EXISTS {prefix}_{suf} CASCADE;")

    # WATCHLIST
    cur.execute(f"""
    CREATE TABLE {prefix}_watchlist (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        expected_cases_next_week NUMERIC,
        high_scenario_p90 NUMERIC,
        status TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # OVERFLOW RISK
    cur.execute(f"""
    CREATE TABLE {prefix}_overflow_risk (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        capacity_threshold_beds_per_week INT,
        forecast_median_next_week NUMERIC,
        high_scenario_p90 NUMERIC,
        breach_risk_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # ACCELERATION ALERTS
    cur.execute(f"""
    CREATE TABLE {prefix}_acceleration_alerts (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        last_week_cases NUMERIC,
        this_week_actual NUMERIC,
        this_week_predicted NUMERIC,
        next_week_forecast NUMERIC,
        growth_rate_wow NUMERIC,
        growth_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # CONFIDENCE
    cur.execute(f"""
    CREATE TABLE {prefix}_confidence (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        forecast_next_week NUMERIC,
        uncertainty_width NUMERIC,
        confidence_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # SURVEILLANCE QUALITY
    cur.execute(f"""
    CREATE TABLE {prefix}_surveillance_quality (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        reporting_continuity_pct NUMERIC,
        data_quality_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # PEAK PROJECTION
    cur.execute(f"""
    CREATE TABLE {prefix}_peak_projection (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        peak_lead_time_weeks NUMERIC,
        peak_cases_median NUMERIC,
        peak_cases_high_p90 NUMERIC,
        peak_when TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # ISOCHRONE SPREAD
    cur.execute(f"""
    CREATE TABLE {prefix}_isochrone_spread (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        forecast_next_week_median NUMERIC,
        forecast_next_week_hi NUMERIC,
        growth_flag TEXT,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # NOWCAST GAP
    cur.execute(f"""
    CREATE TABLE {prefix}_nowcast_gap (
        district TEXT NOT NULL,
        upazila  TEXT NOT NULL,
        year INT NOT NULL,
        month INT NOT NULL,
        current_week_actual NUMERIC,
        current_week_predicted_from_prev NUMERIC,
        nowcast_gap_percent NUMERIC,
        created_at TIMESTAMP DEFAULT now(),
        PRIMARY KEY (district, upazila, year, month)
    );
    """)

    # CLIMATE INFLUENCE
    cur.execute(f"""
    CREATE TABLE {prefix}_climate_influence (
        feature TEXT,
        base_var TEXT,
        lag_info TEXT,
        pearson_corr_with_next_week_forecast NUMERIC,
        abs_corr NUMERIC,
        created_at TIMESTAMP DEFAULT now()
    );
    """)

    # EXEC SUMMARY
    cur.execute(f"""
    CREATE TABLE {prefix}_exec_summary (
        id SERIAL PRIMARY KEY,
        summary JSONB,
        created_at TIMESTAMP DEFAULT now()
    );
    """)

    conn.commit()
    cur.close()
  
# ============================================================
# DATA LOAD + FEATURE ENGINEERING
# ============================================================

def load_malaria_monthly_from_db():
    """
    Pull raw monthly malaria + climate from Postgres malaria_weather.
    We'll clean, compute per-1k rates, fill gaps.
    Required columns:
    dis_name, upa_name, upazilaid, year, month,
    pv, pf, population,
    average_temperature, total_rainfall, relative_humidity,
    average_ndvi, average_ndwi.
    """
    conn = get_db_conn()
    try:
        sql = """
            SELECT
                dis_name,
                upa_name,
                upazilaid,
                year,
                month,
                pv,
                pf,
                population,
                average_temperature,
                total_rainfall,
                relative_humidity,
                average_ndvi,
                average_ndwi
            FROM malaria_weather
            WHERE year IS NOT NULL
              AND month IS NOT NULL
            ORDER BY upazilaid, year, month;
        """
        df = pd.read_sql(sql, conn)
    finally:
        conn.close()

    # normalize
    df.columns = [c.strip().lower() for c in df.columns]

    df["upazilaid"] = pd.to_numeric(df["upazilaid"], errors="coerce")
    df["year"]  = pd.to_numeric(df["year"], errors="coerce").astype("Int64")
    df["month"] = pd.to_numeric(df["month"], errors="coerce").astype("Int64")

    for c in ["pv","pf","population",
              "average_temperature","total_rainfall","relative_humidity",
              "average_ndvi","average_ndwi"]:
        df[c] = pd.to_numeric(df[c], errors="coerce")

    # rate per 1000 pop
    df["pv_rate_1k"] = np.where(
        (df["population"]>0) & df["pv"].notna(),
        1000.0 * df["pv"] / df["population"],
        np.nan
    )
    df["pf_rate_1k"] = np.where(
        (df["population"]>0) & df["pf"].notna(),
        1000.0 * df["pf"] / df["population"],
        np.nan
    )

    # chronological key
    df["ym"] = df["year"].astype(int)*12 + df["month"].astype(int)

    # forward fill climate + population + targets per upazila
    fcols = BASE_FEATURES + ["population","pv_rate_1k","pf_rate_1k"]
    for col in fcols:
        if col not in df.columns:
            df[col] = np.nan
    df = df.sort_values(["upazilaid","year","month"])
    df[fcols] = (
        df.groupby("upazilaid")[fcols]
          .apply(lambda g: g.ffill())
          .reset_index(level=0, drop=True)
    )

    # replace inf -> NaN -> ffill again
    df.replace([np.inf,-np.inf], np.nan, inplace=True)
    df[fcols] = (
        df.groupby("upazilaid")[fcols]
          .apply(lambda g: g.ffill())
          .reset_index(level=0, drop=True)
    )

    # seasonality embeddings
    df["month_sin"] = np.sin(2*np.pi*df["month"].astype(float)/12.0).astype(np.float32)
    df["month_cos"] = np.cos(2*np.pi*df["month"].astype(float)/12.0).astype(np.float32)

    return df

def add_lags_rolls_monthly(df, group_col="upazilaid",
                           base_cols=BASE_FEATURES,
                           lags=(1,3,6),
                           rolls=(3,6)):
    """
    Adds lag and rolling stats for each base col.
    """
    df = df.sort_values([group_col,"year","month"]).copy()
    new_cols = []
    for c in base_cols:
        for L in lags:
            col = f"{c}_lag{L}"
            df[col] = df.groupby(group_col)[c].shift(L)
            new_cols.append(col)
        for R in rolls:
            m = f"{c}_rmean{R}"
            s = f"{c}_rstd{R}"
            g = df.groupby(group_col)[c]
            df[m] = g.rolling(R, min_periods=1).mean().reset_index(level=0, drop=True)
            df[s] = g.rolling(R, min_periods=1).std().reset_index(level=0, drop=True).fillna(0.0)
            new_cols += [m,s]
    df[new_cols] = df[new_cols].fillna(0.0)
    return df, new_cols

# ============================================================
# SEQUENCE BUILDING / SCALING
# ============================================================

class Split:
    def __init__(self, X, y, c, df):
        self.X=X; self.y=y; self.c=c; self.df=df

def build_mats(df, feat_cols, target_col):
    """
    df must have:
      - upazilaid
      - year, month
      - target_col (pv_rate_1k or pf_rate_1k)
    We encode upazilaid -> entity_id.
    """
    le = LabelEncoder().fit(df["upazilaid"].astype(str).values)

    d = df.copy()
    d["entity_id"] = le.transform(d["upazilaid"].astype(str))

    X = d[feat_cols].values.astype(np.float32)
    feat_scaler = StandardScaler().fit(X)
    Xs = feat_scaler.transform(X)

    # log1p target (non-negative)
    y_raw = np.log1p(np.clip(d[target_col].values.reshape(-1,1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y_raw, np.float32)
    y_scalers = {}
    for cid, g in d.groupby("entity_id"):
        idx = g.index.values
        sc  = StandardScaler().fit(y_raw[idx])
        y_scalers[cid] = sc
        Ys[idx] = sc.transform(y_raw[idx])

    return Split(Xs, Ys, d["entity_id"].values.astype(np.int64), d), le, feat_scaler, y_scalers

def apply_mats(df, feat_cols, target_col, le, feat_scaler, y_scalers):
    d = df.copy()
    d["entity_id"] = le.transform(d["upazilaid"].astype(str))
    X = d[feat_cols].values.astype(np.float32)
    Xs = feat_scaler.transform(X)

    y_raw = np.log1p(np.clip(d[target_col].values.reshape(-1,1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y_raw, np.float32)
    for cid, g in d.groupby("entity_id"):
        idx = g.index.values
        sc = y_scalers.get(int(cid), StandardScaler().fit(y_raw[idx]))
        Ys[idx] = sc.transform(y_raw[idx])

    return Split(Xs, Ys, d["entity_id"].values.astype(np.int64), d)

def to_seq(split, L=SEQ, prev_y=True):
    """
    Build rolling SEQ-month windows for each upazila,
    require strictly consecutive ym (year*12+month).
    """
    X,y,c,df = split.X, split.y, split.c, split.df
    ym = (df["year"].astype(int)*12 + df["month"].astype(int)).values

    SX,SY,SC = [],[],[]
    for cid in np.unique(c):
        idx = np.where(c==cid)[0]
        # sort within this upazila by ym
        order = np.argsort(ym[idx])
        idx   = idx[order]
        o     = ym[idx]
        for i in range(len(idx)-L+1):
            sl = idx[i:i+L]
            if np.all(np.diff(o[i:i+L]) == 1):
                Xi = X[sl]
                Yi = y[sl]
                if prev_y:
                    prev_col = np.vstack([np.zeros((1,1),np.float32), Yi[:-1]])
                    Xi = np.concatenate([Xi, prev_col], axis=1)
                SX.append(Xi)
                SY.append(Yi)
                SC.append(cid)
    return (
        np.asarray(SX, np.float32),
        np.asarray(SY, np.float32),
        np.asarray(SC, np.int64),
    )

# ============================================================
# MODEL (LSTM + ATTENTION)
# ============================================================

class Forecaster(nn.Module):
    """
    LSTM + causal self-attention with upazila embedding;
    outputs mean, log-sigma, and quantiles.
    """
    def __init__(self, cond_dim, n_entities,
                 emb_ent=8, lstm=LSTM_UNITS, heads=HEADS, drop=DROP, qu=Q):
        super().__init__()
        self.q = qu
        self.ent_emb = nn.Embedding(n_entities, emb_ent)

        self.lstm = nn.LSTM(cond_dim + emb_ent, lstm, 1, batch_first=True)
        self.mha  = nn.MultiheadAttention(lstm, heads, batch_first=True, dropout=drop)
        self.ln   = nn.LayerNorm(lstm)
        self.drop = nn.Dropout(drop)

        self.mu = nn.Linear(lstm,1)
        self.ls = nn.Linear(lstm,1)
        self.qh = nn.ModuleList([nn.Linear(lstm,1) for _ in qu])

        self.last_attn = None

    def forward(self, cond, ent_id):
        B,L,D = cond.shape
        e = self.ent_emb(ent_id).unsqueeze(1).repeat(1,L,1)  # [B,L,emb_ent]
        x = torch.cat([cond, e], dim=-1)

        h,_ = self.lstm(x)

        # causal mask so each month can't see future months
        mask = torch.triu(torch.ones(L, L, device=h.device, dtype=torch.bool), diagonal=1)
        att, w = self.mha(h,h,h, attn_mask=mask, need_weights=True)
        self.last_attn = w.detach()

        h = self.drop(self.ln(att))

        mu = self.mu(h)
        ls = torch.clamp(self.ls(h), -5.0, 3.0)
        qs = [head(h) for head in self.qh]

        return mu, ls, qs

# ============================================================
# TRAIN / VAL
# ============================================================

def smape(y, p):
    y = y.flatten()
    p = p.flatten()
    return 100*np.mean(2*np.abs(p-y)/(np.abs(y)+np.abs(p)+1e-8))

def pinball(pred, target, quantile):
    err = target - pred
    return torch.mean(torch.maximum(quantile * err, (quantile - 1) * err))

def tv_l1_on_mu(mu):
    diff = mu[:,1:,:] - mu[:,:-1,:]
    return diff.abs().mean(), mu.abs().mean()

class SeqDS(Dataset):
    def __init__(self,X,Y,C):
        self.X=X; self.Y=Y; self.C=C
    def __len__(self): return self.X.shape[0]
    def __getitem__(self,i):
        return self.X[i], self.Y[i], self.C[i]

def validate_model(M,Xv,Yv,Cv):
    M.eval()
    with torch.no_grad():
        X  = torch.tensor(Xv, dtype=torch.float32, device=DEVICE)
        C  = torch.tensor(Cv, dtype=torch.long,   device=DEVICE)
        mu, ls, qs = M(X, C)

        mp  = mu[...,0].cpu().numpy()
        y   = Yv[...,0]
        q10 = qs[0][...,0].cpu().numpy()
        q90 = qs[2][...,0].cpu().numpy()

        coverage = float(np.mean((y>=q10) & (y<=q90)))
        score    = smape(y.reshape(-1), mp.reshape(-1))

    return score, coverage

def train_model(Xtr,Ytr,Ctr,
                Xva,Yva,Cva,
                cond_dim,nc):
    M = Forecaster(cond_dim=cond_dim, n_entities=nc).to(DEVICE)
    opt = torch.optim.AdamW(M.parameters(),
                            lr=LR,
                            betas=(0.9,0.999),
                            weight_decay=WD)

    dl = DataLoader(SeqDS(Xtr,Ytr,Ctr),
                    batch_size=BATCH,
                    shuffle=True,
                    drop_last=True)

    best=float("inf")
    best_sd=None
    wait=0

    for e in range(EPOCHS):
        # scheduled sampling factor for prev_y channel
        t  = e / max(1,(EPOCHS-1))
        tf = TF_START + (TF_END-TF_START)*t

        M.train()
        for Xb,Yb,Cb in dl:
            Xb  = Xb.to(torch.float32).to(DEVICE)
            Yb  = Yb.to(torch.float32).to(DEVICE)
            Cb  = Cb.to(torch.long).to(DEVICE)

            B,L,_ = Xb.shape

            # scheduled teacher forcing for prev_y (last feature in X)
            with torch.no_grad():
                mu0,_,_ = M(Xb,Cb)
                prev = torch.cat(
                    [torch.zeros(B,1,device=DEVICE), mu0[:,:-1,0]],
                    dim=1
                )
            Xb[:,:,-1] = tf*Xb[:,:,-1] + (1-tf)*prev

            opt.zero_grad(set_to_none=True)

            mu, ls, qs = M(Xb,Cb)

            sig = (ls.exp()).clamp(1e-3,50.0)
            nll = 0.5*(((Yb-mu)/sig)**2 + 2*ls + math.log(2*math.pi)).mean()

            ql = 0.0
            for i, qv in enumerate(Q):
                ql += pinball(qs[i],Yb,qv)
            ql /= len(Q)

            tv, l1 = tv_l1_on_mu(mu)
            loss = 1.0*nll + 1.0*ql + 0.05*tv + 0.02*l1

            loss.backward()
            torch.nn.utils.clip_grad_norm_(M.parameters(),CLIP)
            opt.step()

        sm, cov = validate_model(M,Xva,Yva,Cva)
        comp = sm + 10*abs(cov-0.9)

        if comp < best-1e-6:
            best    = comp
            best_sd = copy.deepcopy(M.state_dict())
            wait    = 0
        else:
            wait += 1
            if wait >= PATIENCE:
                break

    if best_sd is not None:
        M.load_state_dict(best_sd)
    return M

# ============================================================
# DASHBOARD HELPERS
# ============================================================

def inverse_seq(arr_scaled, scaler):
    arr_scaled = np.array(arr_scaled).reshape(-1,1)
    unstd = scaler.inverse_transform(arr_scaled).reshape(-1)
    return np.clip(np.expm1(unstd), 0, None)

def classify_growth(g):
    if g >= RAPID_GROWTH_THRESH:
        return "🔴 Rapid growth"
    elif g >= MODERATE_GROWTH_THRESH:
        return "🟠 Moderate growth"
    else:
        return "🟢 Stable/decline"

def classify_risk(high_scenario):
    if high_scenario >= 40:
        return "🔴 Surge likely"
    elif high_scenario >= 20:
        return "🟠 Elevated"
    else:
        return "🟢 Low"

def classify_conf(width):
    if width <= CONF_HIGH:
        return "✅ High confidence"
    elif width <= CONF_MED:
        return "🟡 Medium confidence"
    else:
        return "⚠ Low confidence"

def classify_data_quality(continuity_pct):
    if continuity_pct >= 90:
        return "✅ Reliable"
    elif continuity_pct >= 70:
        return "🟡 Watch"
    else:
        return "🔴 Needs field check"

def evaluate_and_build_products_malaria(
    M,
    Xseq, Yseq, Cseq,
    y_scalers, le,
    seq_lookup_df,
    feat_cols
):
    """
    Build director-style products for malaria.
    - each seq = 1 upazila
    - attach dis_name, upa_name, year, month
    - derive "weekly" numbers by *7 scaling (placeholder)
    """

    M.eval()
    X_t = torch.tensor(Xseq, dtype=torch.float32, device=DEVICE)
    C_t = torch.tensor(Cseq, dtype=torch.long,   device=DEVICE)

    with torch.no_grad():
        mu, ls, qs = M(X_t, C_t)
        mu_np  = mu.cpu().numpy()[...,0]      # [N,L]
        q10_np = qs[0].cpu().numpy()[...,0]
        q50_np = qs[1].cpu().numpy()[...,0]
        q90_np = qs[2].cpu().numpy()[...,0]

    idx_curr = -2
    idx_next = -1

    district_names = []
    upazila_names  = []
    year_list      = []
    month_list     = []

    last_week_cases_list = []
    this_week_actual_list = []
    this_week_pred_list   = []
    next_week_forecast_list = []
    next_week_hi_list       = []
    next_week_lo_list       = []
    growth_rate_list        = []
    conf_width_list         = []

    peak_day_list = []
    peak_val_list = []
    peak_hi_list  = []
    peak_leadtime_list = []

    all_y_real = []
    all_mu_real = []
    all_lo_real = []
    all_hi_real = []

    last_meta = {}  # key=(district,upazila)

    for i, cid in enumerate(Cseq):
        sc = y_scalers[int(cid)]
        upa_df = seq_lookup_df[int(cid)]
        upa_df = upa_df.sort_values(["year","month"])

        y_real_full   = inverse_seq(Yseq[i].reshape(-1),  sc)
        mu_real_full  = inverse_seq(mu_np[i].reshape(-1), sc)
        lo_real_full  = inverse_seq(q10_np[i].reshape(-1),sc)
        hi_real_full  = inverse_seq(q90_np[i].reshape(-1),sc)

        all_y_real.append(y_real_full)
        all_mu_real.append(mu_real_full)
        all_lo_real.append(lo_real_full)
        all_hi_real.append(hi_real_full)

        safe_idx_curr = idx_curr if idx_curr >= -len(y_real_full) else -1
        safe_idx_next = idx_next if idx_next >= -len(y_real_full) else -1

        actual_curr  = y_real_full[safe_idx_curr]
        pred_curr    = mu_real_full[safe_idx_curr]
        if (safe_idx_curr-1)>=-len(y_real_full):
            actual_prev = y_real_full[safe_idx_curr-1]
        else:
            actual_prev = y_real_full[safe_idx_curr]

        pred_next    = mu_real_full[safe_idx_next]
        hi_next      = hi_real_full[safe_idx_next]
        lo_next      = lo_real_full[safe_idx_next]

        growth      = (pred_next - actual_prev) / (actual_prev + 1e-6)
        conf_width  = hi_next - lo_next

        # crude peak projection on last PEAK_LOOKAHEAD_DAYS steps
        look_slice = mu_real_full[-PEAK_LOOKAHEAD_DAYS:]
        hi_slice   = hi_real_full[-PEAK_LOOKAHEAD_DAYS:]
        if len(look_slice)==0:
            peak_when="NA"; peak_val=np.nan; peak_hi=np.nan; lead_weeks=np.nan
        else:
            local_max_idx = int(np.argmax(look_slice))
            peak_val = look_slice[local_max_idx]
            peak_hi  = hi_slice[local_max_idx]
            lead_weeks = ((len(look_slice)-1) - local_max_idx)/7.0
            peak_when = f"t+{(len(look_slice)-1 - local_max_idx)}d"

        # latest metadata for this upazila
        lastrow = upa_df.iloc[-1]
        dis_name  = str(lastrow["dis_name"])
        upa_name  = str(lastrow["upa_name"])
        yval      = int(lastrow["year"])
        mval      = int(lastrow["month"])

        # scale to "weekly-style" numbers for dashboards (placeholder)
        next_week_forecast = pred_next * 7.0
        next_week_hi       = hi_next   * 7.0
        next_week_lo       = lo_next   * 7.0
        this_week_pred     = pred_curr * 7.0
        last_week_cases    = actual_prev * 7.0
        this_week_actual   = actual_curr * 7.0

        district_names.append(dis_name)
        upazila_names.append(upa_name)
        year_list.append(yval)
        month_list.append(mval)

        last_week_cases_list.append(last_week_cases)
        this_week_actual_list.append(this_week_actual)
        this_week_pred_list.append(this_week_pred)
        next_week_forecast_list.append(next_week_forecast)
        next_week_hi_list.append(next_week_hi)
        next_week_lo_list.append(next_week_lo)
        growth_rate_list.append(growth)
        conf_width_list.append(next_week_hi - next_week_lo)

        peak_day_list.append(peak_when)
        peak_val_list.append(peak_val * 7.0)
        peak_hi_list.append(peak_hi * 7.0)
        peak_leadtime_list.append(lead_weeks)

        last_meta[(dis_name, upa_name)] = {
            "Year": yval,
            "Month": mval
        }

    all_y_real  = np.concatenate(all_y_real)
    all_mu_real = np.concatenate(all_mu_real)
    all_lo_real = np.concatenate(all_lo_real)
    all_hi_real = np.concatenate(all_hi_real)

    overall_metrics = {
        "SMAPE": smape(all_y_real, all_mu_real),
        "MSE": mean_squared_error(all_y_real, all_mu_real),
        "RMSE": math.sqrt(mean_squared_error(all_y_real, all_mu_real)),
        "R2": r2_score(all_y_real, all_mu_real),
        "Coverage90": float(np.mean((all_y_real >= all_lo_real) & (all_y_real <= all_hi_real))),
    }

    # build tables

    watchlist_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "Expected_cases_next_week": np.round(next_week_forecast_list,2),
        "High_scenario_p90":        np.round(next_week_hi_list,2),
    })
    watchlist_df["Status"] = watchlist_df["High_scenario_p90"].apply(classify_risk)

    capacity_lookup = {
        (d,u): DEFAULT_CAPACITY_PER_DISTRICT
        for d,u in zip(district_names, upazila_names)
    }
    overflow_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "Capacity_threshold_beds_per_week": [
            capacity_lookup[(d,u)]
            for d,u in zip(district_names, upazila_names)
        ],
        "Forecast_median_next_week": np.round(next_week_forecast_list,2),
        "High_scenario_p90":         np.round(next_week_hi_list,2),
    })
    overflow_df["Breach_risk_flag"] = [
        "YES" if hi > capacity_lookup[(d,u)] else "NO"
        for (d,u), hi in zip(zip(district_names,upazila_names), next_week_hi_list)
    ]

    accel_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "Last_week_cases":        np.round(last_week_cases_list,2),
        "This_week_actual":       np.round(this_week_actual_list,2),
        "This_week_predicted":    np.round(this_week_pred_list,2),
        "Next_week_forecast":     np.round(next_week_forecast_list,2),
        "Growth_rate_WoW(%)":     np.round(np.array(growth_rate_list)*100,1),
    })
    accel_df["Growth_flag"] = accel_df["Growth_rate_WoW(%)"].apply(
        lambda pct: classify_growth(pct/100.0)
    )

    confidence_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "Forecast_next_week":            np.round(next_week_forecast_list,2),
        "Uncertainty_width(p90-p10)":    np.round(conf_width_list,2),
    })
    confidence_df["Confidence_flag"] = confidence_df["Uncertainty_width(p90-p10)"].apply(classify_conf)

    # simple placeholder for data quality: 100% reporting
    quality_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "reporting_continuity_pct": [100.0]*len(district_names),
    })
    quality_df["Data_quality_flag"] = quality_df["reporting_continuity_pct"].apply(classify_data_quality)

    peak_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "Peak_lead_time_weeks": peak_leadtime_list,
        "Peak_cases_median":    np.round(peak_val_list,2),
        "Peak_cases_high(p90)": np.round(peak_hi_list,2),
        "Peak_when":            peak_day_list
    })

    isochrone_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "Forecast_next_week_median": np.round(next_week_forecast_list,2),
        "Forecast_next_week_hi":     np.round(next_week_hi_list,2),
        "Growth_flag": [
            classify_growth(g) for g in growth_rate_list
        ],
    })

    gap_pct_list = []
    for a, p in zip(this_week_actual_list, this_week_pred_list):
        gap_pct = (p - a)/(a + 1e-6)*100.0
        gap_pct_list.append(gap_pct)
    nowcast_gap_df = pd.DataFrame({
        "District": district_names,
        "Upazila":  upazila_names,
        "Year":     year_list,
        "Month":    month_list,
        "Current_week_actual":               np.round(this_week_actual_list,2),
        "Current_week_predicted_from_prev":  np.round(this_week_pred_list,2),
        "Nowcast_gap_percent":               np.round(gap_pct_list,1),
    })

    # climate importance via corr with forecast
    last_feats_matrix = Xseq[:, -2, :-1]  # exclude prev_y channel
    preds_arr = np.array(next_week_forecast_list)

    feat_import_rows = []
    def parse_feature_name(raw_name: str):
        if raw_name.startswith("month_"):
            return "season(month)", ""
        if "_rmean" in raw_name:
            base, tail = raw_name.split("_rmean")
            return base, f"{tail}m rolling mean"
        if "_rstd" in raw_name:
            base, tail = raw_name.split("_rstd")
            return base, f"{tail}m rolling std"
        if "_lag" in raw_name:
            base, tail = raw_name.split("_lag")
            return base, f"lag {tail}m"
        return raw_name, "same month"

    for fi in range(last_feats_matrix.shape[1]):
        col_vals = last_feats_matrix[:, fi]
        if np.std(col_vals) < 1e-8:
            corr = 0.0
        else:
            corr = np.corrcoef(col_vals, preds_arr)[0,1]

        raw_feat_name = feat_cols[fi]
        base_var, lag_info = parse_feature_name(raw_feat_name)
        feat_import_rows.append({
            "feature": raw_feat_name,
            "base_var": base_var,
            "lag_info": lag_info,
            "pearson_corr_with_next_week_forecast": float(corr),
        })

    climate_influence_df = pd.DataFrame(feat_import_rows)
    climate_influence_df["abs_corr"] = (
        climate_influence_df["pearson_corr_with_next_week_forecast"].abs()
    )
    climate_influence_df = climate_influence_df.sort_values(
        "abs_corr", ascending=False
    )

    top5 = watchlist_df.sort_values("High_scenario_p90", ascending=False).head(5)[
        ["District","Upazila","Expected_cases_next_week","High_scenario_p90","Status"]
    ].to_dict(orient="records")

    rapid = accel_df[accel_df["Growth_flag"]=="🔴 Rapid growth"][["District","Upazila"]].apply(tuple,1).tolist()
    overflow_risk = overflow_df[overflow_df["Breach_risk_flag"]=="YES"][["District","Upazila"]].apply(tuple,1).tolist()
    dq_bad = quality_df[quality_df["Data_quality_flag"]=="🔴 Needs field check"][["District","Upazila"]].apply(tuple,1).tolist()

    summary_text = {
        "Top5_high_risk_next_week": to_python(top5),
        "Upazilas_with_rapid_growth": to_python(rapid),
        "Upazilas_with_capacity_breach_risk": to_python(overflow_risk),
        "Upazilas_with_data_quality_issues": to_python(dq_bad),
        "Model_calibration": {
            "Coverage90": round(float(overall_metrics["Coverage90"]),3),
            "SMAPE": round(float(overall_metrics["SMAPE"]),2),
            "R2": round(float(overall_metrics["R2"]),3),
        }
    }

    products = {
        "watchlist_df": watchlist_df,
        "overflow_df": overflow_df,
        "accel_df": accel_df,
        "confidence_df": confidence_df,
        "quality_df": quality_df,
        "peak_df": peak_df,
        "isochrone_df": isochrone_df,
        "nowcast_gap_df": nowcast_gap_df,
        "climate_influence_df": climate_influence_df,
        "executive_summary": summary_text,
        "overall_metrics": overall_metrics,
        "last_meta": last_meta,
    }

    return overall_metrics, products

# ============================================================
# PUSH TO DB FOR MALARIA
# ============================================================

def push_products_to_db_malaria(conn, prefix, products):
    last_meta = products["last_meta"]

    # watchlist
    watchlist_rows = []
    for _, r in products["watchlist_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        watchlist_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            py(r["Expected_cases_next_week"]),
            py(r["High_scenario_p90"]),
            r["Status"]
        ))
    upsert_table(
        conn,
        f"{prefix}_watchlist",
        ["district","upazila","year","month",
         "expected_cases_next_week","high_scenario_p90","status"],
        watchlist_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # overflow
    overflow_rows = []
    for _, r in products["overflow_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        overflow_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            int(r["Capacity_threshold_beds_per_week"]) if not pd.isna(r["Capacity_threshold_beds_per_week"]) else None,
            py(r["Forecast_median_next_week"]),
            py(r["High_scenario_p90"]),
            r["Breach_risk_flag"],
        ))
    upsert_table(
        conn,
        f"{prefix}_overflow_risk",
        ["district","upazila","year","month",
         "capacity_threshold_beds_per_week",
         "forecast_median_next_week","high_scenario_p90","breach_risk_flag"],
        overflow_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # accel
    accel_rows = []
    for _, r in products["accel_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        accel_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            py(r["Last_week_cases"]),
            py(r["This_week_actual"]),
            py(r["This_week_predicted"]),
            py(r["Next_week_forecast"]),
            py(r["Growth_rate_WoW(%)"]),
            r["Growth_flag"],
        ))
    upsert_table(
        conn,
        f"{prefix}_acceleration_alerts",
        ["district","upazila","year","month",
         "last_week_cases","this_week_actual","this_week_predicted",
         "next_week_forecast","growth_rate_wow","growth_flag"],
        accel_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # confidence
    conf_rows = []
    for _, r in products["confidence_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        conf_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            py(r["Forecast_next_week"]),
            py(r["Uncertainty_width(p90-p10)"]),
            r["Confidence_flag"],
        ))
    upsert_table(
        conn,
        f"{prefix}_confidence",
        ["district","upazila","year","month",
         "forecast_next_week","uncertainty_width","confidence_flag"],
        conf_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # quality
    qual_rows = []
    for _, r in products["quality_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        qual_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            py(r["reporting_continuity_pct"]),
            r["Data_quality_flag"],
        ))
    upsert_table(
        conn,
        f"{prefix}_surveillance_quality",
        ["district","upazila","year","month",
         "reporting_continuity_pct","data_quality_flag"],
        qual_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # peak projection
    peak_rows = []
    for _, r in products["peak_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        peak_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            py(r["Peak_lead_time_weeks"]),
            py(r["Peak_cases_median"]),
            py(r["Peak_cases_high(p90)"]),
            r["Peak_when"],
        ))
    upsert_table(
        conn,
        f"{prefix}_peak_projection",
        ["district","upazila","year","month",
         "peak_lead_time_weeks","peak_cases_median",
         "peak_cases_high_p90","peak_when"],
        peak_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # isochrone spread
    iso_rows = []
    for _, r in products["isochrone_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        iso_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            py(r["Forecast_next_week_median"]),
            py(r["Forecast_next_week_hi"]),
            r["Growth_flag"],
        ))
    upsert_table(
        conn,
        f"{prefix}_isochrone_spread",
        ["district","upazila","year","month",
         "forecast_next_week_median",
         "forecast_next_week_hi","growth_flag"],
        iso_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # nowcast gap
    gap_rows = []
    for _, r in products["nowcast_gap_df"].iterrows():
        key = (r["District"], r["Upazila"])
        meta = last_meta.get(key, {})
        gap_rows.append((
            r["District"],
            r["Upazila"],
            py(meta.get("Year")),
            py(meta.get("Month")),
            py(r["Current_week_actual"]),
            py(r["Current_week_predicted_from_prev"]),
            py(r["Nowcast_gap_percent"]),
        ))
    upsert_table(
        conn,
        f"{prefix}_nowcast_gap",
        ["district","upazila","year","month",
         "current_week_actual",
         "current_week_predicted_from_prev",
         "nowcast_gap_percent"],
        gap_rows,
        pk_cols=["district","upazila","year","month"],
        wipe_first=False
    )

    # climate influence
    clim_rows = []
    for _, r in products["climate_influence_df"].iterrows():
        clim_rows.append((
            r["feature"],
            r["base_var"],
            r["lag_info"],
            py(r["pearson_corr_with_next_week_forecast"]),
            py(r["abs_corr"]),
        ))
    upsert_table(
        conn,
        f"{prefix}_climate_influence",
        ["feature","base_var","lag_info",
         "pearson_corr_with_next_week_forecast","abs_corr"],
        clim_rows,
        pk_cols=None,
        wipe_first=True
    )

    # exec summary
    summary_json = json.dumps(to_python(products["executive_summary"]))
    upsert_table(
        conn,
        f"{prefix}_exec_summary",
        ["summary"],
        [(summary_json,)],
        pk_cols=None,
        wipe_first=True
    )

# ============================================================
# PIPELINE PER TARGET (pv_rate_1k / pf_rate_1k)
# ============================================================

def run_target_pipeline(df_all, target_col, prefix, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    print(f"\n=== Running malaria target: {target_col} ({prefix}) ===")

    # Build lagged/rolled features on climate
    df_feat, lag_cols = add_lags_rolls_monthly(
        df_all,
        group_col="upazilaid",
        base_cols=BASE_FEATURES,
        lags=(1,3,6),
        rolls=(3,6)
    )

    # final feature list
    feat_cols = (
        ["month_sin","month_cos"] +
        lag_cols
    )

    # per-upazila temporal split: TRAIN / VAL / TEST
    tr_parts, va_parts, te_parts = [], [], []
    for uid, g in df_feat.groupby("upazilaid"):
        g = g.sort_values(["year","month"]).copy()
        if len(g) > (VAL_H_MONTHS + TEST_H_MONTHS):
            te_parts.append(g.iloc[-TEST_H_MONTHS:])
            va_parts.append(g.iloc[-(VAL_H_MONTHS + TEST_H_MONTHS):-TEST_H_MONTHS])
            tr_parts.append(g.iloc[:-(VAL_H_MONTHS + TEST_H_MONTHS)])
        elif len(g) > TEST_H_MONTHS:
            te_parts.append(g.iloc[-TEST_H_MONTHS:])
            tr_parts.append(g.iloc[:-TEST_H_MONTHS])
        else:
            tr_parts.append(g)
    tr_df = pd.concat(tr_parts).reset_index(drop=True)
    va_df = pd.concat(va_parts).reset_index(drop=True) if va_parts else tr_df.iloc[0:0].copy()
    te_df = pd.concat(te_parts).reset_index(drop=True) if te_parts else tr_df.iloc[0:0].copy()

    # Fit scalers on TRAIN
    split_tr, le_tr, feat_scaler_tr, y_scalers_tr = build_mats(tr_df, feat_cols, target_col)

    # Sequences
    X_tr, Y_tr, C_tr = to_seq(split_tr, SEQ, prev_y=True)
    if X_tr.shape[0] == 0:
        raise RuntimeError("No train sequences. Reduce SEQ or check data continuity.")
    print(f"Train sequences: {X_tr.shape}")

    # Validation seqs
    if len(va_df) > 0:
        split_va = apply_mats(va_df, feat_cols, target_col, le_tr, feat_scaler_tr, y_scalers_tr)
        X_va, Y_va, C_va = to_seq(split_va, SEQ, prev_y=True)
        if X_va.shape[0] == 0:
            X_va, Y_va, C_va = X_tr, Y_tr, C_tr
    else:
        X_va, Y_va, C_va = X_tr, Y_tr, C_tr

    # Test seqs
    if len(te_df) > 0:
        split_te = apply_mats(te_df, feat_cols, target_col, le_tr, feat_scaler_tr, y_scalers_tr)
        X_te, Y_te, C_te = to_seq(split_te, SEQ, prev_y=True)
        if X_te.shape[0] == 0:
            X_te, Y_te, C_te = X_va, Y_va, C_va
    else:
        X_te, Y_te, C_te = X_va, Y_va, C_va

    # Build lookup df for metadata per entity_id for TEST split
    # We want each upazila's rows in time order, including names + year/month
    te_df_with_eid = te_df.copy()
    te_df_with_eid["entity_id"] = le_tr.transform(te_df_with_eid["upazilaid"].astype(str))
    seq_lookup_df = {}
    for cid, g in te_df_with_eid.groupby("entity_id"):
        g = g.sort_values(["year","month"]).copy()
        seq_lookup_df[int(cid)] = g

    # Train model
    nc = int(C_tr.max()) + 1
    cond_dim = X_tr.shape[2]
    print("Training LSTM+Attention forecaster ...")
    M = train_model(
        X_tr, Y_tr, C_tr,
        X_va, Y_va, C_va,
        cond_dim, nc
    )
    torch.save(M.state_dict(), os.path.join(out_dir, "forecaster.pt"))

    # Build director products on TEST split
    overall_metrics, products = evaluate_and_build_products_malaria(
        M,
        X_te, Y_te, C_te,
        y_scalers_tr, le_tr,
        seq_lookup_df,
        feat_cols
    )

    # save local artifacts
    products["watchlist_df"].to_csv(os.path.join(out_dir,"watchlist.csv"), index=False)
    products["overflow_df"].to_csv(os.path.join(out_dir,"overflow_risk.csv"), index=False)
    products["accel_df"].to_csv(os.path.join(out_dir,"acceleration_alerts.csv"), index=False)
    products["confidence_df"].to_csv(os.path.join(out_dir,"forecast_confidence.csv"), index=False)
    products["quality_df"].to_csv(os.path.join(out_dir,"surveillance_quality.csv"), index=False)
    products["peak_df"].to_csv(os.path.join(out_dir,"peak_projection.csv"), index=False)
    products["isochrone_df"].to_csv(os.path.join(out_dir,"isochrone_spread.csv"), index=False)
    products["nowcast_gap_df"].to_csv(os.path.join(out_dir,"nowcast_gap.csv"), index=False)
    products["climate_influence_df"].to_csv(os.path.join(out_dir,"climate_lag_influence.csv"), index=False)

    dash_manifest = {
        "executive_summary": products["executive_summary"],
        "overall_metrics": products["overall_metrics"],
        "notes": f"Auto-generated malaria decision tables for {prefix} (monthly model scaled to 'weekly')."
    }
    with open(os.path.join(out_dir,"DASHBOARD_SUMMARY.json"), "w") as f:
        json.dump(to_python(dash_manifest), f, indent=2)

    # Push to DB
    conn = get_db_conn()
    ensure_output_tables_malaria(conn, prefix)
    push_products_to_db_malaria(conn, prefix, products)
    conn.close()

    print("\n=== Executive Summary Preview ===")
    print(json.dumps(to_python(dash_manifest["executive_summary"]), indent=2))
    print(f"\nArtifacts saved to {out_dir} and pushed to Postgres.\n")

# ============================================================
# MAIN
# ============================================================

def main():
    print("1) Load malaria monthly data from Postgres")
    df_all = load_malaria_monthly_from_db()

    # PV
    pv_out_dir = os.path.join(OUT_ROOT, "pv_rate_1k")
    run_target_pipeline(
        df_all,
        target_col="pv_rate_1k",
        prefix="malaria_pv",              # tables: malaria_pv_*
        out_dir=pv_out_dir
    )

    # PF
    pf_out_dir = os.path.join(OUT_ROOT, "pf_rate_1k")
    run_target_pipeline(
        df_all,
        target_col="pf_rate_1k",
        prefix="malaria_pf",              # tables: malaria_pf_*
        out_dir=pf_out_dir
    )

if __name__ == "__main__":
    main()