In [None]:
"""
TD3 for Surrogate-Driven Parameter Optimization (PLAL)
-----------------------------------------------------
- Environment: Random Forest model that predicts DLS from [Time, ScanSpeed, Fluence].
- Agent: TD3 (Twin Delayed DDPG) with:
    * Twin critics (reduces Q overestimation bias)
    * Target policy smoothing (robustness)
    * Delayed policy/target updates (stability)
- Goal: Given a target DLS (e.g., 150 nm), learn to propose parameters that hit it.
- Action/State bounds:
    Time(min):       [ 2.00, 25.00 ]
    ScanSpeed(mm/s): [ 3000, 3500 ]
    Fluence(J/cm²):  [ 1.83, 1.91 ]
- DLS valid range (for reward scaling): [83, 203] nm
"""

import os
import gym
import numpy as np
from gym import spaces
import joblib
from collections import deque
import random
from dataclasses import dataclass
from typing import Tuple, Dict, Any, List

import torch
import torch.nn as nn
import torch.optim as optim


# -------------------------------
# Reproducibility helpers
# -------------------------------
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


# -------------------------------
# Constants (bounds + scaling)
# -------------------------------
ACTION_LOW  = np.array([ 2.0, 3000.0, 1.83], dtype=np.float32)
ACTION_HIGH = np.array([25.0, 3500.0, 1.91], dtype=np.float32)

DLS_MIN, DLS_MAX = 83.0, 203.0              # for reward normalization
DLS_TOL = 1.0                                # success if |pred - target| < 1 nm
EPISODE_STEPS = 10                           # environment horizon (concise exploration)


# -------------------------------
# Gym Environment: RF-as-Env
# -------------------------------
class DLS_Environment(gym.Env):
    """
    A minimal Gym-compatible environment that wraps a trained Random Forest regressor.
    State = current parameter vector [Time, ScanSpeed, Fluence]
    Action = next proposed parameter vector (same shape, bounded)
    Reward = - normalized absolute error to target DLS
    Done   = success (error < 1 nm) or out of steps
    """

    metadata = {"render.modes": []}

    def __init__(self, model_path: str = "random_forest_dls_model.pkl"):
        super().__init__()
        # Load your trained Random Forest model
        self.model = joblib.load(model_path)

        # If you trained RF on scaled features, load/apply the SAME scaler(s) here:
        # Example:
        # self.scaler_X = joblib.load("feature_scaler.pkl")  # optional
        # self.scaler_y = joblib.load("target_scaler.pkl")   # optional

        self.action_space = spaces.Box(low=ACTION_LOW, high=ACTION_HIGH, dtype=np.float32)
        self.observation_space = spaces.Box(low=ACTION_LOW, high=ACTION_HIGH, dtype=np.float32)

        self.target_dls = 150.0
        self.current_step = 0
        self.max_steps = EPISODE_STEPS
        self.state = None

    def _predict_dls(self, x: np.ndarray) -> float:
        """Predict DLS from RF; apply scaler if RF was trained with one."""
        X = x[None, :]  # shape (1,3)
        # If you used a scaler during RF training, uncomment:
        # X = self.scaler_X.transform(X)
        pred = float(self.model.predict(X)[0])
        # If target was scaled:
        # pred = float(self.scaler_y.inverse_transform([[pred]])[0,0])
        return pred

    def reset(self, target_dls: float = None) -> np.ndarray:
        """Start a new episode with a random state (valid bounds)."""
        self.current_step = 0
        if target_dls is not None:
            self.target_dls = float(target_dls)
        self.state = self.action_space.sample()
        return self.state.astype(np.float32)

    def step(self, action: np.ndarray):
        """Apply action (next proposed parameters), get DLS prediction and reward."""
        self.current_step += 1
        # Keep within physical bounds
        self.state = np.clip(action, self.action_space.low, self.action_space.high).astype(np.float32)

        # Predict DLS with surrogate model
        predicted_dls = self._predict_dls(self.state)

        # Dense reward: negative normalized absolute error
        norm = (DLS_MAX - DLS_MIN)
        reward = -abs(predicted_dls - self.target_dls) / norm

        # Episode termination: success or horizon reached
        done = (abs(predicted_dls - self.target_dls) < DLS_TOL) or (self.current_step >= self.max_steps)

        info = {"Predicted_DLS": predicted_dls}
        return self.state, float(reward), bool(done), info

    def render(self, mode="human"):
        pass


