In [None]:
"""
STRUCTURAL CASHFLOWS â€” SINGLE PATH (NO MONTE CARLO) + ANNUAL GRADE TRANSITIONS

Includes:
- Two-part models (hazard + positive size) for:
    1) Drawdowns (size as ratio of capacity)
    2) Repayments (size as ratio of lagged NAV)
    3) Recallables (conditional on repayments; size as ratio of repayment)
- Recallables calibrated from data.parquet column: "Recallable"
- KMP limits used when present; if missing, use soft limits calibrated from data:
    rho_soft(strategy) = p95 of empirical sum(Recallable)/Commitment
    E_soft(strategy)   = median Expiration_Quarters among KMP funds in strategy (fallback 20)
- Recallables increase capacity:
    Capacity = Remaining_Commitment + Recallable_Available
  Drawdowns consume recallables FIFO; only the remainder increases cumulative drawn from commitment.
- Fund horizon cap: planned end date (+ average strategy overrun if fund historically overran)
- NAV updated each quarter from Untitled-1 output: nav_projection_sota_*.parquet/csv
- If NAV_current <= NAV_STOP_EPS: stop all cashflows from that quarter onward for that fund.
- Annual grade transitions (every 4 quarters) using:
    grade_transition_1y_pe.csv / grade_transition_1y_vc.csv / grade_transition_1y_all.csv (fallback identity)

- Single-factor Gaussian copula across funds within each quarter.
  Separate copula draws per component (draw_event, draw_size, rep_event, rep_size, rc_event, rc_size).

Outputs:
  - structural_cashflows_{year}_{quarter}_{n_q}q.(csv|parquet)
  - optional calibration CSVs

Assumptions:
- data.parquet contains:
  FundID, Adj Strategy, Grade, Fund_Age_Quarters,
  Year of Transaction Date, Quarter of Transaction Date,
  Commitment EUR, Adj Drawdown EUR, Adj Repayment EUR, NAV Adjusted EUR, Recallable,
  Planned end date with add. years as per legal doc
  (optional) Capacity column used as proxy for calibration (preferred)
- kmp.(parquet|csv) contains:
  FundID, Recallable_Percentage_Decimal, Expiration_Quarters
- NAV projections exist in DATA_DIR:
  nav_projection_sota_*.parquet/csv with columns: FundID, quarter_end, NAV_projected
- Grade transition matrices exist in DATA_DIR (preferred):
  grade_transition_1y_all.csv and optionally _pe/_vc

"""

import os
import glob
import time
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import List, Dict, Tuple
from math import erf, sqrt

# =============================
# Inputs / Paths
# =============================
t0 = time.perf_counter()

year = int(input("Enter year (e.g. 2025): ").strip())
quarter = input("Enter quarter (Q1, Q2, Q3, Q4): ").strip().upper()

n_q = int(input("Enter number of quarters to simulate (0 => default 40): ").strip() or "0")
if n_q == 0:
    n_q = 40

RHO_MKT = float(input("Single-factor copula correlation rho_mkt [0.25]: ").strip() or "0.25")
RHO_MKT = float(np.clip(RHO_MKT, 0.0, 0.999))

SEED = int(input("Random seed [1234]: ").strip() or "1234")

BASE_DIR = os.path.join("C:\\Users", os.environ.get("USERNAME"), "Documents", "Equity")
HOME = os.path.join(BASE_DIR, f"{year}_{quarter}")
DATA_DIR = os.path.join(HOME, "data")

data_path = os.path.join(DATA_DIR, "data.parquet")
kmp_path_parquet = os.path.join(DATA_DIR, "kmp.parquet")
kmp_path_csv = os.path.join(DATA_DIR, "kmp.csv")

if not os.path.exists(data_path):
    raise FileNotFoundError(f"Missing {data_path}")
if not (os.path.exists(kmp_path_parquet) or os.path.exists(kmp_path_csv)):
    raise FileNotFoundError("Missing kmp.parquet or kmp.csv in DATA_DIR")

print("HOME:", HOME)
print("DATA_DIR:", DATA_DIR)

# =============================
# Config
# =============================
AGE_BINS_Q = [-1, 7, 15, 23, 31, 39, 59, 79, 10_000]
AGE_LABELS = ["0-2y","2-4y","4-6y","6-8y","8-10y","10-15y","15-20y","20y+"]

NAV_EPS = 100.0
CAP_EPS = 1.0
SIGMA_FLOOR = 0.35
SIGMA_CAP = 2.0

SOFT_RHO_PCTL = 0.95
SOFT_EXPIRY_FALLBACK = 20

NAV_STOP_EPS = 1.0  # if NAV_current <= this => stop all cashflows for that fund

PLAN_COL = "Planned end date with add. years as per legal doc"

GRADE_STATES = ["A", "B", "C", "D"]

