In [1]:
print("Hello, World!")

Hello, World!


In [None]:
# For each target, load the model and predict
targets = ["pv_rate", "pf_rate", "mixed_rate"]
for tgt in targets:
    print(f"Starting predictions for {tgt}")
    out_dir = f"./final_output/malaria_e2e/Target_{tgt.split('_')[0].upper()}_over_Pop"
    # Load VAE
    vae = CondVAE(inp=len(INDICATORS),hid=HIDDEN_MULT*len(INDICATORS),lat=LATENT,n_upazila=full_df["upazila_code"].max()+1,n_year=full_df["year_code"].max()+1)
    vae.load_state_dict(torch.load(os.path.join(out_dir,"conditional_vae.pt"), map_location=DEVICE))
    vae.to(DEVICE)

    # For scalers, process hist_df
    lat_hist = vae_latent_df(vae, hist_df)
    base_hist = lat_hist.merge(hist_df[["UpazilaID","year","month","ym"] + INDICATORS + [tgt]], on=["UpazilaID","year","month","ym"], how="left")
    base_hist, lag_cols = add_lags(base_hist, INDICATORS, group_col="UpazilaID", lags=(1,3,6), rolls=(3,6))
    pc_cols = [c for c in base_hist.columns if c.startswith("pc_mean_") or c.startswith("pc_std_")]
    feats = pc_cols + ["month_sin","month_cos"] + lag_cols
    # Load training data to fit scalers
    split_tr, le_tr, fsc_tr, ysc_tr = build_mats(base_hist, feats, tgt, "UpazilaID")

    # Get latents for full
    lat = vae_latent_df(vae, full_df)
    # Merge
    merge_cols = ["UpazilaID","year","month","ym"] + INDICATORS
    if tgt in full_df.columns:
        merge_cols += [tgt]
    base = lat.merge(full_df[merge_cols], on=["UpazilaID","year","month","ym"], how="left")
    # Add lags
    base, lag_cols = add_lags(base, INDICATORS, group_col="UpazilaID", lags=(1,3,6), rolls=(3,6))
    # Feats (same as above)
    pc_cols = [c for c in base.columns if c.startswith("pc_mean_") or c.startswith("pc_std_")]
    feats = pc_cols + ["month_sin","month_cos"] + lag_cols

    # Load Generator
    nc = int(split_tr.c.max()) + 1
    cond_dim = len(feats) + 1  # + prev_y
    G = Generator(cond_dim=cond_dim,noise_dim=NOISE,nc=nc)
    G.load_state_dict(torch.load(os.path.join(out_dir,"generator.pt"), map_location=DEVICE))
    G.to(DEVICE)
    G.eval()
    # Now, for each upazila, forecast the next month
    predictions = []
    for uid in hist_df["UpazilaID"].unique():
        g = base[base["UpazilaID"] == uid].sort_values(["year","month"]).copy()
        if len(g) < SEQ:
            continue
        # Take last SEQ
        last_seq = g.iloc[-SEQ:].copy()
        # Prepare X
        X = last_seq[feats].values.astype(np.float32)
        Xs = fsc_tr.transform(X)
        # prev_y, start with 0
        prev_y = np.zeros((SEQ,1), np.float32)
        X_full = np.concatenate([Xs, prev_y], axis=1)
        # Predict next
        X_t = torch.tensor(X_full[np.newaxis, :, :], device=DEVICE)
        C_t = torch.tensor([le_tr.transform([str(uid)])[0]], device=DEVICE)
        with torch.no_grad():
            mu, _, _ = G(X_t, torch.zeros(1, SEQ, NOISE, device=DEVICE), C_t)
        pred_scaled = mu.cpu().numpy()[0, -1, 0]
        pred_rate = np.expm1(ysc_tr[int(C_t.item())].inverse_transform([[pred_scaled]])[0,0])
        predictions.append({"UpazilaID": uid, "predicted_rate": pred_rate, "target": tgt})
    print(f"Predictions count: {len(predictions)}")
    pred_df_out = pd.DataFrame(predictions)
    pred_df_out.to_csv(f"./predictions_{tgt}.csv", index=False)
    print(f"Predictions for {tgt} saved to ./predictions_{tgt}.csv")

Starting predictions for pv_rate
Predictions count: 71
Predictions for pv_rate saved to ./predictions_pv_rate.csv
Starting predictions for pf_rate
Predictions count: 71
Predictions for pv_rate saved to ./predictions_pv_rate.csv
Starting predictions for pf_rate
Predictions count: 71
Predictions for pf_rate saved to ./predictions_pf_rate.csv
Starting predictions for mixed_rate
Predictions count: 71
Predictions for pf_rate saved to ./predictions_pf_rate.csv
Starting predictions for mixed_rate
Predictions count: 71
Predictions for mixed_rate saved to ./predictions_mixed_rate.csv
Predictions count: 71
Predictions for mixed_rate saved to ./predictions_mixed_rate.csv


In [None]:
# ===== Malaria Upazila-wise forecasting (PV/Pop, PF/Pop, MIXED/Pop) =====
import os, math, json, copy, random, warnings
warnings.filterwarnings("ignore")

import numpy as np, pandas as pd
from math import erf
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, r2_score
import matplotlib.pyplot as plt

