In [None]:
from pathlib import Path
import sys

here = Path.cwd()
repo_root = next(p for p in [here, *here.parents] if (p / "pyproject.toml").exists())
src = repo_root / "src"
if str(src) not in sys.path:
    sys.path.insert(0, str(src))

In [None]:
## import random
from decimal import Decimal

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
from time import time
from joblib import Parallel, delayed
from tqdm.auto import tqdm
from dataclasses import dataclass, replace
from typing import Any, Dict

from optimal_ipr.distributions import value_distribution, build_theta_distribution
from optimal_ipr.probability import build_subjective_probability
from optimal_ipr.cost import build_cost_function
from optimal_ipr.fee import build_fee_schedule
from optimal_ipr.lookup import build_lookup_tables
from optimal_ipr.outcomes import welfare_outcomes

np.random.seed(42)
random.seed(42)

RESULTS_DIR = Path("results/robustness")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)


def ensure_exists(path: Path) -> Path:
    path.parent.mkdir(parents=True, exist_ok=True)
    return path


GOVERNMENT_PREFERENCES = {"utilitarian": lambda th: 1.0}
REGULATOR_PREFERENCES = {"welfarist_balanced": {"phi": 0.6, "psi": 0.3}}


@dataclass(frozen=True)
class Context:
    v_grid: np.ndarray
    v_weights: np.ndarray
    f: callable
    F: callable
    F_inv: callable
    p: callable
    Z: callable
    tau_d: float
    tau_f: float
    bar_beta_grid: np.ndarray
    gov_prefs: Dict[str, Any]
    reg_prefs: Dict[str, Any]


def build_base_context() -> Context:
    v_grid, v_weights = value_distribution(n_v=41, sigma=1.0)
    f, F, F_inv = build_theta_distribution(noise_level=0.20)
    p = build_subjective_probability(base_k=1.2, m_comp=25, F=F, F_inv=F_inv)
    Z = build_fee_schedule(zeta=0.04, fee_M=1.5)
    tau_d, tau_f = 0.20, 0.05
    bar_beta_grid = np.linspace(0.0, 1.0, 101)
    return Context(
        v_grid,
        v_weights,
        f,
        F,
        F_inv,
        p,
        Z,
        tau_d,
        tau_f,
        bar_beta_grid,
        GOVERNMENT_PREFERENCES,
        REGULATOR_PREFERENCES,
    )


def run_model(ctx: Context, c: callable) -> tuple[float, float]:
    F_scalar = lambda t: float(ctx.F(np.array([t])))
    theta_tilde, theta_winner, *_ = build_lookup_tables(
        ctx.p,
        c,
        ctx.Z,
        ctx.f,
        F_scalar,
        ctx.tau_d,
        ctx.tau_f,
        ctx.bar_beta_grid,
        ctx.v_grid,
        seed=42,
    )
    results_table = welfare_outcomes(
        tau_d=ctx.tau_d,
        tau_f=ctx.tau_f,
        gov_prefs=ctx.gov_prefs,
        reg_prefs=ctx.reg_prefs,
        v_grid=ctx.v_grid,
        v_weights=ctx.v_weights,
        theta_tilde_table=theta_tilde,
        theta_winner_table=theta_winner,
        f=ctx.f,
        F=ctx.F,
        F_inv=ctx.F_inv,
        p=ctx.p,
        c=c,
        Z=ctx.Z,
        feas=False,
    )
    if results_table.empty:
        raise ValueError("welfare_outcomes returned an empty table")
    bar_beta_opt = float(results_table.loc[0, "Optimal Policy"])
    welfare_pct_change = float(results_table.loc[0, "Welfare % Change"])
    return bar_beta_opt, welfare_pct_change

In [None]:
ctx = build_base_context()


def run_for_c_min_cost(value: float) -> dict:
    c = build_cost_function(ctx.f, TARGET_AVG_COST_SHARE=0.50, C_MIN_COST=value, GAMMA_C_COST=3.0)
    bar_beta_opt, welfare_pct_change = run_model(ctx, c)
    return {
        "c_min_cost": value,
        "bar_beta_opt": bar_beta_opt,
        "welfare_pct_change": welfare_pct_change,
    }


values = np.arange(0.0, 0.5 + 1e-9, 0.05)
start = time()
results = Parallel(n_jobs=-1, backend="loky")(delayed(run_for_c_min_cost)(v) for v in tqdm(values))
total_time = time() - start
results_df = pd.DataFrame(results)
csv_path = ensure_exists(RESULTS_DIR / "c_min_cost_sweep.csv")
results_df.to_csv(csv_path, index=False)

plt.rcParams.update({"font.size": 14})
fig, ax1 = plt.subplots(figsize=(16, 9))
ax1.plot(results_df["c_min_cost"], results_df["bar_beta_opt"], color="tab:blue")
ax1.set_xlabel("C_MIN_COST")
ax1.set_ylabel("Optimal Policy", color="tab:blue")
ax2 = ax1.twinx()
ax2.plot(results_df["c_min_cost"], results_df["welfare_pct_change"], color="tab:orange")
ax2.set_ylabel("Welfare % Change", color="tab:orange")
fig.tight_layout()
png_path = ensure_exists(RESULTS_DIR / "c_min_cost_sweep.png")
fig.savefig(png_path, dpi=100, bbox_inches="tight")
plt.show()

baseline_value = 0.05
baseline_calc = run_for_c_min_cost(baseline_value)
baseline_row = results_df.loc[results_df["c_min_cost"] == baseline_value].iloc[0]

beta_diff = abs(baseline_calc["bar_beta_opt"] - baseline_row["bar_beta_opt"])
welfare_diff = abs(baseline_calc["welfare_pct_change"] - baseline_row["welfare_pct_change"])
if beta_diff <= 1e-6 and welfare_diff <= 1e-6:
    status = "passed"
else:
    status = "FAILED"
    raise AssertionError("Baseline cross-check failed")
print(
    f"Baseline cross-check {status}: Δbar_beta_opt={beta_diff:.3e}, Δwelfare_pct_change={welfare_diff:.3e}"
)
print(f"Total wall time: {total_time:.2f} s; per-value avg: {total_time/len(values):.2f} s")
print(f"{len(values)} values evaluated")
print(f"CSV saved to {csv_path}")
print(f"PNG saved to {png_path}")

In [None]:
display(results_df.head(5))

display(results_df.tail(5))