<a href="https://colab.research.google.com/github/manivafapour/ppt_mani/blob/main/RL_Hierarchical.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# planner_dqn_final.py
# Single-episode DQN trainer: planning -> execution -> integration -> evaluation -> stop
# Final version: exploration samples from full action space (no mask) so agent must learn.

import os, random, json, math, time
from collections import deque
from typing import List, Dict, Any, Tuple
import numpy as np
from tqdm.auto import tqdm
from rich.console import Console
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

console = Console()
os.makedirs("logs", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

# ------------------ Training task (single planner task) ------------------
PLANNER_TASK = {
    "task_id": "planner_compute_2o3",
    "prompt": "Write a Python function `compute(a, b)` that returns (a + b)^(2/3).",
    "tests": [
        "import math",
        "assert math.isclose(compute(1,2), (1+2)**(2/3), rel_tol=1e-9)",
        "assert math.isclose(compute(0,0), (0+0)**(2/3), rel_tol=1e-9)",
        "assert math.isclose(compute(8,1), (8+1)**(2/3), rel_tol=1e-9)",
        "assert math.isclose(compute(-1,8), (-1+8)**(2/3), rel_tol=1e-9)",
    ],
    "correct_code": """def compute(a, b):
    # compute (a + b) ** (2/3)
    return (a + b) ** (2.0/3.0)""",
    "incorrect_code": """def compute(a, b):
    # wrong: uses integer division or wrong exponent
    return (a + b) ** (1/2)""",
}

# ------------------ Local test runner ------------------
def run_tests_locally(code_str: str, tests: List[str]) -> Tuple[bool, Dict[str, Any]]:
    import traceback
    feedback = {"number_passed": 0, "total": len(tests), "failures": []}
    exec_globals = {}
    try:
        exec(code_str, exec_globals)
    except Exception as e:
        return False, {"error": str(e), "trace": traceback.format_exc()}
    for test in tests:
        try:
            exec(test, exec_globals)
            feedback["number_passed"] += 1
        except AssertionError:
            feedback["failures"].append({"test": test, "error": "AssertionError"})
        except Exception as e:
            feedback["failures"].append({"test": test, "error": str(e)})
    return feedback["number_passed"] == feedback["total"], feedback

console.log("Planner training task and test runner ready.")

# ------------------ Planner action space ------------------
from enum import IntEnum

class PlannerAction(IntEnum):
    PLANNING = 0
    EXECUTION = 1
    INTEGRATION = 2
    EVALUATION = 3
    STOP = 4

PlannerAction.NAMES = {
    PlannerAction.PLANNING: "planning",
    PlannerAction.EXECUTION: "execution",
    PlannerAction.INTEGRATION: "integration",
    PlannerAction.EVALUATION: "evaluation",
    PlannerAction.STOP: "stop",
}
console.log("Planner Action space:", PlannerAction.NAMES)

# ------------------ Simulated Planner LLM (deterministic) ------------------
class SimulatedPlannerLLM:
    def __init__(self, task: Dict):
        self.task = task

    def planning(self, prompt: str) -> Dict[str, str]:
        return {
            "task_a": "Write function add(a,b) that returns a + b",
            "task_b": "Write function square(x) that returns x ** 2",
            "task_c": "Write function cbrt(x) that returns x ** (1/3)",
        }

    def execute(self, subtask_prompts: Dict[str, str]) -> Dict[str, str]:
        return {
            "task_a": "def add(a, b):\n    return a + b\n",
            "task_b": "def square(x):\n    return x ** 2\n",
            "task_c": "def cbrt(x):\n    return x ** (1.0/3.0)\n",
        }

    def integrate(self, exec_outputs: Dict[str, str], original_prompt: str) -> str:
        integrated = """def compute(a, b):
    # integrated result computing (a + b) ** (2/3)
    return (a + b) ** (2.0/3.0)
"""
        helpers = """
def add(a, b):
    return a + b

def square(x):
    return x ** 2

def cbrt(x):
    return x ** (1.0/3.0)
"""
        return helpers + "\n" + integrated

    def evaluate(self, integrated_code: str, tests: List[str]) -> Tuple[bool, Dict[str, Any]]:
        return run_tests_locally(integrated_code, tests)

sim_planner = SimulatedPlannerLLM(PLANNER_TASK)

# ------------------ Planner environment ------------------
class PlannerEnv:
    def __init__(self, task: Dict, max_steps: int = 10):
        self.task = task
        self.max_steps = max_steps
        self.reset()

    def reset(self):
        self.steps = 0
        self.done = False
        self.success = False
        self.has_planned = False
        self.has_executed = False
        self.has_integrated = False
        self.has_evaluated = False
        self.invalid_action = False
        self.plan_output = None
        self.exec_output = None
        self.integrated_code = ""
        self.test_feedback = ""
        self.test_results = []
        return self._state()

    def _fail(self):
        return -5.0

    def step(self, action: int):
        if self.done:
            raise RuntimeError("Episode already finished")
        self.steps += 1
        reward = -0.1
        info = {"action": PlannerAction.NAMES[PlannerAction(action)]}

        # PLANNING
        if action == PlannerAction.PLANNING:
            if self.has_planned:
                reward += self._fail()
                self.invalid_action = True
            else:
                self.plan_output = sim_planner.planning(self.task["prompt"])
                self.has_planned = True
                reward += 1.0

        # EXECUTION
        elif action == PlannerAction.EXECUTION:
            if (not self.has_planned) or self.has_executed:
                reward += self._fail()
                self.invalid_action = True
            else:
                self.exec_output = sim_planner.execute(self.plan_output)
                self.has_executed = True
                reward += 1.5

        # INTEGRATION
        elif action == PlannerAction.INTEGRATION:
            if (not self.has_executed) or self.has_integrated:
                reward += self._fail()
                self.invalid_action = True
            else:
                self.integrated_code = sim_planner.integrate(self.exec_output, self.task["prompt"])
                self.has_integrated = True
                reward += 2.0

        # EVALUATION
        elif action == PlannerAction.EVALUATION:
            if (not self.has_integrated) or self.has_evaluated:
                reward += self._fail()
                self.invalid_action = True
            else:
                passed, fb = sim_planner.evaluate(self.integrated_code, self.task["tests"])
                self.test_results.append(passed)
                self.test_feedback = f"Tests {fb['number_passed']}/{fb['total']}"
                self.has_evaluated = True
                reward += 5.0 if passed else -1.0

        # STOP
        elif action == PlannerAction.STOP:
            self.done = True
            if self.has_evaluated and self.test_results and self.test_results[-1] and (not self.invalid_action):
                self.success = True
            reward += 30.0 if self.success else -10.0

        # step limit
        if self.steps >= self.max_steps and not self.done:
            self.done = True
            reward -= 10.0

        info["success"] = self.success
        info["invalid_action"] = self.invalid_action
        return self._state(), reward, self.done, info

    def _state(self):
        last_code = self.integrated_code if self.integrated_code else ""
        last_feedback = self.test_feedback
        return make_state_embedding(
            self.task["prompt"],
            last_code,
            last_feedback,
            has_planned=self.has_planned,
            executed_count=1 if self.has_executed else 0,
            integrated=1 if self.has_integrated else 0,
            evaluated=1 if self.has_evaluated else 0,
            invalid_action=self.invalid_action,
            last_test_passed=bool(self.test_results[-1]) if self.test_results else False,
        )

# ------------------ Embedding helper ------------------
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def make_state_embedding(task_prompt: str, last_code: str, last_test_feedback: str,
                         has_planned: bool = False, executed_count: int = 0,
                         integrated: bool = False, evaluated: bool = False,
                         invalid_action: bool = False, last_test_passed: bool = False) -> np.ndarray:
    task_emb = embedder.encode([task_prompt], show_progress_bar=False)
    code_emb = embedder.encode([last_code or ""], show_progress_bar=False)
    feedback_emb = embedder.encode([last_test_feedback or ""], show_progress_bar=False)
    state_vec = np.concatenate([task_emb[0], code_emb[0], feedback_emb[0]])
    flags = np.array([
        1.0 if has_planned else 0.0,
        float(executed_count),
        1.0 if integrated else 0.0,
        1.0 if evaluated else 0.0,
        1.0 if invalid_action else 0.0,
        1.0 if last_test_passed else 0.0,
    ], dtype=np.float32)
    return np.concatenate([state_vec, flags]).astype(np.float32)

sample_state = make_state_embedding(PLANNER_TASK["prompt"], "", "", False, 0, False, False, False, False)
STATE_DIM = sample_state.shape[0]
console.log(f"Planner state dim: {STATE_DIM}")

# ------------------ Dueling DQN + Replay Buffer ------------------
class DuelingQNetwork(nn.Module):
    def __init__(self, input_dim, hidden=[512,256], output_dim=len(PlannerAction.NAMES)):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        self.trunk = nn.Sequential(*layers)
        self.value_head = nn.Sequential(
            nn.Linear(prev, prev//2 if prev//2>0 else 32),
            nn.ReLU(),
            nn.Linear(prev//2 if prev//2>0 else 32, 1)
        )
        self.adv_head = nn.Sequential(
            nn.Linear(prev, prev//2 if prev//2>0 else 32),
            nn.ReLU(),
            nn.Linear(prev//2 if prev//2>0 else 32, output_dim)
        )

    def forward(self, x):
        x = self.trunk(x)
        value = self.value_head(x)
        adv = self.adv_head(x)
        q = value + (adv - adv.mean(dim=1, keepdim=True))
        return q

class ReplayBuffer:
    def __init__(self, capacity=8000):
        self.buffer = deque(maxlen=capacity)
    def push(self, s,a,r,ns,done):
        self.buffer.append((s,a,r,ns,done))
    def sample(self, batch_size):
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        s,a,r,ns,d = zip(*batch)
        return (np.stack(s), np.array(a), np.array(r, dtype=np.float32), np.stack(ns), np.array(d, dtype=np.float32))
    def __len__(self):
        return len(self.buffer)

class DQNAgent:
    def __init__(self, state_dim, action_dim=len(PlannerAction.NAMES), hidden=[512,256], lr=1e-4, gamma=0.99,
                 buffer_size=8000, batch_size=64, target_update=500, device=None, mask_actions=True,
                 eps_decay=0.9995):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.qnet = DuelingQNetwork(state_dim, hidden, action_dim).to(self.device)
        self.target = DuelingQNetwork(state_dim, hidden, action_dim).to(self.device)
        self.target.load_state_dict(self.qnet.state_dict())
        self.opt = Adam(self.qnet.parameters(), lr=lr)
        self.gamma = gamma
        self.buffer = ReplayBuffer(capacity=buffer_size)
        self.batch_size = batch_size
        self.action_dim = action_dim
        self.eps = 1.0
        self.eps_min = 0.05
        self.eps_decay = eps_decay
        self.learn_steps = 0
        self.target_update = target_update
        self.mask_actions = mask_actions

    def _compute_action_mask(self, state: np.ndarray):
        flags = state[-6:]
        has_planned = bool(flags[0])
        executed_count = int(flags[1])
        integrated = bool(flags[2])
        evaluated = bool(flags[3])
        invalid_action = bool(flags[4])
        last_test_passed = bool(flags[5])

        mask = np.ones(self.action_dim, dtype=bool)
        if has_planned:
            mask[PlannerAction.PLANNING] = False
        if (not has_planned) or (executed_count >= 1):
            mask[PlannerAction.EXECUTION] = False
        if (not executed_count) or integrated:
            mask[PlannerAction.INTEGRATION] = False
        if (not integrated) or evaluated:
            mask[PlannerAction.EVALUATION] = False
        if not (last_test_passed and (not invalid_action)):
            mask[PlannerAction.STOP] = False
        return mask

    def select_action(self, state):
        # EXPLORATION: sample uniformly from full action space (no mask) -> allows discovering failures
        if random.random() < self.eps:
            return random.randrange(self.action_dim)

        # EXPLOIT: use greedy Q with masking to avoid selecting impossible terminal actions
        s = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            qvals = self.qnet(s).cpu().numpy().squeeze(0)
        if self.mask_actions:
            mask = self._compute_action_mask(state)
            legal_q = np.where(mask, qvals, -1e9)
            if legal_q.max() <= -1e8:
                return random.randrange(self.action_dim)
            return int(int(legal_q.argmax()))
        else:
            return int(int(qvals.argmax()))

    def push_transition(self, s,a,r,ns,done):
        self.buffer.push(s,a,r,ns,done)

    def train_step(self):
        if len(self.buffer) < 32:
            return 0.0
        s,a,r,ns,d = self.buffer.sample(self.batch_size)
        s = torch.FloatTensor(s).to(self.device)
        a = torch.LongTensor(a).to(self.device)
        r = torch.FloatTensor(r).to(self.device)
        ns = torch.FloatTensor(ns).to(self.device)
        d = torch.FloatTensor(d).to(self.device)

        qvals = self.qnet(s).gather(1, a.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_actions = self.qnet(ns).argmax(dim=1, keepdim=True)
            next_q_target = self.target(ns).gather(1, next_actions).squeeze(1)
            target = r + self.gamma * next_q_target * (1 - d)

        loss = F.mse_loss(qvals, target)
        self.opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.qnet.parameters(), 0.5)
        self.opt.step()

        self.learn_steps += 1
        if self.learn_steps % self.target_update == 0:
            self.target.load_state_dict(self.qnet.state_dict())

        self.eps = max(self.eps_min, self.eps * self.eps_decay)
        return float(loss.item())

console.log("DQN agent for planner ready.")

# ------------------ Training loop ------------------
def train_planner(agent: DQNAgent, task: Dict, num_episodes=1000, max_steps=10, log_every=50,
                  seed_expert_episodes=0, save_after=100):
    metrics = {"episode_rewards": [], "losses": [], "episode_steps": [], "success_rate_window": [], "sample_eff": []}
    success_history = deque(maxlen=100)
    total_env_steps = 0

    ACTION_NAME_TO_INT = {v:k for k,v in PlannerAction.NAMES.items()}
    expert_seq = ["planning","execution","integration","evaluation","stop"]

    for _ in range(seed_expert_episodes):
        env = PlannerEnv(task, max_steps=max_steps)
        s = env.reset()
        for action_name in expert_seq:
            a = ACTION_NAME_TO_INT[action_name]
            ns, r, done, info = env.step(a)
            agent.push_transition(s, a, r, ns, float(done))
            s = ns
            if done: break

    best_success = 0.0
    best_avg_reward = -1e9

    for ep in tqdm(range(1, num_episodes+1), desc="Planner DQN Training"):
        env = PlannerEnv(task, max_steps=max_steps)
        state = env.reset()
        ep_reward = 0.0
        ep_loss = 0.0

        for t in range(max_steps):
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.push_transition(state, action, reward, next_state, float(done))
            loss = agent.train_step()
            state = next_state
            ep_reward += reward
            ep_loss += loss
            total_env_steps += 1
            if done:
                break

        metrics["episode_rewards"].append(ep_reward)
        metrics["losses"].append(ep_loss / (t+1) if (t+1)>0 else 0.0)
        metrics["episode_steps"].append(t+1)
        success_history.append(1.0 if done and info.get("success", False) else 0.0)
        metrics["success_rate_window"].append(np.mean(success_history))
        metrics["sample_eff"].append(sum(metrics["episode_rewards"]) / (total_env_steps + 1e-8))

        current_success = metrics["success_rate_window"][-1]
        current_avg_reward = np.mean(metrics["episode_rewards"][-log_every:])

        if ep > save_after and (current_success > best_success or (current_success == best_success and current_avg_reward > best_avg_reward)):
            best_success = current_success
            best_avg_reward = current_avg_reward
            torch.save(agent.qnet.state_dict(), "saved_models/planner_dqn_best.pt")
            console.log(f"[green]Ep {ep:4d} | New best model saved: success(last100)={best_success:.3%}, avg_reward={best_avg_reward:.3f}[/green]")

        if ep % log_every == 0 or ep == 1:
            console.log(f"[blue]Ep {ep:4d} | AvgReward {np.mean(metrics['episode_rewards'][-log_every:]):.3f} | "
                        f"Success(last100) {metrics['success_rate_window'][-1]:.3%} | Eps {agent.eps:.3f}")

    return metrics

# ------------------ Run training ------------------
if __name__ == '__main__':
    agent = DQNAgent(state_dim=STATE_DIM, action_dim=len(PlannerAction.NAMES), hidden=[512,256], lr=3e-4,
                     buffer_size=8000, batch_size=64, target_update=200, mask_actions=True, eps_decay=0.9995)

    metrics = train_planner(agent, PLANNER_TASK, num_episodes=3000, max_steps=10, log_every=50, seed_expert_episodes=0, save_after=100)

    torch.save(agent.qnet.state_dict(), "saved_models/planner_dqn_final.pt")
    console.log("Planner training complete. Models saved to saved_models/planner_dqn_*.pt")

    for i in range(5):
        env = PlannerEnv(PLANNER_TASK, max_steps=10)
        s = env.reset()
        path = []
        for _ in range(12):
            a = agent.select_action(s)
            path.append(PlannerAction.NAMES[PlannerAction(a)])
            s, r, done, info = env.step(a)
            if done: break
        console.log(f"Demo {i+1}: {path} | success={info.get('success',False)}")


Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


Planner DQN Training:   0%|          | 0/3000 [00:00<?, ?it/s]