# -------------------------------
# Replay Buffer
# -------------------------------
class ReplayBuffer:
    def __init__(self, capacity: int = 100000):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s2, d):
        self.buffer.append((s, a, r, s2, d))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s2, d = map(np.array, zip(*batch))
        return s, a, r.reshape(-1, 1), s2, d.reshape(-1, 1)

    def __len__(self):
        return len(self.buffer)


# -------------------------------
# Neural Networks
# -------------------------------
class Actor(nn.Module):
    """
    Actor maps state -> action. We use tanh() then scale to [low, high]
    so the network naturally outputs bounded, physical actions.
    """
    def __init__(self, state_dim: int, action_low: np.ndarray, action_high: np.ndarray, hidden: int = 256):
        super().__init__()
        self.action_low  = torch.tensor(action_low, dtype=torch.float32)
        self.action_high = torch.tensor(action_high, dtype=torch.float32)
        self.register_buffer("a_low",  self.action_low)
        self.register_buffer("a_high", self.action_high)

        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, len(action_low))
        )

    def forward(self, s: torch.Tensor) -> torch.Tensor:
        raw = self.net(s)
        a = torch.tanh(raw)                      # (-1,1)
        # scale to [low, high]
        return 0.5 * (a + 1.0) * (self.a_high - self.a_low) + self.a_low


class Critic(nn.Module):
    """
    Critic maps [state, action] -> Q-value. (One head per Q-network)
    """
    def __init__(self, state_dim: int, action_dim: int, hidden: int = 256):
        super().__init__()
        self.q = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, 1)
        )

    def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        x = torch.cat([s, a], dim=-1)
        return self.q(x)


# -------------------------------
# TD3 Agent
# -------------------------------
@dataclass
class TD3Config:
    actor_lr: float = 1e-4
    critic_lr: float = 1e-3
    gamma: float = 0.99
    tau: float = 0.005
    batch_size: int = 64
    policy_noise_frac: float = 0.10   # fraction of (high-low) for target smoothing
    noise_clip_frac: float = 0.20     # clip target noise to this fraction
    explore_noise_frac: float = 0.05  # exploration noise fraction of (high-low)
    policy_delay: int = 2             # delayed actor/target updates
    max_episodes: int = 400
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 0


