In [None]:
# ============================================================
# Rt panels for Copula simulations: Cori (obs & Copula) + WT (stable)
# - Load simulations from CSV (R: simulated_cases_severity_weighted_v2.csv)
# - Calculate Cori (observed and posterior mixture over simulations)
# - Calculate stable WT (observed) with explosion control
# - Draw unified panel by commune (publication, high resolution)
# ============================================================

import os, math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from datetime import datetime

# If SciPy is available (usually yes), use it for Gamma CDF/PPF
from scipy.stats import gamma as gamma_dist

# -------------------- Configuration --------------------
# Path to CSV with COPULA simulations generated in R
COPULA_SIM_CSV = "path/to/simulated_cases_severity_weighted_v2.csv"

# If you already have df_scaled in memory with the observed column OBS_COL, we'll use it.
# If not, you can define OBS_DATA_CSV to load observed data from another file (optional).
OBS_DATA_CSV = None  # e.g., "/path/to/observed.csv" with columns ["Date","Commune","Observed_Cases"]

# Name of column with observed incidence in df_scaled or in OBS_DATA_CSV.
# We'll try to detect automatically if not defined.
OBS_COL = None  # e.g., "Observed_Cases" or "Gross_Daily_Cases_Mobile_Average_7_Days"

# Target communes
TARGET_COMMUNES = ["La Florida", "Cerrillos", "Vitacura", "Providencia", "Las Condes", "Santiago"]

# Serial interval parameters (discretized Gamma)
SI_MEAN, SI_SD, SI_MAX = 4.7, 2.9, 28

# Cori: window and prior
TAU = 7
A0, B0 = 1.0, 1.0    # prior Gamma(shape=a0, rate=b0)
CORI_Q = (0.05, 0.95)
# Sampling for generative uncertainty mixing (Cori over simulations)
MIX_SAMPLES_PER_PATH = 32

# WT (stable)
WT_MODE = "sampled"     # "expected" | "sampled"
WT_B = 400
WT_RIGHT_CENSOR = True
WT_LMIN_MODE = "relative"  # 'relative' (to max(Λ)) or 'absolute'
WT_LMIN_VAL = 1e-6

# Output
try:
    OUT_DIR
except NameError:
    OUT_DIR = "./eval_out"
os.makedirs(OUT_DIR, exist_ok=True)

# -------------------- SI kernel utilities --------------------
def discretize_si_gamma(mean_si=SI_MEAN, sd_si=SI_SD, max_days=SI_MAX):
    """Discrete PMF w_s = F(s) - F(s-1) from Gamma(mean, sd)."""
    var = sd_si**2
    k = (mean_si**2) / var          # shape
    theta = var / mean_si           # scale
    s_grid = np.arange(0, max_days+1, dtype=float)
    cdf = gamma_dist.cdf(s_grid, a=k, scale=theta)
    w = np.diff(cdf)                # w[0] = P(0<SI<=1), etc.
    w = np.maximum(w, 0.0)
    w = w / w.sum()
    return w

def total_infectiousness(I, w):
    """Λ_t = sum_{s>=1} I_{t-s} w_s (with w[0] corresponding to s=1)."""
    I = np.asarray(I, float)
    T = len(I); W = np.zeros(T, dtype=float)
    S = len(w)
    for s in range(1, min(S, T) + 1):
        W[s:] += I[:-s] * w[s-1]
    return W

# -------------------- Cori (observed and generative mixture) --------------------
def cori_posterior_params(I, w, tau=7, a0=1.0, b0=1.0):
    """Returns arrays shape_t, rate_t for each t (NaN where no window)."""
    I = np.asarray(I, float)
    T = len(I)
    L = total_infectiousness(I, w)
    shape = np.full(T, np.nan)
    rate  = np.full(T, np.nan)
    for t in range(T):
        u0 = t - tau + 1
        if u0 < 0: 
            continue
        Is = I[u0:t+1]
        Ls = L[u0:t+1]
        shape[t] = a0 + np.sum(Is)
        rate[t]  = b0 + np.sum(Ls)
    return shape, rate

def cori_quantiles_from_params(shape, rate, q=(0.05, 0.95)):
    """Median and quantiles for Gamma(shape, rate)."""
    med = np.full_like(shape, np.nan, dtype=float)
    lo  = np.full_like(shape, np.nan, dtype=float)
    hi  = np.full_like(shape, np.nan, dtype=float)
    valid = np.isfinite(shape) & np.isfinite(rate) & (rate > 0) & (shape > 0)
    if valid.any():
        sc = 1.0 / rate[valid]
        med[valid] = gamma_dist.ppf(0.5, a=shape[valid], scale=sc)
        lo[valid]  = gamma_dist.ppf(q[0], a=shape[valid], scale=sc)
        hi[valid]  = gamma_dist.ppf(q[1], a=shape[valid], scale=sc)
    return med, lo, hi