# =============================
# Helpers: quarters / dates
# =============================
def quarter_end_from_year_quarter(y: int, q: str) -> pd.Timestamp:
    q = q.upper().strip()
    if q not in {"Q1","Q2","Q3","Q4"}:
        raise ValueError("Quarter must be one of Q1..Q4")
    q_num = int(q[1])
    return pd.Period(f"{y}Q{q_num}", freq="Q").to_timestamp("Q")

def add_quarters(qe: pd.Timestamp, q: float) -> pd.Timestamp:
    if pd.isna(qe):
        return pd.NaT
    p = pd.Period(qe, freq="Q")
    return (p + int(round(q))).to_timestamp("Q")

def qdiff(a: pd.Timestamp, b: pd.Timestamp) -> float:
    if pd.isna(a) or pd.isna(b):
        return np.nan
    return float(pd.Period(a, freq="Q").ordinal - pd.Period(b, freq="Q").ordinal)

start_qe = quarter_end_from_year_quarter(year, quarter)
future_qe = pd.period_range(
    start=start_qe.to_period("Q") + 1,
    periods=n_q,
    freq="Q"
).to_timestamp("Q")
future_qe = pd.to_datetime(future_qe)

# =============================
# Load data + normalize
# =============================
data = pd.read_parquet(data_path).copy()
kmp = pd.read_parquet(kmp_path_parquet).copy() if os.path.exists(kmp_path_parquet) else pd.read_csv(kmp_path_csv).copy()

required = [
    "FundID","Adj Strategy","Grade","Fund_Age_Quarters",
    "Year of Transaction Date","Quarter of Transaction Date",
    "Commitment EUR","Adj Drawdown EUR","Adj Repayment EUR","NAV Adjusted EUR",
    "Recallable",
    PLAN_COL
]
missing = [c for c in required if c not in data.columns]
if missing:
    raise ValueError(f"Missing columns in data.parquet: {missing}")

# Numeric hardening
for c in ["Commitment EUR","Adj Drawdown EUR","Adj Repayment EUR","NAV Adjusted EUR","Recallable","Fund_Age_Quarters"]:
    data[c] = pd.to_numeric(data[c], errors="coerce")

data["Adj Drawdown EUR"] = data["Adj Drawdown EUR"].fillna(0.0).clip(lower=0.0)
data["Adj Repayment EUR"] = data["Adj Repayment EUR"].fillna(0.0).clip(lower=0.0)
data["Recallable"] = data["Recallable"].fillna(0.0).clip(lower=0.0)
data["NAV Adjusted EUR"] = data["NAV Adjusted EUR"].fillna(0.0).clip(lower=0.0)
data["Fund_Age_Quarters"] = data["Fund_Age_Quarters"].fillna(0.0)

# Panel quarter_end + planned_end_qe
data["quarter_end"] = pd.PeriodIndex(
    data["Year of Transaction Date"].astype(int).astype(str) + "Q" +
    data["Quarter of Transaction Date"].astype(int).astype(str),
    freq="Q"
).to_timestamp("Q")
data["quarter_end"] = pd.to_datetime(data["quarter_end"])

data["planned_end_qe"] = pd.to_datetime(data[PLAN_COL], errors="coerce").dt.to_period("Q").dt.to_timestamp("Q")

# Sort and create quarter index
data["tx_q_idx"] = data["Year of Transaction Date"].astype(int) * 4 + data["Quarter of Transaction Date"].astype(int)
data = data.sort_values(["FundID","tx_q_idx"]).reset_index(drop=True)

# Commitment level per fund (assume flow => cumulative)
data["Commitment_Level"] = data.groupby("FundID")["Commitment EUR"].cumsum().fillna(0.0)

# NAV lag in history
data["nav_prev"] = data.groupby("FundID")["NAV Adjusted EUR"].shift(1)

# Age bucket
data["AgeBucket"] = pd.cut(data["Fund_Age_Quarters"], bins=AGE_BINS_Q, labels=AGE_LABELS)

# Optional capacity proxy for calibration
CAP_PROXY_COL = "Capacity" if "Capacity" in data.columns else None
if CAP_PROXY_COL is None:
    print("WARNING: data.parquet has no 'Capacity' column; draw calibration uses Commitment_Level proxy.")
data["cap_proxy"] = pd.to_numeric(data[CAP_PROXY_COL], errors="coerce").fillna(0.0) if CAP_PROXY_COL else data["Commitment_Level"].fillna(0.0)

print("Loaded rows:", len(data), "| Funds:", data["FundID"].nunique())

# =============================
# NAV projection (Untitled-1 output)
# =============================
def find_nav_projection_file(data_dir: str) -> str:
    cands = glob.glob(os.path.join(data_dir, "nav_projection_sota_*.parquet")) + \
            glob.glob(os.path.join(data_dir, "nav_projection_sota_*.csv"))
    if not cands:
        raise FileNotFoundError(
            f"No nav_projection_sota file found in {data_dir}. "
            "Run Untitled-1.ipynb first."
        )
    cands.sort(key=os.path.getmtime, reverse=True)
    return cands[0]