class TD3Agent:
    def __init__(self, state_dim: int, action_low: np.ndarray, action_high: np.ndarray, cfg: TD3Config):
        self.cfg = cfg
        self.device = torch.device(cfg.device)

        self.action_low  = torch.tensor(action_low,  dtype=torch.float32, device=self.device)
        self.action_high = torch.tensor(action_high, dtype=torch.float32, device=self.device)
        self.act_range   = (self.action_high - self.action_low)

        # Networks
        self.actor        = Actor(state_dim, action_low, action_high).to(self.device)
        self.actor_target = Actor(state_dim, action_low, action_high).to(self.device)
        self.critic1      = Critic(state_dim, len(action_low)).to(self.device)
        self.critic2      = Critic(state_dim, len(action_low)).to(self.device)
        self.critic1_t    = Critic(state_dim, len(action_low)).to(self.device)
        self.critic2_t    = Critic(state_dim, len(action_low)).to(self.device)

        # Target init
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.critic1_t.load_state_dict(self.critic1.state_dict())
        self.critic2_t.load_state_dict(self.critic2.state_dict())

        # Optims
        self.actor_opt  = optim.Adam(self.actor.parameters(),  lr=cfg.actor_lr)
        self.critic1_opt= optim.Adam(self.critic1.parameters(), lr=cfg.critic_lr)
        self.critic2_opt= optim.Adam(self.critic2.parameters(), lr=cfg.critic_lr)

        self.replay = ReplayBuffer(capacity=100_000)
        self.total_it = 0  # gradient step counter

    @torch.no_grad()
    def select_action(self, state: np.ndarray, explore: bool = True) -> np.ndarray:
        """Deterministic policy + optional exploration noise (bounded)."""
        s = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
        a = self.actor(s)  # already scaled to [low, high]
        a = a.squeeze(0)

        if explore:
            # Gaussian exploration noise proportional to action range
            noise_std = self.cfg.explore_noise_frac * self.act_range
            noise = torch.normal(mean=torch.zeros_like(a), std=noise_std.to(self.device))
            a = a + noise

        # Clip to bounds (safety)
        a = torch.max(torch.min(a, self.action_high), self.action_low)
        return a.detach().cpu().numpy().astype(np.float32)

    def train_step(self):
        if len(self.replay) < self.cfg.batch_size:
            return

        self.total_it += 1

        # Sample batch
        s, a, r, s2, d = self.replay.sample(self.cfg.batch_size)
        s  = torch.tensor(s,  dtype=torch.float32, device=self.device)
        a  = torch.tensor(a,  dtype=torch.float32, device=self.device)
        r  = torch.tensor(r,  dtype=torch.float32, device=self.device)
        s2 = torch.tensor(s2, dtype=torch.float32, device=self.device)
        d  = torch.tensor(d,  dtype=torch.float32, device=self.device)

        with torch.no_grad():
            # Target policy smoothing: actor_target(s2) + clipped noise
            noise_std  = self.cfg.policy_noise_frac * self.act_range
            noise_clip = self.cfg.noise_clip_frac  * self.act_range

            target_a = self.actor_target(s2)
            noise = torch.normal(mean=torch.zeros_like(target_a), std=noise_std.to(self.device))
            noise = torch.clamp(noise, -noise_clip, noise_clip)

            target_a = target_a + noise
            target_a = torch.max(torch.min(target_a, self.action_high), self.action_low)

            # Twin target critics and min for clipped double-Q
            q1_t = self.critic1_t(s2, target_a)
            q2_t = self.critic2_t(s2, target_a)
            q_t_min = torch.minimum(q1_t, q2_t)

            y = r + (1.0 - d) * self.cfg.gamma * q_t_min

        # Critic updates (both critics)
        q1 = self.critic1(s, a)
        q2 = self.critic2(s, a)
        critic1_loss = nn.MSELoss()(q1, y)
        critic2_loss = nn.MSELoss()(q2, y)

        self.critic1_opt.zero_grad()
        critic1_loss.backward()
        self.critic1_opt.step()

        self.critic2_opt.zero_grad()
        critic2_loss.backward()
        self.critic2_opt.step()

        # Delayed policy (actor) and target updates
        if self.total_it % self.cfg.policy_delay == 0:
            # Actor aims to maximize Q1(s, actor(s)) => minimize negative Q
            actor_actions = self.actor(s)
            actor_loss = - self.critic1(s, actor_actions).mean()
            self.actor_opt.zero_grad()
            actor_loss.backward()
            self.actor_opt.step()

            # Soft target updates
            with torch.no_grad():
                tau = self.cfg.tau
                for targ, src in zip(self.actor_target.parameters(), self.actor.parameters()):
                    targ.data.mul_(1 - tau).add_(tau * src.data)
                for targ, src in zip(self.critic1_t.parameters(), self.critic1.parameters()):
                    targ.data.mul_(1 - tau).add_(tau * src.data)
                for targ, src in zip(self.critic2_t.parameters(), self.critic2.parameters()):
                    targ.data.mul_(1 - tau).add_(tau * src.data)


# -------------------------------
# Training utilities
# -------------------------------
def run_td3_training(env: DLS_Environment, agent: TD3Agent, episodes: int = 400, target_dls: float = 150.0,
                     print_every: int = 10) -> Dict[str, Any]:
    """
    Train TD3 for a number of episodes. Returns training logs (rewards, last info).
    """
    logs = {"episode_rewards": [], "final_dls": []}
    for ep in range(1, episodes + 1):
        state = env.reset(target_dls=target_dls)
        ep_reward = 0.0
        last_info = {}
        for _ in range(env.max_steps):
            action = agent.select_action(state, explore=True)
            next_state, reward, done, info = env.step(action)
            agent.replay.push(state, action, reward, next_state, float(done))
            agent.train_step()
            state = next_state
            ep_reward += reward
            last_info = info
            if done:
                break
        logs["episode_rewards"].append(ep_reward)
        logs["final_dls"].append(last_info.get("Predicted_DLS", np.nan))

        if (ep % print_every) == 0:
            print(f"[TD3] Episode {ep:4d} | Reward: {ep_reward: .4f} | Final DLS: {last_info.get('Predicted_DLS', np.nan): .2f} nm")
    return logs


