In [None]:
# load_and_run_hierarchical_inference_prettylog_final_compute_only.py
# Same logic as your script; final output now prints ONLY the compute() function
# and evaluation is separated from final code.

import os, re, math, sys, time
from typing import List, Dict
import numpy as np
import torch

# paths (adjust if your files are elsewhere)
EXEC_MODEL_PATH = "saved_models/dqn_dqfd_best.pt"
PLANNER_MODEL_PATH = "saved_models/planner_dqn_best.pt"

device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# Pretty logging helper (non-intrusive)
# ---------------------------
from datetime import datetime

def now():
    return datetime.now().strftime("%H:%M:%S")

def log_header(title):
    print("\n" + "="*78)
    print(f"[{now()}] {title}")
    print("="*78)

def log_section(title):
    print("\n" + "-"*78)
    print(f"[{now()}] {title}")
    print("-"*78)

def log_step(role, action, detail=None):
    prefix = f"[{now()}] {role} ⟶ {action}"
    if detail:
        print(prefix + ":")
        for line in detail.splitlines():
            print("   " + line)
    else:
        print(prefix)

def log_simple(*args, **kwargs):
    print(*args, **kwargs)

# initial header
log_header("Hierarchical Inference (Planner DQN + Executive DQN) — Pretty Log")
log_simple(f"[{now()}] Device: {device}")

# ---------------------------
# Recreate DuelingQNetwork exactly like training code
# ---------------------------
import torch.nn as nn
import torch.nn.functional as F

class DuelingQNetwork(nn.Module):
    def __init__(self, input_dim, hidden=[512,256], output_dim=4):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        self.trunk = nn.Sequential(*layers)
        # **Important**: attribute names match training: value_head, adv_head
        half = prev//2 if prev//2>0 else 32
        self.value_head = nn.Sequential(
            nn.Linear(prev, half),
            nn.ReLU(),
            nn.Linear(half, 1)
        )
        self.adv_head = nn.Sequential(
            nn.Linear(prev, half),
            nn.ReLU(),
            nn.Linear(half, output_dim)
        )
    def forward(self, x):
        t = self.trunk(x)
        v = self.value_head(t)
        a = self.adv_head(t)
        q = v + (a - a.mean(dim=1, keepdim=True))
        return q

# ---------------------------
# Executive hash embedder (exact as training)
# ---------------------------
VOCAB_SIZE = 4096
EMB_DIM = 64
_rng = np.random.RandomState(42)
_emb_table = _rng.normal(scale=0.5, size=(VOCAB_SIZE, EMB_DIM)).astype(np.float32)

def _tokens_of(text: str):
    if not text:
        return []
    return re.findall(r"\w+", text.lower())

def text_to_emb(text: str) -> np.ndarray:
    toks = _tokens_of(text)
    if not toks:
        return np.zeros(EMB_DIM, dtype=np.float32)
    ids = [abs(hash(t)) % VOCAB_SIZE for t in toks]
    vecs = _emb_table[ids]
    return vecs.mean(axis=0)

def make_exec_state(task_prompt: str, last_code: str, last_test_feedback: str) -> np.ndarray:
    t_emb = text_to_emb(task_prompt)
    c_emb = text_to_emb(last_code or "")
    f_emb = text_to_emb(last_test_feedback or "")
    flags = np.array([1.0 if (last_code and last_code.strip()) else 0.0,
                      1.0 if (last_test_feedback and last_test_feedback.strip()) else 0.0], dtype=np.float32)
    return np.concatenate([t_emb, c_emb, f_emb, flags]).astype(np.float32)

# compute exec state dim
EXEC_STATE_DIM = make_exec_state("def foo():\n", "", "").shape[0]
log_simple(f"[{now()}] Exec state dim: {EXEC_STATE_DIM}")

