# Git

## if failed in black (commit):
- ignore the failing pre-commit hooks / local checks and just commit + push to my branch”, 
- bypass them — as long as your remote (GitHub/GitLab) doesn’t block pushes via branch protection.
    - Commit without running hooks (pre-commit / git hooks)
        git commit -m "WIP: push despite failing hooks" --no-verify
    - Push without running pre-push hooks
        git push origin <your-branch> --no-verify

In [None]:
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# --------------------------
# small safety helpers
# --------------------------
def _existing_cols(df, cols):
    """Return only columns that exist (prevents KeyError)."""
    return [c for c in cols if c in df.columns]

def _to_numeric(s):
    """Convert pandas scalars/NAType/object -> float with NaN for bad values."""
    return pd.to_numeric(s, errors="coerce")

def _ensure_numeric_cols(df, cols):
    for c in cols:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce")
    return df

def _ensure_int_flag(df, col):
    """Make 0/1 columns numeric so .mean() works (fixes category dtype mean error)."""
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype("int64")
    return df

# --------------------------
# tidy interval stats (overall OR by group)
# --------------------------
_PCTS = [(0.10, "p10"), (0.25, "p25"), (0.50, "p50"), (0.75, "p75"), (0.90, "p90")]

def interval_summary(df, metrics, group_cols=None):
    """
    Returns tidy table:
    - overall: index=metric
    - grouped: columns=[*group_cols, metric, count, mean, std, min, p10, p25, p50, p75, p90, max]
    """
    metrics = _existing_cols(df, metrics)
    if not metrics:
        return pd.DataFrame()

    # make sure metrics are numeric
    df = df.copy()
    df = _ensure_numeric_cols(df, metrics)

    def _describe(s):
        s = s.dropna()
        if s.empty:
            return {"count": 0}
        out = {
            "count": int(s.shape[0]),
            "mean": float(s.mean()),
            "std": float(s.std(ddof=1)),
            "min": float(s.min()),
            "max": float(s.max()),
        }
        qs = s.quantile([p for p, _ in _PCTS])
        for p, name in _PCTS:
            out[name] = float(qs.loc[p])
        return out

    if not group_cols:
        rows = []
        for m in metrics:
            d = _describe(df[m])
            d["metric"] = m
            rows.append(d)
        out = pd.DataFrame(rows).set_index("metric")
        # enforce consistent columns ordering where present
        cols_order = ["count","mean","std","min","p10","p25","p50","p75","p90","max"]
        return out.reindex(columns=[c for c in cols_order if c in out.columns])

    group_cols = [c for c in group_cols if c in df.columns]
    if not group_cols:
        return interval_summary(df, metrics, group_cols=None)

    out_rows = []
    for keys, sub in df.groupby(group_cols, dropna=False):
        if not isinstance(keys, tuple):
            keys = (keys,)
        key_dict = dict(zip(group_cols, keys))
        for m in metrics:
            d = _describe(sub[m])
            out_rows.append({**key_dict, "metric": m, **d})
    out = pd.DataFrame(out_rows)
    cols_order = group_cols + ["metric","count","mean","std","min","p10","p25","p50","p75","p90","max"]
    return out.reindex(columns=[c for c in cols_order if c in out.columns])

# --------------------------
# plots
# --------------------------
def plot_metric_histograms(df, metric, group_col=None, outpath=None, bins=40):
    """Histogram overall, or small multiples via looping (keeps it readable)."""
    if metric not in df.columns:
        return None
    x = pd.to_numeric(df[metric], errors="coerce").dropna()
    if x.empty:
        return None

    fig, ax = plt.subplots(figsize=(7, 4))
    ax.hist(x, bins=bins)
    ax.set_title(f"Distribution: {metric}" + (" (overall)" if not group_col else ""))
    ax.set_xlabel("Days")
    ax.set_ylabel("Count")
    fig.tight_layout()
    if outpath:
        fig.savefig(outpath, dpi=200)
    plt.show()
    return outpath

def plot_group_median_bar(summary_df, group_col, metric, outpath=None):
    """Bar chart of median by group for ONE metric (tidy + interpretable)."""
    if summary_df.empty:
        return None
    sdf = summary_df.copy()
    sdf = sdf[(sdf["metric"] == metric)].copy()
    if group_col not in sdf.columns or "p50" not in sdf.columns:
        return None

    sdf["p50"] = pd.to_numeric(sdf["p50"], errors="coerce")
    sdf = sdf.dropna(subset=["p50"])

    fig, ax = plt.subplots(figsize=(8, 4))
    ax.bar(sdf[group_col].astype(str), sdf["p50"].astype(float))
    ax.set_title(f"Median {metric} by {group_col}")
    ax.set_ylabel("Days (median)")
    ax.set_xlabel(group_col)
    ax.tick_params(axis="x", rotation=45)
    fig.tight_layout()
    if outpath:
        fig.savefig(outpath, dpi=200)
    plt.show()
    return outpath

def pickup_rule_tables(di, wip_col="wip_band", gap_col="gap_band", event_col="event_newcase"):
    """
    Returns:
      prob: P(new case today | wip_band, gap_band)
      counts: staff-days per cell
    """
    needed = _existing_cols(di, [wip_col, gap_col, event_col])
    if len(needed) < 3:
        return pd.DataFrame(), pd.DataFrame()

    tmp = di[[wip_col, gap_col, event_col]].copy()
    tmp[event_col] = pd.to_numeric(tmp[event_col], errors="coerce")  # make mean work
    prob = tmp.groupby([wip_col, gap_col], dropna=False)[event_col].mean().unstack(gap_col)
    counts = tmp.groupby([wip_col, gap_col], dropna=False)[event_col].size().unstack(gap_col)
    return prob, counts

def plot_prob_heatmap(prob, outpath=None, title="P(new case | workload, gap)"):
    """Matplotlib heatmap; forces float dtype to avoid 'dtype object' imshow error."""
    if prob.empty:
        return None
    prob2 = prob.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    arr = prob2.to_numpy(dtype=float)

    fig, ax = plt.subplots(figsize=(7, 3.5))
    im = ax.imshow(arr, aspect="auto")
    ax.set_title(title)
    ax.set_yticks(range(len(prob2.index)))
    ax.set_yticklabels([str(i) for i in prob2.index])
    ax.set_xticks(range(len(prob2.columns)))
    ax.set_xticklabels([str(c) for c in prob2.columns], rotation=45, ha="right")
    fig.colorbar(im, ax=ax, fraction=0.02, pad=0.03)
    fig.tight_layout()
    if outpath:
        fig.savefig(outpath, dpi=200)
    plt.show()
    return outpath

# --------------------------
# THE wrapper
# --------------------------
def run_interval_outputs(
    typed,
    di,
    group_col="case_type",
    outdir="data/out/plot/plots",
    interval_metrics=None,
):
    """
    Replaces your hard-coded 'by case_type' section.
    Call it with group_col='case_type' OR group_col='application_type' etc.
    """
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)

    # choose metrics you want in the interval distribution table
    if interval_metrics is None:
        interval_metrics = ["days_to_pg_signoff", "days_alloc_to_close", "inter_pickup_days"]

    # ---------- interval distributions (overall) ----------
    overall = interval_summary(di, interval_metrics, group_cols=None)
    display(overall)

    # save tidy table
    overall.to_csv(outdir / "interval_summary_overall.csv")

    # simple histogram for the key metrics
    for m in _existing_cols(di, interval_metrics):
        plot_metric_histograms(di, m, outpath=outdir / f"hist_{m}_overall.png")

    # ---------- interval distributions (by chosen group) ----------
    by_group = interval_summary(di, interval_metrics, group_cols=[group_col])
    display(by_group.head(20))
    by_group.to_csv(outdir / f"interval_summary_by_{group_col}.csv", index=False)

    # a readable bar chart: median inter-pickup gap by group (or swap metric)
    if "inter_pickup_days" in by_group["metric"].unique():
        plot_group_median_bar(by_group, group_col=group_col, metric="inter_pickup_days",
                              outpath=outdir / f"median_inter_pickup_days_by_{group_col}.png")

    # ---------- rules table: P(new case | wip, gap) ----------
    prob, counts = pickup_rule_tables(di)
    display(prob)
    display(counts)
    prob.to_csv(outdir / "pickup_prob_matrix.csv")
    counts.to_csv(outdir / "pickup_counts_matrix.csv")

    plot_prob_heatmap(prob, outpath=outdir / "pickup_prob_heatmap.png")

    # OPTIONAL: bar chart for Low workload only (fixes your NAType bar error)
    if not prob.empty and "Low" in prob.index:
        rules_low = prob.loc["Low"].copy()
        rules_low = pd.to_numeric(rules_low, errors="coerce").dropna()
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.bar(rules_low.index.astype(str), rules_low.values.astype(float))
        ax.set_ylabel("P(new case today)")
        ax.set_title("Probability of new case vs gap since last pickup\n(Low workload band)")
        ax.tick_params(axis="x", rotation=45)
        fig.tight_layout()
        fig.savefig(outdir / "pickup_prob_low_bar.png", dpi=200)
        plt.show()

    return {
        "overall": overall,
        "by_group": by_group,
        "pickup_prob": prob,
        "pickup_counts": counts,
    }



# --------------------------
# Generate distributions for ALL metrics (overall + breakdowns)
# Helper: filter + tidy 
# (creates a metric column so you don’t hit KeyError: 'metric')
# --------------------------

def _filter_metrics(d: dict, metrics: list[str]) -> dict:
    """Keep only keys we care about AND that exist in the output dict."""
    return {m: d[m] for m in metrics if m in d}

def tidy_overall(interval_dists_overall: dict) -> pd.DataFrame:
    """
    interval_dists_overall is: {metric: stats_dict}
    -> DataFrame with a 'metric' column
    """
    df = pd.DataFrame(interval_dists_overall).T.reset_index().rename(columns={"index": "metric"})
    return df

def tidy_by(interval_dists_by: dict, by_cols: list[str]) -> pd.DataFrame:
    """
    interval_dists_by is: {metric: {group_key_tuple: stats_dict}}
    -> long DataFrame with columns: by_cols + ['metric'] + stats
    """
    rows = []
    for metric, groups in interval_dists_by.items():
        for gkey, stats in groups.items():
            if not isinstance(gkey, tuple):
                gkey = (gkey,)
            row = {"metric": metric}
            for col, val in zip(by_cols, gkey):
                row[col] = val
            row.update(stats)
            rows.append(row)
    return pd.DataFrame(rows)


# --------------------------
# “school holiday season” label (Dec/Jan, Apr, Jul/Aug)
# --------------------------
# “multi-week” gaps in new case starts / pickups are very plausibly driven by 
# (a) people being off the system for chunks of time (term-time contracts, annual leave, sickness, training), 
# and (b) school holiday blocks (Christmas, Easter-ish, summer). 
# Test this directly with a couple of lightweight additions to your existing notebook: 
# (1) tag each pickup/gap with a “holiday season” label, 
# (2) infer likely term-time workers from repeated summer/Christmas inactivity, 
# and (3) add a 3-month rolling average to smooth month-to-month volatility.
# long gaps are “people not starting cases during holidays”:
#     - higher inter_pickup_days ending in Jan (post-Christmas),
#     - higher gaps around Apr,
#     - higher gaps around Sep or late Aug (post-summer).  

def school_holiday_season(d: pd.Timestamp) -> str:
    """
    Coarse UK school-holiday seasons (not region-specific dates):
    - Christmas: mid-Dec to early Jan
    - Easter: April (approx)
    - Summer: Jul/Aug
    Everything else: Term-time/Other
    """
    if pd.isna(d):
        return "__NA__"
    d = pd.Timestamp(d)

    # Christmas window spanning year boundary
    if (d.month == 12 and d.day >= 15) or (d.month == 1 and d.day <= 10):
        return "Christmas"
    # Easter-ish (coarse; you can refine later)
    if d.month == 4:
        return "Easter"
    # Summer holidays
    if d.month in (7, 8):
        return "Summer"

    return "Other/Term"

def add_holiday_flags(df: pd.DataFrame, date_col: str = "date") -> pd.DataFrame:
    out = df.copy()
    out[date_col] = pd.to_datetime(out[date_col], errors="coerce")
    out["holiday_season"] = out[date_col].apply(school_holiday_season)
    out["is_school_holiday_season"] = out["holiday_season"].isin(["Christmas", "Easter", "Summer"])
    return out


# --------------------------
# Infer “term-time workers” (so they don’t inflate full-time medians)
# --------------------------
# This is worthwhile if term-time staff are recorded with FTE≈1 during on-strength weeks, 
# because your “Full-time” bucket will pick up their holiday gaps and inflate the median gap

# def infer_term_time_workers(
#     typed: pd.DataFrame,
#     staff_col: str = "staff_id",
#     alloc_col: str = "dt_alloc_invest",
#     fte_col: str = "fte",
#     min_cases_per_year: int = 8,
#     years_required: int = 2,
# ) -> pd.DataFrame:
#     """
#     Heuristic:
#     - Consider each staff-year where they have enough allocations (>= min_cases_per_year)
#     - If they have *zero* allocations in Jul/Aug in that year, count that as "summer off"
#     - If "summer off" happens in >= years_required distinct years, mark as term-time-like
#     """
#     df = typed[[staff_col, alloc_col, fte_col]].copy()
#     df[alloc_col] = pd.to_datetime(df[alloc_col], errors="coerce")
#     df = df.dropna(subset=[staff_col, alloc_col])