def evaluate_agent(env: DLS_Environment, agent: TD3Agent, target_dls: float, trials: int = 20) -> Tuple[np.ndarray, float, float]:
    """
    Deterministic evaluation: try multiple random restarts, pick the best action.
    Returns (best_action, best_pred_dls, best_error_nm).
    """
    best_action, best_pred, best_err = None, None, float("inf")
    for _ in range(trials):
        state = env.reset(target_dls=target_dls)
        action = agent.select_action(state, explore=False)  # no noise for inference
        _, _, _, info = env.step(action)
        pred = info["Predicted_DLS"]
        err = abs(pred - target_dls)
        if err < best_err:
            best_err = err
            best_pred = pred
            best_action = action.copy()
    return best_action, float(best_pred), float(best_err)


# -------------------------------
# Simple hyperparameter tuning
# -------------------------------
def small_hparam_sweep(env_path: str = "random_forest_dls_model.pkl",
                       seeds: List[int] = [0, 1],
                       configs: List[TD3Config] = None,
                       episodes: int = 300,
                       target_dls: float = 150.0) -> Tuple[TD3Agent, TD3Config, Dict[str, Any]]:
    """
    Very small TD3 sweep: tests a handful of configs & seeds to pick the best by average reward.
    Extend this list if you have more compute time.
    """
    if configs is None:
        configs = [
            TD3Config(actor_lr=1e-4, critic_lr=1e-3, explore_noise_frac=0.05, policy_noise_frac=0.10, noise_clip_frac=0.20, policy_delay=2),
            TD3Config(actor_lr=3e-4, critic_lr=3e-4, explore_noise_frac=0.10, policy_noise_frac=0.10, noise_clip_frac=0.20, policy_delay=2),
            TD3Config(actor_lr=1e-4, critic_lr=1e-3, explore_noise_frac=0.08, policy_noise_frac=0.15, noise_clip_frac=0.25, policy_delay=2),
        ]

    best_score = -1e9
    best_agent, best_cfg, best_logs = None, None, None

    for seed in seeds:
        for cfg in configs:
            cfg = TD3Config(**{**cfg.__dict__, "seed": seed})
            print(f"\n=== Trying config: {cfg} ===")
            set_seed(cfg.seed)

            env = DLS_Environment(model_path=env_path)
            agent = TD3Agent(state_dim=3, action_low=ACTION_LOW, action_high=ACTION_HIGH, cfg=cfg)
            logs = run_td3_training(env, agent, episodes=min(episodes, cfg.max_episodes), target_dls=target_dls, print_every=25)

            avg_last50 = float(np.nanmean(logs["episode_rewards"][-50:])) if len(logs["episode_rewards"]) >= 50 else float(np.nanmean(logs["episode_rewards"]))
            print(f"Avg reward (last 50): {avg_last50:.4f}")

            if avg_last50 > best_score:
                best_score, best_agent, best_cfg, best_logs = avg_last50, agent, cfg, logs

    print("\n>>> Best config selected:", best_cfg)
    return best_agent, best_cfg, best_logs


# -------------------------------
# Main: train + interactive query
# -------------------------------
if __name__ == "__main__":
    # 1) (Optional) quick hyperparameter sweep to pick a good TD3 config
    best_agent, best_cfg, _ = small_hparam_sweep(
        env_path="random_forest_dls_model.pkl",
        seeds=[0, 1],
        episodes=300,           # increase if you want stronger policies
        target_dls=150.0
    )

    # 2) Let the user query any target DLS and get suggested parameters
    print("\nYou can now query the trained TD3 policy with any desired DLS in [83, 203] nm.")
    print("Enter e.g. 150 (or press Enter to use 150 by default). Ctrl+C to exit.")

    while True:
        try:
            s = input("\nDesired DLS (nm) [83..203]: ").strip()
            target = 150.0 if (s == "" or s is None) else float(s)
            target = float(np.clip(target, DLS_MIN, DLS_MAX))

            env = DLS_Environment(model_path="random_forest_dls_model.pkl")
            action, pred_dls, err = evaluate_agent(env, best_agent, target_dls=target, trials=30)

            print("\nSuggested parameters (TD3 policy):")
            print(f"  Time (min):        {action[0]:.2f}  (bounds: {ACTION_LOW[0]}..{ACTION_HIGH[0]})")
            print(f"  Scanspeed (mm/s):  {action[1]:.2f}  (bounds: {ACTION_LOW[1]}..{ACTION_HIGH[1]})")
            print(f"  Fluence (J/cm²):   {action[2]:.3f}  (bounds: {ACTION_LOW[2]}..{ACTION_HIGH[2]})")
            print(f"Predicted DLS:       {pred_dls:.2f} nm | Target: {target:.2f} nm | Error: {err:.2f} nm")

        except KeyboardInterrupt:
            print("\nExiting. Goodbye!")
            break
        except Exception as e:
            print(f"Input error ({e}). Please enter a number between {DLS_MIN} and {DLS_MAX}.")