nav_path = find_nav_projection_file(DATA_DIR)
print("Using NAV projection file:", nav_path)

nav_proj = pd.read_parquet(nav_path) if nav_path.lower().endswith(".parquet") else pd.read_csv(nav_path)

need_nav = {"FundID","quarter_end","NAV_projected"}
if not need_nav.issubset(nav_proj.columns):
    raise ValueError(f"NAV projection must contain columns: {need_nav}")

nav_proj = nav_proj.copy()
nav_proj["quarter_end"] = pd.to_datetime(nav_proj["quarter_end"])
nav_proj["NAV_projected"] = pd.to_numeric(nav_proj["NAV_projected"], errors="coerce").fillna(0.0).clip(lower=0.0)

nav_proj = nav_proj[nav_proj["quarter_end"].isin(set(future_qe))].copy()
nav_lookup = nav_proj.set_index(["FundID","quarter_end"])["NAV_projected"].to_dict()

# =============================
# Grade transition matrices (annual)
# =============================
p1_all_path = os.path.join(DATA_DIR, "grade_transition_1y_all.csv")
p1_pe_path  = os.path.join(DATA_DIR, "grade_transition_1y_pe.csv")
p1_vc_path  = os.path.join(DATA_DIR, "grade_transition_1y_vc.csv")

P1_ALL = pd.read_csv(p1_all_path, index_col=0) if os.path.exists(p1_all_path) else None
P1_PE  = pd.read_csv(p1_pe_path, index_col=0) if os.path.exists(p1_pe_path) else None
P1_VC  = pd.read_csv(p1_vc_path, index_col=0) if os.path.exists(p1_vc_path) else None

def _row_normalize_df(P: pd.DataFrame) -> pd.DataFrame:
    P = P.reindex(index=GRADE_STATES, columns=GRADE_STATES).fillna(0.0)
    P = P.clip(lower=0.0)
    rs = P.sum(axis=1).replace(0.0, 1.0)
    return P.div(rs, axis=0)

if P1_ALL is not None: P1_ALL = _row_normalize_df(P1_ALL)
if P1_PE is not None:  P1_PE  = _row_normalize_df(P1_PE)
if P1_VC is not None:  P1_VC  = _row_normalize_df(P1_VC)

def get_grade_matrix(strategy: str) -> pd.DataFrame:
    if strategy == "Private Equity" and P1_PE is not None:
        return P1_PE
    if strategy == "Venture Capital" and P1_VC is not None:
        return P1_VC
    if P1_ALL is not None:
        return P1_ALL
    # fallback: identity (no transitions)
    return pd.DataFrame(np.eye(4), index=GRADE_STATES, columns=GRADE_STATES)

def sample_next_grade(curr: str, P: pd.DataFrame, rng: np.random.Generator) -> str:
    if curr not in GRADE_STATES:
        curr = "D"
    probs = P.loc[curr].to_numpy(dtype=float)
    return str(rng.choice(GRADE_STATES, p=probs))

print("Grade matrices loaded:",
      "ALL" if P1_ALL is not None else "ALL:None",
      "PE" if P1_PE is not None else "PE:None",
      "VC" if P1_VC is not None else "VC:None")

# =============================
# KMP + soft limits
# =============================
kmp_needed = ["FundID","Recallable_Percentage_Decimal","Expiration_Quarters"]
missing_k = [c for c in kmp_needed if c not in kmp.columns]
if missing_k:
    raise ValueError(f"Missing columns in kmp: {missing_k}")

kmp2 = kmp[kmp_needed].copy()
kmp2["Recallable_Percentage_Decimal"] = pd.to_numeric(kmp2["Recallable_Percentage_Decimal"], errors="coerce")
kmp2["Expiration_Quarters"] = pd.to_numeric(kmp2["Expiration_Quarters"], errors="coerce")

# Static fund row (latest observed in history panel)
fund_static = (
    data.sort_values(["FundID","quarter_end"])
        .groupby("FundID")
        .tail(1)[["FundID","Adj Strategy","Grade","Fund_Age_Quarters","Commitment_Level","NAV Adjusted EUR","planned_end_qe"]]
        .copy()
)
fund_static = fund_static.merge(kmp2, on="FundID", how="left")
fund_static["has_kmp"] = fund_static["Recallable_Percentage_Decimal"].notna() & fund_static["Expiration_Quarters"].notna()

# Empirical rho per fund from history: sum(Recallable)/Commitment_Level_last
tmp_rho = (
    data.groupby("FundID", as_index=False)
        .agg(sum_rc=("Recallable","sum"), C_last=("Commitment_Level","max"))
)
tmp_rho["rho_emp"] = np.where(tmp_rho["C_last"] > 0, tmp_rho["sum_rc"] / tmp_rho["C_last"], np.nan)
tmp_rho = tmp_rho.merge(fund_static[["FundID","Adj Strategy"]], on="FundID", how="left")