def cori_from_observed(I, w, tau=7, a0=1.0, b0=1.0, q=(0.05,0.95)):
    shape, rate = cori_posterior_params(I, w, tau, a0, b0)
    return cori_quantiles_from_params(shape, rate, q=q)

def cori_from_ensemble(samples_matrix, w, tau=7, a0=1.0, b0=1.0, q=(0.05,0.95),
                       mix_samples_per_path=32, seed=12345):
    """
    Posterior mixture over simulations:
    - for each simulation m and time t: posterior Gamma(shape_m[t], rate_m[t])
    - sample K values and mix at level t; then compute quantiles.
    """
    rng = np.random.default_rng(seed)
    S = np.asarray(samples_matrix, float)           # (T, M)
    T, M = S.shape
    Rt_samples = np.full((T, M * mix_samples_per_path), np.nan, dtype=float)

    for m in range(M):
        shape, rate = cori_posterior_params(S[:, m], w, tau, a0, b0)
        valid = np.isfinite(shape) & np.isfinite(rate) & (rate > 0) & (shape > 0)
        if not valid.any():
            continue
        sc = 1.0 / rate[valid]
        draws = rng.gamma(shape[valid][:, None], sc[:, None], size=(valid.sum(), mix_samples_per_path))
        Rt_samples[valid, m*mix_samples_per_path:(m+1)*mix_samples_per_path] = draws

    med = np.nanmedian(Rt_samples, axis=1)
    lo  = np.nanquantile(Rt_samples, q[0], axis=1)
    hi  = np.nanquantile(Rt_samples, q[1], axis=1)
    return med, lo, hi

# -------------------- Stable WT (observed) --------------------
def wt_expected_stable(I, w, correct_right_censor=True, lmin_mode="relative", lmin_val=1e-6):
    """
    Stable version of WT 'expected':
    R_t = sum_{s>=1, i=t+s} [ I[i] * w[s-1] / Λ[i] ] / sum_{valid s} w[s-1]
    - Ignores terms with Λ[i] < Lmin (avoids explosions).
    - If no valid mass at t, returns NaN and marks invalid.
    """
    I = np.asarray(I, float)
    T = len(I)
    L = total_infectiousness(I, w)
    Rt = np.full(T, np.nan)
    valid = np.ones(T, dtype=int)

    Lmax = np.nanmax(L) if np.isfinite(np.nanmax(L)) else 1.0
    Lmin = (lmin_val * max(Lmax, 1e-12)) if (lmin_mode == "relative") else float(lmin_val)

    for t in range(T):
        max_s = min(len(w), T - t - 1)
        if max_s <= 0:
            Rt[t] = np.nan; valid[t] = 0; continue

        num = 0.0
        mass_valid = 0.0
        for s in range(1, max_s + 1):
            i = t + s
            if L[i] >= Lmin:
                num += I[i] * w[s-1] / L[i]
                mass_valid += w[s-1]

        if mass_valid <= 0:
            Rt[t] = np.nan
            valid[t] = 0
            continue

        Rt[t] = num / mass_valid

    return Rt, valid

def wt_bootstrap_stable(I, w, B=400, correct_right_censor=True,
                        lmin_mode="relative", lmin_val=1e-6,
                        q_lo=0.05, q_hi=0.95, seed=123):
    """Stable parametric bootstrap: I* ~ Poisson(I); mix quantiles."""
    rng = np.random.default_rng(seed)
    T = len(I)
    Rt_s = np.full((T, B), np.nan)
    valid_last = None
    for b in range(B):
        I_b = rng.poisson(lam=np.maximum(I, 0.0))
        Rt_b, valid_b = wt_expected_stable(I_b, w,
                                           correct_right_censor=correct_right_censor,
                                           lmin_mode=lmin_mode, lmin_val=lmin_val)
        Rt_s[:, b] = Rt_b
        valid_last = valid_b
    Rt_med = np.nanmedian(Rt_s, axis=1)
    Rt_lo  = np.nanquantile(Rt_s, q_lo, axis=1)
    Rt_hi  = np.nanquantile(Rt_s, q_hi, axis=1)
    return Rt_med, Rt_lo, Rt_hi, valid_last