try:
    plt.style.use('seaborn-v0_8-darkgrid')
except Exception:
    pass

# ----------------- Config -----------------
MALARIA_CSV = "./TrainingFile2012_2024.csv"   # <<< set your path
OUT_ROOT = "./output/malaria_e2e"; os.makedirs(OUT_ROOT, exist_ok=True)

SEED=42; DEVICE="cuda" if torch.cuda.is_available() else "cpu"
SEQ=12; BATCH=64; EPOCHS_VAE=400; EPOCHS_GAN=400
LATENT=12; HIDDEN_MULT=4; DROP=0.25; BETA=2.0; CONTR=1e-3; PATIENCE=25
NOISE=16; LSTM_UNITS=128; HEADS=8
LR_VAE=1e-3; LR_G=8e-4; LR_D=3e-4; WD=1e-4; CLIP=1.0; TTUR=(1.0,1.0)
TF_START=1.0; TF_END=0.3
ALPHA=0.5; K_SYNC=5; TOPK=10
Q=(0.1,0.5,0.9)
ABL={"adv":True,"hetero":True,"quant":True}
NSIG=(0.1,1.2); K_MC=50

# indicators provided in your sample
INDICATORS = [
    "Average_temperature","Total_rainfall","Relative_humidity","Average_NDVI","Average_NDWI"
]

VAL_H_MONTHS  = 6    # per Upazila for early stop
TEST_H_MONTHS = 6    # per Upazila for final test

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark=False; torch.backends.cudnn.deterministic=True

# ----------------- Utils -----------------
def smape(y, p): y=y.flatten(); p=p.flatten(); return 100*np.mean(2*np.abs(p-y)/(np.abs(y)+np.abs(p)+1e-8))
def lerp(a,b,t): return {k:a[k]+(b[k]-a[k])*t for k in a}
def psd(M): e,V=np.linalg.eigh(M); e[e<1e-6]=1e-6; return (V@np.diag(e)@V.T).astype(np.float32)

# ----------------- Data (Malaria) -----------------
def load_malaria(path):
    """
    Expected columns (at minimum):
    UpazilaID, month, year, Population, PV, PF, MIXED, and the 5 indicators.
    Optional: DIS_NAME/DIS_CODE/UPA_NAME etc. are ignored.
    """
    df = pd.read_csv(path, sep=None, engine="python")
    # normalize column names
    df.columns = [c.strip() for c in df.columns]
    # safety: standard names
    colmap = {
        "Month":"month","MonthNo":"month","MONTH":"month",
        "Year":"year","YEAR":"year",
        "UpazilaId":"UpazilaID","UPAZILAID":"UpazilaID","UpazilaID":"UpazilaID"
    }
    for k,v in colmap.items():
        if k in df.columns and v not in df.columns:
            df[v]=df[k]
    # types
    df["UpazilaID"] = pd.to_numeric(df["UpazilaID"], errors="coerce").astype("Int64")
    df["month"] = pd.to_numeric(df["month"], errors="coerce").astype(int)
    df["year"]  = pd.to_numeric(df["year"],  errors="coerce").astype(int)

    # build targets (rates)
    for cases_col, tgt in [("PV","pv_rate"),("PF","pf_rate"),("MIXED","mixed_rate")]:
        if cases_col in df.columns:
            cc = pd.to_numeric(df[cases_col], errors="coerce")
            pp = pd.to_numeric(df["Population"], errors="coerce")
            m  = (pp>0) & cc.notna()
            df[tgt] = np.nan
            df.loc[m, tgt] = (cc[m] / pp[m]).astype(float)
        else:
            df[tgt] = np.nan

    # order & date key
    df = df.sort_values(["UpazilaID","year","month"]).reset_index(drop=True)
    df["ym"] = df["year"]*12 + df["month"]

    # forward-fill per Upazila for indicators + population + targets
    cols_to_ffill = set(INDICATORS + ["Population","pv_rate","pf_rate","mixed_rate"])
    present = [c for c in cols_to_ffill if c in df.columns]
    df[present] = df.groupby("UpazilaID")[present].apply(lambda g: g.ffill()).reset_index(level=0, drop=True)

    # replace inf -> NaN -> ffill (still per Upazila)
    df.replace([np.inf, -np.inf], np.nan, inplace=True)
    df[present] = df.groupby("UpazilaID")[present].apply(lambda g: g.ffill()).reset_index(level=0, drop=True)

    # seasonality
    df["month_sin"] = np.sin(2*np.pi*df["month"]/12.).astype(np.float32)
    df["month_cos"] = np.cos(2*np.pi*df["month"]/12.).astype(np.float32)

    # codes for conditioning VAE
    df["upazila_code"] = pd.Categorical(df["UpazilaID"]).codes
    df["year_code"]    = pd.Categorical(df["year"]).codes
    return df