=== Trying config: TD3Config(actor_lr=0.0001, critic_lr=0.001, gamma=0.99, tau=0.005, batch_size=64, policy_noise_frac=0.1, noise_clip_frac=0.2, explore_noise_frac=0.05, policy_delay=2, max_episodes=400, device='cpu', seed=0) ===
[TD3] Episode   25 | Reward: -1.2712 | Final DLS:  134.90 nm
[TD3] Episode   50 | Reward: -1.2583 | Final DLS:  134.90 nm
[TD3] Episode   75 | Reward: -1.2583 | Final DLS:  134.90 nm
[TD3] Episode  100 | Reward: -1.2893 | Final DLS:  134.28 nm
[TD3] Episode  125 | Reward: -1.2634 | Final DLS:  134.90 nm
[TD3] Episode  150 | Reward: -1.2660 | Final DLS:  134.90 nm
[TD3] Episode  175 | Reward: -1.2583 | Final DLS:  134.90 nm
[TD3] Episode  200 | Reward: -1.2634 | Final DLS:  134.90 nm
[TD3] Episode  225 | Reward: -1.2583 | Final DLS:  134.90 nm
[TD3] Episode  250 | Reward: -1.2686 | Final DLS:  134.90 nm
[TD3] Episode  275 | Reward: -1.2634 | Final DLS:  134.90 nm
[TD3] Episode  300 | Reward: -1.2583 | Final DLS:  134.90 nm
Avg reward (last 50): -1.2643

=== Tr


Desired DLS (nm) [83..203]:  150



Suggested parameters (TD3 policy):
  Time (min):        25.00  (bounds: 2.0..25.0)
  Scanspeed (mm/s):  3000.00  (bounds: 3000.0..3500.0)
  Fluence (J/cm²):   1.830  (bounds: 1.8300000429153442..1.909999966621399)
Predicted DLS:       134.90 nm | Target: 150.00 nm | Error: 15.10 nm



Desired DLS (nm) [83..203]:  95



Suggested parameters (TD3 policy):
  Time (min):        25.00  (bounds: 2.0..25.0)
  Scanspeed (mm/s):  3000.00  (bounds: 3000.0..3500.0)
  Fluence (J/cm²):   1.830  (bounds: 1.8300000429153442..1.909999966621399)
Predicted DLS:       134.90 nm | Target: 95.00 nm | Error: 39.90 nm



Desired DLS (nm) [83..203]:  145



Suggested parameters (TD3 policy):
  Time (min):        25.00  (bounds: 2.0..25.0)
  Scanspeed (mm/s):  3000.00  (bounds: 3000.0..3500.0)
  Fluence (J/cm²):   1.830  (bounds: 1.8300000429153442..1.909999966621399)
Predicted DLS:       134.90 nm | Target: 145.00 nm | Error: 10.10 nm



Desired DLS (nm) [83..203]:  190



Suggested parameters (TD3 policy):
  Time (min):        25.00  (bounds: 2.0..25.0)
  Scanspeed (mm/s):  3000.00  (bounds: 3000.0..3500.0)
  Fluence (J/cm²):   1.830  (bounds: 1.8300000429153442..1.909999966621399)
Predicted DLS:       134.90 nm | Target: 190.00 nm | Error: 55.10 nm



Desired DLS (nm) [83..203]:  150



Suggested parameters (TD3 policy):
  Time (min):        25.00  (bounds: 2.0..25.0)
  Scanspeed (mm/s):  3000.00  (bounds: 3000.0..3500.0)
  Fluence (J/cm²):   1.830  (bounds: 1.8300000429153442..1.909999966621399)
Predicted DLS:       134.90 nm | Target: 150.00 nm | Error: 15.10 nm
