In [1]:
# pip install 'gym[atari]'
# pip install 'stable-baselines3[extra]'

In [2]:
from matplotlib import pyplot as plt
import numpy as np

In [4]:
# from stable_baselines3.common.policies import CnnPolicy
from stable_baselines3 import A2C, PPO
from stable_baselines3.common.env_util import make_vec_env

In [11]:
n_envs = 16
env = make_vec_env("CartPole-v1", n_envs=n_envs, seed=0)

In [12]:
s0 = env.reset()

In [15]:
# s0

In [23]:
from typing import Any, Dict

import gym
import torch as th

from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.logger import Video


class VideoRecorderCallback(BaseCallback):
    def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):
        """
        Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard

        :param eval_env: A gym environment from which the trajectory is recorded
        :param render_freq: Render the agent's trajectory every eval_freq call of the callback.
        :param n_eval_episodes: Number of episodes to render
        :param deterministic: Whether to use deterministic or stochastic policy
        """
        super().__init__()
        self._eval_env = eval_env
        self._render_freq = render_freq
        self._n_eval_episodes = n_eval_episodes
        self._deterministic = deterministic

    def _on_step(self) -> bool:
        if self.n_calls % self._render_freq == 0:
            screens = []

            def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
                """
                Renders the environment in its current state, recording the screen in the captured `screens` list

                :param _locals: A dictionary containing all local variables of the callback's scope
                :param _globals: A dictionary containing all global variables of the callback's scope
                """
                screen = self._eval_env.render(mode="rgb_array")
                # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention
                screens.append(screen.transpose(2, 0, 1))

            evaluate_policy(
                self.model,
                self._eval_env,
                callback=grab_screens,
                n_eval_episodes=self._n_eval_episodes,
                deterministic=self._deterministic,
            )
            self.logger.record(
                f"trajectory/video_{self.n_calls}",
                Video(th.ByteTensor([screens]), fps=40),
                exclude=("stdout", "log", "json", "csv"),
            )
        return True

In [24]:
env = make_vec_env("CartPole-v1", n_envs=2, seed=0)

In [25]:
model = PPO(
    'MlpPolicy',
    env, verbose=1, tensorboard_log="./a2c_cheeta_tensorboard/")

Using cpu device


In [26]:
env.observation_space

Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)

In [31]:
model.env.reset().shape

(2, 4)

In [28]:
video_recorder = VideoRecorderCallback(
    make_vec_env("CartPole-v1", n_envs=1, seed=0),
    render_freq=5000
)


In [None]:
model.learn(total_timesteps=int(5e6), callback=video_recorder)

Logging to ./a2c_cheeta_tensorboard/A2C_2
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 23.5     |
|    ep_rew_mean        | 23.5     |
| time/                 |          |
|    fps                | 167      |
|    iterations         | 100      |
|    time_elapsed       | 5        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -0.659   |
|    explained_variance | -0.0279  |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 1.54     |
|    value_loss         | 8.75     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 28.4     |
|    ep_rew_mean        | 28.4     |
| time/                 |          |
|    fps                | 311      |
|    iterations         | 200      |
|    time_elapsed       | 6        |
|    total_timesteps    | 2000   

In [None]:
obs = env.reset()
while True:
    # By default, deterministic=False, so we use the stochastic policy
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()