In [1]:
"""
STRUCTURAL CASHFLOWS â€” SINGLE PATH (NO MONTE CARLO), ENDOGENOUS NAV

Inputs (from your pipeline):
- data.parquet (historical panel, for calibrating hazards and size ratios)
- kmp.parquet/csv (recallable limits; if missing for a fund -> soft limits from data)
- omega_projection_sota_{year}_{quarter}_{n_q}q.parquet/csv (from NAV Logic)
    includes: FundID, quarter_end, omega, Adj Strategy, Grade, Fund_Age_Quarters (or AgeBucket)
- nav_start_sota_{year}_{quarter}.parquet/csv (from NAV Logic)
    includes: FundID, NAV_start, NAV_start_source, cap_qe, Adj Strategy, Grade, Fund_Age_Quarters

Model:
- Two-part hazard + lognormal size for:
    Drawdowns: size ratio to Capacity_Pre
    Repayments: size ratio to NAV_prev
    Recallables: conditional on Rep_Regular, ratio to Rep_Regular

- Recallables increase capacity:
    Capacity_Pre = Remaining_Commitment + RC_Avail_Pre
  Drawdowns consume recallables FIFO; only remainder reduces Remaining_Commitment.

- NAV is endogenous:
    NAV_after = max((NAV_prev + Draw_Amount - Rep_Total) * (1 + omega_t), 0)

- Terminal condition:
    At fund cap_qe (last quarter in omega path), enforce NAV=0 by:
      Rep_Terminal = NAV_after
      NAV_end = 0
    Rep_Terminal is separated (does not affect calibration distributions, does not generate recallables).

Copula:
- Single-factor Gaussian copula across funds within each quarter.
- Separate copula draws per component: draw_event, draw_size, rep_event, rep_size, rc_event, rc_size.

Outputs:
- structural_cashflows_endogenous_{year}_{quarter}_{n_q}q.(csv|parquet)
"""

import os
import glob
import time
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from typing import List, Dict, Tuple
from math import erf, sqrt

t0 = time.perf_counter()

# =============================
# User inputs / paths
# =============================
year = int(input("Enter year (e.g. 2025): ").strip())
quarter = input("Enter quarter (Q1, Q2, Q3, Q4): ").strip().upper()

n_q = int(input("Enter number of quarters to simulate (0 => default 40): ").strip() or "0")
if n_q == 0:
    n_q = 40

RHO_MKT = float(input("Single-factor copula correlation rho_mkt [0.25]: ").strip() or "0.25")
RHO_MKT = float(np.clip(RHO_MKT, 0.0, 0.999))

SEED = int(input("Random seed [1234]: ").strip() or "1234")

BASE_DIR = os.path.join("C:\\Users", os.environ.get("USERNAME"), "Documents", "Equity")
HOME = os.path.join(BASE_DIR, f"{year}_{quarter}")
DATA_DIR = os.path.join(HOME, "data")

data_path = os.path.join(DATA_DIR, "data.parquet")
kmp_path_parquet = os.path.join(DATA_DIR, "kmp.parquet")
kmp_path_csv = os.path.join(DATA_DIR, "kmp.csv")

if not os.path.exists(data_path):
    raise FileNotFoundError(f"Missing {data_path}")
if not (os.path.exists(kmp_path_parquet) or os.path.exists(kmp_path_csv)):
    raise FileNotFoundError("Missing kmp.parquet or kmp.csv in DATA_DIR")

print("HOME:", HOME)
print("DATA_DIR:", DATA_DIR)

# =============================
# Config
# =============================
AGE_BINS_Q = [-1, 7, 15, 23, 31, 39, 59, 79, 10_000]
AGE_LABELS = ["0-2y","2-4y","4-6y","6-8y","8-10y","10-15y","15-20y","20y+"]

NAV_EPS = 100.0         # repayment sizing requires NAV_prev > NAV_EPS
NAV_STOP_EPS = 1.0      # if NAV <= this, stop fund
CAP_EPS = 1.0

SIGMA_FLOOR = 0.35
SIGMA_CAP = 2.0

SOFT_RHO_PCTL = 0.95
SOFT_EXPIRY_FALLBACK = 20

GRADE_STATES = ["A","B","C","D"]

# =============================
# Helpers
# =============================
def make_age_bucket_q(age_q: float):
    return pd.cut(pd.Series([age_q]), bins=AGE_BINS_Q, labels=AGE_LABELS).iloc[0]

def norm_cdf(x: float) -> float:
    return 0.5 * (1.0 + erf(x / sqrt(2.0)))