rho_soft_by_strategy = (
    tmp_rho.dropna(subset=["rho_emp","Adj Strategy"])
           .groupby("Adj Strategy")["rho_emp"]
           .quantile(SOFT_RHO_PCTL)
)

exp_soft_by_strategy = (
    fund_static.loc[fund_static["has_kmp"]]
              .groupby("Adj Strategy")["Expiration_Quarters"]
              .median()
)

def soft_params(strategy: str) -> Tuple[float, int]:
    rho = float(rho_soft_by_strategy.get(strategy, 0.0))
    E = exp_soft_by_strategy.get(strategy, np.nan)
    E = int(E) if pd.notna(E) else SOFT_EXPIRY_FALLBACK
    rho = float(np.clip(rho, 0.0, 1.0))
    E = max(int(E), 0)
    return rho, E

# choose rho/E per fund (KMP if present else soft)
rho_used = []
E_used = []
for _, row in fund_static.iterrows():
    if bool(row["has_kmp"]):
        rho_used.append(float(row["Recallable_Percentage_Decimal"]))
        E_used.append(int(row["Expiration_Quarters"]))
    else:
        r, e = soft_params(row["Adj Strategy"])
        rho_used.append(r)
        E_used.append(e)

fund_static["rho_used"] = rho_used
fund_static["E_used"] = E_used

print("Funds:", fund_static["FundID"].nunique())
print("Funds with KMP limits:", int(fund_static["has_kmp"].sum()))
print("Funds using soft limits:", int((~fund_static["has_kmp"]).sum()))

# =============================
# Horizon cap: planned end + strategy overrun logic (from history)
# =============================

# last observed quarter_end per fund
fund_last = (
    data.groupby("FundID", as_index=False)
        .agg(
            last_qe=("quarter_end", "max"),
            planned_end_qe=("planned_end_qe", "last"),
            Adj_Strategy=("Adj Strategy", "last"),
        )
)

fund_last["overrun_q"] = fund_last.apply(
    lambda r: max(qdiff(r["last_qe"], r["planned_end_qe"]), 0.0)
    if pd.notna(r["planned_end_qe"]) else np.nan,
    axis=1
)

fund_last["ever_overran"] = fund_last["overrun_q"].fillna(0.0) > 0

# Average overrun by strategy among funds that overran
overran_only = fund_last[(fund_last["overrun_q"].notna()) & (fund_last["overrun_q"] > 0)].copy()
avg_overrun_by_strategy = overran_only.groupby("Adj_Strategy")["overrun_q"].mean().clip(lower=0.0)

ever_overran_map = fund_last.set_index("FundID")["ever_overran"].to_dict()
planned_end_map = fund_last.set_index("FundID")["planned_end_qe"].to_dict()

cap_qe_map: Dict[str, pd.Timestamp] = {}
for _, row in fund_static.iterrows():
    fid = row["FundID"]
    strat = row["Adj Strategy"]
    planned = planned_end_map.get(fid, pd.NaT)

    if pd.isna(planned):
        cap_qe_map[fid] = future_qe[-1]  # fallback: run to horizon
        continue

    if bool(ever_overran_map.get(fid, False)):
        avg_over = float(avg_overrun_by_strategy.get(strat, 0.0))
        cap_qe_map[fid] = add_quarters(planned, avg_over)
    else:
        cap_qe_map[fid] = planned


# =============================
# Calibration: hazard + lognormal size
# =============================
group_keys = ["Adj Strategy","Grade","AgeBucket"]

data["draw_event"] = (data["Adj Drawdown EUR"] > 0).astype(int)
data["rep_event"] = (data["Adj Repayment EUR"] > 0).astype(int)

# Draw ratio: draw / capacity proxy
data["draw_ratio"] = np.where(data["cap_proxy"] > CAP_EPS, data["Adj Drawdown EUR"] / data["cap_proxy"], np.nan)
data.loc[data["draw_ratio"] <= 0, "draw_ratio"] = np.nan

# Rep ratio: repay / NAV_prev
data["rep_ratio"] = np.where(data["nav_prev"].abs() > NAV_EPS, data["Adj Repayment EUR"] / data["nav_prev"].abs(), np.nan)
data.loc[data["rep_ratio"] <= 0, "rep_ratio"] = np.nan

# RC|Rep: event among repayment quarters, size ratio Recallable/Repayment
data["rc_given_rep_event"] = ((data["Adj Repayment EUR"] > 0) & (data["Recallable"] > 0)).astype(int)
data["rc_ratio_given_rep"] = np.where(data["Adj Repayment EUR"] > 0, data["Recallable"] / data["Adj Repayment EUR"], np.nan)
data.loc[data["rc_ratio_given_rep"] <= 0, "rc_ratio_given_rep"] = np.nan

