In [None]:
import os, sys, json
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import numpy as np
import torch
import pathlib
from dataclasses import replace

def _find_project_root():
    here = pathlib.Path.cwd().resolve()
    for p in [here, *here.parents]:
        if (p / "src").is_dir():
            return p
    # Common Google Colab clone location
    cand = pathlib.Path("/content/econml")
    if (cand / "src").is_dir():
        return cand
    raise RuntimeError("Could not find project root containing src/. If on Colab, clone repo to /content/econml.")

PROJECT_ROOT = _find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import ModelParams, TrainConfig
from src.deqn import PolicyNetwork, Trainer, simulate_paths
from src.io_utils import make_run_dir, save_run_metadata, save_selected_run, pack_config, save_torch, save_csv, save_json, save_npz
from src.metrics import residual_quality

# ---------- config ----------
ARTIFACTS_ROOT = str(PROJECT_ROOT / "artifacts")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
params = ModelParams(device=DEVICE, dtype=torch.float32)

cfg_seed = 0
cfg_probe = TrainConfig.full(seed=cfg_seed)
run_dir = make_run_dir(ARTIFACTS_ROOT, "commitment", tag=cfg_probe.mode, seed=cfg_probe.seed)
cfg_base = TrainConfig.full(seed=cfg_seed, run_dir=run_dir, artifacts_root=ARTIFACTS_ROOT)
# Commitment-only training override for stability: easier warm-up in phase1, strict refinement in phase2.
cfg = replace(
    cfg_base,
    strict_eps_max_steps=120_000,
    n_path=192,
    phase1=replace(cfg_base.phase1, eps_stop=1e-5),
    phase2=replace(cfg_base.phase2, eps_stop=1e-8, batch_size=96, lr=1e-6),
)

save_run_metadata(run_dir, pack_config(params, cfg, extra={"policy":"commitment"}))
print("Run dir:", run_dir)


In [None]:

# ---------- model ----------
d_in, d_out = 7, 13
net = PolicyNetwork(d_in, d_out, hidden=cfg.hidden_layers, activation=cfg.activation)

trainer = Trainer(
    params=params,
    cfg=cfg,
    policy="commitment",
    net=net,
    rbar_by_regime=None,
)



In [None]:

# ---------- train ----------
losses = trainer.train(
    commitment_sss=None,
    n_path=cfg.n_path,
    n_paths_per_step=cfg.n_paths_per_step,
)

# save weights and log
save_torch(os.path.join(run_dir, "weights.pt"), trainer.net.state_dict())
import pandas as pd
df = pd.DataFrame({"iter": np.arange(len(losses)), "loss": losses})
save_csv(os.path.join(run_dir, "train_log.csv"), df)

# quality on a fresh validation batch sampled from the model's simulated state distribution
# Discretion residuals require autograd through Delta-derivative terms.
ctx = torch.enable_grad() if trainer.policy == "discretion" else torch.inference_mode()
with ctx:
    x_val = trainer.simulate_initial_state(int(cfg.val_size), commitment_sss=None)
    # optional short burn-in for validation states (kept small; training itself is path-based)
    val_burn = int(getattr(cfg, "val_burn_in", 200))
    for _ in range(val_burn):
        x_val = trainer._step_state(x_val)
    resid = trainer._residuals(x_val).detach().cpu().numpy()
q = residual_quality(resid, tol=getattr(cfg, "report_tol", 1e-3))
save_json(os.path.join(run_dir, "train_quality.json"), q)
print("Train quality:", q)

# Commitment SSS and timeless simulations are computed in the next cells.


In [None]:
# ---------- Commitment SSS and diagnostics (timeless bootstrap + refinement) ----------
from dataclasses import replace
from src.steady_states import solve_commitment_sss_from_policy
from src.sanity_checks import fixed_point_check, residuals_check_switching_consistent

# 1) Bootstrap: compute switching-consistent SSS from the first-pass network.
# This gives lagged multipliers (vartheta_prev, varrho_prev) needed for timeless initialization.
comm_sss_bootstrap = solve_commitment_sss_from_policy(params, trainer.net)

# 2) Timeless refinement: re-train with commitment states initialized at SSS multipliers.
# We keep phase-2-only refinement (same equations/objective, cheaper than full re-train).
timeless_refine_rounds = 1
timeless_phase2_steps = int(cfg.phase2.steps)
cfg_tl = replace(
    cfg,
    phase1=replace(cfg.phase1, steps=0),
    phase2=replace(cfg.phase2, steps=timeless_phase2_steps),
)

