In [None]:
import os, sys, json
import numpy as np
import torch
import pathlib

def _find_project_root():
    here = pathlib.Path.cwd().resolve()
    for p in [here, *here.parents]:
        if (p / "src").is_dir():
            return p
    # Common Google Colab clone location
    cand = pathlib.Path("/content/econml")
    if (cand / "src").is_dir():
        return cand
    raise RuntimeError("Could not find project root containing src/. If on Colab, clone repo to /content/econml.")

PROJECT_ROOT = _find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import ModelParams, TrainConfig
from src.deqn import PolicyNetwork, Trainer, simulate_paths
from src.io_utils import make_run_dir, save_run_metadata, save_selected_run, pack_config, save_torch, save_csv, save_json, save_npz, ensure_dir
from src.metrics import residual_quality

# ---------- config ----------
ARTIFACTS_ROOT = str(PROJECT_ROOT / "artifacts")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
params = ModelParams(device=DEVICE, dtype=torch.float32)

cfg_seed = 0
cfg_probe = TrainConfig.mid(seed=cfg_seed)
run_dir = make_run_dir(ARTIFACTS_ROOT, "discretion", tag=cfg_probe.mode, seed=cfg_probe.seed)
cfg = TrainConfig.mid(seed=cfg_seed, run_dir=run_dir, artifacts_root=ARTIFACTS_ROOT)

save_run_metadata(run_dir, pack_config(params, cfg, extra={"policy":"discretion"}))
print("Run dir:", run_dir)


In [None]:
rbar = None  # only used for mod_taylor


In [None]:

# ---------- model ----------
d_in, d_out = 5, 11
net = PolicyNetwork(d_in, d_out, hidden=cfg.hidden_layers, activation=cfg.activation)

trainer = Trainer(
    params=params,
    cfg=cfg,
    policy="discretion",
    net=net,
    rbar_by_regime=rbar if "discretion"=="mod_taylor" else None,
)



In [None]:

# ---------- train ----------
losses = trainer.train(
    commitment_sss=None,
    n_path=cfg.n_path,
    n_paths_per_step=cfg.n_paths_per_step,
)

# save weights and log
save_torch(os.path.join(run_dir, "weights.pt"), trainer.net.state_dict())
import pandas as pd
df = pd.DataFrame({"iter": np.arange(len(losses)), "loss": losses})
save_csv(os.path.join(run_dir, "train_log.csv"), df)

# quality on a fresh validation batch sampled from the model's simulated state distribution
# Discretion residuals require autograd through Delta-derivative terms.
ctx = torch.enable_grad() if trainer.policy == "discretion" else torch.inference_mode()
with ctx:
    x_val = trainer.simulate_initial_state(int(cfg.val_size), commitment_sss=None)
    # optional short burn-in for validation states (kept small; training itself is path-based)
    val_burn = int(getattr(cfg, "val_burn_in", 200))
    for _ in range(val_burn):
        x_val = trainer._step_state(x_val)
    resid = trainer._residuals(x_val).detach().cpu().numpy()
q = residual_quality(resid, tol=getattr(cfg, "report_tol", 1e-3))
save_json(os.path.join(run_dir, "train_quality.json"), q)
print("Train quality:", q)

# optional: mark this run as selected for results notebook
save_selected_run(ARTIFACTS_ROOT, trainer.policy, run_dir)

# ---------- simulate ergodic paths ----------
x0 = trainer.simulate_initial_state(512, commitment_sss=None)
sim = simulate_paths(
    params=params,
    policy=trainer.policy,
    net=trainer.net,
    T=20000,
    burn_in=2000,
    x0=x0,
    rbar_by_regime=rbar if trainer.policy=="mod_taylor" else None,
    compute_implied_i=True,
    gh_n=3,
    thin=10,
    show_progress=True,
)
save_npz(os.path.join(run_dir, "sim_paths.npz"), **sim)
print("Saved sim_paths to:", os.path.join(run_dir, "sim_paths.npz"))


In [None]:
# ---------- SSS from trained policy (paper-faithful) ----------
from src.steady_states import solve_discretion_sss_from_policy_switching
disc_sss = solve_discretion_sss_from_policy_switching(params, trainer.net)
save_json(os.path.join(run_dir, 'sss_policy_fixed_point.json'), {'policy':'discretion','by_regime': disc_sss.by_regime})
print('=== DISCRETION SSS (policy fixed point (switching-consistent), by regime) ===')
for _s in sorted(disc_sss.by_regime.keys()):
    print(f'Regime {_s}:')
    for _k,_v in disc_sss.by_regime[_s].items():
        print(f'{_k:>20}: {_v}')


In [None]:
# ---------- Sanity checks (fixed-regime mapping + switching-consistent residuals) ----------
from src.sanity_checks import fixed_point_check, residuals_check_switching_consistent
fp = fixed_point_check(params, trainer.net, policy='discretion', sss_by_regime=disc_sss.by_regime)
rc = residuals_check_switching_consistent(params, trainer.net, policy='discretion', sss_by_regime=disc_sss.by_regime)
print('Fixed-regime one-step check max |x_next-x| by regime (NOT Table-2 SSS):', {k:v.max_abs_state_diff for k,v in fp.items()})
print('Switching-consistent residual check max |res| by regime:', {k:v.max_abs_residual for k,v in rc.items()})
print('Residual keys:', list(next(iter(rc.values())).residuals.keys()))


In [None]:
# ---------- Save sanity checks ----------
save_json(os.path.join(run_dir, 'sanity_checks.json'), {
    'policy': 'discretion',
    'fixed_regime_one_step_max_abs_state_diff': {int(k): float(v.max_abs_state_diff) for k,v in fp.items()},
    'residual_max_abs': {int(k): float(v.max_abs_residual) for k,v in rc.items()},
    'residuals_by_regime': {int(k): {kk: float(vv) for kk,vv in v.residuals.items()} for k,v in rc.items()},
})