In [8]:
# Auto-plot all runs discovered under any 'results/' folder (80/20 + LOO, tuned + untuned, S3 + S8)
# Writes PDFs into: figs/<Protocol>/<tuned|untuned>/<Site>/<window>/<run-prefix>__*.pdf
# TITLES: split across two lines to avoid overflow.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.offsetbox import AnchoredText
from matplotlib.backends.backend_pdf import PdfPages
from pathlib import Path
import re
from math import sqrt
from typing import Optional, Tuple, Dict, List

# styling constants
TICK_LABEL_SIZE   = 14
LEGEND_FONT_SIZE  = 14
AXIS_LABEL_SIZE   = 14
TITLE_FONT_SIZE   = 12   # smaller to avoid title overflow
TITLE_LINE_BREAK  = True # two-line titles

plt.rcParams["pdf.fonttype"] = 42
plt.rcParams["ps.fonttype"]  = 42

# helpers
def _legend_dedup(ax, loc="best"):
    handles, labels = ax.get_legend_handles_labels()
    if not labels:
        return
    dd = {}
    for h, l in zip(handles, labels):
        if l not in dd:
            dd[l] = h
    ax.legend(list(dd.values()), list(dd.keys()), loc=loc, fontsize=LEGEND_FONT_SIZE)

def _sanitize(s: str) -> str:
    return re.sub(r"[^A-Za-z0-9._-]+", "_", str(s)).strip("_")

def parse_meta_from_preds_filename(p: Path) -> Dict[str, Optional[str]]:
    """
    Pattern (robust):
      8020__<lag>__<zero>__<window>__<tuned/untuned>__<target>__preds__<ts>.csv
      loo__<lag>__<zero>__<window>__<tuned/untuned>__<target>__preds__<ts>.csv
    """
    toks = p.stem.split("__")
    meta = {"protocol": None, "lag": None, "zero": None, "window": None,
            "tuned": None, "target": None, "run_prefix": p.stem}
    # If we can't find "preds", still return with all keys present
    if "preds" not in toks:
        return meta
    i = toks.index("preds")
    if i-1 >= 0: meta["target"]   = toks[i-1]
    if i-2 >= 0: meta["tuned"]    = toks[i-2].lower()
    if i-3 >= 0: meta["window"]   = toks[i-3]
    if i-4 >= 0: meta["zero"]     = toks[i-4]
    if i-5 >= 0: meta["lag"]      = toks[i-5]
    if i-6 >= 0: meta["protocol"] = toks[i-6].lower()
    return meta

def pretty_protocol(proto: Optional[str]) -> str:
    if not proto: return "Unknown"
    proto = proto.lower()
    if proto in {"8020","80_20","80-20"}: return "80_20"
    if proto in {"loo","l-o-o"}:          return "LOO"
    return _sanitize(proto)

def pretty_tuned(tok: Optional[str]) -> str:
    if not tok: return "untuned"
    t = tok.lower()
    if t in {"tuned","true"}:                 return "tuned"
    if t in {"untuned","false","default"}:    return "untuned"
    return _sanitize(t)

def site_species_labels(target: Optional[str]) -> Tuple[str, str, str]:
    site_label, sp_short, sp_code = "Site ?", "Species ?", "?"
    if not target: return site_label, sp_short, sp_code
    m_site = re.search(r"Adults_(\d+)_", target)
    if m_site: site_label = f"Site {m_site.group(1)}"
    if target.endswith("_Col"):
        sp_short, sp_code = "A. coluzzii", "Col"
    elif target.endswith("_Gam") or target.endswith("_Gam0"):
        sp_short, sp_code = "A. gambiae", "Gam"
    return site_label, sp_short, sp_code

def pretty_zero(z: Optional[str]) -> str:
    if not z: return "—"
    if z == "standard":            return "TRAIN & TEST: 0→NaN"
    if z == "train_keeps_zeros":   return "TRAIN keeps zeros; TEST 0→NaN"
    return _sanitize(z)

def pretty_lag(l: Optional[str]) -> str:
    if not l: return "—"
    return {"no_lag":"No-lag","precip_lags":"Precip-lags","both_lags":"Both-lags"}.get(l, _sanitize(l))

def infer_window_from_dates(df: pd.DataFrame, date_col="Date") -> str:
    if date_col not in df.columns: return "unknown_window"
    dmin = pd.to_datetime(df[date_col]).min()
    dmax = pd.to_datetime(df[date_col]).max()
    if pd.isna(dmin) or pd.isna(dmax): return "unknown_window"
    return f"{dmin.date()}_{dmax.date()}"