# -------------------- Data loading (simulations + observed) --------------------
def load_copula_simulations(csv_path):
    """
    Reads CSV with columns: Date, Commune, X1..XM
    Returns dict[commune] -> {"dates": list[pd.Timestamp], "samples": np.array shape (T, M)}
    """
    df = pd.read_csv(csv_path, parse_dates=["Date"])
    # Detect simulation columns (X1, X2, ...)
    sim_cols = [c for c in df.columns if c.startswith("X")]
    if len(sim_cols) == 0:
        raise ValueError("No simulation columns found (prefix 'X').")
    out = {}
    for comm, d in df.groupby("Commune"):
        d = d.sort_values("Date").reset_index(drop=True)
        out[comm] = {
            "dates": pd.to_datetime(d["Date"]).tolist(),
            "samples": d[sim_cols].to_numpy(dtype=float)  # (T, M)
        }
    return out

def load_observed_frame(OBS_COL_name=None, obs_csv=None):
    """
    Gets DataFrame with columns ['Date','Commune', OBS_COL].
    - If global df_scaled exists and OBS_COL_name is in df_scaled, use that.
    - If OBS_COL_name is None, try to guess.
    - If obs_csv is not None, load it and require OBS_COL_name.
    """
    # 1) Is df_scaled already in memory?
    df_src = None
    try:
        df_scaled  # noqa
        df_src = df_scaled.copy()
    except NameError:
        df_src = None

    if df_src is not None:
        candidates = []
        if OBS_COL_name is not None:
            candidates = [OBS_COL_name]
        else:
            # try to detect
            for cand in ["Observed_Cases", "OBS_COL", "Gross_Daily_Cases_Mobile_Average_7_Days", "Cases"]:
                if cand in df_src.columns:
                    candidates.append(cand)
        if not candidates:
            # fall back to external CSV
            df_src = None
        else:
            col = candidates[0]
            return df_src[["Date","Commune", col]].rename(columns={col: "OBSERVED"}).assign(Date=pd.to_datetime(df_src["Date"]))

    # 2) External CSV
    if obs_csv is not None:
        df = pd.read_csv(obs_csv, parse_dates=["Date"])
        if OBS_COL_name is None:
            for cand in ["Observed_Cases", "Gross_Daily_Cases_Mobile_Average_7_Days", "Cases"]:
                if cand in df.columns:
                    OBS_COL_name = cand; break
        if OBS_COL_name is None or OBS_COL_name not in df.columns:
            raise ValueError("Could not determine observed column in OBS_DATA_CSV.")
        return df[["Date","Commune", OBS_COL_name]].rename(columns={OBS_COL_name:"OBSERVED"})

    raise RuntimeError("No observed data found. Define df_scaled with cases column or provide OBS_DATA_CSV.")

# -------------------- Pipeline: Cori (obs & Copula) + WT --------------------
def run_rt_all_copula(df_obs, copula_dict, communes, si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
                      tau=TAU, a0=A0, b0=B0, cori_q=CORI_Q,
                      mix_samples_per_path=MIX_SAMPLES_PER_PATH,
                      wt_mode=WT_MODE, wt_B=WT_B, wt_right_censor=WT_RIGHT_CENSOR,
                      wt_lmin_mode=WT_LMIN_MODE, wt_lmin_val=WT_LMIN_VAL):
    """
    Returns:
      RT_RES[comm] -> {dates, Rt_obs_median/lo/hi, Rt_gen_median/lo/hi}
      WT_RES[comm] -> {dates, Rt_median/lo/hi, valid}
    """
    w = discretize_si_gamma(si_mean, si_sd, si_max)
    RT_RES, WT_RES = {}, {}
    groups = df_obs.groupby("Commune")

    for comm in communes:
        if comm not in copula_dict:
            continue
        # Observed data
        if comm not in groups.groups:
            continue
        d = groups.get_group(comm).sort_values("Date").reset_index(drop=True)
        dates_obs = pd.to_datetime(d["Date"]).tolist()
        I_obs = d["OBSERVED"].to_numpy(dtype=float)
        I_obs = np.clip(I_obs, a_min=0.0, a_max=None)

        # Cori (observed)
        obs_med, obs_lo, obs_hi = cori_from_observed(I_obs, w, tau=tau, a0=a0, b0=b0, q=cori_q)

        # Cori (COPULA: posterior mixture over simulations)
        samples = copula_dict[comm]["samples"]  # (T, M)
        # Ensure dates align (we'll assume they do by construction; if not, do a merge by Date)
        gen_med, gen_lo, gen_hi = cori_from_ensemble(samples, w, tau=tau, a0=a0, b0=b0,
                                                     q=cori_q, mix_samples_per_path=mix_samples_per_path)

        RT_RES[comm] = {
            "dates": dates_obs,
            "Rt_obs_median": obs_med.tolist(),
            "Rt_obs_lo":     obs_lo.tolist(),
            "Rt_obs_hi":     obs_hi.tolist(),
            "Rt_gen_median": gen_med.tolist(),
            "Rt_gen_lo":     gen_lo.tolist(),
            "Rt_gen_hi":     gen_hi.tolist(),
        }

        # WT (observed) stable
        if wt_mode == "expected":
            Rt_wt, valid = wt_expected_stable(I_obs, w,
                                              correct_right_censor=wt_right_censor,
                                              lmin_mode=wt_lmin_mode, lmin_val=wt_lmin_val)
            WT_RES[comm] = {
                "dates": dates_obs,
                "Rt_median": Rt_wt.tolist(),
                "Rt_lo": Rt_wt.tolist(),
                "Rt_hi": Rt_wt.tolist(),
                "valid": valid.tolist()
            }
        else:
            Rt_med, Rt_lo, Rt_hi, valid = wt_bootstrap_stable(I_obs, w, B=wt_B,
                                                              correct_right_censor=wt_right_censor,
                                                              lmin_mode=wt_lmin_mode, lmin_val=wt_lmin_val)
            WT_RES[comm] = {
                "dates": dates_obs,
                "Rt_median": Rt_med.tolist(),
                "Rt_lo": Rt_lo.tolist(),
                "Rt_hi": Rt_hi.tolist(),
                "valid": valid.tolist()
            }

    return RT_RES, WT_RES