#     df["year"] = df[alloc_col].dt.year
#     df["month"] = df[alloc_col].dt.month

#     per_year_total = df.groupby([staff_col, "year"]).size().rename("n_allocs").reset_index()
#     per_year_summer = (
#         df[df["month"].isin([7, 8])]
#         .groupby([staff_col, "year"]).size()
#         .rename("n_summer_allocs")
#         .reset_index()
#     )

#     per_year = per_year_total.merge(per_year_summer, on=[staff_col, "year"], how="left")
#     per_year["n_summer_allocs"] = per_year["n_summer_allocs"].fillna(0)

#     # only judge years with enough activity
#     per_year = per_year[per_year["n_allocs"] >= min_cases_per_year].copy()
#     per_year["summer_off"] = per_year["n_summer_allocs"].eq(0)

#     staff_flag = (
#         per_year.groupby(staff_col)["summer_off"].sum()
#         .rename("years_summer_off")
#         .reset_index()
#     )
#     staff_flag["term_time_inferred"] = staff_flag["years_summer_off"] >= years_required

#     # Attach a stable staff FTE estimate (median)
#     staff_fte = typed.groupby(staff_col)[fte_col].median().rename("fte_median").reset_index()
#     out = staff_flag.merge(staff_fte, on=staff_col, how="left")

#     # final label
#     def label_row(r):
#         if bool(r.get("term_time_inferred", False)):
#             return "Term-time (inferred)"
#         if pd.isna(r.get("fte_median")):
#             return "__NA__"
#         return "Full-time" if r["fte_median"] >= 0.9 else "Part-time"

#     out["staff_work_pattern"] = out.apply(label_row, axis=1)
#     return out[[staff_col, "fte_median", "years_summer_off", "term_time_inferred", "staff_work_pattern"]]



# --------------------------
# 3-month rolling average to smooth volatility (Dec↔Jan etc.)
# --------------------------
def add_rolling_avg(series: pd.Series, window: int = 3) -> pd.Series:
    return series.rolling(window=window, min_periods=1).mean()




# --------------------------
# monthly counts from 2020 onwards
# --------------------------
# cases received (dt_received_inv)
# cases allocated (dt_alloc_invest)
# cases PG sign-off (dt_pg_signoff)
# cases closed (dt_close)


CASE_FLOW_DATE_COLS = {
    "received": "dt_received_inv",
    "allocated": "dt_alloc_invest",
    "pg_signoff": "dt_pg_signoff",
    "closed": "dt_close",
}

def monthly_case_flow_counts(
    typed: pd.DataFrame,
    start: str = "2020-01-01",
    end: str = "2025-10-31",
    date_cols: dict = CASE_FLOW_DATE_COLS,
    case_id_col: str = "case_id",
    distinct_cases: bool = True,
) -> pd.DataFrame:
    """
    Returns a wide monthly table with counts for each event type:
      month | received | allocated | pg_signoff | closed
    By default counts DISTINCT case_id per month per event.
    Set distinct_cases=False to count rows/events instead.
    """
    df = typed.copy()
    start_ts = pd.Timestamp(start)
    end_ts = pd.Timestamp(end)
    
    # Ensure datetime
    for _, col in date_cols.items():
        if col in df.columns:
            df[col] = pd.to_datetime(df[col], errors="coerce")

    out = None

    for label, col in date_cols.items():
        if col not in df.columns:
            continue

        sub = (
            df[[case_id_col, col]]
            .dropna(subset=[col])
            .loc[lambda x: ((x[col] >= start_ts) & (x[col] <= end_ts))]
            .assign(month=lambda x: x[col].dt.to_period("M").dt.to_timestamp())
        )

        if distinct_cases:
            s = sub.groupby("month")[case_id_col].nunique()
        else:
            s = sub.groupby("month")[case_id_col].size()

        tmp = s.rename(label).reset_index()
        out = tmp if out is None else out.merge(tmp, on="month", how="outer")

    if out is None:
        return pd.DataFrame(columns=["month"] + list(date_cols.keys()))

    # Fill missing months and zeros for nicer plots
    out = out.sort_values("month").set_index("month")
    full_idx = pd.date_range(
        start=start_ts.to_period("M").to_timestamp(),
        end=out.index.max(),
        freq="MS"
    )
    out = out.reindex(full_idx, fill_value=0)
    out.index.name = "month"
    out = out.reset_index()
    
    # Make integer counts safely (outer merges create NaNs)
    for c in out.columns:
        if c != "month":
            out[c] = pd.to_numeric(out[c], errors="coerce").fillna(0).astype(int)
    return out


def plot_monthly_case_flow(
    monthly_df: pd.DataFrame,
    rolling_months: int | None = None,
    title: str = "Monthly case flow counts (from 2020)",
    outpath=None,
):
    """
    Plots all numeric series in monthly_df (expects a 'month' column).
    If rolling_months is set (e.g., 3), plots a rolling mean to smooth volatility.
    """
    df = monthly_df.copy().set_index("month").sort_index()

    y = df
    if rolling_months and rolling_months > 1:
        y = df.rolling(rolling_months, min_periods=1).mean()

    fig, ax = plt.subplots(figsize=(11, 5))
    for col in y.columns:
        ax.plot(y.index, y[col], label=col)

    t = title
    if rolling_months and rolling_months > 1:
        t += f" — {rolling_months}m rolling avg"
    ax.set_title(t)
    ax.set_xlabel("Month")
    ax.set_ylabel("Number of cases")
    ax.legend()
    fig.autofmt_xdate()

    if outpath is not None:
        fig.savefig(outpath, dpi=150, bbox_inches="tight")

    plt.show()
    return fig, ax

def monthly_case_flow_by(
    typed: pd.DataFrame,
    group_cols: list[str] = ["case_type"],
    start="2020-01-01",
    date_cols=CASE_FLOW_DATE_COLS,
) -> pd.DataFrame:
    """
    Returns monthly counts of received / allocated / pg_signoff / closed
    grouped by case_type (or other columns in group_cols).
    """
    dfs = []
    start_ts = pd.Timestamp(start)
    for label, col in date_cols.items():
        if col not in typed.columns:
            continue
        sub = (
            typed.dropna(subset=[col])
            .assign(month=lambda x: pd.to_datetime(x[col]).dt.to_period("M").dt.to_timestamp())
            .loc[lambda x: x["month"] >= start_ts]
        )
        g = (
            sub.groupby(group_cols + ["month"])["case_id"]
            .nunique()
            .rename(label)
            .reset_index()
        )
        dfs.append(g)

    out = dfs[0]
    for d in dfs[1:]:
        out = out.merge(d, on=group_cols + ["month"], how="outer")

    out = out.fillna(0).sort_values(group_cols + ["month"])
    return out



def infer_term_time_workers(
    typed: pd.DataFrame,
    staff_col: str = "staff_id",
    alloc_date_col: str = "dt_alloc_invest",
    case_id_col: str = "case_id",
    start: str = "2020-01-01",
    # school holiday months: Jan (xmas spillover), Apr, Jul, Aug, Dec
    holiday_months=(1, 4, 7, 8, 12),
    min_months: int = 12,
    holiday_gap_threshold: float = 0.25,
) -> pd.DataFrame:
    """
    Flag staff as 'term-time-like' if they have disproportionately many zero-allocation months
    in school-holiday months compared to term months.
    Returns one row per staff_id with diagnostics + a boolean flag.
    """
    df = typed.copy()

    if staff_col not in df.columns:
        raise KeyError(f"'{staff_col}' not in typed.columns")

    # Choose allocation date column (fallbacks if you use different naming)
    if alloc_date_col not in df.columns:
        for c in ["dt_alloc_invest", "dt_alloc_team"]:
            if c in df.columns:
                alloc_date_col = c
                break
        else:
            raise KeyError(
                f"'{alloc_date_col}' not found. Expected something like dt_alloc_invest / dt_alloc_team."
            )

    df[alloc_date_col] = pd.to_datetime(df[alloc_date_col], errors="coerce")
    df = df[df[alloc_date_col].notna()].copy()

    if start is not None:
        df = df[df[alloc_date_col] >= pd.Timestamp(start)]

    # Monthly bucket
    df["month"] = df[alloc_date_col].dt.to_period("M").dt.to_timestamp()

    # Monthly allocations per staff (distinct cases if case_id exists)
    if case_id_col in df.columns:
        monthly = (
            df.groupby([staff_col, "month"])[case_id_col]
            .nunique()
            .rename("alloc_cases")
            .reset_index()
        )
    else:
        monthly = (
            df.groupby([staff_col, "month"])
            .size()
            .rename("alloc_cases")
            .reset_index()
        )

    if monthly.empty:
        # No data -> return empty frame with expected columns
        return pd.DataFrame(
            columns=[
                staff_col,
                "term_time_like",
                "term_time_band",
                "n_months",
                "active_months",
                "holiday_zero_rate",
                "term_zero_rate",
                "holiday_minus_term",
            ]
        )

    # Build complete staff x month grid so missing months become 0
    all_staff = monthly[staff_col].dropna().unique()
    month_index = pd.date_range(monthly["month"].min(), monthly["month"].max(), freq="MS")

    grid = (
        pd.MultiIndex.from_product([all_staff, month_index], names=[staff_col, "month"])
        .to_frame(index=False)
    )

    monthly = grid.merge(monthly, on=[staff_col, "month"], how="left")
    monthly["alloc_cases"] = monthly["alloc_cases"].fillna(0)

    monthly["is_holiday_month"] = monthly["month"].dt.month.isin(list(holiday_months))
    monthly["is_zero"] = monthly["alloc_cases"].eq(0)

    # Rates: zero-month fraction in holiday vs term months
    rates = (
        monthly.groupby([staff_col, "is_holiday_month"])["is_zero"]
        .mean()
        .unstack()
        .rename(columns={False: "term_zero_rate", True: "holiday_zero_rate"})
    )

    # Meta: months observed and active months
    meta = monthly.groupby(staff_col).agg(
        n_months=("month", "nunique"),
        active_months=("alloc_cases", lambda s: (s > 0).sum()),
    )

    out = meta.join(rates).reset_index()

    out["holiday_minus_term"] = out["holiday_zero_rate"] - out["term_zero_rate"]
    out["term_time_like"] = (
        (out["n_months"] >= min_months) &
        (out["holiday_minus_term"] >= holiday_gap_threshold)
    )
    out["term_time_band"] = np.where(out["term_time_like"], "Term-time-like", "Other")

    return out[
        [
            staff_col,
            "term_time_like",
            "term_time_band",
            "n_months",
            "active_months",
            "holiday_zero_rate",
            "term_zero_rate",
            "holiday_minus_term",
        ]
    ]



def _group_key_to_label(key, group_cols):
    """Pretty label for legend when grouping by 1 or many cols."""
    if not isinstance(key, tuple):
        key = (key,)
    parts = []
    for c, v in zip(group_cols, key):
        parts.append(f"{c}={v}")
    return " | ".join(parts)

def _top_groups_mask(df, group_cols, value_col, top_n):
    """Keep only top_n groups by total volume (helps if concern_type has many levels)."""
    if top_n is None:
        return pd.Series(True, index=df.index)

    totals = (
        df.groupby(group_cols, dropna=False)[value_col]
          .sum()
          .sort_values(ascending=False)
    )

    if isinstance(totals.index, pd.MultiIndex):
        keep = set(totals.head(top_n).index.tolist())
        keys = list(map(tuple, df[group_cols].astype("object").to_numpy()))
    else:
        keep = set((k,) for k in totals.head(top_n).index.tolist())
        keys = [(k,) for k in df[group_cols[0]].astype("object").to_numpy()]

    return pd.Series([k in keep for k in keys], index=df.index)

def plot_monthly_flow_lines(
    monthly_df: pd.DataFrame,
    group_cols: list,
    value_col: str,
    outpath,
    title: str,
    step_date="2022-10-01",
    rolling=None,          # e.g. 3 for 3-month rolling average
    top_n=12,              # keep plots readable for high-cardinality columns
):
    outpath = Path(outpath)
    outpath.parent.mkdir(parents=True, exist_ok=True)

    if "month" not in monthly_df.columns:
        raise KeyError("monthly_df must include a 'month' column")

    if value_col not in monthly_df.columns:
        print(f"[skip plot] '{value_col}' not in monthly_df")
        return

    df = monthly_df.copy()
    df["month"] = pd.to_datetime(df["month"])

    # Keep only top groups (optional)
    mask = _top_groups_mask(df, group_cols, value_col=value_col, top_n=top_n)
    df = df.loc[mask].copy()

    fig, ax = plt.subplots(figsize=(14, 6))

    for key, sub in df.groupby(group_cols, dropna=False):
        sub = sub.sort_values("month")
        y = sub[value_col]

        # optional rolling average (per group)
        if rolling is not None and rolling > 1:
            y = y.rolling(window=rolling, min_periods=1).mean()

        ax.plot(sub["month"], y, label=_group_key_to_label(key, group_cols), alpha=0.65)
    ax.axvline(pd.Timestamp(step_date), color="red", ls="--", lw=2, label=f"Step change {step_date}")
    ax.set_title(title)
    ax.set_ylabel("Number of cases")
    # 1) create the legend and KEEP a handle to it
    leg = ax.legend(loc="center left", 
                    bbox_to_anchor=(1.02, 0.5),
                    frameon=False, fontsize=8)
    # ax.legend(loc="best")
    # ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)
    fig.subplots_adjust(right=0.78)
    # 2) leave space for the legend on the right (optional but helps)
    fig.tight_layout(rect=[0, 0, 0.78, 1])
    #plt.tight_layout()
    # 3) save including the legend explicitly + a bit of padding
    fig.savefig(outpath, dpi=150,
                bbox_inches="tight",
                bbox_extra_artists=(leg,),
                pad_inches=0.2)
    plt.show()