def fit_lognormal(x: pd.Series) -> Tuple[float, float, int]:
    g = x.dropna()
    g = g[g > 0]
    n = int(len(g))
    if n == 0:
        return 0.0, SIGMA_FLOOR, 0
    lx = np.log(g.to_numpy(dtype=float))
    mu = float(np.mean(lx))
    sig = float(np.std(lx, ddof=1)) if n > 1 else SIGMA_FLOOR
    sig = float(np.clip(max(sig, SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))
    return mu, sig, n

cal_rows = []
for (s, g, a), grp in data.groupby(group_keys, dropna=False):
    p_draw = float(grp["draw_event"].mean()) if len(grp) else 0.0
    p_rep  = float(grp["rep_event"].mean()) if len(grp) else 0.0

    rep_q = grp[grp["Adj Repayment EUR"] > 0]
    p_rc_given_rep = float(rep_q["rc_given_rep_event"].mean()) if len(rep_q) else 0.0

    mu_d, sig_d, n_d = fit_lognormal(grp["draw_ratio"])
    mu_r, sig_r, n_r = fit_lognormal(grp["rep_ratio"])
    mu_c, sig_c, n_c = fit_lognormal(rep_q["rc_ratio_given_rep"]) if len(rep_q) else (0.0, SIGMA_FLOOR, 0)

    cal_rows.append({
        "Adj Strategy": s, "Grade": g, "AgeBucket": a,
        "p_draw": p_draw,
        "p_rep": p_rep,
        "p_rc_given_rep": p_rc_given_rep,
        "mu_draw": mu_d, "sig_draw": sig_d,
        "mu_rep": mu_r, "sig_rep": sig_r,
        "mu_rc": mu_c, "sig_rc": sig_c,
        "n_obs": int(len(grp))
    })

cal = pd.DataFrame(cal_rows)

# Strategy fallback
cal_s = data.groupby(["Adj Strategy"], dropna=False).agg(
    p_draw=("draw_event","mean"),
    p_rep=("rep_event","mean")
).reset_index()
rc_s = (data[data["Adj Repayment EUR"] > 0]
        .groupby(["Adj Strategy"], dropna=False)["rc_given_rep_event"].mean()
        .reset_index(name="p_rc_given_rep"))
cal_s = cal_s.merge(rc_s, on="Adj Strategy", how="left").fillna({"p_rc_given_rep": 0.0})

mu_sig_s = []
for s, grp in data.groupby(["Adj Strategy"], dropna=False):
    mu_d, sig_d, _ = fit_lognormal(grp["draw_ratio"])
    mu_r, sig_r, _ = fit_lognormal(grp["rep_ratio"])
    rep_g = grp[grp["Adj Repayment EUR"] > 0]
    mu_c, sig_c, _ = fit_lognormal(rep_g["rc_ratio_given_rep"]) if len(rep_g) else (0.0, SIGMA_FLOOR, 0)
    mu_sig_s.append({
        "Adj Strategy": s,
        "mu_draw": mu_d, "sig_draw": sig_d,
        "mu_rep": mu_r, "sig_rep": sig_r,
        "mu_rc": mu_c, "sig_rc": sig_c
    })
mu_sig_s = pd.DataFrame(mu_sig_s)
cal_s = cal_s.merge(mu_sig_s, on="Adj Strategy", how="left")

# Global fallback
global_p_draw = float(data["draw_event"].mean())
global_p_rep  = float(data["rep_event"].mean())
rep_all = data[data["Adj Repayment EUR"] > 0]
global_p_rc_given_rep = float(rep_all["rc_given_rep_event"].mean()) if len(rep_all) else 0.0

g_mu_draw, g_sig_draw, _ = fit_lognormal(data["draw_ratio"])
g_mu_rep,  g_sig_rep,  _ = fit_lognormal(data["rep_ratio"])
g_mu_rc,   g_sig_rc,   _ = fit_lognormal(rep_all["rc_ratio_given_rep"]) if len(rep_all) else (0.0, SIGMA_FLOOR, 0)

def lookup_params(strategy, grade, age_bucket) -> Dict[str, float]:
    m = (cal["Adj Strategy"].eq(strategy)) & (cal["Grade"].eq(grade)) & (cal["AgeBucket"].eq(age_bucket))
    if m.any():
        r = cal[m].iloc[0].to_dict()
    else:
        r = {}

    if not r:
        ss = cal_s[cal_s["Adj Strategy"].eq(strategy)]
        if len(ss):
            r = ss.iloc[0].to_dict()

    r["p_draw"] = float(np.clip(r.get("p_draw", global_p_draw), 0.0, 1.0))
    r["p_rep"]  = float(np.clip(r.get("p_rep", global_p_rep), 0.0, 1.0))
    r["p_rc_given_rep"] = float(np.clip(r.get("p_rc_given_rep", global_p_rc_given_rep), 0.0, 1.0))

    r["mu_draw"]  = float(r.get("mu_draw", g_mu_draw))
    r["sig_draw"] = float(np.clip(max(r.get("sig_draw", g_sig_draw), SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))
    r["mu_rep"]   = float(r.get("mu_rep", g_mu_rep))
    r["sig_rep"]  = float(np.clip(max(r.get("sig_rep", g_sig_rep), SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))
    r["mu_rc"]    = float(r.get("mu_rc", g_mu_rc))
    r["sig_rc"]   = float(np.clip(max(r.get("sig_rc", g_sig_rc), SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))

    return r

save_cal = input("Save calibration tables? (y/n) [n]: ").strip().lower() or "n"
if save_cal in {"y","yes"}:
    cal.to_csv(os.path.join(DATA_DIR, "cal_structural_sga.csv"), index=False)
    cal_s.to_csv(os.path.join(DATA_DIR, "cal_structural_s.csv"), index=False)
    print("Saved calibration tables.")

# =============================
# Copula utilities
# =============================
def norm_cdf(x: float) -> float:
    return 0.5 * (1.0 + erf(x / sqrt(2.0)))

def one_factor_uniforms(n: int, rng: np.random.Generator, rho_mkt: float) -> np.ndarray:
    rho_mkt = float(np.clip(rho_mkt, 0.0, 0.999))
    Z = rng.standard_normal()
    eps = rng.standard_normal(n)
    X = np.sqrt(rho_mkt) * Z + np.sqrt(1.0 - rho_mkt) * eps
    return np.array([norm_cdf(x) for x in X], dtype=float)

def inv_norm(u: float) -> float:
    try:
        from scipy.special import erfinv
        return sqrt(2.0) * float(erfinv(2.0*u - 1.0))
    except Exception:
        u = float(np.clip(u, 1e-6, 1.0 - 1e-6))
        return float(np.sign(u - 0.5) * np.sqrt(2.0) * np.sqrt(abs(np.log(1.0 - 2.0*abs(u-0.5)))))

def lognormal_from_u(mu: float, sigma: float, u: float) -> float:
    z = inv_norm(u)
    return float(np.exp(mu + sigma * z))

# =============================
# Recallable ledger
# =============================
@dataclass
class RecallableBucket:
    created_q: int
    expiry_q: int
    amount_remaining: float

@dataclass
class RecallableLedger:
    rho: float
    expiry_quarters: int
    commitment: float
    buckets: List[RecallableBucket] = field(default_factory=list)

    def _rc_cap(self) -> float:
        return max(float(self.rho), 0.0) * max(float(self.commitment), 0.0)

    def drop_expired(self, q: int) -> None:
        if int(self.expiry_quarters) <= 0:
            self.buckets = []
            return
        self.buckets = [b for b in self.buckets if b.expiry_q >= q and b.amount_remaining > 0]

    def available(self, q: int) -> float:
        self.drop_expired(q)
        return float(sum(b.amount_remaining for b in self.buckets))

    def add_recallable(self, q: int, rc_amount: float, enforce_cap: bool = True) -> float:
        self.drop_expired(q)
        x = max(float(rc_amount or 0.0), 0.0)
        if x <= 0.0 or int(self.expiry_quarters) <= 0:
            return 0.0

        add_amt = x
        if enforce_cap:
            cap = self._rc_cap()
            cur = self.available(q)
            room = max(cap - cur, 0.0)
            add_amt = min(add_amt, room)

        if add_amt <= 0.0:
            return 0.0

        self.buckets.append(RecallableBucket(
            created_q=q,
            expiry_q=q + int(self.expiry_quarters),
            amount_remaining=float(add_amt)
        ))
        return float(add_amt)

    def consume_for_drawdown(self, q: int, draw_amount: float) -> Dict[str, float]:
        self.drop_expired(q)
        need = max(float(draw_amount or 0.0), 0.0)
        if need <= 0.0:
            return {"use_rc": 0.0, "use_commitment": 0.0}

        self.buckets.sort(key=lambda b: b.created_q)  # FIFO

        use_rc = 0.0
        for b in self.buckets:
            if need <= 0:
                break
            take = min(b.amount_remaining, need)
            b.amount_remaining -= take
            need -= take
            use_rc += take

        self.buckets = [b for b in self.buckets if b.amount_remaining > 0]
        use_commitment = max(float(draw_amount) - use_rc, 0.0)
        return {"use_rc": float(use_rc), "use_commitment": float(use_commitment)}

# =============================
# Simulation (single path) + annual grade transitions
# =============================
rng = np.random.default_rng(SEED)

funds = fund_static["FundID"].tolist()
n_funds = len(funds)

ledgers: Dict[str, RecallableLedger] = {}
state: Dict[str, Dict[str, object]] = {}

# init fund states
for _, row in fund_static.iterrows():
    fid = row["FundID"]
    strat = row["Adj Strategy"]
    grade = row["Grade"] if row["Grade"] in GRADE_STATES else "D"

    C = max(float(row.get("Commitment_Level", 0.0) or 0.0), 0.0)
    rho = float(row.get("rho_used", 0.0) or 0.0)
    E = int(row.get("E_used", 0) or 0)

    ledgers[fid] = RecallableLedger(rho=rho, expiry_quarters=E, commitment=C)

    age0 = row.get("Fund_Age_Quarters", 0.0)
    nav0 = row.get("NAV Adjusted EUR", 0.0)

    state[fid] = {
        "dd_cum_commit": 0.0,
        "age_q": int(0 if pd.isna(age0) else age0),
        "nav_prev": float(0.0 if pd.isna(nav0) else nav0),
        "strategy": strat,
        "grade": grade,
        "commitment": C,
        "alive": 1
    }

out = []

for step in range(1, n_q + 1):
    qe = future_qe[step - 1]

    # Separate copula draws per component
    U_draw_event = one_factor_uniforms(n_funds, rng, RHO_MKT)
    U_draw_size  = one_factor_uniforms(n_funds, rng, RHO_MKT)
    U_rep_event  = one_factor_uniforms(n_funds, rng, RHO_MKT)
    U_rep_size   = one_factor_uniforms(n_funds, rng, RHO_MKT)
    U_rc_event   = one_factor_uniforms(n_funds, rng, RHO_MKT)
    U_rc_size    = one_factor_uniforms(n_funds, rng, RHO_MKT)

    for i, fid in enumerate(funds):
        st = state[fid]
        ledger = ledgers[fid]

        strategy = st["strategy"]
        cap_qe = cap_qe_map.get(fid, future_qe[-1])

        # Horizon cap
        if qe > cap_qe:
            continue

        # Dead fund
        if int(st["alive"]) == 0:
            continue

        # NAV update from projection
        nav_current = float(nav_lookup.get((fid, qe), st["nav_prev"]))
        nav_current = max(nav_current, 0.0)

        # NAV stop
        if nav_current <= NAV_STOP_EPS:
            st["alive"] = 0
            out.append({
                "FundID": fid,
                "quarter_end": qe,
                "step_q": step,
                "Adj Strategy": strategy,
                "Grade": st["grade"],
                "Age_Quarters": st["age_q"],
                "AgeBucket": pd.cut(pd.Series([st["age_q"]]), bins=AGE_BINS_Q, labels=AGE_LABELS).iloc[0],
                "Commitment_Level": st["commitment"],
                "NAV_prev": float(st["nav_prev"]),
                "NAV_current": float(nav_current),
                "Stopped_NAV_Zero": 1,

                "RC_Avail_Pre": ledger.available(step),
                "Remaining_Commit_Pre": max(st["commitment"] - st["dd_cum_commit"], 0.0),
                "Capacity_Pre": max(st["commitment"] - st["dd_cum_commit"], 0.0) + ledger.available(step),

                "Draw_Event": 0, "Draw_Amount": 0.0, "Use_Recallable": 0.0, "Use_Commitment": 0.0,
                "DD_Cum_Commitment": float(st["dd_cum_commit"]),

                "Rep_Event": 0, "Rep_Amount": 0.0,
                "RC_Event": 0, "RC_Amount_Raw": 0.0, "RC_Added": 0.0,

                "RC_Avail_Post": ledger.available(step),
                "Remaining_Commit_Post": max(st["commitment"] - st["dd_cum_commit"], 0.0),
                "Capacity_Post": max(st["commitment"] - st["dd_cum_commit"], 0.0) + ledger.available(step),

                "rho_used": float(ledger.rho),
                "E_used": int(ledger.expiry_quarters),
                "cap_qe": cap_qe
            })
            st["nav_prev"] = nav_current
            continue

        # Age advances
        st["age_q"] += 1
        age_q = int(st["age_q"])
        age_bucket = pd.cut(pd.Series([age_q]), bins=AGE_BINS_Q, labels=AGE_LABELS).iloc[0]

        # Annual grade transition (every 4 quarters)
        if step % 4 == 0:
            P = get_grade_matrix(strategy)
            st["grade"] = sample_next_grade(st["grade"], P, rng)

        grade = st["grade"]

        params = lookup_params(strategy, grade, age_bucket)

        rc_avail_pre = ledger.available(step)
        remaining_commit_pre = max(st["commitment"] - st["dd_cum_commit"], 0.0)
        capacity_pre = remaining_commit_pre + rc_avail_pre

        # Drawdown: hazard + lognormal ratio * capacity
        draw_event = (U_draw_event[i] < params["p_draw"]) and (capacity_pre > 0.0)
        draw_amt = 0.0
        use_rc = 0.0
        use_commit = 0.0

        if draw_event:
            ratio = lognormal_from_u(params["mu_draw"], params["sig_draw"], float(U_draw_size[i]))
            ratio = float(np.clip(ratio, 0.0, 1.0))
            draw_amt = ratio * capacity_pre

            cons = ledger.consume_for_drawdown(step, draw_amt)
            use_rc = cons["use_rc"]
            use_commit = cons["use_commitment"]
            st["dd_cum_commit"] += use_commit

        # Repayment: hazard + lognormal ratio * NAV_prev
        nav_prev = float(st["nav_prev"])
        rep_event = (U_rep_event[i] < params["p_rep"]) and (nav_prev > NAV_EPS)

        rep_amt = 0.0
        if rep_event:
            rep_ratio = lognormal_from_u(params["mu_rep"], params["sig_rep"], float(U_rep_size[i]))
            rep_ratio = float(np.clip(rep_ratio, 0.0, 2.0))
            rep_amt = rep_ratio * nav_prev

        # Recallable conditional on repayment
        rc_event = (rep_amt > 0.0) and (U_rc_event[i] < params["p_rc_given_rep"])
        rc_amt_raw = 0.0
        rc_added = 0.0
        if rc_event:
            rc_ratio = lognormal_from_u(params["mu_rc"], params["sig_rc"], float(U_rc_size[i]))
            rc_ratio = float(np.clip(rc_ratio, 0.0, 1.0))
            rc_amt_raw = rc_ratio * rep_amt
            rc_added = ledger.add_recallable(step, rc_amt_raw, enforce_cap=True)

        rc_avail_post = ledger.available(step)
        remaining_commit_post = max(st["commitment"] - st["dd_cum_commit"], 0.0)
        capacity_post = remaining_commit_post + rc_avail_post

        out.append({
            "FundID": fid,
            "quarter_end": qe,
            "step_q": step,
            "Adj Strategy": strategy,
            "Grade": grade,
            "Age_Quarters": age_q,
            "AgeBucket": age_bucket,
            "Commitment_Level": st["commitment"],

            "NAV_prev": nav_prev,
            "NAV_current": nav_current,
            "Stopped_NAV_Zero": 0,

            "RC_Avail_Pre": rc_avail_pre,
            "Remaining_Commit_Pre": remaining_commit_pre,
            "Capacity_Pre": capacity_pre,

            "Draw_Event": int(draw_event),
            "Draw_Amount": float(draw_amt),
            "Use_Recallable": float(use_rc),
            "Use_Commitment": float(use_commit),
            "DD_Cum_Commitment": float(st["dd_cum_commit"]),

            "Rep_Event": int(rep_event),
            "Rep_Amount": float(rep_amt),

            "RC_Event": int(rc_event),
            "RC_Amount_Raw": float(rc_amt_raw),
            "RC_Added": float(rc_added),

            "RC_Avail_Post": float(rc_avail_post),
            "Remaining_Commit_Post": float(remaining_commit_post),
            "Capacity_Post": float(capacity_post),

            "rho_used": float(ledger.rho),
            "E_used": int(ledger.expiry_quarters),
            "cap_qe": cap_qe
        })

        st["nav_prev"] = nav_current

sim = pd.DataFrame(out)

print("\nSimulated rows:", len(sim))
print("Simulated funds:", sim["FundID"].nunique())

# =============================
# Save outputs
# =============================
save = input("Save structural cashflow outputs? (y/n) [y]: ").strip().lower()
if save in {"", "y", "yes"}:
    out_csv = os.path.join(DATA_DIR, f"structural_cashflows_{year}_{quarter}_{n_q}q.csv")
    out_pq  = os.path.join(DATA_DIR, f"structural_cashflows_{year}_{quarter}_{n_q}q.parquet")
    sim.to_csv(out_csv, index=False, sep=";", encoding="utf-8-sig")
    sim.to_parquet(out_pq, index=False)
    print("Saved:")
    print(out_csv)
    print(out_pq)

# =============================
# Sanity checks
# =============================
print("\nSanity checks:")
print("Total draw amount:", float(sim["Draw_Amount"].sum()))
print("Total repayment amount:", float(sim["Rep_Amount"].sum()))
print("Total recallable added:", float(sim["RC_Added"].sum()))
print("NAV stop count (rows):", int(sim["Stopped_NAV_Zero"].sum()))

# Cap compliance at end per fund (RC <= rho*C)
end = sim.sort_values(["FundID","quarter_end"]).groupby("FundID").tail(1).copy()
end["rc_cap"] = end["rho_used"] * end["Commitment_Level"]
viol = end[end["RC_Avail_Post"] - end["rc_cap"] > 1e-6]
print("Cap violations at end (should be 0):", len(viol))
if len(viol):
    print(viol[["FundID","RC_Avail_Post","rc_cap","rho_used","Commitment_Level"]].head(20))

print("\nRuntime (seconds):", round(time.perf_counter() - t0, 2))
