In [None]:
# ============================================================
# Copula simulations: load, compute metrics, and visualize (paper-ready)
#
# - Input sims CSV (from R): Date, Commune, X1..XN (N = #simulations)
# - Observed CSV: must contain Date, Commune, and OBS_COL
# - Outputs:
#     * Figure panels PNG/PDF with per-panel metrics (600 dpi)
#     * CSV with metrics by commune
# ============================================================

import os, math, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from typing import Dict, Any, Tuple, List

# --------- User Configuration ---------
SIM_PATH = "path/simulated_cases_severity_weighted_v2.csv"
OBS_PATH = "path/covid_data_with_6_states.csv"
OBS_COL  = "Gross_Daily_Cases_Mobile_Average_7_Days"

TARGET_COMMUNES = ["La Florida", "Cerrillos", "Vitacura",
                   "Providencia", "Las Condes", "Santiago"]

# Output directory setup
OUT_DIR = "./eval_copula_out"
os.makedirs(OUT_DIR, exist_ok=True)

# --------- Helper Functions: Robust Quantiles, CRPS, AAD ----------

def _nanquantile_1d(a: np.ndarray, q: float) -> float:
    """Robust 1D quantile calculation ignoring NaNs."""
    a = np.asarray(a, float)
    a = a[~np.isnan(a)]
    if a.size == 0:
        return np.nan
    return float(np.quantile(a, q))

def crps_ens(y: float, samples: np.ndarray) -> float:
    """
    Continuous Ranked Probability Score (CRPS) for ensemble forecasts.
    Formula: E|X - y| - 0.5 * E|X - X'|
    """
    s = np.asarray(samples, float)
    s = s[~np.isnan(s)]
    if s.size == 0 or np.isnan(y):
        return np.nan
    s = np.sort(s)
    n = len(s)
    y = float(y)
    
    # E|X - y| (Mean absolute error of samples relative to observation)
    e1 = np.mean(np.abs(s - y)) 
    
    if n == 1:
        # Cannot calculate E|X - X'| with one sample
        return float(e1)
        
    # E|X - X'| (Mean absolute difference between any two samples)
    diffs = np.diff(s)
    weights = np.arange(1, n) * (n - np.arange(1, n))
    e2 = 2.0 * np.sum(weights * diffs) / (n * n) 
    
    return float(e1 - 0.5 * e2)

def asymmetry_degree(lo: float, med: float, hi: float) -> float:
    """
    Calculates the Asymmetry Degree (AAD) of a probability interval.
    (Width_Upper - Width_Lower) / Total_Width. Range: [-1, 1].
    """
    lo = float(lo); med = float(med); hi = float(hi)
    if any(np.isnan([lo, med, hi])):
        return np.nan
    width = max(hi - lo, 1e-12)
    # (High - Median) - (Median - Low) / Width
    return ((hi - med) - (med - lo)) / width

# --------- Load & Build Data Bundles from Copula Simulations ----------