# -------------------- Unified plot --------------------
def _series_from(dict_, keys, name):
    ds = pd.to_datetime(pd.Series(dict_[keys["date"]], dtype="object"))
    df = pd.DataFrame({
        f"{name}_med": dict_[keys["med"]],
        f"{name}_lo":  dict_[keys["lo"]],
        f"{name}_hi":  dict_[keys["hi"]],
    }, index=ds)
    df = df[~df.index.duplicated(keep="first")].sort_index()
    return df

def _last_invalid_span(valid_arr):
    v = np.asarray(valid_arr, int)
    if v.size == 0 or v[-1] != 0:
        return None
    i = len(v) - 1
    while i >= 0 and v[i] == 0:
        i -= 1
    return i + 1

def plot_rt_unified_panels_general(
    RT_RES, WT_RES, communes=TARGET_COMMUNES,
    gen_name="Copula", ncols=2, fig_width=16, row_height=5.0,
    save_path=None, title=None, force_ylim=True
):
    plt.rcParams.update({
        "font.family": "DejaVu Sans",
        "axes.titlesize": 12.5,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10.5,
        "savefig.dpi": 600,
    })
    col_obs_band = "#808080"; col_obs_line = "#222222"
    col_gen_band = "#4C78A8"; col_gen_line = "#2F5B8B"
    col_wt_band  = "#59A14F"; col_wt_line  = "#2E7D32"
    a_obs, a_gen, a_wt = 0.18, 0.22, 0.22
    lw = 2.0

    n = len(communes); nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(fig_width, row_height*nrows), sharex=False, sharey=False)
    axes = np.array(axes).reshape(-1)

    for ax, comm in zip(axes, communes):
        if comm not in RT_RES or comm not in WT_RES:
            ax.text(0.5, 0.5, f"No data for {comm}", ha="center", va="center", transform=ax.transAxes)
            ax.axis("off"); continue

        rt = RT_RES[comm]; wt = WT_RES[comm]
        df_obs = _series_from(rt, {"date":"dates","med":"Rt_obs_median","lo":"Rt_obs_lo","hi":"Rt_obs_hi"}, "obs")
        df_gen = _series_from(rt, {"date":"dates","med":"Rt_gen_median","lo":"Rt_gen_lo","hi":"Rt_gen_hi"}, "gen")
        df_wt  = _series_from(wt, {"date":"dates","med":"Rt_median","lo":"Rt_lo","hi":"Rt_hi"}, "wt")

        df = df_obs.join(df_gen, how="inner").join(df_wt, how="inner")
        if df.empty:
            ax.text(0.5,0.5,f"No aligned dates for {comm}",ha="center",va="center",transform=ax.transAxes)
            ax.axis("off"); continue

        dates = df.index

        # Bands
        ax.fill_between(dates, df["obs_lo"], df["obs_hi"], color=col_obs_band, alpha=a_obs, edgecolor="none")
        ax.fill_between(dates, df["gen_lo"], df["gen_hi"], color=col_gen_band, alpha=a_gen, edgecolor="none")
        ax.fill_between(dates, df["wt_lo"],  df["wt_hi"],  color=col_wt_band,  alpha=a_wt,  edgecolor="none")

        # Medians
        ax.plot(dates, df["obs_med"], color=col_obs_line, lw=lw)
        ax.plot(dates, df["gen_med"], color=col_gen_line, lw=lw)
        ax.plot(dates, df["wt_med"],  color=col_wt_line,  lw=lw)

        # Rt=1
        ax.axhline(1.0, color="#000000", lw=1.0, ls=":", alpha=0.85)

        # Shade invalid span at the end (WT)
        if "valid" in wt and len(wt["valid"]) == len(wt["dates"]):
            idx0 = _last_invalid_span(wt["valid"])
            if idx0 is not None and idx0 < len(wt["dates"]):
                t0 = pd.to_datetime(wt["dates"][idx0]); t1 = pd.to_datetime(wt["dates"][-1])
                ax.axvspan(t0, t1, color="#9E9E9E", alpha=0.12, lw=0, zorder=0)

        # Vertical range (optional, helps see Rt=1)
        if force_ylim:
            q98 = np.nanquantile(np.r_[df["obs_med"].values, df["gen_med"].values, df["wt_med"].values], 0.98)
            ylim_top = max(2.5, float(q98))
            ax.set_ylim(0.0, ylim_top)

        ax.set_title(comm); ax.set_ylabel(r"$R_t$"); ax.grid(True, alpha=0.3)

    for k in range(len(communes), len(axes)):
        axes[k].axis("off")

    legend_elems = [
        Patch(facecolor=col_obs_band, edgecolor="none", alpha=a_obs, label="Cori (obs) 95% band"),
        Line2D([0],[0], color=col_obs_line, lw=lw, label="Cori (obs) median"),
        Patch(facecolor=col_gen_band, edgecolor="none", alpha=a_gen, label=f"Cori ({gen_name}) 90% band"),
        Line2D([0],[0], color=col_gen_line, lw=lw, label=f"Cori ({gen_name}) median"),
        Patch(facecolor=col_wt_band,  edgecolor="none", alpha=a_wt,  label="WT 90% band"),
        Line2D([0],[0], color=col_wt_line,  lw=lw, label="WT median"),
        Line2D([0],[0], color="#000000", lw=1.0, ls=":", label=r"$R_t=1$")
    ]
    fig.legend(legend_elems, [h.get_label() for h in legend_elems],
               loc="lower center", ncol=4, frameon=False)

    if title is None:
        title = r"Time-varying reproduction number $R_t$: Cori (obs \& Copula) and WT"
    fig.suptitle(title, y=0.995, fontsize=14)

    fig.tight_layout(rect=[0,0.05,1,0.98])
    if save_path is None:
        save_path = os.path.join(OUT_DIR, "rt_unified_panels_copula.png")
    fig.savefig(save_path, dpi=600, bbox_inches="tight")
    plt.show()
    print("✓ Saved:", save_path)

