In [None]:
# ============================================================
# CVAE-LSTM (heteroscedastic) + nHMM-like soft transitions +
# ARIMA & Persistence baselines + metrics + visualizations
#
# What you get:
#   - Robust checkpoint loader (adapts dims, adds log-variance head)
#   - Two evaluation modes:
#       (A) Tail-only (last TEST_DAYS)    -> matches your original setup
#       (B) Rolling / Regime-stratified   -> recommended for the paper
#   - Free-run reconstruction of full curves (shape metrics)
#   - Publication-ready plots
#   - Comparative summary table by commune
# ============================================================

import os, warnings, math, random
from pathlib import Path
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim

# sklearn compatibility (sparse_output -> sparse for older versions)
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder as _OHE
def OneHotEncoder_compat(**kwargs):
    try:
        # Newer sklearn uses 'sparse_output'
        return _OHE(sparse_output=False, **{k:v for k,v in kwargs.items() if k != "sparse"})
    except TypeError:
        # Older sklearn uses 'sparse'
        return _OHE(sparse=False, **kwargs)

from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA

warnings.filterwarnings("ignore")
np.set_printoptions(suppress=True, linewidth=160)

# -------------------------
# Paths & configuration
# -------------------------
ROOT = 'path/to'
DATA_CSV = os.path.join(ROOT, "covid_data_with_6_states.csv")
MODEL_PTH = os.path.join(ROOT, "cvae_lstm_model_normalized.pth")
OUT_DIR   = os.path.join(ROOT, "eval_out")
os.makedirs(OUT_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 123
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Time series & conditioning
FEATURES = ["Gross_Daily_Cases_Mobile_Average_7_Days",
            "Internal_Mobility_Index", "External_Mobility_Index"]
CASE_COL = FEATURES[0]
SEQ_LEN  = 7  # LSTM window length

# Default CVAE hyperparams (used if training from scratch)
LATENT_DIM = 6
HIDDEN_DIM = 64
DEC_HIDDEN = 64
BETA_KL    = 0.25     # lower beta -> freer latent space
ALPHA_MSE  = 0.20     # small weight on full-feature MSE
LR         = 1e-3
EPOCHS     = 50
BATCH_SIZE = 64

# Optional quick fine-tune after loading a checkpoint (to calibrate new logvar row)
FINETUNE_AFTER_LOAD_EPOCHS = 6

# Ensembles & horizons
N_ENSEMBLE = 200
H_LIST     = [7, 14]
TEST_DAYS  = 28              # tail used in "A) Tail-only" evaluation
TEMP_Z     = 1.75            # latent temperature
TEMP_OUT   = 1.50            # output variance temperature (inflates σ)

# ARIMA search grid (compact & robust)
ARIMA_GRID = [(1,0,0),(1,1,0),(0,1,1),(1,1,1),(2,1,1)]

# Target communes for figures
TARGET_COMMUNES = ["La Florida", "Cerrillos", "Vitacura",
                   "Providencia", "Las Condes", "Santiago"]

# ---------------------------------
# Utilities: sequences
# ---------------------------------
def create_sequences(X, C, Y, seq_len=SEQ_LEN):
    """Build (X_seq, cond_at_t, Y_t) for one-step-ahead supervised learning."""
    Xs, Cs, Ys = [], [], []
    for i in range(len(X)-seq_len):
        Xs.append(X[i:i+seq_len])
        Cs.append(C[i+seq_len])    # condition at prediction time
        Ys.append(Y[i+seq_len])    # next-step target (full feature vector)
    return np.array(Xs), np.array(Cs), np.array(Ys)

# ---------------------------------
# Scoring metrics
# ---------------------------------
def mae(y, yhat):  return float(np.mean(np.abs(np.asarray(y)-np.asarray(yhat))))
def rmse(y, yhat): return float(np.sqrt(np.mean((np.asarray(y)-np.asarray(yhat))**2)))

def crps_ensemble(y, samples):
    """CRPS for ensemble samples (Gneiting & Raftery)."""
    s = np.sort(np.asarray(samples).ravel())
    n = len(s)
    y = float(y)
    e1 = np.mean(np.abs(s - y))                # E|X - y|
    diffs = np.diff(s)
    weights = np.arange(1, n) * (n - np.arange(1, n))
    e2 = 2.0 * np.sum(weights * diffs) / (n*n) # E|X - X'|
    return float(e1 - 0.5*e2)

def interval_score(y, lo, hi, alpha):
    y = float(y); lo = float(lo); hi = float(hi)
    width = hi - lo
    penalty = 0.0
    if y < lo: penalty = (2.0/alpha) * (lo - y)
    elif y > hi: penalty = (2.0/alpha) * (y - hi)
    return width + penalty

def wis_from_quantiles(y, interval_levels, qdict):
    """
    Weighted Interval Score with median component.
    interval_levels e.g. [0.5, 0.9]. qdict keys must include:
    0.5 and (1±alpha)/2 for each alpha.
    """
    if 0.5 not in qdict:
        raise KeyError("qdict must contain the median under key 0.5")
    total = float(abs(float(y) - float(qdict[0.5])))
    for alpha in interval_levels:
        lo_key = round((1.0 - alpha)/2.0, 4)
        hi_key = round(1.0 - (1.0 - alpha)/2.0, 4)
        if lo_key not in qdict or hi_key not in qdict:
            raise KeyError(f"Missing quantiles for alpha={alpha}: need keys {lo_key} and {hi_key}")
        total += (alpha/2.0) * interval_score(y, qdict[lo_key], qdict[hi_key], alpha)
    denom = 1.0 + np.sum([a/2.0 for a in interval_levels])
    return float(total / denom)

def empirical_quantiles(samples, qs=(0.05,0.25,0.5,0.75,0.95)):
    s = np.asarray(samples).ravel()
    return {round(q,4): float(np.quantile(s, q)) for q in qs}

def coverage(y, lo, hi):
    y = float(y); lo = float(lo); hi = float(hi)
    return 1.0 if (lo <= y <= hi) else 0.0

# ============================
# Heteroscedastic CVAE-LSTM
# + robust checkpoint loader
# ============================
class CVAE_LSTM_HET(nn.Module):
    """
    Heteroscedastic CVAE with LSTM encoder.
      - Encoder: LSTM -> [mu_z, logvar_z]
      - Decoder: [z, cond] -> dense -> [mu_full (D), logvar_case (1)]
    """
    def __init__(self, input_dim, cond_dim, latent_dim, hidden_dim, dec_hidden=64):
        super().__init__()
        self.input_dim  = input_dim
        self.cond_dim   = cond_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.dec_hidden = dec_hidden

        self.lstm_enc = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc_mu    = nn.Linear(hidden_dim + cond_dim, latent_dim)
        self.fc_logv  = nn.Linear(hidden_dim + cond_dim, latent_dim)

        self.dec = nn.Sequential(
            nn.Linear(latent_dim + cond_dim, dec_hidden),
            nn.ReLU(),
            nn.Linear(dec_hidden, input_dim + 1)  # +1 for log-variance of 'case'
        )

    def encode(self, x, c):
        # x: (B, T, D), c: (B, C)
        _, (h, _) = self.lstm_enc(x)
        hc = torch.cat([h[-1], c], dim=1)
        mu = self.fc_mu(hc)
        logv = self.fc_logv(hc)
        return mu, logv

    def reparam(self, mu, logv, temp=1.0):
        std = torch.exp(0.5 * logv)
        eps = torch.randn_like(std)
        return mu + temp * std * eps

    def forward(self, x, c, temp=1.0):
        mu_z, logv_z = self.encode(x, c)
        z = self.reparam(mu_z, logv_z, temp=temp)
        out = self.dec(torch.cat([z, c], dim=1))
        mu_full     = out[:, :-1]  # means for all features
        logvar_case = out[:, -1:]  # log-variance for case only
        return mu_full, logvar_case, mu_z, logv_z

def nll_gaussian_case(y_case, mu_case, logvar_case, out_temp=1.0):
    """Gaussian NLL on the 'case' dimension, with variance temperature scaling."""
    adj_logvar = logvar_case + 2.0*math.log(out_temp)
    return 0.5 * (adj_logvar + (y_case - mu_case)**2 / torch.exp(adj_logvar))

def cvae_loss_hetero(mu_full, logvar_case, y_full, mu_z, logv_z,
                     alpha_mse=0.2, beta_kl=0.25, out_temp=1.0):
    """Total loss = NLL(case) + alpha*MSE(full features) + beta*KL."""
    y_case  = y_full[:, :1]
    mu_case = mu_full[:, :1]
    nll = nll_gaussian_case(y_case, mu_case, logvar_case, out_temp=out_temp).mean()
    mse = nn.functional.mse_loss(mu_full, y_full)
    kl  = -0.5 * torch.mean(1 + logv_z - mu_z.pow(2) - torch.exp(logv_z))
    return nll + alpha_mse*mse + beta_kl*kl

def infer_ckpt_dims(state_dict, input_dim):
    """
    Infer hidden_dim, latent_dim, dec_hidden, cond_dim, and whether the
    checkpoint already contains the extra log-variance row in the last
    decoder layer.
    """
    w_ih = state_dict['lstm_enc.weight_ih_l0']          # (4*H, input_dim)
    hidden_old = w_ih.shape[0] // 4

    latent_old = state_dict['fc_mu.weight'].shape[0]    # rows = latent_dim

    dec0_w = state_dict['dec.0.weight']                 # (dec_hidden, latent + cond)
    dec_hidden_old = dec0_w.shape[0]
    cond_old = dec0_w.shape[1] - latent_old

    # Rows of the last decoder layer (output size in the checkpoint)
    out_old = state_dict['dec.2.weight'].shape[0]

    # Accept either μ-only (input_dim) or μ+logvar (input_dim+1)
    if out_old not in (input_dim, input_dim + 1):
        raise ValueError(
            f"Incompatible checkpoint: decoder rows={out_old}, "
            f"but expected {input_dim} (μ only) or {input_dim+1} (μ+logvar). "
            "Did FEATURES change?"
        )

    has_logvar_in_ckpt = (out_old == input_dim + 1)
    return hidden_old, latent_old, dec_hidden_old, cond_old, has_logvar_in_ckpt


def build_and_load_robust(INPUT_DIM, COND_DIM, ckpt_path, device):
    """
    - Load state_dict from ckpt.
    - Rebuild model using dims from ckpt (so shapes fit).
    - Ensure the model's decoder outputs INPUT_DIM + 1 (μ of all features + logvar for cases).
      * If ckpt already has that extra row, copy as-is.
      * If ckpt has only INPUT_DIM rows, expand and zero-init the new row.
    - Copy all compatible weights; incompatible remain randomly initialized.
    """
    assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
    sd = torch.load(ckpt_path, map_location=device)

    H_old, Z_old, DEC_old, COND_old, has_logvar = infer_ckpt_dims(sd, INPUT_DIM)

    if COND_old != COND_DIM:
        print(f"⚠️  Current COND_DIM={COND_DIM} differs from checkpoint={COND_old}. "
              "I will load all shape-compatible tensors; the first decoder layer "
              "may remain randomly initialized.")

    print(f"→ Checkpoint dims detected: hidden={H_old}, latent={Z_old}, "
          f"dec_hidden={DEC_old}, cond_dim(ckpt)={COND_old}, "
          f"logvar_in_ckpt={has_logvar}")

    # Build model with ckpt dims (decoder output = INPUT_DIM + 1)
    model = CVAE_LSTM_HET(INPUT_DIM, COND_DIM, Z_old, H_old, dec_hidden=DEC_old).to(device)
    new_sd = model.state_dict()

    # Copy 1:1 matching tensors first
    for k, v in sd.items():
        if k in new_sd and new_sd[k].shape == v.shape:
            new_sd[k] = v.clone()

    # Handle the last decoder layer explicitly
    # Our model expects (INPUT_DIM + 1, DEC_old) in dec.2.weight and bias
    W_new = new_sd['dec.2.weight']
    b_new = new_sd['dec.2.bias']

    if has_logvar:
        # CKPT already has μ+logvar → if shapes match, we already copied above.
        # If not (e.g., cond_dim mismatch changed the in_features of dec.2), we at least
        # copy the rows that match in width.
        W_ck = sd['dec.2.weight']
        b_ck = sd['dec.2.bias']
        if W_ck.shape == W_new.shape:
            pass  # already copied
        else:
            # Fallback: copy overlapping columns, keep the rest as initialized
            cols = min(W_ck.shape[1], W_new.shape[1])
            rows = min(W_ck.shape[0], W_new.shape[0])
            W_new[:rows, :cols] = W_ck[:rows, :cols]
            b_new[:rows] = b_ck[:rows]
            new_sd['dec.2.weight'] = W_new
            new_sd['dec.2.bias']   = b_new
    else:
        # CKPT had only μ (INPUT_DIM rows) → expand with one extra row for logvar
        W_ck = sd['dec.2.weight']   # (INPUT_DIM, DEC_old_ck)
        b_ck = sd['dec.2.bias']     # (INPUT_DIM,)
        rows_old, cols_old = W_ck.shape
        # Copy existing μ rows into the first INPUT_DIM rows
        cols = min(cols_old, W_new.shape[1])
        W_new[:rows_old, :cols] = W_ck[:, :cols]
        b_new[:rows_old] = b_ck
        # Initialize the extra logvar row to zeros
        nn.init.zeros_(W_new[rows_old:, :])
        nn.init.zeros_(b_new[rows_old:])
        new_sd['dec.2.weight'] = W_new
        new_sd['dec.2.bias']   = b_new

    # Finally load into the model
    missing, unexpected = model.load_state_dict(new_sd, strict=False)
    print(f"✓ Checkpoint loaded with adaptation. Missing={list(missing)}, Unexpected={list(unexpected)}")
    model.eval()
    return model


# ---------------------------------
# Load data & encoders
# ---------------------------------
df = pd.read_csv(DATA_CSV, parse_dates=["Date"])
df = df.sort_values(["Commune","Date"]).reset_index(drop=True)

# Guardamos la columna de casos en UNIDADES ORIGINALES bajo un nombre fijo
OBS_COL = "Observed_Cases"
df[OBS_COL] = df[CASE_COL].values  # copia explícita

# Global MinMax scaler (fit on ALL features, incluyendo casos)
scaler = MinMaxScaler()
X_all = df[FEATURES].values
X_scaled = scaler.fit_transform(X_all)

# df_scaled: igual a df pero con FEATURES escalados; mantenemos OBS_COL en original
df_scaled = df.copy()
df_scaled[FEATURES] = X_scaled  # OJO: CASE_COL queda escalado aquí a propósito
# pero OBS_COL queda en original y es la que usaremos para verdad/ARIMA

# (opcional) sanity check útil
assert df_scaled[OBS_COL].max() > 5.0, "Observed_Cases parece estar escalado; revisa el flujo."

# Hidden states → enteros 0..K-1
state_vals = sorted(df_scaled["Hidden_State"].dropna().unique().tolist())
state_to_id = {s:i for i,s in enumerate(state_vals)}
id_to_state = {i:s for s,i in state_to_id.items()}
K_STATES = len(state_vals)
print(f"Detected hidden-state categories (ordered): {state_vals}  -> K={K_STATES}")

df_scaled["state_id"] = df_scaled["Hidden_State"].map(state_to_id).astype(int)

# Commune encoder (one-hot)
enc_comm = OneHotEncoder_compat(handle_unknown="ignore")
comm_cond = enc_comm.fit_transform(df_scaled[["Commune"]])  # (N, C)

# State encoder (para entrenamiento duro)
enc_state = OneHotEncoder_compat(categories=[state_vals], handle_unknown="ignore")
state_oh = enc_state.fit_transform(df_scaled[["Hidden_State"]])

# Full conditioning = [comm_onehot, state_onehot]
COND_FULL = np.concatenate([comm_cond, state_oh], axis=1)
COND_DIM  = COND_FULL.shape[1]
INPUT_DIM = len(FEATURES)

# Guardamos min/max por si los quieres inspeccionar
CASE_MIN, CASE_MAX = scaler.data_min_[0], scaler.data_max_[0]

# ---------------------------------
# Train/test split by time (tail-only mode)
# ---------------------------------
def time_split_by_commune(df_in, test_days=TEST_DAYS):
    marks = []
    for comm, d in df_in.groupby("Commune", sort=False):
        n = len(d)
        cut = max(0, n - test_days)
        marks.append(pd.Series([0]*cut + [1]*(n-cut), index=d.index))
    return pd.concat(marks).sort_index().values

is_test = time_split_by_commune(df_scaled, TEST_DAYS)
df_scaled["is_test"] = is_test

# ---------------------------------
# Tensors for training CVAE (train slice only)
# ---------------------------------
train_df = df_scaled[df_scaled["is_test"]==0].copy()
X_train_raw = train_df[FEATURES].values
C_train_raw = COND_FULL[train_df.index]
Y_train_raw = train_df[FEATURES].values  # next-step target built in sequence

X_seq, C_seq, Y_seq = create_sequences(X_train_raw, C_train_raw, Y_train_raw, seq_len=SEQ_LEN)
Xt = torch.tensor(X_seq, dtype=torch.float32)
Ct = torch.tensor(C_seq, dtype=torch.float32)
Yt = torch.tensor(Y_seq, dtype=torch.float32)

train_ds = torch.utils.data.TensorDataset(Xt, Ct, Yt)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

# ---------------------------------
# Build model & load checkpoint (or train from scratch)
# ---------------------------------
if Path(MODEL_PTH).exists():
    model = build_and_load_robust(INPUT_DIM, COND_DIM, MODEL_PTH, DEVICE)
    # Optional: quick fine-tune to calibrate the new logvar row
    if FINETUNE_AFTER_LOAD_EPOCHS > 0:
        print(f"→ Quick fine-tune for {FINETUNE_AFTER_LOAD_EPOCHS} epochs to calibrate logvar row...")
        opt = optim.Adam(model.parameters(), lr=LR)
        model.train()
        for ep in range(FINETUNE_AFTER_LOAD_EPOCHS):
            tot = 0.0
            for xb, cb, yb in train_loader:
                xb, cb, yb = xb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad()
                mu_full, logvar_case, mu_z, logv_z = model(xb, cb, temp=1.0)
                loss = cvae_loss_hetero(mu_full, logvar_case, yb, mu_z, logv_z,
                                        alpha_mse=ALPHA_MSE, beta_kl=BETA_KL, out_temp=1.0)
                loss.backward(); opt.step()
                tot += loss.item() * xb.size(0)
            print(f"[FT {ep+1:02d}] loss={tot/len(train_ds):.4f}")
        torch.save(model.state_dict(), MODEL_PTH)
        print("✓ Fine-tuned weights saved.")
    model.eval()
else:
    print("No checkpoint found. Training from scratch...")
    model = CVAE_LSTM_HET(INPUT_DIM, COND_DIM, LATENT_DIM, HIDDEN_DIM, dec_hidden=DEC_HIDDEN).to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)
    model.train()
    for ep in range(EPOCHS):
        tot = 0.0
        for xb, cb, yb in train_loader:
            xb, cb, yb = xb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad()
            mu_full, logvar_case, mu_z, logv_z = model(xb, cb, temp=1.0)
            loss = cvae_loss_hetero(mu_full, logvar_case, yb, mu_z, logv_z,
                                    alpha_mse=ALPHA_MSE, beta_kl=BETA_KL, out_temp=1.0)
            loss.backward(); opt.step()
            tot += loss.item() * xb.size(0)
        print(f"[Epoch {ep+1:02d}] loss={tot/len(train_ds):.4f}")
    torch.save(model.state_dict(), MODEL_PTH)
    print(f"✓ Trained & saved: {MODEL_PTH}")
    model.eval()

# ---------------------------------
# Non-homogeneous transition model (multinomial LR)
# P(S_t | S_{t-1}, mobility, commune)
# ---------------------------------
def fit_transition_softmax(df_all):
    """
    Fit multinomial logistic regression for state transitions using:
    one-hot(S_{t-1}), mobility at t (IM, EM), and commune one-hot.
    Uses ONLY the training slice to avoid leakage.
    """
    use_df = df_all[df_all["is_test"]==0].copy()
    use_df["s_prev"] = use_df.groupby("Commune")["state_id"].shift(1)
    use_df["im_prev"] = use_df.groupby("Commune")["Internal_Mobility_Index"].shift(0)
    use_df["em_prev"] = use_df.groupby("Commune")["External_Mobility_Index"].shift(0)
    use_df = use_df.dropna(subset=["s_prev"]).copy()
    use_df["s_prev"] = use_df["s_prev"].astype(int)

    X_comm = enc_comm.transform(use_df[["Commune"]])
    X_mob  = use_df[["im_prev","em_prev"]].values
    enc_prev = OneHotEncoder_compat(categories=[list(range(K_STATES))])
    X_sp   = enc_prev.fit_transform(use_df[["s_prev"]])

    X = np.hstack([X_sp, X_mob, X_comm])
    y = use_df["state_id"].values.astype(int)

    clf = LogisticRegression(max_iter=300, multi_class="multinomial", solver="lbfgs")
    clf.fit(X, y)
    return clf

trans_clf = fit_transition_softmax(df_scaled)

def transition_proba_soft(df_slice_row):
    """Return P(S_t | S_{t-1}, mobility, commune) as vector length K_STATES."""
    s_prev = int(df_slice_row["state_id"])
    im = float(df_slice_row["Internal_Mobility_Index"])
    em = float(df_slice_row["External_Mobility_Index"])
    comm_vec = enc_comm.transform([[df_slice_row["Commune"]]])
    sp = np.zeros((1,K_STATES)); sp[0, s_prev] = 1.0
    X = np.hstack([sp, np.array([[im,em]]), comm_vec])
    p = trans_clf.predict_proba(X)[0]
    p = np.clip(p, 1e-8, 1.0)
    return p / p.sum()

# ---------------------------------
# Forecast helpers (ARIMA & Persistence)
# ---------------------------------
def persistence_forecast(y_hist, horizon, n_samples=200):
    """Deterministic point forecast + residual bootstrap for intervals."""
    last = float(y_hist[-1])
    mean = np.full(horizon, last)
    resid = np.diff(y_hist[-30:]) if len(y_hist) > 30 else np.diff(y_hist)
    if resid.size < 2:
        resid = np.array([0.0, 0.0, 0.0])
    samples = np.zeros((horizon, n_samples))
    for h in range(horizon):
        noise = np.random.choice(resid, size=n_samples, replace=True)
        samples[h,:] = mean[h] + noise
    return mean, samples

def arima_forecast(y_hist, horizon, n_samples=200, max_train=180):
    """
    Compact ARIMA grid on recent window; sample N(μ, σ^2_pred) for intervals.
    """
    y_train = np.asarray(y_hist[-max_train:], dtype=float)
    best_aic, best_res = np.inf, None
    for order in ARIMA_GRID:
        try:
            res = ARIMA(y_train, order=order).fit()
            if res.aic < best_aic:
                best_aic, best_res = res.aic, res
        except Exception:
            continue
    if best_res is None:
        return persistence_forecast(y_hist, horizon, n_samples)

    fc = best_res.get_forecast(steps=horizon)
    mean_fc = np.asarray(fc.predicted_mean)
    var_fc  = np.asarray(fc.var_pred_mean)
    std_fc  = np.sqrt(np.maximum(var_fc, 1e-8))
    samples = np.zeros((horizon, n_samples))
    for h in range(horizon):
        samples[h,:] = np.random.normal(mean_fc[h], std_fc[h], size=n_samples)
    return mean_fc, samples