timeless_losses = []
comm_sss_final = comm_sss_bootstrap
for rr in range(timeless_refine_rounds):
    trainer = Trainer(
        params=params,
        cfg=cfg_tl,
        policy="commitment",
        net=trainer.net,
        rbar_by_regime=None,
    )
    ft_losses = trainer.train(
        commitment_sss=comm_sss_final.by_regime,
        n_path=cfg_tl.n_path,
        n_paths_per_step=cfg_tl.n_paths_per_step,
    )
    timeless_losses.extend(ft_losses)
    comm_sss_final = solve_commitment_sss_from_policy(params, trainer.net)
    if len(ft_losses):
        print(f"Timeless refine round {rr+1}: steps={len(ft_losses)}, best_loss={min(ft_losses):.3e}")
    else:
        print(f"Timeless refine round {rr+1}: no extra steps were run.")

# Persist refined network and timeless-refine log
save_torch(os.path.join(run_dir, "weights.pt"), trainer.net.state_dict())
import pandas as pd
save_csv(
    os.path.join(run_dir, "train_log_timeless_refine.csv"),
    pd.DataFrame({"iter": np.arange(len(timeless_losses)), "loss": timeless_losses}),
)

# Save final switching-consistent timeless SSS
save_json(os.path.join(run_dir, 'sss_policy_fixed_point.json'), {'policy':'commitment','by_regime': comm_sss_final.by_regime})

# Refresh validation quality on SSS-initialized commitment states
with torch.inference_mode():
    x_val = trainer.simulate_initial_state(int(cfg.val_size), commitment_sss=comm_sss_final.by_regime)
    val_burn = int(getattr(cfg, "val_burn_in", 200))
    for _ in range(val_burn):
        x_val = trainer._step_state(x_val)
    resid = trainer._residuals(x_val).detach().cpu().numpy()
q = residual_quality(resid, tol=getattr(cfg, "report_tol", 1e-3))
save_json(os.path.join(run_dir, "train_quality.json"), q)
print("Train quality (after timeless refinement):", q)

print('=== COMMITMENT SSS (switching-consistent, includes lagged multipliers; timeless perspective) ===')
for _s in sorted(comm_sss_final.by_regime.keys()):
    print(f'Regime {_s}:')
    for _k,_v in comm_sss_final.by_regime[_s].items():
        print(f'{_k:>20}: {_v}')

fp = fixed_point_check(params, trainer.net, policy='commitment', sss_by_regime=comm_sss_final.by_regime)
rc = residuals_check_switching_consistent(params, trainer.net, policy='commitment', sss_by_regime=comm_sss_final.by_regime)
print('Fixed-regime one-step check max |x_next-x| by regime (NOT Table-2 SSS):', {k:v.max_abs_state_diff for k,v in fp.items()})
print('Switching-consistent residual check max |res| by regime:', {k:v.max_abs_residual for k,v in rc.items()})
print('Residual keys:', list(next(iter(rc.values())).residuals.keys()))


In [None]:
# ---------- Save sanity checks ----------
save_json(os.path.join(run_dir, 'sanity_checks.json'), {
    'policy': 'commitment',
    'fixed_regime_one_step_max_abs_state_diff': {int(k): float(v.max_abs_state_diff) for k,v in fp.items()},
    'residual_max_abs': {int(k): float(v.max_abs_residual) for k,v in rc.items()},
    'residuals_by_regime': {int(k): {kk: float(vv) for kk,vv in v.residuals.items()} for k,v in rc.items()},
})

In [None]:
# ---------- Simulate (timeless commitment: start from SSS incl. lagged multipliers) ----------
# This produces sim_paths.npz used by Table 2 / figures.
B_sim = 512
T_sim = 20000
burn_in_sim = 2000

x0_sim = trainer.simulate_initial_state(B_sim, commitment_sss=comm_sss_final.by_regime)
sim = simulate_paths(
    params=params,
    policy="commitment",
    net=trainer.net,
    T=T_sim,
    burn_in=burn_in_sim,
    x0=x0_sim,
    compute_implied_i=True,
    gh_n=3,
    thin=1,
    show_progress=True,
    store_states=False,
)

save_npz(os.path.join(run_dir, "sim_paths.npz"), **sim)
print("Saved sim_paths:", os.path.join(run_dir, "sim_paths.npz"))

# Mark run as selected only after all core artifacts are saved.
save_selected_run(ARTIFACTS_ROOT, trainer.policy, run_dir)
