# 94_fig4_persistence_sensitivity

Figure 4: compare temporary Î¾ shock persistence across two commitment runs (e.g., rho_tau=0.90 vs 0.99).

In [None]:
import os, sys, json, numpy as np, torch, matplotlib.pyplot as plt
import pathlib

def _find_project_root():
    here = pathlib.Path.cwd().resolve()
    for p in [here, *here.parents]:
        if (p / "src").is_dir():
            return p
    # Common Google Colab clone location
    cand = pathlib.Path("/content/econml")
    if (cand / "src").is_dir():
        return cand
    raise RuntimeError("Could not find project root containing src/. If on Colab, clone repo to /content/econml.")

PROJECT_ROOT = _find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import ModelParams
from src.io_utils import load_json, load_npz, load_torch, load_selected_run, find_latest_run_dir
from src.deqn import PolicyNetwork

ART = str(PROJECT_ROOT / "artifacts" / "runs")
COMPUTE_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Compute device:", COMPUTE_DEVICE)

def get_run(policy: str) -> str:
    rd = load_selected_run(ART, policy)
    if rd is None:
        rd = find_latest_run_dir(ART, policy)
    if rd is None:
        raise RuntimeError(f"No runs found for policy={policy} under {ART}")
    return rd

def _parse_dtype(s: str):
    if s is None:
        return torch.float32
    if isinstance(s, torch.dtype):
        return s
    s = str(s)
    if "float64" in s:
        return torch.float64
    if "float32" in s:
        return torch.float32
    if "bfloat16" in s:
        return torch.bfloat16
    return torch.float32

def load_params_from_run(run_dir: str, *, device=None):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    p = cfg.get("params", {})
    dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else COMPUTE_DEVICE
    keep = {k:v for k,v in p.items() if k in ModelParams.__dataclass_fields__}
    keep["device"] = dev
    keep["dtype"] = dtype
    return ModelParams(**keep).to_torch()

def load_net_from_run(run_dir: str, d_in: int, d_out: int, *, device=None):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    tc = cfg.get("train_cfg", {})
    hidden = tuple(tc.get("hidden_layers", (512,512)))
    activation = tc.get("activation", "selu")
    p = cfg.get("params", {})
    net_dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else COMPUTE_DEVICE
    net = PolicyNetwork(d_in, d_out, hidden=hidden, activation=activation).to(device=dev, dtype=net_dtype)
    state = load_torch(os.path.join(run_dir, "weights.pt"), map_location="cpu")
    # state is usually a plain state_dict
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    net.load_state_dict(state)
    net.eval()
    return net

# --- paper reporting helpers ---
ann = lambda x: 400.0*x  # annualized percent (quarterly -> annual)


In [None]:
from src.steady_states import solve_commitment_sss_from_policy_switching
from src.experiments import DeterministicPathSpec, simulate_deterministic_path

ART_ROOT = str(PROJECT_ROOT / "artifacts")
ROOT = os.path.join(ART_ROOT, "runs", "commitment")

TARGET_RHO_A = 0.90
TARGET_RHO_B = 0.99
RHO_TOL = 1e-10
SHOCK_STD_MULT = 1.0  # one-sigma temporary cost-push shock at t=0 for both calibrations


def _has_weights(run_dir: str) -> bool:
    return os.path.exists(os.path.join(run_dir, "weights.pt")) or os.path.exists(os.path.join(run_dir, "weights_best.pt"))


def list_runs(root):
    if not os.path.isdir(root):
        return []
    runs = []
    for d in os.listdir(root):
        rd = os.path.join(root, d)
        if os.path.isdir(rd) and _has_weights(rd):
            runs.append(rd)
    runs.sort(key=lambda q: os.path.getmtime(q), reverse=True)
    return runs


def _is_no_regime(params) -> bool:
    return float(params.eta_bar) == 0.0 and float(params.p12) == 0.0 and float(params.p21) == 0.0


def _is_target_rho(v: float, target: float) -> bool:
    return abs(float(v) - float(target)) <= RHO_TOL


def _pick_target_pair(candidates, rho_a: float, rho_b: float):
    cache = {}

    def getp(run_dir):
        if run_dir not in cache:
            cache[run_dir] = load_params_from_run(run_dir)
        return cache[run_dir]

    run_a = None
    run_b = None
    params_a = None
    params_b = None

    for rd in candidates:
        p = getp(rd)
        if not _is_no_regime(p):
            continue
        if run_a is None and _is_target_rho(p.rho_tau, rho_a):
            run_a, params_a = rd, p
        if run_b is None and _is_target_rho(p.rho_tau, rho_b):
            run_b, params_b = rd, p
        if run_a is not None and run_b is not None:
            break

    if run_a is None or run_b is None:
        return None, None, None, None
    return run_a, run_b, params_a, params_b


