# 96_fig6_asymmetry

Figure 6: asymmetry normal→bad vs bad→normal (commitment) using forced regime paths.

In [None]:
import os, sys, json, numpy as np, torch, matplotlib.pyplot as plt
import pathlib

def _find_project_root():
    here = pathlib.Path.cwd().resolve()
    for p in [here, *here.parents]:
        if (p / "src").is_dir():
            return p
    # Common Google Colab clone location
    cand = pathlib.Path("/content/econml")
    if (cand / "src").is_dir():
        return cand
    raise RuntimeError("Could not find project root containing src/. If on Colab, clone repo to /content/econml.")

PROJECT_ROOT = _find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import ModelParams
from src.io_utils import load_json, load_npz, load_torch, load_selected_run, find_latest_run_dir
from src.deqn import PolicyNetwork

ART = str(PROJECT_ROOT / "artifacts" / "runs")

def get_run(policy: str) -> str:
    rd = load_selected_run(ART, policy)
    if rd is None:
        rd = find_latest_run_dir(ART, policy)
    if rd is None:
        raise RuntimeError(f"No runs found for policy={policy} under {ART}")
    return rd

def _parse_dtype(s: str):
    if s is None:
        return torch.float32
    if isinstance(s, torch.dtype):
        return s
    s = str(s)
    if "float64" in s:
        return torch.float64
    if "float32" in s:
        return torch.float32
    if "bfloat16" in s:
        return torch.bfloat16
    return torch.float32

def load_params_from_run(run_dir: str, *, device="cpu"):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    p = cfg.get("params", {})
    dtype = _parse_dtype(p.get("dtype"))
    dev = device if device is not None else p.get("device","cpu")
    keep = {k:v for k,v in p.items() if k in ModelParams.__dataclass_fields__}
    keep["device"] = dev
    keep["dtype"] = dtype
    return ModelParams(**keep).to_torch()

def load_net_from_run(run_dir: str, d_in: int, d_out: int):
    cfg = load_json(os.path.join(run_dir, "config.json"))
    tc = cfg.get("train_cfg", {})
    hidden = tuple(tc.get("hidden_layers", (512,512)))
    activation = tc.get("activation", "selu")
    p = cfg.get("params", {})
    net_dtype = _parse_dtype(p.get("dtype"))
    net = PolicyNetwork(d_in, d_out, hidden=hidden, activation=activation).to(device="cpu", dtype=net_dtype)
    state = load_torch(os.path.join(run_dir, "weights.pt"), map_location="cpu")
    # state is usually a plain state_dict
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    net.load_state_dict(state)
    net.eval()
    return net

# --- paper reporting helpers ---
ann = lambda x: 400.0*x  # annualized percent (quarterly -> annual)


In [None]:
from src.steady_states import solve_commitment_sss_from_policy_switching, solve_efficient_sss
from src.experiments import DeterministicPathSpec, simulate_deterministic_path

run_dir = get_run("commitment")
params = load_params_from_run(run_dir)
net = load_net_from_run(run_dir, 7, 13)
sss = solve_commitment_sss_from_policy_switching(params, net)

x0_n = torch.tensor([[float(sss.by_regime[0]["Delta_prev"]), float(sss.by_regime[0]["logA"]), float(sss.by_regime[0]["loggtilde"]), float(sss.by_regime[0]["xi"]), 0.0,
                      float(sss.by_regime[0]["vartheta_prev"]), float(sss.by_regime[0]["varrho_prev"])]], dtype=torch.float32)
x0_b = torch.tensor([[float(sss.by_regime[1]["Delta_prev"]), float(sss.by_regime[1]["logA"]), float(sss.by_regime[1]["loggtilde"]), float(sss.by_regime[1]["xi"]), 1.0,
                      float(sss.by_regime[1]["vartheta_prev"]), float(sss.by_regime[1]["varrho_prev"])]], dtype=torch.float32)