def load_copula_bundles(sim_csv: str, obs_csv: str, obs_col: str, communes: List[str] = None) -> Tuple[Dict[str, Any], pd.DataFrame]:
    """
    Loads simulation and observation data, aligns them by date, and computes 
    quantiles and metrics for each target commune.

    Returns:
      bundles: dict[commune] -> {time series and simulation matrix}
      metrics_df: DataFrame with metrics by commune
    """
    # Load simulations
    sim_df = pd.read_csv(sim_csv, parse_dates=["Date"])
    sim_df["Commune"] = sim_df["Commune"].astype(str)

    # Load observed
    obs_df = pd.read_csv(obs_csv, parse_dates=["Date"])
    obs_df["Commune"] = obs_df["Commune"].astype(str)

    if obs_col not in obs_df.columns:
        raise KeyError(f"Column '{obs_col}' not found in observed CSV.")

    # Select target communes
    if communes is None:
        communes = sorted(sim_df["Commune"].unique().tolist())

    # Identify simulation columns (e.g., X1, X2, ...)
    sim_cols = [c for c in sim_df.columns if re.match(r'^X\d+$', c)]
    if len(sim_cols) == 0:
        # Fallback: all numeric columns except Date/Commune
        sim_cols = [c for c in sim_df.columns if c not in ("Date", "Commune")]
        sim_cols = [c for c in sim_cols if pd.api.types.is_numeric_dtype(sim_df[c])]
    if len(sim_cols) == 0:
        raise RuntimeError("No simulation columns found (e.g., X1..XN).")

    bundles = {}
    metrics_rows = []

    for comm in communes:
        ssub = sim_df[sim_df["Commune"] == comm].copy()
        if ssub.empty:
            print(f"[WARN] No simulations for commune: {comm}")
            continue
        ssub = ssub.sort_values("Date").reset_index(drop=True)

        ocom = obs_df[obs_df["Commune"] == comm][["Date", obs_col]].copy()
        if ocom.empty:
            print(f"[WARN] No observed data for commune: {comm}")
            continue
        ocom = ocom.sort_values("Date").reset_index(drop=True).rename(columns={obs_col:"Observed"})

        # Align by Date (intersection of dates)
        common_dates = np.intersect1d(ssub["Date"].values, ocom["Date"].values)
        if len(common_dates) == 0:
            print(f"[WARN] No overlapping dates for commune: {comm}")
            continue

        ssub = ssub[ssub["Date"].isin(common_dates)].sort_values("Date").reset_index(drop=True)
        ocom = ocom[ocom["Date"].isin(common_dates)].sort_values("Date").reset_index(drop=True)

        # Build simulations matrix: shape (N_sims, T_time)
        sims_matrix = ssub[sim_cols].to_numpy(dtype=float).T 
        dates = ssub["Date"].to_numpy()
        truth = ocom["Observed"].to_numpy(dtype=float)

        # Quantiles & mean calculated across simulations (axis=0) for each time step
        q25 = np.nanquantile(sims_matrix, 0.25, axis=0)
        q50 = np.nanquantile(sims_matrix, 0.50, axis=0)
        q75 = np.nanquantile(sims_matrix, 0.75, axis=0)
        mean = np.nanmean(sims_matrix, axis=0)

        bundle = {
            "dates": dates,
            "truth": truth,
            "q25": q25, "q50": q50, "q75": q75,
            "mean": mean,
            "sims": sims_matrix # (N, T)
        }
        bundles[comm] = bundle

        # --- Calculate Metrics ---
        m = _metrics_from_bundle(bundle)
        m["Commune"] = comm
        metrics_rows.append(m)

    # Aggregate metrics table
    met_df = pd.DataFrame(metrics_rows).set_index("Commune")
    
    # Save metrics CSV
    met_path = os.path.join(OUT_DIR, "copula_metrics_by_commune.csv")
    met_df.to_csv(met_path)
    print(f"✓ Metrics CSV saved -> {met_path}")

    return bundles, met_df

# --------- Calculate Metrics from a Bundle (robust to NaN) ----------

def _metrics_from_bundle(bundle: Dict[str, Any]) -> Dict[str, float]:
    """Calculates evaluation metrics from a single bundle of data."""
    y = np.asarray(bundle["truth"], float)
    q25 = np.asarray(bundle["q25"],  float)
    q50 = np.asarray(bundle["q50"],  float)
    q75 = np.asarray(bundle["q75"],  float)
    sims = np.asarray(bundle["sims"], float)  # (N, T)

    # AW50 (Average Width of the 50% Prediction Interval)
    width50 = q75 - q25
    AW50 = float(np.nanmean(width50))

    # COV50 (50% Coverage Probability)
    COV50 = float(np.nanmean((y >= q25) & (y <= q75)))

    # AAD50 (Average Asymmetry Degree of the 50% Prediction Interval)
    AAD_list = [asymmetry_degree(lo, md, hi) for lo, md, hi in zip(q25, q50, q75)]
    AAD50 = float(np.nanmean(AAD_list))

    # CRPS (Continuous Ranked Probability Score) averaged over time
    T = y.shape[0]
    crps_vals = []
    for t in range(T):
        samp_t = sims[:, t]
        # Only compute CRPS if there is at least one valid sample and valid observation
        samp_t = samp_t[~np.isnan(samp_t)]
        if samp_t.size >= 1 and not np.isnan(y[t]):
            crps_vals.append(crps_ens(y[t], samp_t))
        else:
            crps_vals.append(np.nan)
    CRPS = float(np.nanmean(crps_vals))

    # Point Accuracy Metrics (using median as point forecast)
    MAE  = float(np.nanmean(np.abs(y - q50)))
    RMSE = float(np.sqrt(np.nanmean((y - q50)**2)))

    return {"AW50": AW50, "COV50": COV50, "AAD50": AAD50, "CRPS": CRPS, "MAE": MAE, "RMSE": RMSE}

# --------- Plotting Panels (Observation + Copula Forecast) ----------

