# 100_fig10_sensitivity_p21

Figure 10: SSS inflation and real-rate sensitivity to regime duration (baseline and counterfactual).

In [None]:
import os, sys, json, numpy as np, torch, matplotlib.pyplot as plt
import pathlib

def _find_project_root():
    here = pathlib.Path.cwd().resolve()
    for p in [here, *here.parents]:
        if (p / "src").is_dir():
            return p
    # Common Google Colab clone location
    cand = pathlib.Path("/content/econml")
    if (cand / "src").is_dir():
        return cand
    raise RuntimeError("Could not find project root containing src/. If on Colab, clone repo to /content/econml.")

PROJECT_ROOT = _find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import ModelParams
from src.io_utils import load_json, load_npz, load_torch, load_selected_run, find_latest_run_dir
from src.deqn import PolicyNetwork

ART = str(PROJECT_ROOT / "artifacts" / "runs")

def get_run(policy: str) -> str:
    rd = load_selected_run(ART, policy)
    if rd is None:
        rd = find_latest_run_dir(ART, policy)
    if rd is None:
        raise RuntimeError(f"No runs found for policy={policy} under {ART}")
    return rd

def _parse_dtype(s: str):
    if s is None:
        return torch.float32
    if isinstance(s, torch.dtype):
        return s
    s = str(s)
    if "float64" in s:
        return torch.float64
    if "float32" in s:
        return torch.float32
    if "bfloat16" in s:
        return torch.bfloat16
    return torch.float32

def load_params_from_run(run_dir: str, *, device="cpu"):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    p = cfg.get("params", {})
    dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else p.get("device","cpu")
    keep = {k:v for k,v in p.items() if k in ModelParams.__dataclass_fields__}
    keep["device"] = dev
    keep["dtype"] = dtype
    return ModelParams(**keep).to_torch()

def load_net_from_run(run_dir: str, d_in: int, d_out: int):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    tc = cfg.get("train_cfg", {})
    hidden = tuple(tc.get("hidden_layers", (512,512)))
    activation = tc.get("activation", "selu")
    net = PolicyNetwork(d_in, d_out, hidden=hidden, activation=activation)
    state = load_torch(os.path.join(run_dir, "weights.pt"), map_location="cpu")
    # state is usually a plain state_dict
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    net.load_state_dict(state)
    net.eval()
    return net

In [None]:
from src.steady_states import solve_flexprice_sss, solve_taylor_sss, solve_efficient_sss

ann = lambda x: 400.0 * x
params0 = ModelParams(device="cpu", dtype=torch.float32)

grid_p21 = np.linspace(0.02, 0.98, 40)
dur_bad = 1.0 / grid_p21  # average duration of bad times

pi_normal_baseline = []
r_normal_baseline = []
pi_bad_counter = []
r_bad_counter = []

for p21 in grid_p21:
    # Baseline model: vary p21 keeping baseline p12.
    p_base = ModelParams(
        beta=params0.beta, gamma=params0.gamma, omega=params0.omega, theta=params0.theta, eps=params0.eps, tau_bar=params0.tau_bar,
        rho_A=params0.rho_A, rho_tau=params0.rho_tau, rho_g=params0.rho_g,
        sigma_A=params0.sigma_A, sigma_tau=params0.sigma_tau, sigma_g=params0.sigma_g,
        g_bar=params0.g_bar, eta_bar=params0.eta_bar, bad_state=params0.bad_state,
        p12=params0.p12, p21=float(p21),
        pi_bar=params0.pi_bar, psi=params0.psi,
        device="cpu", dtype=torch.float32
    ).to_torch()
    flex_base = solve_flexprice_sss(p_base)
    tay_base = solve_taylor_sss(p_base, flex_base)
    pi_normal_baseline.append(float(tay_base.by_regime[0]["pi"]))
    r_normal_baseline.append(float(tay_base.by_regime[0]["r"]))

    # Counterfactual for bad-times line: p12 -> 1, vary p21.
    p_cf = ModelParams(
        beta=params0.beta, gamma=params0.gamma, omega=params0.omega, theta=params0.theta, eps=params0.eps, tau_bar=params0.tau_bar,
        rho_A=params0.rho_A, rho_tau=params0.rho_tau, rho_g=params0.rho_g,
        sigma_A=params0.sigma_A, sigma_tau=params0.sigma_tau, sigma_g=params0.sigma_g,
        g_bar=params0.g_bar, eta_bar=params0.eta_bar, bad_state=params0.bad_state,
        p12=1.0, p21=float(p21),
        pi_bar=params0.pi_bar, psi=params0.psi,
        device="cpu", dtype=torch.float32
    ).to_torch()
    flex_cf = solve_flexprice_sss(p_cf)
    tay_cf = solve_taylor_sss(p_cf, flex_cf)
    pi_bad_counter.append(float(tay_cf.by_regime[1]["pi"]))
    r_bad_counter.append(float(tay_cf.by_regime[1]["r"]))

eff = solve_efficient_sss(params0)
pi_eff = 0.0
r_eff = float(eff["r_hat"])

# monotone x-axis in plots
idx = np.argsort(dur_bad)
x = dur_bad[idx]

fig, ax = plt.subplots(1, 2, figsize=(11, 4))

ax[0].plot(x, ann(np.asarray(pi_normal_baseline)[idx]), label="Normal SSS (baseline)", color="tab:blue")
ax[0].plot(x, ann(np.asarray(pi_bad_counter)[idx]), label="Bad SSS (counterfactual p12->1)", color="tab:red", linestyle="--")
ax[0].axhline(ann(pi_eff), color="gray", linestyle=":", label="Efficient allocation")
ax[0].set_title("Figure 10a: SSS inflation sensitivity")
ax[0].set_xlabel("Average bad-times duration (quarters)")
ax[0].set_ylabel("Annualized percent")
ax[0].legend()

ax[1].plot(x, ann(np.asarray(r_normal_baseline)[idx]), label="Normal SSS (baseline)", color="tab:blue")
ax[1].plot(x, ann(np.asarray(r_bad_counter)[idx]), label="Bad SSS (counterfactual p12->1)", color="tab:red", linestyle="--")
ax[1].axhline(ann(r_eff), color="gray", linestyle=":", label="Efficient allocation")
ax[1].set_title("Figure 10b: SSS real-rate sensitivity")
ax[1].set_xlabel("Average bad-times duration (quarters)")
ax[1].set_ylabel("Annualized percent")
ax[1].legend()

plt.tight_layout()
plt.show()
