# Pong

In [1]:
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from torch.optim import Adam
from torch.nn import MSELoss
from matplotlib import animation
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# set animations to jshtml to render them in browser
# plt.rcParams["animation.html"] = "jshtml"

SEED = 19
ENV_NAME = "LunarLander-v3"

rng = np.random.default_rng(SEED)

def new_seed(rng):
    return rng.integers(10_000).item()

In [2]:
def replay(frames):
    fig, ax = plt.subplots()
    img = ax.imshow(frames[0])
    ax.axis("off")

    def update(frame):
        img.set_data(frame)
        return [img]

    anim = FuncAnimation(fig, update, frames=frames, interval=30, blit=True)
    plt.close(fig)
    return HTML(anim.to_jshtml())

In [3]:
def play_episode(agent, env, max_steps=None, seed=19):
    """Run one episode and return replay."""
    observation, info = env.reset(seed=seed)

    frames = []

    terminated = False
    truncated = False

    step = 0

    while not (terminated or truncated):
        frames.append(env.render())
        action = agent.act(observation)
        # apply the selected action to the environment
        observation, reward, terminated, truncated, info = env.step(action) 
        step += 1

        if max_steps is not None and step >= max_steps:
            truncated = True
    
    return replay(frames)

In [4]:
class Agent:
    """Abstract base class for our cart pole agents."""

    def act(self, obs):
        raise NotImplementedError()

class RandomAgent(Agent):
    """Select a random action."""
    def __init__(self, n=4, rng=None):
        self.rng = np.random.default_rng(rng)
        self.n = n

    def act(self, obs):
        return self.rng.integers(self.n)

Import stable-baselines3:

In [5]:
import stable_baselines3 as sb3
import ale_py

In [None]:

gym.register_envs(ale_py)

# Initialise the environment
env = gym.make("ALE/Pong-v5", render_mode="rgb_array")

obs, info = env.reset(seed=SEED)

terminated = False
truncated = False
frames = []

while not terminated and not truncated:
    frame = env.render()
    frames.append(frame)
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)

env.close()

In [None]:
replay(frames)

Now, let's train a DQN for the Pong environment via sb3:

In [None]:
from gymnasium import ObservationWrapper, ActionWrapper
from gymnasium.wrappers import FrameStackObservation, GrayscaleObservation
from stable_baselines3.common.env_checker import check_env
from gymnasium.spaces import Discrete, Box
#from stable_baselines3.common.vec_env import VecTransposeImage