def one_factor_uniforms(n: int, rng: np.random.Generator, rho_mkt: float) -> np.ndarray:
    rho_mkt = float(np.clip(rho_mkt, 0.0, 0.999))
    Z = rng.standard_normal()
    eps = rng.standard_normal(n)
    X = np.sqrt(rho_mkt) * Z + np.sqrt(1.0 - rho_mkt) * eps
    return np.array([norm_cdf(x) for x in X], dtype=float)

def inv_norm(u: float) -> float:
    # Prefer scipy if available
    try:
        from scipy.special import erfinv
        return sqrt(2.0) * float(erfinv(2.0*u - 1.0))
    except Exception:
        # fallback approximation (deterministic-ish)
        u = float(np.clip(u, 1e-6, 1.0 - 1e-6))
        return float(np.sign(u - 0.5) * np.sqrt(2.0) * np.sqrt(abs(np.log(1.0 - 2.0*abs(u-0.5)))))

def lognormal_from_u(mu: float, sigma: float, u: float) -> float:
    z = inv_norm(u)
    return float(np.exp(mu + sigma * z))

# =============================
# Recallable ledger
# =============================
@dataclass
class RecallableBucket:
    created_q: int
    expiry_q: int
    amount_remaining: float

@dataclass
class RecallableLedger:
    rho: float
    expiry_quarters: int
    commitment: float
    buckets: List[RecallableBucket] = field(default_factory=list)

    def _rc_cap(self) -> float:
        return max(float(self.rho), 0.0) * max(float(self.commitment), 0.0)

    def drop_expired(self, q: int) -> None:
        if int(self.expiry_quarters) <= 0:
            self.buckets = []
            return
        self.buckets = [b for b in self.buckets if b.expiry_q >= q and b.amount_remaining > 0]

    def available(self, q: int) -> float:
        self.drop_expired(q)
        return float(sum(b.amount_remaining for b in self.buckets))

    def add_recallable(self, q: int, rc_amount: float, enforce_cap: bool = True) -> float:
        self.drop_expired(q)
        x = max(float(rc_amount or 0.0), 0.0)
        if x <= 0.0 or int(self.expiry_quarters) <= 0:
            return 0.0

        add_amt = x
        if enforce_cap:
            cap = self._rc_cap()
            cur = self.available(q)
            room = max(cap - cur, 0.0)
            add_amt = min(add_amt, room)

        if add_amt <= 0.0:
            return 0.0

        self.buckets.append(RecallableBucket(
            created_q=q,
            expiry_q=q + int(self.expiry_quarters),
            amount_remaining=float(add_amt)
        ))
        return float(add_amt)

    def consume_for_drawdown(self, q: int, draw_amount: float) -> Dict[str, float]:
        self.drop_expired(q)
        need = max(float(draw_amount or 0.0), 0.0)
        if need <= 0.0:
            return {"use_rc": 0.0, "use_commitment": 0.0}

        self.buckets.sort(key=lambda b: b.created_q)  # FIFO

        use_rc = 0.0
        for b in self.buckets:
            if need <= 0:
                break
            take = min(b.amount_remaining, need)
            b.amount_remaining -= take
            need -= take
            use_rc += take

        self.buckets = [b for b in self.buckets if b.amount_remaining > 0]
        use_commitment = max(float(draw_amount) - use_rc, 0.0)
        return {"use_rc": float(use_rc), "use_commitment": float(use_commitment)}

# =============================
# Load historical data for calibration
# =============================
data = pd.read_parquet(data_path).copy()

# Required columns for calibration
req = [
    "FundID","Adj Strategy","Grade","Fund_Age_Quarters",
    "Year of Transaction Date","Quarter of Transaction Date",
    "Commitment EUR","Adj Drawdown EUR","Adj Repayment EUR","NAV Adjusted EUR","Recallable"
]
missing = [c for c in req if c not in data.columns]
if missing:
    raise ValueError(f"Missing in data.parquet: {missing}")

for c in ["Commitment EUR","Adj Drawdown EUR","Adj Repayment EUR","NAV Adjusted EUR","Recallable","Fund_Age_Quarters"]:
    data[c] = pd.to_numeric(data[c], errors="coerce")

data["Adj Drawdown EUR"] = data["Adj Drawdown EUR"].fillna(0.0).clip(lower=0.0)
data["Adj Repayment EUR"] = data["Adj Repayment EUR"].fillna(0.0).clip(lower=0.0)
data["Recallable"] = data["Recallable"].fillna(0.0).clip(lower=0.0)
data["NAV Adjusted EUR"] = data["NAV Adjusted EUR"].fillna(0.0).clip(lower=0.0)
data["Fund_Age_Quarters"] = data["Fund_Age_Quarters"].fillna(0.0)

