In [None]:
import os, sys, json
import numpy as np
import torch
import pathlib

def _find_project_root():
    here = pathlib.Path.cwd().resolve()
    for p in [here, *here.parents]:
        if (p / "src").is_dir():
            return p
    # Common Google Colab clone location
    cand = pathlib.Path("/content/econml")
    if (cand / "src").is_dir():
        return cand
    raise RuntimeError("Could not find project root containing src/. If on Colab, clone repo to /content/econml.")

PROJECT_ROOT = _find_project_root()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.config import ModelParams
from src.steady_states import solve_flexprice_sss, export_rbar_tensor
from src.io_utils import ensure_dir, save_json

ARTIFACTS_ROOT = str(PROJECT_ROOT / "artifacts")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)
params = ModelParams(device=DEVICE, dtype=torch.float32)

flex = solve_flexprice_sss(params)
rbar = export_rbar_tensor(params, flex)

print("=== FLEX SSS (by regime) ===")
for s in [0,1]:
    print(f"Regime {s}:")
    for k,v in flex.by_regime[s].items():
        print(f"  {k:>18s}: {v}")

print("\n=== rbar_by_regime ===")
print(rbar)

out_dir = os.path.join(ARTIFACTS_ROOT, "flex")
ensure_dir(out_dir)
save_json(os.path.join(out_dir, "sss.json"), {
    "policy": "flex",
    "by_regime": flex.by_regime,
    "rbar_by_regime": rbar.detach().cpu().tolist(),
})
print(f"Saved flex SSS to: {os.path.join(out_dir, 'sss.json')}")