# -------------------- Run everything --------------------
# 1) Copula simulations
COPULA_ENS = load_copula_simulations(COPULA_SIM_CSV)

# 2) Observed data
df_obs = load_observed_frame(OBS_COL_name=OBS_COL, obs_csv=OBS_DATA_CSV)

# 3) Filter to target communes and align dates (optional: here we assume they're consistent)
df_obs = df_obs[df_obs["Commune"].isin(TARGET_COMMUNES)].copy()

# 4) Calculate Rt (Cori obs/copula + stable WT)
RT_RES, WT_RES = run_rt_all_copula(
    df_obs=df_obs,
    copula_dict=COPULA_ENS,
    communes=TARGET_COMMUNES,
    si_mean=SI_MEAN, si_sd=SI_SD, si_max=SI_MAX,
    tau=TAU, a0=A0, b0=B0, cori_q=CORI_Q,
    mix_samples_per_path=MIX_SAMPLES_PER_PATH,
    wt_mode=WT_MODE, wt_B=WT_B, wt_right_censor=WT_RIGHT_CENSOR,
    wt_lmin_mode=WT_LMIN_MODE, wt_lmin_val=WT_LMIN_VAL
)

# 5) Unified visualization
plot_rt_unified_panels_general(
    RT_RES=RT_RES,
    WT_RES=WT_RES,
    communes=TARGET_COMMUNES,
    gen_name="Copula",
    ncols=2,
    fig_width=16,
    row_height=5.0,
    save_path=os.path.join(OUT_DIR, "rt_unified_panels_copula.png"),
    title=r"Time-varying reproduction number $R_t$: Cori (obs \& Copula) and WT",
    force_ylim=True
)