# ----------------- Conditional VAE on indicators -----------------
class CondVAE(nn.Module):
    def __init__(s,inp,hid,lat,n_upazila,n_year,emb=16):
        super().__init__()
        s.ue=nn.Embedding(n_upazila,emb); s.ye=nn.Embedding(n_year,emb)
        s.p=nn.Linear(inp,hid//2); s.f=nn.Linear(hid//2+2*emb,hid)
        s.mu=nn.Linear(hid,lat); s.lv=nn.Linear(hid,lat)
        s.fd=nn.Linear(lat+2*emb,hid); s.out=nn.Linear(hid,inp)
        s.l1=nn.LayerNorm(hid); s.l2=nn.LayerNorm(hid); s.drop=nn.Dropout(DROP)
    def encode(s,x,u,y):
        h=torch.relu(s.p(x)); h=torch.cat([h,s.ue(u),s.ye(y)],1)
        h=s.drop(torch.relu(s.l1(s.f(h)))); return s.mu(h), s.lv(h)
    def reparam(s,m,l): std=(0.5*l).exp(); return m+torch.randn_like(std)*std
    def decode(s,z,u,y):
        h=torch.cat([z,s.ue(u),s.ye(y)],1); h=s.drop(torch.relu(s.l2(s.fd(h))))
        return s.out(h)
    def forward(s,x,u,y): mu,lv=s.encode(x,u,y); z=s.reparam(mu,lv); return s.decode(z,u,y),mu,lv

def vae_loss(xh,x,mu,lv,model,x_in,u,y):
    recon=F.mse_loss(xh,x)
    kl=-0.5*torch.mean(1+lv-mu.pow(2)-lv.exp())
    x_in.requires_grad_(True); mu_c,_=model.encode(x_in,u,y)
    g=torch.autograd.grad(mu_c.sum(),x_in,create_graph=True)[0]
    return recon+BETA*kl+CONTR*(g.pow(2).sum(1).mean())

def train_vae_on_indicators(df, out_dir):
    # build matrix of indicators
    X = df[INDICATORS].astype(np.float32).values
    u = df["upazila_code"].astype(np.int64).values
    y = df["year_code"].astype(np.int64).values

    # train/val split for VAE: earlier years train, later years val
    yq=np.quantile(df["year"],0.8)
    mtr = (df["year"]<=yq)

    Xt=torch.tensor(X[mtr],dtype=torch.float32,device=DEVICE)
    Xv=torch.tensor(X[~mtr],dtype=torch.float32,device=DEVICE)
    ut=torch.tensor(u[mtr],dtype=torch.long,device=DEVICE)
    yt=torch.tensor(y[mtr],dtype=torch.long,device=DEVICE)
    uv=torch.tensor(u[~mtr],dtype=torch.long,device=DEVICE)
    yv=torch.tensor(y[~mtr],dtype=torch.long,device=DEVICE)

    inp=Xt.shape[1]; hid=HIDDEN_MULT*inp
    n_upazila=int(df["upazila_code"].max())+1
    n_year   =int(df["year_code"].max())+1
    vae=CondVAE(inp,hid,LATENT,n_upazila,n_year).to(DEVICE)

    opt=torch.optim.AdamW(vae.parameters(),lr=LR_VAE,weight_decay=WD)
    sch=torch.optim.lr_scheduler.ReduceLROnPlateau(opt,"min",0.5,5)
    best=float("inf"); best_sd=None; wait=0
    for e in range(EPOCHS_VAE):
        vae.train(); opt.zero_grad(set_to_none=True)
        xh,mu,lv=vae(Xt,ut,yt); loss=vae_loss(xh,Xt,mu,lv,vae,Xt.clone(),ut,yt); loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(),CLIP); opt.step()
        vae.eval()
        with torch.no_grad():
            xhv,mv,lvv=vae(Xv,uv,yv)
            vloss=F.mse_loss(xhv,Xv)+BETA*(-0.5*torch.mean(1+lvv-mv.pow(2)-lvv.exp()))
        sch.step(vloss)
        if vloss.item()<best-1e-6:
            best=vloss.item(); best_sd=copy.deepcopy(vae.state_dict()); wait=0
        else:
            wait+=1
            if e>=30 and wait>=PATIENCE: break
    if best_sd: vae.load_state_dict(best_sd)
    os.makedirs(out_dir, exist_ok=True)
    torch.save(vae.state_dict(), os.path.join(out_dir,"conditional_vae.pt"))
    return vae

def vae_latent_df(vae, df):
    vae.eval()
    X = torch.tensor(df[INDICATORS].astype(np.float32).values, device=DEVICE)
    U = torch.tensor(df["upazila_code"].astype(np.int64).values, device=DEVICE)
    Y = torch.tensor(df["year_code"].astype(np.int64).values, device=DEVICE)
    with torch.no_grad():
        _,mu,_ = vae(X,U,Y)
    Z = mu.detach().cpu().numpy().astype(np.float32)
    out = df[["UpazilaID","year","month","ym"]].copy()
    # attach direct means/std (std not well-defined from single pass → zeros)
    for i in range(LATENT):
        out[f"pc_mean_{i+1}"] = Z[:,i]
        out[f"pc_std_{i+1}"]  = 0.0
    out["month_sin"] = df["month_sin"].values.astype(np.float32)
    out["month_cos"] = df["month_cos"].values.astype(np.float32)
    return out

