In [1]:
%%capture
!pip install stable-baselines3[extra]
!pip install moviepy

In [2]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList
from stable_baselines3.common.logger import TensorBoardOutputFormat, Video
from stable_baselines3.common.evaluation import evaluate_policy

from typing import Any, Dict

import gymnasium as gym
import torch as th
import numpy as np

CALLBACK_FREQ = 50000
FRAMESKIP = 1
NUM_TIMESTEPS = 1000000

2024-04-30 12:39:16.581158: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-30 12:39:16.581292: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-30 12:39:16.737097: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
class VideoRecorderCallback(BaseCallback):
    def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
        """
        Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard

        :param eval_env: A gym environment from which the trajectory is recorded
        :param render_freq: Render the agent's trajectory every eval_freq call of the callback.
        :param n_eval_episodes: Number of episodes to render
        :param deterministic: Whether to use deterministic or stochastic policy
        """
        super().__init__()
        self._eval_env = eval_env
        self._render_freq = render_freq
        self._n_eval_episodes = n_eval_episodes
        self._deterministic = deterministic

    def _on_step(self) -> bool:
        if self.n_calls % self._render_freq == 0:
            screens = []

            def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
                """
                Renders the environment in its current state, recording the screen in the captured `screens` list

                :param _locals: A dictionary containing all local variables of the callback's scope
                :param _globals: A dictionary containing all global variables of the callback's scope
                """
                screen = self._eval_env.render()
                # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
                screens.append(screen.transpose(2, 0, 1))

            evaluate_policy(
                self.model,
                self._eval_env,
                callback=grab_screens,
                n_eval_episodes=self._n_eval_episodes,
                deterministic=self._deterministic,
            )
            # Convert screens to a numpy array before passing to pytorch
            screens_array = np.array(screens)
            self.logger.record(
                "trajectory/video",
                Video(th.ByteTensor([screens_array]), fps=20),
                exclude=("stdout", "log", "json", "csv"),
            )
        return True

In [4]:
eval_env = Monitor(gym.make("ALE/Pacman-v5", render_mode="rgb_array", frameskip=FRAMESKIP))
train_env = gym.make("ALE/Pacman-v5", render_mode="rgb_array", frameskip=FRAMESKIP)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [5]:
eval_callback = EvalCallback(eval_env, log_path="./", eval_freq=CALLBACK_FREQ, n_eval_episodes=5, deterministic=True, render=False)
video_callback = VideoRecorderCallback(eval_env, render_freq=CALLBACK_FREQ)
callback_list = CallbackList([eval_callback, video_callback])

In [6]:
model = DQN(
    "CnnPolicy",
    train_env,
    verbose=1,
    buffer_size=100000,
    tensorboard_log="./")

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


In [7]:
model.learn(total_timesteps=NUM_TIMESTEPS, callback=callback_list, tb_log_name="./control/")

Logging to ././control/_1




----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.47e+03 |
|    ep_rew_mean      | 7.25     |
|    exploration_rate | 0.944    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 1275     |
|    time_elapsed     | 4        |
|    total_timesteps  | 5874     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.84e+03 |
|    ep_rew_mean      | 7.62     |
|    exploration_rate | 0.86     |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1282     |
|    time_elapsed     | 11       |
|    total_timesteps  | 14758    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2e+03    |
|    ep_rew_mean      | 8.33     |
|    exploration_rate | 0.772    |
| time/               |          |
|    episodes       

  logger.warn(
  Video(th.ByteTensor([screens_array]), fps=20),


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.83e+03 |
|    ep_rew_mean      | 8.89     |
|    exploration_rate | 0.513    |
| time/               |          |
|    episodes         | 28       |
|    fps              | 430      |
|    time_elapsed     | 119      |
|    total_timesteps  | 51234    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 2.95e-05 |
|    n_updates        | 308      |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.81e+03 |
|    ep_rew_mean      | 9.28     |
|    exploration_rate | 0.45     |
| time/               |          |
|    episodes         | 32       |
|    fps              | 361      |
|    time_elapsed     | 160      |
|    total_timesteps  | 57872    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 1.61e-05 |
|    n_updates      

<stable_baselines3.dqn.dqn.DQN at 0x7ec590945d20>

In [8]:
model.save("ALE-Pacman-v5-control-v2")