# Stable Baselines 3

> Instead of implementing a DQN ourselves again, in this tutorial we will use stable-baselines3: a Python library for Reinforcement Learning!

In [16]:
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from torch.optim import Adam
from torch.nn import MSELoss
from matplotlib import animation
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# set animations to jshtml to render them in browser
# plt.rcParams["animation.html"] = "jshtml"

SEED = 19
ENV_NAME = "LunarLander-v3"

rng = np.random.default_rng(SEED)

def new_seed(rng):
    return rng.integers(10_000).item()

In [17]:
def replay(frames):
    fig, ax = plt.subplots()
    img = ax.imshow(frames[0])
    ax.axis("off")

    def update(frame):
        img.set_data(frame)
        return [img]

    anim = FuncAnimation(fig, update, frames=frames, interval=30, blit=True)
    plt.close(fig)
    return HTML(anim.to_jshtml())

In [18]:
def play_episode(agent, env, max_steps=None, seed=19):
    """Run one episode and return replay."""
    observation, info = env.reset(seed=seed)

    frames = []

    terminated = False
    truncated = False

    step = 0

    while not (terminated or truncated):
        frames.append(env.render())
        action = agent.act(observation)
        # apply the selected action to the environment
        observation, reward, terminated, truncated, info = env.step(action) 
        step += 1

        if max_steps is not None and step >= max_steps:
            truncated = True
    
    return replay(frames)

In [19]:
class Agent:
    """Abstract base class for our cart pole agents."""

    def act(self, obs):
        raise NotImplementedError()

class RandomAgent(Agent):
    """Select a random action."""
    def __init__(self, n=4, rng=None):
        self.rng = np.random.default_rng(rng)
        self.n = n

    def act(self, obs):
        return self.rng.integers(self.n)

Import stable-baselines3:

In [20]:
import stable_baselines3 as sb3
import ale_py

In [21]:

gym.register_envs(ale_py)

# Initialise the environment
env = gym.make("ALE/Pong-v5", render_mode="rgb_array")

obs, info = env.reset(seed=SEED)

terminated = False
truncated = False
frames = []

while not terminated and not truncated:
    frame = env.render()
    frames.append(frame)
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)

env.close()

In [22]:
replay(frames)

Now, let's train a DQN for the Pong environment via sb3:

In [23]:
# Initialise the environment
env = gym.make("ALE/Pong-v5", render_mode="rgb_array")

# model = sb3.DQN("MlpPolicy", env, verbose=1)
# reduce replay buffer size to avoid kernel crashout
model = sb3.DQN("CnnPolicy", env, verbose=1, buffer_size=50_000, exploration_fraction=0.2)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


In [24]:
model.learn(
    total_timesteps=500_000, log_interval=50, progress_bar=True
)

Output()

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 932      |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.557    |
| time/               |          |
|    episodes         | 50       |
|    fps              | 358      |
|    time_elapsed     | 129      |
|    total_timesteps  | 46591    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0123   |
|    n_updates        | 11622    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 997      |
|    ep_rew_mean      | -19.9    |
|    exploration_rate | 0.053    |
| time/               |          |
|    episodes         | 100      |
|    fps              | 446      |
|    time_elapsed     | 223      |
|    total_timesteps  | 99684    |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.000966 |
|    n_updates        | 24895    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.28e+03 |
|    ep_rew_mean      | -18.7    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 150      |
|    fps              | 475      |
|    time_elapsed     | 367      |
|    total_timesteps  | 174949   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00392  |
|    n_updates        | 43712    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.69e+03 |
|    ep_rew_mean      | -17.2    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 200      |
|    fps              | 493      |
|    time_elapsed     | 544      |
|    total_timesteps  | 268718   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00318  |
|    n_updates        | 67154    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 1.98e+03 |
|    ep_rew_mean      | -16.4    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 250      |
|    fps              | 504      |
|    time_elapsed     | 739      |
|    total_timesteps  | 372704   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00271  |
|    n_updates        | 93150    |
----------------------------------


----------------------------------
| rollout/            |          |
|    ep_len_mean      | 2.08e+03 |
|    ep_rew_mean      | -15.9    |
|    exploration_rate | 0.05     |
| time/               |          |
|    episodes         | 300      |
|    fps              | 509      |
|    time_elapsed     | 935      |
|    total_timesteps  | 476982   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00433  |
|    n_updates        | 119220   |
----------------------------------


<stable_baselines3.dqn.dqn.DQN at 0x7f4ef86717f0>

In [28]:
class SB3Agent(Agent):
    def __init__(self, model):
        self.model = model

    def act(self, obs):
        action, _states = self.model.predict(obs)
        return action

In [29]:
sb3_agent = SB3Agent(model)

In [None]:
play_episode(sb3_agent, env, max_steps=400)

: 