<a href="https://colab.research.google.com/github/manivafapour/ppt_mani/blob/main/RL_3Episode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Standard imports, dataset (TRAINING_TASKS) and dirs
import os, random, copy, math, json, time
from typing import List, Dict, Any, Tuple
from pathlib import Path
from collections import deque
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from rich.console import Console
from rich.progress import track
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

console = Console()
os.makedirs("logs", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

# TRAINING_TASKS (modified to add a third "search-required" task based on HumanEval style)
TRAINING_TASKS = [
    {
        "task_id": "task1_correct",
        "prompt": "def is_palindrome(s: str) -> bool:\n    \"\"\"Return True if string is palindrome, else False.\"\"\"\n",
        "tests": [
            "assert is_palindrome('racecar') == True",
            "assert is_palindrome('hello') == False",
            "assert is_palindrome('') == True",
            "assert is_palindrome('a') == True",
            "assert is_palindrome('madam') == True"
        ],
        "correct_code": """def is_palindrome(s: str) -> bool:
    \"\"\"Return True if string is palindrome, else False.\"\"\"
    return s == s[::-1]""",
        "incorrect_code": """def is_palindrome(s: str) -> bool:
    \"\"\"Return True if string is palindrome, else False.\"\"\"
    return s == s[0]""",
        "debugged_code": """def is_palindrome(s: str) -> bool:
    \"\"\"Return True if string is palindrome, else False.\"\"\"
    return s == s[::-1]"""
    },
    {
        "task_id": "task2_incorrect",
        "prompt": "def find_max(nums: List[int]) -> int:\n    \"\"\"Return the maximum number in a list.\"\"\"\n",
        "tests": [
            "assert find_max([1, 2, 3, 4, 5]) == 5",
            "assert find_max([-1, -2, -3]) == -1",
            "assert find_max([10]) == 10",
            "assert find_max([5, 3, 9, 1, 7]) == 9",
            "assert find_max([0, 0, 0]) == 0"
        ],
        "correct_code": """def find_max(nums: List[int]) -> int:
    \"\"\"Return the maximum number in a list.\"\"\"
    if not nums:
        raise ValueError("List cannot be empty")
    return max(nums)""",
        "incorrect_code": """def find_max(nums: List[int]) -> int:
    \"\"\"Return the maximum number in a list.\"\"\"
    return nums[0]""",
        "debugged_code": """def find_max(nums: List[int]) -> int:
    \"\"\"Return the maximum number in a list.\"\"\"
    if not nums:
        raise ValueError("List cannot be empty")
    return max(nums)"""
    },
    {
        # New task requiring a "search" to get a hint / correct approach
        # Human-Eval style: small function; incorrect_code + debug behavior modeled below
        "task_id": "task3_search",
        "prompt": "def sum_unique(nums: List[int]) -> int:\n    \"\"\"Return the sum of elements that appear exactly once in the list.\"\"\"\n",
        "tests": [
            "assert sum_unique([1,2,2,3,4]) == 1+3+4",
            "assert sum_unique([]) == 0",
            "assert sum_unique([5,5,5]) == 0",
            "assert sum_unique([1,2,3]) == 6",
            "assert sum_unique([0,1,0,2,3,2]) == 1+3"
        ],
        "correct_code": """def sum_unique(nums: List[int]) -> int:
    \"\"\"Return the sum of elements that appear exactly once in the list.\"\"\"
    from collections import Counter
    c = Counter(nums)
    return sum(x for x, cnt in c.items() if cnt == 1)""",
        # incorrect_code: naive implementation that sums distinct numbers, not those that appear exactly once
        "incorrect_code": """def sum_unique(nums: List[int]) -> int:
    \"\"\"Return the sum of elements that appear exactly once in the list.\"\"\"
    # Wrong: sums unique set rather than only elements with count == 1
    return sum(set(nums))""",
        # debugged_code: correct implementation (what a successful debug+search yields)
        "debugged_code": """def sum_unique(nums: List[int]) -> int:
    \"\"\"Return the sum of elements that appear exactly once in the list.\"\"\"
    from collections import Counter
    c = Counter(nums)
    return sum(x for x, cnt in c.items() if cnt == 1)"""
    }
]

# Utility local test runner (copied from your code to run unit-tests on strings)
def run_tests_locally(code_str: str, tests: List[str]) -> Tuple[bool, Dict[str, Any]]:
    import traceback
    feedback = {"number_passed": 0, "total": len(tests), "failures": []}
    exec_globals = {}
    try:
        exec("from typing import List", exec_globals)
        exec(code_str, exec_globals)
    except Exception as e:
        return False, {"error": str(e), "trace": traceback.format_exc()}
    for test in tests:
        try:
            exec(test, exec_globals)
            feedback["number_passed"] += 1
        except AssertionError:
            feedback["failures"].append({"test": test, "error": "AssertionError"})
        except Exception as e:
            feedback["failures"].append({"test": test, "error": str(e)})
    return feedback["number_passed"] == feedback["total"], feedback

console.log("Dataset and basic utilities loaded.")


In [None]:
from enum import IntEnum

class Action(IntEnum):
    GENERATE = 0
    TEST     = 1
    DEBUG    = 2
    SEARCH   = 3   # NEW action
    STOP     = 4

# Human-readable names used by the environment + logging
Action.NAMES = {
    Action.GENERATE: "generate",
    Action.TEST:     "test",
    Action.DEBUG:    "debug",
    Action.SEARCH:   "search",
    Action.STOP:     "stop",
}

print("Action space:", Action.NAMES)


Action space: {<Action.GENERATE: 0>: 'generate', <Action.TEST: 1>: 'test', <Action.DEBUG: 2>: 'debug', <Action.SEARCH: 3>: 'search', <Action.STOP: 4>: 'stop'}


In [None]:
#cell-1 # STRICT finite-state environment — illegal actions used to terminate previously
# MODIFIED: add `invalid_action` flag and `last_test_passed`. illegal actions set invalid_action True.
# last_test_passed included in state so agent can mask test-after-pass.

class CodeGenEnv:
    def __init__(self, task: Dict, max_steps: int = 10):
        self.task = task
        self.max_steps = max_steps
        self.reset()

    def reset(self):
        self.steps = 0
        self.done = False
        self.success = False
        self.current_code = ""
        self.test_feedback = ""
        self.has_generated = False
        self.test_results = []  # list of bools
        self.debug_count = 0
        self.has_searched = False
        self.invalid_action = False  # NEW: any illegal action sets this True
        return self._state()

    def _fail(self):
        # SOFT FAIL: give a strong negative reward but do NOT force immediate termination.
        # This enables exploration to continue and the agent to learn the correct preconditions.
        return -5.0

    def step(self, action: int):
        if self.done:
            raise RuntimeError("Episode already finished")
        self.steps += 1
        reward = -0.1
        info = {"action": Action.NAMES[action]}

        # ---------- GENERATE ----------
        if action == Action.GENERATE:
            if self.has_generated:
                reward += self._fail()
                self.invalid_action = True
            else:
                self.current_code = sim_llm.generate_code(self.task["task_id"])
                self.has_generated = True
                reward += 1.0

        # ---------- TEST ----------
        elif action == Action.TEST:
            # illegal: testing before generation
            if not self.has_generated:
                reward += self._fail()
                self.invalid_action = True
            # illegal: too many tests
            elif len(self.test_results) >= 4:
                reward += self._fail()
                self.invalid_action = True
            # illegal: re-testing immediately after a passed test (prevents extra test shortcuts)
            elif self.test_results and self.test_results[-1] is True:
                reward += self._fail()
                self.invalid_action = True
            else:
                passed, fb = run_tests_locally(self.current_code, self.task["tests"])
                self.test_results.append(passed)
                self.test_feedback = f"Tests {fb['number_passed']}/{fb['total']}"
                reward += 1.0 if passed else -1.0

        # ---------- DEBUG ----------
        elif action == Action.DEBUG:
            # Only allowed when there was at least one test, last test failed, debug_count < 2,
            # and if this is the 2nd debug it must either have searched where required (handled later)
            if (
                not self.test_results
                or self.test_results[-1] is True
                or self.debug_count >= 2
                or (self.debug_count == 1 and not self.has_searched and len(self.test_results) != 2)
            ):
                reward += self._fail()
                self.invalid_action = True
            else:
                self.debug_count += 1
                self.current_code = sim_llm.debug_code(
                    self.task["task_id"],
                    self.current_code,
                    searched=self.has_searched
                )
                reward += 2.0

        # ---------- SEARCH ----------
        elif action == Action.SEARCH:
            if (
                self.has_searched
                or self.debug_count != 1
                or self.test_results != [False, False]
            ):
                reward += self._fail()
                self.invalid_action = True
            else:
                self.has_searched = True
                reward += 1.5

        # ---------- STOP ----------
        elif action == Action.STOP:
            self.done = True
            tid = self.task["task_id"]
            # Only allow success if no illegal actions occurred during the episode
            if tid != "task3_search" and self.test_results == [True] and not self.invalid_action:
                self.success = True
            elif (
                tid != "task3_search"
                and self.test_results == [False, True]
                and self.debug_count == 1
                and not self.invalid_action
            ):
                self.success = True
            elif (
                tid == "task3_search"
                and self.test_results == [False, False, True]
                and self.debug_count == 2
                and self.has_searched
                and not self.invalid_action
            ):
                self.success = True
            reward += 30.0 if self.success else -10.0

        # ---------- STEP LIMIT ----------
        if self.steps >= self.max_steps and not self.done:
            self.done = True
            reward -= 10.0

        info["success"] = self.success
        info["invalid_action"] = self.invalid_action  # debugging helper
        return self._state(), reward, self.done, info

    def _state(self):
        # include flags in the state embedding so the agent can easily learn the temporal preconditions
        last_test_passed = bool(self.test_results[-1]) if self.test_results else False
        return make_state_embedding(
            self.task["prompt"],
            self.current_code,
            self.test_feedback,
            has_generated=self.has_generated,
            debug_count=self.debug_count,
            has_searched=self.has_searched,
            test_count=len(self.test_results),
            invalid_action=self.invalid_action,
            last_test_passed=last_test_passed
        )


In [None]:
#cell-2
# Embedding (same approach as you used) - create state vector from prompt, code, and feedback
# MODIFIED: append small numeric flag vector (has_generated, debug_count, has_searched, test_count, invalid_action, last_test_passed)
# to make the temporal aspects explicit to the agent.

# NOTE: In Colab you may need to pip install sentence-transformers:
# !pip install -q sentence-transformers

from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

def make_state_embedding(task_prompt: str, last_code: str, last_test_feedback: str,
                         has_generated: bool = False, debug_count: int = 0,
                         has_searched: bool = False, test_count: int = 0,
                         invalid_action: bool = False, last_test_passed: bool = False) -> np.ndarray:
    task_emb = embedder.encode([task_prompt], show_progress_bar=False)
    code_emb = embedder.encode([last_code or ""], show_progress_bar=False)
    feedback_emb = embedder.encode([last_test_feedback or ""], show_progress_bar=False)
    state_vec = np.concatenate([task_emb[0], code_emb[0], feedback_emb[0]])
    # append simple numeric flags (small in magnitude)
    flags = np.array([1.0 if has_generated else 0.0,
                      float(debug_count),
                      1.0 if has_searched else 0.0,
                      float(test_count),
                      1.0 if invalid_action else 0.0,
                      1.0 if last_test_passed else 0.0], dtype=np.float32)
    return np.concatenate([state_vec, flags]).astype(np.float32)

# Quick smoke
sample_state = make_state_embedding(TRAINING_TASKS[0]["prompt"], "", "", False, 0, False, 0, False, False)
console.log(f"State dim {sample_state.shape[0]}")
STATE_DIM = sample_state.shape[0]


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Loading weights:   0%|          | 0/103 [00:00<?, ?it/s]

BertModel LOAD REPORT from: sentence-transformers/all-MiniLM-L6-v2
Key                     | Status     |  | 
------------------------+------------+--+-
embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
#cell-3
# SimulatedLLM with enforced SEARCH dependency

class SimulatedLLM:
    def __init__(self, tasks: List[Dict]):
        self.tasks = {task["task_id"]: task for task in tasks}

    def generate_code(self, task_id: str) -> str:
        task = self.tasks[task_id]
        return task["correct_code"] if task_id == "task1_correct" else task["incorrect_code"]

    def debug_code(self, task_id: str, current_code: str, searched: bool = False) -> str:
        task = self.tasks[task_id]

        if task_id == "task1_correct":
            return task["correct_code"]

        if task_id == "task2_incorrect":
            return task["debugged_code"]

        if task_id == "task3_search":
            if not searched:
                # STILL WRONG — guaranteed to fail tests
                return """def sum_unique(nums: List[int]) -> int:
    # Almost right but intentionally wrong
    from collections import Counter
    c = Counter(nums)
    return sum(x for x, cnt in c.items() if cnt >= 1)  # wrong condition
"""
            # Only after SEARCH can it be correct
            return task["debugged_code"]

        return current_code

    def search(self, task_id: str) -> bool:
        return task_id == "task3_search"


# 🔴 REQUIRED: instantiate the simulated LLM
sim_llm = SimulatedLLM(TRAINING_TASKS)


In [None]:
#cell-4
# ---------- DQN agent (Double + Dueling) ----------
# DQN with target network, replay buffer, epsilon-greedy, and prioritized-ish uniform sampling.
# MODIFIED: select_action now applies action masking derived from the flag vector at the end of state.

class DuelingQNetwork(nn.Module):
    def __init__(self, input_dim, hidden=[512,256], output_dim=len(Action.NAMES)):
        super().__init__()
        # shared trunk
        layers = []
        prev = input_dim
        for h in hidden:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        self.trunk = nn.Sequential(*layers)
        # value stream
        self.value_head = nn.Sequential(
            nn.Linear(prev, prev//2 if prev//2>0 else 32),
            nn.ReLU(),
            nn.Linear(prev//2 if prev//2>0 else 32, 1)
        )
        # advantage stream
        self.adv_head = nn.Sequential(
            nn.Linear(prev, prev//2 if prev//2>0 else 32),
            nn.ReLU(),
            nn.Linear(prev//2 if prev//2>0 else 32, output_dim)
        )

    def forward(self, x):
        x = self.trunk(x)
        value = self.value_head(x)
        adv = self.adv_head(x)
        # combine into Q-values: Q(s,a) = V(s) + (A(s,a) - mean_a A(s,a))
        q = value + (adv - adv.mean(dim=1, keepdim=True))
        return q

class ReplayBuffer:
    def __init__(self, capacity=8000):
        self.buffer = deque(maxlen=capacity)
    def push(self, s, a, r, ns, done):
        self.buffer.append((s, a, r, ns, done))
    def sample(self, batch_size):
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        s,a,r,ns,d = zip(*batch)
        return (np.stack(s), np.array(a), np.array(r, dtype=np.float32), np.stack(ns), np.array(d, dtype=np.float32))
    def __len__(self):
        return len(self.buffer)

class DQNAgent:
    def __init__(self, state_dim, action_dim=len(Action.NAMES), hidden=[512,256], lr=1e-4, gamma=0.99,
                 buffer_size=8000, batch_size=64, target_update=500, device=None, mask_actions=True):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.qnet = DuelingQNetwork(state_dim, hidden, action_dim).to(self.device)
        self.target = DuelingQNetwork(state_dim, hidden, action_dim).to(self.device)
        self.target.load_state_dict(self.qnet.state_dict())
        self.opt = Adam(self.qnet.parameters(), lr=lr)
        self.gamma = gamma
        self.buffer = ReplayBuffer(capacity=buffer_size)
        self.batch_size = batch_size
        self.action_dim = action_dim
        self.eps = 1.0
        self.eps_min = 0.05
        self.eps_decay = 0.995
        self.learn_steps = 0
        self.target_update = target_update
        self.mask_actions = mask_actions

    def _compute_action_mask(self, state: np.ndarray):
        # state ends with six flags: [has_generated, debug_count, has_searched, test_count, invalid_action, last_test_passed]
        # assume state is numpy array
        flags = state[-6:]
        has_generated = bool(flags[0])
        debug_count = int(flags[1])
        has_searched = bool(flags[2])
        test_count = int(flags[3])
        invalid_action = bool(flags[4])
        last_test_passed = bool(flags[5])

        mask = np.ones(self.action_dim, dtype=bool)  # True = allowed

        # GENERATE allowed only if not generated yet
        if has_generated:
            mask[Action.GENERATE] = False

        # TEST allowed only if generated, not too many tests, and last test was not a pass
        if (not has_generated) or test_count >= 4 or last_test_passed:
            mask[Action.TEST] = False

        # DEBUG allowed only if there is at least one test and last test failed, debug_count < 2, and second debug requires search (handled conservatively)
        if test_count == 0 or last_test_passed or debug_count >= 2:
            mask[Action.DEBUG] = False
        else:
            # If this is the second debug (debug_count==1), require that test_count == 2 (we conservatively require two tests before second debug)
            if debug_count == 1 and test_count != 2:
                mask[Action.DEBUG] = False

        # SEARCH allowed only if we have exactly two tests, both failed (we approximate using test_count==2 and last_test_passed==False), and debug_count==1 and not already searched
        if not (debug_count == 1 and test_count == 2 and (not last_test_passed) and (not has_searched)):
            mask[Action.SEARCH] = False

        # STOP allowed only if last test passed and no invalid action (conservative)
        if not (last_test_passed and (not invalid_action)):
            mask[Action.STOP] = False

        return mask

    def select_action(self, state):
        # state: np.array
        if random.random() < self.eps:
            # when exploring, respect mask by sampling only legal actions when possible
            if self.mask_actions:
                mask = self._compute_action_mask(state)
                legal_indices = np.flatnonzero(mask)
                if len(legal_indices) > 0:
                    return int(np.random.choice(legal_indices))
                # fallthrough to random full action
            return random.randrange(self.action_dim)

        s = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            qvals = self.qnet(s).cpu().numpy().squeeze(0)
        if self.mask_actions:
            mask = self._compute_action_mask(state)
            legal_q = np.where(mask, qvals, -1e9)  # very low for illegal actions
            if legal_q.max() <= -1e8:
                # no legal action (should be rare) -> fallback to random
                return random.randrange(self.action_dim)
            return int(int(legal_q.argmax()))
        else:
            return int(int(qvals.argmax()))

    def push_transition(self, s,a,r,ns,done):
        self.buffer.push(s,a,r,ns,done)

    def train_step(self):
        if len(self.buffer) < 32:
            return 0.0
        s,a,r,ns,d = self.buffer.sample(self.batch_size)
        s = torch.FloatTensor(s).to(self.device)
        a = torch.LongTensor(a).to(self.device)
        r = torch.FloatTensor(r).to(self.device)
        ns = torch.FloatTensor(ns).to(self.device)
        d = torch.FloatTensor(d).to(self.device)

        # current Q-values
        qvals = self.qnet(s).gather(1, a.unsqueeze(1)).squeeze(1)

        # ---- Double DQN target calculation ----
        # use online network to select best next action, use target network to evaluate its Q
        with torch.no_grad():
            next_actions = self.qnet(ns).argmax(dim=1, keepdim=True)  # shape (batch,1)
            next_q_target = self.target(ns).gather(1, next_actions).squeeze(1)
            target = r + self.gamma * next_q_target * (1 - d)

        loss = F.mse_loss(qvals, target)
        self.opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.qnet.parameters(), 0.5)
        self.opt.step()

        self.learn_steps += 1
        if self.learn_steps % self.target_update == 0:
            self.target.load_state_dict(self.qnet.state_dict())

        # decay epsilon
        self.eps = max(self.eps_min, self.eps * self.eps_decay)
        return float(loss.item())

console.log("DQN agent (Double + Dueling + Action masking) ready.")


In [None]:
#cell-5
# Training loop for DQN and logging / metrics (plots)
def train_dqn(agent: DQNAgent, tasks: List[Dict], num_episodes=1200, max_steps=10, log_every=50,
              seed_expert_episodes_per_task=8):
    metrics = {
        "episode_rewards": [],
        "success_rate_window": [],
        "episode_steps": [],
        "losses": [],
        "sample_eff": []
    }
    success_history = deque(maxlen=100)
    total_env_steps = 0

    # --- Seed replay buffer with expert episodes (scripted correct trajectories) ---
    ACTION_NAME_TO_INT = {v:k for k,v in Action.NAMES.items()}

    expert_seqs = {
        "task1_correct": ["generate","test","stop"],
        "task2_incorrect": ["generate","test","debug","test","stop"],
        # For task3, the correct sequence is: gen -> test(fail) -> debug -> test(fail) -> search -> debug -> test(pass) -> stop
        "task3_search": ["generate","test","debug","test","search","debug","test","stop"]
    }

    # Push several expert episodes into the buffer
    for task in tasks:
        seq = expert_seqs.get(task["task_id"], None)
        if seq is None:
            continue
        for _ in range(seed_expert_episodes_per_task):
            env = CodeGenEnv(task, max_steps=max_steps)
            s = env.reset()
            for action_name in seq:
                a = ACTION_NAME_TO_INT[action_name]
                ns, r, done, info = env.step(a)
                agent.push_transition(s, a, r, ns, float(done))
                s = ns
                if done:
                    break

    best_success = -1.0  # track best "success(last100)" to save best model
    best_avg_reward = -1e9

    for ep in tqdm(range(1, num_episodes+1), desc="DQN Training"):
        task = random.choice(tasks)
        env = CodeGenEnv(task, max_steps=max_steps)
        state = env.reset()
        ep_reward = 0.0
        ep_loss = 0.0

        for t in range(max_steps):
            action = agent.select_action(state)
            next_state, reward, done, info = env.step(action)
            agent.push_transition(state, action, reward, next_state, float(done))
            loss = agent.train_step()
            state = next_state
            ep_reward += reward
            ep_loss += loss
            total_env_steps += 1
            if done:
                break

        metrics["episode_rewards"].append(ep_reward)
        metrics["losses"].append(ep_loss / (t+1) if (t+1)>0 else 0.0)
        metrics["episode_steps"].append(t+1)
        success_history.append(1.0 if done and info.get("success", False) else 0.0)

        metrics["success_rate_window"].append(np.mean(success_history))
        metrics["sample_eff"].append(sum(metrics["episode_rewards"])/ (total_env_steps + 1e-8))

        # check for best model (prefer success_rate over avg reward)
        current_success = metrics["success_rate_window"][-1]
        current_avg_reward = np.mean(metrics["episode_rewards"][-log_every:])
        if current_success > best_success or (current_success == best_success and current_avg_reward > best_avg_reward):
            best_success = current_success
            best_avg_reward = current_avg_reward
            # save best model
            torch.save(agent.qnet.state_dict(), "saved_models/dqn_best.pt")
            console.log(f"[green]Ep {ep:4d} | New best model saved: success(last100)={best_success:.3%}, avg_reward={best_avg_reward:.3f}[/green]")

        if ep % log_every == 0 or ep == 1:
            console.log(f"[blue]Ep {ep:4d} | AvgReward {np.mean(metrics['episode_rewards'][-log_every:]):.3f} | "
                        f"Success(last100) {metrics['success_rate_window'][-1]:.3%} | Eps {agent.eps:.3f}")

    return metrics

# Quick run for smoke/training (lower episodes). Increase num_episodes to train more.
agent = DQNAgent(state_dim=STATE_DIM, action_dim=len(Action.NAMES), hidden=[512,256], lr=3e-4, buffer_size=8000, batch_size=64, target_update=200, mask_actions=True)
metrics = train_dqn(agent, TRAINING_TASKS, num_episodes=500, max_steps=10, log_every=50)

# Save final model (also keep best which was saved during training)
torch.save(agent.qnet.state_dict(), "saved_models/dqn_final.pt")
console.log("Training complete and final model saved.")


DQN Training:   0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
#cell-7
# Save final agent (and quick test on training tasks)
# Modified to load the best model saved during training (falls back to final if not present)

best_path = "saved_models/dqn_best.pt"
final_path = "saved_models/dqn_final.pt"

# load best if available
if os.path.exists(best_path):
    try:
        agent.qnet.load_state_dict(torch.load(best_path, map_location=agent.device))
        console.log(f"Loaded best model from {best_path}")
    except Exception as e:
        console.log(f"[red]Failed to load best model ({e}), loading final model if present.[/red]")
        if os.path.exists(final_path):
            agent.qnet.load_state_dict(torch.load(final_path, map_location=agent.device))
            console.log(f"Loaded final model from {final_path}")
else:
    if os.path.exists(final_path):
        agent.qnet.load_state_dict(torch.load(final_path, map_location=agent.device))
        console.log(f"No best model found; loaded final model from {final_path}")
    else:
        console.log("[yellow]No saved model found; running with current agent weights.[/yellow]")

torch.save(agent.qnet.state_dict(), "saved_models/dqn_final.pt")
console.log("Final model saved to saved_models/dqn_final.pt")

# Quick run: show three example rollouts per task with printed paths
for task in TRAINING_TASKS:
    console.log(f"[bold]Demo rollouts for {task['task_id']}[/bold]")
    for i in range(3):
        env = CodeGenEnv(task, max_steps=10)
        s = env.reset()
        path = []
        for _ in range(12):
            a = agent.select_action(s)
            path.append(Action.NAMES[a])
            s, r, done, info = env.step(a)
            if done: break
        console.log(f"  Path {i+1}: {path} | success={info.get('success',False)}")