# ---------------------------
# Planner embedder: try sentence-transformers like training, else fallback to hash-embedding
# ---------------------------
USE_SENTENCE_TRANSFORMER = True
try:
    from sentence_transformers import SentenceTransformer
    embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    def make_planner_state(prompt: str, last_code: str, last_test_feedback: str,
                         has_planned: bool = False, executed_count: int = 0,
                         integrated: bool = False, evaluated: bool = False,
                         invalid_action: bool = False, last_test_passed: bool = False) -> np.ndarray:
        task_emb = embedder.encode([prompt], show_progress_bar=False)[0]
        code_emb = embedder.encode([last_code or ""], show_progress_bar=False)[0]
        feedback_emb = embedder.encode([last_test_feedback or ""], show_progress_bar=False)[0]
        flags = np.array([
            1.0 if has_planned else 0.0,
            float(executed_count),
            1.0 if integrated else 0.0,
            1.0 if evaluated else 0.0,
            1.0 if invalid_action else 0.0,
            1.0 if last_test_passed else 0.0,
        ], dtype=np.float32)
        return np.concatenate([task_emb, code_emb, feedback_emb, flags]).astype(np.float32)
    log_simple(f"[{now()}] Loaded sentence-transformers embedder for planner state.")
except Exception as e:
    log_simple(f"[{now()}] Could not load sentence-transformers (falling back). Reason: {e}")
    USE_SENTENCE_TRANSFORMER = False
    # fallback uses same hash emb but repeated to approximate dims
    def make_planner_state(prompt: str, last_code: str, last_test_feedback: str,
                         has_planned: bool = False, executed_count: int = 0,
                         integrated: bool = False, evaluated: bool = False,
                         invalid_action: bool = False, last_test_passed: bool = False) -> np.ndarray:
        t = text_to_emb(prompt)
        c = text_to_emb(last_code or "")
        f = text_to_emb(last_test_feedback or "")
        flags = np.array([
            1.0 if has_planned else 0.0,
            float(executed_count),
            1.0 if integrated else 0.0,
            1.0 if evaluated else 0.0,
            1.0 if invalid_action else 0.0,
            1.0 if last_test_passed else 0.0,
        ], dtype=np.float32)
        # concatenate and repeat to make a larger vector (not ideal but functional)
        return np.concatenate([t, c, f, flags]).astype(np.float32)

# compute planner state dim
PLANNER_STATE_DIM = make_planner_state("def foo():", "", "").shape[0]
log_simple(f"[{now()}] Planner state dim: {PLANNER_STATE_DIM}")

# ---------------------------
# Instantiate networks with matching shapes
# ---------------------------
exec_qnet = DuelingQNetwork(EXEC_STATE_DIM, hidden=[512,256], output_dim=4).to(device)
planner_qnet = DuelingQNetwork(PLANNER_STATE_DIM, hidden=[512,256], output_dim=5).to(device)

def try_load(model, path, model_name="model"):
    if not os.path.exists(path):
        log_simple(f"[{now()}] ERROR: {model_name} checkpoint not found at {path}")
        return False
    state = torch.load(path, map_location=device)
    try:
        model.load_state_dict(state)
        log_simple(f"[{now()}] Loaded {model_name} state_dict (strict match).")
        return True
    except RuntimeError as e:
        log_simple(f"[{now()}] Strict load failed for {model_name}: {e}")
        # attempt non-strict load to see which keys mismatch but still load what we can
        try:
            model.load_state_dict(state, strict=False)
            model_keys = set(model.state_dict().keys())
            state_keys = set(state.keys())
            missing = sorted(list(model_keys - state_keys))
            unexpected = sorted(list(state_keys - model_keys))
            if missing:
                log_simple(f"[{now()}] Keys in model but not in state (will be left as init): {missing}")
            if unexpected:
                log_simple(f"[{now()}] Keys in state but not in model: {unexpected}")
            log_simple(f"[{now()}] Loaded {model_name} with strict=False (partial load).")
            return True
        except Exception as e2:
            log_simple(f"[{now()}] Failed to load {model_name} even with strict=False: {e2}")
            return False

