# 92_fig2_transition_commitment_vs_discretion

Figure 2: forced transition normalâ†’bad, commitment vs discretion, deterministic innovations (eps=0).

In [None]:
import os, sys, json, numpy as np, torch, matplotlib.pyplot as plt
import pathlib

def _find_project_root():
    here = pathlib.Path.cwd().resolve()
    for p in [here, *here.parents]:
        if (p / "src").is_dir():
            return p
    # Common Google Colab clone location
    cand = pathlib.Path("/content/econml")
    if (cand / "src").is_dir():
        return cand
    raise RuntimeError("Could not find project root containing src/. If on Colab, clone repo to /content/econml.")

PROJECT_ROOT = _find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import ModelParams
from src.io_utils import load_json, load_npz, load_torch, load_selected_run, find_latest_run_dir
from src.deqn import PolicyNetwork

ART = str(PROJECT_ROOT / "artifacts" / "runs")
COMPUTE_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Compute device:", COMPUTE_DEVICE)

def get_run(policy: str) -> str:
    rd = load_selected_run(ART, policy)
    if rd is None:
        rd = find_latest_run_dir(ART, policy)
    if rd is None:
        raise RuntimeError(f"No runs found for policy={policy} under {ART}")
    return rd

def _parse_dtype(s: str):
    if s is None:
        return torch.float32
    if isinstance(s, torch.dtype):
        return s
    s = str(s)
    if "float64" in s:
        return torch.float64
    if "float32" in s:
        return torch.float32
    if "bfloat16" in s:
        return torch.bfloat16
    return torch.float32

def load_params_from_run(run_dir: str, *, device=None):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    p = cfg.get("params", {})
    dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else COMPUTE_DEVICE
    keep = {k:v for k,v in p.items() if k in ModelParams.__dataclass_fields__}
    keep["device"] = dev
    keep["dtype"] = dtype
    return ModelParams(**keep).to_torch()

def load_net_from_run(run_dir: str, d_in: int, d_out: int, *, device=None):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    tc = cfg.get("train_cfg", {})
    hidden = tuple(tc.get("hidden_layers", (512,512)))
    activation = tc.get("activation", "selu")
    p = cfg.get("params", {})
    net_dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else COMPUTE_DEVICE
    net = PolicyNetwork(d_in, d_out, hidden=hidden, activation=activation).to(device=dev, dtype=net_dtype)
    state = load_torch(os.path.join(run_dir, "weights.pt"), map_location="cpu")
    # state is usually a plain state_dict
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    net.load_state_dict(state)
    net.eval()
    return net

# --- paper reporting helpers ---
ann = lambda x: 400.0*x  # annualized percent (quarterly -> annual)


In [None]:
from src.steady_states import solve_commitment_sss_from_policy, solve_discretion_sss_from_policy, solve_efficient_sss
from src.experiments import DeterministicPathSpec, simulate_deterministic_path

run_comm = get_run("commitment")
run_disc = get_run("discretion")

params_comm = load_params_from_run(run_comm)
params_disc = load_params_from_run(run_disc)

# In the paper, policies are compared under identical calibration.
def _assert_same_calibration(p1, p2, *, atol=1e-12):
    fields = [k for k in ModelParams.__dataclass_fields__.keys() if k not in ("device", "dtype")]
    diffs = []
    for k in fields:
        v1 = getattr(p1, k)
        v2 = getattr(p2, k)
        if isinstance(v1, (int, float)) and isinstance(v2, (int, float)):
            if abs(float(v1) - float(v2)) > atol:
                diffs.append((k, v1, v2))
        else:
            if v1 != v2:
                diffs.append((k, v1, v2))
    if diffs:
        details = ", ".join([f"{k}: commitment={a}, discretion={b}" for k, a, b in diffs])
        raise RuntimeError(
            "Figure 2 requires identical calibration across commitment/discretion runs. "
            f"Differences found: {details}"
        )

_assert_same_calibration(params_comm, params_disc)
params = params_comm

net_comm = load_net_from_run(run_comm, 7, 13)
net_disc = load_net_from_run(run_disc, 5, 11)

comm_sss = solve_commitment_sss_from_policy(params, net_comm)
disc_sss = solve_discretion_sss_from_policy(params, net_disc)

x0_comm = torch.tensor([[float(comm_sss.by_regime[0]["Delta_prev"]), float(comm_sss.by_regime[0]["logA"]), float(comm_sss.by_regime[0]["loggtilde"]), float(comm_sss.by_regime[0]["xi"]), 0.0,
                         float(comm_sss.by_regime[0]["vartheta_prev"]),
                         float(comm_sss.by_regime[0]["varrho_prev"])]], dtype=torch.float32)
x0_disc = torch.tensor([[float(disc_sss.by_regime[0]["Delta_prev"]), float(disc_sss.by_regime[0]["logA"]), float(disc_sss.by_regime[0]["loggtilde"]), float(disc_sss.by_regime[0]["xi"]), 0.0]], dtype=torch.float32)