# ---------------------------------
# CVAE rollout (soft states, autoregressive), mobility persistence
# Returns samples in ORIGINAL units (already inverse-transformed)
# ---------------------------------
@torch.no_grad()
def cvae_rollout(df_comm, origin_idx, horizon, n_samples=N_ENSEMBLE,
                 temp_z=TEMP_Z, temp_out=TEMP_OUT):
    """
    df_comm: dataframe of a single commune, scaled features & state_id present.
    origin_idx: index (in df_comm) of the last observed time used as origin.
    horizon: steps ahead.

    - For t>origin, mobility persists at last observed values.
    - Next-state probs via multinomial LR with (prev state, mobility, commune).
    - Conditioning uses SOFT state probs.
    - Output samples are inverse-transformed to ORIGINAL units for the 'case' var.
    """
    if origin_idx < SEQ_LEN:
        return None, None

    # Initial LSTM window (scaled)
    Xwin = df_comm[FEATURES].values[origin_idx-SEQ_LEN+1:origin_idx+1].astype(float)
    im_last = float(df_comm.iloc[origin_idx]["Internal_Mobility_Index"])
    em_last = float(df_comm.iloc[origin_idx]["External_Mobility_Index"])
    comm_name = df_comm.iloc[0]["Commune"]
    comm_onehot = enc_comm.transform([[comm_name]])[0]

    # Previous state (hard) -> soft thereafter
    s_prev_id = int(df_comm.iloc[origin_idx]["state_id"])
    s_prev_probs = np.zeros(K_STATES); s_prev_probs[s_prev_id] = 1.0

    samples = np.zeros((horizon, n_samples), dtype=float)
    x_buf = Xwin.copy()

    for m in range(n_samples):
        s_probs = s_prev_probs.copy()
        x_buf_m = x_buf.copy()

        for h in range(horizon):
            # Conditioning vector with soft state probs
            state_soft = s_probs.reshape(1,-1)  # (1,K)
            cond_vec = np.concatenate([comm_onehot.reshape(1,-1), state_soft], axis=1)

            xb = torch.tensor(x_buf_m[np.newaxis, :, :], dtype=torch.float32, device=DEVICE)
            cb = torch.tensor(cond_vec, dtype=torch.float32, device=DEVICE)

            mu_full, logvar_case, mu_z, logv_z = model(xb, cb, temp=temp_z)
            mu_full = mu_full.cpu().numpy()[0]
            logvar_case = logvar_case.cpu().numpy()[0,0]
            mu_case = mu_full[0]
            sigma_case = math.sqrt(max(1e-8, math.exp(logvar_case))) * temp_out

            # Sample next 'case' in SCALED space and clip to [0,1] for safety
            y_case_scaled = float(np.clip(np.random.normal(mu_case, sigma_case), 0.0, 1.0))

            # Build next-step scaled vector (use model means for mobility dims)
            mu_full[0] = y_case_scaled
            next_vec = mu_full.copy()

            # Inverse-scale to ORIGINAL units for the case variable
            tmp = np.zeros((1, len(FEATURES))); tmp[0,:] = next_vec
            unscaled_case = scaler.inverse_transform(tmp)[0,0]
            samples[h, m] = unscaled_case

            # Autoregressive buffer update (keep scaled buffer)
            x_buf_m = np.vstack([x_buf_m[1:], next_vec])

            # Project next state's soft probs (mobility persistence)
            pseudo = {
                "state_id": np.argmax(s_probs),
                "Internal_Mobility_Index": im_last, "External_Mobility_Index": em_last,
                "Commune": comm_name
            }
            s_probs = transition_proba_soft(pd.Series(pseudo))

    mean_path = np.mean(samples, axis=1)
    return mean_path, samples