def plot_copula_fullcurve_panels_with_metrics(
    bundles: Dict[str, Any],
    communes: List[str] = TARGET_COMMUNES,
    ncols: int = 2,
    fig_width: float = 16,
    row_height: float = 5.0,
    save_basename: str = "paper_panels_COPULA_50band_with_observed_metrics",
    title: str = "Full-curve reconstructions — Observed & Copula (median, mean, 50% band)"
):
    """Generates a multi-panel plot comparing observed data against copula simulations."""
    n = len(communes)
    ncols = max(1, ncols)
    nrows = int(math.ceil(n / ncols))

    # Set publication-quality aesthetics
    plt.rcParams.update({
        "font.family": "DejaVu Sans",
        "axes.titlesize": 13,
        "axes.labelsize": 12,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 11,
        "savefig.dpi": 600,
    })

    fig, axes = plt.subplots(
        nrows, ncols,
        figsize=(fig_width, row_height * nrows),
        sharex=False, sharey=False
    )
    # Ensure axes is a flat array for easy iteration, even if only 1x1
    axes = np.array(axes).reshape(-1)

    # Colors and styles
    col_obs    = "#000000"   # Observed data
    col_median = "#1f77b4"   # Copula median (50% quantile)
    col_mean   = "#ff7f0e"   # Copula mean
    col_band   = "#b7c1c9"   # 50% Prediction Interval band
    alpha50    = 0.40
    lw_obs     = 2.0
    lw_line    = 2.0

    for ax, comm in zip(axes, communes):
        if comm not in bundles:
            ax.axis("off") # Hide panel if no data exists
            continue
            
        b = bundles[comm]
        dates = b["dates"]

        # 50% Prediction Interval band (IQR: Q25 to Q75)
        ax.fill_between(dates, b["q25"], b["q75"],
                         color=col_band, alpha=alpha50, edgecolor="none")

        # Lines: observed (solid black), median (dashed blue), mean (solid orange)
        ax.plot(dates, b["truth"], color=col_obs,   lw=lw_obs,   label=None)
        ax.plot(dates, b["q50"],   color=col_median, linestyle="--", lw=lw_line, label=None)
        ax.plot(dates, b["mean"],  color=col_mean,   linestyle="-",  lw=lw_line, label=None)

        ax.set_title(comm)
        ax.set_ylabel("Daily cases (7-day MA)")
        ax.grid(True, alpha=0.7)

        # Add Metrics as Text Box (Top-Right)
        m = _metrics_from_bundle(b)
        lines = [
            f"AW50: {m['AW50']:.1f}",
            f"COV50: {m['COV50']:.3f}",
            f"AAD50: {m['AAD50']:.3f}",
            f"CRPS: {m['CRPS']:.2f}" if not np.isnan(m["CRPS"]) else "CRPS: —",
            f"MAE: {m['MAE']:.2f}",
            f"RMSE: {m['RMSE']:.2f}",
        ]
        txt = "\n".join(lines)
        ax.text(
            0.98, 0.98, txt,
            transform=ax.transAxes,
            va="top", ha="right",
            fontsize=12.0,
            color="#222222",
            bbox=dict(boxstyle="round,pad=0.35,rounding_size=0.12",
                      fc="white", ec="#4c4d4f", lw=0.9, alpha=0.88)
        )

    # Hide unused subplots
    for k in range(len(communes), len(axes)):
        axes[k].axis("off")

    # Global legend
    legend_elems = [
        Patch(facecolor=col_band, edgecolor="none", alpha=alpha50, label="Copula 50% band"),
        Line2D([0], [0], color=col_median, lw=lw_line, linestyle="--", label="Copula median"),
        Line2D([0], [0], color=col_mean,   lw=lw_line, linestyle="-",  label="Copula mean"),
        Line2D([0], [0], color=col_obs,    lw=lw_obs,  linestyle="-",  label="Observed"),
    ]
    fig.legend(legend_elems, [h.get_label() for h in legend_elems],
               loc="lower center", ncol=4, frameon=False)

    # Layout & export
    if title:
        fig.suptitle(title, y=0.99, fontsize=14)
        fig.tight_layout(rect=[0, 0.07, 1, 0.96])
    else:
        fig.tight_layout(rect=[0, 0.07, 1, 1])

    png_path = os.path.join(OUT_DIR, f"{save_basename}.png")
    pdf_path = os.path.join(OUT_DIR, f"{save_basename}.pdf")
    fig.savefig(png_path, dpi=600, bbox_inches="tight")
    fig.savefig(pdf_path, bbox_inches="tight")
    plt.show()
    print(f"✓ Saved: {png_path}")
    print(f"✓ Saved: {pdf_path}")

# ============================
# RUN: Load -> Metrics -> Plot
# ============================
if __name__ == "__main__":
    bundles_copula, metrics_copula = load_copula_bundles(
        sim_csv=SIM_PATH,
        obs_csv=OBS_PATH,
        obs_col=OBS_COL,
        communes=TARGET_COMMUNES
    )

    plot_copula_fullcurve_panels_with_metrics(
        bundles=bundles_copula,
        communes=TARGET_COMMUNES,
        ncols=2,
        fig_width=16,
        row_height=5.0,
        save_basename="paper_panels_COPULA_50band_with_observed_metrics",
        title="Full-curve Reconstructions: Observed vs. Copula Simulation"
    )

    # (Optional) Display the metrics table with 3 decimal places
    print("\n=== Copula Metrics by Commune (Rounded) ===")
    print(metrics_copula.round({"AW50":3,"COV50":3,"AAD50":3,"CRPS":3,"MAE":3,"RMSE":3}))
