In [1]:
%%capture
!pip install stable-baselines3[extra]
!pip install moviepy

In [2]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList
from stable_baselines3.common.logger import TensorBoardOutputFormat, Video
from stable_baselines3.common.evaluation import evaluate_policy

from typing import Any, Dict

import gymnasium as gym
import torch as th
import numpy as np

EVAL_CALLBACK_FREQ = 50_000
VIDEO_CALLBACK_FREQ = 250_000
FRAMESKIP = 4
NUM_TIMESTEPS = 1_000_000

2024-05-01 16:40:49.228128: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-01 16:40:49.228225: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-01 16:40:49.464967: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
class VideoRecorderCallback(BaseCallback):
    def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
        """
        Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard

        :param eval_env: A gym environment from which the trajectory is recorded
        :param render_freq: Render the agent's trajectory every eval_freq call of the callback.
        :param n_eval_episodes: Number of episodes to render
        :param deterministic: Whether to use deterministic or stochastic policy
        """
        super().__init__()
        self._eval_env = eval_env
        self._render_freq = render_freq
        self._n_eval_episodes = n_eval_episodes
        self._deterministic = deterministic

    def _on_step(self) -> bool:
        if self.n_calls % self._render_freq == 0:
            screens = []

            def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
                """
                Renders the environment in its current state, recording the screen in the captured `screens` list

                :param _locals: A dictionary containing all local variables of the callback's scope
                :param _globals: A dictionary containing all global variables of the callback's scope
                """
                screen = self._eval_env.render()
                # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
                screens.append(screen.transpose(2, 0, 1))

            evaluate_policy(
                self.model,
                self._eval_env,
                callback=grab_screens,
                n_eval_episodes=self._n_eval_episodes,
                deterministic=self._deterministic,
            )
            # Convert screens to a numpy array before passing to pytorch
            screens_array = np.array(screens)
            self.logger.record(
                "trajectory/video",
                Video(th.ByteTensor([screens_array]), fps=60),
                exclude=("stdout", "log", "json", "csv"),
            )
        return True

In [4]:
eval_env = Monitor(gym.make("ALE/Pacman-v5", render_mode="rgb_array", frameskip=FRAMESKIP))
train_env = gym.make("ALE/Pacman-v5", render_mode="rgb_array", frameskip=FRAMESKIP)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [5]:
eval_callback = EvalCallback(eval_env, log_path="./", eval_freq=EVAL_CALLBACK_FREQ, n_eval_episodes=5, deterministic=True, render=False)
video_callback = VideoRecorderCallback(eval_env, render_freq=VIDEO_CALLBACK_FREQ)
callback_list = CallbackList([eval_callback, video_callback])

In [6]:
model = DQN(
    "CnnPolicy",
    train_env,
    verbose=1,
    buffer_size=100_000,
    tensorboard_log="./")

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


In [7]:
model.learn(total_timesteps=NUM_TIMESTEPS, callback=callback_list, tb_log_name="./control/")

Logging to ././control/_1




----------------------------------
| rollout/            |          |
|    ep_len_mean      | 419      |
|    ep_rew_mean      | 16       |
|    exploration_rate | 0.984    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 674      |
|    time_elapsed     | 2        |
|    total_timesteps  | 1676     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 409      |
|    ep_rew_mean      | 15.9     |
|    exploration_rate | 0.969    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 686      |
|    time_elapsed     | 4        |
|    total_timesteps  | 3272     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 430      |
|    ep_rew_mean      | 16.8     |
|    exploration_rate | 0.951    |
| time/               |          |
|    episodes       

  logger.warn(
  Video(th.ByteTensor([screens_array]), fps=60),


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 891      |
|    ep_rew_mean      | 54.9     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 432      |
|    fps              | 180      |
|    time_elapsed     | 1401     |
|    total_timesteps  | 252772   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0537   |
|    n_updates        | 50692    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 892      |
|    ep_rew_mean      | 55.7     |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 436      |
|    fps              | 179      |
|    time_elapsed     | 1424     |
|    total_timesteps  | 255480   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0172   |
|    n_updates      

<stable_baselines3.dqn.dqn.DQN at 0x7cf026a8d7b0>

In [8]:
model.save("ALE-Pacman-v5-control-v3")