In [5]:
import os
import time
import requests
import random
import numpy as np
import inspect
import math
import glob
import re
import torch
from concurrent.futures import ThreadPoolExecutor, as_completed
import torch.nn as nn
import gymnasium as gym

# ----------------------------
# 配置 LLM（保持你原有的 KEY/URL）
# ----------------------------
LLM_URL = 'https://api.yesapikey.com/v1/chat/completions'
LLM_HEADERS = {
    'Content-Type': 'application/json',
    'Authorization': 'Bearer sk-oUmlgJFV5BaTBy8y7048F0E4Af2b4031AdA7B24037F9Bd71'
}

def call_llm(prompt, model="gpt-4.1-2025-04-14", temperature=0.2, max_tokens=1024):
    data = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "temperature": temperature,
        "max_tokens": max_tokens
    }
    while True:
        try:
            response = requests.post(LLM_URL, json=data, headers=LLM_HEADERS, timeout=60)
            if response.status_code == 200:
                resp_json = response.json()
                if 'choices' in resp_json and resp_json['choices']:
                    content = resp_json['choices'][0].get('message', {}).get('content')
                    return content
            else:
                print("[LLM HTTP]", response.status_code, response.text[:200])
        except Exception as e:
            print("[LLM Exception]", e)
        time.sleep(2)

# ----------------------------
# 环境文档映射 & 静态知识
# ----------------------------
ENV_DOC_URL = {
    "Acrobot-v1": "https://gymnasium.farama.org/environments/classic_control/acrobot/ ",
    "CartPole-v1": "https://gymnasium.farama.org/environments/classic_control/cart_pole/ ",
    "MountainCarContinuous-v0": "https://gymnasium.farama.org/environments/classic_control/mountain_car_continuous/ ",
    "MountainCar-v0": "https://gymnasium.farama.org/environments/classic_control/mountain_car/ ",
    "Pendulum-v1": "https://gymnasium.farama.org/environments/classic_control/pendulum/ "
}

def get_env_doc_url(env_id: str) -> str:
    return ENV_DOC_URL.get(env_id, "https://gymnasium.farama.org/ ")

STATIC_KNOWLEDGE = {
    "CartPole-v1": {
        "state_dim": 4,
        "state_vars": ["cart_position", "cart_velocity", "pole_angle", "pole_velocity"],
        "state_ranges": [(-4.8, 4.8), (-float("inf"), float("inf")), (-0.418, 0.418), (-float("inf"), float("inf"))],
        "action_space": [0, 1],
        "reward_threshold": 475,
        "action_type": "discrete"
    },
    "Acrobot-v1": {
        "state_dim": 6,
        "state_vars": ["cos_theta1", "sin_theta1", "cos_theta2", "sin_theta2", "theta1_dot", "theta2_dot"],
        "state_ranges": [(-1, 1), (-1, 1), (-1, 1), (-1, 1), (-float("inf"), float("inf")), (-float("inf"), float("inf"))],
        "action_space": [0, 1, 2],
        "reward_threshold": -100,
        "action_type": "discrete"
    },
    "MountainCar-v0": {
        "state_dim": 2,
        "state_vars": ["position", "velocity"],
        "state_ranges": [(-1.2, 0.6), (-0.07, 0.07)],
        "action_space": [0, 1, 2],
        "reward_threshold": -110,
        "action_type": "discrete"
    },
    "MountainCarContinuous-v0": {
        "state_dim": 2,
        "state_vars": ["position", "velocity"],
        "state_ranges": [(-1.2, 0.6), (-0.07, 0.07)],
        "action_space": [-1.0, 1.0],
        "reward_threshold": 90,
        "action_type": "continuous"
    },
    "Pendulum-v1": {
        "state_dim": 3,
        "state_vars": ["cos_theta", "sin_theta", "theta_dot"],
        "state_ranges": [(-1, 1), (-1, 1), (-float("inf"), float("inf"))],
        "action_space": [-2.0, 2.0],
        "reward_threshold": -200,
        "action_type": "continuous"
    }
}

EXPERTS_DIR = "experts"