def monthly_flow_outputs_by_breakdown(
    typed: pd.DataFrame,
    group_cols: list,
    start="2020-01-01",
    outcsv_dir=None,
    outplot_dir=None,
    name=None,
    plot_cols=("received", "pg_signoff"),
    step_date="2022-10-01",
    rolling=None,
    top_n=12,
):
    # defensive: skip if columns missing
    missing = [c for c in group_cols if c not in typed.columns]
    if missing:
        print(f"[skip] {name or group_cols}: missing columns {missing}")
        return None

    monthly = monthly_case_flow_by(typed, group_cols, start=start)

    tag = name or "_".join(group_cols)

    # save CSV
    if outcsv_dir is not None:
        outcsv_dir = Path(outcsv_dir)
        outcsv_dir.mkdir(parents=True, exist_ok=True)
        monthly.to_csv(outcsv_dir / f"monthly_flow_by_{tag}.csv", index=False)

    # plots
    if outplot_dir is not None:
        outplot_dir = Path(outplot_dir)
        outplot_dir.mkdir(parents=True, exist_ok=True)

        for col in plot_cols:
            plot_monthly_flow_lines(
                monthly_df=monthly,
                group_cols=group_cols,
                value_col=col,
                outpath=outplot_dir / f"monthly_{col}_by_{tag}.png",
                title=f"Monthly {col} by {tag}",
                step_date=step_date,
                rolling=rolling,
                top_n=top_n,
            )

    return monthly

def run_monthly_flow_breakdowns(
    typed: pd.DataFrame,
    breakdowns: dict,
    start="2020-01-01",
    outcsv_dir=None,
    outplot_dir=None,
    plot_cols=("received", "pg_signoff"),
    step_date="2022-10-01",
    rolling=None,
    top_n=12,
):
    outputs = {}
    for name, cols in breakdowns.items():
        monthly = monthly_flow_outputs_by_breakdown(
            typed=typed,
            group_cols=cols,
            start=start,
            outcsv_dir=outcsv_dir,
            outplot_dir=outplot_dir,
            name=name,
            plot_cols=plot_cols,
            step_date=step_date,
            rolling=rolling,
            top_n=top_n,
        )
        if monthly is not None:
            outputs[name] = monthly
    return outputs


In [None]:
# One collective end-to-end demo


import pandas as pd

from preprocessing import load_raw, engineer
from time_series import build_backlog_series, build_daily_panel
from interval_analysis import IntervalAnalysis, plot_pg_signoff_monthly_trends, plot_allocation_monthly_trends
from eda_opg import EDAConfig, OPGInvestigationEDA
from distributions import interval_change_distribution



