# 93_fig3_persistent_vs_temporary

Figure 3: persistent regime switch vs temporary ξ shock with matched inflation impact (commitment or discretion).

In [None]:
import os, sys, json, numpy as np, torch, matplotlib.pyplot as plt
sys.path.append("..")

from src.config import ModelParams
from src.io_utils import load_json, load_npz, load_torch, load_selected_run, find_latest_run_dir
from src.deqn import PolicyNetwork

ART = os.path.join("..","artifacts","runs")

def get_run(policy: str) -> str:
    rd = load_selected_run(ART, policy)
    if rd is None:
        rd = find_latest_run_dir(ART, policy)
    if rd is None:
        raise RuntimeError(f"No runs found for policy={policy} under {ART}")
    return rd

def _parse_dtype(s: str):
    if s is None:
        return torch.float32
    if isinstance(s, torch.dtype):
        return s
    s = str(s)
    if "float64" in s:
        return torch.float64
    if "float32" in s:
        return torch.float32
    if "bfloat16" in s:
        return torch.bfloat16
    return torch.float32

def load_params_from_run(run_dir: str, *, device="cpu"):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    p = cfg.get("params", {})
    dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else p.get("device","cpu")
    keep = {k:v for k,v in p.items() if k in ModelParams.__dataclass_fields__}
    keep["device"] = dev
    keep["dtype"] = dtype
    return ModelParams(**keep).to_torch()

def load_net_from_run(run_dir: str, d_in: int, d_out: int):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    tc = cfg.get("train_cfg", {})
    hidden = tuple(tc.get("hidden_layers", (512,512)))
    activation = tc.get("activation", "selu")
    net = PolicyNetwork(d_in, d_out, hidden=hidden, activation=activation)
    state = load_torch(os.path.join(run_dir, "weights.pt"), map_location="cpu")
    # state is usually a plain state_dict
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    net.load_state_dict(state)
    net.eval()
    return net

# --- paper reporting helpers ---
ann = lambda x: 400.0*x  # annualized percent (quarterly -> annual)


In [None]:
from src.steady_states import solve_commitment_sss_from_policy_switching, solve_discretion_sss_from_policy_switching
from src.experiments import DeterministicPathSpec, simulate_deterministic_path, calibrate_xi_jump_to_match_pi_impact

POLICY = "commitment"  # or "discretion"
run_dir = get_run(POLICY)
params = load_params_from_run(run_dir)

if POLICY == "commitment":
    net = load_net_from_run(run_dir, 7, 13)
    sss = solve_commitment_sss_from_policy_switching(params, net)
    x0 = torch.tensor([[float(sss.by_regime[0]["Delta_prev"]), float(sss.by_regime[0]["logA"]), float(sss.by_regime[0]["loggtilde"]), float(sss.by_regime[0]["xi"]), 0.0,
                        float(sss.by_regime[0]["vartheta_prev"]), float(sss.by_regime[0]["varrho_prev"])]], dtype=torch.float32)
else:
    net = load_net_from_run(run_dir, 5, 11)
    sss = solve_discretion_sss_from_policy_switching(params, net)
    x0 = torch.tensor([[float(sss.by_regime[0]["Delta_prev"]), float(sss.by_regime[0]["logA"]), float(sss.by_regime[0]["loggtilde"]), float(sss.by_regime[0]["xi"]), 0.0]], dtype=torch.float32)

T=40
spec_switch = DeterministicPathSpec(T=T, epsA=0.0, epsg=0.0, epst=0.0, regime_path=[0]+[1]*T)
path_switch = simulate_deterministic_path(params, POLICY, net, x0=x0, spec=spec_switch, compute_implied_i=True)
target_pi0 = float(path_switch["pi"][0].mean())

xi_jump = calibrate_xi_jump_to_match_pi_impact(params, POLICY, net, x0=x0, target_pi0=target_pi0, horizon_T=1)
x0_xi = x0.clone(); x0_xi[:,3] += xi_jump
spec_temp = DeterministicPathSpec(T=T, epsA=0.0, epsg=0.0, epst=0.0, regime_path=None)
path_temp = simulate_deterministic_path(params, POLICY, net, x0=x0_xi, spec=spec_temp, compute_implied_i=True)

t=np.arange(T+1)
plt.figure()
plt.plot(t, ann(ann(path_switch["pi"][:,0])), label="Persistent (regime switch)")
plt.plot(t, ann(ann(path_temp["pi"][:,0])), label="Temporary ξ shock", linestyle="--")
plt.axhline(0, linewidth=1)
plt.title("Figure 3: π paths (matched impact)")
plt.xlabel("t"); plt.ylabel("π"); plt.legend(); plt.show()

print("Matched impact target_pi0:", target_pi0, "xi_jump:", xi_jump)