ok1 = try_load(exec_qnet, EXEC_MODEL_PATH, "executive")
ok2 = try_load(planner_qnet, PLANNER_MODEL_PATH, "planner")

if not (ok1 and ok2):
    log_section("Model load failure — aborting")
    log_simple(f"[{now()}] One or both models failed to load cleanly. See messages above. Aborting run.")
    sys.exit(1)

exec_qnet.eval()
planner_qnet.eval()

# ---------------------------
# Now perform the hierarchical inference (no training)
# ---------------------------
log_header("Hierarchical inference start")
QUERY = "write me a code for computing the third root of (a+b) ^ 2"
log_simple(f"[{now()}] Query: {QUERY}")


TASK_ADD = {
    "task_id": "compute_add_001",
    "prompt": "def add(a: int, b: int) -> int:\n    \"\"\"Return a + b.\"\"\"\n",
    "tests": [
        "assert add(1,2) == 3",
        "assert add(0,0) == 0",
        "assert add(-1,2) == 1",
    ],
    "correct_code": "def add(a: int, b: int) -> int:\n    return a + b",
    "incorrect_code": "def add(a: int, b: int) -> int:\n    return a - b",
    "debugged_code": "def add(a: int, b: int) -> int:\n    return a + b"
}

TASK_SQUARE = {
    "task_id": "square_001",
    "prompt": "def square(x: int) -> int:\n    \"\"\"Return x squared.\"\"\"\n",
    "tests": [
        "assert square(2) == 4",
        "assert square(0) == 0",
        "assert square(-3) == 9",
    ],
    "correct_code": "def square(x: int) -> int:\n    return x ** 2",
    "incorrect_code": "def square(x: int) -> int:\n    return x * 2",
    "debugged_code": "def square(x: int) -> int:\n    return x ** 2"
}

TASK_CBRT = {
    "task_id": "cbrt_001",
    "prompt": "def cbrt(x: int) -> float:\n    \"\"\"Return the real cube root of x. For perfect cubes return exact integer when possible.\"\"\"\n",
    "tests": [
        "assert cbrt(27) == 3",
        "assert cbrt(8) == 2",
        "assert cbrt(0) == 0",
        "assert cbrt(-27) == -3",
    ],
    "correct_code": "def cbrt(x: int) -> float:\n    if x == 0:\n        return 0\n    sign = -1 if x < 0 else 1\n    x_abs = abs(x)\n    r = round(x_abs ** (1.0 / 3.0))\n    if r ** 3 == x_abs:\n        return sign * r\n    return sign * (x_abs ** (1.0 / 3.0))",
    "incorrect_code": "def cbrt(x: int) -> float:\n    return x ** (1/3)",
    "debugged_code": "def cbrt(x: int) -> float:\n    if x == 0:\n        return 0\n    sign = -1 if x < 0 else 1\n    x_abs = abs(x)\n    r = round(x_abs ** (1.0 / 3.0))\n    if r ** 3 == x_abs:\n        return sign * r\n    return sign * (x_abs ** (1.0 / 3.0))"
}

SUBTASKS = [TASK_ADD, TASK_SQUARE, TASK_CBRT]

# Executive simulation: run generate->test->stop for each subtask using correct_code
log_section("Executive: run subtasks (simulated LLM actions)")
produced = {}
for t in SUBTASKS:
    log_simple("")
    log_step("Planner", "I will ask the Executive to solve subtask", t["task_id"])
    # generate
    log_step("Executive", "generate", "Producing candidate code...")
    gen_code = t["correct_code"]  # we use the provided correct code (simulated LLM)
    print()
    for line in gen_code.splitlines():
        print("    " + line)
    # test
    log_step("Executive", "test", "Running unit tests for the generated code")
    # run tests in a safe exec environment
    exec_globals = {}
    try:
        exec("from typing import List", exec_globals)
        exec(gen_code, exec_globals)
        passed = True
        passed_count = 0
        for test in t["tests"]:
            try:
                exec(test, exec_globals)
                passed_count += 1
            except Exception as e:
                passed = False
                log_simple(f"[{now()}]    Test failed: {test} -> {repr(e)}")
        log_simple(f"[{now()}]    Test results: {passed_count}/{len(t['tests'])} passed")
    except Exception as e:
        log_simple(f"[{now()}]    Error during generation or test execution: {e}")
        passed = False
    # stop
    log_step("Executive", "stop", "Episode finished for subtask")
    produced[t["task_id"]] = gen_code