T = 40
path_nb = simulate_deterministic_path(
    params,
    "commitment",
    net,
    x0=x0_n,
    spec=DeterministicPathSpec(T=T, regime_path=[0] + [1] * T),
    compute_implied_i=True,
)
path_bn = simulate_deterministic_path(
    params,
    "commitment",
    net,
    x0=x0_b,
    spec=DeterministicPathSpec(T=T, regime_path=[1] + [0] * T),
    compute_implied_i=True,
)

# regime_path applies to s_{t+1}; use index 1 as aligned impact t=0.
pi_nb_all = path_nb["pi"][:, 0]
pi_bn_all = path_bn["pi"][:, 0]
c_nb_all = path_nb["c"][:, 0]
c_bn_all = path_bn["c"][:, 0]
Delta_nb_all = path_nb["Delta"][:, 0]
Delta_bn_all = path_bn["Delta"][:, 0]
i_nb_all = path_nb["i"][:, 0]
i_bn_all = path_bn["i"][:, 0]

r_nb_all = ((1.0 + i_nb_all[:-1]) / (1.0 + pi_nb_all[1:])) - 1.0
r_bn_all = ((1.0 + i_bn_all[:-1]) / (1.0 + pi_bn_all[1:])) - 1.0

# Use a common horizon and align from impact period.
n = min(len(r_nb_all) - 1, len(r_bn_all) - 1)
if n <= 0:
    raise RuntimeError("Figure 6 alignment failed: too short deterministic paths.")

t = np.arange(n)

pi_nb = pi_nb_all[1:1 + n]
pi_bn = pi_bn_all[1:1 + n]

eff = solve_efficient_sss(params)
c_hat = float(eff["c_hat"])
x_nb = np.log(c_nb_all[1:1 + n]) - np.log(c_hat)
x_bn = np.log(c_bn_all[1:1 + n]) - np.log(c_hat)

r_nb = r_nb_all[1:1 + n]
r_bn = r_bn_all[1:1 + n]

Delta_nb = Delta_nb_all[1:1 + n]
Delta_bn = Delta_bn_all[1:1 + n]
dDelta_nb = Delta_nb - Delta_nb[0]
dDelta_bn = Delta_bn - Delta_bn[0]

fig, ax = plt.subplots(2, 2, figsize=(12, 8))

ax[0, 0].plot(t, ann(pi_nb), label="normal->bad")
ax[0, 0].plot(t, -ann(pi_bn), label="-(bad->normal)", linestyle="--")
ax[0, 0].axhline(0, linewidth=1)
ax[0, 0].set_title("Figure 6a: Inflation")
ax[0, 0].set_xlabel("t")
ax[0, 0].set_ylabel("Annualized percent")
ax[0, 0].legend()

ax[0, 1].plot(t, 100.0 * x_nb, label="normal->bad")
ax[0, 1].plot(t, -100.0 * x_bn, label="-(bad->normal)", linestyle="--")
ax[0, 1].axhline(0, linewidth=1)
ax[0, 1].set_title("Figure 6b: Output gap")
ax[0, 1].set_xlabel("t")
ax[0, 1].set_ylabel("Percent")
ax[0, 1].legend()

ax[1, 0].plot(t, ann(r_nb), label="normal->bad")
ax[1, 0].plot(t, -ann(r_bn), label="-(bad->normal)", linestyle="--")
ax[1, 0].axhline(0, linewidth=1)
ax[1, 0].set_title("Figure 6c: Real rate")
ax[1, 0].set_xlabel("t")
ax[1, 0].set_ylabel("Annualized percent")
ax[1, 0].legend()

ax[1, 1].plot(t, dDelta_nb, label="normal->bad")
ax[1, 1].plot(t, -dDelta_bn, label="-(bad->normal)", linestyle="--")
ax[1, 1].axhline(0, linewidth=1)
ax[1, 1].set_title("Figure 6d: Price dispersion (deviation)")
ax[1, 1].set_xlabel("t")
ax[1, 1].set_ylabel("Delta - Delta(impact)")
ax[1, 1].legend()

plt.tight_layout()
plt.show()