# ----------------------------
# Knowledge / Memory 类（基本不变）
# ----------------------------
class Knowledge:
    def __init__(self):
        self.static_knowledge = {}
        self.dynamic_knowledge = []

    def load_static_knowledge(self, env_id):
        if env_id not in STATIC_KNOWLEDGE:
            raise ValueError("Unsupported environment")
        self.static_knowledge = STATIC_KNOWLEDGE[env_id]
        self.dynamic_knowledge = []

    def add_dynamic_entry(self, entry):
        self.dynamic_knowledge.append(entry)

    def get_dynamic_guidance(self, env_id):
        prompt = f"""
I am generating a policy in environment {env_id}.
Current dynamic knowledge entries: {self.dynamic_knowledge}

Focus on environment principles, physics, and dynamics, not superficial patterns.
Please provide concise heuristic suggestions for policy generation based on this knowledge, such as:
- State ranges to prioritize
- Common failing action patterns
- Recommended threshold adjustments

Return a short, structured bullet list (no prose).
"""
        guidance = call_llm(prompt)
        return guidance

class Memory:
    def __init__(self):
        self.episodes = []

    def start_episode(self):
        self.episodes.append({"steps": [], "summary": None})

    def add_step(self, s, a, r, done):
        if not self.episodes:
            raise ValueError("Please call start_episode() before adding steps!")
        self.episodes[-1]["steps"].append({"s": np.array(s).tolist(), "a": (np.array(a).tolist() if not np.isscalar(a) else float(a)), "r": float(r), "done": bool(done)})

    def add_episode_summary(self, env_id, policy_version):
        if not self.episodes:
            raise ValueError("No running episode!")
        steps = self.episodes[-1]["steps"]
        total_reward = sum(step["r"] for step in steps)
        length = len(steps)
        self.episodes[-1]["summary"] = {
            "env_id": env_id,
            "policy_version": policy_version,
            "return": total_reward,
            "length": length
        }

    def get_recent_episodes(self, n=5):
        summaries = [ep["summary"] for ep in self.episodes if ep["summary"] is not None]
        return summaries[-n:]

# ----------------------------
# Expert loader / Lightweight wrapper（改进版）
# ----------------------------
def find_expert_file_for_env(env_id):
    patterns = [
        f"{env_id}_best_expert.*",
        f"expert_{env_id}*",
        f"{env_id}*cleanrl_model",
        f"*{env_id}*cleanrl_model",
    ]
    for pat in patterns:
        matches = glob.glob(os.path.join(EXPERTS_DIR, pat))
        if matches:
            matches.sort(key=lambda p: os.path.getmtime(p), reverse=True)
            return matches[0]
    return None

def _unwrap_state_dict(maybe_dict):
    """If saved the checkpoint was a dict wrapper, try to unwrap known fields."""
    if isinstance(maybe_dict, dict):
        for key in ("state_dict", "model_state_dict", "policy_state_dict", "params", "model"):
            if key in maybe_dict and isinstance(maybe_dict[key], dict):
                return maybe_dict[key]
        return maybe_dict
    else:
        return None

# Define PPO_Agent, SAC_Actor, and DQN_MLP classes
class PPO_Agent(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
            torch.nn.init.orthogonal_(layer.weight, std)
            torch.nn.init.constant_(layer.bias, bias_const)
            return layer

        self.critic = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0),
        )
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(obs_dim, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, act_dim), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, act_dim))

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        dist = torch.distributions.Normal(action_mean, action_std)
        if action is None:
            action = dist.sample()
        logprob = dist.log_prob(action).sum(1)
        entropy = dist.entropy().sum(1)
        value = self.critic(x)
        return action, logprob, entropy, value

class SAC_Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, action_low, action_high):
        super().__init__()
        self.fc1 = nn.Linear(obs_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc_mean = nn.Linear(256, act_dim)
        self.fc_logstd = nn.Linear(256, act_dim)
        action_scale = (action_high - action_low) / 2.0
        action_bias = (action_high + action_low) / 2.0
        self.register_buffer("action_scale", torch.tensor(action_scale, dtype=torch.float32))
        self.register_buffer("action_bias", torch.tensor(action_bias, dtype=torch.float32))

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        mean = self.fc_mean(x)
        log_std = self.fc_logstd(x)
        log_std = torch.tanh(log_std)
        LOG_STD_MAX = 2
        LOG_STD_MIN = -5
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1)
        return mean, log_std

    def get_action(self, x):
        mean, log_std = self.forward(x)
        std = torch.exp(log_std)
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t).sum(1, keepdim=True)
        return action, log_prob, torch.tanh(mean) * self.action_scale + self.action_bias