runs = []
selected = load_selected_run(ART_ROOT, "commitment")
if selected is not None and os.path.isdir(selected) and _has_weights(selected):
    runs.append(selected)
for rd in list_runs(ROOT):
    if rd not in runs:
        runs.append(rd)

RUN_A = None
RUN_B = None

if RUN_A is None or RUN_B is None:
    RUN_A, RUN_B, paramsA, paramsB = _pick_target_pair(runs, TARGET_RHO_A, TARGET_RHO_B)
    if RUN_A is None:
        raise RuntimeError(
            f"Need two no-regime commitment runs with rho_tau={TARGET_RHO_A} and rho_tau={TARGET_RHO_B}."
        )
else:
    if not (_has_weights(RUN_A) and _has_weights(RUN_B)):
        raise RuntimeError("RUN_A/RUN_B must point to runs containing weights.pt or weights_best.pt.")
    paramsA = load_params_from_run(RUN_A)
    paramsB = load_params_from_run(RUN_B)

print("RUN_A:", RUN_A)
print("RUN_B:", RUN_B)


def _require_no_regimes(params, label):
    if float(params.eta_bar) != 0.0 or float(params.p12) != 0.0 or float(params.p21) != 0.0:
        raise RuntimeError(
            f"{label}: Figure 4 is defined for no-regime model only (eta_bar=0, p12=0, p21=0). "
            f"Current values: eta_bar={params.eta_bar}, p12={params.p12}, p21={params.p21}."
        )


_require_no_regimes(paramsA, "Run A")
_require_no_regimes(paramsB, "Run B")
if not _is_target_rho(paramsA.rho_tau, TARGET_RHO_A):
    raise RuntimeError(f"Run A must have rho_tau={TARGET_RHO_A}, got {paramsA.rho_tau}.")
if not _is_target_rho(paramsB.rho_tau, TARGET_RHO_B):
    raise RuntimeError(f"Run B must have rho_tau={TARGET_RHO_B}, got {paramsB.rho_tau}.")

netA = load_net_from_run(RUN_A, 7, 13)
netB = load_net_from_run(RUN_B, 7, 13)

sssA = solve_commitment_sss_from_policy_switching(paramsA, netA)
sssB = solve_commitment_sss_from_policy_switching(paramsB, netB)

x0A = torch.tensor([[float(sssA.by_regime[0]["Delta_prev"]), float(sssA.by_regime[0]["logA"]), float(sssA.by_regime[0]["loggtilde"]), float(sssA.by_regime[0]["xi"]), 0.0,
                     float(sssA.by_regime[0]["vartheta_prev"]), float(sssA.by_regime[0]["varrho_prev"])]], dtype=torch.float32)
x0B = torch.tensor([[float(sssB.by_regime[0]["Delta_prev"]), float(sssB.by_regime[0]["logA"]), float(sssB.by_regime[0]["loggtilde"]), float(sssB.by_regime[0]["xi"]), 0.0,
                     float(sssB.by_regime[0]["vartheta_prev"]), float(sssB.by_regime[0]["varrho_prev"])]], dtype=torch.float32)

T = 40

# One-time temporary shock at t=0 (same size in sigma-units for both runs), then deterministic decay.
xiA = SHOCK_STD_MULT * float(paramsA.sigma_tau)
xiB = SHOCK_STD_MULT * float(paramsB.sigma_tau)
x0A2 = x0A.clone(); x0A2[:, 3] += xiA
x0B2 = x0B.clone(); x0B2[:, 3] += xiB

spec = DeterministicPathSpec(T=T, epsA=0.0, epsg=0.0, epst=0.0, regime_path=None)
pathA = simulate_deterministic_path(paramsA, "commitment", netA, x0=x0A2, spec=spec, compute_implied_i=True)
pathB = simulate_deterministic_path(paramsB, "commitment", netB, x0=x0B2, spec=spec, compute_implied_i=True)

t = np.arange(T + 1)
plt.figure()
plt.plot(t, ann(pathA["pi"][:, 0]), label=f"rho_tau={paramsA.rho_tau:.2f}")
plt.plot(t, ann(pathB["pi"][:, 0]), label=f"rho_tau={paramsB.rho_tau:.2f}", linestyle="--")
plt.axhline(0, linewidth=1)
plt.title("Figure 4: pi persistence (temporary cost-push shock)")
plt.xlabel("Time in quarters")
plt.ylabel("Annualized percent")
plt.legend()
plt.show()


In [None]:
# Checks for no-regime configuration and distinct rho_tau are enforced in the previous cell before any simulation.