def demo_all():
    from pathlib import Path

    DEFAULT_INTERVAL_METRICS = [
        # “new case start” gap
        "inter_pickup_days",
        # alloc → PG sign-off (already engineered in preprocessing.py)
        "days_to_signoff",
        # received → PG sign-off
        "days_to_pg_signoff",
        # alloc → close
        "days_alloc_to_close",
        # received/alloc → legal review request
        "days_recieved_to_legal_review",
        "days_alloc_to_req_legal_review",
        # received → alloc
        "days_to_alloc",
    ]
    
    def _describe_days(s: pd.Series) -> dict:
        s = pd.to_numeric(s, errors="coerce").dropna()
        if s.empty:
            return {"count": 0, "mean": np.nan, "std": np.nan, "min": np.nan,
                    "p10": np.nan, "p25": np.nan, "p50": np.nan, "p75": np.nan, "p90": np.nan, "max": np.nan}
        return {
            "count": int(s.shape[0]),
            "mean": float(s.mean()),
            "std": float(s.std(ddof=1)),
            "min": float(s.min()),
            "p10": float(s.quantile(0.10)),
            "p25": float(s.quantile(0.25)),
            "p50": float(s.quantile(0.50)),
            "p75": float(s.quantile(0.75)),
            "p90": float(s.quantile(0.90)),
            "max": float(s.max()),
        }
    
    def interval_distributions_by(di: pd.DataFrame, by: list[str], metrics: list[str] | None = None) -> dict[str, pd.DataFrame]:
        metrics = metrics or [m for m in DEFAULT_INTERVAL_METRICS if m in di.columns]
        out: dict[str, pd.DataFrame] = {}
        for metric in metrics:
            tbl = (
                di.groupby(by, dropna=False)[metric]
                  .apply(_describe_days)
                  .apply(pd.Series)
                  .reset_index()
            )
            out[metric] = tbl
        return out
    
    def save_and_plot_interval_breakdowns(
        di: pd.DataFrame,
        by: list[str],
        outdir: Path,
        label: str | None = None,
        metrics: list[str] | None = None,
        max_categories: int = 25,
    ) -> dict[str, pd.DataFrame]:
    
        label = label or "_".join(by)
        outdir = Path(outdir)
        outdir.mkdir(parents=True, exist_ok=True)
    
        tables = interval_distributions_by(di, by=by, metrics=metrics)
    
        for metric, df in tables.items():
            # Save tidy table
            df.to_csv(outdir / f"interval_dists_by_{label}__{metric}.csv", index=False)
    
            # Bar chart of mean days
            plot_df = df.copy()
            plot_df["mean"] = pd.to_numeric(plot_df["mean"], errors="coerce")
            plot_df = plot_df.dropna(subset=["mean"])
    
            if len(by) == 1 and plot_df.shape[0] > max_categories:
                plot_df = plot_df.sort_values("count", ascending=False).head(max_categories)
    
            if len(by) == 1:
                x = plot_df[by[0]].astype(str)
            else:
                x = plot_df[by].astype(str).agg(" | ".join, axis=1)
    
            fig, ax = plt.subplots(figsize=(10, 4))
            ax.bar(x, plot_df["mean"])
            ax.set_ylabel(f"Mean {metric} (days)")
            ax.set_title(f"Mean {metric} by {label}")
            ax.tick_params(axis="x", rotation=45, labelsize=8)
            plt.tight_layout()
            plt.savefig(outdir / f"bar_mean_{metric}_by_{label}.png", dpi=150)
            plt.close(fig)
    
        return tables

        
    outdir = "data/out/plot/plots"
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)
    
    outcsv = "data/out/csv/csvs"
    outcsv = Path(outcsv)
    outcsv.mkdir(parents=True, exist_ok=True)
    
    raw, colmap = load_raw(Path("data/raw/raw.csv")) # REAL DATA
    typed = engineer(raw, colmap)
    
    # infer term-time-like pattern and merge back
    staff_pattern = infer_term_time_workers(typed)
    # keep existing logic expecting "staff_work_pattern"
    staff_pattern["staff_work_pattern"] = staff_pattern["term_time_band"]
    typed = typed.merge(staff_pattern, on="staff_id", how="left")
    typed = typed.copy()


        
    # --------------------------
    # Monthly case flow counts (from 2020 onwards)
    # --------------------------
    # ---- Monthly case flow counts (from 2020 onwards) ----
    outcsv = Path("data/out/csv/csvs/monthly_flows")
    outcsv.mkdir(parents=True, exist_ok=True)
    outdir = Path("data/out/plot/plots/monthly_flows")
    outdir.mkdir(parents=True, exist_ok=True)
    
    # Distinct vs events: distinct_cases=True counts unique case_id per month per event (usually what people mean by “number of cases”). 
    # If you have multiple allocations per case and you want to count allocations rather than cases, set distinct_cases=False.
    # Missing columns: if any of the date columns don’t exist (e.g., dt_close missing), the function will just skip that series
    # and still produce the others.
    # Date column names: if your engineered dataframe uses different names, update CASE_FLOW_DATE_COLS.
    
    monthly_flows = monthly_case_flow_counts(
        typed,
        start="2020-01-01",
        end="2025-10-31",
        distinct_cases=True,   # change to False if you want to count events/rows
    )
    
    monthly_flows.to_csv(outcsv / "monthly_case_flow_counts_from_2020.csv", index=False)
    
    # One chart with all four series
    plot_monthly_case_flow(
        monthly_flows,
        rolling_months=None,
        outpath=outdir / "monthly_case_flow_counts_from_2020.png",
    )
    
    # Optional: smoother view (3-month rolling average)
    plot_monthly_case_flow(
        monthly_flows,
        rolling_months=3,
        outpath=outdir / "monthly_case_flow_counts_from_2020_roll3.png",
    )



        
    # --------------------------
    # Monthly case flow counts case-type trends (from 2020 onwards)
    # --------------------------
    FLOW_BREAKDOWNS = {
        "case_type": ["case_type"],
        "application_type": ["application_type"],
        "legal_review": ["legal_review"],
        "concern_type": ["concern_type"],
        # NOTE: this combo can create loads of lines; top_n will keep it readable
        "case_status_app_legal": ["case_type", "application_type", "legal_review", "concern_type"],
    }
    
    monthly_flow_outputs = run_monthly_flow_breakdowns(
        typed,
        breakdowns=FLOW_BREAKDOWNS,
        start="2020-01-01",
        outcsv_dir=outcsv,  # e.g. Path("data/out/csv")
        outplot_dir=outdir / "monthly_flows_by_breakdown",  # e.g. Path("data/out/plot/plots")
        plot_cols=("received", "pg_signoff"),   # add "closed"/"allocated" if your monthly fn includes them
        step_date="2022-10-01",
        rolling=3,     # set None to turn off smoothing
        top_n=12,      # increase if you want more lines on high-cardinality columns
    )

    # monthly_flow_outputs_by_breakdown(
    #     typed,
    #     group_cols=["some_other_cat_col"],
    #     start="2020-01-01",
    #     outcsv_dir=outcsv,
    #     outplot_dir=outdir / "monthly_flows_by_breakdown",
    #     name="some_other_cat_col",
    #     rolling=3,
    #     top_n=15,
    # )


    # monthly_by_case_type = monthly_case_flow_by(typed, ["case_type"], start="2020-01-01")
    # monthly_by_case_type.to_csv(outcsv / "monthly_by_case_type_flow_by_case_type.csv", index=False)

    # # recieved
    # fig, ax = plt.subplots(figsize=(10,6))
    # for ct in monthly_by_case_type["case_type"].unique():
    #     sub = monthly_by_case_type[monthly_by_case_type["case_type"]==ct]
    #     ax.plot(sub["month"], sub["received"], label=f"{ct} received", alpha=0.6)
    # ax.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--", lw=2, label="Step change Oct 2022")
    # ax.set_title("Monthly investigations (received) by case type")
    # ax.set_ylabel("Number of cases")
    # ax.legend()
    # plt.tight_layout()
    # plot_monthly_by_case_type_pg_received = outdir / "monthly_by_case_type_pg_received.png"
    # plt.savefig(plot_monthly_by_case_type_pg_received, dpi=150)
    # plt.show()
    
    # # signoff
    # fig, ax = plt.subplots(figsize=(10,6))
    # for ct in monthly_by_case_type["case_type"].unique():
    #     sub = monthly_by_case_type[monthly_by_case_type["case_type"]==ct]
    #     ax.plot(sub["month"], sub["pg_signoff"], label=f"{ct} pg_signoff", alpha=0.6)
    # ax.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--", lw=2, label="Step change Oct 2022")
    # ax.set_title("Monthly investigations (pg_signoff) by case type")
    # ax.set_ylabel("Number of cases")
    # ax.legend()
    # plt.tight_layout()
    # plot_monthly_by_case_type_pg_signoff = outdir / "monthly_by_case_type_pg_signoff.png"
    # plt.savefig(plot_monthly_by_case_type_pg_signoff, dpi=150)
    # plt.show()
    
    # --------------------------
    # Link workforce metrics (FTE) - demand vs staff
    # --------------------------
    # fte_monthly = (
    #     typed.groupby(typed["dt_alloc_invest"].dt.to_period("M"))["fte"]
    #     .mean()
    #     .rename("avg_fte")
    #     .reset_index()
    # )
    # fte_monthly["month"] = fte_monthly["dt_alloc_invest"].dt.to_timestamp()
    # plt.figure(figsize=(10,5))
    # plt.plot(monthly_flows["month"], monthly_flows["received"], label="New investigations")
    # plt.plot(fte_monthly["month"], fte_monthly["avg_fte"]*10, label="Avg FTE × 10 (scaled)")
    # plt.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--")
    # plt.legend(); plt.title("New investigations vs staff FTE")
    # plt.show()
    # plot_fte_monthly = outdir / "fte_monthly.png"
    # plt.savefig(plot_fte_monthly, dpi=150)

    # avg_fte line is “flat” because you’re taking the mean FTE across allocated cases each month 
    # — if most staff have stable FTE (often 1.0), the monthly mean won’t move much.
    # What we want instead is a monthly headcount of investigators available/active, 
    # split into full-time vs part-time, and optionally a monthly total FTE capacity (sum of FTEs of active staff).


    # --- configure ---
    FTE_FULL_TIME_CUTOFF = 0.8              # change threshold if needed
    STAFF_MONTH_DATE_COL = "dt_alloc_invest"  # "activity" definition (see note below)
    #     “available/active” = had at least one allocation (dt_alloc_invest) that month.
    # If you want “active” to mean any work happened, switch STAFF_MONTH_DATE_COL to another date 
    #     (e.g., dt_pg_signoff, dt_close) or build activity from multiple event dates (I can give you that version too).
    
    # 1) Estimate each staff member's typical FTE (use median to be robust)
    staff_fte = (
        typed.dropna(subset=["staff_id"])
            .groupby("staff_id")["fte"]
            .median()
            .rename("staff_fte")
            .reset_index()
    )
    
    # 2) Classify staff as Full-time / Part-time (and keep Unknown if missing)
    staff_fte["fte_band"] = np.where(
        staff_fte["staff_fte"].ge(FTE_FULL_TIME_CUTOFF), "Full-time",
        np.where(staff_fte["staff_fte"].notna(), "Part-time", "Unknown")
    )
    
    # 3) Attach staff_fte + band back onto typed
    typed_staff = typed.merge(staff_fte, on="staff_id", how="left")
    typed_staff = typed_staff[typed_staff["dt_alloc_invest"] >= pd.Timestamp("2020-01-01")]
    
    # 4) Build staff "active this month" based on STAFF_MONTH_DATE_COL
    staff_active = (
        typed_staff.dropna(subset=[STAFF_MONTH_DATE_COL, "staff_id"])
            .assign(month=lambda d: d[STAFF_MONTH_DATE_COL].dt.to_period("M").dt.to_timestamp())
            [["month", "staff_id", "staff_fte", "fte_band"]]
            .drop_duplicates(["month", "staff_id"])   # key: count each staff once per month
    )
    
    # 5) Monthly headcount (unique staff) by band + total FTE capacity by band
    staff_monthly = (
        staff_active.groupby(["month", "fte_band"], as_index=False)
            .agg(
                n_staff=("staff_id", "nunique"),
                sum_fte=("staff_fte", "sum"),
            )
    )
    
    # Wide forms for plotting
    staff_counts_wide = (
        staff_monthly.pivot(index="month", columns="fte_band", values="n_staff")
        .fillna(0)
        .reset_index()
    )
    staff_fte_wide = (
        staff_monthly.pivot(index="month", columns="fte_band", values="sum_fte")
        .fillna(0)
        .reset_index()
    )
    
    # Convenience totals
    for df in (staff_counts_wide, staff_fte_wide):
        cols = [c for c in df.columns if c != "month"]
        df["total"] = df[cols].sum(axis=1)

    # Ensure months are aligned (optional but helpful)
    mf = monthly_flows.copy()
    mf["month"] = pd.to_datetime(mf["month"])
    
    fig, ax1 = plt.subplots(figsize=(10, 5))
    
    # Demand side (cases)
    if "received" in mf.columns:
        ax1.plot(mf["month"], mf["received"], label="Received (new investigations)", color="cyan")
    if "pg_signoff" in mf.columns:
        ax1.plot(mf["month"], mf["pg_signoff"], label="PG sign-offs", color="pink")
    # if "closed" in mf.columns:
    #     ax1.plot(mf["month"], mf["closed"], label="Closed")
    
    ax1.set_ylabel("Cases per month")
    ax1.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--")
    
    # Supply side (staff headcount)
    ax2 = ax1.twinx()
    if "Full-time" in staff_counts_wide.columns:
        ax2.plot(staff_counts_wide["month"], staff_counts_wide["Full-time"], label="Full-time investigators")
    if "Part-time" in staff_counts_wide.columns:
        ax2.plot(staff_counts_wide["month"], staff_counts_wide["Part-time"], label="Part-time investigators")
    ax2.plot(staff_counts_wide["month"], staff_counts_wide["total"], label="Total investigators", linestyle=":")
    
    ax2.set_ylabel("Investigators active (unique staff/month)")
    
    # combine legends
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper left")
    ax1.set_title("Monthly demand vs investigator availability (FT/PT headcount)")
    plt.tight_layout()
    plot_fte_monthly = outdir / "demand_vs_staff_headcount.png"
    plt.savefig(plot_fte_monthly, dpi=150)
    plt.show()
    
    # Quick check: did staff mix change around Oct 2022?
    # This gives a simple pre/post comparison including “PG sign-offs per investigator”:
    cut = pd.Timestamp("2022-10-01")
    # merge demand + staff totals
    staff_tot = staff_counts_wide[["month", "total"]].rename(columns={"total": "n_investigators"})
    merged = mf.merge(staff_tot, on="month", how="left")
    
    merged["pg_signoff_per_investigator"] = np.where(
        merged["n_investigators"] > 0,
        merged.get("pg_signoff", np.nan) / merged["n_investigators"],
        np.nan
    )
    summary = (
        merged.assign(period=np.where(merged["month"] < cut, "pre", "post"))
              .groupby("period")[["received", "pg_signoff", "n_investigators", "pg_signoff_per_investigator"]] # , "closed"
              .mean(numeric_only=True)
    )
    print(summary)

    
    # --------------------------
    # “case age” lens (from receipt → PG sign-off)
    # --------------------------
    typed["case_age_days"] = (typed["dt_pg_signoff"] - typed["dt_received_inv"]).dt.days
    typed["age_band"] = pd.cut(
        typed["case_age_days"],
        bins=[0,90,180,365,730,5000],
        labels=["<3 m","3-6 m","6-12 m","1-2 y",">2 y"]
    )
    age_monthly = (
        typed.groupby([typed["dt_received_inv"].dt.to_period("M"),"age_band"])["case_id"]
        .nunique()
        .reset_index()
        .rename(columns={"dt_received_inv":"month"})
    )
    age_monthly["month"]=age_monthly["month"].dt.to_timestamp()
    age_monthly.to_csv(outcsv / "age_monthly_case_flow_counts_from_2020.csv", index=False)
    # ---- tidy/standardise ----
    age_monthly = age_monthly.copy()
    age_monthly["month"] = pd.to_datetime(age_monthly["month"])
    
    # if your count column is still called case_id, rename it
    if "case_id" in age_monthly.columns and "n_cases" not in age_monthly.columns:
        age_monthly = age_monthly.rename(columns={"case_id": "n_cases"})
    
    # pivot to wide
    pivot = (
        age_monthly
        .pivot_table(index="month", columns="age_band", values="n_cases", aggfunc="sum")
        .fillna(0)
        .sort_index()
    )
    
    # ensure continuous monthly index
    all_months = pd.date_range(pivot.index.min(), pivot.index.max(), freq="MS")
    pivot = pivot.reindex(all_months, fill_value=0)
    pivot.index.name = "month"
    
    # ---- plot ----
    fig, ax = plt.subplots(figsize=(11, 5))
    ax.stackplot(
        pivot.index,
        [pivot[c].values for c in pivot.columns],
        labels=[str(c) for c in pivot.columns],
        alpha=0.9
    )
    ax.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--", lw=2, label="Oct 2022")
    ax.set_title("Monthly case counts by Case Age Band (Received - PG signed off) ")
    ax.set_ylabel("Number of cases")
    ax.legend(loc="upper left", ncol=3)
    plt.tight_layout()
    plot_age_monthly = outdir / "age_monthly.png"
    plt.savefig(plot_age_monthly, dpi=150)
    plt.show()
    fig, ax = plt.subplots(figsize=(11, 5))
    bottom = np.zeros(len(pivot.index))
    
    for col in pivot.columns:
        ax.bar(pivot.index, pivot[col].values, bottom=bottom, label=str(col), width=25)  # ~month width
        bottom += pivot[col].values
    
    ax.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--", lw=2)
    ax.set_title("Monthly case counts by Case Age Band (Received - PG signed off) stacked bars")
    ax.set_ylabel("Number of cases")
    ax.legend(loc="upper left", ncol=3)
    plt.tight_layout()
    plot_age_monthly_stacked_car_chart = outdir / "age_monthly_stacked_car_chart.png"
    plt.savefig(plot_age_monthly_stacked_car_chart, dpi=150)
    plt.show()
    total = pivot.sum(axis=1)
    roll3 = total.rolling(3, min_periods=1).mean()
    
    plt.figure(figsize=(11, 4))
    plt.plot(total.index, total.values, label="Monthly total")
    plt.plot(roll3.index, roll3.values, label="3-month rolling avg")
    plt.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--", lw=2)
    plt.title("Monthly total (All Case Age Band (Received - PG signed off)) with 3-month rolling average")
    plt.ylabel("Number of cases")
    plt.legend()
    plt.tight_layout()
    plot_age_monthly_stacked_car_chart_3m = outdir / "age_monthly_stacked_car_chart_3m.png"
    plt.savefig(plot_age_monthly_stacked_car_chart_3m, dpi=150)
    plt.show()
    
    pivot_roll3 = pivot.rolling(3, min_periods=1).mean()
    fig, ax = plt.subplots(figsize=(11, 5))
    ax.stackplot(
        pivot_roll3.index,
        [pivot_roll3[c].values for c in pivot_roll3.columns],
        labels=[str(c) for c in pivot_roll3.columns],
        alpha=0.9
    )
    ax.axvline(pd.Timestamp("2022-10-01"), color="red", ls="--", lw=2)
    ax.set_title("Case Age Band (Received - PG signed off) volumes (3-month rolling average)")
    ax.set_ylabel("Avg cases / month (3-mo)")
    ax.legend(loc="upper left", ncol=3)
    plt.tight_layout()
    plot_age_monthly_stacked_car_chart_3m_ageband = outdir / "age_monthly_stacked_car_chart_3m_ageband.png"
    plt.savefig(plot_age_monthly_stacked_car_chart_3m_ageband, dpi=150)
    plt.show()


    # make sure key date columns are datetime (skip if not present)
    for c in ["dt_received_inv", "dt_alloc_invest", "dt_pg_signoff", "dt_close", "dt_legal_review_req1"]:
        if c in typed.columns:
            typed[c] = pd.to_datetime(typed[c], errors="coerce")
    
    # alloc → PG sign-off
    if "days_to_signoff" not in typed.columns and {"dt_pg_signoff", "dt_alloc_invest"} <= set(typed.columns):
        typed["days_to_signoff"] = (typed["dt_pg_signoff"] - typed["dt_alloc_invest"]).dt.days
    
    # alloc → close
    if "days_alloc_to_close" not in typed.columns and {"dt_close", "dt_alloc_invest"} <= set(typed.columns):
        typed["days_alloc_to_close"] = (typed["dt_close"] - typed["dt_alloc_invest"]).dt.days
    
    # received → PG sign-off
    if "days_to_pg_signoff" not in typed.columns and {"dt_pg_signoff", "dt_received_inv"} <= set(typed.columns):
        typed["days_to_pg_signoff"] = (typed["dt_pg_signoff"] - typed["dt_received_inv"]).dt.days
    
    # received → alloc
    if "days_to_alloc" not in typed.columns and {"dt_alloc_invest", "dt_received_inv"} <= set(typed.columns):
        typed["days_to_alloc"] = (typed["dt_alloc_invest"] - typed["dt_received_inv"]).dt.days
    
    # received → legal review request
    if "days_recieved_to_legal_review" not in typed.columns and {"dt_legal_review_req1", "dt_received_inv"} <= set(typed.columns):
        typed["days_recieved_to_legal_review"] = (typed["dt_legal_review_req1"] - typed["dt_received_inv"]).dt.days
    
    # alloc → legal review request
    if "days_alloc_to_req_legal_review" not in typed.columns and {"dt_legal_review_req1", "dt_alloc_invest"} <= set(typed.columns):
        typed["days_alloc_to_req_legal_review"] = (typed["dt_legal_review_req1"] - typed["dt_alloc_invest"]).dt.days

    
    if "inter_pickup_days" not in typed.columns and "time_since_last_pickup" in typed.columns:
        typed["inter_pickup_days"] = typed["time_since_last_pickup"]
        
    typed['days_to_signoff'] = typed["days_to_pg_signoff"]
    typed["legal_review"] = pd.to_numeric(typed["legal_review"], errors="coerce").fillna(0).astype("int8")
    # or: typed["legal_review"] = typed["legal_review"].astype("int8")

    typed.to_csv(outcsv / "typed.csv", index=False)
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    typed['days_to_alloc'].hist(ax=axes[0], bins=50)
    axes[0].set_title('Days to allocation (all cases)')
    axes[0].set_xlabel('Days')
    axes[0].set_ylabel('Number of cases')
    typed['days_to_signoff'].hist(ax=axes[1], bins=50)
    axes[1].set_title('Days to sign-off (all cases)')
    axes[1].set_xlabel('Days')
    plt.tight_layout()
    plot_outliers = outdir / "outliers.png"
    plt.savefig(plot_outliers, bbox_inches="tight", dpi=150)

    # Time-series panels
    backlog = build_backlog_series(typed)
    if "backlog_available" in backlog.columns and "backlog" not in backlog.columns:
        backlog = backlog.rename(columns={"backlog_available": "backlog"})
    daily, backlog_ts, events = build_daily_panel(typed)
    daily.to_csv(outcsv / "daily.csv", index=False)
    backlog_ts.to_csv(outcsv / "backlog_ts.csv", index=False)
    events.to_csv(outcsv / "events.csv", index=False)
    
    # Interval frame for analysis (includes wip_load, time_since_last_pickup, event_newcase, etc.)
    di = IntervalAnalysis.build_interval_frame(
        typed, 
        backlog_series=backlog_ts#,
        #bank_holidays=True
    )
    
    # add holiday flags on di "date"
    #di = add_holiday_flags(di, date_col="date")
    
    # bring staff pattern onto di (if staff_id is present in di; otherwise merge via case_id from typed)
    if "staff_id" in di.columns:
        di = di.merge(staff_pattern[["staff_id", "staff_work_pattern"]], on="staff_id", how="left")
    else:
        di = di.merge(typed[["case_id", "staff_id"]], on="case_id", how="left")
        di = di.merge(staff_pattern[["staff_id", "staff_work_pattern"]], on="staff_id", how="left")

        
    # Ensure `di` contains case-level fields needed for breakdowns (e.g. application_type).
    # `IntervalAnalysis.build_interval_frame()` keeps a fixed set of columns, so we merge extra case-level fields here.
    if "application_type" in typed.columns and "application_type" not in di.columns:
        _app = typed[["case_id", "application_type"]].drop_duplicates("case_id").copy()
        # Robust join even if case_id is stored as int in one frame and string in the other
        _app["case_id"] = _app["case_id"].astype("string")
        di["case_id"] = di["case_id"].astype("string")
        di = di.merge(_app, on="case_id", how="left")


    # All the “interval” analysis we want should be plugged in after di is created. 
    # That guarantees we are working on real data, not synthetic.
    di.to_csv(outcsv / "di.csv", index=False)

    
    # --------------------------
    # Direct test: do multi-week gaps spike in those holiday seasons?
    # --------------------------
    alloc = (
        typed[["staff_id", "dt_alloc_invest", "staff_work_pattern"]]
        .copy()
    )
    alloc["dt_alloc_invest"] = pd.to_datetime(alloc["dt_alloc_invest"], errors="coerce")
    alloc = alloc.dropna(subset=["staff_id", "dt_alloc_invest"]).sort_values(["staff_id", "dt_alloc_invest"])
    
    alloc["prev_alloc"] = alloc.groupby("staff_id")["dt_alloc_invest"].shift(1)
    alloc["inter_pickup_days"] = (alloc["dt_alloc_invest"] - alloc["prev_alloc"]).dt.days
    alloc = alloc.dropna(subset=["inter_pickup_days"])
    
    alloc["holiday_season"] = alloc["dt_alloc_invest"].apply(school_holiday_season)
    
    gap_summary = (
        alloc.groupby(["staff_work_pattern", "holiday_season"])["inter_pickup_days"]
        .agg(n="count", median="median", p75=lambda s: s.quantile(0.75), p90=lambda s: s.quantile(0.90), mean="mean")
        .reset_index()
        .sort_values(["staff_work_pattern", "holiday_season"])
    )
    display(gap_summary)

    outplot = Path("data/out/plot/plots/Inter-pickup_gap_by_holiday_season")
    outplot.mkdir(parents=True, exist_ok=True)

    plot_df = alloc.copy()
    plot_df = plot_df[plot_df["inter_pickup_days"].between(0, plot_df["inter_pickup_days"].quantile(0.99))]
    
    fig, ax = plt.subplots(figsize=(8, 4))
    plot_df.boxplot(column="inter_pickup_days", by="holiday_season", ax=ax)
    ax.set_title("Inter-pickup gap by holiday season (outliers clipped at p99)")
    ax.set_xlabel("Holiday season")
    ax.set_ylabel("Days between case pickups")
    plt.suptitle("")
    fig.savefig(outplot / "boxplot_pickup_gap_by_holiday_season.png", dpi=150)
    plt.show()
    plt.close(fig)


    
    # --------------------------
    # Overall distributions (all interval metrics)
    # --------------------------
    
    INTERVAL_METRICS = [
        # “new case start” gap
        "inter_pickup_days",
        # alloc → PG sign-off
        "days_to_signoff",
        # received → PG sign-off
        "days_to_pg_signoff",
        # alloc → close
        "days_alloc_to_close",
        # received/alloc → legal review request
        "days_recieved_to_legal_review",
        "days_alloc_to_req_legal_review",
        # received → alloc
        "days_to_alloc",
    ]

    # Set any negative durations to NA so they don't affect mins/means/plots
    for c in INTERVAL_METRICS:
        if c in di.columns:
            di.loc[di[c] < 0, c] = pd.NA
    
    # --- overall ---
    all_overall = IntervalAnalysis.analyse_interval_distributions(di)  # uses interval_columns_available internally
    interval_dists_overall = _filter_metrics(all_overall, INTERVAL_METRICS)
    
    overall_df = tidy_overall(interval_dists_overall).sort_values("metric")
    print("\n=== Interval distributions (overall, last 4 years) ===")
    display(overall_df)


    # Breakdowns (case_type, risk, application_type, legal_review, and combo)
    # “frequency distribution summary” table (count/mean/std/percentiles/etc) 
    # for each metric, broken down by case type, risk, application type, legal review, and all combined.
    
    BREAKDOWNS = {
        "case_type": ["case_type"],
        #"risk_band": ["risk_band"],
        "application_type": ["application_type"],
        "legal_review": ["legal_review"],
        "concern_type": ["concern_type"],
        # full breakdown requested:
        "case_status_app_legal": ["case_type", "application_type", "legal_review", "concern_type"], #"risk_band", 
    }
    
    breakdown_tables = {}
    
    for name, cols in BREAKDOWNS.items():
        all_by = IntervalAnalysis.analyse_interval_distributions(di, by=cols)
        d = _filter_metrics(all_by, INTERVAL_METRICS)
        df = tidy_by(d, cols)
        breakdown_tables[name] = df
    
        print(f"\n=== Interval distributions by {name} ===")
        display(df.head(20))


    # --------------------------
    # Add basic charts for each metric (hist + box + trend)
    # --------------------------

    # Histograms (overall)
    outplot = Path("data/out/plot/plots/interval_metrics")
    outplot.mkdir(parents=True, exist_ok=True)
    
    # Get the metric series the SAME way IntervalAnalysis does (handles derived metrics)
    metric_series = IntervalAnalysis.interval_columns_available(di)
    
    for metric in INTERVAL_METRICS:
        s = metric_series.get(metric)
        if s is None:
            continue
    
        x = pd.to_numeric(s, errors="coerce").dropna()
        if x.empty:
            continue
    
        fig, ax = plt.subplots(figsize=(7, 4))
        ax.hist(x, bins=50)
        ax.set_title(f"Distribution: {metric}")
        ax.set_xlabel("Days")
        ax.set_ylabel("Count")
        fig.tight_layout()
        fig.savefig(outplot / f"hist_{metric}.png", dpi=150)
        plt.close(fig)


    # Boxplots by case type (repeatable for other group cols)
    GROUP_COL = "case_type"  # change to "risk_band" or "application_type" etc.
    
    for metric in INTERVAL_METRICS:
        s = metric_series.get(metric)
        if s is None or GROUP_COL not in di.columns:
            continue
    
        tmp = di[[GROUP_COL]].copy()
        tmp[metric] = pd.to_numeric(s, errors="coerce")
        tmp = tmp.dropna(subset=[GROUP_COL, metric])
    
        if tmp.empty:
            continue
    
        fig, ax = plt.subplots(figsize=(10, 4))
        tmp.boxplot(column=metric, by=GROUP_COL, ax=ax, rot=45)
        ax.set_title(f"{metric} by {GROUP_COL}")
        ax.set_xlabel(GROUP_COL)
        ax.set_ylabel("Days")
        plt.suptitle("")  # removes pandas default subtitle
        fig.tight_layout()
        fig.savefig(outplot / f"box_{metric}_by_{GROUP_COL}.png", dpi=150)
        plt.close(fig)


    
    # materialises any derived metrics into di
    _available = IntervalAnalysis.interval_columns_available(di)
    for _m in INTERVAL_METRICS:
        if _m not in di.columns and _m in _available:
            di[_m] = _available[_m]


    

    # Monthly trends (median by case type) for ALL metrics
    for metric in INTERVAL_METRICS:
        if metric not in di.columns:
            continue
        trend = IntervalAnalysis.monthly_trend(di, metric=metric, by=["case_type"])
        # print("\n=== Trend ===")
        # display(f"trend \n: {trend}")

    # for metric in INTERVAL_METRICS:
    #     trend = IntervalAnalysis.monthly_trend(di, metric=metric, by=["case_type"])
    #     if trend is None or len(trend) == 0:
    #         continue
    
        # trend columns: ['case_type','yyyymm','value',...] varies by your implementation;
        # In your earlier output you had the metric column present, so pivot on that:
        if metric not in trend.columns:
            continue


        # Ensure we have a real datetime month column for pivoting/plotting
        if "month" not in trend.columns:
            if "yyyymm" in trend.columns:
                trend = trend.copy()
                trend["month"] = pd.to_datetime(trend["yyyymm"].astype(str) + "-01")
            elif "date" in trend.columns:
                trend = trend.copy()
                trend["month"] = pd.to_datetime(trend["date"]).dt.to_period("M").dt.to_timestamp()
        
        pivot = (
            trend.pivot_table(index="month", columns="case_type", values=metric, aggfunc="median")
            .sort_index()
        )

        pivot = trend.pivot_table(index="month", columns="case_type", values=metric, aggfunc="median")
        # pivot = trend.pivot_table(index="yyyymm", columns="case_type", values=metric, aggfunc="median").sort_index()

        if pivot.empty:
            continue
    
        fig, ax = plt.subplots(figsize=(10, 4))
        pivot.plot(ax=ax)
        ax.set_title(f"Monthly median {metric} by case_type")
        ax.set_xlabel("Month")
        ax.set_ylabel("Days")
        fig.tight_layout()
        fig.savefig(outplot / f"trend_{metric}_by_case_type.png", dpi=150)
        plt.close(fig)


        # Rolling - overall monthly median of any metric
        tmp = di[["date", metric]].copy()
        tmp["date"] = pd.to_datetime(tmp["date"], errors="coerce")
        tmp = tmp.dropna(subset=["date", metric])
        tmp["month"] = tmp["date"].dt.to_period("M").dt.to_timestamp()
        
        monthly = tmp.groupby("month")[metric].median().sort_index()
        rolling3 = add_rolling_avg(monthly, window=3)
        
        fig, ax = plt.subplots(figsize=(8, 4))
        ax.plot(monthly.index, monthly.values, label="Monthly median")
        ax.plot(rolling3.index, rolling3.values, label="3-mo rolling avg")
        ax.set_title(f"{metric} over time (smoothed)")
        ax.set_xlabel("Month")
        ax.set_ylabel(f"Median {metric}")
        ax.legend()
        fig.tight_layout()
        plt.show()
        fig.savefig(outplot / f"rolling average_median_{metric}_by_case_type.png", dpi=150)
        plt.close(fig)


    # ------------------------------------------------------
    # Distributions of key time intervals:
    # How long from allocation to close vs to PG signoff
    # How long gaps between pickups vs case lifecycle durations
    # Distribution shape (boxplots)
    # ------------------------------------------------------
    interval_dists_overall = IntervalAnalysis.analyse_interval_distributions(di)
    
    # Remove unnecessary columns/metrics
    #interval_dists_overall.pop("days_alloc_to_req_legal_review", None)

    
    # Optional: print a quick summary so you can see something immediately
    print("\n=== Interval distributions (overall, last 4 years) ===")
    interval_dists_overall_df = pd.DataFrame(interval_dists_overall).T
    interval_dists_overall_df.to_csv(outcsv / "interval_dists_overall_last4yrs.csv")
    
    overall_df = (
        pd.DataFrame(interval_dists_overall)
        .T                      # metrics become rows
        .reset_index()
        .rename(columns={"index": "metric", "p50": "median"})
        [["metric", "count", "mean", "std", "min", "p10", "p25", "median", "p75", "p90", "max"]]
    )
    print(overall_df)
    #print(pd.DataFrame(interval_dists_overall).T.head())
    # Save Interval distributions DataFrame to CSV
    overall_df.to_csv(outcsv / "interval_dists_overall_df.csv", index=False)
    
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(overall_df["metric"], overall_df["median"])
    ax.set_ylabel("Median days")
    ax.set_title("Median interval length (last 4 years)")
    ax.set_xticklabels(overall_df["metric"], rotation=45, ha="right")
    plt.tight_layout()
    plot_interval_dists_overall = outdir / "plot_interval_dists_overall.png"
    plt.savefig(plot_interval_dists_overall, bbox_inches="tight", dpi=150)
    plt.show()


    # Distribution shape (boxplots)
    # Keep it simple: use a subset of metrics
    metrics_to_plot = ["days_to_alloc", 
                       "days_to_pg_signoff", 
                       "days_alloc_to_close", 
                       "days_recieved_to_legal_review", 
                       "days_alloc_to_req_legal_review", 
                       "inter_pickup_days"
                      ]

    # Set any negative durations to NA so they don't affect mins/means/plots
    # That will prevent negative mins/means in anything computed from di afterwards 
    # (distributions, charts, trends, etc.), without inventing fake values.
    for c in metrics_to_plot:
        if c in di.columns:
            di.loc[di[c] < 0, c] = pd.NA
            
    # Use the interval frame (di) for derived interval metrics
    df_for_box = di.copy()
    
    # Ensure derived columns exist (in case you want to reference them directly)
    if "days_alloc_to_close" not in df_for_box.columns and {"dt_close", "dt_alloc_invest"}.issubset(df_for_box.columns):
        df_for_box["days_alloc_to_close"] = (df_for_box["dt_close"] - df_for_box["dt_alloc_invest"]).dt.days.astype("float")
    
    if "inter_pickup_days" not in df_for_box.columns and "time_since_last_pickup" in df_for_box.columns:
        df_for_box["inter_pickup_days"] = df_for_box["time_since_last_pickup"]
    
    cols_for_box = ["days_to_alloc", 
                       "days_to_pg_signoff", 
                       "days_alloc_to_close", 
                       "days_recieved_to_legal_review", 
                       "days_alloc_to_req_legal_review", 
                       "inter_pickup_days"
                      ]
    cols_present = [c for c in cols_for_box if c in df_for_box.columns]
    
    missing = sorted(set(cols_for_box) - set(cols_present))
    if missing:
        print(f"Skipping missing metrics (not found in df): {missing}")
    
    box_df = (
        df_for_box[cols_present]
          .melt(var_name="metric", value_name="value")
          .dropna()
    )

    fig, ax = plt.subplots(figsize=(18, 12))
    box_df.boxplot(by="metric", column="value", ax=ax)
    ax.set_xlabel("")
    ax.set_ylabel("Days")
    ax.set_title("Interval distributions (last 4 years)")
    plt.suptitle("")
    plt.tight_layout()
    plot_interval_dists_boxplots = outdir / "plot_interval_dists_boxplots.png"
    plt.savefig(plot_interval_dists_boxplots, bbox_inches="tight", dpi=150)
    plt.show()



    
    # ------------------------------------------------------------
    # Interval distributions by case_type and application_type
    # ------------------------------------------------------------

    # Interval distributions by case_type
    
    out_case_type = run_interval_outputs(
        typed=typed,
        di=di,
        group_col="case_type",
        outdir="data/out/plot/plots/by_case_type",
        interval_metrics=INTERVAL_METRICS,
    )

    # Interval distributions by application_type

    if "application_type" in typed.columns and "application_type" not in di.columns:
        di = di.merge(typed[["case_id", "application_type"]], on="case_id", how="left")
    
    out_app_type = run_interval_outputs(
        typed=typed,
        di=di,
        group_col="application_type",
        outdir="data/out/plot/plots/by_application_type",
        interval_metrics=INTERVAL_METRICS,
    )

    # Interval distributions by application_type

    if "concern_type" in typed.columns and "concern_type" not in di.columns:
        di = di.merge(typed[["case_id", "concern_type"]], on="case_id", how="left")
    
    out_app_type = run_interval_outputs(
        typed=typed,
        di=di,
        group_col="concern_type",
        outdir="data/out/plot/plots/by_concern_type",
        interval_metrics=INTERVAL_METRICS,
    )

    
    # Quick sanity check (optional)
    print("Metrics present in di:", [m for m in INTERVAL_METRICS if m in di.columns])
    print("Metrics missing from di:", [m for m in INTERVAL_METRICS if m not in di.columns])


    # ------------------------------------------------------------
    # Add the extra breakdowns for (risk, application type, legal reviewed)
    # - Time intervals for new case starts broken down by: case_type, risk, application_type, legal_review
    # - Time from allocated → PG sign-off broken down by same
    # ------------------------------------------------------------

    # Make sure legal_review is numeric (also fixes your "category mean" error later)
    typed = typed.copy()
    typed["legal_review"] = pd.to_numeric(typed.get("legal_review"), errors="coerce").fillna(0).astype("int64")
    
    # If risk_band is missing, create a simple one from 'risk' if present
    if "status" not in typed.columns and "status" in typed.columns:
        typed["status"] = typed["status"].astype("string")
    
    # Ensure the same columns exist in di (merge from typed by case_id if needed)
    need_in_di = ["case_id", "case_type", "status", "application_type", "legal_review"]
    missing_in_di = [c for c in need_in_di if c not in di.columns]
    if missing_in_di and "case_id" in di.columns and "case_id" in typed.columns:
        di = di.merge(typed[["case_id"] + [c for c in missing_in_di if c in typed.columns]],
                      on="case_id", how="left")


    # ------------------------------------------------------------
    # “New case start interval” distributions (frequency + summary)
    # - This uses new case start rows (event_newcase==1) and summarises the “gap since last pickup”.
    # Time from allocated → PG sign-off broken down by same
    # ------------------------------------------------------------

    from pathlib import Path
    outdir = Path("data/out/plot/plots/newcase_intervals")
    outdir.mkdir(parents=True, exist_ok=True)
    
    # Use whichever column you actually have for the gap-on-newcase rows
    # Prefer: inter_pickup_days (if it exists)
    gap_col = "inter_pickup_days" if "inter_pickup_days" in di.columns else "time_since_last_pickup"
    
    newcase = di.loc[di.get("event_newcase") == 1].copy()
    newcase[gap_col] = pd.to_numeric(newcase[gap_col], errors="coerce")
    
    # Create the same bands you’ve been using (if gap_band not present)
    if "gap_band" not in newcase.columns:
        newcase["gap_band"] = pd.cut(
            newcase[gap_col],
            bins=[-0.1, 7, 14, 28, 91, np.inf],
            labels=["<1 week", "1–2 weeks", "2–4 weeks", "4–13 weeks", ">13 weeks"],
        )
    
    # (1) Summary stats by requested breakdown
    breakdowns = ["case_type", "status", "application_type", "legal_review"]
    for b in breakdowns:
        if b in newcase.columns:
            summ = interval_summary(newcase, metrics=[gap_col], group_cols=[b])
            display(summ.head(30))
            summ.to_csv(outdir / f"summary_gap_by_{b}.csv", index=False)
    
    # (2) Frequency distributions (counts in bands) by each breakdown
    for b in breakdowns:
        if b in newcase.columns:
            freq = (
                newcase.groupby([b, "gap_band"], dropna=False)[gap_col]
                .size()
                .reset_index(name="n")
            )
            freq.to_csv(outdir / f"freq_gapband_by_{b}.csv", index=False)
    
            # quick plot: stacked bars is messy; do a simple “top line” plot per band total
            pivot = freq.pivot_table(index=b, columns="gap_band", values="n", fill_value=0)
            fig, ax = plt.subplots(figsize=(9, 4))
            for col in pivot.columns:
                ax.plot(pivot.index.astype(str), pivot[col].values, label=str(col))
            ax.set_title(f"New case start gaps: frequency by {b} (by gap band)")
            ax.set_ylabel("Count")
            ax.tick_params(axis="x", rotation=45)
            ax.legend()
            fig.tight_layout()
            fig.savefig(outdir / f"freq_gapband_lines_by_{b}.png", dpi=200)
            plt.show()



    # ------------------------------------------------------
    # “Allocated → PG sign-off” interval broken down the same way
    # ------------------------------------------------------
    outdir2 = Path("data/out/plot/plots/alloc_to_pg_signoff")
    outdir2.mkdir(parents=True, exist_ok=True)
    
    # compute alloc -> signoff days (only where dates exist)
    typed = typed.copy()
    typed["days_alloc_to_pg_signoff"] = (
        (pd.to_datetime(typed["dt_pg_signoff"], errors="coerce") -
         pd.to_datetime(typed["dt_alloc_invest"], errors="coerce"))
        .dt.days
    )
    
    metric = "days_alloc_to_pg_signoff"
    
    # summary stats by breakdown
    breakdowns = ["case_type", "application_type", "legal_review"] # , "risk_band"
    for b in breakdowns:
        if b in typed.columns:
            summ = interval_summary(typed, metrics=[metric], group_cols=[b])
            display(summ.head(30))
            summ.to_csv(outdir2 / f"summary_{metric}_by_{b}.csv", index=False)
    
            # median bar chart
            plot_group_median_bar(summ, group_col=b, metric=metric,
                                  outpath=outdir2 / f"median_{metric}_by_{b}.png")
    
    # optional: histogram overall
    plot_metric_histograms(typed, metric, outpath=outdir2 / f"hist_{metric}_overall.png")

    
    # -----------------------------------------------------------------------
    # Legal review rate: break down + “Risk × age” deeper dive
    # - Legal review rate by case_type / risk / application_type
    # - “Risk × age”: why legal review appears more likely as time increases
    # -----------------------------------------------------------------------
    
    # - Legal review rate by case_type / risk / application_type
    outdir3 = Path("data/out/plot/plots/legal_review")
    outdir3.mkdir(parents=True, exist_ok=True)
    
    typed = typed.copy()
    typed["legal_review"] = pd.to_numeric(typed["legal_review"], errors="coerce").fillna(0).astype(int)
    
    def legal_rate_table(df, by):
        cols = [c for c in by if c in df.columns] + ["legal_review", "case_id"]
        tmp = df[cols].copy()
        out = (
            tmp.groupby([c for c in by if c in tmp.columns], dropna=False)
               .agg(n_cases=("case_id", "count"),
                    legal_rate=("legal_review", "mean"))
               .reset_index()
        )
        return out
    
    for b in ["case_type", "status", "application_type"]:
        if b in typed.columns:
            tab = legal_rate_table(typed, by=[b])
            display(tab.sort_values("legal_rate", ascending=False).head(30))
            tab.to_csv(outdir3 / f"legal_rate_by_{b}.csv", index=False)
    
            # bar plot
            fig, ax = plt.subplots(figsize=(8, 4))
            ax.bar(tab[b].astype(str), tab["legal_rate"].astype(float))
            ax.set_title(f"Legal review rate by {b}")
            ax.set_ylabel("Proportion legal_review=1")
            ax.tick_params(axis="x", rotation=45)
            fig.tight_layout()
            fig.savefig(outdir3 / f"legal_rate_by_{b}.png", dpi=200)
            plt.show()
    
    # two-way breakdown: case_type x status (often very informative)
    if "case_type" in typed.columns and "status" in typed.columns:
        tab2 = legal_rate_table(typed, by=["case_type", "status"])
        pivot = tab2.pivot_table(index="case_type", columns="status", values="legal_rate")
        display(pivot)
    
        # heatmap
        plot_prob_heatmap(pivot, outpath=outdir3 / "legal_rate_heatmap_case_type_x_status.png",
                          title="Legal review rate (case_type × status)")



    # - “Risk × age”: why legal review appears more likely as time increases
    # There are two different “age” concepts people often mix:
    # Case age (how long the case has been open)
    # Donor age (actual age of the donor — requires linked data)
    # Your earlier “Risk × age” slide sounds like case age. If so, it’s very common to see legal review likelihood rise with case age because:
    # Accumulation effect: the longer a case stays open, the more opportunities it has to trigger a legal step (this is basically a hazard story).
    # Confounding by complexity: complex cases both (a) last longer and (b) need legal review more often.
    # Reverse causality: legal review itself can extend duration (queues, waiting for opinion, rework).
    # A better way to “explain” it is to estimate hazard of legal review by case age band, optionally controlling for risk/case type.
    # If di has daily rows with event_legal and a “case age” (days since received or since allocation):

    outdir4 = Path("data/out/plot/plots/legal_hazard")
    outdir4.mkdir(parents=True, exist_ok=True)
    
    # pick a case-age column:
    # if you already have something like weeks_since_start or time index, use it.
    # Otherwise compute case_age_days from di["date"] - di["dt_received_inv"] (if both exist)
    if "case_age_days" not in di.columns and "date" in di.columns and "dt_received_inv" in di.columns:
        di = di.copy()
        di["case_age_days"] = (
            pd.to_datetime(di["date"], errors="coerce") -
            pd.to_datetime(di["dt_received_inv"], errors="coerce")
        ).dt.days
    
    # age bands
    if "case_age_days" in di.columns:
        di["case_age_band"] = pd.cut(
            pd.to_numeric(di["case_age_days"], errors="coerce"),
            bins=[-0.1, 7, 14, 28, 56, 91, 182, 365, np.inf],
            labels=["0–7","8–14","15–28","29–56","57–91","92–182","183–365",">365"],
        )
    
        # hazard = P(event_legal today | in that age band)
        if "event_req_legal_review" in di.columns:
            di["event_req_legal_review"] = pd.to_numeric(di["event_req_legal_review"], errors="coerce").fillna(0).astype(int)
    
            # overall hazard by age band
            hz = di.groupby("case_age_band", dropna=False)["event_req_legal_review"].mean().reset_index(name="p_legal_today")
            display(hz)
    
            fig, ax = plt.subplots(figsize=(7, 4))
            ax.plot(hz["case_age_band"].astype(str), hz["p_legal_today"].astype(float))
            ax.set_title("Daily hazard of legal review vs case age band")
            ax.set_ylabel("P(legal review today)")
            ax.tick_params(axis="x", rotation=45)
            fig.tight_layout()
            fig.savefig(outdir4 / "hazard_legal_by_case_age.png", dpi=200)
            plt.show()
    
            # hazard by risk band (if present)
            if "status" in di.columns:
                hz2 = (
                    di.groupby(["status", "case_age_band"], dropna=False)["event_legal"]
                      .mean()
                      .unstack("case_age_band")
                )
                display(hz2)
                plot_prob_heatmap(hz2, outpath=outdir4 / "hazard_legal_status_x_age.png",
                                  title="Daily hazard of legal review (status × case age)")
    
    # To do this properly, you need those fields in your typed table (or join them in). The workflow is:
    # Join linked investigations + LPA register data into typed on a stable key (often the “LPA/Deputy ID”).
    # Make sure the joined fields are clean types:
    # - donor_age numeric
    # - donor_sex category/string
    # - num_attorneys numeric
    # - lpa_type category/string
    # - time_since_registered_days numeric
    
    # Produce:
    # - legal review rate by each characteristic (1D)
    # - 2D interactions that matter (e.g., lpa_type × donor_age_band, num_attorneys_band × case_type)
    # - a simple baseline predictive model (logistic regression) as a benchmark for “is it predictable at all?”

    # Important modelling note (so you don’t accidentally “cheat”)
    # - If your goal is “predict legal review at receipt/allocation”, 
    # - avoid using post-start variables like final duration, days_to_signoff, etc. Those leak future information.

    # Minimal “rate by characteristic” pattern:

    # Example: donor age bands (once donor_age exists in typed)
    if "donor_age" in typed.columns:
        typed = typed.copy()
        typed["donor_age"] = pd.to_numeric(typed["donor_age"], errors="coerce")
        typed["donor_age_band"] = pd.cut(
            typed["donor_age"],
            bins=[0, 30, 40, 50, 60, 70, 80, 90, 120],
            right=False,
            labels=["<30","30–39","40–49","50–59","60–69","70–79","80–89","90+"],
        )
    
        tab = typed.groupby("donor_age_band", dropna=False).agg(
            n_cases=("case_id", "count"),
            legal_rate=("legal_review", "mean")
        ).reset_index()
    
        display(tab)



    # --------------------------------------------------------------
    # Add “linked characteristics” for predicting legal review 
    # (LPA type, donor age, sex, #attorneys, time since registered)
    # --------------------------------------------------------------
    

    # ------------------------------------------------------
    # Probability of new case start vs workload & gap
    # ------------------------------------------------------
    
    # “Rules” about new case starts vs workload & time since last case:
    # - From the interval frame di:
    # - wip_load – investigators’ weighted caseload on that date
    # - time_since_last_pickup – days since they last started a new case
    # - event_newcase – indicator that a new case started that day
    # - What we want is a conditional probability table:
    #     - P(new case today | caseload band, gap since last pickup band)
    # - If caseload LOW & gap LONG → probability HIGH?
    # - If caseload HIGH & gap LONG → probability LOW?
    # etc.

    # rules like…” question (workload + time since last pickup)
    # Yes — you can infer empirical “rules” like:
    # - P(new case start | workload band, gap band)
    # - …because that’s exactly what the pickup_prob matrix represents.

    # What we can safely say:
    # - You’ve estimated conditional probabilities from observed staff-days.
    # - Those can be used as a policy rule in simulation (“if an investigator is Low workload and gap is >13 weeks then sample a new case with probability p”).
    # What you should not claim without further work:
    # - That workload/gap cause allocations (there may be operational policies, staffing changes, backlog availability, etc.)
   
    # If we want stronger evidence, next step is a simple model:
    # - logistic regression (or gradient boosting) predicting event_newcase from 
    # wip_load, time_since_last_pickup, seasonality, bank holidays, backlog level, team, etc.
    # - then compare fitted probabilities to your banded rule table (they should broadly align).


    
    # “Rules” table – probability of new case start by workload & gap

    pickup_df = di[
        ["date", "staff_id", "wip_load", "time_since_last_pickup", "event_newcase"]
    ].copy()

    # Drop rows where we don't know the gap or caseload
    pickup_df = pickup_df.dropna(subset=["wip_load", "time_since_last_pickup"])

    # Define workload bands (tweak thresholds as needed for OPG)
    pickup_df["wip_band"] = pd.cut(
        pickup_df["wip_load"],
        #bins=[0, 40, 80, 120, float("inf")],
        bins=[0, 2, 3, 4, float("inf")],
        labels=["Low", "Medium", "High", "Very high"],
        right=False,
        include_lowest=True,
    )

    # Define time-since-last-pickup bands (gap in days)
    pickup_df["gap_band"] = pd.cut(
        pickup_df["time_since_last_pickup"],
        bins=[0, 7, 14, 28, 90, float("inf")],
        labels=["<1 week", "1–2 weeks", "2–4 weeks", "4–13 weeks", ">13 weeks"],
        right=False,
        include_lowest=True,
    )

    # Probability: mean of event_newcase within each (wip_band, gap_band)
    pickup_prob = (
        pickup_df
        .groupby(["wip_band", "gap_band"], dropna=False, observed=False)["event_newcase"]
        .mean()
        .unstack("gap_band")
    )

    # Counts: how many staff-days in each cell (for reliability)
    pickup_counts = (
        pickup_df
        .groupby(["wip_band", "gap_band"], dropna=False, observed=False)["event_newcase"]
        .size()
        .unstack("gap_band")
    )

    print("\n=== P(new case start | workload band, gap band) ===")
    print(pickup_prob)

    print("\n=== Number of staff-days underlying each cell ===")
    print(pickup_counts)

    # Convert to long / tidy format
    prob_long = (
        pickup_prob
        .reset_index()
        .melt(id_vars="wip_band", var_name="gap_band", value_name="prob_new_case")
    )
    
    counts_long = (
        pickup_counts
        .reset_index()
        .melt(id_vars="wip_band", var_name="gap_band", value_name="staff_days")
    )
    
    rules_df = (
        prob_long
        .merge(counts_long, on=["wip_band", "gap_band"])
        .sort_values(["wip_band", "gap_band"])
    )
    
    print("===== “Rules” table – probability of new case start by workload & gap =====")
    print(rules_df)

    # rules_low = pd.DataFrame({
    #     'C': ['<1 week','1–2 weeks','2–4 weeks','4–13 weeks','>13 weeks'],
    #     'prob_new_case': [0.058057,0.033535,0.048967,0.123724,0.142045],
    #     'staff_days': [7441,4294,3145,1665,352]
    # })
    # Keep only the Low workload band (your chart title says Low)
    rules_low = (
        rules_df.loc[
            rules_df["wip_band"].eq("Low"),
            ["gap_band", "prob_new_case", "staff_days"]
        ]
        .copy()
    )
    
    # Keep only cells that actually have observations
    rules_low["staff_days"] = pd.to_numeric(rules_low["staff_days"], errors="coerce")
    rules_low = rules_low.loc[rules_low["staff_days"].fillna(0) > 0]
    
    # Convert <NA> -> NaN -> float (matplotlib-safe)
    rules_low["prob_new_case"] = pd.to_numeric(rules_low["prob_new_case"], errors="coerce")
    
    # Optional: enforce sensible x-axis order
    gap_order = ["<1 week", "1–2 weeks", "2–4 weeks", "4–13 weeks", ">13 weeks"]
    rules_low["gap_band"] = pd.Categorical(rules_low["gap_band"], categories=gap_order, ordered=True)
    rules_low = rules_low.sort_values("gap_band")
    
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(rules_low["gap_band"].astype(str), rules_low["prob_new_case"].astype(float))
    ax.set_ylabel("P(new case today)")
    ax.set_title("Probability of new case vs gap since last pickup\n(Low workload band)")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()


    
    # Convert to float for plotting; <NA> becomes NaN
    prob = pickup_prob.astype("float")
    counts = pickup_counts.astype("float")
    
    fig, ax = plt.subplots(figsize=(7, 3.5))
    im = ax.imshow(prob.values, aspect="auto")
    
    ax.set_xticks(range(prob.shape[1]))
    ax.set_xticklabels(prob.columns, rotation=45, ha="right")
    ax.set_yticks(range(prob.shape[0]))
    ax.set_yticklabels(prob.index)
    
    ax.set_title("P(new case start | workload band, gap band)")
    plt.colorbar(im, ax=ax)
    
    # annotate with counts
    for i in range(prob.shape[0]):
        for j in range(prob.shape[1]):
            p = prob.values[i, j]
            n = counts.values[i, j]
            if np.isfinite(p):
                ax.text(j, i, f"{p:.3f}\\n(n={int(n)})", ha="center", va="center")
    
    plt.tight_layout()
    plot_P_new_case_workload_gap_band = outdir / "plot_P_new_case_workload_gap_band.png"
    plt.savefig(plot_P_new_case_workload_gap_band, bbox_inches="tight", dpi=150)
    plt.show()
    print("If your table only shows Low and everything else is empty: that usually means your wip_band calculation (or the WIP itself) is not varying in the dataset you’re feeding into the rule-table. The fastest check is:")

    # Always show details
    # daily[["wip", "wip_load"]].describe()
    # daily["wip_band"].value_counts(dropna=False)


    # Heatmap of “rules” (visual fuzzy rule surface)
    # If you later get non-zero values for Medium/High/Very high, this becomes a really intuitive “rule map” for a fuzzy system.    
    # Ensure the matrix is truly numeric (not object / pd.NA)
    prob_matrix = pickup_prob.apply(pd.to_numeric, errors="coerce").fillna(0.0)
    print(prob_matrix.dtypes)

    fig, ax = plt.subplots(figsize=(6, 4))
    im = ax.imshow(prob_matrix.to_numpy(dtype=float), aspect="auto", vmin=0, vmax=1)

    
    ax.set_xticks(range(len(prob_matrix.columns)))
    ax.set_xticklabels(prob_matrix.columns, rotation=45, ha="right")
    ax.set_yticks(range(len(prob_matrix.index)))
    ax.set_yticklabels(prob_matrix.index)
    
    ax.set_xlabel("Gap band (time since last pickup)")
    ax.set_ylabel("WIP band (current caseload)")
    ax.set_title("P(new case start | workload, gap)")
    
    fig.colorbar(im, ax=ax, label="Probability of new case start")
    plt.tight_layout()
    plot_heatmaps_rules_fuzzy = outdir / "plot_heatmaps_rules_fuzzy.png"
    plt.savefig(plot_heatmaps_rules_fuzzy, bbox_inches="tight", dpi=150)
    plt.show()


    # year-on-year interval changes
    alloc_change = interval_change_distribution(
        typed,
        interval_col="days_to_alloc",
        date_col="dt_received_inv",  # <- real received-date column
        group="case_type",
    )

    alloc_annual_stats = alloc_change["annual_stats"]
    alloc_yoy_change = alloc_change["yoy_change"]

    print("\n=== Annual distributions of days_to_alloc by case_type ===")
    print(alloc_annual_stats.head())
    fig, ax = plt.subplots(figsize=(8, 4))
    
    for ct, grp in alloc_annual_stats.groupby("case_type"):
        ax.plot(grp["year"], grp["median"], marker="o", label=ct)
    
    ax.set_xlabel("Year")
    ax.set_ylabel("Median days from received to allocation")
    ax.set_title("Median allocation delay by case type over time")
    ax.legend(title="Case type", ncol=2)
    plt.tight_layout()
    plot_Annual_dist_days_to_alloc_b_case_type = outdir / "plot_Annual_dist_days_to_alloc_b_case_type.png"
    plt.savefig(plot_Annual_dist_days_to_alloc_b_case_type, bbox_inches="tight", dpi=150)
    plt.show()

    main_types = ["Aspect", "Fraud", "Investigation", "Multiple", "Multiple Sub", "TPO"]
    mask = alloc_annual_stats["case_type"].isin(main_types)
    
    fig, ax = plt.subplots(figsize=(8, 4))
    for ct, grp in alloc_annual_stats[mask].groupby("case_type"):
        ax.plot(grp["year"], grp["median"], marker="o", label=ct)
    
    ax.set_xlabel("Year")
    ax.set_ylabel("Median days to allocation")
    ax.set_title("Median allocation delay for main case types")
    ax.legend(title="Case type", ncol=2)
    plt.tight_layout()
    plot_Annual_dist_days_to_alloc_restricted_case_type = outdir / "plot_Annual_dist_days_to_alloc_restricted_case_type.png"
    plt.savefig(plot_Annual_dist_days_to_alloc_restricted_case_type, bbox_inches="tight", dpi=150)
    plt.show()



    print("\n=== Year-on-year change in median days_to_alloc by case_type ===")
    print(alloc_yoy_change.head())

    # Focus on realistic years to avoid huge early artefacts if needed
    yoy_df = alloc_yoy_change.copy()
    
    fig, ax = plt.subplots(figsize=(8, 4))
    
    for ct, grp in yoy_df.groupby("case_type"):
        ax.plot(grp["year"], grp["yoy_median_change"], marker="o", label=ct)
    
    ax.axhline(0, linestyle="--")
    ax.set_xlabel("Year")
    ax.set_ylabel("Change in median days to allocation vs previous year")
    ax.set_title("Year-on-year change in allocation delays by case type")
    ax.legend(title="Case type", ncol=2)
    plt.tight_layout()
    plot_alloc_yoy_change = outdir / "plot_alloc_yoy_change.png"
    plt.savefig(plot_alloc_yoy_change, bbox_inches="tight", dpi=150)
    plt.show()


    print("===== TREND ANALYSIS =====")
    # Trend Analysis
    trend = IntervalAnalysis.monthly_trend(
        di, metric="days_to_pg_signoff", agg="median", by=["case_type"]
    ).copy()
    trend["month"] = pd.to_datetime(trend["yyyymm"] + "-01")
    
    print("\n=== INTERVAL TREND HEAD ===")
    print(trend.head())

    fig, ax = plt.subplots(figsize=(8, 4))
    
    ax.plot(trend["month"], trend["days_to_pg_signoff"], marker="o")
    ax.set_xlabel("Month")
    ax.set_ylabel("Median days to PG signoff")
    ax.set_title("Monthly median days to PG signoff – ALL case types")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plot_trend_analysis = outdir / "plot_trend_analysis.png"
    plt.savefig(plot_trend_analysis, bbox_inches="tight", dpi=150)
    plt.show()

    # Faceted by case_type (simple overlay version)
    fig, ax = plt.subplots(figsize=(9, 5))
    
    for ct, grp in trend.groupby("case_type"):
        ax.plot(grp["month"], grp["days_to_pg_signoff"], marker="o", linewidth=1, label=ct)
    
    ax.set_xlabel("Month")
    ax.set_ylabel("Median days to PG signoff")
    ax.set_title("Monthly median PG signoff times by case type")
    plt.xticks(rotation=45, ha="right")
    ax.legend(title="Case type", ncol=2)
    plt.tight_layout()
    plot_trend_analysis_by_case_type = outdir / "plot_trend_analysis_by_case_type.png"
    plt.savefig(plot_trend_analysis_by_case_type, bbox_inches="tight", dpi=150)
    plt.show()




    # Expand Above for the FTE and investigator invertal destributions
    
    interval_dists_by_fte = IntervalAnalysis.analyse_interval_distributions(
        di, by=["fte"]
    )
    interval_dists_by_fte.pop("days_alloc_to_req_legal_review", None)
    
    interval_dists_by_investigator = IntervalAnalysis.analyse_interval_distributions(
        di, by=["staff_id"]
    )
    interval_dists_by_investigator.pop("days_alloc_to_req_legal_review", None)




    cfg = EDAConfig(
        id_col="case_id",
        date_received="dt_received_inv",
        date_allocated="dt_alloc_invest",
        date_signed_off="dt_pg_signoff",
    )
    cfg.numeric_cols = [
        "days_to_alloc",
        "days_to_signoff",
        "legal_review",
        "fte",
        "weighting",
    ]
    eda = OPGInvestigationEDA(typed, cfg)
    print("=== EDA COLUMNS ===")
    print(eda.df.columns.tolist())
    
    # --- EDA code from demo_eda.py ---
    
    print("=== EDA OVERVIEW ===")
    overview = eda.quick_overview()
    #print(overview)

    print("=== EDA MISSING ===")
    missing_pct = eda.missingness_matrix()
    missing_vs_target = eda.missing_vs_target("days_to_signoff")#, "legal_review")
    outliers_signoff = eda.iqr_outliers("days_to_signoff")
    outliers_allocate = eda.iqr_outliers("days_to_alloc")
    #cat_summary = eda.group_summary(["case_type", "risk_band"], target="legal_review")
    # cat_summary = eda.group_summary(
    #     ["case_type", "risk_band"], 
    #     metrics={"legal_rate": ("legal_review", "mean")} 
    # ) #"n": ("id", "count"), 
    weight_summary = eda.group_summary( 
        by=["weighting"],
        metrics={ #"new_column_name": ("existing_column_name", "aggfunc")
            "n_cases": ("case_id", "count"),
            "legal_rate": ("legal_review", "mean"),          # proportion of cases with legal_review=1
            "median_days_to_signoff": ("days_to_signoff", "median"), # "aggfunc" one of: "count", "mean", "median", "min", "max", "std", etc.
        },
    )

    case_weight_summary = eda.group_summary(
        by=["case_type", "weighting"],
        metrics={
            "n_cases": ("case_id", "count"),
            "median_days_to_alloc": ("days_to_alloc", "median"),
            "median_days_to_signoff": ("days_to_signoff", "median"),
        },
    )
    case_weight_summary = case_weight_summary.sort_values(["case_type", "weighting"])

    legal_review_by_case_type = eda.group_summary(
        by=["case_type"],
        metrics={
            #"avg_backlog": ("backlog", "mean"),
            "legal_rate": ("legal_review", "mean"),
        },
    )

    legal_review_by_case_status = eda.group_summary(
        by=["case_type", "status"],
        metrics={
            #"avg_backlog": ("backlog", "mean"),
            "legal_rate": ("legal_review", "mean"),
        },
    )

    staff_summary = eda.group_summary(
        by=["staff_id", "fte"],
        metrics={
            "n_cases": ("case_id", "count"),
            "mean_days_to_alloc": ("days_to_alloc", "mean"),
            "mean_days_to_signoff": ("days_to_signoff", "mean"),
            "legal_rate": ("legal_review", "mean"),
        },
    )

    fte_weight_summary = eda.group_summary(
        by=["fte"],
        metrics={
            "total_weight": ("weighting", "sum"),
            "avg_weight": ("weighting", "mean"),
        },
    )

    case_weight_full = eda.group_summary(
        by=["case_type", "weighting"],
        metrics={
            "staff_id": ("staff_id", "count"),
            #"avg_backlog": ("backlog", "mean"),
            "median_days_to_alloc": ("days_to_alloc", "median"),
            "median_days_to_signoff": ("days_to_signoff", "median"),
            "legal_rate": ("legal_review", "mean"),
        },
    )
    case_weight_full = case_weight_full.sort_values(["case_type", "weighting"])


    status_summary = eda.group_summary(
        by=["status"],
        metrics={
            "staff_id": ("staff_id", "count"),
            "legal_rate": ("legal_review", "mean"),
            "median_days_to_signoff": ("days_to_signoff", "median"),
            "median_days_to_alloc": ("days_to_alloc", "median"),
        },
    )

    legal_case_summary = eda.group_summary(
        by=["case_type", "legal_review"],
        metrics={
            "staff_id": ("staff_id", "count"),
            "median_days_to_signoff": ("days_to_signoff", "median"),
            "median_days_to_alloc": ("days_to_alloc", "median"),
        },
    )



    corrs = eda.numeric_correlations(method="spearman")
    class_balance = eda.imbalance_summary()
    leakage_hits = eda.leakage_scan(["post", "signed", "decision"])
    interaction = eda.binned_interaction_rate(
        num_col="days_to_alloc",
        cat_col="weighting",
        target="legal_review",
    )
    
    # ts_7d, lag_corrs = eda.resample_time_series(
    #     metrics={"days_to_alloc": ("days_to_alloc", "last"), 
    #              "staff_count": ("staff_id", "count")}
    # )
    
    # km_q = eda.km_quantiles_by_group(group="weighting")
    # monthly_kpis = eda.monthly_kpis()
    cramers_case_type_w = eda.cramers_v(typed["case_type"],typed["weighting"])
    cramers_case_type_fte = eda.cramers_v(typed["case_type"],typed["fte"])
    
    
    # --- END EDA code ---



    def weighted_mean(values, weights):
        v = np.asarray(values)
        w = np.asarray(weights)
        mask = ~np.isnan(v) & ~np.isnan(w)
        if mask.sum() == 0:
            return np.nan
        return (v[mask] * w[mask]).sum() / w[mask].sum()
    
    # weighted mean days_to_signoff by case_type
    weighted_signoff = (
        eda.df
        .groupby("case_type")
        .apply(lambda g: weighted_mean(g["days_to_signoff"], g["weighting"]))
        .reset_index(name="w_mean_days_to_signoff")
    )

    # weighted mean days_to_alloc by case_type
    weighted_alloc = (
        eda.df
        .groupby("case_type")
        .apply(lambda g: weighted_mean(g["days_to_alloc"], g["weighting"]))
        .reset_index(name="w_mean_days_to_alloc")
    )
    
    # Call your plotting function for the interval and trends
    results = plot_pg_signoff_monthly_trends(di,"data/out/plot/plots")
    # Extract for returning
    trend_all = results["trend_all"]
    plot_paths = results["plots"]

    results_alloc = plot_allocation_monthly_trends(di,"data/out/plot/plots")
    # Extract for returning
    trend_all_alloc = results_alloc["trend_all"]
    plot_paths_alloc = results_alloc["plots"]

    print("=== LEGAL REVIEW, DATE OF LEGAL REQUEST, STATUS ===")
    typed[["legal_review", "dt_legal_review_req1", "status"]].head()
    print("=== LEGAL REVIEW COUNTS ===")
    typed["legal_review"].value_counts()
    print("=== STATUS ===")
    print(typed["status"].value_counts(dropna=False))

    # status_summary = eda.group_summary(
    #     by=["status"],
    #     metrics={
    #         "n_cases": ("id", "count"),
    #         "legal_rate": ("legal_review", "mean"),
    #         "median_days_to_signoff": ("days_to_signoff", "median"),
    #         "median_days_to_alloc": ("days_to_alloc", "median"),
    #     },
    # )
    # print(status_summary)
    
    # risk_legal_summary = eda.group_summary(
    #     by=["risk_band", "legal_review"],
    #     metrics={
    #         "n_cases": ("id", "count"),
    #         "median_days_to_signoff": ("days_to_signoff", "median"),
    #     },
    # )
    # print(risk_legal_summary)

    # --------------------------------------------------------------
    # Fuzzy inference + micro-simulation: is it appropriate?
    # --------------------------------------------------------------

    # It can be a very good fit if the stakeholder goal is:
    # - transparent “human-like” decision logic
    # - smooth transitions between LOW/MED/HIGH (instead of hard thresholds)
    # - scenario testing (e.g., “what if we increase staffing”, “what if we change triage rules”)
    
    # How I’d use it here:
    # - Use your empirical tables / fitted model to calibrate membership functions and rule strengths.
    # - Use fuzzy inference to output:
    #    - probability of picking up a new case today
    #    - probability of escalating to legal today (hazard)
    #    - probability of closing today (hazard)
    # - Then your micro-sim steps day-by-day.
    
    # Big watch-outs:
    # - Fuzzy systems are easy to write and easy to miscalibrate; you still need a validation loop:
    #    - reproduce historical backlog curve
    #    - reproduce distributions (pickup gaps, time-to-signoff, legal review rates)
    #    - reproduce variation by team/case type
    
    # A pragmatic compromise that often works well:
    # - Fit a simple statistical model first (logistic / survival / gradient boosting),
    # - Then translate it into fuzzy rules for interpretability.

    return {
        "raw": raw,
        "typed": typed,
        "daily": daily,
        "backlog": backlog_ts,
        "events": events,
        "backlog_ts": backlog_ts,
        "di": di,
        "trend": trend,
        "trend_all": trend_all,
        "trend_all_alloc": trend_all_alloc,
        "eda": {
            "cfg": cfg,
            "overview": overview,
            "missing_pct": missing_pct,
            "missing_vs_target": missing_vs_target,
            "outliers_signoff": outliers_signoff,
            "outliers_allocate": outliers_allocate,
            "weight_summary": weight_summary,
            "case_weight_summary": case_weight_summary,
            "legal_review_by_case_type": legal_review_by_case_type,
            "legal_review_by_case_status": legal_review_by_case_status,
            "staff_summary": staff_summary,
            "fte_weight_summary": fte_weight_summary,
            "status_summary": status_summary,
            "case_weight_full": case_weight_full,
            "legal_case_summary": legal_case_summary,
            "corrs": corrs,
            "cramers_case_type_w": cramers_case_type_w,
            "cramers_case_type_fte": cramers_case_type_fte,
            "class_balance": class_balance,
            "leakage_hits": leakage_hits,
            "interaction": interaction,
            #"ts_7d": ts_7d,
            #"lag_corrs": lag_corrs,
            #"km_quantiles": km_q,
            #"monthly_kpis": monthly_kpis,
        },
        "interval_dists_overall": interval_dists_overall,
        "interval_dists_by_case_type": out_case_type,
        "interval_dists_by_app_type": out_app_type,
        "interval_dists_by_fte": interval_dists_by_fte,
        "interval_dists_by_investigator": interval_dists_by_investigator,
        "pickup_prob": pickup_prob,
        "pickup_counts": pickup_counts,
        "alloc_annual_stats": alloc_annual_stats,
        "alloc_yoy_change": alloc_yoy_change,
        "weighted_signoff": weighted_signoff,
        "weighted_alloc": weighted_alloc,
        "monthly_case_flow_counts": monthly_flows,
        "monthly_flow_outputs": monthly_flow_outputs,
        "plot_paths": plot_paths,
        "plot_paths_alloc": plot_paths_alloc,
        #"plots": plot_paths,
    }

# from demo_pipeline import demo_all
outputs = demo_all()