class DQN_MLP(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(obs_dim, 120), nn.ReLU(),
            nn.Linear(120, 84), nn.ReLU(),
            nn.Linear(84, act_dim)
        )

    def forward(self, x):
        return self.network(x)

class ExpertWrapper:
    def __init__(self, env_id, model_path, device="cpu"):
        self.env_id = env_id
        self.model_path = model_path
        self.device = device
        self.model_type = None
        self.policy = None
        self.loaded = False
        sk = STATIC_KNOWLEDGE[env_id]
        self.obs_dim = sk["state_dim"]
        self.action_type = sk["action_type"]
        self.act_dim = len(sk["action_space"]) if sk["action_type"] == "discrete" else 1
        self.action_low = sk["action_space"][0] if sk["action_type"] == "continuous" else None
        self.action_high = sk["action_space"][-1] if sk["action_type"] == "continuous" else None

    def _load(self):
        print(f"[Info] loading expert from {self.model_path}")
        state = torch.load(self.model_path, map_location="cpu")
        keys = list(state.keys())
        print("[Info] Model keys:", keys[:10])  # Debug print
        if any(k.startswith('fc1') or k.startswith('fc_mean') for k in keys):
            self.model_type = "sac_actor"
            print("[Info] Detected checkpoint type: SAC")
            self.policy = SAC_Actor(self.obs_dim, self.act_dim, self.action_low, self.action_high)
        elif any(k == 'actor_logstd' or k.startswith('actor_mean') for k in keys):
            self.model_type = "ppo_policy"
            print("[Info] Detected checkpoint type: PPO")
            self.policy = PPO_Agent(self.obs_dim, self.act_dim)
        elif any(k.startswith('network.') for k in keys):
            self.model_type = "dqn"
            print("[Info] Detected checkpoint type: DQN")
            self.policy = DQN_MLP(self.obs_dim, self.act_dim)
        else:
            raise ValueError(f"Unknown expert checkpoint format: keys={keys[:10]}")
        self.policy.load_state_dict(state, strict=False)  # Use strict=False to tolerate minor mismatches
        self.policy.to(self.device)
        self.policy.eval()
        self.loaded = True
        print(f"[Info] Expert loaded ({self.model_type})")

    def act(self, obs, deterministic=True):
        if not self.loaded:
            self._load()
        obs = torch.as_tensor(obs, dtype=torch.float32).to(self.device)
        if obs.ndim == 1:
            obs = obs.unsqueeze(0)

        if self.model_type == "ppo_policy":
            with torch.no_grad():
                action, _, _, _ = self.policy.get_action_and_value(obs)
                action = action.cpu().numpy()[0]
        elif self.model_type == "sac_actor":
            with torch.no_grad():
                _, _, mean = self.policy.get_action(obs)
                action = mean.cpu().numpy()[0] if deterministic else self.policy.get_action(obs)[0].cpu().numpy()[0]
        elif self.model_type == "dqn":
            with torch.no_grad():
                q = self.policy(obs)
                action = torch.argmax(q, dim=1).cpu().numpy()[0]
        else:
            raise RuntimeError("Unsupported expert type")

        return action

    def run_episodes(self, episodes=3, render=False):
        env = self.make_env_with_wrappers()
        if render:
            env = gym.make(self.env_id, render_mode="human")
        returns = []
        for _ in range(episodes):
            obs, _ = env.reset()
            done = False
            total_r = 0
            while not done:
                action = self.act(obs, deterministic=True)
                obs, r, terminated, truncated, _ = env.step(action)
                total_r += r
                done = terminated or truncated
        returns.append(total_r)
        env.close()
        return returns

    def make_env_with_wrappers(self):
        env = gym.make(self.env_id)
        if self.model_type == "ppo_policy":  # Apply PPO wrappers if PPO
            try:
                env = gym.wrappers.FlattenObservation(env)
            except Exception:
                pass
            try:
                env = gym.wrappers.ClipAction(env)
            except Exception:
                pass
            try:
                env = gym.wrappers.NormalizeObservation(env)
            except Exception:
                pass
            try:
                env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
            except Exception:
                pass
        return env