# quarter_end + ordering
data["quarter_end"] = pd.PeriodIndex(
    data["Year of Transaction Date"].astype(int).astype(str) + "Q" +
    data["Quarter of Transaction Date"].astype(int).astype(str),
    freq="Q"
).to_timestamp("Q")
data["quarter_end"] = pd.to_datetime(data["quarter_end"])

data = data.sort_values(["FundID","quarter_end"]).reset_index(drop=True)

# commitment level proxy (flow->cum)
data["Commitment_Level"] = data.groupby("FundID")["Commitment EUR"].cumsum().fillna(0.0)

# NAV lag for repayment ratio calibration
data["nav_prev"] = data.groupby("FundID")["NAV Adjusted EUR"].shift(1)

# age bucket
data["AgeBucket"] = pd.cut(data["Fund_Age_Quarters"], bins=AGE_BINS_Q, labels=AGE_LABELS)

# capacity proxy for draw ratio calibration
CAP_PROXY_COL = "Capacity" if "Capacity" in data.columns else None
if CAP_PROXY_COL is None:
    print("WARNING: data.parquet has no 'Capacity' column; draw calibration uses Commitment_Level proxy.")
data["cap_proxy"] = pd.to_numeric(data[CAP_PROXY_COL], errors="coerce").fillna(0.0) if CAP_PROXY_COL else data["Commitment_Level"].fillna(0.0)

# event flags + ratios
data["draw_event"] = (data["Adj Drawdown EUR"] > 0).astype(int)
data["rep_event"]  = (data["Adj Repayment EUR"] > 0).astype(int)

data["draw_ratio"] = np.where(data["cap_proxy"] > CAP_EPS, data["Adj Drawdown EUR"] / data["cap_proxy"], np.nan)
data.loc[data["draw_ratio"] <= 0, "draw_ratio"] = np.nan

data["rep_ratio"] = np.where(data["nav_prev"].abs() > NAV_EPS, data["Adj Repayment EUR"] / data["nav_prev"].abs(), np.nan)
data.loc[data["rep_ratio"] <= 0, "rep_ratio"] = np.nan

data["rc_given_rep_event"] = ((data["Adj Repayment EUR"] > 0) & (data["Recallable"] > 0)).astype(int)
data["rc_ratio_given_rep"] = np.where(data["Adj Repayment EUR"] > 0, data["Recallable"] / data["Adj Repayment EUR"], np.nan)
data.loc[data["rc_ratio_given_rep"] <= 0, "rc_ratio_given_rep"] = np.nan

