# 91_fig1_ergodic_distributions

Figure 1: ergodic distributions under commitment (by regime).

In [None]:
import os, sys, json, numpy as np, torch, matplotlib.pyplot as plt
sys.path.append("..")

from src.config import ModelParams
from src.io_utils import load_json, load_npz, load_torch, load_selected_run, find_latest_run_dir
from src.deqn import PolicyNetwork

ART = os.path.join("..","artifacts","runs")

def get_run(policy: str) -> str:
    rd = load_selected_run(ART, policy)
    if rd is None:
        rd = find_latest_run_dir(ART, policy)
    if rd is None:
        raise RuntimeError(f"No runs found for policy={policy} under {ART}")
    return rd

def _parse_dtype(s: str):
    if s is None:
        return torch.float32
    if isinstance(s, torch.dtype):
        return s
    s = str(s)
    if "float64" in s:
        return torch.float64
    if "float32" in s:
        return torch.float32
    if "bfloat16" in s:
        return torch.bfloat16
    return torch.float32

def load_params_from_run(run_dir: str, *, device="cpu"):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    p = cfg.get("params", {})
    dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else p.get("device","cpu")
    keep = {k:v for k,v in p.items() if k in ModelParams.__dataclass_fields__}
    keep["device"] = dev
    keep["dtype"] = dtype
    return ModelParams(**keep).to_torch()

def load_net_from_run(run_dir: str, d_in: int, d_out: int):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    tc = cfg.get("train_cfg", {})
    hidden = tuple(tc.get("hidden_layers", (512,512)))
    activation = tc.get("activation", "selu")
    net = PolicyNetwork(d_in, d_out, hidden=hidden, activation=activation)
    state = load_torch(os.path.join(run_dir, "weights.pt"), map_location="cpu")
    # state is usually a plain state_dict
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    net.load_state_dict(state)
    net.eval()
    return net

# --- paper reporting helpers ---
ann = lambda x: 400.0*x  # annualized percent (quarterly -> annual)


In [None]:
run = get_run("commitment")
params = load_params_from_run(run)
sim_path = os.path.join(run, "sim_paths.npz")
if not os.path.exists(sim_path):
    raise FileNotFoundError(f"Missing sim_paths.npz in {run}. Re-run training notebook with simulation enabled.")
sim = load_npz(sim_path)

if "i" not in sim:
    raise RuntimeError("Figure 1 requires nominal rates i in sim_paths.npz (run commitment simulation with compute_implied_i=True).")

from src.steady_states import solve_efficient_sss
from src.metrics import output_gap_from_consumption

s = sim["s"].reshape(-1).astype(np.int64)
pi = sim["pi"].reshape(-1)
i_nom = sim["i"].reshape(-1)
Delta = sim["Delta"].reshape(-1)

eff = solve_efficient_sss(params)
x_gap = output_gap_from_consumption(sim, eff, params=params, time_varying=True)  # log points

# Real rate aligned to t where pi_{t+1} exists
r_real = ((1.0 + i_nom[:-1]) / (1.0 + pi[1:])) - 1.0
s_r = s[:-1]

fig, ax = plt.subplots(2, 2, figsize=(12, 8))

ax[0, 0].hist(400.0 * pi[s == 0], bins=60, alpha=0.60, label="normal (s=0)")
ax[0, 0].hist(400.0 * pi[s == 1], bins=60, alpha=0.60, label="bad (s=1)")
ax[0, 0].set_title("Figure 1a: Inflation distribution")
ax[0, 0].set_xlabel("Annualized percent")
ax[0, 0].legend()

ax[0, 1].hist(100.0 * x_gap[s == 0], bins=60, alpha=0.60, label="normal (s=0)")
ax[0, 1].hist(100.0 * x_gap[s == 1], bins=60, alpha=0.60, label="bad (s=1)")
ax[0, 1].set_title("Figure 1b: Output gap distribution")
ax[0, 1].set_xlabel("Percent")
ax[0, 1].legend()

ax[1, 0].hist(400.0 * r_real[s_r == 0], bins=60, alpha=0.60, label="normal (s=0)")
ax[1, 0].hist(400.0 * r_real[s_r == 1], bins=60, alpha=0.60, label="bad (s=1)")
ax[1, 0].set_title("Figure 1c: Real rate distribution")
ax[1, 0].set_xlabel("Annualized percent")
ax[1, 0].legend()

ax[1, 1].hist(Delta[s == 0], bins=60, alpha=0.60, label="normal (s=0)")
ax[1, 1].hist(Delta[s == 1], bins=60, alpha=0.60, label="bad (s=1)")
ax[1, 1].set_title("Figure 1d: Price dispersion distribution")
ax[1, 1].set_xlabel("Level")
ax[1, 1].legend()

plt.tight_layout()
plt.show()