# ============================================================
# A) Tail-only Evaluation (as in your original setup)
# ============================================================
def evaluate_tail(horizon_list=H_LIST, test_days=TEST_DAYS):
    rows = []
    viz_cache = {}

    for comm, dfc in df_scaled.groupby("Commune"):
        dfc = dfc.sort_values("Date").reset_index(drop=True)
        y_true_full = dfc[OBS_COL].values  # original units

        idxs = np.where(dfc["is_test"].values == 1)[0]
        if len(idxs) == 0:
            continue

        for H in horizon_list:
            cvae_err=[]; arima_err=[]; pers_err=[]
            cvae_crps=[]; cvae_wis=[]; cvae_cov50=[]; cvae_cov90=[]
            ar_crps=[];   ar_wis=[];   ar_cov50=[];   ar_cov90=[]
            pe_crps=[];   pe_wis=[];   pe_cov50=[];   pe_cov90=[]

            if comm in TARGET_COMMUNES and H==14:
                viz_cache[(comm,H)] = {"dates": [], "y": [], "cvae_mean": [], "cvae_p05": [], "cvae_p95": [],
                                       "ar_mean": [], "ar_p05": [], "ar_p95": []}

            for origin_idx in idxs:
                if origin_idx < SEQ_LEN or origin_idx + H >= len(dfc):
                    continue

                mean_c, samp_c = cvae_rollout(dfc, origin_idx, H, n_samples=N_ENSEMBLE,
                                              temp_z=TEMP_Z, temp_out=TEMP_OUT)
                if mean_c is None:
                    continue

                y_hist = y_true_full[:origin_idx+1]
                mean_a, samp_a = arima_forecast(y_hist, H, n_samples=N_ENSEMBLE)
                mean_p, samp_p = persistence_forecast(y_hist, H, n_samples=N_ENSEMBLE)

                truth = y_true_full[origin_idx+1:origin_idx+1+H]

                # step-wise metrics
                for h in range(H):
                    y_h = truth[h]

                    # CVAE (already in original units)
                    q_c = empirical_quantiles(samp_c[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    cvae_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_c))
                    cvae_crps.append(crps_ensemble(y_h, samp_c[h,:]))
                    cvae_cov50.append(coverage(y_h, q_c[0.25], q_c[0.75]))
                    cvae_cov90.append(coverage(y_h, q_c[0.05], q_c[0.95]))
                    cvae_err.append(abs(y_h - np.mean(samp_c[h,:])))

                    # ARIMA
                    q_a = empirical_quantiles(samp_a[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    ar_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_a))
                    ar_crps.append(crps_ensemble(y_h, samp_a[h,:]))
                    ar_cov50.append(coverage(y_h, q_a[0.25], q_a[0.75]))
                    ar_cov90.append(coverage(y_h, q_a[0.05], q_a[0.95]))
                    arima_err.append(abs(y_h - np.mean(samp_a[h,:])))

                    # Persistence
                    q_p = empirical_quantiles(samp_p[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    pe_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_p))
                    pe_crps.append(crps_ensemble(y_h, samp_p[h,:]))
                    pe_cov50.append(coverage(y_h, q_p[0.25], q_p[0.75]))
                    pe_cov90.append(coverage(y_h, q_p[0.05], q_p[0.95]))
                    pers_err.append(abs(y_h - np.mean(samp_p[h,:])))

                # cache for visualization
                if (comm in TARGET_COMMUNES) and (H==14):
                    dates = dfc.loc[origin_idx+1:origin_idx+H, "Date"].tolist()
                    viz_cache[(comm,H)]["dates"].append(dates)
                    viz_cache[(comm,H)]["y"].append(truth.tolist())
                    viz_cache[(comm,H)]["cvae_mean"].append(mean_c.tolist())
                    viz_cache[(comm,H)]["cvae_p05"].append([np.quantile(samp_c[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["cvae_p95"].append([np.quantile(samp_c[h,:],0.95) for h in range(H)])
                    viz_cache[(comm,H)]["ar_mean"].append(mean_a.tolist())
                    viz_cache[(comm,H)]["ar_p05"].append([np.quantile(samp_a[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["ar_p95"].append([np.quantile(samp_a[h,:],0.95) for h in range(H)])

            if len(cvae_err)==0:
                continue

            rows.append({
                "Commune": comm, "horizon": H,
                "MAE_CVAE": float(np.mean(cvae_err)),
                "CRPS_CVAE": float(np.mean(cvae_crps)),
                "WIS_CVAE": float(np.mean(cvae_wis)),
                "COV50_CVAE": float(np.mean(cvae_cov50)),
                "COV90_CVAE": float(np.mean(cvae_cov90)),
                "MAE_ARIMA": float(np.mean(arima_err)),
                "CRPS_ARIMA": float(np.mean(ar_crps)),
                "WIS_ARIMA": float(np.mean(ar_wis)),
                "COV50_ARIMA": float(np.mean(ar_cov50)),
                "COV90_ARIMA": float(np.mean(ar_cov90)),
                "MAE_PERSIST": float(np.mean(pers_err)),
                "CRPS_PERSIST": float(np.mean(pe_crps)),
                "WIS_PERSIST": float(np.mean(pe_wis)),
                "COV50_PERSIST": float(np.mean(pe_cov50)),
                "COV90_PERSIST": float(np.mean(pe_cov90))
            })

    res_df = pd.DataFrame(rows)
    res_df.to_csv(os.path.join(OUT_DIR, "A_tail_metrics_by_commune.csv"), index=False)
    agg = res_df.groupby("horizon").mean(numeric_only=True).reset_index()
    agg.to_csv(os.path.join(OUT_DIR, "A_tail_metrics_pooled_by_horizon.csv"), index=False)
    overall = res_df.mean(numeric_only=True).to_frame().T
    overall.to_csv(os.path.join(OUT_DIR, "A_tail_metrics_pooled_overall.csv"), index=False)

    print("=== [A] Tail-only: pooled by horizon ===")
    print(agg.round(4))
    print("\n=== [A] Tail-only: pooled overall ===")
    print(overall.round(4))

    return res_df, agg, overall, viz_cache

# ============================================================
# B) Rolling / Regime-stratified Evaluation (recommended)
#    (no retraining; just different origins)
# ============================================================
def pick_origins_regime_stratified(dfc, per_regime=4, min_gap=14, max_h=max(H_LIST)):
    """Pick ~per_regime origins per regime, separated by >= min_gap, with room for max_h."""
    idxs = []
    n = len(dfc)
    for s in sorted(dfc["state_id"].unique().tolist()):
        cand = np.where(dfc["state_id"].values == s)[0]
        cand = cand[(cand >= SEQ_LEN) & (cand <= n - max_h - 1)]
        chosen = []
        last = -10**9
        for i in cand:
            if i - last >= min_gap:
                chosen.append(i); last = i
            if len(chosen) >= per_regime:
                break
        idxs.extend(chosen)
    return sorted(set(idxs))

def pick_origins_rolling_every_k(dfc, step=7, max_h=max(H_LIST)):
    """Pick origins every 'step' days along the series."""
    n = len(dfc)
    valid = list(range(SEQ_LEN, n - max_h))
    return valid[::step]

def pick_origins_incidence_quantiles(dfc, per_bucket=4, min_gap=14, max_h=max(H_LIST)):
    """Pick origins stratified by incidence quantiles (low/mid/high)."""
    vals = dfc[CASE_COL].values
    q1, q2 = np.quantile(vals, [0.33, 0.66])
    buckets = {
        "low":  np.where(vals <= q1)[0],
        "mid":  np.where((vals > q1) & (vals <= q2))[0],
        "high": np.where(vals > q2)[0]
    }
    idxs = []
    n = len(dfc)
    for _, cand in buckets.items():
        cand = cand[(cand >= SEQ_LEN) & (cand <= n - max_h - 1)]
        chosen = []
        last = -10**9
        for i in cand:
            if i - last >= min_gap:
                chosen.append(i); last = i
            if len(chosen) >= per_bucket:
                break
        idxs.extend(chosen)
    return sorted(set(idxs))

def evaluate_with_selector(horizon_list=H_LIST, origin_selector="regime", **selector_kwargs):
    """
    origin_selector: 'regime' | 'rolling' | 'quantiles'
    selector_kwargs: parameters for the selector.
    """
    rows = []
    viz_cache = {}

    for comm, dfc in df_scaled.groupby("Commune"):
        dfc = dfc.sort_values("Date").reset_index(drop=True)
        y_true_full = dfc[OBS_COL].values  # original units

        # Select origins
        if origin_selector == "regime":
            origins = pick_origins_regime_stratified(dfc, **selector_kwargs)
        elif origin_selector == "rolling":
            origins = pick_origins_rolling_every_k(dfc, **selector_kwargs)
        elif origin_selector == "quantiles":
            origins = pick_origins_incidence_quantiles(dfc, **selector_kwargs)
        else:
            raise ValueError("Unknown origin_selector")

        if len(origins) == 0:
            continue

        for H in horizon_list:
            cvae_err=[]; arima_err=[]; pers_err=[]
            cvae_crps=[]; cvae_wis=[]; cvae_cov50=[]; cvae_cov90=[]
            ar_crps=[];   ar_wis=[];   ar_cov50=[];   ar_cov90=[]
            pe_crps=[];   pe_wis=[];   pe_cov50=[];   pe_cov90=[]

            if comm in TARGET_COMMUNES and H==14:
                viz_cache[(comm,H)] = {"dates": [], "y": [], "cvae_mean": [], "cvae_p05": [], "cvae_p95": [],
                                       "ar_mean": [], "ar_p05": [], "ar_p95": []}

            for origin_idx in origins:
                if origin_idx < SEQ_LEN or origin_idx + H >= len(dfc):
                    continue

                mean_c, samp_c = cvae_rollout(dfc, origin_idx, H, n_samples=N_ENSEMBLE,
                                              temp_z=TEMP_Z, temp_out=TEMP_OUT)
                if mean_c is None:
                    continue

                y_hist = y_true_full[:origin_idx+1]
                mean_a, samp_a = arima_forecast(y_hist, H, n_samples=N_ENSEMBLE)
                mean_p, samp_p = persistence_forecast(y_hist, H, n_samples=N_ENSEMBLE)

                truth = y_true_full[origin_idx+1:origin_idx+1+H]

                # step-wise metrics (all in ORIGINAL units)
                for h in range(H):
                    y_h = truth[h]

                    q_c = empirical_quantiles(samp_c[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    cvae_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_c))
                    cvae_crps.append(crps_ensemble(y_h, samp_c[h,:]))
                    cvae_cov50.append(coverage(y_h, q_c[0.25], q_c[0.75]))
                    cvae_cov90.append(coverage(y_h, q_c[0.05], q_c[0.95]))
                    cvae_err.append(abs(y_h - np.mean(samp_c[h,:])))

                    q_a = empirical_quantiles(samp_a[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    ar_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_a))
                    ar_crps.append(crps_ensemble(y_h, samp_a[h,:]))
                    ar_cov50.append(coverage(y_h, q_a[0.25], q_a[0.75]))
                    ar_cov90.append(coverage(y_h, q_a[0.05], q_a[0.95]))
                    arima_err.append(abs(y_h - np.mean(samp_a[h,:])))

                    q_p = empirical_quantiles(samp_p[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    pe_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_p))
                    pe_crps.append(crps_ensemble(y_h, samp_p[h,:]))
                    pe_cov50.append(coverage(y_h, q_p[0.25], q_p[0.75]))
                    pe_cov90.append(coverage(y_h, q_p[0.05], q_p[0.95]))
                    pers_err.append(abs(y_h - np.mean(samp_p[h,:])))

                # cache for figures (use last origin later)
                if (comm in TARGET_COMMUNES) and (H==14):
                    dates = dfc.loc[origin_idx+1:origin_idx+H, "Date"].tolist()
                    viz_cache[(comm,H)]["dates"].append(dates)
                    viz_cache[(comm,H)]["y"].append(truth.tolist())
                    viz_cache[(comm,H)]["cvae_mean"].append(mean_c.tolist())
                    viz_cache[(comm,H)]["cvae_p05"].append([np.quantile(samp_c[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["cvae_p95"].append([np.quantile(samp_c[h,:],0.95) for h in range(H)])
                    viz_cache[(comm,H)]["ar_mean"].append(mean_a.tolist())
                    viz_cache[(comm,H)]["ar_p05"].append([np.quantile(samp_a[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["ar_p95"].append([np.quantile(samp_a[h,:],0.95) for h in range(H)])

            if len(cvae_err)==0:
                continue

            rows.append({
                "Commune": comm, "horizon": H,
                "MAE_CVAE": float(np.mean(cvae_err)),
                "CRPS_CVAE": float(np.mean(cvae_crps)),
                "WIS_CVAE": float(np.mean(cvae_wis)),
                "COV50_CVAE": float(np.mean(cvae_cov50)),
                "COV90_CVAE": float(np.mean(cvae_cov90)),
                "MAE_ARIMA": float(np.mean(arima_err)),
                "CRPS_ARIMA": float(np.mean(ar_crps)),
                "WIS_ARIMA": float(np.mean(ar_wis)),
                "COV50_ARIMA": float(np.mean(ar_cov50)),
                "COV90_ARIMA": float(np.mean(ar_cov90)),
                "MAE_PERSIST": float(np.mean(pers_err)),
                "CRPS_PERSIST": float(np.mean(pe_crps)),
                "WIS_PERSIST": float(np.mean(pe_wis)),
                "COV50_PERSIST": float(np.mean(pe_cov50)),
                "COV90_PERSIST": float(np.mean(pe_cov90))
            })

    res_df = pd.DataFrame(rows)
    res_df.to_csv(os.path.join(OUT_DIR, "B_selector_metrics_by_commune.csv"), index=False)
    agg = res_df.groupby("horizon").mean(numeric_only=True).reset_index()
    agg.to_csv(os.path.join(OUT_DIR, "B_selector_metrics_pooled_by_horizon.csv"), index=False)
    overall = res_df.mean(numeric_only=True).to_frame().T
    overall.to_csv(os.path.join(OUT_DIR, "B_selector_metrics_pooled_overall.csv"), index=False)

    print("=== [B] Selector-based: pooled by horizon ===")
    print(agg.round(4))
    print("\n=== [B] Selector-based: pooled overall ===")
    print(overall.round(4))

    return res_df, agg, overall, viz_cache

# ============================================================
# Free-run reconstruction of full curve + shape metrics
# ============================================================
@torch.no_grad()
def cvae_free_run(dfc, start_idx=None):
    """
    Recursive 1-step rollout until the end, starting from start_idx (default SEQ_LEN-1).
    Returns predictions in ORIGINAL units (same as cvae_rollout output).
    """
    if start_idx is None:
        start_idx = SEQ_LEN - 1
    n = len(dfc)
    preds = []
    for t in range(start_idx, n-1):
        mean1, samp1 = cvae_rollout(dfc, t, horizon=1, n_samples=N_ENSEMBLE,
                                    temp_z=TEMP_Z, temp_out=TEMP_OUT)
        preds.append(float(mean1[0]) if mean1 is not None else np.nan)
    return np.array(preds, dtype=float)

def arima_free_run(y, start_idx=None):
    """Recursive ARIMA: refit on y[:t+1], forecast 1-step, advance."""
    if start_idx is None:
        start_idx = SEQ_LEN - 1
    n = len(y)
    out = []
    for t in range(start_idx, n-1):
        mean, _ = arima_forecast(y[:t+1], horizon=1, n_samples=64)
        out.append(float(mean[0]))
    return np.array(out, dtype=float)

def peak_metrics(y_true, y_pred):
    """Peak timing (days) and magnitude error."""
    i_t = int(np.nanargmax(y_true)); i_p = int(np.nanargmax(y_pred))
    return {"peak_timing_error_days": i_p - i_t,
            "peak_mag_error": float(abs(y_pred[i_p] - y_true[i_t]))}

def series_shape_metrics(y_true, y_pred):
    """Correlation, MAE, RMSE, and MASE against naive-1."""
    y_true = np.asarray(y_true, float); y_pred = np.asarray(y_pred, float)
    n = min(len(y_true), len(y_pred))
    yt = y_true[:n]; yp = y_pred[:n]
    corr = np.corrcoef(yt, yp)[0,1] if np.std(yt)>0 and np.std(yp)>0 else np.nan
    mae_v = float(np.mean(np.abs(yt-yp)))
    rmse_v = float(np.sqrt(np.mean((yt-yp)**2)))
    mase_den = np.mean(np.abs(yt[1:] - yt[:-1])) if len(yt)>1 else np.nan
    mase = mae_v / mase_den if (mase_den is not None and mase_den>0) else np.nan
    return {"corr": float(corr), "MAE": mae_v, "RMSE": rmse_v, "MASE": float(mase)}

# ============================================================
# Visualization helpers
# ============================================================
def plot_panels(viz_cache, H=14, n_panels=6):
    """
    Multi-commune panels: Observed vs CVAE vs ARIMA (90% bands),
    using the last origin only for each commune (clearer for publication).
    """
    communes = [c for c in TARGET_COMMUNES if (c,H) in viz_cache]
    n = min(n_panels, len(communes))
    if n == 0:
        print("No communes cached for visualization.")
        return

    nrows = math.ceil(n/2); ncols = 2
    fig, axes = plt.subplots(nrows, ncols, figsize=(16, 5.5*nrows), sharex=False)
    axes = np.array(axes).reshape(-1)

    for ax, comm in zip(axes, communes[:n]):
        block = viz_cache[(comm,H)]
        dates  = block["dates"][-1]
        y      = block["y"][-1]
        cm     = block["cvae_mean"][-1]; c05 = block["cvae_p05"][-1]; c95 = block["cvae_p95"][-1]
        am     = block["ar_mean"][-1];   a05 = block["ar_p05"][-1];   a95 = block["ar_p95"][-1]

        ax.plot(dates, y, lw=2.0, color="black", label="Observed")
        ax.fill_between(dates, c05, c95, alpha=0.20, label="CVAE 90%", edgecolor='none')
        ax.plot(dates, cm, lw=2.0, linestyle="--", label="CVAE mean")
        ax.fill_between(dates, a05, a95, alpha=0.15, label="ARIMA 90%", edgecolor='none')
        ax.plot(dates, am, lw=2.0, linestyle="-.", label="ARIMA mean")

        ax.set_title(f"{comm} — {H}-day ahead", fontsize=13)
        ax.set_ylabel("Daily cases (7-day MA)")
        ax.grid(True, alpha=0.3)

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=4, frameon=False)
    fig.autofmt_xdate()
    fig.tight_layout(rect=[0,0.05,1,1])
    path = os.path.join(OUT_DIR, f"viz_obs_vs_cvae_arima_H{H}.png")
    fig.savefig(path, dpi=300)
    plt.show()
    print(f"✓ Saved: {path}")

def plot_last_origin(commune, H=14, viz_cache=None, outdir=OUT_DIR):
    """Single-commune, last-origin chart with 90% bands (Observed, CVAE, ARIMA)."""
    if viz_cache is None or (commune, H) not in viz_cache or len(viz_cache[(commune,H)]["dates"]) == 0:
        print(f"No cache for {commune} (H={H}). Run evaluation first and ensure commune is in TARGET_COMMUNES.")
        return
    block = viz_cache[(commune,H)]
    dates  = block["dates"][-1]
    y      = block["y"][-1]
    cm     = block["cvae_mean"][-1]; c05 = block["cvae_p05"][-1]; c95 = block["cvae_p95"][-1]
    am     = block["ar_mean"][-1];   a05 = block["ar_p05"][-1];   a95 = block["ar_p95"][-1]

    plt.figure(figsize=(11,5))
    plt.plot(dates, y, color="black", lw=2.0, label="Observed")
    plt.fill_between(dates, c05, c95, alpha=0.20, label="CVAE 90%", edgecolor='none')
    plt.plot(dates, cm, lw=2.0, linestyle="--", label="CVAE mean")
    plt.fill_between(dates, a05, a95, alpha=0.15, label="ARIMA 90%", edgecolor='none')
    plt.plot(dates, am, lw=2.0, linestyle="-.", label="ARIMA mean")
    plt.title(f"{commune} — {H}-day ahead (last origin)")
    plt.ylabel("Daily cases (7-day MA)")
    plt.grid(True, alpha=0.3)
    plt.legend(frameon=False, ncol=4, loc="upper left")
    plt.tight_layout()
    png_path = os.path.join(outdir, f"last_origin_{commune.replace(' ','_')}_H{H}.png")
    plt.savefig(png_path, dpi=300)
    plt.show()
    print(f"✓ Saved: {png_path}")

# ============================================================
# RUN — choose one or both evaluation modes
# ============================================================

# --- (A) Tail-only (as in your current paper draft) ---
res_tail, agg_tail, overall_tail, viz_tail = evaluate_tail(H_LIST, TEST_DAYS)
print(f"✓ Tail-only CSVs saved in: {OUT_DIR}")

# Visualizations for (A)
plot_panels(viz_tail, H=14, n_panels=6)
plot_panels(viz_tail, H=7,  n_panels=6)
for c in TARGET_COMMUNES:
    plot_last_origin(c, H=14, viz_cache=viz_tail)

# --- (B) Selector-based backtesting (recommended for paper claims) ---
# Regime-stratified: 4 origins per regime, >=14d spacing
res_sel, agg_sel, overall_sel, viz_sel = evaluate_with_selector(
    horizon_list=[7,14],
    origin_selector="regime",
    per_regime=4, min_gap=14
)
print(f"✓ Selector-based CSVs saved in: {OUT_DIR}")

# Visualizations for (B)
plot_panels(viz_sel, H=14, n_panels=6)

# ============================================================
# Comparative summary table by commune (both horizons)
# ============================================================
def wide_summary(res_df, tag):
    """
    Build a wide table with per-commune metrics for H=7 and H=14
    for CVAE, ARIMA, and Persistence (MAE/CRPS/WIS + coverages).
    """
    keep_cols = ["Commune","horizon",
                 "MAE_CVAE","CRPS_CVAE","WIS_CVAE","COV50_CVAE","COV90_CVAE",
                 "MAE_ARIMA","CRPS_ARIMA","WIS_ARIMA","COV50_ARIMA","COV90_ARIMA",
                 "MAE_PERSIST","CRPS_PERSIST","WIS_PERSIST","COV50_PERSIST","COV90_PERSIST"]
    tmp = res_df[keep_cols].copy()
    # Aggregate by commune-horizon (mean across origins)
    tmp = tmp.groupby(["Commune","horizon"]).mean(numeric_only=True).reset_index()
    # Pivot to wide: columns like MAE_CVAE_H7, MAE_CVAE_H14, ...
    out = []
    for h in sorted(tmp["horizon"].unique()):
        dfh = tmp[tmp["horizon"]==h].drop(columns=["horizon"]).copy()
        dfh = dfh.set_index("Commune")
        dfh = dfh.add_suffix(f"_H{h}")
        out.append(dfh)
    wide = pd.concat(out, axis=1).reset_index()
    wide = wide.sort_values("Commune")
    path = os.path.join(OUT_DIR, f"{tag}_summary_by_commune_wide.csv")
    wide.to_csv(path, index=False)
    print(f"✓ Saved table: {path}")
    return wide

wide_tail = wide_summary(res_tail, tag="A_tail")
wide_sel  = wide_summary(res_sel,  tag="B_selector")

print("\n=== Head of selector-based summary (by commune) ===")
print(wide_sel.head(10).round(3))

# ============================================================
# Free-run reconstruction (shape metrics) for a subset or all communes
# ============================================================
shape_rows = []
for comm, dfc in df_scaled.groupby("Commune"):
    dfc = dfc.sort_values("Date").reset_index(drop=True)
    start = SEQ_LEN - 1  # warm-up
    # CVAE free-run (ORIGINAL units)
    cvae_pred = cvae_free_run(dfc, start_idx=start)
    y_true    = dfc[OBS_COL].values[start+1:]
    # ARIMA free-run
    ar_pred   = arima_free_run(dfc[OBS_COL].values, start_idx=start)

    # Align lengths
    nmin = min(len(y_true), len(cvae_pred), len(ar_pred))
    y_true = y_true[:nmin]; cvae_pred = cvae_pred[:nmin]; ar_pred = ar_pred[:nmin]

    pm_c = peak_metrics(y_true, cvae_pred);  sm_c = series_shape_metrics(y_true, cvae_pred)
    pm_a = peak_metrics(y_true, ar_pred);    sm_a = series_shape_metrics(y_true, ar_pred)

    shape_rows.append({
        "Commune": comm,
        "corr_CVAE": sm_c["corr"], "MASE_CVAE": sm_c["MASE"],
        "peak_dt_CVAE": pm_c["peak_timing_error_days"], "peak_err_CVAE": pm_c["peak_mag_error"],
        "corr_ARIMA": sm_a["corr"], "MASE_ARIMA": sm_a["MASE"],
        "peak_dt_ARIMA": pm_a["peak_timing_error_days"], "peak_err_ARIMA": pm_a["peak_mag_error"],
    })

shape_df = pd.DataFrame(shape_rows).sort_values("Commune")
shape_path = os.path.join(OUT_DIR, "C_free_run_shape_metrics.csv")
shape_df.to_csv(shape_path, index=False)
print(f"\n✓ Free-run shape metrics saved: {shape_path}")
print(shape_df.head(10).round(3))

# Optional: quick plot of free-run for target communes
def plot_free_run(commune):
    dfc = df_scaled[df_scaled["Commune"]==commune].sort_values("Date").reset_index(drop=True)
    start = SEQ_LEN - 1
    pred_c = cvae_free_run(dfc, start_idx=start)
    pred_a = arima_free_run(dfc[CASE_COL].values, start_idx=start)
    y = dfc[OBS_COL].values[start+1:]
    nmin = min(len(y), len(pred_c), len(pred_a))
    dates = dfc["Date"].iloc[start+1:start+1+nmin]

    plt.figure(figsize=(11,5))
    plt.plot(dates, y[:nmin], color="black", lw=2.0, label="Observed")
    plt.plot(dates, pred_c[:nmin], lw=2.0, linestyle="--", label="CVAE (free-run)")
    plt.plot(dates, pred_a[:nmin], lw=2.0, linestyle="-.", label="ARIMA (free-run)")
    plt.title(f"{commune} — Free-run reconstruction")
    plt.ylabel("Daily cases (7-day MA)")
    plt.grid(True, alpha=0.3)
    plt.legend(frameon=False, ncol=3, loc="upper left")
    plt.tight_layout()
    p = os.path.join(OUT_DIR, f"free_run_{commune.replace(' ','_')}.png")
    plt.savefig(p, dpi=300); plt.show()
    print(f"✓ Saved: {p}")

for c in TARGET_COMMUNES:
    plot_free_run(c)

print("\nAll done.")


In [None]:
# ---------------------------------
# Load data & encoders
# ---------------------------------
df = pd.read_csv(DATA_CSV, parse_dates=["Date"])
df = df.sort_values(["Commune","Date"]).reset_index(drop=True)

# Guardamos la columna de casos en UNIDADES ORIGINALES bajo un nombre fijo
OBS_COL = "Observed_Cases"
df[OBS_COL] = df[CASE_COL].values  # copia explícita

# Global MinMax scaler (fit on ALL features, incluyendo casos)
scaler = MinMaxScaler()
X_all = df[FEATURES].values
X_scaled = scaler.fit_transform(X_all)

# df_scaled: igual a df pero con FEATURES escalados; mantenemos OBS_COL en original
df_scaled = df.copy()
df_scaled[FEATURES] = X_scaled  # OJO: CASE_COL queda escalado aquí a propósito
# pero OBS_COL queda en original y es la que usaremos para verdad/ARIMA

# (opcional) sanity check útil
assert df_scaled[OBS_COL].max() > 5.0, "Observed_Cases parece estar escalado; revisa el flujo."

# Hidden states → enteros 0..K-1
state_vals = sorted(df_scaled["Hidden_State"].dropna().unique().tolist())
state_to_id = {s:i for i,s in enumerate(state_vals)}
id_to_state = {i:s for s,i in state_to_id.items()}
K_STATES = len(state_vals)
print(f"Detected hidden-state categories (ordered): {state_vals}  -> K={K_STATES}")

df_scaled["state_id"] = df_scaled["Hidden_State"].map(state_to_id).astype(int)

# Commune encoder (one-hot)
enc_comm = OneHotEncoder_compat(handle_unknown="ignore")
comm_cond = enc_comm.fit_transform(df_scaled[["Commune"]])  # (N, C)

# State encoder (para entrenamiento duro)
enc_state = OneHotEncoder_compat(categories=[state_vals], handle_unknown="ignore")
state_oh = enc_state.fit_transform(df_scaled[["Hidden_State"]])

# Full conditioning = [comm_onehot, state_onehot]
COND_FULL = np.concatenate([comm_cond, state_oh], axis=1)
COND_DIM  = COND_FULL.shape[1]
INPUT_DIM = len(FEATURES)

# Guardamos min/max por si los quieres inspeccionar
CASE_MIN, CASE_MAX = scaler.data_min_[0], scaler.data_max_[0]

# ---------------------------------
# Train/test split by time (tail-only mode)
# ---------------------------------
def time_split_by_commune(df_in, test_days=TEST_DAYS):
    marks = []
    for comm, d in df_in.groupby("Commune", sort=False):
        n = len(d)
        cut = max(0, n - test_days)
        marks.append(pd.Series([0]*cut + [1]*(n-cut), index=d.index))
    return pd.concat(marks).sort_index().values

is_test = time_split_by_commune(df_scaled, TEST_DAYS)
df_scaled["is_test"] = is_test

# ---------------------------------
# Tensors for training CVAE (train slice only)
# ---------------------------------
train_df = df_scaled[df_scaled["is_test"]==0].copy()
X_train_raw = train_df[FEATURES].values
C_train_raw = COND_FULL[train_df.index]
Y_train_raw = train_df[FEATURES].values  # next-step target built in sequence

X_seq, C_seq, Y_seq = create_sequences(X_train_raw, C_train_raw, Y_train_raw, seq_len=SEQ_LEN)
Xt = torch.tensor(X_seq, dtype=torch.float32)
Ct = torch.tensor(C_seq, dtype=torch.float32)
Yt = torch.tensor(Y_seq, dtype=torch.float32)

train_ds = torch.utils.data.TensorDataset(Xt, Ct, Yt)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

# ---------------------------------
# Build model & load checkpoint (or train from scratch)
# ---------------------------------
if Path(MODEL_PTH).exists():
    model = build_and_load_robust(INPUT_DIM, COND_DIM, MODEL_PTH, DEVICE)
    # Optional: quick fine-tune to calibrate the new logvar row
    if FINETUNE_AFTER_LOAD_EPOCHS > 0:
        print(f"→ Quick fine-tune for {FINETUNE_AFTER_LOAD_EPOCHS} epochs to calibrate logvar row...")
        opt = optim.Adam(model.parameters(), lr=LR)
        model.train()
        for ep in range(FINETUNE_AFTER_LOAD_EPOCHS):
            tot = 0.0
            for xb, cb, yb in train_loader:
                xb, cb, yb = xb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
                opt.zero_grad()
                mu_full, logvar_case, mu_z, logv_z = model(xb, cb, temp=1.0)
                loss = cvae_loss_hetero(mu_full, logvar_case, yb, mu_z, logv_z,
                                        alpha_mse=ALPHA_MSE, beta_kl=BETA_KL, out_temp=1.0)
                loss.backward(); opt.step()
                tot += loss.item() * xb.size(0)
            print(f"[FT {ep+1:02d}] loss={tot/len(train_ds):.4f}")
        torch.save(model.state_dict(), MODEL_PTH)
        print("✓ Fine-tuned weights saved.")
    model.eval()
else:
    print("No checkpoint found. Training from scratch...")
    model = CVAE_LSTM_HET(INPUT_DIM, COND_DIM, LATENT_DIM, HIDDEN_DIM, dec_hidden=DEC_HIDDEN).to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=LR)
    model.train()
    for ep in range(EPOCHS):
        tot = 0.0
        for xb, cb, yb in train_loader:
            xb, cb, yb = xb.to(DEVICE), cb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad()
            mu_full, logvar_case, mu_z, logv_z = model(xb, cb, temp=1.0)
            loss = cvae_loss_hetero(mu_full, logvar_case, yb, mu_z, logv_z,
                                    alpha_mse=ALPHA_MSE, beta_kl=BETA_KL, out_temp=1.0)
            loss.backward(); opt.step()
            tot += loss.item() * xb.size(0)
        print(f"[Epoch {ep+1:02d}] loss={tot/len(train_ds):.4f}")
    torch.save(model.state_dict(), MODEL_PTH)
    print(f"✓ Trained & saved: {MODEL_PTH}")
    model.eval()

# ---------------------------------
# Non-homogeneous transition model (multinomial LR)
# P(S_t | S_{t-1}, mobility, commune)
# ---------------------------------
def fit_transition_softmax(df_all):
    """
    Fit multinomial logistic regression for state transitions using:
    one-hot(S_{t-1}), mobility at t (IM, EM), and commune one-hot.
    Uses ONLY the training slice to avoid leakage.
    """
    use_df = df_all[df_all["is_test"]==0].copy()
    use_df["s_prev"] = use_df.groupby("Commune")["state_id"].shift(1)
    use_df["im_prev"] = use_df.groupby("Commune")["Internal_Mobility_Index"].shift(0)
    use_df["em_prev"] = use_df.groupby("Commune")["External_Mobility_Index"].shift(0)
    use_df = use_df.dropna(subset=["s_prev"]).copy()
    use_df["s_prev"] = use_df["s_prev"].astype(int)

    X_comm = enc_comm.transform(use_df[["Commune"]])
    X_mob  = use_df[["im_prev","em_prev"]].values
    enc_prev = OneHotEncoder_compat(categories=[list(range(K_STATES))])
    X_sp   = enc_prev.fit_transform(use_df[["s_prev"]])

    X = np.hstack([X_sp, X_mob, X_comm])
    y = use_df["state_id"].values.astype(int)

    clf = LogisticRegression(max_iter=300, multi_class="multinomial", solver="lbfgs")
    clf.fit(X, y)
    return clf

trans_clf = fit_transition_softmax(df_scaled)

def transition_proba_soft(df_slice_row):
    """Return P(S_t | S_{t-1}, mobility, commune) as vector length K_STATES."""
    s_prev = int(df_slice_row["state_id"])
    im = float(df_slice_row["Internal_Mobility_Index"])
    em = float(df_slice_row["External_Mobility_Index"])
    comm_vec = enc_comm.transform([[df_slice_row["Commune"]]])
    sp = np.zeros((1,K_STATES)); sp[0, s_prev] = 1.0
    X = np.hstack([sp, np.array([[im,em]]), comm_vec])
    p = trans_clf.predict_proba(X)[0]
    p = np.clip(p, 1e-8, 1.0)
    return p / p.sum()

# ---------------------------------
# Forecast helpers (ARIMA & Persistence)
# ---------------------------------
def persistence_forecast(y_hist, horizon, n_samples=200):
    """Deterministic point forecast + residual bootstrap for intervals."""
    last = float(y_hist[-1])
    mean = np.full(horizon, last)
    resid = np.diff(y_hist[-30:]) if len(y_hist) > 30 else np.diff(y_hist)
    if resid.size < 2:
        resid = np.array([0.0, 0.0, 0.0])
    samples = np.zeros((horizon, n_samples))
    for h in range(horizon):
        noise = np.random.choice(resid, size=n_samples, replace=True)
        samples[h,:] = mean[h] + noise
    return mean, samples

def arima_forecast(y_hist, horizon, n_samples=200, max_train=180):
    """
    Compact ARIMA grid on recent window; sample N(μ, σ^2_pred) for intervals.
    """
    y_train = np.asarray(y_hist[-max_train:], dtype=float)
    best_aic, best_res = np.inf, None
    for order in ARIMA_GRID:
        try:
            res = ARIMA(y_train, order=order).fit()
            if res.aic < best_aic:
                best_aic, best_res = res.aic, res
        except Exception:
            continue
    if best_res is None:
        return persistence_forecast(y_hist, horizon, n_samples)

    fc = best_res.get_forecast(steps=horizon)
    mean_fc = np.asarray(fc.predicted_mean)
    var_fc  = np.asarray(fc.var_pred_mean)
    std_fc  = np.sqrt(np.maximum(var_fc, 1e-8))
    samples = np.zeros((horizon, n_samples))
    for h in range(horizon):
        samples[h,:] = np.random.normal(mean_fc[h], std_fc[h], size=n_samples)
    return mean_fc, samples

# ---------------------------------
# CVAE rollout (soft states, autoregressive), mobility persistence
# Returns samples in ORIGINAL units (already inverse-transformed)
# ---------------------------------
@torch.no_grad()
def cvae_rollout(df_comm, origin_idx, horizon, n_samples=N_ENSEMBLE,
                 temp_z=TEMP_Z, temp_out=TEMP_OUT):
    """
    df_comm: dataframe of a single commune, scaled features & state_id present.
    origin_idx: index (in df_comm) of the last observed time used as origin.
    horizon: steps ahead.

    - For t>origin, mobility persists at last observed values.
    - Next-state probs via multinomial LR with (prev state, mobility, commune).
    - Conditioning uses SOFT state probs.
    - Output samples are inverse-transformed to ORIGINAL units for the 'case' var.
    """
    if origin_idx < SEQ_LEN:
        return None, None

    # Initial LSTM window (scaled)
    Xwin = df_comm[FEATURES].values[origin_idx-SEQ_LEN+1:origin_idx+1].astype(float)
    im_last = float(df_comm.iloc[origin_idx]["Internal_Mobility_Index"])
    em_last = float(df_comm.iloc[origin_idx]["External_Mobility_Index"])
    comm_name = df_comm.iloc[0]["Commune"]
    comm_onehot = enc_comm.transform([[comm_name]])[0]

    # Previous state (hard) -> soft thereafter
    s_prev_id = int(df_comm.iloc[origin_idx]["state_id"])
    s_prev_probs = np.zeros(K_STATES); s_prev_probs[s_prev_id] = 1.0

    samples = np.zeros((horizon, n_samples), dtype=float)
    x_buf = Xwin.copy()

    for m in range(n_samples):
        s_probs = s_prev_probs.copy()
        x_buf_m = x_buf.copy()

        for h in range(horizon):
            # Conditioning vector with soft state probs
            state_soft = s_probs.reshape(1,-1)  # (1,K)
            cond_vec = np.concatenate([comm_onehot.reshape(1,-1), state_soft], axis=1)

            xb = torch.tensor(x_buf_m[np.newaxis, :, :], dtype=torch.float32, device=DEVICE)
            cb = torch.tensor(cond_vec, dtype=torch.float32, device=DEVICE)

            mu_full, logvar_case, mu_z, logv_z = model(xb, cb, temp=temp_z)
            mu_full = mu_full.cpu().numpy()[0]
            logvar_case = logvar_case.cpu().numpy()[0,0]
            mu_case = mu_full[0]
            sigma_case = math.sqrt(max(1e-8, math.exp(logvar_case))) * temp_out

            # Sample next 'case' in SCALED space and clip to [0,1] for safety
            y_case_scaled = float(np.clip(np.random.normal(mu_case, sigma_case), 0.0, 1.0))

            # Build next-step scaled vector (use model means for mobility dims)
            mu_full[0] = y_case_scaled
            next_vec = mu_full.copy()

            # Inverse-scale to ORIGINAL units for the case variable
            tmp = np.zeros((1, len(FEATURES))); tmp[0,:] = next_vec
            unscaled_case = scaler.inverse_transform(tmp)[0,0]
            samples[h, m] = unscaled_case

            # Autoregressive buffer update (keep scaled buffer)
            x_buf_m = np.vstack([x_buf_m[1:], next_vec])

            # Project next state's soft probs (mobility persistence)
            pseudo = {
                "state_id": np.argmax(s_probs),
                "Internal_Mobility_Index": im_last, "External_Mobility_Index": em_last,
                "Commune": comm_name
            }
            s_probs = transition_proba_soft(pd.Series(pseudo))

    mean_path = np.mean(samples, axis=1)
    return mean_path, samples

# ============================================================
# A) Tail-only Evaluation (as in your original setup)
# ============================================================
def evaluate_tail(horizon_list=H_LIST, test_days=TEST_DAYS):
    rows = []
    viz_cache = {}

    for comm, dfc in df_scaled.groupby("Commune"):
        dfc = dfc.sort_values("Date").reset_index(drop=True)
        y_true_full = dfc[OBS_COL].values  # original units

        idxs = np.where(dfc["is_test"].values == 1)[0]
        if len(idxs) == 0:
            continue

        for H in horizon_list:
            cvae_err=[]; arima_err=[]; pers_err=[]
            cvae_crps=[]; cvae_wis=[]; cvae_cov50=[]; cvae_cov90=[]
            ar_crps=[];   ar_wis=[];   ar_cov50=[];   ar_cov90=[]
            pe_crps=[];   pe_wis=[];   pe_cov50=[];   pe_cov90=[]

            if comm in TARGET_COMMUNES and H==14:
                viz_cache[(comm,H)] = {"dates": [], "y": [], "cvae_mean": [], "cvae_p05": [], "cvae_p95": [],
                                       "ar_mean": [], "ar_p05": [], "ar_p95": []}

            for origin_idx in idxs:
                if origin_idx < SEQ_LEN or origin_idx + H >= len(dfc):
                    continue

                mean_c, samp_c = cvae_rollout(dfc, origin_idx, H, n_samples=N_ENSEMBLE,
                                              temp_z=TEMP_Z, temp_out=TEMP_OUT)
                if mean_c is None:
                    continue

                y_hist = y_true_full[:origin_idx+1]
                mean_a, samp_a = arima_forecast(y_hist, H, n_samples=N_ENSEMBLE)
                mean_p, samp_p = persistence_forecast(y_hist, H, n_samples=N_ENSEMBLE)

                truth = y_true_full[origin_idx+1:origin_idx+1+H]

                # step-wise metrics
                for h in range(H):
                    y_h = truth[h]

                    # CVAE (already in original units)
                    q_c = empirical_quantiles(samp_c[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    cvae_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_c))
                    cvae_crps.append(crps_ensemble(y_h, samp_c[h,:]))
                    cvae_cov50.append(coverage(y_h, q_c[0.25], q_c[0.75]))
                    cvae_cov90.append(coverage(y_h, q_c[0.05], q_c[0.95]))
                    cvae_err.append(abs(y_h - np.mean(samp_c[h,:])))

                    # ARIMA
                    q_a = empirical_quantiles(samp_a[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    ar_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_a))
                    ar_crps.append(crps_ensemble(y_h, samp_a[h,:]))
                    ar_cov50.append(coverage(y_h, q_a[0.25], q_a[0.75]))
                    ar_cov90.append(coverage(y_h, q_a[0.05], q_a[0.95]))
                    arima_err.append(abs(y_h - np.mean(samp_a[h,:])))

                    # Persistence
                    q_p = empirical_quantiles(samp_p[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    pe_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_p))
                    pe_crps.append(crps_ensemble(y_h, samp_p[h,:]))
                    pe_cov50.append(coverage(y_h, q_p[0.25], q_p[0.75]))
                    pe_cov90.append(coverage(y_h, q_p[0.05], q_p[0.95]))
                    pers_err.append(abs(y_h - np.mean(samp_p[h,:])))

                # cache for visualization
                if (comm in TARGET_COMMUNES) and (H==14):
                    dates = dfc.loc[origin_idx+1:origin_idx+H, "Date"].tolist()
                    viz_cache[(comm,H)]["dates"].append(dates)
                    viz_cache[(comm,H)]["y"].append(truth.tolist())
                    viz_cache[(comm,H)]["cvae_mean"].append(mean_c.tolist())
                    viz_cache[(comm,H)]["cvae_p05"].append([np.quantile(samp_c[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["cvae_p95"].append([np.quantile(samp_c[h,:],0.95) for h in range(H)])
                    viz_cache[(comm,H)]["ar_mean"].append(mean_a.tolist())
                    viz_cache[(comm,H)]["ar_p05"].append([np.quantile(samp_a[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["ar_p95"].append([np.quantile(samp_a[h,:],0.95) for h in range(H)])

            if len(cvae_err)==0:
                continue

            rows.append({
                "Commune": comm, "horizon": H,
                "MAE_CVAE": float(np.mean(cvae_err)),
                "CRPS_CVAE": float(np.mean(cvae_crps)),
                "WIS_CVAE": float(np.mean(cvae_wis)),
                "COV50_CVAE": float(np.mean(cvae_cov50)),
                "COV90_CVAE": float(np.mean(cvae_cov90)),
                "MAE_ARIMA": float(np.mean(arima_err)),
                "CRPS_ARIMA": float(np.mean(ar_crps)),
                "WIS_ARIMA": float(np.mean(ar_wis)),
                "COV50_ARIMA": float(np.mean(ar_cov50)),
                "COV90_ARIMA": float(np.mean(ar_cov90)),
                "MAE_PERSIST": float(np.mean(pers_err)),
                "CRPS_PERSIST": float(np.mean(pe_crps)),
                "WIS_PERSIST": float(np.mean(pe_wis)),
                "COV50_PERSIST": float(np.mean(pe_cov50)),
                "COV90_PERSIST": float(np.mean(pe_cov90))
            })

    res_df = pd.DataFrame(rows)
    res_df.to_csv(os.path.join(OUT_DIR, "A_tail_metrics_by_commune.csv"), index=False)
    agg = res_df.groupby("horizon").mean(numeric_only=True).reset_index()
    agg.to_csv(os.path.join(OUT_DIR, "A_tail_metrics_pooled_by_horizon.csv"), index=False)
    overall = res_df.mean(numeric_only=True).to_frame().T
    overall.to_csv(os.path.join(OUT_DIR, "A_tail_metrics_pooled_overall.csv"), index=False)

    print("=== [A] Tail-only: pooled by horizon ===")
    print(agg.round(4))
    print("\n=== [A] Tail-only: pooled overall ===")
    print(overall.round(4))

    return res_df, agg, overall, viz_cache

# ============================================================
# B) Rolling / Regime-stratified Evaluation (recommended)
#    (no retraining; just different origins)
# ============================================================
def pick_origins_regime_stratified(dfc, per_regime=4, min_gap=14, max_h=max(H_LIST)):
    """Pick ~per_regime origins per regime, separated by >= min_gap, with room for max_h."""
    idxs = []
    n = len(dfc)
    for s in sorted(dfc["state_id"].unique().tolist()):
        cand = np.where(dfc["state_id"].values == s)[0]
        cand = cand[(cand >= SEQ_LEN) & (cand <= n - max_h - 1)]
        chosen = []
        last = -10**9
        for i in cand:
            if i - last >= min_gap:
                chosen.append(i); last = i
            if len(chosen) >= per_regime:
                break
        idxs.extend(chosen)
    return sorted(set(idxs))

def pick_origins_rolling_every_k(dfc, step=7, max_h=max(H_LIST)):
    """Pick origins every 'step' days along the series."""
    n = len(dfc)
    valid = list(range(SEQ_LEN, n - max_h))
    return valid[::step]

def pick_origins_incidence_quantiles(dfc, per_bucket=4, min_gap=14, max_h=max(H_LIST)):
    """Pick origins stratified by incidence quantiles (low/mid/high)."""
    vals = dfc[CASE_COL].values
    q1, q2 = np.quantile(vals, [0.33, 0.66])
    buckets = {
        "low":  np.where(vals <= q1)[0],
        "mid":  np.where((vals > q1) & (vals <= q2))[0],
        "high": np.where(vals > q2)[0]
    }
    idxs = []
    n = len(dfc)
    for _, cand in buckets.items():
        cand = cand[(cand >= SEQ_LEN) & (cand <= n - max_h - 1)]
        chosen = []
        last = -10**9
        for i in cand:
            if i - last >= min_gap:
                chosen.append(i); last = i
            if len(chosen) >= per_bucket:
                break
        idxs.extend(chosen)
    return sorted(set(idxs))

def evaluate_with_selector(horizon_list=H_LIST, origin_selector="regime", **selector_kwargs):
    """
    origin_selector: 'regime' | 'rolling' | 'quantiles'
    selector_kwargs: parameters for the selector.
    """
    rows = []
    viz_cache = {}

    for comm, dfc in df_scaled.groupby("Commune"):
        dfc = dfc.sort_values("Date").reset_index(drop=True)
        y_true_full = dfc[OBS_COL].values  # original units

        # Select origins
        if origin_selector == "regime":
            origins = pick_origins_regime_stratified(dfc, **selector_kwargs)
        elif origin_selector == "rolling":
            origins = pick_origins_rolling_every_k(dfc, **selector_kwargs)
        elif origin_selector == "quantiles":
            origins = pick_origins_incidence_quantiles(dfc, **selector_kwargs)
        else:
            raise ValueError("Unknown origin_selector")

        if len(origins) == 0:
            continue

        for H in horizon_list:
            cvae_err=[]; arima_err=[]; pers_err=[]
            cvae_crps=[]; cvae_wis=[]; cvae_cov50=[]; cvae_cov90=[]
            ar_crps=[];   ar_wis=[];   ar_cov50=[];   ar_cov90=[]
            pe_crps=[];   pe_wis=[];   pe_cov50=[];   pe_cov90=[]

            if comm in TARGET_COMMUNES and H==14:
                viz_cache[(comm,H)] = {"dates": [], "y": [], "cvae_mean": [], "cvae_p05": [], "cvae_p95": [],
                                       "ar_mean": [], "ar_p05": [], "ar_p95": []}

            for origin_idx in origins:
                if origin_idx < SEQ_LEN or origin_idx + H >= len(dfc):
                    continue

                mean_c, samp_c = cvae_rollout(dfc, origin_idx, H, n_samples=N_ENSEMBLE,
                                              temp_z=TEMP_Z, temp_out=TEMP_OUT)
                if mean_c is None:
                    continue

                y_hist = y_true_full[:origin_idx+1]
                mean_a, samp_a = arima_forecast(y_hist, H, n_samples=N_ENSEMBLE)
                mean_p, samp_p = persistence_forecast(y_hist, H, n_samples=N_ENSEMBLE)

                truth = y_true_full[origin_idx+1:origin_idx+1+H]

                # step-wise metrics (all in ORIGINAL units)
                for h in range(H):
                    y_h = truth[h]

                    q_c = empirical_quantiles(samp_c[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    cvae_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_c))
                    cvae_crps.append(crps_ensemble(y_h, samp_c[h,:]))
                    cvae_cov50.append(coverage(y_h, q_c[0.25], q_c[0.75]))
                    cvae_cov90.append(coverage(y_h, q_c[0.05], q_c[0.95]))
                    cvae_err.append(abs(y_h - np.mean(samp_c[h,:])))

                    q_a = empirical_quantiles(samp_a[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    ar_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_a))
                    ar_crps.append(crps_ensemble(y_h, samp_a[h,:]))
                    ar_cov50.append(coverage(y_h, q_a[0.25], q_a[0.75]))
                    ar_cov90.append(coverage(y_h, q_a[0.05], q_a[0.95]))
                    arima_err.append(abs(y_h - np.mean(samp_a[h,:])))

                    q_p = empirical_quantiles(samp_p[h,:], qs=(0.05,0.25,0.5,0.75,0.95))
                    pe_wis.append(wis_from_quantiles(y_h, [0.5,0.9], q_p))
                    pe_crps.append(crps_ensemble(y_h, samp_p[h,:]))
                    pe_cov50.append(coverage(y_h, q_p[0.25], q_p[0.75]))
                    pe_cov90.append(coverage(y_h, q_p[0.05], q_p[0.95]))
                    pers_err.append(abs(y_h - np.mean(samp_p[h,:])))

                # cache for figures (use last origin later)
                if (comm in TARGET_COMMUNES) and (H==14):
                    dates = dfc.loc[origin_idx+1:origin_idx+H, "Date"].tolist()
                    viz_cache[(comm,H)]["dates"].append(dates)
                    viz_cache[(comm,H)]["y"].append(truth.tolist())
                    viz_cache[(comm,H)]["cvae_mean"].append(mean_c.tolist())
                    viz_cache[(comm,H)]["cvae_p05"].append([np.quantile(samp_c[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["cvae_p95"].append([np.quantile(samp_c[h,:],0.95) for h in range(H)])
                    viz_cache[(comm,H)]["ar_mean"].append(mean_a.tolist())
                    viz_cache[(comm,H)]["ar_p05"].append([np.quantile(samp_a[h,:],0.05) for h in range(H)])
                    viz_cache[(comm,H)]["ar_p95"].append([np.quantile(samp_a[h,:],0.95) for h in range(H)])

            if len(cvae_err)==0:
                continue

            rows.append({
                "Commune": comm, "horizon": H,
                "MAE_CVAE": float(np.mean(cvae_err)),
                "CRPS_CVAE": float(np.mean(cvae_crps)),
                "WIS_CVAE": float(np.mean(cvae_wis)),
                "COV50_CVAE": float(np.mean(cvae_cov50)),
                "COV90_CVAE": float(np.mean(cvae_cov90)),
                "MAE_ARIMA": float(np.mean(arima_err)),
                "CRPS_ARIMA": float(np.mean(ar_crps)),
                "WIS_ARIMA": float(np.mean(ar_wis)),
                "COV50_ARIMA": float(np.mean(ar_cov50)),
                "COV90_ARIMA": float(np.mean(ar_cov90)),
                "MAE_PERSIST": float(np.mean(pers_err)),
                "CRPS_PERSIST": float(np.mean(pe_crps)),
                "WIS_PERSIST": float(np.mean(pe_wis)),
                "COV50_PERSIST": float(np.mean(pe_cov50)),
                "COV90_PERSIST": float(np.mean(pe_cov90))
            })

    res_df = pd.DataFrame(rows)
    res_df.to_csv(os.path.join(OUT_DIR, "B_selector_metrics_by_commune.csv"), index=False)
    agg = res_df.groupby("horizon").mean(numeric_only=True).reset_index()
    agg.to_csv(os.path.join(OUT_DIR, "B_selector_metrics_pooled_by_horizon.csv"), index=False)
    overall = res_df.mean(numeric_only=True).to_frame().T
    overall.to_csv(os.path.join(OUT_DIR, "B_selector_metrics_pooled_overall.csv"), index=False)

    print("=== [B] Selector-based: pooled by horizon ===")
    print(agg.round(4))
    print("\n=== [B] Selector-based: pooled overall ===")
    print(overall.round(4))

    return res_df, agg, overall, viz_cache

# ============================================================
# Free-run reconstruction of full curve + shape metrics
# ============================================================
@torch.no_grad()
def cvae_free_run(dfc, start_idx=None):
    """
    Recursive 1-step rollout until the end, starting from start_idx (default SEQ_LEN-1).
    Returns predictions in ORIGINAL units (same as cvae_rollout output).
    """
    if start_idx is None:
        start_idx = SEQ_LEN - 1
    n = len(dfc)
    preds = []
    for t in range(start_idx, n-1):
        mean1, samp1 = cvae_rollout(dfc, t, horizon=1, n_samples=N_ENSEMBLE,
                                    temp_z=TEMP_Z, temp_out=TEMP_OUT)
        preds.append(float(mean1[0]) if mean1 is not None else np.nan)
    return np.array(preds, dtype=float)

def arima_free_run(y, start_idx=None):
    """Recursive ARIMA: refit on y[:t+1], forecast 1-step, advance."""
    if start_idx is None:
        start_idx = SEQ_LEN - 1
    n = len(y)
    out = []
    for t in range(start_idx, n-1):
        mean, _ = arima_forecast(y[:t+1], horizon=1, n_samples=64)
        out.append(float(mean[0]))
    return np.array(out, dtype=float)

def peak_metrics(y_true, y_pred):
    """Peak timing (days) and magnitude error."""
    i_t = int(np.nanargmax(y_true)); i_p = int(np.nanargmax(y_pred))
    return {"peak_timing_error_days": i_p - i_t,
            "peak_mag_error": float(abs(y_pred[i_p] - y_true[i_t]))}

def series_shape_metrics(y_true, y_pred):
    """Correlation, MAE, RMSE, and MASE against naive-1."""
    y_true = np.asarray(y_true, float); y_pred = np.asarray(y_pred, float)
    n = min(len(y_true), len(y_pred))
    yt = y_true[:n]; yp = y_pred[:n]
    corr = np.corrcoef(yt, yp)[0,1] if np.std(yt)>0 and np.std(yp)>0 else np.nan
    mae_v = float(np.mean(np.abs(yt-yp)))
    rmse_v = float(np.sqrt(np.mean((yt-yp)**2)))
    mase_den = np.mean(np.abs(yt[1:] - yt[:-1])) if len(yt)>1 else np.nan
    mase = mae_v / mase_den if (mase_den is not None and mase_den>0) else np.nan
    return {"corr": float(corr), "MAE": mae_v, "RMSE": rmse_v, "MASE": float(mase)}

# ============================================================
# Visualization helpers
# ============================================================
def plot_panels(viz_cache, H=14, n_panels=6):
    """
    Multi-commune panels: Observed vs CVAE vs ARIMA (90% bands),
    using the last origin only for each commune (clearer for publication).
    """
    communes = [c for c in TARGET_COMMUNES if (c,H) in viz_cache]
    n = min(n_panels, len(communes))
    if n == 0:
        print("No communes cached for visualization.")
        return

    nrows = math.ceil(n/2); ncols = 2
    fig, axes = plt.subplots(nrows, ncols, figsize=(16, 5.5*nrows), sharex=False)
    axes = np.array(axes).reshape(-1)

    for ax, comm in zip(axes, communes[:n]):
        block = viz_cache[(comm,H)]
        dates  = block["dates"][-1]
        y      = block["y"][-1]
        cm     = block["cvae_mean"][-1]; c05 = block["cvae_p05"][-1]; c95 = block["cvae_p95"][-1]
        am     = block["ar_mean"][-1];   a05 = block["ar_p05"][-1];   a95 = block["ar_p95"][-1]

        ax.plot(dates, y, lw=2.0, color="black", label="Observed")
        ax.fill_between(dates, c05, c95, alpha=0.20, label="CVAE 90%", edgecolor='none')
        ax.plot(dates, cm, lw=2.0, linestyle="--", label="CVAE mean")
        ax.fill_between(dates, a05, a95, alpha=0.15, label="ARIMA 90%", edgecolor='none')
        ax.plot(dates, am, lw=2.0, linestyle="-.", label="ARIMA mean")

        ax.set_title(f"{comm} — {H}-day ahead", fontsize=13)
        ax.set_ylabel("Daily cases (7-day MA)")
        ax.grid(True, alpha=0.3)

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=4, frameon=False)
    fig.autofmt_xdate()
    fig.tight_layout(rect=[0,0.05,1,1])
    path = os.path.join(OUT_DIR, f"viz_obs_vs_cvae_arima_H{H}.png")
    fig.savefig(path, dpi=300)
    plt.show()
    print(f"✓ Saved: {path}")

def plot_last_origin(commune, H=14, viz_cache=None, outdir=OUT_DIR):
    """Single-commune, last-origin chart with 90% bands (Observed, CVAE, ARIMA)."""
    if viz_cache is None or (commune, H) not in viz_cache or len(viz_cache[(commune,H)]["dates"]) == 0:
        print(f"No cache for {commune} (H={H}). Run evaluation first and ensure commune is in TARGET_COMMUNES.")
        return
    block = viz_cache[(commune,H)]
    dates  = block["dates"][-1]
    y      = block["y"][-1]
    cm     = block["cvae_mean"][-1]; c05 = block["cvae_p05"][-1]; c95 = block["cvae_p95"][-1]
    am     = block["ar_mean"][-1];   a05 = block["ar_p05"][-1];   a95 = block["ar_p95"][-1]

    plt.figure(figsize=(11,5))
    plt.plot(dates, y, color="black", lw=2.0, label="Observed")
    plt.fill_between(dates, c05, c95, alpha=0.20, label="CVAE 90%", edgecolor='none')
    plt.plot(dates, cm, lw=2.0, linestyle="--", label="CVAE mean")
    plt.fill_between(dates, a05, a95, alpha=0.15, label="ARIMA 90%", edgecolor='none')
    plt.plot(dates, am, lw=2.0, linestyle="-.", label="ARIMA mean")
    plt.title(f"{commune} — {H}-day ahead (last origin)")
    plt.ylabel("Daily cases (7-day MA)")
    plt.grid(True, alpha=0.3)
    plt.legend(frameon=False, ncol=4, loc="upper left")
    plt.tight_layout()
    png_path = os.path.join(outdir, f"last_origin_{commune.replace(' ','_')}_H{H}.png")
    plt.savefig(png_path, dpi=300)
    plt.show()
    print(f"✓ Saved: {png_path}")

# ============================================================
# RUN — choose one or both evaluation modes
# ============================================================

# --- (A) Tail-only (as in your current paper draft) ---
res_tail, agg_tail, overall_tail, viz_tail = evaluate_tail(H_LIST, TEST_DAYS)
print(f"✓ Tail-only CSVs saved in: {OUT_DIR}")

# Visualizations for (A)
plot_panels(viz_tail, H=14, n_panels=6)
plot_panels(viz_tail, H=7,  n_panels=6)
for c in TARGET_COMMUNES:
    plot_last_origin(c, H=14, viz_cache=viz_tail)

# --- (B) Selector-based backtesting (recommended for paper claims) ---
# Regime-stratified: 4 origins per regime, >=14d spacing
res_sel, agg_sel, overall_sel, viz_sel = evaluate_with_selector(
    horizon_list=[7,14],
    origin_selector="regime",
    per_regime=4, min_gap=14
)
print(f"✓ Selector-based CSVs saved in: {OUT_DIR}")

# Visualizations for (B)
plot_panels(viz_sel, H=14, n_panels=6)

# ============================================================
# Comparative summary table by commune (both horizons)
# ============================================================
def wide_summary(res_df, tag):
    """
    Build a wide table with per-commune metrics for H=7 and H=14
    for CVAE, ARIMA, and Persistence (MAE/CRPS/WIS + coverages).
    """
    keep_cols = ["Commune","horizon",
                 "MAE_CVAE","CRPS_CVAE","WIS_CVAE","COV50_CVAE","COV90_CVAE",
                 "MAE_ARIMA","CRPS_ARIMA","WIS_ARIMA","COV50_ARIMA","COV90_ARIMA",
                 "MAE_PERSIST","CRPS_PERSIST","WIS_PERSIST","COV50_PERSIST","COV90_PERSIST"]
    tmp = res_df[keep_cols].copy()
    # Aggregate by commune-horizon (mean across origins)
    tmp = tmp.groupby(["Commune","horizon"]).mean(numeric_only=True).reset_index()
    # Pivot to wide: columns like MAE_CVAE_H7, MAE_CVAE_H14, ...
    out = []
    for h in sorted(tmp["horizon"].unique()):
        dfh = tmp[tmp["horizon"]==h].drop(columns=["horizon"]).copy()
        dfh = dfh.set_index("Commune")
        dfh = dfh.add_suffix(f"_H{h}")
        out.append(dfh)
    wide = pd.concat(out, axis=1).reset_index()
    wide = wide.sort_values("Commune")
    path = os.path.join(OUT_DIR, f"{tag}_summary_by_commune_wide.csv")
    wide.to_csv(path, index=False)
    print(f"✓ Saved table: {path}")
    return wide

wide_tail = wide_summary(res_tail, tag="A_tail")
wide_sel  = wide_summary(res_sel,  tag="B_selector")

print("\n=== Head of selector-based summary (by commune) ===")
print(wide_sel.head(10).round(3))

# ============================================================
# Free-run reconstruction (shape metrics) for a subset or all communes
# ============================================================
shape_rows = []
for comm, dfc in df_scaled.groupby("Commune"):
    dfc = dfc.sort_values("Date").reset_index(drop=True)
    start = SEQ_LEN - 1  # warm-up
    # CVAE free-run (ORIGINAL units)
    cvae_pred = cvae_free_run(dfc, start_idx=start)
    y_true    = dfc[OBS_COL].values[start+1:]
    # ARIMA free-run
    ar_pred   = arima_free_run(dfc[OBS_COL].values, start_idx=start)

    # Align lengths
    nmin = min(len(y_true), len(cvae_pred), len(ar_pred))
    y_true = y_true[:nmin]; cvae_pred = cvae_pred[:nmin]; ar_pred = ar_pred[:nmin]

    pm_c = peak_metrics(y_true, cvae_pred);  sm_c = series_shape_metrics(y_true, cvae_pred)
    pm_a = peak_metrics(y_true, ar_pred);    sm_a = series_shape_metrics(y_true, ar_pred)

    shape_rows.append({
        "Commune": comm,
        "corr_CVAE": sm_c["corr"], "MASE_CVAE": sm_c["MASE"],
        "peak_dt_CVAE": pm_c["peak_timing_error_days"], "peak_err_CVAE": pm_c["peak_mag_error"],
        "corr_ARIMA": sm_a["corr"], "MASE_ARIMA": sm_a["MASE"],
        "peak_dt_ARIMA": pm_a["peak_timing_error_days"], "peak_err_ARIMA": pm_a["peak_mag_error"],
    })

shape_df = pd.DataFrame(shape_rows).sort_values("Commune")
shape_path = os.path.join(OUT_DIR, "C_free_run_shape_metrics.csv")
shape_df.to_csv(shape_path, index=False)
print(f"\n✓ Free-run shape metrics saved: {shape_path}")
print(shape_df.head(10).round(3))

# Optional: quick plot of free-run for target communes
def plot_free_run(commune):
    dfc = df_scaled[df_scaled["Commune"]==commune].sort_values("Date").reset_index(drop=True)
    start = SEQ_LEN - 1
    pred_c = cvae_free_run(dfc, start_idx=start)
    pred_a = arima_free_run(dfc[CASE_COL].values, start_idx=start)
    y = dfc[OBS_COL].values[start+1:]
    nmin = min(len(y), len(pred_c), len(pred_a))
    dates = dfc["Date"].iloc[start+1:start+1+nmin]

    plt.figure(figsize=(11,5))
    plt.plot(dates, y[:nmin], color="black", lw=2.0, label="Observed")
    plt.plot(dates, pred_c[:nmin], lw=2.0, linestyle="--", label="CVAE (free-run)")
    plt.plot(dates, pred_a[:nmin], lw=2.0, linestyle="-.", label="ARIMA (free-run)")
    plt.title(f"{commune} — Free-run reconstruction")
    plt.ylabel("Daily cases (7-day MA)")
    plt.grid(True, alpha=0.3)
    plt.legend(frameon=False, ncol=3, loc="upper left")
    plt.tight_layout()
    p = os.path.join(OUT_DIR, f"free_run_{commune.replace(' ','_')}.png")
    plt.savefig(p, dpi=300); plt.show()
    print(f"✓ Saved: {p}")

for c in TARGET_COMMUNES:
    plot_free_run(c)

print("\nAll done.")


In [None]:
# ============================================================
# Full-curve CVAE reconstructions (regime-conditioned) +
# Metrics: AW/POC/AAD (90% & 50%), CRPS, MAE/RMSE (median) +
# Plots & summary tables (per-commune and pooled).
# ============================================================

import os
import math, numpy as np, pandas as pd
import matplotlib.pyplot as plt
import torch  # needed for @torch.no_grad()

# IMPORTANT: we assume the following objects already exist from your previous code:
# - df_scaled  (with FEATURES scaled and an extra column OBS_COL="Observed_Cases" in ORIGINAL units)
# - model, DEVICE, scaler, enc_comm, K_STATES, FEATURES, CASE_COL, SEQ_LEN, OUT_DIR, TARGET_COMMUNES
# If not, run the previous cells first (with the scale fix that creates OBS_COL).

# --------- Config for simulations ----------
N_SIMS_FULL   = 400        # ensemble size per day
TEMP_Z_FULL   = 1.25       # latent temperature
TEMP_OUT_FULL = 1.25       # output variance temperature
OBS_COL       = "Observed_Cases"  # original-units cases column from the scale fix

# --------- Helpers: quantiles, coverage, CRPS ----------
def empirical_quants(a, qs=(0.05,0.25,0.5,0.75,0.95)):
    a = np.asarray(a).ravel()
    return {float(q): float(np.quantile(a, q)) for q in qs}

def crps_ens(y, samples):
    s = np.sort(np.asarray(samples).ravel())
    n = len(s); y = float(y)
    e1 = np.mean(np.abs(s - y))                              # E|X - y|
    diffs = np.diff(s)
    weights = np.arange(1, n) * (n - np.arange(1, n))
    e2 = 2.0 * np.sum(weights * diffs) / (n*n)               # E|X - X'|
    return float(e1 - 0.5*e2)

def cov_indicator(y, lo, hi):
    y = float(y)
    return 1.0 if (lo <= y <= hi) else 0.0

# --------- AAD (Average Asymmetry Degree) ----------
def asymmetry_degree(lo, med, hi):
    """
    Signed asymmetry degree in [-1,1]:
      A = ((hi - med) - (med - lo)) / (hi - lo)
    A = 0   → symmetric w.r.t. the median
    A > 0   → longer upper tail
    A < 0   → longer lower tail
    """
    lo = float(lo); med = float(med); hi = float(hi)
    width = max(hi - lo, 1e-12)
    return ((hi - med) - (med - lo)) / width

# --------- One-hot (hard) state encoding ----------
def state_onehot_from_id(state_id, K):
    v = np.zeros((K,), dtype=float)
    v[int(state_id)] = 1.0
    return v

# --------- Full-curve regime-conditioned CVAE simulation ----------
@torch.no_grad()
def simulate_full_curve_cvae(df_comm_scaled,
                             n_sims=N_SIMS_FULL,
                             temp_z=TEMP_Z_FULL,
                             temp_out=TEMP_OUT_FULL,
                             use_soft_states=False):
    """
    Produce full-length simulated trajectories for one commune by chaining day-by-day:
      - Inputs (past window) start with observed features (scaled).
      - At each day t, conditioning uses [comm-onehot, state(t)].
      - Only CASE is sampled; mobility dims are replaced by the observed mobility at t.
      - The sampled CASE feeds the autoregressive buffer for the next day.
    Returns dict with per-time quantiles/mean and arrays of truth (ORIGINAL units).
    """
    # Safety: need enough history for the first prediction
    if len(df_comm_scaled) <= SEQ_LEN:
        return None

    # Precompute encodings & arrays
    comm_name  = df_comm_scaled.iloc[0]["Commune"]
    comm_onehot= enc_comm.transform([[comm_name]])[0]          # shape (C,)
    X_all_scaled = df_comm_scaled[FEATURES].values.astype(float) # scaled features (cases+mobilities)
    states_id    = df_comm_scaled["state_id"].values.astype(int) # 0..K-1
    dates        = df_comm_scaled["Date"].tolist()

    T = len(df_comm_scaled)
    steps = T - SEQ_LEN

    # Storage for simulated trajectories (ORIGINAL units)
    sims  = np.zeros((n_sims, steps), dtype=float)

    # Truth in ORIGINAL units (use OBS_COL, not CASE_COL which is scaled)
    truth = df_comm_scaled[OBS_COL].values[SEQ_LEN:]
    # Quick sanity check to avoid scale mistakes:
    assert np.nanmax(truth) > 5.0, "Observed_Cases seems too small (maybe scaled?). Check the scale fix."

    # Build initial buffer for each simulation (observed scaled features)
    init_buf = X_all_scaled[:SEQ_LEN, :].copy()

    for m in range(n_sims):
        x_buf = init_buf.copy()
        for t in range(SEQ_LEN, T):
            # Condition vector at day t (regime-conditioned reconstruction)
            if use_soft_states:
                # Optional: approximate soft state via LR using s_{t-1} and mobility(t)
                pseudo = {
                    "state_id": states_id[t-1],
                    "Internal_Mobility_Index": df_comm_scaled.iloc[t]["Internal_Mobility_Index"],
                    "External_Mobility_Index": df_comm_scaled.iloc[t]["External_Mobility_Index"],
                    "Commune": comm_name
                }
                st_vec = transition_proba_soft(pd.Series(pseudo))  # soft probs
            else:
                st_vec = state_onehot_from_id(states_id[t], K_STATES)  # hard one-hot of observed state

            cond_vec = np.concatenate([comm_onehot, st_vec], axis=0)[None, :]  # (1, C+K)

            xb = torch.tensor(x_buf[None, :, :], dtype=torch.float32, device=DEVICE)
            cb = torch.tensor(cond_vec,          dtype=torch.float32, device=DEVICE)

            mu_full, logvar_case, mu_z, logv_z = model(xb, cb, temp=temp_z)
            mu_full = mu_full.cpu().numpy()[0]       # (D,)
            logvar  = float(logvar_case.cpu().numpy()[0,0])

            mu_case    = float(mu_full[0])
            sigma_case = math.sqrt(max(1e-10, math.exp(logvar))) * temp_out

            # Sample CASE in scaled space and clip to [0,1]
            y_scaled  = float(np.clip(np.random.normal(mu_case, sigma_case), 0.0, 1.0))

            # Build next-step vector for buffer:
            #   - CASE: sampled
            #   - MOBILITIES: observed (scaled) at day t (forces consistency)
            next_vec_scaled = mu_full.copy()
            next_vec_scaled[0]  = y_scaled
            next_vec_scaled[1:] = X_all_scaled[t, 1:]  # IM/EM observed, scaled

            # Inverse-scale CASE only for storage (ORIGINAL units)
            tmp = np.zeros((1, len(FEATURES))); tmp[0,:] = next_vec_scaled
            y_unscaled = float(scaler.inverse_transform(tmp)[0,0])
            sims[m, t-SEQ_LEN] = y_unscaled

            # AR buffer update (scaled)
            x_buf = np.vstack([x_buf[1:], next_vec_scaled])

    # Aggregate ensemble at each time
    q05, q25, q50, q75, q95, mean = [], [], [], [], [], []
    for k in range(steps):
        qu = empirical_quants(sims[:,k], qs=(0.05,0.25,0.5,0.75,0.95))
        q05.append(qu[0.05]); q25.append(qu[0.25]); q50.append(qu[0.5])
        q75.append(qu[0.75]); q95.append(qu[0.95])
        mean.append(float(np.mean(sims[:,k])))

    out = {
        "dates": dates[SEQ_LEN:],
        "truth": truth,  # ORIGINAL units
        "q05": np.array(q05), "q25": np.array(q25), "q50": np.array(q50),
        "q75": np.array(q75), "q95": np.array(q95), "mean": np.array(mean),
        "sims": sims  # (n_sims, steps)
    }
    return out

# --------- Metrics over a full curve ----------
def metrics_full_curve(bundle):
    """
    bundle: dict from simulate_full_curve_cvae
    Returns per-commune summary metrics (all in ORIGINAL units).
    """
    y   = np.asarray(bundle["truth"]).astype(float)
    q05 = np.asarray(bundle["q05"]); q50 = np.asarray(bundle["q50"]); q95 = np.asarray(bundle["q95"])
    q25 = np.asarray(bundle["q25"]); q75 = np.asarray(bundle["q75"])
    mean_sims = np.asarray(bundle["mean"])

    width90 = q95 - q05
    width50 = q75 - q25
    # AW (average width)
    AW90 = float(np.mean(width90))
    AW50 = float(np.mean(width50))

    # Coverage (POC) for central 90% and 50%
    COV90 = float(np.mean((y >= q05) & (y <= q95)))
    COV50 = float(np.mean((y >= q25) & (y <= q75)))

    # AAD (signed) and absolute AAD for 90% and 50%
    AAD90 = float(np.mean([asymmetry_degree(lo, md, hi) for lo, md, hi in zip(q05, q50, q95)]))
    AAD90_abs = float(np.mean([abs(asymmetry_degree(lo, md, hi)) for lo, md, hi in zip(q05, q50, q95)]))
    AAD50 = float(np.mean([asymmetry_degree(lo, md, hi) for lo, md, hi in zip(q25, q50, q75)]))
    AAD50_abs = float(np.mean([abs(asymmetry_degree(lo, md, hi)) for lo, md, hi in zip(q25, q50, q75)]))

    # CRPS (ensemble) averaged over time
    sims = bundle["sims"]
    CRPS = float(np.mean([crps_ens(y[t], sims[:,t]) for t in range(len(y))]))

    # Point accuracy (median)
    MAE_med  = float(np.mean(np.abs(y - q50)))
    RMSE_med = float(np.sqrt(np.mean((y - q50)**2)))

    return {
        "AW90": AW90, "COV90": COV90, "AAD90": AAD90, "AAD90_abs": AAD90_abs,
        "AW50": AW50, "COV50": COV50, "AAD50": AAD50, "AAD50_abs": AAD50_abs,
        "CRPS": CRPS, "MAE_median": MAE_med, "RMSE_median": RMSE_med
    }

# --------- Run full-curve simulations for many communes ----------
def run_full_curve_all(df_scaled, communes=None, n_sims=N_SIMS_FULL,
                       temp_z=TEMP_Z_FULL, temp_out=TEMP_OUT_FULL,
                       use_soft_states=False, out_dir=OUT_DIR):
    """
    Simulate and score all requested communes. Saves CSV with per-time quantiles
    (per commune) and a summary table with metrics per commune + pooled.
    """
    results = {}
    metrics_rows = []

    groups = df_scaled.groupby("Commune")
    if communes is None:
        communes = list(groups.groups.keys())

    for comm in communes:
        dfc = groups.get_group(comm).sort_values("Date").reset_index(drop=True)
        bundle = simulate_full_curve_cvae(dfc, n_sims=n_sims,
                                          temp_z=temp_z, temp_out=temp_out,
                                          use_soft_states=use_soft_states)
        if bundle is None:
            print(f"[skip] {comm}: not enough data for SEQ_LEN={SEQ_LEN}")
            continue

        # Save per-time outputs (helpful for SI)
        df_out = pd.DataFrame({
            "Date": bundle["dates"],
            "Observed": bundle["truth"],
            "Mean": bundle["mean"],
            "Q05": bundle["q05"], "Q25": bundle["q25"], "Q50": bundle["q50"],
            "Q75": bundle["q75"], "Q95": bundle["q95"]
        })
        csv_path = os.path.join(out_dir, f"fullcurve_{comm.replace(' ','_')}.csv")
        df_out.to_csv(csv_path, index=False)

        # Metrics
        m = metrics_full_curve(bundle)
        m["Commune"] = comm
        metrics_rows.append(m)
        results[comm] = {"bundle": bundle, "metrics": m, "csv": csv_path}
        print(f"✓ {comm}: saved per-time CSV → {csv_path}")

    # Summary tables
    met_df = pd.DataFrame(metrics_rows).set_index("Commune")
    met_df = met_df[["AW90","COV90","AAD90","AAD90_abs","AW50","COV50","AAD50","AAD50_abs",
                     "CRPS","MAE_median","RMSE_median"]]
    met_path = os.path.join(out_dir, "fullcurve_metrics_by_commune.csv")
    met_df.to_csv(met_path)

    pooled = met_df.mean().to_frame().T
    pooled.index = ["PooledMean"]
    pooled_path = os.path.join(out_dir, "fullcurve_metrics_pooled.csv")
    pooled.to_csv(pooled_path)

    print("\n=== Full-curve metrics: by commune ===")
    print(met_df.round(3))
    print("\n=== Full-curve metrics: pooled mean ===")
    print(pooled.round(3))
    print(f"\n✓ Saved summary CSVs in {out_dir}")

    return results, met_df, pooled

# --------- Visualization: full-curve panels (publication-ready) ----------
def plot_fullcurve_panels(results, communes, ncols=2, band="90", out_dir=OUT_DIR, fname="fullcurve_panels.png"):
    """
    Plot observed vs CVAE (median & mean) with central band (90% or 50%) for selected communes.
    """
    assert band in ("90","50")
    n = len(communes)
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(16, 5.0*nrows), sharex=False)
    axes = np.array(axes).reshape(-1)

    for ax, comm in zip(axes, communes):
        b = results[comm]["bundle"]
        dates = b["dates"]; obs = b["truth"]
        med = b["q50"]; mean = b["mean"]
        if band == "90":
            lo, hi = b["q05"], b["q95"]; lab = "CVAE 90%"
        else:
            lo, hi = b["q25"], b["q75"]; lab = "CVAE 50%"

        ax.plot(dates, obs, color="black", lw=2.0, label="Observed")
        ax.fill_between(dates, lo, hi, alpha=0.2, label=lab, edgecolor='none')
        ax.plot(dates, med,  lw=2.0, linestyle="--", label="CVAE median")
        ax.plot(dates, mean, lw=1.8, linestyle="-.", label="CVAE mean")
        ax.set_title(comm)
        ax.set_ylabel("Daily cases (7-day MA)")
        ax.grid(True, alpha=0.3)

    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=4, frameon=False)
    fig.autofmt_xdate()
    fig.tight_layout(rect=[0,0.05,1,1])
    path = os.path.join(out_dir, fname)
    fig.savefig(path, dpi=300)
    plt.show()
    print(f"✓ Saved: {path}")

# ============================================================
# RUN: full-curve simulations + metrics + plots
# (Select a subset for panels; all communes go to CSV + table)
# ============================================================
COMM_FOR_PANELS = TARGET_COMMUNES   # reuse your earlier selection

results_full, table_full, pooled_full = run_full_curve_all(
    df_scaled,
    communes=None,                 # None -> all communes in df_scaled
    n_sims=N_SIMS_FULL,
    temp_z=TEMP_Z_FULL,
    temp_out=TEMP_OUT_FULL,
    use_soft_states=False,         # True: LR-projected soft states instead of observed
    out_dir=OUT_DIR
)

# Plots (90% band by default; change band="50" if needed)
plot_fullcurve_panels(results_full, COMM_FOR_PANELS, ncols=2, band="90",
                      out_dir=OUT_DIR, fname="fullcurve_panels_90.png")
plot_fullcurve_panels(results_full, COMM_FOR_PANELS, ncols=2, band="50",
                      out_dir=OUT_DIR, fname="fullcurve_panels_50.png")

# Pretty table for the paper (rounded)
display_tbl = table_full.copy().round({
    "AW90":3,"COV90":3,"AAD90":3,"AAD90_abs":3,
    "AW50":3,"COV50":3,"AAD50":3,"AAD50_abs":3,
    "CRPS":3,"MAE_median":3,"RMSE_median":3
})
print("\n=== Table (rounded) — Full-curve CVAE metrics by commune ===")
print(display_tbl)
print("\n=== Pooled mean (rounded) ===")
print(pooled_full.round(3))


In [None]:
# ============================================================
# FULL-CURVE (toda la serie) — CVAE & ARIMA con bandas
# Métricas (AW/POC/AAD/CRPS/MAE/RMSE), tablas y paneles 600 dpi
# Requiere que ya existan: df_scaled, FEATURES, CASE_COL, OBS_COL,
# enc_comm, scaler, K_STATES, model, DEVICE, OUT_DIR, TARGET_COMMUNES.
# ============================================================

import os, math, numpy as np, pandas as pd
import matplotlib.pyplot as plt
import torch
from statsmodels.tsa.arima.model import ARIMA

# ---------- Config ----------
N_SIMS_FULL   = 400          # tamaño del ensemble por día (CVAE y ARIMA)
TEMP_Z_FULL   = 1.25
TEMP_OUT_FULL = 1.25
OBS_COL       = "Observed_Cases"  # columna en unidades originales
ARIMA_GRID    = [(1,0,0),(1,1,0),(0,1,1),(1,1,1),(2,1,1)]  # por si no estuviera definida

# ---------- Utilidades ----------
def empirical_quants(a, qs=(0.05,0.25,0.5,0.75,0.95)):
    a = np.asarray(a).ravel()
    return {float(q): float(np.quantile(a, q)) for q in qs}

def crps_ens(y, samples):
    s = np.sort(np.asarray(samples).ravel())
    n = len(s); y = float(y)
    e1 = np.mean(np.abs(s - y))
    diffs = np.diff(s)
    weights = np.arange(1, n) * (n - np.arange(1, n))
    e2 = 2.0 * np.sum(weights * diffs) / (n*n)
    return float(e1 - 0.5*e2)

def asymmetry_degree(lo, med, hi):
    lo = float(lo); med = float(med); hi = float(hi)
    width = max(hi - lo, 1e-12)
    return ((hi - med) - (med - lo)) / width

def state_onehot_from_id(state_id, K):
    v = np.zeros((K,), dtype=float)
    v[int(state_id)] = 1.0
    return v

# ---------- Simulación CVAE a lo largo de TODA la curva ----------
@torch.no_grad()
def simulate_full_curve_cvae(df_comm_scaled,
                             n_sims=N_SIMS_FULL,
                             temp_z=TEMP_Z_FULL,
                             temp_out=TEMP_OUT_FULL,
                             use_soft_states=False):
    if len(df_comm_scaled) <= SEQ_LEN:
        return None

    comm_name   = df_comm_scaled.iloc[0]["Commune"]
    comm_onehot = enc_comm.transform([[comm_name]])[0]
    X_all_scaled = df_comm_scaled[FEATURES].values.astype(float)     # SCALED
    states_id    = df_comm_scaled["state_id"].values.astype(int)
    dates        = df_comm_scaled["Date"].tolist()

    T = len(df_comm_scaled)
    steps = T - SEQ_LEN
    sims  = np.zeros((n_sims, steps), dtype=float)
    truth = df_comm_scaled[OBS_COL].values[SEQ_LEN:]
    assert np.nanmax(truth) > 5.0, "Observed_Cases parece escalado; revisa tu flujo de escala."

    init_buf = X_all_scaled[:SEQ_LEN, :].copy()

    for m in range(n_sims):
        x_buf = init_buf.copy()
        for t in range(SEQ_LEN, T):
            if use_soft_states:
                pseudo = {
                    "state_id": states_id[t-1],
                    "Internal_Mobility_Index": df_comm_scaled.iloc[t]["Internal_Mobility_Index"],
                    "External_Mobility_Index": df_comm_scaled.iloc[t]["External_Mobility_Index"],
                    "Commune": comm_name
                }
                st_vec = transition_proba_soft(pd.Series(pseudo))
            else:
                st_vec = state_onehot_from_id(states_id[t], K_STATES)

            cond_vec = np.concatenate([comm_onehot, st_vec], axis=0)[None, :]

            xb = torch.tensor(x_buf[None, :, :], dtype=torch.float32, device=DEVICE)
            cb = torch.tensor(cond_vec,          dtype=torch.float32, device=DEVICE)

            mu_full, logvar_case, _, _ = model(xb, cb, temp=temp_z)
            mu_full = mu_full.cpu().numpy()[0]
            logvar  = float(logvar_case.cpu().numpy()[0,0])
            mu_case = float(mu_full[0])
            sigma_case = math.sqrt(max(1e-10, math.exp(logvar))) * temp_out

            y_scaled = float(np.clip(np.random.normal(mu_case, sigma_case), 0.0, 1.0))
            next_vec_scaled = mu_full.copy()
            next_vec_scaled[0]  = y_scaled
            next_vec_scaled[1:] = X_all_scaled[t, 1:]  # IM/EM observadas (SCALED)

            tmp = np.zeros((1, len(FEATURES))); tmp[0,:] = next_vec_scaled
            y_unscaled = float(scaler.inverse_transform(tmp)[0,0])
            sims[m, t-SEQ_LEN] = y_unscaled

            x_buf = np.vstack([x_buf[1:], next_vec_scaled])

    # Agrega cuantiles y media
    q05=q25=q50=q75=q95=mean=None
    q05=[]; q25=[]; q50=[]; q75=[]; q95=[]; mean=[]
    for k in range(steps):
        qu = empirical_quants(sims[:,k], qs=(0.05,0.25,0.5,0.75,0.95))
        q05.append(qu[0.05]); q25.append(qu[0.25]); q50.append(qu[0.5])
        q75.append(qu[0.75]); q95.append(qu[0.95])
        mean.append(float(np.mean(sims[:,k])))

    return {
        "dates": dates[SEQ_LEN:],
        "truth": truth,
        "q05": np.array(q05), "q25": np.array(q25), "q50": np.array(q50),
        "q75": np.array(q75), "q95": np.array(q95), "mean": np.array(mean),
        "sims": sims
    }

# ---------- ARIMA rolling 1-paso con bandas para TODA la curva ----------
def arima_full_curve_with_bands(y_orig, dates, start_idx=None, n_sims=N_SIMS_FULL):
    """
    y_orig: serie en UNIDADES ORIGINALES (usa OBS_COL!)
    Devuelve bundle con q05/q25/q50/q75/q95 y muestras (Normal approx).
    """
    if start_idx is None:
        start_idx = SEQ_LEN - 1
    y = np.asarray(y_orig, float)
    n = len(y)
    steps = n - SEQ_LEN
    mean = np.zeros(steps)
    q05  = np.zeros(steps); q25 = np.zeros(steps); q50 = np.zeros(steps)
    q75  = np.zeros(steps); q95 = np.zeros(steps)
    sims = np.zeros((n_sims, steps))

    Z05, Z25 = -1.6448536269514729, -0.6744897501960817
    Z75, Z95 =  0.6744897501960817,  1.6448536269514729

    k = 0
    for t in range(start_idx, n-1):
        y_train = y[:t+1]
        best, best_aic = None, np.inf
        for order in ARIMA_GRID:
            try:
                res = ARIMA(y_train, order=order).fit()
                if res.aic < best_aic:
                    best, best_aic = res, res.aic
            except Exception:
                continue

        if best is None:
            mu = y_train[-1]
            std = np.std(np.diff(y_train[-30:])) if len(y_train) > 2 else 1.0
        else:
            fc  = best.get_forecast(steps=1)
            mu  = float(fc.predicted_mean[0])
            var = float(fc.var_pred_mean[0])
            std = math.sqrt(max(var, 1e-12))

        mean[k] = mu; q50[k] = mu
        q25[k] = max(0.0, mu + Z25*std); q75[k] = max(0.0, mu + Z75*std)
        q05[k] = max(0.0, mu + Z05*std); q95[k] = max(0.0, mu + Z95*std)
        sims[:, k] = np.random.normal(mu, std, size=n_sims)
        k += 1

    return {
        "dates": dates[SEQ_LEN:],
        "truth": y[SEQ_LEN:],
        "q05": q05, "q25": q25, "q50": q50, "q75": q75, "q95": q95,
        "mean": mean, "sims": sims
    }

# ---------- Métricas de curva completa ----------
def metrics_full_curve(bundle):
    y   = np.asarray(bundle["truth"]).astype(float)
    q05 = np.asarray(bundle["q05"]); q50 = np.asarray(bundle["q50"]); q95 = np.asarray(bundle["q95"])
    q25 = np.asarray(bundle["q25"]); q75 = np.asarray(bundle["q75"])
    sims= bundle["sims"]

    width90 = q95 - q05; width50 = q75 - q25
    out = {
        "AW90": float(np.mean(width90)),
        "COV90": float(np.mean((y >= q05) & (y <= q95))),
        "AAD90": float(np.mean([asymmetry_degree(lo, md, hi) for lo, md, hi in zip(q05, q50, q95)])),
        "AAD90_abs": float(np.mean([abs(asymmetry_degree(lo, md, hi)) for lo, md, hi in zip(q05, q50, q95)])),
        "AW50": float(np.mean(width50)),
        "COV50": float(np.mean((y >= q25) & (y <= q75))),
        "AAD50": float(np.mean([asymmetry_degree(lo, md, hi) for lo, md, hi in zip(q25, q50, q75)])),
        "AAD50_abs": float(np.mean([abs(asymmetry_degree(lo, md, hi)) for lo, md, hi in zip(q25, q50, q75)])),
        "CRPS": float(np.mean([crps_ens(y[t], sims[:,t]) for t in range(len(y))])),
        "MAE_median": float(np.mean(np.abs(y - q50))),
        "RMSE_median": float(np.sqrt(np.mean((y - q50)**2)))
    }
    return out

# ---------- Ejecuta simulación de TODA la curva para todas las comunas ----------
def run_full_curve_both(df_scaled, communes=None, n_sims=N_SIMS_FULL,
                        temp_z=TEMP_Z_FULL, temp_out=TEMP_OUT_FULL, out_dir=OUT_DIR):
    results_cvae, results_arima = {}, {}
    rows_cvae, rows_arima = [], []

    groups = df_scaled.groupby("Commune")
    if communes is None:
        communes = list(groups.groups.keys())

    for comm in communes:
        dfc = groups.get_group(comm).sort_values("Date").reset_index(drop=True)
        # CVAE
        b_cvae = simulate_full_curve_cvae(dfc, n_sims=n_sims, temp_z=temp_z, temp_out=temp_out)
        if b_cvae is None: 
            print(f"[skip] {comm}: insufficient length")
            continue
        results_cvae[comm] = b_cvae
        m_c = metrics_full_curve(b_cvae); m_c["Commune"] = comm; rows_cvae.append(m_c)
        pd.DataFrame({
            "Date": b_cvae["dates"], "Observed": b_cvae["truth"], "Mean": b_cvae["mean"],
            "Q05": b_cvae["q05"], "Q25": b_cvae["q25"], "Q50": b_cvae["q50"], "Q75": b_cvae["q75"], "Q95": b_cvae["q95"]
        }).to_csv(os.path.join(out_dir, f"fullcurve_CVAE_{comm.replace(' ','_')}.csv"), index=False)

        # ARIMA (¡en OBS_COL!)
        y_obs = dfc[OBS_COL].values
        b_ari = arima_full_curve_with_bands(y_obs, dfc["Date"].tolist(), start_idx=SEQ_LEN-1, n_sims=n_sims)
        results_arima[comm] = b_ari
        m_a = metrics_full_curve(b_ari); m_a["Commune"] = comm; rows_arima.append(m_a)
        pd.DataFrame({
            "Date": b_ari["dates"], "Observed": b_ari["truth"], "Mean": b_ari["mean"],
            "Q05": b_ari["q05"], "Q25": b_ari["q25"], "Q50": b_ari["q50"], "Q75": b_ari["q75"], "Q95": b_ari["q95"]
        }).to_csv(os.path.join(out_dir, f"fullcurve_ARIMA_{comm.replace(' ','_')}.csv"), index=False)

        print(f"✓ {comm}: CSV CVAE/ARIMA guardados (curva completa)")

    # Tablas por comuna
    tbl_c = pd.DataFrame(rows_cvae).set_index("Commune").sort_index()
    tbl_a = pd.DataFrame(rows_arima).set_index("Commune").sort_index()
    tbl_c.to_csv(os.path.join(out_dir, "fullcurve_metrics_CVAE_by_commune.csv"))
    tbl_a.to_csv(os.path.join(out_dir, "fullcurve_metrics_ARIMA_by_commune.csv"))

    # Tabla combinada (sufijos)
    combo = tbl_c.add_suffix("_CVAE").join(tbl_a.add_suffix("_ARIMA"), how="inner")
    combo.to_csv(os.path.join(out_dir, "fullcurve_metrics_COMBINED_by_commune.csv"))

    # Pooled
    pooled_c = tbl_c.mean().to_frame().T; pooled_c.index = ["PooledMean_CVAE"]
    pooled_a = tbl_a.mean().to_frame().T; pooled_a.index = ["PooledMean_ARIMA"]
    pooled = pd.concat([pooled_c, pooled_a], axis=0)
    pooled.to_csv(os.path.join(out_dir, "fullcurve_metrics_pooled_CVAE_ARIMA.csv"))

    print("\n=== Métricas por comuna — CVAE ==="); print(tbl_c.round(3))
    print("\n=== Métricas por comuna — ARIMA ==="); print(tbl_a.round(3))
    print("\n=== Pooled (medias) ==="); print(pooled.round(3))
    print(f"\n✓ Tablas guardadas en {out_dir}")

    return results_cvae, results_arima, combo, pooled

# ---------- Paneles (ambas bandas y ambos modelos) ----------
def plot_fullcurve_panels_both(results_cvae, results_arima, communes=None,
                               band="90", ncols=3, out_dir=OUT_DIR,
                               fname="fullcurve_panels_CVAE_ARIMA.png"):
    assert band in ("90","50")
    if communes is None:
        communes = sorted(list(results_cvae.keys()))

    n = len(communes)
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(5.6*ncols, 4.2*nrows), sharex=False)
    axes = np.array(axes).reshape(-1)

    for ax, comm in zip(axes, communes):
        b_c = results_cvae[comm]; b_a = results_arima[comm]
        dates = b_c["dates"]; obs = b_c["truth"]
        med_c, mean_c = b_c["q50"], b_c["mean"]
        mean_a = b_a["mean"]

        if band == "90":
            lo_c, hi_c = b_c["q05"], b_c["q95"]; lab_c = "CVAE 90%"
            lo_a, hi_a = b_a["q05"], b_a["q95"]; lab_a = "ARIMA 90%"
        else:
            lo_c, hi_c = b_c["q25"], b_c["q75"]; lab_c = "CVAE 50%"
            lo_a, hi_a = b_a["q25"], b_a["q75"]; lab_a = "ARIMA 50%"

        ax.plot(dates, obs, color="black", lw=2.0, label="Observed", zorder=5)
        ax.fill_between(dates, lo_c, hi_c, alpha=0.18, label=lab_c, edgecolor='none', zorder=1)
        ax.fill_between(dates, lo_a, hi_a, alpha=0.18, label=lab_a, edgecolor='none', zorder=2)
        ax.plot(dates, med_c,  lw=2.0, linestyle="--", label="CVAE median", zorder=6)
        ax.plot(dates, mean_a, lw=2.0, linestyle="-.", label="ARIMA mean",  zorder=6)
        ax.set_title(comm)
        ax.set_ylabel("Daily cases (7-day MA)")
        ax.grid(True, alpha=0.3)

    # leyenda global
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=5, frameon=False)
    fig.autofmt_xdate()
    fig.tight_layout(rect=[0,0.05,1,1])
    path = os.path.join(out_dir, fname)
    fig.savefig(path, dpi=600)   # alta resolución para publicación
    plt.show()
    print(f"✓ Panel guardado: {path}")

# ============================
# Ejecutar todo (curva completa)
# ============================
ALL_COMMUNES = sorted(df_scaled["Commune"].unique().tolist())  # o usa TARGET_COMMUNES
res_cvae, res_arima, table_combined, pooled = run_full_curve_both(
    df_scaled,
    communes=ALL_COMMUNES,
    n_sims=N_SIMS_FULL,
    temp_z=TEMP_Z_FULL, temp_out=TEMP_OUT_FULL,
    out_dir=OUT_DIR
)

# Paneles 90% y 50% para TODAS las comunas (cuidado: figura larga si hay muchas)
plot_fullcurve_panels_both(res_cvae, res_arima, communes=ALL_COMMUNES,
                           band="90", ncols=3, out_dir=OUT_DIR,
                           fname="fullcurve_panels_90_CVAE_ARIMA.png")

plot_fullcurve_panels_both(res_cvae, res_arima, communes=ALL_COMMUNES,
                           band="50", ncols=3, out_dir=OUT_DIR,
                           fname="fullcurve_panels_50_CVAE_ARIMA.png")

# Vista rápida de tabla combinada (redondeada) para el paper
print("\n=== Tabla combinada (redondeada) — métricas por comuna (curva completa) ===")
print(table_combined.round(3))


In [None]:
# ============================================================
# Paper figure: CVAE-only panels (Observed + CVAE median/mean + 50% band)
# - Observed: black solid line
# - CVAE 50% band: lighter blue, higher transparency
# - CVAE median: blue dashed
# - CVAE mean: orange solid (no dashes, per request)
# - High-quality export (PNG + PDF)
# ============================================================

import os
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

# If you don't have this list in scope, (re)define it:
TARGET_COMMUNES = ["La Florida", "Cerrillos", "Vitacura",
                   "Providencia", "Las Condes", "Santiago"]

def _ensure_bundles_for(communes,
                        results_dict=None,
                        n_sims=400,
                        temp_z=1.25,
                        temp_out=1.25,
                        use_soft_states=False):
    """
    Ensure we have the full-curve simulation bundles for each commune.
    Uses 'results_full' if available; otherwise calls simulate_full_curve_cvae.
    Returns: dict {commune: bundle}
      bundle keys: dates, q05, q25, q50, q75, q95, mean, truth (np arrays)
    """
    bundles = {}
    # Try to reuse an existing dictionary produced by run_full_curve_all
    if results_dict is None:
        try:
            results_dict = results_full  # must exist if already computed
        except NameError:
            results_dict = None

    groups = df_scaled.groupby("Commune")
    for comm in communes:
        if (results_dict is not None) and (comm in results_dict):
            bundles[comm] = results_dict[comm]["bundle"]
        else:
            dfc = groups.get_group(comm).sort_values("Date").reset_index(drop=True)
            b = simulate_full_curve_cvae(
                dfc, n_sims=n_sims, temp_z=temp_z, temp_out=temp_out,
                use_soft_states=use_soft_states
            )
            if b is None:
                raise RuntimeError(f"Not enough data to simulate: {comm}")
            bundles[comm] = b
    return bundles

def plot_cvae_fullcurve_panels_cv_only(
    communes=TARGET_COMMUNES,
    results_dict=None,
    ncols=2,
    fig_width=16,
    row_height=5.0,
    save_basename="paper_panels_CVAE_only_50band_with_observed",
    title="Full-curve reconstructions — Observed & CVAE (median, mean, 50% band)"
):
    """
    Build a panel (3x2 for six communes) with:
      - Observed curve (black solid)
      - CVAE 50% band (lighter blue, more transparent)
      - CVAE median (blue dashed)
      - CVAE mean (orange solid; no dashed)
    """
    bundles = _ensure_bundles_for(communes, results_dict)

    n = len(communes)
    ncols = max(1, ncols)
    nrows = int(math.ceil(n / ncols))

    # Publication-oriented aesthetics
    plt.rcParams.update({
        "font.family": "DejaVu Sans",
        "axes.titlesize": 13,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 11,
        "savefig.dpi": 600,   # high-resolution export
    })

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(fig_width, row_height * nrows),
        sharex=False, sharey=False
    )
    axes = np.array(axes).reshape(-1)

    # Colors & styles
    col_obs   = "#000000"   # observed (black)
    col_median= "#1f77b4"   # blue for median
    col_mean  = "#ff7f0e"   # orange for mean
    col_band  = "#8bb6e8"   # slightly lighter blue for the band
    alpha50   = 0.20        # clearer (more transparent) band
    lw_obs    = 2.0
    lw_line   = 2.0

    for ax, comm in zip(axes, communes):
        b = bundles[comm]
        dates = b["dates"]

        # 50% band (lighter + clearer)
        ax.fill_between(dates, b["q25"], b["q75"],
                        color=col_band, alpha=alpha50, edgecolor="none")

        # Observed
        ax.plot(dates, b["truth"], color=col_obs, lw=lw_obs, label=None)

        # CVAE median (dashed) & mean (solid, no dashes)
        ax.plot(dates, b["q50"],  color=col_median, linestyle="--", lw=lw_line, label=None)  # median
        ax.plot(dates, b["mean"], color=col_mean,   linestyle="-",  lw=lw_line, label=None)  # mean (solid)

        ax.set_title(comm)
        ax.set_ylabel("Daily cases (7-day MA)")
        ax.grid(True, alpha=0.3)

    # Hide any unused subplots (if communes % ncols != 0)
    for k in range(len(communes), len(axes)):
        axes[k].axis("off")

    # Global legend
    legend_elems = [
        Patch(facecolor=col_band, edgecolor="none", alpha=alpha50, label="CVAE 50% band"),
        Line2D([0], [0], color=col_median, lw=lw_line, linestyle="--", label="CVAE median"),
        Line2D([0], [0], color=col_mean,   lw=lw_line, linestyle="-",  label="CVAE mean"),
        Line2D([0], [0], color=col_obs,    lw=lw_obs,  linestyle="-",  label="Observed"),
    ]
    fig.legend(legend_elems, [h.get_label() for h in legend_elems],
               loc="lower center", ncol=4, frameon=False)

    # Layout + optional super-title
    if title:
        fig.suptitle(title, y=0.99, fontsize=14)
        fig.tight_layout(rect=[0, 0.07, 1, 0.96])
    else:
        fig.tight_layout(rect=[0, 0.07, 1, 1])

    # Save both PNG and PDF (publication-ready vector)
    png_path = os.path.join(OUT_DIR, f"{save_basename}.png")
    pdf_path = os.path.join(OUT_DIR, f"{save_basename}.pdf")
    fig.savefig(png_path, dpi=600, bbox_inches="tight")
    fig.savefig(pdf_path, bbox_inches="tight")
    plt.show()
    print(f"✓ Saved: {png_path}")
    print(f"✓ Saved: {pdf_path}")

# ---- Run it (generates the 3×2 panel for the six target communes) ----
plot_cvae_fullcurve_panels_cv_only(
    communes=TARGET_COMMUNES,
    results_dict=None,        # pass results_full if you already have it; otherwise it simulates on the fly
    ncols=2,                  # 2 columns × 3 rows
    fig_width=16,
    row_height=5.0,
    save_basename="paper_panels_CVAE_only_50band_with_observed",
    title="Full-curve reconstructions — Observed & CVAE (median, mean, 50% band)"
)


In [None]:
# ============================================================
# Paper figure: CVAE-only full-curve panels for target communes
# - Shows: Observed (black), CVAE median (blue, dashed), CVAE mean (orange, solid)
# - Band: central 50% (slightly lighter, aesthetic)
# - Per-panel metrics card (top-right): AW50, COV50, AAD50, CRPS, MAE, RMSE
# - High-quality export: PNG (600 dpi) and PDF
# ============================================================

import os, math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

# -------- Target communes (edit if needed) --------
TARGET_COMMUNES = ["La Florida", "Cerrillos", "Vitacura",
                   "Providencia", "Las Condes", "Santiago"]

# -------- Safe OUT_DIR fallback --------
try:
    OUT_DIR  # if defined upstream, keep it
except NameError:
    OUT_DIR = "./eval_out"
os.makedirs(OUT_DIR, exist_ok=True)

# --------- CRPS (ensemble) & asymmetry helpers ----------
def crps_ens(y, samples):
    s = np.sort(np.asarray(samples).ravel())
    n = len(s); y = float(y)
    e1 = np.mean(np.abs(s - y))  # E|X - y|
    diffs = np.diff(s)
    weights = np.arange(1, n) * (n - np.arange(1, n))
    e2 = 2.0 * np.sum(weights * diffs) / (n*n)  # E|X - X'|
    return float(e1 - 0.5*e2)

def asymmetry_degree(lo, med, hi):
    """Signed asymmetry in [-1,1]; >0 = longer upper tail."""
    lo = float(lo); med = float(med); hi = float(hi)
    width = max(hi - lo, 1e-12)
    return ((hi - med) - (med - lo)) / width

# --------- Metrics from a simulation bundle ----------
def _metrics_from_bundle(bundle):
    """
    Expects keys: truth, q25, q50, q75, sims
    Returns dict with AW50, COV50, AAD50 (signed), CRPS, MAE, RMSE
    """
    y   = np.asarray(bundle["truth"], float)
    q25 = np.asarray(bundle["q25"],  float)
    q50 = np.asarray(bundle["q50"],  float)
    q75 = np.asarray(bundle["q75"],  float)

    # Width & coverage
    width50 = q75 - q25
    AW50  = float(np.mean(width50))
    COV50 = float(np.mean((y >= q25) & (y <= q75)))

    # AAD50 (signed)
    AAD50 = float(np.mean([asymmetry_degree(lo, md, hi) for lo, md, hi in zip(q25, q50, q75)]))

    # CRPS averaged in time (requires sims)
    sims = bundle.get("sims", None)
    if sims is not None and sims.size > 0:
        CRPS = float(np.mean([crps_ens(y[t], sims[:, t]) for t in range(len(y))]))
    else:
        CRPS = float("nan")

    # Point accuracy (vs median)
    MAE  = float(np.mean(np.abs(y - q50)))
    RMSE = float(np.sqrt(np.mean((y - q50)**2)))

    return {"AW50": AW50, "COV50": COV50, "AAD50": AAD50, "CRPS": CRPS, "MAE": MAE, "RMSE": RMSE}

# --------- Ensure bundles for communes ----------
def _ensure_bundles_for(communes,
                        results_dict=None,
                        n_sims=400,
                        temp_z=1.25,
                        temp_out=1.25,
                        use_soft_states=False):
    """
    Make sure we have the full-curve simulation bundles for each commune.
    Uses 'results_full' if available; otherwise calls simulate_full_curve_cvae.
    Returns: dict {commune: bundle}
      bundle keys: dates, truth, q05, q25, q50, q75, q95, mean, sims
    """
    # Need df_scaled and simulate_full_curve_cvae in scope
    try:
        df_scaled  # noqa
    except NameError:
        raise RuntimeError("df_scaled not found in scope. Load and prepare your dataframe first.")
    try:
        simulate_full_curve_cvae  # noqa
    except NameError:
        raise RuntimeError("simulate_full_curve_cvae not found. Define/import it before plotting.")

    bundles = {}
    # Reuse a previous dict produced by run_full_curve_all
    if results_dict is None:
        try:
            results_dict = results_full  # optional global
        except NameError:
            results_dict = None

    groups = df_scaled.groupby("Commune")
    for comm in communes:
        if (results_dict is not None) and (comm in results_dict):
            bundles[comm] = results_dict[comm]["bundle"]
        else:
            if comm not in groups.groups:
                raise RuntimeError(f"Commune '{comm}' not present in df_scaled.")
            dfc = groups.get_group(comm).sort_values("Date").reset_index(drop=True)
            b = simulate_full_curve_cvae(
                dfc, n_sims=n_sims, temp_z=temp_z, temp_out=temp_out,
                use_soft_states=use_soft_states
            )
            if b is None:
                raise RuntimeError(f"Not enough data to simulate: {comm}")
            bundles[comm] = b
    return bundles

# --------- Plot panels (with metrics box top-right) ----------
def plot_cvae_fullcurve_panels_with_metrics(
    communes=TARGET_COMMUNES,
    results_dict=None,
    ncols=2,
    fig_width=16,
    row_height=5.0,
    save_basename="paper_panels_CVAE_50band_with_observed_metrics",
    title="Full-curve reconstructions — Observed & CVAE (median, mean, 50% band)"
):
    """
    Build a multi-panel figure (e.g., 3x2 for 6 communes) with:
      - Observed (black solid), CVAE median (blue dashed), CVAE mean (orange solid)
      - 50% central band (light blue, subtle)
      - Metrics card (AW50, COV50, AAD50, CRPS, MAE, RMSE) top-right in each panel
    Saves PNG (600 dpi) and PDF to OUT_DIR.
    """
    bundles = _ensure_bundles_for(communes, results_dict)

    n = len(communes)
    ncols = max(1, ncols)
    nrows = int(math.ceil(n / ncols))

    # Publication-oriented aesthetics
    plt.rcParams.update({
        "font.family": "DejaVu Sans",
        "axes.titlesize": 13,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 11,
        "savefig.dpi": 600,   # high-resolution export
    })

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(fig_width, row_height * nrows),
        sharex=False, sharey=False
    )
    axes = np.array(axes).reshape(-1)

    # Colors & styles
    col_obs    = "#000000"   # observed
    col_median = "#1f77b4"   # median
    col_mean   = "#ff7f0e"   # mean (solid)
    col_band   = "#b7c1c9"   # slightly lighter than before
    alpha50    = 0.40
    lw_obs     = 2.0
    lw_line    = 2.0

    for ax, comm in zip(axes, communes):
        b = bundles[comm]
        dates = b["dates"]

        # 50% band (lighter & clean)
        ax.fill_between(dates, b["q25"], b["q75"],
                        color=col_band, alpha=alpha50, edgecolor="none")

        # Lines: observed (solid), median (dashed), mean (solid)
        ax.plot(dates, b["truth"], color=col_obs,   lw=lw_obs)
        ax.plot(dates, b["q50"],   color=col_median, linestyle="--", lw=lw_line)
        ax.plot(dates, b["mean"],  color=col_mean,   linestyle="-",  lw=lw_line)

        ax.set_title(comm)
        ax.set_ylabel("Daily cases (7-day MA)")
        ax.grid(True, alpha=0.7)

        # ---------- Metrics card (top-right) ----------
        m = _metrics_from_bundle(b)
        lines = [
            f"AW50: {m['AW50']:.1f}",
            f"COV50: {m['COV50']:.3f}",
            f"AAD50: {m['AAD50']:.3f}",
            f"CRPS: {m['CRPS']:.2f}" if not np.isnan(m["CRPS"]) else "CRPS: —",
            f"MAE: {m['MAE']:.2f}",
            f"RMSE: {m['RMSE']:.2f}",
        ]
        txt = "\n".join(lines)

        ax.text(
            0.98, 0.98, txt,                 # top-right corner in axes coords
            transform=ax.transAxes,
            va="top", ha="right",
            fontsize=12.0,                   # slightly larger
            color="#222222",
            bbox=dict(boxstyle="round,pad=0.35,rounding_size=0.12",
                      fc="white", ec="#4c4d4f", lw=0.9, alpha=0.88)
        )

    # Hide any unused subplots (if communes % ncols != 0)
    for k in range(len(communes), len(axes)):
        axes[k].axis("off")

    # Global legend
    legend_elems = [
        Patch(facecolor=col_band, edgecolor="none", alpha=alpha50, label="CVAE 50% band"),
        Line2D([0], [0], color=col_median, lw=lw_line, linestyle="--", label="CVAE median"),
        Line2D([0], [0], color=col_mean,   lw=lw_line, linestyle="-",  label="CVAE mean"),
        Line2D([0], [0], color=col_obs,    lw=lw_obs,  linestyle="-",  label="Observed"),
    ]
    fig.legend(legend_elems, [h.get_label() for h in legend_elems],
               loc="lower center", ncol=4, frameon=False)

    # Layout and export
    if title:
        fig.suptitle(title, y=0.99, fontsize=14)
        fig.tight_layout(rect=[0, 0.07, 1, 0.96])
    else:
        fig.tight_layout(rect=[0, 0.07, 1, 1])

    png_path = os.path.join(OUT_DIR, f"{save_basename}.png")
    pdf_path = os.path.join(OUT_DIR, f"{save_basename}.pdf")
    fig.savefig(png_path, dpi=600, bbox_inches="tight")
    fig.savefig(pdf_path, bbox_inches="tight")
    plt.show()
    print(f"✓ Saved: {png_path}")
    print(f"✓ Saved: {pdf_path}")

# ---- Run it (generates the 3×2 panel for the six target communes) ----
plot_cvae_fullcurve_panels_with_metrics(
    communes=TARGET_COMMUNES,
    results_dict=None,        # pass results_full if you have it; else it will simulate on the fly
    ncols=2,                  # 2 columns × 3 rows
    fig_width=16,
    row_height=5.0,
    save_basename="paper_panels_CVAE_50band_with_observed_metrics",
    title="Full-curve reconstructions — Observed & CVAE (median, mean, 50% band)"
)


In [None]:
# ============================================================
# Cori Rt (EpiEstim-like) from observed & CVAE-simulated curves
# - Discretized serial interval (Gamma) with mean/sd configurable
# - Sliding-window Cori posterior (Gamma) with user priors
# - Rt for observed series + Rt across CVAE simulations (quantiles)
# - CSV export + publication-quality panel plots (3×2)
# ============================================================

import os, math, numpy as np, pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

# -----------------------------
# Configuration (you can tweak)
# -----------------------------
SI_MEAN = 5.0          # mean of serial interval (days)
SI_SD   = 2.0          # std  of serial interval (days)
SI_MAX  = 30           # support truncation for w_s (days)

TAU     = 7            # Cori sliding window length
A0      = 1.0          # Gamma prior shape for R_t
B0      = 1.0          # Gamma prior rate  for R_t (mean = A0/B0)

ALPHA_CR = 0.05        # 95% credible intervals

# Colors for Rt plots
COL_OBS_LINE  = "#111111"  # observed Rt median (black-ish)
COL_OBS_BAND  = "#777777"  # observed 95% band (grey)
COL_CVAE_LINE = "#1f77b4"  # CVAE Rt median
COL_CVAE_BAND = "#1f77b433" # CVAE band (transparent-ish blue)

# ----------------------------------------------------------
# (Optional) ensure we have CVAE bundles if results_full is missing
# ----------------------------------------------------------
def ensure_cvae_bundles(communes, n_sims=400, temp_z=1.25, temp_out=1.25, use_soft_states=False):
    bundles = {}
    try:
        _res = results_full  # try global
    except NameError:
        _res = None
    groups = df_scaled.groupby("Commune")
    for comm in communes:
        if (_res is not None) and (comm in _res):
            bundles[comm] = _res[comm]["bundle"]
        else:
            dfc = groups.get_group(comm).sort_values("Date").reset_index(drop=True)
            b = simulate_full_curve_cvae(
                dfc, n_sims=n_sims, temp_z=temp_z, temp_out=temp_out, use_soft_states=use_soft_states
            )
            if b is None:
                raise RuntimeError(f"Not enough data to simulate: {comm}")
            bundles[comm] = b
    return bundles

# ----------------------------------------
# Discretized Gamma PMF for serial interval
# ----------------------------------------
def _gamma_k_mean_sd(mean, sd):
    # shape k, scale theta for Gamma(k, theta) with given mean/sd
    theta = (sd**2) / mean
    k = mean / theta
    return k, theta

def discretized_serial_interval(mean=SI_MEAN, sd=SI_SD, max_days=SI_MAX, eps=1e-12):
    """
    Return w[1..max_days], where w_s = P(serial interval = s).
    Discretization as probability mass at integer days via Gamma CDF differences.
    """
    from math import erf, sqrt
    # We do Gamma(k, theta) via scipy-like CDF approximation using numpy's gammainc? Not available.
    # Implement with numpy's gammainc: lower incomplete gamma normalized is available via math? Not.
    # => Use numerical integration by differences on continuous pdf at half-integers (midpoint rule).
    # Robust and accurate enough for daily bins.
    k, theta = _gamma_k_mean_sd(mean, sd)
    xs = np.arange(0.5, max_days + 0.5)  # midpoints of bins [s-0.5, s+0.5]
    # Gamma pdf: f(x) = x^{k-1} e^{-x/theta} / (Gamma(k) theta^k)
    # Use log-domain for stability
    from math import lgamma
    logZ = lgamma(k) + k * math.log(theta)
    pdf = np.exp((k - 1) * np.log(np.clip(xs, eps, None)) - xs / theta - logZ)
    w = pdf / (pdf.sum() + eps)
    # shift to w[1..max_days]
    return w

# ----------------------------------------
# Infectivity term Λ_t by discrete convolution
# ----------------------------------------
def infectivity_lambda(I, w):
    """
    I: incidence array (length T), non-negative floats
    w: serial interval pmf for lags 1..L
    Returns Λ of length T, with Λ[t] = sum_{s=1}^{min(t, L)} I[t-s] * w[s]
    """
    T = len(I)
    L = len(w)
    lam = np.zeros(T, dtype=float)
    for t in range(T):
        smax = min(t, L)
        if smax > 0:
            past = I[t-smax:t][::-1]      # I[t-1], I[t-2], ...
            lam[t] = float(np.dot(past, w[:smax]))
    return lam

# ----------------------------------------
# Cori posterior for a sliding window
# ----------------------------------------
def cori_posterior_rt(I, w, tau=TAU, a0=A0, b0=B0, alpha=ALPHA_CR, eps=1e-12):
    """
    EpiEstim-like Cori method with Gamma(a0,b0) prior (rate parameterization).
    For each day t>=tau:
       Posterior is Gamma(a_post, b_post) with
       a_post = a0 + sum_{k=t-tau+1..t} I_k
       b_post = b0 + sum_{k=t-tau+1..t} Λ_k,    Λ_k = sum_{s>=1} I_{k-s} w_s
    Returns DataFrame with posterior mean/median and CI.
    """
    I = np.asarray(I, dtype=float)
    w = np.asarray(w, dtype=float)
    T = len(I)
    lam = infectivity_lambda(I, w) + 0.0

    # precompute cumulative sums
    cI = np.cumsum(I)
    cL = np.cumsum(lam)

    means  = np.full(T, np.nan)
    med    = np.full(T, np.nan)
    qlo    = np.full(T, np.nan)
    qhi    = np.full(T, np.nan)
    ashape = np.full(T, np.nan)
    arate  = np.full(T, np.nan)

    from scipy.stats import gamma as sg  # if SciPy not available, use numpy.random.gamma quantile fallback
    have_scipy = True
    try:
        _ = sg.cdf(1.0, 1.0, scale=1.0)
    except Exception:
        have_scipy = False

    for t in range(tau-1, T):
        i_sum = cI[t] - (cI[t - tau] if t - tau >= 0 else 0.0)
        l_sum = cL[t] - (cL[t - tau] if t - tau >= 0 else 0.0)
        a_post = a0 + max(i_sum, 0.0)
        b_post = b0 + max(l_sum, 0.0) + eps  # avoid zero-rate

        means[t]  = a_post / b_post
        ashape[t] = a_post
        arate[t]  = b_post

        if have_scipy:
            # SciPy uses shape k, scale = 1/rate
            scale = 1.0 / b_post
            med[t] = sg.ppf(0.5, a_post, scale=scale)
            qlo[t] = sg.ppf(alpha/2.0, a_post, scale=scale)
            qhi[t] = sg.ppf(1.0 - alpha/2.0, a_post, scale=scale)
        else:
            # Fallback: approx median ~ k*(1 - 1/(9k))^3 / rate for large k
            k = a_post
            rate = b_post
            med[t] = (k * (1 - 1/(9*max(k,1e-6)))**3) / rate
            # crude quantiles via Wilson-Hilferty; acceptable for large k, just a fallback
            from math import sqrt
            z = 1.96
            qlo[t] = (k * (1 - z*sqrt(2/(9*k)))**3) / rate
            qhi[t] = (k * (1 + z*sqrt(2/(9*k)))**3) / rate

    out = pd.DataFrame({
        "Rt_mean": means, "Rt_median": med, "Rt_qlo": qlo, "Rt_qhi": qhi,
        "shape": ashape, "rate": arate
    })
    return out

# ----------------------------------------------------
# Rt for observed series and for CVAE simulations
# ----------------------------------------------------
def rt_observed_for_commune(commune, w=None, tau=TAU, a0=A0, b0=B0):
    dfc = df_scaled[df_scaled["Commune"]==commune].sort_values("Date").reset_index(drop=True)
    dates = dfc["Date"].values
    I = np.asarray(dfc[OBS_COL].values, dtype=float)  # observed incidence (7d MA in your data)
    if w is None:
        w = discretized_serial_interval()
    R = cori_posterior_rt(I, w, tau=tau, a0=a0, b0=b0)
    R.insert(0, "Date", dates)
    R.insert(1, "Commune", commune)
    return R

def rt_cvae_for_commune(commune, bundles=None, w=None, tau=TAU, a0=A0, b0=B0,
                        qlo=0.025, qhi=0.975):
    if bundles is None:
        bundles = ensure_cvae_bundles([commune])
    b = bundles[commune]
    sims  = np.asarray(b["sims"])          # (M, steps) steps = T-SEQ_LEN
    steps = sims.shape[1]
    # Build full incidence per sim: observed prefix (length SEQ_LEN) + simulated path
    dfc = df_scaled[df_scaled["Commune"]==commune].sort_values("Date").reset_index(drop=True)
    I_prefix = np.asarray(dfc[OBS_COL].values[:SEQ_LEN], dtype=float)
    T_full   = len(dfc)

    if w is None:
        w = discretized_serial_interval()

    # For each simulation, compute Rt posterior summary and keep the **posterior mean** time-series
    Rt_mean_all = np.zeros((sims.shape[0], T_full), dtype=float) * np.nan
    for m in range(sims.shape[0]):
        I_full = np.concatenate([I_prefix, sims[m,:]], axis=0)  # length T_full
        Rm = cori_posterior_rt(I_full, w, tau=tau, a0=a0, b0=b0)
        Rt_mean_all[m, :] = Rm["Rt_mean"].values

    # Aggregate across simulations (drop the prefix for reporting, align with b["dates"])
    Rt_mean_sims = Rt_mean_all[:, :]  # keep full for plotting observed vs cvae if needed
    Rt_qlo = np.nanquantile(Rt_mean_sims[:, SEQ_LEN:], qlo, axis=0)
    Rt_med = np.nanquantile(Rt_mean_sims[:, SEQ_LEN:], 0.5, axis=0)
    Rt_qhi = np.nanquantile(Rt_mean_sims[:, SEQ_LEN:], qhi, axis=0)

    out = pd.DataFrame({
        "Date": np.asarray(b["dates"]),         # SEQ_LEN..end
        "Commune": commune,
        "Rt_cvae_median": Rt_med,
        "Rt_cvae_qlo": Rt_qlo,
        "Rt_cvae_qhi": Rt_qhi
    })
    return out

# ----------------------------------------------------
# Batch: compute & save CSVs for target communes
# ----------------------------------------------------
def compute_and_save_rt_all(communes=TARGET_COMMUNES,
                            tau=TAU, a0=A0, b0=B0,
                            si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
                            out_dir=OUT_DIR):
    w = discretized_serial_interval(mean=si_mean, sd=si_sd, max_days=si_max)
    bundles = ensure_cvae_bundles(communes)

    rows_obs, rows_cvae = [], []
    for comm in communes:
        R_obs = rt_observed_for_commune(comm, w=w, tau=tau, a0=a0, b0=b0)
        R_obs.to_csv(os.path.join(out_dir, f"rt_observed_{comm.replace(' ','_')}.csv"), index=False)
        rows_obs.append(R_obs)

        R_cv  = rt_cvae_for_commune(comm, bundles=bundles, w=w, tau=tau, a0=a0, b0=b0)
        R_cv.to_csv(os.path.join(out_dir, f"rt_cvae_{comm.replace(' ','_')}.csv"), index=False)
        rows_cvae.append(R_cv)

    obs_all  = pd.concat(rows_obs, ignore_index=True)
    cvae_all = pd.concat(rows_cvae, ignore_index=True)
    obs_all.to_csv(os.path.join(out_dir, "rt_observed_all.csv"), index=False)
    cvae_all.to_csv(os.path.join(out_dir, "rt_cvae_all.csv"), index=False)
    print(f"✓ Saved Rt CSVs in {out_dir}")
    return obs_all, cvae_all, w

# ----------------------------------------------------
# Plot panels (3×2) of Rt for target communes
# ----------------------------------------------------
def plot_rt_panels(communes=TARGET_COMMUNES, tau=TAU, a0=A0, b0=B0,
                   si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
                   out_path=None, title=r"Time-varying reproduction number $R_t$ (Cori)"):
    w = discretized_serial_interval(mean=si_mean, sd=si_sd, max_days=si_max)
    bundles = ensure_cvae_bundles(communes)

    # Aesthetics for paper
    plt.rcParams.update({
        "font.family": "DejaVu Sans",
        "axes.titlesize": 13,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 11,
        "savefig.dpi": 600,
    })

    n = len(communes); ncols=2; nrows=int(math.ceil(n/ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(16, 4.6*nrows), sharex=False)
    axes = np.array(axes).reshape(-1)

    legend_elems = [
        Patch(facecolor=COL_OBS_BAND, edgecolor="none", alpha=0.25, label="Observed 95% CI"),
        Line2D([0],[0], color=COL_OBS_LINE, lw=2.0, label="Observed median"),
        Patch(facecolor=COL_CVAE_BAND, edgecolor="none", alpha=1.0, label="CVAE band (sim median)"),
        Line2D([0],[0], color=COL_CVAE_LINE, lw=2.0, label="CVAE median"),
    ]

    for ax, comm in zip(axes, communes):
        # observed
        R_obs = rt_observed_for_commune(comm, w=w, tau=tau, a0=a0, b0=b0)
        ax.fill_between(R_obs["Date"], R_obs["Rt_qlo"], R_obs["Rt_qhi"],
                        color=COL_OBS_BAND, alpha=0.25, edgecolor="none")
        ax.plot(R_obs["Date"], R_obs["Rt_median"], color=COL_OBS_LINE, lw=2.0)

        # cvae (across sims)
        R_cv  = rt_cvae_for_commune(comm, bundles=bundles, w=w, tau=tau, a0=a0, b0=b0)
        ax.fill_between(R_cv["Date"], R_cv["Rt_cvae_qlo"], R_cv["Rt_cvae_qhi"],
                        color=COL_CVAE_BAND, edgecolor="none")
        ax.plot(R_cv["Date"], R_cv["Rt_cvae_median"], color=COL_CVAE_LINE, lw=2.0)

        ax.axhline(1.0, color="#444444", lw=1.2, ls="--", alpha=0.7)
        ax.set_title(comm)
        ax.set_ylabel(r"$R_t$")
        ax.grid(True, alpha=0.3)

    # hide unused axes
    for k in range(len(communes), len(axes)):
        axes[k].axis("off")

    fig.legend(legend_elems, [h.get_label() for h in legend_elems],
               loc="lower center", ncol=4, frameon=False)

    if title:
        fig.suptitle(title, y=0.99, fontsize=14)
        fig.tight_layout(rect=[0,0.05,1,0.96])
    else:
        fig.tight_layout(rect=[0,0.05,1,1])

    if out_path is None:
        out_path = os.path.join(OUT_DIR, "rt_cori_panels.png")
    fig.savefig(out_path, dpi=600, bbox_inches="tight")
    plt.show()
    print(f"✓ Saved: {out_path}")

# -------------------------
# Run it
# -------------------------
obs_all, cvae_all, w_si = compute_and_save_rt_all(
    communes=TARGET_COMMUNES,
    tau=TAU, a0=A0, b0=B0,
    si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
    out_dir=OUT_DIR
)
plot_rt_panels(
    communes=TARGET_COMMUNES,
    tau=TAU, a0=A0, b0=B0,
    si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
    out_path=os.path.join(OUT_DIR, "rt_cori_panels.png"),
    title=r"Time-varying reproduction number $R_t$ (Cori): observed vs. CVAE"
)


In [None]:
# ------------------ Robust setup for OBS_COL + Wallinga–Teunis run ------------------
import os
import numpy as np

# 1) Asegura una columna con incidencia observada en unidades originales
#    Usa tu columna fuente (p.ej. CASE_COL) para poblarla si aún no existe.
if 'OBS_COL' not in locals():
    OBS_COL = "Observed_Cases"   # nombre que usaremos para la columna de observados

# Si df_scaled aún no tiene la columna de observados en unidades originales, créala.
if OBS_COL not in df_scaled.columns:
    # Asumimos que df (no escalado) y CASE_COL existen y están en unidades originales
    # (como configuraste en tu flujo corregido).
    if ('df' in locals()) and (CASE_COL in df.columns):
        df_scaled[OBS_COL] = df[CASE_COL].values
    else:
        raise RuntimeError(
            f"No encuentro datos originales para crear {OBS_COL}. "
            f"Asegúrate de tener 'df' sin escalar y la columna CASE_COL='{CASE_COL}'."
        )

# Limpiezas y checks mínimos
df_scaled[OBS_COL] = np.asarray(df_scaled[OBS_COL], float)
df_scaled[OBS_COL] = np.clip(df_scaled[OBS_COL], a_min=0.0, a_max=None)  # sin negativos
assert df_scaled.groupby("Commune")[OBS_COL].apply(lambda s: s.notna().all()).all(), \
    f"Hay NaNs en {OBS_COL} para alguna comuna."

# 2) Parámetros SI y paths
SI_MEAN = 4.7
SI_SD   = 2.9
SI_MAX  = 28
B_SAMPLES = 400

WT_DIR = os.path.join(OUT_DIR, "WT")
os.makedirs(WT_DIR, exist_ok=True)

# 3) Ejecuta WT usando la columna OBS_COL (¡sin comillas!)
WT_RES = run_wt_all(
    df_scaled,
    communes=TARGET_COMMUNES,
    si_mean=SI_MEAN,
    si_sd=SI_SD,
    si_max=SI_MAX,
    mode="sampled",               # "expected" = más rápido (curva puntual), "sampled" = con bandas
    B=B_SAMPLES,
    correct_right_censor=True,
    out_dir=WT_DIR,
    use_column=OBS_COL            # <-- variable, no string literal
)

# 4) Panel de alta calidad (2x3)
WT_FIG_PATH = os.path.join(WT_DIR, "wt_panels.png")
plot_wt_panels(
    WT_RES,
    TARGET_COMMUNES,
    ncols=2,
    fig_width=16,
    row_height=5.0,
    title="Wallinga–Teunis case reproduction number (observed incidence)",
    save_path=WT_FIG_PATH
)

print(f"\n✓ Listo. CSVs en: {WT_DIR}")
print(f"✓ Figura: {WT_FIG_PATH}")


In [None]:
# ============================================================
# Unified Rt panels: Cori (observed), Cori (CVAE), WT — per commune
# - Bands: 95% (Cori-observed, gray), 90% (Cori-CVAE, blue), 90% (WT, green)
# - Lines: medians (solid), plus Rt=1 reference
# - Publication quality (high DPI)
# ============================================================

import os, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from datetime import datetime

# Example list (adjust as needed)
TARGET_COMMUNES = ["La Florida", "Cerrillos", "Vitacura",
                   "Providencia", "Las Condes", "Santiago"]

def _series_from(dict_, keys, name):
    """Build a DataFrame (Date index) from dictionary arrays with given keys."""
    # keys: {"date":"dates", "med":"Rt_median", "lo":"Rt_lo", "hi":"Rt_hi"}
    ds = pd.to_datetime(pd.Series(dict_[keys["date"]], dtype="object"))
    df = pd.DataFrame({
        f"{name}_med": dict_[keys["med"]],
        f"{name}_lo":  dict_[keys["lo"]],
        f"{name}_hi":  dict_[keys["hi"]],
    }, index=ds)
    df = df[~df.index.duplicated(keep="first")].sort_index()
    return df

def _last_invalid_span(valid_arr):
    """Return start date of the last invalid (0) run, else None."""
    v = np.asarray(valid_arr, int)
    if v.size == 0 or v[-1] != 0:
        return None
    # find last transition from 1 -> 0
    i = len(v) - 1
    while i >= 0 and v[i] == 0:
        i -= 1
    return i + 1  # first invalid index

def plot_rt_unified_panels(
    RT_RES, WT_RES, communes=TARGET_COMMUNES,
    ncols=2, fig_width=16, row_height=5.0,
    save_path=None, title="Time-varying reproduction number $R_t$: Cori (obs & CVAE) and WT"
):
    """
    RT_RES: dict[commune] -> {dates, Rt_obs_median, Rt_obs_lo, Rt_obs_hi,
                              Rt_cvae_median, Rt_cvae_lo, Rt_cvae_hi}
    WT_RES: dict[commune] -> {dates, Rt_median, Rt_lo, Rt_hi, valid(optional)}
    """
    # Aesthetics
    plt.rcParams.update({
        "font.family": "DejaVu Sans",
        "axes.titlesize": 12.5,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10.5,
        "savefig.dpi": 600,
    })

    # Colors & alpha
    col_obs_band = "#808080"   # gray
    col_obs_line = "#222222"   # near-black
    col_cvae_band = "#4C78A8"  # blue-ish
    col_cvae_line = "#2F5B8B"
    col_wt_band  = "#59A14F"   # green-ish
    col_wt_line  = "#2E7D32"
    a_obs, a_cvae, a_wt = 0.18, 0.22, 0.22
    lw = 2.0

    n = len(communes)
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(fig_width, row_height*nrows), sharex=False, sharey=False)
    axes = np.array(axes).reshape(-1)

    for ax, comm in zip(axes, communes):
        if comm not in RT_RES or comm not in WT_RES:
            ax.text(0.5, 0.5, f"No data for {comm}", ha="center", va="center", transform=ax.transAxes)
            ax.axis("off"); continue

        # Build aligned series (inner join on dates)
        rt = RT_RES[comm]
        wt = WT_RES[comm]

        df_obs = _series_from(rt,
                              {"date":"dates","med":"Rt_obs_median","lo":"Rt_obs_lo","hi":"Rt_obs_hi"},
                              "obs")
        df_cvae = _series_from(rt,
                               {"date":"dates","med":"Rt_cvae_median","lo":"Rt_cvae_lo","hi":"Rt_cvae_hi"},
                               "cvae")
        df_wt = _series_from(wt,
                             {"date":"dates","med":"Rt_median","lo":"Rt_lo","hi":"Rt_hi"},
                             "wt")

        # Align by intersection of dates (safer to avoid edge artifacts)
        df = df_obs.join(df_cvae, how="inner").join(df_wt, how="inner")
        if df.empty:
            ax.text(0.5, 0.5, f"No aligned dates for {comm}", ha="center", va="center", transform=ax.transAxes)
            ax.axis("off"); continue

        dates = df.index

        # Bands
        ax.fill_between(dates, df["obs_lo"],  df["obs_hi"],  color=col_obs_band,  alpha=a_obs,  edgecolor="none", label=None)
        ax.fill_between(dates, df["cvae_lo"], df["cvae_hi"], color=col_cvae_band, alpha=a_cvae, edgecolor="none", label=None)
        ax.fill_between(dates, df["wt_lo"],   df["wt_hi"],   color=col_wt_band,   alpha=a_wt,   edgecolor="none", label=None)

        # Medians
        ax.plot(dates, df["obs_med"],  color=col_obs_line,  lw=lw, label=None)
        ax.plot(dates, df["cvae_med"], color=col_cvae_line, lw=lw, label=None)
        ax.plot(dates, df["wt_med"],   color=col_wt_line,   lw=lw, label=None)

        # Rt=1 reference
        ax.axhline(1.0, color="#000000", lw=1.0, ls=":", alpha=0.8)

        # Right-censor shading (if WT provides validity mask)
        if "valid" in wt:
            valid = np.asarray(wt["valid"], int)
            if len(valid) == len(wt["dates"]):
                idx0 = _last_invalid_span(valid)
                if idx0 is not None and idx0 < len(wt["dates"]):
                    t0 = pd.to_datetime(wt["dates"][idx0])
                    t1 = pd.to_datetime(wt["dates"][-1])
                    ax.axvspan(t0, t1, color="#9E9E9E", alpha=0.12, lw=0, zorder=0)

        ax.set_title(comm)
        ax.set_ylabel(r"$R_t$")
        ax.grid(True, alpha=0.3)

    # Hide unused axes
    for k in range(len(communes), len(axes)):
        axes[k].axis("off")

    # Legend
    legend_elems = [
        Patch(facecolor=col_obs_band,  edgecolor="none", alpha=a_obs,  label="Cori (obs) 95% band"),
        Line2D([0],[0], color=col_obs_line,  lw=lw, label="Cori (obs) median"),
        Patch(facecolor=col_cvae_band, edgecolor="none", alpha=a_cvae, label="Cori (CVAE) 90% band"),
        Line2D([0],[0], color=col_cvae_line, lw=lw, label="Cori (CVAE) median"),
        Patch(facecolor=col_wt_band,   edgecolor="none", alpha=a_wt,   label="WT 90% band"),
        Line2D([0],[0], color=col_wt_line,   lw=lw, label="WT median"),
        Line2D([0],[0], color="#000000", lw=1.0, ls=":", label=r"$R_t=1$")
    ]
    fig.legend(legend_elems, [h.get_label() for h in legend_elems],
               loc="lower center", ncol=4, frameon=False)

    if title:
        fig.suptitle(title, y=0.995, fontsize=14)

    fig.tight_layout(rect=[0,0.05,1,0.98])
    if save_path is None:
        save_path = os.path.join(OUT_DIR, "rt_unified_panels.png")
    fig.savefig(save_path, dpi=600, bbox_inches="tight")
    plt.show()
    print(f"✓ Saved: {save_path}")



In [None]:
import numpy as np
import pandas as pd
import os

# ------------------------------------------------------------
# Discretize serial interval as a Gamma-based discrete PMF
# (mid-point approximation; numerically stable, SciPy-free)
# ------------------------------------------------------------
def discretize_si_gamma(mean_si=4.7, sd_si=2.9, max_days=28):
    """Return w with shape (max_days+1,), w[0]=0 by convention, sum(w)=1 over 1..max_days."""
    mu, sd = float(mean_si), float(sd_si)
    k = (mu / sd) ** 2             # shape
    theta = (sd ** 2) / mu         # scale
    # Mid-point rule for discrete days 1..max_days
    s = np.arange(1, max_days + 1, dtype=float)
    # Gamma pdf at mid-points (s - 0.5), clipped to >= 1e-6 to avoid 0 underflow
    x = np.clip(s - 0.5, 1e-6, None)
    pdf_mid = (x**(k-1) * np.exp(-x/theta)) / (np.math.gamma(k) * (theta**k))
    w = pdf_mid / pdf_mid.sum()
    # prepend w[0]=0 for convenience in convolutions
    w = np.concatenate([[0.0], w])
    return w

# ------------------------------------------------------------
# Renewal equation pieces: lambda_t and Cori posterior
# ------------------------------------------------------------
def _renewal_lambda(I, w):
    """
    I: incidence array (T,)
    w: PMF array with w[0]=0, len = L+1 (L = si_max)
    Returns lambda array (T,), lambda[t] = sum_{s=1..min(t,L)} I[t-s]*w[s]
    """
    T = len(I)
    L = len(w) - 1
    lam = np.zeros(T, dtype=float)
    for t in range(T):
        smax = min(t, L)
        if smax > 0:
            # reversed slice of last smax incidences * w[1:smax+1]
            lam[t] = np.dot(I[t-smax:t][::-1], w[1:smax+1])
    return lam

def _cori_posterior(I, w, tau=7, a0=1.0, b0=5.0, draws=400, rng=None):
    """
    Cori et al. posterior with Gamma(shape=a, rate=b).
    Window [t-tau+1, t] assumes Rt is constant in that window.
    Returns dict with mean,q05,q50,q95,valid (NaN outside valid).
    """
    if rng is None:
        rng = np.random.default_rng(123)
    I = np.asarray(I, float)
    T = len(I)
    lam = _renewal_lambda(I, w)

    Rt_mean = np.full(T, np.nan, float)
    q05 = np.full(T, np.nan, float)
    q50 = np.full(T, np.nan, float)
    q95 = np.full(T, np.nan, float)
    valid = np.zeros(T, dtype=bool)

    # rolling sums for speed
    csum_I = np.cumsum(I)
    csum_lam = np.cumsum(lam)

    for t in range(T):
        if t < tau - 1:
            continue
        i0 = t - tau + 1
        sum_I = csum_I[t] - (csum_I[i0-1] if i0 > 0 else 0.0)
        sum_lam = csum_lam[t] - (csum_lam[i0-1] if i0 > 0 else 0.0)

        a = a0 + sum_I
        b = b0 + max(0.0, sum_lam)
        if b <= 0.0:
            continue

        Rt_mean[t] = a / b
        # Sample draws from Gamma(shape=a, scale=1/b)
        samp = rng.gamma(shape=a, scale=1.0/b, size=draws)
        q05[t], q50[t], q95[t] = np.quantile(samp, [0.05, 0.50, 0.95])
        valid[t] = True

    return {"mean": Rt_mean, "q05": q05, "q50": q50, "q95": q95, "valid": valid}

def _cori_posterior_from_sims(I_obs_full, sims, w, tau=7, a0=1.0, b0=5.0,
                              draws_per_sim=1, rng=None):
    """
    Aggregate Cori posterior over CVAE simulations.
    I_obs_full: observed incidence over full length T (array)
    sims: array (M, steps) = trajectories simulated por el CVAE
    We stitch: I_full^(m) = concat(I_obs_full[:T-steps], sims[m,:])
    Returns dict with aggregated mean (across sims) and quantiles (from posterior draws across sims).
    """
    if rng is None:
        rng = np.random.default_rng(123)
    M, steps = sims.shape
    T = len(I_obs_full)
    hist = T - steps
    assert hist >= 1, "Sim length inconsistent with observed length."

    # Storage
    means_acc = np.zeros(T, dtype=float)
    draws_acc = [[] for _ in range(T)]
    valid_any = np.zeros(T, dtype=bool)

    for m in range(M):
        I_full = np.concatenate([I_obs_full[:hist], sims[m, :]], axis=0)
        lam = _renewal_lambda(I_full, w)

        csum_I = np.cumsum(I_full)
        csum_lam = np.cumsum(lam)

        for t in range(T):
            if t < tau - 1:
                continue
            i0 = t - tau + 1
            sum_I = csum_I[t] - (csum_I[i0-1] if i0 > 0 else 0.0)
            sum_lam = csum_lam[t] - (csum_lam[i0-1] if i0 > 0 else 0.0)
            a = a0 + sum_I
            b = b0 + max(0.0, sum_lam)
            if b <= 0.0:
                continue
            means_acc[t] += a / b
            # one (or few) posterior draw(s) per sim to build an ensemble of draws
            if draws_per_sim > 0:
                samp = np.atleast_1d(np.random.gamma(shape=a, scale=1.0/b, size=draws_per_sim))
                draws_acc[t].extend(samp.tolist())
            valid_any[t] = True

    # Aggregate
    Rt_mean = np.full(T, np.nan, float)
    q05 = np.full(T, np.nan, float)
    q50 = np.full(T, np.nan, float)
    q95 = np.full(T, np.nan, float)

    nz = (valid_any & (np.array([len(x) for x in draws_acc]) > 0))
    if np.any(valid_any):
        Rt_mean[valid_any] = means_acc[valid_any] / float(M)
    for t in np.where(nz)[0]:
        arr = np.asarray(draws_acc[t], float)
        q05[t], q50[t], q95[t] = np.quantile(arr, [0.05, 0.50, 0.95])

    return {"mean": Rt_mean, "q05": q05, "q50": q50, "q95": q95, "valid": valid_any}

# ------------------------------------------------------------
# Public runner: Cori over communes (observed / CVAE / both)
# ------------------------------------------------------------
def run_rt_cori_all(df_scaled,
                    results_dict=None,         # dict tipo results_full (para "sim" o "both")
                    communes=None,
                    si_mean=4.7, si_sd=2.9, si_max=28,
                    tau=7,
                    mode="both",               # "obs" | "sim" | "both"
                    B=400,                     # draws por día para "obs"; y 1 draw/sim para "sim"
                    use_column="Observed_Cases",
                    out_dir=None,
                    a0=1.0, b0=5.0,
                    seed=123):
    """
    Devuelve un dict:
      RT_RES[comm] = {
         "dates": np.array(datetime64),
         "Rt_obs": {mean,q05,q50,q95,valid}  (si mode incluye obs)
         "Rt_sim": {mean,q05,q50,q95,valid}  (si mode incluye sim)
      }
    Si out_dir no es None, guarda CSV por comuna.
    """
    assert mode in ("obs", "sim", "both")
    if communes is None:
        communes = sorted(df_scaled["Commune"].unique().tolist())
    groups = df_scaled.groupby("Commune")
    w = discretize_si_gamma(mean_si=si_mean, sd_si=si_sd, max_days=si_max)
    rng = np.random.default_rng(seed)

    results = {}
    for comm in communes:
        dfc = groups.get_group(comm).sort_values("Date").reset_index(drop=True)
        if use_column not in dfc.columns:
            raise KeyError(f"{use_column} not found in df for commune {comm}.")
        dates = dfc["Date"].values
        I_obs = np.asarray(dfc[use_column].values, float)
        out_comm = {"dates": dates}

        if mode in ("obs", "both"):
            post_obs = _cori_posterior(I_obs, w, tau=tau, a0=a0, b0=b0, draws=B, rng=rng)
            out_comm["Rt_obs"] = post_obs

        if mode in ("sim", "both"):
            if (results_dict is None) or (comm not in results_dict):
                raise ValueError(f"results_dict missing CVAE sims for {comm}.")
            sims = np.asarray(results_dict[comm]["bundle"]["sims"], float)  # (M, steps)
            post_sim = _cori_posterior_from_sims(
                I_obs_full=I_obs, sims=sims, w=w, tau=tau, a0=a0, b0=b0,
                draws_per_sim=1, rng=rng
            )
            out_comm["Rt_sim"] = post_sim

        results[comm] = out_comm

        # Optional CSVs
        if out_dir is not None:
            os.makedirs(out_dir, exist_ok=True)
            df_out = pd.DataFrame({"Date": dates})
            if "Rt_obs" in out_comm:
                df_out["Rt_obs_mean"] = out_comm["Rt_obs"]["mean"]
                df_out["Rt_obs_q05"]  = out_comm["Rt_obs"]["q05"]
                df_out["Rt_obs_q50"]  = out_comm["Rt_obs"]["q50"]
                df_out["Rt_obs_q95"]  = out_comm["Rt_obs"]["q95"]
                df_out["Rt_obs_valid"]= out_comm["Rt_obs"]["valid"].astype(int)
            if "Rt_sim" in out_comm:
                df_out["Rt_sim_mean"] = out_comm["Rt_sim"]["mean"]
                df_out["Rt_sim_q05"]  = out_comm["Rt_sim"]["q05"]
                df_out["Rt_sim_q50"]  = out_comm["Rt_sim"]["q50"]
                df_out["Rt_sim_q95"]  = out_comm["Rt_sim"]["q95"]
                df_out["Rt_sim_valid"]= out_comm["Rt_sim"]["valid"].astype(int)
            df_out.to_csv(os.path.join(out_dir, f"rt_cori_{comm.replace(' ','_')}.csv"), index=False)

    return results


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

def _get_quant_series_from_nested(d, prefix="Rt_obs"):
    """
    Extrae (med, lo, hi) desde un dict anidado con estructura:
      d[prefix] = {"mean","q05","q50","q95","valid"}
    Devuelve dict con keys: med, lo, hi (o None si no existe).
    """
    if d is None or prefix not in d:
        return None
    blk = d[prefix]
    if not all(k in blk for k in ("q05","q50","q95")):
        return None
    return {"med": np.asarray(blk["q50"], float),
            "lo":  np.asarray(blk["q05"], float),
            "hi":  np.asarray(blk["q95"], float)}

def _get_quant_series_from_flat(d, med_key, lo_key, hi_key):
    """
    Extrae (med, lo, hi) desde claves planas (p.ej. WT con Rt_q50, Rt_q05, Rt_q95)
    """
    if not all(k in d for k in (med_key, lo_key, hi_key)):
        return None
    return {"med": np.asarray(d[med_key], float),
            "lo":  np.asarray(d[lo_key],  float),
            "hi":  np.asarray(d[hi_key],  float)}

def _dates_array(d):
    # d["dates"] puede venir como numpy datetime64 o lista de timestamps
    return pd.to_datetime(pd.Series(d["dates"], dtype="object")).values

def plot_rt_unified_panels(RT_RES, WT_RES, communes,
                           ncols=2,
                           fig_width=16,
                           row_height=5.0,
                           save_path=None,
                           title=r"Time-varying reproduction number $R_t$: Cori (obs \& CVAE) and WT"):
    """
    Panel unificado por comuna:
      - Cori (observado): banda 90% + mediana sólida
      - Cori (CVAE sims): banda 90% + mediana discontinua
      - Wallinga–Teunis: mediana (línea punto-trazo)
      - Línea horizontal R_t = 1
    Soporta estructuras anidadas (run_rt_cori_all) y planas (WT).
    """
    # Estética de publicación
    plt.rcParams.update({
        "font.family": "DejaVu Sans",
        "axes.titlesize": 13,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 11,
        "savefig.dpi": 600,
    })

    n = len(communes)
    ncols = max(1, ncols)
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols,
                             figsize=(fig_width, row_height*nrows),
                             sharex=False, sharey=False)
    axes = np.array(axes).reshape(-1)

    # Colores/estilos
    col_obs   = "#1f77b4"  # azul: Cori (obs)
    col_sim   = "#ff7f0e"  # naranjo: Cori (CVAE)
    col_wt    = "#2ca02c"  # verde: WT
    alpha_obs = 0.20
    alpha_sim = 0.15
    lw_med    = 2.0

    for ax, comm in zip(axes, communes):
        # ----- Cori (obs/sim) desde RT_RES -----
        if comm not in RT_RES:
            ax.text(0.5, 0.5, f"Missing in RT_RES:\n{comm}", ha="center", va="center", transform=ax.transAxes)
            ax.axis("off")
            continue

        rt = RT_RES[comm]
        dates = _dates_array(rt)

        # Cori observado (anidado)
        s_obs = _get_quant_series_from_nested(rt, "Rt_obs")
        # Cori CVAE (anidado)
        s_sim = _get_quant_series_from_nested(rt, "Rt_sim")

        # ----- WT desde WT_RES (puede venir con claves planas) -----
        s_wt = None
        if (WT_RES is not None) and (comm in WT_RES):
            wt = WT_RES[comm]
            # Admite dos convenciones: ("q50","q05","q95") o ("Rt_q50","Rt_q05","Rt_q95")
            if all(k in wt for k in ("q50","q05","q95")):
                s_wt = _get_quant_series_from_flat(wt, "q50","q05","q95")
            else:
                s_wt = _get_quant_series_from_flat(wt, "Rt_q50","Rt_q05","Rt_q95")
            dates_wt = _dates_array(wt) if "dates" in wt else dates
        else:
            dates_wt = dates

        # ----- Plots -----
        # Banda y mediana Cori (obs)
        if s_obs is not None:
            ax.fill_between(dates, s_obs["lo"], s_obs["hi"],
                            color=col_obs, alpha=alpha_obs, edgecolor="none",
                            label="Cori (obs) 90%")
            ax.plot(dates, s_obs["med"], color=col_obs, lw=lw_med, linestyle="-",
                    label="Cori (obs) median")

        # Banda y mediana Cori (CVAE)
        if s_sim is not None:
            ax.fill_between(dates, s_sim["lo"], s_sim["hi"],
                            color=col_sim, alpha=alpha_sim, edgecolor="none",
                            label="Cori (CVAE) 90%")
            ax.plot(dates, s_sim["med"], color=col_sim, lw=lw_med, linestyle="--",
                    label="Cori (CVAE) median")

        # WT mediana
        if s_wt is not None:
            ax.plot(dates_wt, s_wt["med"], color=col_wt, lw=1.8, linestyle="-.",
                    label="WT median")

        # Línea crítica R_t = 1
        ax.axhline(1.0, color="gray", lw=1.2, linestyle=":", alpha=0.8)

        ax.set_title(comm)
        ax.set_ylabel(r"$R_t$")
        ax.grid(True, alpha=0.3)

    # Ocultar subplots vacíos
    for k in range(len(communes), len(axes)):
        axes[k].axis("off")

    # Leyenda global
    legend_elems = [
        Patch(facecolor=col_obs, edgecolor="none", alpha=alpha_obs, label="Cori (obs) 90%"),
        Line2D([0],[0], color=col_obs, lw=lw_med, linestyle="-", label="Cori (obs) median"),
        Patch(facecolor=col_sim, edgecolor="none", alpha=alpha_sim, label="Cori (CVAE) 90%"),
        Line2D([0],[0], color=col_sim, lw=lw_med, linestyle="--", label="Cori (CVAE) median"),
        Line2D([0],[0], color=col_wt,  lw=1.8,   linestyle="-.", label="WT median"),
        Line2D([0],[0], color="gray",  lw=1.2,   linestyle=":",  label=r"$R_t=1$")
    ]
    fig.legend(legend_elems, [h.get_label() for h in legend_elems],
               loc="lower center", ncol=6, frameon=False)

    if title:
        fig.suptitle(title, y=0.99, fontsize=14)
        fig.tight_layout(rect=[0, 0.06, 1, 0.965])
    else:
        fig.tight_layout(rect=[0, 0.06, 1, 1])

    if save_path is None:
        save_path = "rt_unified_panels.png"
    fig.savefig(save_path, dpi=600, bbox_inches="tight")
    plt.show()
    print(f"✓ Saved: {save_path}")


In [None]:
RT_RES = run_rt_cori_all(
    df_scaled=df_scaled,
    results_dict=results_for_rt,    # o results_full si ya lo tienes
    communes=TARGET_COMMUNES,
    si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
    tau=TAU,
    mode="both",
    B=B_SAMPLES,
    use_column=OBS_COL,
    out_dir=RT_DIR
)


In [None]:
# ============================================================
# 0) Asegurar OBS_COL en df_scaled (incidencia en unidades originales)
# ============================================================
OBS_COL = "Observed_Cases"

def attach_observed_cases(df_scaled_in, CASE_COL=CASE_COL):
    """Devuelve df_scaled con OBS_COL en unidades originales.
    Estrategia:
      1) Si 'df' (datos crudos) existe → copia directa.
      2) Si hay 'scaler' → inverse_transform.
      3) Fallback con CASE_MIN/CASE_MAX (MinMax).
    """
    df_s = df_scaled_in.copy()
    if OBS_COL in df_s.columns:
        return df_s

    # (1) Copia desde df crudo si está en memoria
    try:
        if ('df' in globals()) and (CASE_COL in df.columns):
            df_s[OBS_COL] = df[CASE_COL].values
            return df_s
    except Exception:
        pass

    # (2) Inversa vía scaler si está disponible
    try:
        _ = scaler.data_min_  # comprobar que es el MinMax ya entrenado
        idx = FEATURES.index(CASE_COL)
        X_scaled = df_s[FEATURES].values
        X_unscaled = scaler.inverse_transform(X_scaled)
        df_s[OBS_COL] = X_unscaled[:, idx]
        return df_s
    except Exception:
        pass

    # (3) Fallback solo con min/max del caso
    try:
        rng = CASE_MAX - CASE_MIN
        df_s[OBS_COL] = CASE_MIN + df_s[CASE_COL].values * rng
        return df_s
    except Exception as e:
        raise RuntimeError(
            "Cannot reconstruct OBS_COL; re-ejecuta el bloque de carga que define 'df' y 'scaler'."
        ) from e

# Aplicar el fix
df_scaled = attach_observed_cases(df_scaled, CASE_COL=CASE_COL)

# ============================================================
# 1) Bundles CVAE (si no están en memoria) para las comunas objetivo
# ============================================================
TARGET_COMMUNES = ["La Florida", "Cerrillos", "Vitacura",
                   "Providencia", "Las Condes", "Santiago"]

import os
os.makedirs(OUT_DIR, exist_ok=True)

try:
    _ = [results_full[c] for c in TARGET_COMMUNES]
    results_for_rt = results_full
except Exception:
    results_for_rt, _, _ = run_full_curve_all(
        df_scaled,
        communes=TARGET_COMMUNES,
        n_sims=400,
        temp_z=1.25,
        temp_out=1.25,
        use_soft_states=False,
        out_dir=OUT_DIR
    )

# ============================================================
# 2) Cori (observado + CVAE) -> RT_RES
# ============================================================
SI_MEAN, SI_SD, SI_MAX = 4.7, 2.9, 28
TAU = 7
B_SAMPLES = 400

RT_DIR = os.path.join(OUT_DIR, "RT_Cori")
os.makedirs(RT_DIR, exist_ok=True)

RT_RES = run_rt_cori_all(
    df_scaled=df_scaled,
    results_dict=results_for_rt,
    communes=TARGET_COMMUNES,
    si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
    tau=TAU,
    mode="both",            # 'obs' | 'sim' | 'both'
    B=B_SAMPLES,
    use_column=OBS_COL,
    out_dir=RT_DIR
)

# ============================================================
# 3) Wallinga–Teunis (observado) -> WT_RES
# ============================================================
WT_DIR = os.path.join(OUT_DIR, "RT_WT")
os.makedirs(WT_DIR, exist_ok=True)

WT_RES = run_wt_all(
    df_scaled=df_scaled,
    communes=TARGET_COMMUNES,
    si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
    mode="sampled",         # 'expected' | 'sampled'
    B=B_SAMPLES,
    correct_right_censor=True,
    out_dir=WT_DIR,
    use_column=OBS_COL
)

# ============================================================
# 4) Panel unificado: Cori (obs & CVAE) + WT
# ============================================================
plot_rt_unified_panels(
    RT_RES=RT_RES,
    WT_RES=WT_RES,
    communes=TARGET_COMMUNES,
    ncols=2,
    fig_width=16,
    row_height=5.0,
    save_path=os.path.join(OUT_DIR, "rt_unified_panels.png"),
    title=r"Time-varying reproduction number $R_t$: Cori (obs \& CVAE) and WT"
)

print("✓ Panel guardado en:", os.path.join(OUT_DIR, "rt_unified_panels.png"))