def evaluate_expert(env, expert: ExpertWrapper, n_eval_episodes=3):
    """Evaluate expert policy with debug output"""
    returns = []
    for ep in range(n_eval_episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0
        steps = 0
        while not done:
            action = expert.act(obs, deterministic=True)
            if steps == 0:
                print(f"[ExpertEval debug] sample obs: {obs}")
                print(f"[ExpertEval debug] sample action: {action}")
            obs, reward, terminated, truncated, _ = env.step(action)
            total_reward += reward
            steps += 1
            done = terminated or truncated
        returns.append(total_reward)
        print(f"[ExpertEval] Episode {ep+1}: total_reward={total_reward:.3f}, steps={steps}")
    avg_return = sum(returns) / len(returns)
    print(f"[ExpertEval] Avg Return over {n_eval_episodes} eps: {avg_return:.2f}")
    return avg_return

# ----------------------------
# Reflection (集成 expert)
# ----------------------------
class Reflection:
    def __init__(self, knowledge: Knowledge, use_expert=True):
        self.knowledge = knowledge
        self.use_expert = use_expert
        self._expert_cache = {}

    def load_expert_for_env(self, env_id):
        if not self.use_expert:
            return None
        if env_id in self._expert_cache:
            return self._expert_cache[env_id]
        path = find_expert_file_for_env(env_id)
        if path is None:
            self._expert_cache[env_id] = None
            return None
        try:
            wrapper = ExpertWrapper(env_id, path)
            self._expert_cache[env_id] = wrapper
            return wrapper
        except Exception as e:
            print("[Reflection] failed to load expert:", e)
            self._expert_cache[env_id] = None
            return None

    def metrics(self, recent_episodes):
        returns = [ep["return"] for ep in recent_episodes]
        lengths = [ep["length"] for ep in recent_episodes]
        avg_return = np.mean(returns) if returns else 0.0
        avg_length = np.mean(lengths) if lengths else 0.0
        threshold = self.knowledge.static_knowledge.get("reward_threshold", 0.0)
        success_count = sum(1 for ep in recent_episodes if ep["return"] >= threshold)
        success_rate = (success_count / len(recent_episodes)) if recent_episodes else 0.0
        return {"avg_return": float(avg_return), "avg_length": float(avg_length), "success_rate": float(success_rate)}

    def failure_pattern(self, recent_episodes, env_id):
        expert = self.load_expert_for_env(env_id)
        expert_summary = None
        if expert:
            try:
                expert_rs = expert.run_episodes(episodes=3, render=False)
                expert_summary = {"expert_returns": expert_rs, "expert_mean": float(np.mean(expert_rs))}
            except Exception as e:
                expert_summary = {"error": str(e)}
        prompt = f"""
I have the following {env_id} environment episode summaries: {recent_episodes}
Expert performance (deterministic eval) if available: {expert_summary}
Please analyze the most common failure patterns, including state characteristics, action issues, and return patterns. Focus only on key points.
Return a concise paragraph.
"""
        pattern = call_llm(prompt).strip()
        self.knowledge.add_dynamic_entry({"env_id": env_id, "failure_pattern": pattern})
        return pattern

    def edit_suggestion(self, recent_episodes, env_id):
        expert = self.load_expert_for_env(env_id)
        expert_summary = None
        if expert:
            try:
                expert_rs = expert.run_episodes(episodes=3, render=False)
                expert_summary = {"expert_returns": expert_rs, "expert_mean": float(np.mean(expert_rs))}
            except Exception as e:
                expert_summary = {"error": str(e)}
        prompt = f"""
Based on recent episode data from environment {env_id}: {recent_episodes}
Expert performance (deterministic eval) if available: {expert_summary}
Generate one policy editing suggestion in one of the following formats:
- add_rule(condition -> action)
- modify_threshold(variable, old_value, new_value)
- reprioritize(rule_i over rule_j)

Return exactly one line with one edit.
"""
        suggestion = call_llm(prompt).strip()
        self.knowledge.add_dynamic_entry({"env_id": env_id, "edit_suggestion": suggestion})
        return suggestion

# ----------------------------
# Helpers: action constraint text, code generation, compilation (保持你原始实现)
# ----------------------------
def _action_constraints_text(static_knowledge: dict) -> str:
    a = static_knowledge["action_space"]
    if static_knowledge.get("action_type") == "discrete":
        return f"Discrete actions; valid actions are exactly the integers in {a}."
    else:
        lo, hi = a[0], a[1]
        return f"Continuous action; return a single float within [{lo}, {hi}]. Clip if necessary."

def generate_rule_policy_code(env_id, knowledge: Knowledge):
    sk = knowledge.static_knowledge
    guidance = knowledge.get_dynamic_guidance(env_id) or ""
    doc_url = get_env_doc_url(env_id)
    action_desc = _action_constraints_text(sk)
    state_vars_text = "\n".join([f"- {name} in range {rng}" for name, rng in zip(sk["state_vars"], sk["state_ranges"])])

    prompt = f"""
You are writing a deterministic, white-box **rule-based policy** for Gymnasium environment "{env_id}".
Focus on **environment principles, physics, and dynamics**, not superficial patterns.
The policy must be based on simple if-else statements or threshold comparisons using state variables.
Environment documentation: {doc_url}

Observation (state vector):
{state_vars_text}

Action constraints:
- {action_desc}
- May import 'math' if needed
- Must be deterministic
- Do not use loops, functions, or external libraries except math
- Example (discrete): if state[2] > 0: return 1 else: return 0
- Example (continuous): return max(min(k1*state[1]-k2*state[0], hi), lo)

Dynamic guidance:
{guidance}

Output requirements:
- Only one Python function: def policy(state): ...
- No explanations, no markdown, no print
- Returned action strictly satisfies constraints
"""
    code = call_llm(prompt)
    return code

def compile_policy_or_default(code, sk):
    local_vars = {"math": math, "np": np}
    try:
        exec(code, local_vars)
        policy_fn = local_vars.get("policy")
        if policy_fn is None:
            raise ValueError("No function 'policy' found")
        return policy_fn
    except Exception:
        if sk["action_type"] == "discrete":
            def policy(state): return sk["action_space"][0]
        else:
            lo, hi = sk["action_space"]
            def policy(state): return 0.5 * (lo + hi)
        return policy

def generate_base_policies(env_id, knowledge: Knowledge, n_candidates=3):
    sk = knowledge.static_knowledge
    fns = []
    for _ in range(n_candidates):
        code = generate_rule_policy_code(env_id, knowledge)
        fns.append(compile_policy_or_default(code, sk))
    return fns

def apply_edit(policy_fn, edit_text, knowledge: Knowledge, env_id: str):
    sk = knowledge.static_knowledge
    try:
        existing_src = inspect.getsource(policy_fn)
    except Exception:
        if sk["action_type"] == "discrete":
            existing_src = "def policy(state):\n    return " + str(sk["action_space"][0])
        else:
            lo, hi = sk["action_space"]
            existing_src = "def policy(state):\n    return " + str(0.5*(lo+hi))

    action_desc = _action_constraints_text(sk)
    doc_url = get_env_doc_url(env_id)

    prompt = f"""
Revise deterministic, **rule-based** policy for environment at {doc_url}.
Focus on physics, dynamics, and environment principles.
Constraints: {action_desc}
Current policy:
{existing_src}

Edit suggestion: {edit_text}

Rules:
- Keep it deterministic and rule-based (if-else / threshold)
- Only output a single valid Python function: def policy(state): ...
- You may use 'math'
"""
    code = call_llm(prompt)
    return compile_policy_or_default(code, sk)

# ----------------------------
# Safe step & Safe policy call（保持原实现）
# ----------------------------
def safe_step(env, action):
    sk = STATIC_KNOWLEDGE[env.unwrapped.spec.id]
    if sk["action_type"] == "continuous":
        lo, hi = sk["action_space"]
        if np.isscalar(action):
            action = np.array([np.clip(action, lo, hi)], dtype=np.float32)
        else:
            action = np.clip(np.array(action, dtype=np.float32), lo, hi)
    return env.step(action)

def safe_policy_call(state, policy_fn, sk, env):
    try:
        a = policy_fn(state)
    except Exception:
        if sk["action_type"] == "discrete":
            a = random.choice(sk["action_space"])
        else:
            lo, hi = sk["action_space"]
            a = 0.5 * (lo + hi)

    if sk["action_type"] == "continuous":
        lo, hi = sk["action_space"]
        a = float(np.clip(a, lo, hi))
        return a

    uw = getattr(env, "unwrapped", env)
    can_clone = hasattr(uw, "clone_state") and hasattr(uw, "restore_state")
    if not can_clone:
        return a

    best_a, best_r = a, -float('inf')
    try:
        snapshot = uw.clone_state()
    except Exception:
        return a
    for cand in sk["action_space"]:
        try:
            obs2, r, term, trunc, _ = safe_step(env, cand)
            uw.restore_state(snapshot)
            if r > best_r:
                best_r, best_a = r, cand
        except Exception:
            try:
                uw.restore_state(snapshot)
            except Exception:
                pass
            continue
    return best_a

# ----------------------------
# Evaluation helpers (eval_policy_once, parallel_eval_candidates)
# ----------------------------
def eval_policy_once(env_id, policy_fn, episodes=5, use_mcts=True, no_trunc_reset_for=None):
    knowledge = Knowledge()
    knowledge.load_static_knowledge(env_id)
    sk = knowledge.static_knowledge
    mem = Memory()

    env = gym.make(env_id)
    trunc_sensitive = (env_id not in (no_trunc_reset_for or set()))

    for _ in range(episodes):
        s, _ = env.reset()
        done = False
        mem.start_episode()
        while not done:
            if use_mcts:
                a = safe_policy_call(s, policy_fn, sk, env)
            else:
                try:
                    a = policy_fn(s)
                except Exception:
                    if sk["action_type"] == "discrete":
                        a = random.choice(sk["action_space"])
                    else:
                        lo, hi = sk["action_space"]
                        a = 0.5 * (lo + hi)
            s_next, r, terminated, truncated, info = safe_step(env, a)
            done = terminated or truncated
            mem.add_step(s, a, r, done)
            s = s_next
        mem.add_episode_summary(env_id, policy_version=0)

    env.close()
    refl = Reflection(knowledge)
    return refl.metrics(mem.get_recent_episodes(n=episodes))

def parallel_eval_candidates(env_id, policy_fns, episodes_each=5):
    results = []
    no_trunc_reset_for = {"Acrobot-v1", "Pendulum-v1"}
    with ThreadPoolExecutor(max_workers=min(len(policy_fns), 3)) as ex:
        fut2fn = {
            ex.submit(eval_policy_once, env_id, fn, episodes_each, True, no_trunc_reset_for): fn
            for fn in policy_fns
        }
        for fut in as_completed(fut2fn):
            fn = fut2fn[fut]
            try:
                metrics = fut.result()
            except Exception as e:
                metrics = {"avg_return": -1e9, "avg_length": 0.0, "success_rate": 0.0}
                print("[Parallel Eval Exception]", e)
            results.append((fn, metrics))
    return results

# ----------------------------
# PolicyPool 类（原始实现）
# ----------------------------
class PolicyPool:
    def __init__(self, max_size=5):
        self.policies = []
        self.max_size = max_size
        self.counts = []
        self.values = []

    def add_policy(self, policy_fn, version, metrics=None):
        entry = {"fn": policy_fn, "version": version, "metrics": metrics}
        if len(self.policies) < self.max_size:
            self.policies.append(entry)
            self.counts.append(0)
            self.values.append(0.0)
        else:
            idx = self.get_worst_policy_idx()
            self.policies[idx] = entry
            self.counts[idx] = 0
            self.values[idx] = 0.0

    def select_policy_ucb(self, c=1.0):
        if not self.policies:
            raise RuntimeError("PolicyPool is empty")
        total_counts = sum(self.counts) + 1
        ucb_scores = []
        for i in range(len(self.policies)):
            if self.counts[i] == 0:
                ucb_scores.append(float('inf'))
            else:
                ucb_scores.append(self.values[i] + c * math.sqrt(math.log(total_counts)/self.counts[i]))
        idx = int(np.argmax(ucb_scores))
        self.counts[idx] += 1
        return self.policies[idx]["fn"], idx

    def update_policy_value(self, idx, reward):
        n = self.counts[idx]
        if n <= 0:
            self.values[idx] = reward
        else:
            self.values[idx] = ((n-1)/n)*self.values[idx] + (1/n)*reward

    def get_worst_policy_idx(self):
        vals = []
        for p in self.policies:
            if p["metrics"] and "avg_return" in p["metrics"]:
                vals.append(p["metrics"]["avg_return"])
            else:
                vals.append(-float('inf'))
        return int(np.argmin(vals))

# ----------------------------
# 主闭环训练/搜索循环（run_env_loop）
# ----------------------------
def run_env_loop(env_id, max_iters=10, episodes_per_iter=10, ma_window=3,
                 success_rate_threshold=0.8, pool_size=5, n_init_candidates=3):
    knowledge = Knowledge()
    knowledge.load_static_knowledge(env_id)
    memory = Memory()
    reflection = Reflection(knowledge, use_expert=True)
    policy_version = 0
    policy_pool = PolicyPool(max_size=pool_size)

    # initial candidates
    init_fns = generate_base_policies(env_id, knowledge, n_candidates=n_init_candidates)
    evaluated = parallel_eval_candidates(env_id, init_fns, episodes_each=max(2, episodes_per_iter//2))
    evaluated.sort(key=lambda t: t[1]["avg_return"], reverse=True)
    for fn, m in evaluated[:pool_size]:
        policy_version += 1
        policy_pool.add_policy(fn, policy_version, metrics=m)

    for iter_idx in range(max_iters):
        print(f"=== Iteration {iter_idx+1} ===")
        policy_fn, idx = policy_pool.select_policy_ucb()

        # Run selected policy
        env = gym.make(env_id)
        iteration_returns = []
        truncated_seen = False
        for ep in range(episodes_per_iter):
            s, _ = env.reset()
            done = False
            memory.start_episode()
            while not done:
                a = safe_policy_call(s, policy_fn, knowledge.static_knowledge, env)
                s_next, r, terminated, truncated, info = safe_step(env, a)
                done = terminated or truncated
                if truncated and env_id not in ["Acrobot-v1", "Pendulum-v1"]:
                    truncated_seen = True
                memory.add_step(s, a, r, done)
                s = s_next
            memory.add_episode_summary(env_id, policy_version=policy_pool.policies[idx]["version"])
            iteration_returns.append(memory.episodes[-1]["summary"]["return"])
        env.close()

        recent_ma = memory.get_recent_episodes(n=ma_window)
        metrics = reflection.metrics(recent_ma)
        policy_pool.policies[idx]["metrics"] = metrics
        policy_pool.update_policy_value(idx, metrics["avg_return"])

        # compute expert deterministic mean (3 episodes) for display comparison
        expert_mean = None
        expert_wrapper = reflection.load_expert_for_env(env_id)
        if expert_wrapper is not None:
            try:
                ers = expert_wrapper.run_episodes(episodes=3, render=False)
                expert_mean = float(np.mean(ers))
            except Exception as e:
                expert_mean = None

        # print including expert mean
        if expert_mean is None:
            print(f"[Selected idx {idx} v{policy_pool.policies[idx]['version']}] MA Return={metrics['avg_return']:.2f}  SR={metrics['success_rate']:.2f}  ExpertMean=N/A")
        else:
            print(f"[Selected idx {idx} v{policy_pool.policies[idx]['version']}] MA Return={metrics['avg_return']:.2f}  SR={metrics['success_rate']:.2f}  ExpertMean={expert_mean:.2f}")

        pattern = reflection.failure_pattern(recent_ma, env_id)
        print("Failure Pattern:", pattern)
        edit = reflection.edit_suggestion(recent_ma, env_id)
        print("Edit Suggestion:", edit)

        # Replace worst
        worst_idx = policy_pool.get_worst_policy_idx()
        seed_fn = policy_fn
        new_candidates = []
        try:
            edited_fn = apply_edit(seed_fn, edit, knowledge, env_id)
            new_candidates.append(edited_fn)
        except Exception:
            pass
        fresh = generate_base_policies(env_id, knowledge, n_candidates=2)
        new_candidates.extend(fresh)

        cand_eval = parallel_eval_candidates(env_id, new_candidates, episodes_each=max(2, episodes_per_iter//2))
        cand_eval.sort(key=lambda t: t[1]["avg_return"], reverse=True)
        best_new_fn, best_new_metrics = cand_eval[0]
        policy_version += 1
        policy_pool.policies[worst_idx] = {"fn": best_new_fn, "version": policy_version, "metrics": best_new_metrics}
        policy_pool.counts[worst_idx] = 0
        policy_pool.values[worst_idx] = 0.0
        print(f"Replaced worst idx {worst_idx} with new v{policy_version}: avg_return={best_new_metrics['avg_return']:.2f}")

        if len(memory.episodes) >= ma_window * episodes_per_iter:
            thr = knowledge.static_knowledge.get("reward_threshold", 0.0)
            if metrics["avg_return"] >= thr and metrics["success_rate"] >= success_rate_threshold:
                print(f"Converged! MA Return={metrics['avg_return']:.2f}, SR={metrics['success_rate']:.2f}")
                break

        if truncated_seen and env_id not in ["Acrobot-v1", "Pendulum-v1"]:
            print("[Info] Truncated observed (non-exempt env) — handled via pool search (no hard reset).")

# ----------------------------
# 脚本入口 — 逐个运行任务
# ----------------------------
if __name__ == "__main__":
    env_list = [
        "Acrobot-v1",
        "CartPole-v1",
        "MountainCarContinuous-v0",
        "MountainCar-v0",
        "Pendulum-v1"
    ]
    # 打印发现的 experts
    print("Expert files found:")
    for env in env_list:
        f = find_expert_file_for_env(env)
        print(f"  {env}: {f}")
    print("Starting closed-loop runs (this may take long).")

    for env_id in env_list:
        print(f"\n==== Running {env_id} ====")
        run_env_loop(env_id, max_iters=20, episodes_per_iter=20, ma_window=3,
                     success_rate_threshold=0.8, pool_size=5, n_init_candidates=3)

Expert files found:
  Acrobot-v1: experts/Acrobot-v1_best_expert.cleanrl_model
  CartPole-v1: experts/CartPole-v1_best_expert.cleanrl_model
  MountainCarContinuous-v0: experts/MountainCarContinuous-v0_best_expert.cleanrl_model
  MountainCar-v0: experts/MountainCar-v0_best_expert.cleanrl_model
  Pendulum-v1: experts/Pendulum-v1_best_expert.cleanrl_model
Starting closed-loop runs (this may take long).

==== Running Acrobot-v1 ====
=== Iteration 1 ===
[Info] loading expert from experts/Acrobot-v1_best_expert.cleanrl_model
[Info] Model keys: ['network.0.weight', 'network.0.bias', 'network.2.weight', 'network.2.bias', 'network.4.weight', 'network.4.bias']
[Info] Detected checkpoint type: DQN
[Info] Expert loaded (dqn)
[Selected idx 0 v1] MA Return=-500.00  SR=0.00  ExpertMean=-78.00


  state = torch.load(self.model_path, map_location="cpu")


Failure Pattern: The Acrobot-v1 episode summaries indicate consistent failure, with all episodes reaching the maximum length of 500 steps and receiving the minimum possible return of -500.0. This pattern suggests the agent is unable to achieve the task objective (raising the end-effector above a target height) and is likely stuck in unproductive state trajectories, possibly oscillating or remaining in low-energy configurations. The agent’s actions may be ineffective, such as repeating the same or random actions without exploiting the environment’s dynamics. Compared to expert performance (mean return -62.0), the agent’s returns are drastically worse, highlighting a lack of learning or exploration. Overall, the key failure patterns are persistent inability to escape poor states, ineffective action selection, and no progress toward the goal.
Edit Suggestion: add_rule(if steps_without_progress > 100 -> increase exploration rate)
Replaced worst idx 0 with new v4: avg_return=-500.00
=== Ite