def window_token_to_pretty(wtok: Optional[str]) -> str:
    if not wtok: return "unknown_window"
    if "_to_" in wtok:
        a,b = wtok.split("_to_",1); return f"{a}_{b}"
    if ":" in wtok:
        a,b = wtok.split(":",1);    return f"{a}_{b}"
    return wtok

def make_title(base_left: str, base_right: str, suffix: str) -> str:
    # Split into two lines
    base = f"{base_left} — {base_right}"
    return f"{base}\n{suffix}" if TITLE_LINE_BREAK else f"{base}: {suffix}"

# core plotting for a single preds CSV
def plot_one_preds_csv(csv_path: Path,
                       out_root: Path,
                       time_legend_loc: str = "upper left",
                       other_legend_loc: str = "best",
                       metrics_box_loc: str = "upper right") -> List[Path]:
    meta     = parse_meta_from_preds_filename(csv_path)
    proto    = pretty_protocol(meta.get("protocol"))
    tuned    = pretty_tuned(meta.get("tuned"))
    target   = meta.get("target")
    lag_lbl  = pretty_lag(meta.get("lag"))
    zero_lbl = pretty_zero(meta.get("zero"))

    # ingest CSV
    df = pd.read_csv(csv_path, sep=None, engine="python")
    if "Date" not in df.columns:
        raise ValueError(f"[{csv_path.name}] missing 'Date' column")
    df["Date"] = pd.to_datetime(df["Date"])
    df = df.sort_values("Date", ignore_index=True)
    for col in ("True","Predicted"):
        if col not in df.columns:
            raise ValueError(f"[{csv_path.name}] missing '{col}' column")

    # was_missing mask
    if "was_missing" in df.columns:
        mask_nan = df["was_missing"].astype(str).str.lower().map({"true":True,"false":False}).fillna(False)
    else:
        mask_nan = (df["True"] == 0)

    # Residual = True - Predicted (thesis convention)
    if "Residual" not in df.columns or df["Residual"].isna().all():
        df["Residual"] = np.where(~mask_nan, df["True"].astype(float) - df["Predicted"].astype(float), np.nan)

    # metrics (prefer CSV columns; else compute on known rows)
    def _first_numeric(col):
        if col in df.columns:
            s = pd.to_numeric(df[col], errors="coerce").dropna()
            if not s.empty:
                return float(s.iloc[0])
        return None

    mae_val  = _first_numeric("MAE")
    rmse_val = _first_numeric("RMSE")
    if (mae_val is None) or (rmse_val is None):
        y_true = df.loc[~mask_nan, "True"].astype(float).to_numpy()
        y_pred = df.loc[~mask_nan, "Predicted"].astype(float).to_numpy()
        if y_true.size:
            if mae_val  is None: mae_val  = float(np.mean(np.abs(y_pred - y_true)))
            if rmse_val is None: rmse_val = float(np.sqrt(np.mean((y_pred - y_true)**2)))

    site_label, species_short, species_code = site_species_labels(target)
    win_txt = window_token_to_pretty(meta.get("window")) if meta.get("window") else infer_window_from_dates(df)

    # output directory
    out_dir = out_root / proto / tuned / site_label / win_txt
    out_dir.mkdir(parents=True, exist_ok=True)

    # base labels
    base_left  = f"{site_label}, {species_short}"
    base_right = f"{proto} — {lag_lbl} — {zero_lbl} — {tuned} — {win_txt}"
    run_prefix = _sanitize(meta.get("run_prefix") or csv_path.stem)
    y_species  = f"{species_short} count ({site_label})"

    # pads
    date_pad = pd.Timedelta(days=3)
    x_min = df["Date"].min() - date_pad
    x_max = df["Date"].max() + date_pad

    # metrics box
    def _metrics_box(ax):
        parts = []
        if mae_val  is not None: parts.append(f"MAE: {mae_val:.3f}")
        if rmse_val is not None: parts.append(f"RMSE: {rmse_val:.3f}")
        if parts:
            at = AnchoredText("\n".join(parts), loc=metrics_box_loc,
                              prop=dict(size=16), frameon=True, borderpad=0.8)
            at.patch.set_alpha(0.75)
            ax.add_artist(at)

    paths_out: List[Path] = []
    pdf_pages = PdfPages(out_dir / f"{run_prefix}__ALL.pdf")

    # Plot 1: True vs Predicted over time
    fig1, ax = plt.subplots(figsize=(11, 5))
    ax.plot(df["Date"], df["True"],      marker="o", linewidth=1.5, label="True")
    ax.plot(df["Date"], df["Predicted"], marker="s", linewidth=1.5, label="Predicted")
    if mask_nan.any():
        ax.scatter(df.loc[mask_nan, "Date"], df.loc[mask_nan, "True"],      marker="o", s=90, label="True=0 (residual NaN)")
        ax.scatter(df.loc[mask_nan, "Date"], df.loc[mask_nan, "Predicted"], marker="s", s=90, label="Predicted at True=0")
    ax.set_xlim(x_min, x_max)
    ax.margins(x=0.0, y=0.06)
    ax.set_title(make_title(base_left, base_right, "True vs Predicted over time"),
                 fontsize=TITLE_FONT_SIZE, wrap=True)
    ax.set_xlabel("Date", fontsize=AXIS_LABEL_SIZE)
    ax.set_ylabel(y_species, fontsize=AXIS_LABEL_SIZE)
    ax.xaxis.set_major_locator(mdates.MonthLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%b"))
    ax.grid(True, alpha=0.3)
    _metrics_box(ax)
    _legend_dedup(ax, loc="upper left")
    ax.tick_params(axis="both", labelsize=TICK_LABEL_SIZE)
    plt.tight_layout()
    p1 = out_dir / f"{run_prefix}__time_series.pdf"
    fig1.savefig(p1, bbox_inches="tight"); pdf_pages.savefig(fig1); plt.close(fig1)
    paths_out.append(p1)

    # Plot 2: Residuals over time (True − Predicted)
    fig2, ax = plt.subplots(figsize=(11, 4))
    ax.axhline(0, linestyle="--", linewidth=1)
    ax.plot(df.loc[~mask_nan, "Date"], df.loc[~mask_nan, "Residual"], marker="o", linewidth=1.5)
    ax.set_xlim(x_min, x_max)
    ax.margins(x=0.0, y=0.08)
    ax.set_title(make_title(base_left, base_right, "Residuals over time (True − Predicted)"),
                 fontsize=TITLE_FONT_SIZE, wrap=True)
    ax.set_xlabel("Date", fontsize=AXIS_LABEL_SIZE)
    ax.set_ylabel("Residual (True − Predicted)", fontsize=AXIS_LABEL_SIZE)
    ax.xaxis.set_major_locator(mdates.MonthLocator())
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%b"))
    ax.grid(True, alpha=0.3)
    _metrics_box(ax)
    _legend_dedup(ax, loc="best")
    ax.tick_params(axis="both", labelsize=TICK_LABEL_SIZE)
    plt.tight_layout()
    p2 = out_dir / f"{run_prefix}__residuals_time.pdf"
    fig2.savefig(p2, bbox_inches="tight"); pdf_pages.savefig(fig2); plt.close(fig2)
    paths_out.append(p2)

    # Plot 3: Residual distribution
    fig3, ax = plt.subplots(figsize=(7, 4))
    ax.hist(df.loc[~mask_nan, "Residual"].dropna(), bins="auto")
    ax.set_title(make_title(base_left, base_right, "Residual distribution (True − Predicted)"),
                 fontsize=9, wrap=True)
    ax.set_xlabel("Residual (True − Predicted)", fontsize=AXIS_LABEL_SIZE)
    ax.set_ylabel("Frequency", fontsize=AXIS_LABEL_SIZE)
    ax.grid(True, alpha=0.3)
    _metrics_box(ax)
    _legend_dedup(ax, loc="best")
    ax.tick_params(axis="both", labelsize=TICK_LABEL_SIZE)
    plt.tight_layout()
    p3 = out_dir / f"{run_prefix}__residual_hist.pdf"
    fig3.savefig(p3, bbox_inches="tight"); pdf_pages.savefig(fig3); plt.close(fig3)
    paths_out.append(p3)

    # Plot 4: True vs Predicted (scatter)
    fig4, ax = plt.subplots(figsize=(6.8, 6.8))
    x_true = df["True"].astype(float).to_numpy()
    y_pred = df["Predicted"].astype(float).to_numpy()
    ax.scatter(x_true[~mask_nan], y_pred[~mask_nan], marker="o", alpha=0.9, label="Observed")
    if mask_nan.any():
        ax.scatter(x_true[mask_nan],  y_pred[mask_nan],  marker="o", s=90, label="Residual NaN (True=0)")
    data_min = float(min(np.nanmin(x_true), np.nanmin(y_pred)))
    data_max = float(max(np.nanmax(x_true), np.nanmax(y_pred)))
    span     = max(1e-9, data_max - data_min)
    pad_val  = max(0.5, 0.06 * span)
    lims     = (data_min - pad_val, data_max + pad_val)
    ax.plot(lims, lims, linestyle="--", linewidth=1)
    ax.set_xlim(lims); ax.set_ylim(lims)
    ax.set_title(make_title(base_left, base_right, "True vs Predicted (scatter)"),
                 fontsize=9, wrap=True)
    ax.set_xlabel(f"True count ({species_short}, {site_label})", fontsize=AXIS_LABEL_SIZE)
    ax.set_ylabel(f"Predicted count ({species_short}, {site_label})", fontsize=AXIS_LABEL_SIZE)
    ax.grid(True, alpha=0.3)
    _metrics_box(ax)
    _legend_dedup(ax, loc="best")
    ax.tick_params(axis="both", labelsize=TICK_LABEL_SIZE)
    plt.tight_layout()
    p4 = out_dir / f"{run_prefix}__true_vs_pred_scatter.pdf"
    fig4.savefig(p4, bbox_inches="tight"); pdf_pages.savefig(fig4); plt.close(fig4)
    paths_out.append(p4)

    pdf_pages.close()
    return paths_out

# discover all preds CSVs under any 'results/'
def find_all_results_dirs(start: Path) -> List[Path]:
    return [p for p in start.rglob("results") if p.is_dir()]

def find_all_preds_csvs(results_dir: Path) -> List[Path]:
    return sorted(results_dir.rglob("*__preds__*.csv"))

# main batch
def plot_everything(repo_root: Path = Path.cwd(),
                    out_root: Path = Path("figs")) -> None:
    results_dirs = find_all_results_dirs(repo_root)
    if not results_dirs:
        raise FileNotFoundError("No 'results' directories found under this repo.")
    total_csv = 0
    print(f"[info] found {len(results_dirs)} 'results' dirs")

    for rdir in results_dirs:
        preds_files = find_all_preds_csvs(rdir)
        if not preds_files:
            print(f"[skip] {rdir} (no __preds__ CSVs)")
            continue
        print(f"[info] {rdir}: {len(preds_files)} preds files")
        for csv_path in preds_files:
            try:
                written = plot_one_preds_csv(csv_path, out_root=out_root)
                total_csv += 1
                print(f"[ok] {csv_path.name} -> {len(written)} PDFs")
            except Exception as e:
                print(f"[warn] {csv_path}: {e}")

    print(f"[done] plotted {total_csv} runs. Output tree at: {out_root.resolve()}")

# final run
plot_everything(repo_root=Path.cwd(), out_root=Path("figs"))


[info] found 8 'results' dirs
[info] /Users/pradumchauhan/Desktop/FINAL_THESIS_17_NOV/Pipeline_MICE/pipeline_80_20_NoTuned/pipeline_80_20_2015_S3_NoTuned/results: 24 preds files
[ok] 8020__both_lags__standard__2015-07-01_2015-10-21__untuned__Adults_3_Col__preds__20251009_012851.csv -> 4 PDFs
[ok] 8020__both_lags__standard__2015-07-01_2015-10-21__untuned__Adults_3_Gam__preds__20251009_012851.csv -> 4 PDFs
[ok] 8020__both_lags__standard__2015-07-01_2015-10-31__untuned__Adults_3_Col__preds__20251009_012852.csv -> 4 PDFs
[ok] 8020__both_lags__standard__2015-07-01_2015-10-31__untuned__Adults_3_Gam__preds__20251009_012852.csv -> 4 PDFs
[ok] 8020__both_lags__train_keeps_zeros__2015-07-01_2015-10-21__untuned__Adults_3_Col__preds__20251009_012852.csv -> 4 PDFs
[ok] 8020__both_lags__train_keeps_zeros__2015-07-01_2015-10-21__untuned__Adults_3_Gam__preds__20251009_012852.csv -> 4 PDFs
[ok] 8020__both_lags__train_keeps_zeros__2015-07-01_2015-10-31__untuned__Adults_3_Col__preds__20251009_012853.csv 