class TransposeObservation(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        dtype = self.observation_space.dtype
        c, h, w = self.observation_space.shape
        shape = (h,w,c)
        self.observation_space = Box(
            0, 255, shape, dtype=dtype
        )

    def observation(self, obs: np.ndarray):
        return obs.transpose(1,2,0)

class LessActions(ActionWrapper):
    def __init__(self, env):
        super().__init__(env)
        # use only Up and Down actions
        self.action_space = Discrete(2)
    
    def action(self, action):
        return action+2

# Initialise the environment
#env = gym.make("ALE/Pong-v5", render_mode="rgb_array")
env = gym.make("ALE/Pong-v5", render_mode="rgb_array", obs_type="grayscale")
env = FrameStackObservation(env, stack_size=2)
env = TransposeObservation(env)
# env = LessActions(env)

check_env(env)

# model = sb3.DQN("MlpPolicy", env, verbose=1)
# reduce replay buffer size to avoid kernel crashout
# model = sb3.DQN("CnnPolicy", env, verbose=1, buffer_size=50_000, exploration_fraction=0.2)
model = sb3.PPO("CnnPolicy", env, verbose=1, seed=19)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.


In [None]:
model.learn(
    total_timesteps=300_000, log_interval=5, progress_bar=True
)

Output()

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 1e+03       |
|    ep_rew_mean          | -20.5       |
| time/                   |             |
|    fps                  | 466         |
|    iterations           | 5           |
|    time_elapsed         | 21          |
|    total_timesteps      | 10240       |
| train/                  |             |
|    approx_kl            | 0.016081145 |
|    clip_fraction        | 0.167       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.645      |
|    explained_variance   | 0.742       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0424     |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0259     |
|    value_loss           | 0.0223      |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.19e+03   |
|    ep_rew_mean          | -20.2      |
| time/                   |            |
|    fps                  | 460        |
|    iterations           | 10         |
|    time_elapsed         | 44         |
|    total_timesteps      | 20480      |
| train/                  |            |
|    approx_kl            | 0.07413569 |
|    clip_fraction        | 0.375      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.527     |
|    explained_variance   | 0.294      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.103     |
|    n_updates            | 90         |
|    policy_gradient_loss | -0.0677    |
|    value_loss           | 0.0216     |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.32e+03   |
|    ep_rew_mean          | -20.2      |
| time/                   |            |
|    fps                  | 458        |
|    iterations           | 15         |
|    time_elapsed         | 67         |
|    total_timesteps      | 30720      |
| train/                  |            |
|    approx_kl            | 0.15076292 |
|    clip_fraction        | 0.393      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.422     |
|    explained_variance   | 0.481      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.116     |
|    n_updates            | 140        |
|    policy_gradient_loss | -0.0729    |
|    value_loss           | 0.0189     |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.42e+03   |
|    ep_rew_mean          | -20.1      |
| time/                   |            |
|    fps                  | 456        |
|    iterations           | 20         |
|    time_elapsed         | 89         |
|    total_timesteps      | 40960      |
| train/                  |            |
|    approx_kl            | 0.20115513 |
|    clip_fraction        | 0.372      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.342     |
|    explained_variance   | 0.54       |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0966    |
|    n_updates            | 190        |
|    policy_gradient_loss | -0.0663    |
|    value_loss           | 0.0126     |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.57e+03   |
|    ep_rew_mean          | -19.7      |
| time/                   |            |
|    fps                  | 455        |
|    iterations           | 25         |
|    time_elapsed         | 112        |
|    total_timesteps      | 51200      |
| train/                  |            |
|    approx_kl            | 0.37605894 |
|    clip_fraction        | 0.394      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.305     |
|    explained_variance   | 0.578      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.111     |
|    n_updates            | 240        |
|    policy_gradient_loss | -0.075     |
|    value_loss           | 0.0126     |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.61e+03   |
|    ep_rew_mean          | -19.7      |
| time/                   |            |
|    fps                  | 454        |
|    iterations           | 30         |
|    time_elapsed         | 135        |
|    total_timesteps      | 61440      |
| train/                  |            |
|    approx_kl            | 0.45337072 |
|    clip_fraction        | 0.361      |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.289     |
|    explained_variance   | 0.541      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.0966    |
|    n_updates            | 290        |
|    policy_gradient_loss | -0.0634    |
|    value_loss           | 0.013      |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 1.72e+03   |
|    ep_rew_mean          | -19.3      |
| time/                   |            |
|    fps                  | 455        |
|    iterations           | 35         |
|    time_elapsed         | 157        |
|    total_timesteps      | 71680      |
| train/                  |            |
|    approx_kl            | 0.24024826 |
|    clip_fraction        | 0.33       |
|    clip_range           | 0.2        |
|    entropy_loss         | -0.272     |
|    explained_variance   | 0.742      |
|    learning_rate        | 0.0003     |
|    loss                 | -0.103     |
|    n_updates            | 340        |
|    policy_gradient_loss | -0.0585    |
|    value_loss           | 0.0133     |
----------------------------------------


In [None]:
class SB3Agent(Agent):
    def __init__(self, model):
        self.model = model

    def act(self, obs):
        action, _states = self.model.predict(obs)
        return action

In [None]:
sb3_agent = SB3Agent(model)

In [None]:
play_episode(sb3_agent, env, max_steps=400)

## Bookmarks

- https://karpathy.github.io/2016/05/31/rl/