def fit_lognormal(x: pd.Series) -> Tuple[float, float]:
    g = x.dropna()
    g = g[g > 0]
    if len(g) == 0:
        return 0.0, SIGMA_FLOOR
    lx = np.log(g.to_numpy(dtype=float))
    mu = float(np.mean(lx))
    sig = float(np.std(lx, ddof=1)) if len(g) > 1 else SIGMA_FLOOR
    sig = float(np.clip(max(sig, SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))
    return mu, sig

# build calibration tables
group_keys = ["Adj Strategy","Grade","AgeBucket"]

rows = []
for (s,g,a), grp in data.groupby(group_keys, dropna=False):
    p_draw = float(grp["draw_event"].mean()) if len(grp) else 0.0
    p_rep  = float(grp["rep_event"].mean()) if len(grp) else 0.0

    rep_q = grp[grp["Adj Repayment EUR"] > 0]
    p_rc_given_rep = float(rep_q["rc_given_rep_event"].mean()) if len(rep_q) else 0.0

    mu_d, sig_d = fit_lognormal(grp["draw_ratio"])
    mu_r, sig_r = fit_lognormal(grp["rep_ratio"])
    mu_c, sig_c = fit_lognormal(rep_q["rc_ratio_given_rep"]) if len(rep_q) else (0.0, SIGMA_FLOOR)

    rows.append({
        "Adj Strategy": s, "Grade": g, "AgeBucket": a,
        "p_draw": p_draw, "p_rep": p_rep, "p_rc_given_rep": p_rc_given_rep,
        "mu_draw": mu_d, "sig_draw": sig_d,
        "mu_rep": mu_r, "sig_rep": sig_r,
        "mu_rc": mu_c, "sig_rc": sig_c,
        "n_obs": int(len(grp))
    })
cal = pd.DataFrame(rows)

# strategy fallback
cal_s = data.groupby(["Adj Strategy"], dropna=False).agg(
    p_draw=("draw_event","mean"),
    p_rep=("rep_event","mean")
).reset_index()
rc_s = (data[data["Adj Repayment EUR"] > 0]
        .groupby(["Adj Strategy"], dropna=False)["rc_given_rep_event"].mean()
        .reset_index(name="p_rc_given_rep"))
cal_s = cal_s.merge(rc_s, on="Adj Strategy", how="left").fillna({"p_rc_given_rep": 0.0})

mu_sig_s = []
for s, grp in data.groupby(["Adj Strategy"], dropna=False):
    mu_d, sig_d = fit_lognormal(grp["draw_ratio"])
    mu_r, sig_r = fit_lognormal(grp["rep_ratio"])
    rep_g = grp[grp["Adj Repayment EUR"] > 0]
    mu_c, sig_c = fit_lognormal(rep_g["rc_ratio_given_rep"]) if len(rep_g) else (0.0, SIGMA_FLOOR)
    mu_sig_s.append({"Adj Strategy": s,
                     "mu_draw": mu_d, "sig_draw": sig_d,
                     "mu_rep": mu_r, "sig_rep": sig_r,
                     "mu_rc": mu_c, "sig_rc": sig_c})
mu_sig_s = pd.DataFrame(mu_sig_s)
cal_s = cal_s.merge(mu_sig_s, on="Adj Strategy", how="left")

# global fallback
global_p_draw = float(data["draw_event"].mean())
global_p_rep  = float(data["rep_event"].mean())
rep_all = data[data["Adj Repayment EUR"] > 0]
global_p_rc_given_rep = float(rep_all["rc_given_rep_event"].mean()) if len(rep_all) else 0.0
g_mu_draw, g_sig_draw = fit_lognormal(data["draw_ratio"])
g_mu_rep,  g_sig_rep  = fit_lognormal(data["rep_ratio"])
g_mu_rc,   g_sig_rc   = fit_lognormal(rep_all["rc_ratio_given_rep"]) if len(rep_all) else (0.0, SIGMA_FLOOR)

def lookup_params(strategy, grade, age_bucket) -> Dict[str, float]:
    m = (cal["Adj Strategy"].eq(strategy)) & (cal["Grade"].eq(grade)) & (cal["AgeBucket"].eq(age_bucket))
    if m.any():
        r = cal[m].iloc[0].to_dict()
    else:
        r = {}

    if not r:
        ss = cal_s[cal_s["Adj Strategy"].eq(strategy)]
        if len(ss):
            r = ss.iloc[0].to_dict()

    r["p_draw"] = float(np.clip(r.get("p_draw", global_p_draw), 0.0, 1.0))
    r["p_rep"]  = float(np.clip(r.get("p_rep", global_p_rep), 0.0, 1.0))
    r["p_rc_given_rep"] = float(np.clip(r.get("p_rc_given_rep", global_p_rc_given_rep), 0.0, 1.0))

    r["mu_draw"]  = float(r.get("mu_draw", g_mu_draw))
    r["sig_draw"] = float(np.clip(max(r.get("sig_draw", g_sig_draw), SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))
    r["mu_rep"]   = float(r.get("mu_rep", g_mu_rep))
    r["sig_rep"]  = float(np.clip(max(r.get("sig_rep", g_sig_rep), SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))
    r["mu_rc"]    = float(r.get("mu_rc", g_mu_rc))
    r["sig_rc"]   = float(np.clip(max(r.get("sig_rc", g_sig_rc), SIGMA_FLOOR), SIGMA_FLOOR, SIGMA_CAP))
    return r

# =============================
# Load KMP and build soft limits (same as before)
# =============================
kmp = pd.read_parquet(kmp_path_parquet).copy() if os.path.exists(kmp_path_parquet) else pd.read_csv(kmp_path_csv).copy()
kmp_needed = ["FundID","Recallable_Percentage_Decimal","Expiration_Quarters"]
missing_k = [c for c in kmp_needed if c not in kmp.columns]
if missing_k:
    raise ValueError(f"Missing columns in kmp: {missing_k}")

kmp2 = kmp[kmp_needed].copy()
kmp2["Recallable_Percentage_Decimal"] = pd.to_numeric(kmp2["Recallable_Percentage_Decimal"], errors="coerce")
kmp2["Expiration_Quarters"] = pd.to_numeric(kmp2["Expiration_Quarters"], errors="coerce")

# empirical rho per fund
tmp_rho = (
    data.groupby("FundID", as_index=False)
        .agg(sum_rc=("Recallable","sum"), C_last=("Commitment_Level","max"))
)
tmp_rho["rho_emp"] = np.where(tmp_rho["C_last"] > 0, tmp_rho["sum_rc"] / tmp_rho["C_last"], np.nan)
tmp_rho = tmp_rho.merge(
    data.sort_values(["FundID","quarter_end"]).groupby("FundID").tail(1)[["FundID","Adj Strategy"]],
    on="FundID", how="left"
)

rho_soft_by_strategy = (
    tmp_rho.dropna(subset=["rho_emp","Adj Strategy"])
           .groupby("Adj Strategy")["rho_emp"]
           .quantile(SOFT_RHO_PCTL)
)

# expiry soft from KMP funds by strategy
fund_kmp_merge = (
    data.sort_values(["FundID","quarter_end"]).groupby("FundID").tail(1)[["FundID","Adj Strategy"]]
    .merge(kmp2, on="FundID", how="left")
)
exp_soft_by_strategy = (
    fund_kmp_merge[fund_kmp_merge["Expiration_Quarters"].notna()]
    .groupby("Adj Strategy")["Expiration_Quarters"].median()
)

def soft_params(strategy: str) -> Tuple[float,int]:
    rho = float(rho_soft_by_strategy.get(strategy, 0.0))
    E = exp_soft_by_strategy.get(strategy, np.nan)
    E = int(E) if pd.notna(E) else SOFT_EXPIRY_FALLBACK
    rho = float(np.clip(rho, 0.0, 1.0))
    E = max(int(E), 0)
    return rho, E

# =============================
# Load omega projection + NAV_start (NEW)
# =============================
def find_omega_file(data_dir: str) -> str:
    cands = glob.glob(os.path.join(data_dir, "omega_projection_sota_*.parquet")) + \
            glob.glob(os.path.join(data_dir, "omega_projection_sota_*.csv"))
    if not cands:
        raise FileNotFoundError("No omega_projection_sota_* file found. Run NAV Logic first.")
    cands.sort(key=os.path.getmtime, reverse=True)
    return cands[0]

def find_navstart_file(data_dir: str, year: int, quarter: str) -> str:
    cands = [
        os.path.join(data_dir, f"nav_start_sota_{year}_{quarter}.parquet"),
        os.path.join(data_dir, f"nav_start_sota_{year}_{quarter}.csv"),
    ]
    for p in cands:
        if os.path.exists(p):
            return p
    # fallback: latest nav_start_sota_*
    c2 = glob.glob(os.path.join(data_dir, "nav_start_sota_*.parquet")) + \
         glob.glob(os.path.join(data_dir, "nav_start_sota_*.csv"))
    if not c2:
        raise FileNotFoundError("No nav_start_sota_* file found. Run NAV Logic first.")
    c2.sort(key=os.path.getmtime, reverse=True)
    return c2[0]

omega_path = find_omega_file(DATA_DIR)
navstart_path = find_navstart_file(DATA_DIR, year, quarter)

print("Using omega file:", omega_path)
print("Using nav_start file:", navstart_path)

omega_df = pd.read_parquet(omega_path) if omega_path.lower().endswith(".parquet") else pd.read_csv(omega_path)
navstart = pd.read_parquet(navstart_path) if navstart_path.lower().endswith(".parquet") else pd.read_csv(navstart_path)

need_omega = {"FundID","quarter_end","omega","Adj Strategy","Grade"}
if not need_omega.issubset(omega_df.columns):
    raise ValueError(f"omega_projection must contain columns: {need_omega}")

need_ns = {"FundID","NAV_start","cap_qe"}
if not need_ns.issubset(navstart.columns):
    raise ValueError(f"nav_start must contain columns: {need_ns}")

omega_df = omega_df.copy()
omega_df["quarter_end"] = pd.to_datetime(omega_df["quarter_end"])
omega_df["omega"] = pd.to_numeric(omega_df["omega"], errors="coerce").fillna(0.0)

# Ensure grade strings
omega_df["Grade"] = omega_df["Grade"].astype(str).str.strip()
omega_df.loc[~omega_df["Grade"].isin(GRADE_STATES), "Grade"] = "D"

navstart = navstart.copy()
navstart["NAV_start"] = pd.to_numeric(navstart["NAV_start"], errors="coerce").fillna(0.0).clip(lower=0.0)
navstart["cap_qe"] = pd.to_datetime(navstart["cap_qe"], errors="coerce")

# Keep only funds present in omega_df
funds = sorted(set(omega_df["FundID"]).intersection(set(navstart["FundID"])))
if not funds:
    raise ValueError("No overlap between omega_projection and nav_start funds.")

omega_df = omega_df[omega_df["FundID"].isin(funds)].copy()
navstart = navstart[navstart["FundID"].isin(funds)].copy()

# =============================
# Build per-quarter copula uniforms across funds
# =============================
rng = np.random.default_rng(SEED)
fund_index = {fid:i for i, fid in enumerate(funds)}
n_funds = len(funds)

quarters = sorted(omega_df["quarter_end"].drop_duplicates().tolist())
# limit to at most n_q quarters globally (omega already should be n_q, but keep safe)
quarters = quarters[:n_q]

U_by_q = {}
for qe in quarters:
    U_by_q[qe] = {
        "draw_event": one_factor_uniforms(n_funds, rng, RHO_MKT),
        "draw_size":  one_factor_uniforms(n_funds, rng, RHO_MKT),
        "rep_event":  one_factor_uniforms(n_funds, rng, RHO_MKT),
        "rep_size":   one_factor_uniforms(n_funds, rng, RHO_MKT),
        "rc_event":   one_factor_uniforms(n_funds, rng, RHO_MKT),
        "rc_size":    one_factor_uniforms(n_funds, rng, RHO_MKT),
    }

# =============================
# Initialize fund state + ledgers
# =============================
# commitment level per fund: use last historical Commitment_Level (proxy)
fund_commit = (
    data.sort_values(["FundID","quarter_end"])
        .groupby("FundID")
        .tail(1)[["FundID","Adj Strategy","Commitment_Level"]]
        .set_index("FundID")["Commitment_Level"]
        .to_dict()
)

# merge KMP to navstart to pick rho/E (hard KMP or soft)
kmp_map = kmp2.set_index("FundID")[["Recallable_Percentage_Decimal","Expiration_Quarters"]].to_dict("index")

def get_rho_E(fid: str, strategy: str) -> Tuple[float,int]:
    d = kmp_map.get(fid, None)
    if d is not None and pd.notna(d.get("Recallable_Percentage_Decimal")) and pd.notna(d.get("Expiration_Quarters")):
        rho = float(d["Recallable_Percentage_Decimal"])
        E = int(d["Expiration_Quarters"])
        return float(np.clip(rho, 0.0, 1.0)), max(E, 0)
    return soft_params(strategy)

# state
state = {}
ledgers = {}

navstart_idx = navstart.set_index("FundID")

for fid in funds:
    ns = navstart_idx.loc[fid]
    nav0 = float(ns["NAV_start"])
    cap_qe = ns["cap_qe"]

    # strategy/grade0 for reference (omega provides time-varying grade)
    strat0 = ns["Adj Strategy"] if "Adj Strategy" in ns.index else omega_df[omega_df["FundID"] == fid]["Adj Strategy"].iloc[0]

    C = float(fund_commit.get(fid, 0.0) or 0.0)
    C = max(C, 0.0)

    rho, E = get_rho_E(fid, strat0)
    ledgers[fid] = RecallableLedger(rho=rho, expiry_quarters=E, commitment=C)

    state[fid] = {
        "NAV": nav0,
        "DD_cum_commit": 0.0,
        "alive": True,
        "cap_qe": cap_qe,
        "Commitment": C,
    }

# =============================
# Simulation loop (endogenous NAV)
# =============================
out = []

# pre-group omega rows per fund
omega_by_fund = {fid: omega_df[omega_df["FundID"] == fid].sort_values("quarter_end").copy() for fid in funds}

for fid in funds:
    st = state[fid]
    ledger = ledgers[fid]
    df_f = omega_by_fund[fid]

    # apply cap_qe filter
    cap_qe = st["cap_qe"]
    if pd.notna(cap_qe):
        df_f = df_f[df_f["quarter_end"] <= cap_qe].copy()
    # limit to n_q
    df_f = df_f.head(n_q).copy()
    if df_f.empty:
        continue

    # last quarter for terminal enforcement
    qe_last = df_f["quarter_end"].max()

    for step, row in enumerate(df_f.itertuples(index=False), start=1):
        qe = row.quarter_end
        if qe not in U_by_q:
            # if omega has a quarter beyond our global quarter list, skip safely
            continue

        if not st["alive"]:
            break

        # grade/strategy from omega path (consistent with omega generation)
        strategy = getattr(row, "Adj_Strategy", None) or getattr(row, "Adj Strategy", None) or row._asdict().get("Adj Strategy")
        # pandas itertuples sanitizes names; easiest: use row._asdict()
        rdict = row._asdict()
        strategy = rdict.get("Adj Strategy", strategy)
        grade = rdict.get("Grade", "D")
        omega_t = float(rdict.get("omega", 0.0))

        # age for calibration (prefer Fund_Age_Quarters from omega file if present)
        age_q = rdict.get("Fund_Age_Quarters", None)
        if age_q is None or pd.isna(age_q):
            # fallback: keep a simple counter (not perfect but ok)
            age_q = step
        age_bucket = make_age_bucket_q(float(age_q))

        # stop if NAV already ~0
        NAV_prev = float(st["NAV"])
        if NAV_prev <= NAV_STOP_EPS:
            st["alive"] = False
            out.append({
                "FundID": fid, "quarter_end": qe, "step_q": step,
                "Adj Strategy": strategy, "Grade": grade, "AgeBucket": age_bucket,
                "NAV_prev": NAV_prev, "omega": omega_t,
                "Stopped_NAV_Zero": 1,

                "Capacity_Pre": 0.0,
                "Draw_Event": 0, "Draw_Amount": 0.0,
                "Rep_Event": 0, "Rep_Regular": 0.0, "Rep_Terminal": 0.0, "Rep_Total": 0.0,
                "RC_Event": 0, "RC_Added": 0.0,
                "NAV_after_valuation": 0.0,
                "NAV_end": 0.0
            })
            break

        # capacity
        rc_avail_pre = ledger.available(step)
        remaining_commit_pre = max(st["Commitment"] - st["DD_cum_commit"], 0.0)
        capacity_pre = remaining_commit_pre + rc_avail_pre

        params = lookup_params(strategy, grade, age_bucket)

        # copula uniforms for this quarter (fund-specific index)
        i = fund_index[fid]
        U = U_by_q[qe]

        # ---- Drawdown ----
        draw_event = (U["draw_event"][i] < params["p_draw"]) and (capacity_pre > 0.0)
        draw_amt = 0.0
        use_rc = 0.0
        use_commit = 0.0

        if draw_event:
            ratio = lognormal_from_u(params["mu_draw"], params["sig_draw"], float(U["draw_size"][i]))
            ratio = float(np.clip(ratio, 0.0, 1.0))
            draw_amt = ratio * capacity_pre
            cons = ledger.consume_for_drawdown(step, draw_amt)
            use_rc = cons["use_rc"]
            use_commit = cons["use_commitment"]
            st["DD_cum_commit"] += use_commit

        # ---- Repayment (regular) ----
        rep_regular = 0.0
        rep_event = (U["rep_event"][i] < params["p_rep"]) and (NAV_prev > NAV_EPS)
        if rep_event:
            rep_ratio = lognormal_from_u(params["mu_rep"], params["sig_rep"], float(U["rep_size"][i]))
            rep_ratio = float(np.clip(rep_ratio, 0.0, 2.0))
            rep_regular = rep_ratio * NAV_prev

        # ---- Recallable (conditional on regular repayment ONLY) ----
        rc_added = 0.0
        rc_event = (rep_regular > 0.0) and (U["rc_event"][i] < params["p_rc_given_rep"])
        if rc_event:
            rc_ratio = lognormal_from_u(params["mu_rc"], params["sig_rc"], float(U["rc_size"][i]))
            rc_ratio = float(np.clip(rc_ratio, 0.0, 1.0))
            rc_amt_raw = rc_ratio * rep_regular
            rc_added = ledger.add_recallable(step, rc_amt_raw, enforce_cap=True)

        # ---- Endogenous NAV update ----
        # Flow net into NAV base: +Draw - Rep_Regular (terminal handled separately)
        nav_after_flow = NAV_prev + float(draw_amt) - float(rep_regular)
        nav_after_flow = max(nav_after_flow, 0.0)

        nav_after_val = nav_after_flow * (1.0 + float(omega_t))
        if not np.isfinite(nav_after_val):
            nav_after_val = 0.0
        nav_after_val = max(float(nav_after_val), 0.0)

        rep_terminal = 0.0
        nav_end = nav_after_val

        # terminal enforcement at fund end quarter
        if qe == qe_last:
            rep_terminal = float(nav_after_val)
            nav_end = 0.0
            st["alive"] = False  # fund ends after this quarter

        # update NAV state
        st["NAV"] = nav_end

        # post capacity
        rc_avail_post = ledger.available(step)
        remaining_commit_post = max(st["Commitment"] - st["DD_cum_commit"], 0.0)
        capacity_post = remaining_commit_post + rc_avail_post

        out.append({
            "FundID": fid,
            "quarter_end": qe,
            "step_q": step,
            "Adj Strategy": strategy,
            "Grade": grade,
            "AgeBucket": age_bucket,

            "omega": float(omega_t),

            "NAV_prev": float(NAV_prev),
            "Stopped_NAV_Zero": 0,

            "RC_Avail_Pre": float(rc_avail_pre),
            "Remaining_Commit_Pre": float(remaining_commit_pre),
            "Capacity_Pre": float(capacity_pre),

            "Draw_Event": int(draw_event),
            "Draw_Amount": float(draw_amt),
            "Use_Recallable": float(use_rc),
            "Use_Commitment": float(use_commit),
            "DD_Cum_Commitment": float(st["DD_cum_commit"]),

            "Rep_Event": int(rep_event),
            "Rep_Regular": float(rep_regular),
            "Rep_Terminal": float(rep_terminal),
            "Rep_Total": float(rep_regular + rep_terminal),

            "RC_Event": int(rc_event),
            "RC_Added": float(rc_added),

            "NAV_after_flow": float(nav_after_flow),
            "NAV_after_valuation": float(nav_after_val),
            "NAV_end": float(nav_end),

            "RC_Avail_Post": float(rc_avail_post),
            "Remaining_Commit_Post": float(remaining_commit_post),
            "Capacity_Post": float(capacity_post),

            "rho_used": float(ledger.rho),
            "E_used": int(ledger.expiry_quarters),
            "cap_qe": st["cap_qe"],
            "qe_last": qe_last,
        })

sim = pd.DataFrame(out)
print("\nSimulated rows:", len(sim))
print("Simulated funds:", sim["FundID"].nunique())

# =============================
# Save
# =============================
save = input("Save structural cashflow outputs? (y/n) [y]: ").strip().lower()
if save in {"", "y", "yes"}:
    out_csv = os.path.join(DATA_DIR, f"structural_cashflows_endogenous_{year}_{quarter}_{n_q}q.csv")
    out_pq  = os.path.join(DATA_DIR, f"structural_cashflows_endogenous_{year}_{quarter}_{n_q}q.parquet")
    sim.to_csv(out_csv, index=False, sep=";", encoding="utf-8-sig")
    sim.to_parquet(out_pq, index=False)
    print("Saved:")
    print(out_csv)
    print(out_pq)

# =============================
# Sanity checks
# =============================
print("\nSanity checks:")
print("Total Draw:", float(sim["Draw_Amount"].sum()))
print("Total Rep Regular:", float(sim["Rep_Regular"].sum()))
print("Total Rep Terminal:", float(sim["Rep_Terminal"].sum()))
print("Total Rep Total:", float(sim["Rep_Total"].sum()))
print("Total Recallable Added:", float(sim["RC_Added"].sum()))

# NAV at end should be 0 for ended funds (by construction)
end_nav = sim.sort_values(["FundID","quarter_end"]).groupby("FundID").tail(1)
print("End NAV > 0 (should be 0):", int((end_nav["NAV_end"] > 1e-6).sum()))

# Cap compliance (RC <= rho*C) at fund end
end_nav = end_nav.copy()
end_nav["rc_cap"] = end_nav["rho_used"] * end_nav["Commitment_Level"] if "Commitment_Level" in end_nav.columns else end_nav["rho_used"] * end_nav["Remaining_Commit_Post"].add(end_nav["DD_Cum_Commitment"], fill_value=0.0)
# Better: use state commitment = Commitment from output
end_nav["rc_cap"] = end_nav["rho_used"] * end_nav["Remaining_Commit_Post"].add(end_nav["DD_Cum_Commitment"], fill_value=0.0)
viol = end_nav[end_nav["RC_Avail_Post"] - end_nav["rc_cap"] > 1e-6]
print("Cap violations at end (should be 0):", len(viol))
if len(viol):
    print(viol[["FundID","RC_Avail_Post","rc_cap","rho_used"]].head(20))

print("\nRuntime (seconds):", round(time.perf_counter() - t0, 2))


HOME: C:\Users\MANJANID\Documents\Equity\2025_Q3
DATA_DIR: C:\Users\MANJANID\Documents\Equity\2025_Q3\data


  for (s,g,a), grp in data.groupby(group_keys, dropna=False):


Using omega file: C:\Users\MANJANID\Documents\Equity\2025_Q3\data\omega_projection_sota_2025_Q3_40q.parquet
Using nav_start file: C:\Users\MANJANID\Documents\Equity\2025_Q3\data\nav_start_sota_2025_Q3.parquet

Simulated rows: 13274
Simulated funds: 945
Saved:
C:\Users\MANJANID\Documents\Equity\2025_Q3\data\structural_cashflows_endogenous_2025_Q3_40q.csv
C:\Users\MANJANID\Documents\Equity\2025_Q3\data\structural_cashflows_endogenous_2025_Q3_40q.parquet

Sanity checks:
Total Draw: 21587008279.104668
Total Rep Regular: 23152870352.907166
Total Rep Terminal: 31355793706.1512
Total Rep Total: 54508664059.058365
Total Recallable Added: 814505817.9592048
End NAV > 0 (should be 0): 0
Cap violations at end (should be 0): 0

Runtime (seconds): 35.56