# Integration: combine helpers into compute function
log_section("Integration: combining helper functions into final compute()")
integrated = "\n\n".join([produced["compute_add_001"], produced["square_001"], produced["cbrt_001"]]) + """

def compute(a, b):
    \"\"\"Compute the cube root of (a+b)^2 (i.e. (a+b)^(2/3)).\"\"\"
    s = add(a, b)
    s2 = square(s)
    return cbrt(s2)
"""
# print the full integrated code (no truncation)
log_simple(f"[{now()}] Integrated code assembled. Full content below:\n")
print(integrated)

# Final evaluation: executive receives integrated code as last generated code, then test
log_section("Final Evaluation: validate integrated compute()")
exec_globals = {}
final_ok = False
try:
    exec("import math\nfrom typing import List", exec_globals)
    exec(integrated, exec_globals)
    # run planner tests from earlier planner training (we use the planner task tests you provided in planner)
    final_tests = [
        ("compute(1,2)", (1+2)**(2/3)),
        ("compute(0,0)", (0+0)**(2/3)),
        ("compute(8,1)", (8+1)**(2/3)),
        ("compute(-1,8)", (-1+8)**(2/3)),
    ]
    ok_count = 0
    for expr, expected in final_tests:
        try:
            out = eval(expr, exec_globals)
            is_ok = math.isclose(out, expected, rel_tol=1e-9)
            log_simple(f"[{now()}] TEST: {expr} -> {out}  expected ≈ {expected}  ok={is_ok}")
            if is_ok: ok_count += 1
        except Exception as e:
            log_simple(f"[{now()}] TEST: {expr} -> ERROR: {e}")
    log_simple(f"[{now()}] Final tests passed: {ok_count}/{len(final_tests)}")
    final_ok = (ok_count == len(final_tests))
except Exception as e:
    log_simple(f"[{now()}] Error during final integration/eval: {e}")

# ---------------------------
# Print final code delivered to user: ONLY the compute() function (no helpers)
# and separated from evaluation logs
# ---------------------------
print("\n" + "="*78)
print(f"[{now()}] FINAL CODE (compute function only)")
print("="*78 + "\n")

compute_only = """def compute(a, b):
    \"\"\"Compute the cube root of (a+b)^2 (i.e. (a+b)^(2/3)).\"\"\"
    s = add(a, b)
    s2 = square(s)
    return cbrt(s2)
"""
print(compute_only)



[18:29:06] Hierarchical Inference (Planner DQN + Executive DQN) — Pretty Log
[18:29:06] Device: cpu
[18:29:06] Exec state dim: 194


Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


[18:29:08] Loaded sentence-transformers embedder for planner state.
[18:29:08] Planner state dim: 1158
[18:29:08] Loaded executive state_dict (strict match).
[18:29:08] Loaded planner state_dict (strict match).

[18:29:08] Hierarchical inference start
[18:29:08] Query: write me a code for computing the third root of (a+b) ^ 2

------------------------------------------------------------------------------
[18:29:08] Executive: run subtasks (simulated LLM actions)
------------------------------------------------------------------------------

[18:29:08] Planner ⟶ I will ask the Executive to solve subtask:
   compute_add_001
[18:29:08] Executive ⟶ generate:
   Producing candidate code...

    def add(a: int, b: int) -> int:
        return a + b
[18:29:08] Executive ⟶ test:
   Running unit tests for the generated code
[18:29:08]    Test results: 3/3 passed
[18:29:08] Executive ⟶ stop:
   Episode finished for subtask

[18:29:08] Planner ⟶ I will ask the Executive to solve subtask:
   square