# --------------- Lag features ----------------
def add_lags(df, base_cols, group_col="UpazilaID", lags=(1,3,6), rolls=(3,6)):
    df = df.sort_values([group_col,"year","month"]).copy()
    new_cols=[]
    for c in base_cols:
        for L in lags:
            col=f"{c}_lag{L}"
            df[col]=df.groupby(group_col)[c].shift(L); new_cols.append(col)
        for R in rolls:
            m=f"{c}_rmean{R}"; s=f"{c}_rstd{R}"
            g=df.groupby(group_col)[c]
            df[m]=g.rolling(R,min_periods=1).mean().reset_index(level=0,drop=True)
            df[s]=g.rolling(R,min_periods=1).std().reset_index(level=0,drop=True).fillna(0.0)
            new_cols += [m,s]
    df[new_cols] = df[new_cols].fillna(0.0)
    return df, new_cols

# ----------------- Seq utilities -----------------
class Split:
    def __init__(s,X,y,c,df): s.X=X; s.y=y; s.c=c; s.df=df

def build_mats(df, feat_cols, target_col, entity_col="UpazilaID"):
    le = LabelEncoder().fit(df[entity_col].astype(str).values)
    d  = df.copy()
    d["entity_id"] = le.transform(d[entity_col].astype(str))

    X = d[feat_cols].values.astype(np.float32)
    feat_scaler = StandardScaler().fit(X)
    Xs = feat_scaler.transform(X)

    # log1p target (non-negative rates)
    y  = np.log1p(np.clip(d[target_col].values.reshape(-1,1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y, np.float32)

    y_scalers = {}
    for cid, g in d.groupby("entity_id"):
        idx = g.index.values
        sc  = StandardScaler().fit(y[idx])
        y_scalers[cid] = sc
        Ys[idx] = sc.transform(y[idx])

    return Split(Xs, Ys, d["entity_id"].values.astype(np.int64), d), le, feat_scaler, y_scalers

def apply_mats(df, feat_cols, target_col, le, feat_scaler, y_scalers, entity_col="UpazilaID"):
    d = df.copy()
    d["entity_id"] = le.transform(d[entity_col].astype(str))
    X = d[feat_cols].values.astype(np.float32)
    Xs = feat_scaler.transform(X)

    y  = np.log1p(np.clip(d[target_col].values.reshape(-1,1), 0, None)).astype(np.float32)
    Ys = np.zeros_like(y, np.float32)
    for cid, g in d.groupby("entity_id"):
        idx = g.index.values
        sc = y_scalers.get(int(cid), StandardScaler().fit(y[idx]))
        Ys[idx] = sc.transform(y[idx])

    return Split(Xs, Ys, d["entity_id"].values.astype(np.int64), d)

def to_seq(split,L=SEQ,prev_y=True):
    X,y,c,df=split.X,split.y,split.c,split.df
    ym=(df["year"].astype(int)*12+df["month"].astype(int)).values
    SX,SY,SC=[],[],[]
    for cid in np.unique(c):
        idx=np.where(c==cid)[0]; idx=idx[np.argsort(ym[idx])]; o=ym[idx]
        for i in range(len(idx)-L+1):
            sl=idx[i:i+L]
            if np.all(np.diff(o[i:i+L])==1):
                Xi=X[sl]; Yi=y[sl]
                if prev_y:
                    prev=np.vstack([np.zeros((1,1),np.float32),Yi[:-1]])
                    Xi=np.concatenate([Xi,prev],1)
                SX.append(Xi); SY.append(Yi); SC.append(cid)
    return np.asarray(SX,np.float32), np.asarray(SY,np.float32), np.asarray(SC,np.int64)

# ----------------- Models (same as your dengue GAN) -----------------
class CausalTCN(nn.Module):
    def __init__(s,in_ch,hid=64,levels=3,k=3):
        super().__init__(); s.k=k; s.blocks=nn.ModuleList(); ch=in_ch
        for l in range(levels):
            dil=2**l; pad=(k-1)*dil
            s.blocks.append(nn.Sequential(nn.Conv1d(ch,hid,kernel_size=k,dilation=dil,padding=pad),nn.GELU())); ch=hid
    def forward(s,x):
        y=x.transpose(1,2); L=y.size(-1)
        for _,b in enumerate(s.blocks):
            y=b(y); y=y[...,:L]
        return y.transpose(1,2)

class Generator(nn.Module):
    def __init__(s,cond_dim,noise_dim,nc,emb=8,lstm=LSTM_UNITS,heads=HEADS,drop=DROP,qu=Q):
        super().__init__(); s.q=qu; s.ce=nn.Embedding(nc,emb)
        s.pn=nn.Linear(noise_dim,cond_dim); s.tcn=CausalTCN(cond_dim*2+emb,hid=lstm)
        s.lstm=nn.LSTM(lstm,lstm,1,batch_first=True)
        s.mha=nn.MultiheadAttention(lstm,heads,batch_first=True,dropout=drop)
        s.ln=nn.LayerNorm(lstm); s.drop=nn.Dropout(drop)
        s.mu=nn.Linear(lstm,1); s.ls=nn.Linear(lstm,1)
        s.qh=nn.ModuleList([nn.Linear(lstm,1) for _ in qu]); s.last=None
    def forward(s,cond,noise,cid):
        B,L,D=cond.shape; emb=s.ce(cid).unsqueeze(1).repeat(1,L,1)
        z=s.pn(noise); h=s.tcn(torch.cat([cond,z,emb],-1)); h,_=s.lstm(h)
        L = h.size(1)
        mask = torch.triu(torch.ones(L, L, device=h.device, dtype=torch.bool), diagonal=1)
        att, w = s.mha(h, h, h, attn_mask=mask, need_weights=True)
        s.last = w.detach()
        h=s.drop(s.ln(att)); mu=s.mu(h); ls=torch.clamp(s.ls(h),-5.,3.); qs=[q(h) for q in s.qh]
        return mu,ls,qs

class Critic(nn.Module):
    def __init__(s,cond_dim,nc,emb=8,lstm=LSTM_UNITS,drop=DROP):
        super().__init__(); s.ce=nn.Embedding(nc,emb)
        s.lstm=nn.LSTM(cond_dim+1+emb,lstm,1,batch_first=True); s.fc=nn.Linear(lstm,64); s.drop=nn.Dropout(drop); s.out=nn.Linear(64,1)
    def forward(s,cond,ts,c):
        B,L,D=cond.shape
        e=s.ce(c).unsqueeze(1).repeat(1,L,1)
        h,_=s.lstm(torch.cat([cond,ts,e],-1)); f=F.gelu(s.fc(h[:,-1,:]));
        return s.out(s.drop(f)).squeeze(1), f

def pinball(p,t,q): e=t-p; return torch.mean(torch.maximum(q*e,(q-1)*e))
def tv_l1(x): d=x[:,1:]-x[:,:-1]; return d.abs().mean(), x.abs().mean()

def gp(critic,yr,yf,cond,cc,lam=10.):
    B=yr.size(0); eps=torch.rand(B,1,1,device=DEVICE); xi=(eps*yr+(1-eps)*yf).requires_grad_(True)
    with torch.backends.cudnn.flags(enabled=False): s,_=critic(cond,xi,cc)
    g=torch.autograd.grad(s,xi,grad_outputs=torch.ones_like(s),create_graph=True,retain_graph=True,only_inputs=True)[0]
    return ((g.view(B,-1).norm(2,1)-1)**2).mean()*lam

class Lookahead:
    def __init__(s,opt,alpha=ALPHA,k=K_SYNC): s.opt=opt; s.alpha=alpha; s.k=k; s.stepn=0; s.slow={id(p):p.clone().detach() for g in opt.param_groups for p in g["params"] if p.requires_grad}
    def zero_grad(s,set_to_none=True): s.opt.zero_grad(set_to_none=set_to_none)
    def step(s):
        s.opt.step(); s.stepn+=1
        if s.stepn%s.k==0:
            for g in s.opt.param_groups:
                for p in g["params"]:
                    if not p.requires_grad: continue
                    sp=s.slow[id(p)]; sp.data.add_(s.alpha,p.data-sp.data); p.data.copy_(sp.data)

class SAM:
    def __init__(s,opt,rho=0.05,adaptive=True): s.b=opt; s.rho=rho; s.adapt=adaptive; s.eps=None
    def _gn(s,ps):
        ns=[((p.abs()*p.grad) if s.adapt else p.grad).norm(p=2) for p in ps if p.grad is not None]
        return torch.norm(torch.stack(ns),p=2) if ns else torch.tensor(0.,device=DEVICE)
    def first(s,ps):
        scale=s.rho/(s._gn(ps)+1e-12); s.eps=[]
        with torch.no_grad():
            for p in ps:
                if p.grad is None: s.eps.append(None); continue
                e=((p.abs()*p.grad) if s.adapt else p.grad)*scale; p.add_(e); s.eps.append(e)
        s.b.zero_grad(set_to_none=True)
    def second(s,ps):
        with torch.no_grad():
            for p,e in zip(ps,s.eps or []):
                if e is not None: p.sub_(e)
        torch.nn.utils.clip_grad_norm_(ps,CLIP); s.b.step(); s.b.zero_grad(set_to_none=True); s.eps=None

class SeqDS(Dataset):
    def __init__(s,X,Y,C): s.X=X; s.Y=Y; s.C=C
    def __len__(s): return s.X.shape[0]
    def __getitem__(s,i): return s.X[i],s.Y[i],s.C[i]

def _g_loss(G,D,X,Y,C,z,w,pc_idx):
    mu,ls,qs=G(X,z,C); sig=(ls.exp()).clamp(1e-3,50.0)
    nll=0.5*(((Y-mu)/sig)**2+2*ls+math.log(2*math.pi)).mean() if ABL["hetero"] else F.l1_loss(mu,Y)
    ql=torch.tensor(0.,device=DEVICE)
    if ABL["quant"]:
        for i,q in enumerate(Q): ql+=pinball(qs[i],Y,q)
        ql/=len(Q)
    tv,l1=tv_l1(X[...,pc_idx])
    fm=torch.tensor(0.,device=DEVICE); adv_term=torch.tensor(0.,device=DEVICE)
    if ABL["adv"]:
        with torch.no_grad(): _, fr=D(X,Y,C)
        s_fake, ff=D(X,mu,C); fm=F.l1_loss(ff,fr); adv_term=-s_fake.mean()
    return w["nll"]*nll+w["q"]*ql+w["tv"]*tv+w["l1"]*l1+(w.get("fm",0)*fm if ABL["adv"] else 0.)+(w.get("adv",0)*adv_term if ABL["adv"] else 0.)

def validate(G,Xv,Yv,Cv):
    G.eval()
    with torch.no_grad():
        X=torch.tensor(Xv,dtype=torch.float32,device=DEVICE); C=torch.tensor(Cv,dtype=torch.long,device=DEVICE)
        B,L,_=X.shape; mu,ls,qs=G(X,torch.zeros(B,L,NOISE,device=DEVICE),C)
        mp=mu[...,0].cpu().numpy(); y=Yv[...,0]
        q10=qs[0][...,0].cpu().numpy(); q90=qs[2][...,0].cpu().numpy()
        cov=float(np.mean((y>=q10)&(y<=q90))); sm=smape(y.reshape(-1),mp.reshape(-1))
    return sm,cov

def train_gan(Xtr,Ytr,Ctr,Xva,Yva,Cva,cond_dim,nc,pc_idx,out_dir):
    G=Generator(cond_dim=cond_dim,noise_dim=NOISE,nc=nc).to(DEVICE)
    D=Critic(cond_dim=cond_dim,nc=nc).to(DEVICE)
    optG=Lookahead(torch.optim.AdamW(G.parameters(),lr=LR_G*TTUR[0],betas=(0.9,0.999),weight_decay=WD))
    optD=Lookahead(torch.optim.AdamW(D.parameters(),lr=LR_D*TTUR[1],betas=(0.9,0.999),weight_decay=WD))
    sam=SAM(optG,0.05,True)
    dl=DataLoader(SeqDS(Xtr,Ytr,Ctr),batch_size=BATCH,shuffle=True,drop_last=True)
    best=float("inf"); best_sd=None; wait=0
    for e in range(EPOCHS_GAN):
        t=e/(EPOCHS_GAN-1); ns=NSIG[0]+(NSIG[1]-NSIG[0])*t
        W_START = {"nll":1.0,"q":0.5,"tv":0.05,"l1":0.02,"fm":0.2,"adv":0.5}
        W_END   = {"nll":1.0,"q":1.0,"tv":0.10,"l1":0.05,"fm":0.1,"adv":0.4}
        w=lerp(W_START,W_END,t); tf=TF_START+(TF_END-TF_START)*t
        G.train(); D.train()
        for Xb,Yb,Cb in dl:
            Xb=torch.tensor(Xb,dtype=torch.float32,device=DEVICE); Yb=torch.tensor(Yb,dtype=torch.float32,device=DEVICE); Cb=torch.tensor(Cb,dtype=torch.long,device=DEVICE)
            B,L,_=Xb.shape
            with torch.no_grad():
                mu0,_,_=G(Xb,torch.zeros(B,L,NOISE,device=DEVICE),Cb)
                prev=torch.cat([torch.zeros(B,1,device=DEVICE),mu0[:,:-1,0]],1)
            Xb[:,:,-1]=tf*Xb[:,:,-1]+(1-tf)*prev
            # Critic
            if ABL["adv"]:
                optD.zero_grad(set_to_none=True)
                z=torch.randn(B,L,NOISE,device=DEVICE)*ns
                mu,_,_=G(Xb,z,Cb); Yf=mu.detach(); r,_=D(Xb,Yb,Cb); f,_=D(Xb,Yf,Cb)
                dloss=-(r.mean()-f.mean())+gp(D,Yb,Yf,Xb,Cb,10.0); dloss.backward(); torch.nn.utils.clip_grad_norm_(D.parameters(),CLIP); optD.step()
            # Generator + SAM
            optG.zero_grad(set_to_none=True)
            L1=_g_loss(G,D,Xb,Yb,Cb,torch.randn(B,L,NOISE,device=DEVICE)*ns,w,pc_idx); L1.backward(); sam.first(list(G.parameters()))
            L2=_g_loss(G,D,Xb,Yb,Cb,torch.randn(B,L,NOISE,device=DEVICE)*ns,w,pc_idx); L2.backward(); sam.second(list(G.parameters()))
        sm,cov=validate(G,Xva,Yva,Cva); comp=sm+10*abs(cov-0.9)
        if comp<best-1e-6: best=comp; best_sd=(copy.deepcopy(G.state_dict()),copy.deepcopy(D.state_dict())); wait=0
        else:
            wait+=1
            if wait>=PATIENCE: break
    if best_sd: G.load_state_dict(best_sd[0]); D.load_state_dict(best_sd[1])
    torch.save(G.state_dict(),os.path.join(out_dir,"generator.pt")); torch.save(D.state_dict(),os.path.join(out_dir,"critic.pt"))
    return G,D

# --------------- Prob bands ---------------
def copula_bands(G,X,C,K=K_MC):
    G.eval(); X=torch.tensor(X,dtype=torch.float32,device=DEVICE); C=torch.tensor(C,dtype=torch.long,device=DEVICE)
    with torch.no_grad(): mu,ls,_=G(X,torch.zeros(X.shape[0],X.shape[1],NOISE,device=DEVICE),C)
    mu=mu.cpu().numpy()[...,0]; sig=np.clip(np.exp(ls.cpu().numpy()[...,0]),1e-3,50.0)
    N,L=X.shape[0],X.shape[1]; S=[]
    for i in range(N):
        rho=0.5; R=np.fromfunction(lambda a,b: rho**np.abs(a-b),(L,L)); Sg=psd((sig[i][:,None]*sig[i][None,:])*R)
        z=np.random.multivariate_normal(np.zeros(L),Sg,size=K).astype(np.float32); S.append(mu[i][None,:]+z)
    S=np.stack(S,0); return np.percentile(S,10,1),np.percentile(S,50,1),np.percentile(S,90,1)

# --------------- Eval ---------------
def evaluate(G, Xseq, Yseq, Cseq, ysc, le, label="Evaluation"):
    G.eval()
    X = torch.tensor(Xseq, dtype=torch.float32, device=DEVICE)
    C = torch.tensor(Cseq, dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        mu, ls, qs = G(X, torch.zeros(X.shape[0], X.shape[1], NOISE, device=DEVICE), C)
        mp = mu.cpu().numpy()[..., 0]
    q10, _, q90 = copula_bands(G, Xseq, Cseq)
    yF=[]; mF=[]; lF=[]; hF=[]
    for i, cid in enumerate(Cseq):
        sc = ysc[int(cid)]
        y  = np.expm1(sc.inverse_transform(Yseq[i].reshape(-1,1)).reshape(-1))
        m  = np.clip(np.expm1(sc.inverse_transform(mp[i].reshape(-1,1)).reshape(-1)), 0, None)
        lo = np.clip(np.expm1(sc.inverse_transform(q10[i].reshape(-1,1)).reshape(-1)), 0, None)
        hi = np.clip(np.expm1(sc.inverse_transform(q90[i].reshape(-1,1)).reshape(-1)), 0, None)
        yF.append(y); mF.append(m); lF.append(lo); hF.append(hi)
    yF = np.concatenate(yF); mF = np.concatenate(mF); lF = np.concatenate(lF); hF = np.concatenate(hF)
    overall = {
        "SMAPE": smape(yF, mF),
        "MSE": mean_squared_error(yF, mF),
        "RMSE": math.sqrt(mean_squared_error(yF, mF)),
        "R2": r2_score(yF, mF),
        "Coverage90": float(np.mean((yF >= lF) & (yF <= hF))),
    }
    print(f"\n=== {label} ===")
    for k,v in overall.items(): print(f"{k}: {v:.4f}")
    return overall, (mF, lF, hF, yF)

# --------------- Main per-target ---------------
def run_for_target(df_raw, target_col, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    print(f"\n========== Target: {target_col} ==========")

    # 1) Conditional VAE on indicators
    print("1) Train Conditional VAE on indicators ...")
    vae = train_vae_on_indicators(df_raw, out_dir)

    # 2) Latent features
    print("2) Extract latent features ...")
    lat = vae_latent_df(vae, df_raw)

    # 3) Merge back, + keep indicators as base for lags (as requested: use historical indicator values)
    base = lat.merge(
        df_raw[["UpazilaID","year","month","ym"] + INDICATORS + [target_col]],
        on=["UpazilaID","year","month","ym"], how="left", validate="1:1"
    )

    # 4) Lags on indicators only (as per requirement)
    print("3) Build lag features (1,3,6) + rolling stats ...")
    pc_cols = [c for c in base.columns if c.startswith("pc_mean_") or c.startswith("pc_std_")]
    base, lag_cols = add_lags(base, INDICATORS, group_col="UpazilaID", lags=(1,3,6), rolls=(3,6))

    # 5) Feature set
    feats = pc_cols + ["month_sin","month_cos"] + lag_cols

    # 6) Time split per Upazila
    print("4) Split per Upazila: train / val (-12:-6) / test (-6:) ...")
    tr_parts, va_parts, te_parts = [], [], []
    for uid, g in base.groupby("UpazilaID"):
        g = g.sort_values(["year","month"]).copy()
        if len(g) > (VAL_H_MONTHS + TEST_H_MONTHS):
            te_parts.append(g.iloc[-TEST_H_MONTHS:])                                 # test
            va_parts.append(g.iloc[-(VAL_H_MONTHS + TEST_H_MONTHS):-TEST_H_MONTHS])  # val
            tr_parts.append(g.iloc[:-(VAL_H_MONTHS + TEST_H_MONTHS)])                # train
        elif len(g) > TEST_H_MONTHS:
            te_parts.append(g.iloc[-TEST_H_MONTHS:])
            tr_parts.append(g.iloc[:-TEST_H_MONTHS])
        else:
            tr_parts.append(g)
    tr_df = pd.concat(tr_parts).reset_index(drop=True)
    va_df = pd.concat(va_parts).reset_index(drop=True) if va_parts else tr_df.iloc[0:0].copy()
    te_df = pd.concat(te_parts).reset_index(drop=True) if te_parts else tr_df.iloc[0:0].copy()

    # 7) Fit scalers on TRAIN
    print("5) Fit scalers/encoders on TRAIN ...")
    split_tr, le_tr, fsc_tr, ysc_tr = build_mats(tr_df, feats, target_col, entity_col="UpazilaID")

    # Sequences
    X_tr, Y_tr, C_tr = to_seq(split_tr, SEQ, True)
    if X_tr.shape[0] == 0:
        raise RuntimeError("No train sequences of length SEQ were formed. Reduce SEQ or check continuity.")
    print(f"Train sequences: {X_tr.shape}")

    if len(va_df) > 0:
        split_va = apply_mats(va_df, feats, target_col, le_tr, fsc_tr, ysc_tr, "UpazilaID")
        X_va, Y_va, C_va = to_seq(split_va, SEQ, True)
        if X_va.shape[0] == 0:
            X_va, Y_va, C_va = X_tr, Y_tr, C_tr
    else:
        X_va, Y_va, C_va = X_tr, Y_tr, C_tr

    if len(te_df) > 0:
        split_te = apply_mats(te_df, feats, target_col, le_tr, fsc_tr, ysc_tr, "UpazilaID")
        X_te, Y_te, C_te = to_seq(split_te, SEQ, True)
        if X_te.shape[0] == 0:
            X_te, Y_te, C_te = X_va, Y_va, C_va
    else:
        X_te, Y_te, C_te = X_va, Y_va, C_va

    nc = int(C_tr.max()) + 1
    cond_dim = X_tr.shape[2]
    pc_idx = [i for i, f in enumerate(feats + ["prev_y"]) if f.startswith("pc_mean_")]

    # 8) Train GAN
    print("6) Train forecasting GAN ...")
    G, D = train_gan(X_tr, Y_tr, C_tr, X_va, Y_va, C_va, cond_dim, nc, pc_idx, out_dir)

    # 9) Evaluate
    print("7) Evaluate on HELD-OUT TEST ...")
    overall, (mf, lo, hi, y) = evaluate(G, X_te, Y_te, C_te, ysc_tr, le_tr, label="HELD-OUT TEST")

    # save artifacts
    np.savez(os.path.join(out_dir, "predictions_test.npz"), mf=mf, lo=lo, hi=hi, y=y, C_te=C_te)
    json.dump({
        "metrics_test": {k: float(v) for k,v in overall.items()},
        "config": {"seq_len": SEQ, "latent_dim": LATENT, "noise_dim": NOISE, "lstm_units": LSTM_UNITS},
        "notes": "Upazila-wise split; last 6M TEST, prev 6M VAL; indicators→CondVAE→PCs + lagged indicators + seasonality."
    }, open(os.path.join(out_dir,"manifest.json"),"w"), indent=2)

    # simple scatter plot
    fig, ax = plt.subplots(figsize=(6,6))
    ax.scatter(y, mf, alpha=0.6)
    lim = float(max(y.max(), mf.max())*1.05)
    ax.plot([0,lim],[0,lim],'r--')
    ax.set_title(f'Actual vs Predicted — {target_col}')
    ax.set_xlabel('Actual'); ax.set_ylabel('Pred')
    plt.tight_layout(); plt.savefig(os.path.join(out_dir,"scatter_test.png"), dpi=140); plt.close()
    print(f"[{target_col}] Artifacts saved to {out_dir}")

def main():
    print("Load malaria data ...")
    df = load_malaria(MALARIA_CSV)

    # ensure targets exist
    need = {"pv_rate":"PV/Population","pf_rate":"PF/Population","mixed_rate":"MIXED/Population"}
    for k,v in need.items():
        if not np.isfinite(df[k]).any():
            print(f"[WARN] target {k} ({v}) appears empty or non-finite in your data.")

    # run three targets
    targets = [("pv_rate","Target_PV_over_Pop"),
               ("pf_rate","Target_PF_over_Pop"),
               ("mixed_rate","Target_MIXED_over_Pop")]

    for tgt, name in targets:
        out_dir = os.path.join(OUT_ROOT, name); os.makedirs(out_dir, exist_ok=True)
        # drop rows where target is missing at train time (we'll still preserve timeline continuity via to_seq)
        df_t = df.copy()
        # We keep rows for feature continuity but training matrices will reflect where label exists
        run_for_target(df_t, tgt, out_dir)

if __name__=="__main__":
    main()

Load malaria data ...

1) Train Conditional VAE on indicators ...
2) Extract latent features ...
3) Build lag features (1,3,6) + rolling stats ...
4) Split per Upazila: train / val (-12:-6) / test (-6:) ...
5) Fit scalers/encoders on TRAIN ...
Train sequences: (9443, 12, 62)
6) Train forecasting GAN ...
7) Evaluate on HELD-OUT TEST ...

=== HELD-OUT TEST ===
SMAPE: 47.0294
MSE: 0.0000
RMSE: 0.0000
R2: 0.9863
Coverage90: 0.6609
[pv_rate] Artifacts saved to ./output/malaria_e2e/Target_PV_over_Pop

1) Train Conditional VAE on indicators ...
2) Extract latent features ...
3) Build lag features (1,3,6) + rolling stats ...
4) Split per Upazila: train / val (-12:-6) / test (-6:) ...
5) Fit scalers/encoders on TRAIN ...
Train sequences: (9443, 12, 62)
6) Train forecasting GAN ...
7) Evaluate on HELD-OUT TEST ...

=== HELD-OUT TEST ===
SMAPE: 59.2943
MSE: 0.0000
RMSE: 0.0001
R2: 0.9975
Coverage90: 0.7667
[pf_rate] Artifacts saved to ./output/malaria_e2e/Target_PF_over_Pop

1) Train Conditional 