# Paper-like horizon/axis: show a short pre-shock window and post-switch dynamics.
pre = 5
n_post = 20
T_path = n_post + 1
spec = DeterministicPathSpec(T=T_path, epsA=0.0, epsg=0.0, epst=0.0, regime_path=[0] + [1] * T_path)

path_comm = simulate_deterministic_path(params, "commitment", net_comm, x0=x0_comm, spec=spec, compute_implied_i=True)
path_disc = simulate_deterministic_path(params, "discretion", net_disc, x0=x0_disc, spec=spec, compute_implied_i=True)

pi_comm_all = path_comm["pi"][:, 0]
pi_disc_all = path_disc["pi"][:, 0]
i_comm_all = path_comm["i"][:, 0]
i_disc_all = path_disc["i"][:, 0]
c_comm_all = path_comm["c"][:, 0]
c_disc_all = path_disc["c"][:, 0]
Delta_comm_all = path_comm["Delta"][:, 0]
Delta_disc_all = path_disc["Delta"][:, 0]

r_comm_all = ((1.0 + i_comm_all[:-1]) / (1.0 + pi_comm_all[1:])) - 1.0
r_disc_all = ((1.0 + i_disc_all[:-1]) / (1.0 + pi_disc_all[1:])) - 1.0

pi_comm_post = pi_comm_all[1:1 + n_post]
pi_disc_post = pi_disc_all[1:1 + n_post]
r_comm_post = r_comm_all[1:1 + n_post]
r_disc_post = r_disc_all[1:1 + n_post]
Delta_comm_post = Delta_comm_all[1:1 + n_post]
Delta_disc_post = Delta_disc_all[1:1 + n_post]

if len(pi_comm_post) != n_post or len(r_comm_post) != n_post:
    raise RuntimeError("Figure 2 alignment failed: deterministic path length mismatch.")

eff = solve_efficient_sss(params)
c_hat = float(eff["c_hat"])
x_comm_post = np.log(c_comm_all[1:1 + n_post]) - np.log(c_hat)
x_disc_post = np.log(c_disc_all[1:1 + n_post]) - np.log(c_hat)

pi_comm = np.concatenate([np.full(pre, pi_comm_all[0]), pi_comm_post])
pi_disc = np.concatenate([np.full(pre, pi_disc_all[0]), pi_disc_post])
x_comm = np.concatenate([np.full(pre, np.log(c_comm_all[0]) - np.log(c_hat)), x_comm_post])
x_disc = np.concatenate([np.full(pre, np.log(c_disc_all[0]) - np.log(c_hat)), x_disc_post])
r_comm = np.concatenate([np.full(pre, r_comm_post[0]), r_comm_post])
r_disc = np.concatenate([np.full(pre, r_disc_post[0]), r_disc_post])
Delta_comm = np.concatenate([np.full(pre, Delta_comm_all[0]), Delta_comm_post])
Delta_disc = np.concatenate([np.full(pre, Delta_disc_all[0]), Delta_disc_post])

t = np.arange(-pre, n_post)

fig, ax = plt.subplots(2, 2, figsize=(12, 8))

ax[0, 0].plot(t, ann(pi_comm), label="Commitment")
ax[0, 0].plot(t, ann(pi_disc), label="Discretion", linestyle="--")
ax[0, 0].axhline(0.0, linewidth=1)
ax[0, 0].set_title("Figure 2a: Inflation")
ax[0, 0].set_xlabel("Time in quarters")
ax[0, 0].set_ylabel("Ann. perc.")
ax[0, 0].legend()

ax[0, 1].plot(t, 100.0 * x_comm, label="Commitment")
ax[0, 1].plot(t, 100.0 * x_disc, label="Discretion", linestyle="--")
ax[0, 1].axhline(0.0, linewidth=1)
ax[0, 1].set_title("Figure 2b: Output gap")
ax[0, 1].set_xlabel("Time in quarters")
ax[0, 1].set_ylabel("Perc. of log")
ax[0, 1].legend()

ax[1, 0].plot(t, ann(r_comm), label="Commitment")
ax[1, 0].plot(t, ann(r_disc), label="Discretion", linestyle="--")
ax[1, 0].axhline(0.0, linewidth=1)
ax[1, 0].set_title("Figure 2c: Real interest rate")
ax[1, 0].set_xlabel("Time in quarters")
ax[1, 0].set_ylabel("Ann. perc.")
ax[1, 0].legend()

ax[1, 1].plot(t, Delta_comm, label="Commitment")
ax[1, 1].plot(t, Delta_disc, label="Discretion", linestyle="--")
ax[1, 1].axhline(1.0, linewidth=1)
ax[1, 1].set_title("Figure 2d: Price dispersion")
ax[1, 1].set_xlabel("Time in quarters")
ax[1, 1].set_ylabel("Value")
ax[1, 1].legend()

plt.tight_layout()
plt.show()
