In [1]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines3 import DQN
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.callbacks import BaseCallback

# Custom callback to log rewards
class RewardLoggerCallback(BaseCallback):
    def __init__(self, eval_env, eval_freq, n_eval_episodes, verbose=0):
        super().__init__(verbose)
        self.eval_env = eval_env
        self.eval_freq = eval_freq
        self.n_eval_episodes = n_eval_episodes
        self.median_rewards = []

    def _on_step(self) -> bool:
        if self.n_calls % self.eval_freq == 0:
            rewards = []
            for _ in range(self.n_eval_episodes):
                obs, _ = self.eval_env.reset()
                done = False
                total_reward = 0.0
                while not done:
                    action, _ = self.model.predict(obs, deterministic=True)
                    obs, reward, terminated, truncated, _ = self.eval_env.step(action)
                    done = terminated or truncated
                    total_reward += reward
                rewards.append(total_reward)
            median_reward = np.median(rewards)
            self.median_rewards.append(median_reward)
        return True

# List of environments
env_ids = ['CartPole-v1', 'MountainCar-v0']
timesteps = 50000
eval_freq = 5000
n_eval_episodes = 5

# Store results
all_results = {}

for env_id in env_ids:
    print(f"\n=== Training on {env_id} ===")
    train_env = gym.make(env_id)
    eval_env = gym.make(env_id)

    model = DQN("MlpPolicy", train_env, verbose=0, learning_rate=1e-3)
    
    callback = RewardLoggerCallback(eval_env, eval_freq, n_eval_episodes)
    
    model.learn(total_timesteps=timesteps, callback=callback)
    
    all_results[env_id] = callback.median_rewards

    train_env.close()
    eval_env.close()



=== Training on CartPole-v1 ===

=== Training on MountainCar-v0 ===


In [2]:
all_results

{'CartPole-v1': [9.0, 9.0, 10.0, 10.0, 12.0, 26.0, 358.0, 187.0, 240.0, 207.0],
 'MountainCar-v0': [-200.0,
  -200.0,
  -200.0,
  -200.0,
  -200.0,
  -200.0,
  -200.0,
  -200.0,
  -200.0,
  -200.0]}