In [None]:
import os
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, plot_results
from stable_baselines3.common.callbacks import BaseCallback

# --- Configuration ---
ENV_IDS = ["CartPole-v1", "LunarLander-v2"]  # Multiple different Gym environments
NUM_ENVS_PER_PROBLEM = 4  # Number of parallel environments for each problem
TOTAL_TIMESTEPS_PER_PROBLEM = 50_000  # Total training timesteps for each environment
LOG_DIR = "dqn_multi_env_logs"
EVAL_FREQ = 1000  # Evaluate every EVAL_FREQ timesteps
N_EVAL_EPISODES = 10  # Number of episodes for evaluation

# --- Create logging directory ---
os.makedirs(LOG_DIR, exist_ok=True)

# --- Custom Callback for saving best model and logging rewards ---
class EvalCallback(BaseCallback):
    """
    A custom callback that extends the BaseCallback from Stable-Baselines3.
    It evaluates the agent periodically and saves the best model.
    """
    def __init__(self, eval_env, eval_freq, log_dir, n_eval_episodes=5, verbose=1):
        super().__init__(verbose)
        self.eval_env = eval_env
        self.eval_freq = eval_freq
        self.log_dir = log_dir
        self.n_eval_episodes = n_eval_episodes
        self.best_mean_reward = -np.inf
        self.episode_rewards = []
        self.timesteps = []

    def _on_step(self) -> bool:
        if self.n_calls % self.eval_freq == 0:
            mean_reward, _ = self.evaluate_agent()
            self.episode_rewards.append(mean_reward)
            self.timesteps.append(self.num_timesteps)

            if mean_reward > self.best_mean_reward:
                if self.verbose > 0:
                    print(f"New best mean reward: {mean_reward:.2f}, saving model.")
                self.best_mean_reward = mean_reward
                self.model.save(os.path.join(self.log_dir, "best_model"))
        return True

    def evaluate_agent(self):
        """
        Evaluate the agent's performance.
        """
        all_episode_rewards = []
        for _ in range(self.n_eval_episodes):
            obs, _ = self.eval_env.reset()
            done = False
            episode_reward = 0
            while not done:
                action, _ = self.model.predict(obs, deterministic=True)
                obs, reward, terminated, truncated, _ = self.eval_env.step(action)
                done = terminated or truncated
                episode_reward += reward
            all_episode_rewards.append(episode_reward)
        mean_reward = np.mean(all_episode_rewards)
        std_reward = np.std(all_episode_rewards)
        return mean_reward, std_reward

# --- Training and Evaluation Loop ---
results = {}

for env_id in ENV_IDS:
    print(f"\n--- Training DQN on {env_id} ---")
    
    # Create vectorized environments for training
    # Use SubprocVecEnv for parallel execution, especially for more complex environments
    # For simpler environments like CartPole, DummyVecEnv might be faster due to less overhead
    
    # Create a function that returns a new environment for make_vec_env
    def make_env():
        # Monitor wrapper logs training statistics to a CSV file
        return Monitor(gym.make(env_id)) 

    # For SubprocVecEnv, it's recommended to wrap the environment creation in a function
    # and use make_vec_env with vec_env_cls=SubprocVecEnv
    train_env = make_vec_env(make_env, n_envs=NUM_ENVS_PER_PROBLEM, vec_env_cls=SubprocVecEnv)
    
    # Create a separate evaluation environment (single, not vectorized for clear episode results)
    eval_env = Monitor(gym.make(env_id))

    # Instantiate the DQN agent
    model = DQN(
        "MlpPolicy",
        train_env,
        learning_rate=1e-4,
        buffer_size=100000,
        learning_starts=1000,
        batch_size=32,
        gamma=0.99,
        train_freq=4,
        gradient_steps=1,
        target_update_interval=1000,
        exploration_fraction=0.1,
        exploration_initial_eps=1.0,
        exploration_final_eps=0.05,
        verbose=0,
        tensorboard_log=LOG_DIR,
    )

    # Create the custom evaluation callback
    callback = EvalCallback(eval_env, EVAL_FREQ, os.path.join(LOG_DIR, env_id), N_EVAL_EPISODES)

    # Train the agent
    model.learn(total_timesteps=TOTAL_TIMESTEPS_PER_PROBLEM, callback=callback, progress_bar=True)
    
    # Store the results from the callback
    results[env_id] = {
        'timesteps': callback.timesteps,
        'rewards': callback.episode_rewards
    }

    # Close environments
    train_env.close()
    eval_env.close()

# --- Plotting Median Reward ---

plt.figure(figsize=(12, 6))

for env_id, data in results.items():
    timesteps = np.array(data['timesteps'])
    rewards = np.array(data['rewards'])
    
    # We already have the evaluation mean rewards from the EvalCallback,
    # which is generally what people plot for performance over time.
    # If you strictly want *median* reward per episode during training,
    # you would need to modify the Monitor callback or collect full episode rewards
    # and then calculate the median over a window.
    # However, for plotting performance during learning, mean reward is standard.
    # The EvalCallback already provides mean reward over N_EVAL_EPISODES.
    
    plt.plot(timesteps, rewards, label=f'{env_id} Mean Evaluation Reward')

plt.xlabel("Timesteps")
plt.ylabel("Mean Episode Reward (Evaluation)")
plt.title("DQN Training Performance Across Different Environments")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# You can also use Stable-Baselines3's built-in plotter for training rewards (from Monitor)
# This plots the mean reward over a rolling window.
print("\nPlotting training results using Stable-Baselines3's results plotter:")
for env_id in ENV_IDS:
    # The Monitor wrapper saves logs to a CSV file.
    # We need to load them to plot the training curve.
    log_path = os.path.join(LOG_DIR, env_id)
    if os.path.exists(log_path):
        x, y = ts2xy(load_results(log_path), 'timesteps')
        if len(x) > 0:
            plt.figure(figsize=(10, 5))
            plt.plot(x, y)
            plt.xlabel("Timesteps")
            plt.ylabel("Mean Reward (Training)")
            plt.title(f"DQN Training Curve for {env_id}")
            plt.grid(True)
            plt.tight_layout()
            plt.show()
        else:
            print(f"No training logs found for {env_id} at {log_path}")
    else:
        print(f"Log directory for {env_id} not found at {log_path}")

: 