In [22]:
import itertools
import os
import sys
import time
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import einops
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
import wandb
from IPython.display import HTML, display
from jaxtyping import Bool, Float, Int
from matplotlib.animation import FuncAnimation
from numpy.random import Generator
from torch import Tensor
from torch.distributions.categorical import Categorical
from torch.optim.optimizer import Optimizer
from tqdm import tqdm

warnings.filterwarnings("ignore")

Arr = np.ndarray

from utils import ppo_arg_help, set_global_seeds, make_env, get_episode_data_from_infos, prepare_atari_env

device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [23]:
@dataclass
class PPOArgs:
    # Basic / global
    seed: int = 1
    env_id: str = "CartPole-v1"
    mode: Literal["classic-control", "atari", "mujoco"] = "classic-control"

    # Wandb / logging
    use_wandb: bool = False
    video_log_freq: int | None = None
    wandb_project_name: str = "PPOCartPole"
    wandb_entity: str = None

    # Duration of different phases
    total_timesteps: int = 500_000
    num_envs: int = 4
    num_steps_per_rollout: int = 128
    num_minibatches: int = 4
    batches_per_learning_phase: int = 4

    # Optimization hyperparameters
    lr: float = 2.5e-4
    max_grad_norm: float = 0.5

    # RL hyperparameters
    gamma: float = 0.99

    # PPO-specific hyperparameters
    gae_lambda: float = 0.95
    clip_coef: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.25

    def __post_init__(self):
        self.batch_size = self.num_steps_per_rollout * self.num_envs

        assert self.batch_size % self.num_minibatches == 0, "batch_size must be divisible by num_minibatches"
        self.minibatch_size = self.batch_size // self.num_minibatches
        self.total_phases = self.total_timesteps // self.batch_size
        self.total_training_steps = self.total_phases * self.batches_per_learning_phase * self.num_minibatches

        self.video_save_path = Path("ppo-videos")


args = PPOArgs(num_minibatches=2)  # changing this also changes minibatch_size and total_training_steps
ppo_arg_help(args)

Unnamed: 0_level_0,default value,description
arg,Unnamed: 1_level_1,Unnamed: 2_level_1
seed,1,seed of the experiment
env_id,'CartPole-v1',the id of the environment
mode,'classic-control',"can be 'classic-control', 'atari' or 'mujoco'"
use_wandb,False,"if toggled, this experiment will be tracked with Weights and Biases"
video_log_freq,,"if not None, we log videos this many episodes apart (so shorter episodes mean more frequent logging)"
wandb_project_name,'PPOCartPole',the name of this experiment (also used as the wandb project name)
wandb_entity,,the entity (team) of wandb's project
total_timesteps,500000,total timesteps of the experiments
num_envs,4,number of synchronized vector environments in our `envs` object (this is N in the '37 Implementational Details' post)
num_steps_per_rollout,128,number of steps taken in the rollout phase (this is M in the '37 Implementational Details' post)


In [24]:
def layer_init(layer: nn.Linear, std=np.sqrt(2), bias_const=0.0):
    t.nn.init.orthogonal_(layer.weight, std)
    t.nn.init.constant_(layer.bias, bias_const)
    return layer


def get_actor_and_critic(
    envs: gym.vector.SyncVectorEnv,
    mode: Literal["classic-control", "atari", "mujoco"] = "classic-control",
) -> tuple[nn.Module, nn.Module]:
    """
    Returns (actor, critic), the networks used for PPO, in one of 3 different modes.
    """
    assert mode in ["classic-control", "atari", "mujoco"]

    obs_shape = envs.single_observation_space.shape
    num_obs = np.array(obs_shape).prod()
    num_actions = (
        envs.single_action_space.n
        if isinstance(envs.single_action_space, gym.spaces.Discrete)
        else np.array(envs.single_action_space.shape).prod()
    )

    if mode == "classic-control":
        actor, critic = get_actor_and_critic_classic(num_obs, num_actions)
    if mode == "atari":
        actor, critic = get_actor_and_critic_atari(obs_shape, num_actions)  # you'll implement these later
    if mode == "mujoco":
        actor, critic = get_actor_and_critic_mujoco(num_obs, num_actions)  # you'll implement these later

    return actor.to(device), critic.to(device)


def get_actor_and_critic_classic(num_obs: int, num_actions: int):
    """
    Returns (actor, critic) in the "classic-control" case, according to diagram above.
    """
    critic = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 1), std=1.0),
    )

    actor = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, num_actions), std=0.01),
    )
    return actor, critic

In [25]:
@t.inference_mode()
def compute_advantages(
    next_value: Float[Tensor, "num_envs"],
    next_terminated: Bool[Tensor, "num_envs"],
    rewards: Float[Tensor, "buffer_size num_envs"],
    values: Float[Tensor, "buffer_size num_envs"],
    terminated: Bool[Tensor, "buffer_size num_envs"],
    gamma: float,
    gae_lambda: float,
) -> Float[Tensor, "buffer_size num_envs"]:
    """
    Compute advantages using Generalized Advantage Estimation.
    """
    T = values.shape[0]
    terminated = terminated.float()
    next_terminated = next_terminated.float()

    # Get tensors of V(s_{t+1}) and d_{t+1} for all t = 0, 1, ..., T-1
    next_values = t.concat([values[1:], next_value[None, :]])
    next_terminated = t.concat([terminated[1:], next_terminated[None, :]])

    # Compute deltas: \delta_t = r_t + (1 - d_{t+1}) \gamma V(s_{t+1}) - V(s_t)
    deltas = rewards + gamma * next_values * (1.0 - next_terminated) - values

    # Compute advantages using the recursive formula, starting with advantages[T-1] = deltas[T-1] and working backwards
    advantages = t.zeros_like(deltas)
    advantages[-1] = deltas[-1]
    for s in reversed(range(T - 1)):
        advantages[s] = deltas[s] + gamma * gae_lambda * (1.0 - terminated[s + 1]) * advantages[s + 1]

    return advantages

In [26]:
def get_minibatch_indices(rng: Generator, batch_size: int, minibatch_size: int) -> list[np.ndarray]:
    """
    Return a list of length `num_minibatches`, where each element is an array of `minibatch_size` and the union of all
    the arrays is the set of indices [0, 1, ..., batch_size - 1] where `batch_size = num_steps_per_rollout * num_envs`.
    """
    assert batch_size % minibatch_size == 0
    num_minibatches = batch_size // minibatch_size
    indices = rng.permutation(batch_size).reshape(num_minibatches, minibatch_size)
    return list(indices)


rng = np.random.default_rng(0)

batch_size = 12
minibatch_size = 6
# num_minibatches = batch_size // minibatch_size = 2

indices = get_minibatch_indices(rng, batch_size, minibatch_size)

assert isinstance(indices, list)
assert all(isinstance(x, np.ndarray) for x in indices)
assert np.array(indices).shape == (2, 6)
assert sorted(np.unique(indices)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
print("All tests in `test_minibatch_indexes` passed!")

All tests in `test_minibatch_indexes` passed!


In [27]:
@dataclass
class ReplayMinibatch:
    """
    Samples from the replay memory, converted to PyTorch for use in neural network training.

    Data is equivalent to (s_t, a_t, logpi(a_t|s_t), A_t, A_t + V(s_t), d_{t+1})
    """

    obs: Float[Tensor, "minibatch_size *obs_shape"]
    actions: Int[Tensor, "minibatch_size *action_shape"]
    logprobs: Float[Tensor, "minibatch_size"]
    advantages: Float[Tensor, "minibatch_size"]
    returns: Float[Tensor, "minibatch_size"]
    terminated: Bool[Tensor, "minibatch_size"]


class ReplayMemory:
    """
    Contains buffer; has a method to sample from it to return a ReplayMinibatch object.
    """

    rng: Generator
    obs: Float[Arr, "buffer_size num_envs *obs_shape"]
    actions: Int[Arr, "buffer_size num_envs *action_shape"]
    logprobs: Float[Arr, "buffer_size num_envs"]
    values: Float[Arr, "buffer_size num_envs"]
    rewards: Float[Arr, "buffer_size num_envs"]
    terminated: Bool[Arr, "buffer_size num_envs"]

    def __init__(
        self,
        num_envs: int,
        obs_shape: tuple,
        action_shape: tuple,
        batch_size: int,
        minibatch_size: int,
        batches_per_learning_phase: int,
        seed: int = 42,
    ):
        self.num_envs = num_envs
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.batch_size = batch_size
        self.minibatch_size = minibatch_size
        self.batches_per_learning_phase = batches_per_learning_phase
        self.rng = np.random.default_rng(seed)
        self.reset()

    def reset(self):
        """Resets all stored experiences, ready for new ones to be added to memory."""
        self.obs = np.empty((0, self.num_envs, *self.obs_shape), dtype=np.float32)
        self.actions = np.empty((0, self.num_envs, *self.action_shape), dtype=np.int32)
        self.logprobs = np.empty((0, self.num_envs), dtype=np.float32)
        self.values = np.empty((0, self.num_envs), dtype=np.float32)
        self.rewards = np.empty((0, self.num_envs), dtype=np.float32)
        self.terminated = np.empty((0, self.num_envs), dtype=bool)

    def add(
        self,
        obs: Float[Arr, "num_envs *obs_shape"],
        actions: Int[Arr, "num_envs *action_shape"],
        logprobs: Float[Arr, "num_envs"],
        values: Float[Arr, "num_envs"],
        rewards: Float[Arr, "num_envs"],
        terminated: Bool[Arr, "num_envs"],
    ) -> None:
        """Add a batch of transitions to the replay memory."""
        # Check shapes & datatypes
        for data, expected_shape in zip(
            [obs, actions, logprobs, values, rewards, terminated], [self.obs_shape, self.action_shape, (), (), (), ()]
        ):
            assert isinstance(data, np.ndarray)
            assert data.shape == (self.num_envs, *expected_shape)

        # Add data to buffer (not slicing off old elements)
        self.obs = np.concatenate((self.obs, obs[None, :]))
        self.actions = np.concatenate((self.actions, actions[None, :]))
        self.logprobs = np.concatenate((self.logprobs, logprobs[None, :]))
        self.values = np.concatenate((self.values, values[None, :]))
        self.rewards = np.concatenate((self.rewards, rewards[None, :]))
        self.terminated = np.concatenate((self.terminated, terminated[None, :]))

    def get_minibatches(
        self, next_value: Tensor, next_terminated: Tensor, gamma: float, gae_lambda: float
    ) -> list[ReplayMinibatch]:
        """
        Returns a list of minibatches. Each minibatch has size `minibatch_size`, and the union over all minibatches is
        `batches_per_learning_phase` copies of the entire replay memory.
        """
        # Convert everything to tensors on the correct device
        obs, actions, logprobs, values, rewards, terminated = (
            t.tensor(x, device=device)
            for x in [self.obs, self.actions, self.logprobs, self.values, self.rewards, self.terminated]
        )

        # Compute advantages & returns
        advantages = compute_advantages(next_value, next_terminated, rewards, values, terminated, gamma, gae_lambda)
        returns = advantages + values

        # Return a list of minibatches
        minibatches = []
        for _ in range(self.batches_per_learning_phase):
            for indices in get_minibatch_indices(self.rng, self.batch_size, self.minibatch_size):
                minibatches.append(
                    ReplayMinibatch(
                        *[
                            data.flatten(0, 1)[indices]
                            for data in [obs, actions, logprobs, advantages, returns, terminated]
                        ]
                    )
                )

        # Reset memory (since we only need to call this method once per learning phase)
        self.reset()

        return minibatches

In [28]:
class PPOAgent:
    critic: nn.Sequential
    actor: nn.Sequential

    def __init__(self, envs: gym.vector.SyncVectorEnv, actor: nn.Module, critic: nn.Module, memory: ReplayMemory):
        super().__init__()
        self.envs = envs
        self.actor = actor
        self.critic = critic
        self.memory = memory

        self.step = 0  # Tracking number of steps taken (across all environments)
        self.next_obs = t.tensor(envs.reset()[0], device=device, dtype=t.float)  # need starting obs (in tensor form)
        self.next_terminated = t.zeros(envs.num_envs, device=device, dtype=t.bool)  # need starting termination=False

    def play_step(self) -> list[dict]:
        """
        Carries out a single interaction step between the agent and the environment, and adds results to the replay memory.

        Returns the list of info dicts returned from `self.envs.step`.
        """
        # Get newest observations (i.e. where we're starting from)
        obs = self.next_obs
        terminated = self.next_terminated

        # Compute logits based on newest observation, and use it to get an action distribution we sample from
        with t.inference_mode():
            logits = self.actor(obs)
        dist = Categorical(logits=logits)
        actions = dist.sample()

        # Step environment based on the sampled action
        next_obs, rewards, next_terminated, next_truncated, infos = self.envs.step(actions.cpu().numpy())

        # Calculate logprobs and values, and add this all to replay memory
        logprobs = dist.log_prob(actions).cpu().numpy()
        with t.inference_mode():
            values = self.critic(obs).flatten().cpu().numpy()
        self.memory.add(obs.cpu().numpy(), actions.cpu().numpy(), logprobs, values, rewards, terminated.cpu().numpy())

        # Set next observation & termination state
        self.next_obs = t.from_numpy(next_obs).to(device, dtype=t.float)
        self.next_terminated = t.from_numpy(next_terminated).to(device, dtype=t.float)

        self.step += self.envs.num_envs
        return infos

    def get_minibatches(self, gamma: float, gae_lambda: float) -> list[ReplayMinibatch]:
        """
        Gets minibatches from the replay memory, and resets the memory
        """
        with t.inference_mode():
            next_value = self.critic(self.next_obs).flatten()
        minibatches = self.memory.get_minibatches(next_value, self.next_terminated, gamma, gae_lambda)
        self.memory.reset()
        return minibatches


In [29]:
def calc_clipped_surrogate_objective(
    probs: Categorical,
    mb_action: Int[Tensor, "minibatch_size"],
    mb_advantages: Float[Tensor, "minibatch_size"],
    mb_logprobs: Float[Tensor, "minibatch_size"],
    clip_coef: float,
    eps: float = 1e-8,
) -> Float[Tensor, ""]:
    """Return the clipped surrogate objective, suitable for maximisation with gradient ascent.

    probs:
        a distribution containing the actor's unnormalized logits of shape (minibatch_size, num_actions)
    mb_action:
        what actions actions were taken in the sampled minibatch
    mb_advantages:
        advantages calculated from the sampled minibatch
    mb_logprobs:
        logprobs of the actions taken in the sampled minibatch (according to the old policy)
    clip_coef:
        amount of clipping, denoted by epsilon in Eq 7.
    eps:
        used to add to std dev of mb_advantages when normalizing (to avoid dividing by zero)
    """
    assert mb_action.shape == mb_advantages.shape == mb_logprobs.shape
    logits_diff = probs.log_prob(mb_action) - mb_logprobs

    prob_ratio = t.exp(logits_diff)

    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + eps)

    non_clipped = prob_ratio * mb_advantages
    clipped = t.clip(prob_ratio, 1 - clip_coef, 1 + clip_coef) * mb_advantages

    return t.minimum(non_clipped, clipped).mean()


In [30]:
def calc_value_function_loss(
    values: Float[Tensor, "minibatch_size"], mb_returns: Float[Tensor, "minibatch_size"], vf_coef: float
) -> Float[Tensor, ""]:
    """Compute the value function portion of the loss function.

    values:
        the value function predictions for the sampled minibatch (using the updated critic network)
    mb_returns:
        the target for our updated critic network (computed as `advantages + values` from the old network)
    vf_coef:
        the coefficient for the value loss, which weights its contribution to the overall loss. Denoted by c_1 in the paper.
    """
    assert values.shape == mb_returns.shape

    return vf_coef * (values - mb_returns).pow(2).mean()


In [31]:
def calc_entropy_bonus(dist: Categorical, ent_coef: float):
    """Return the entropy bonus term, suitable for gradient ascent.

    dist:
        the probability distribution for the current policy
    ent_coef:
        the coefficient for the entropy loss, which weights its contribution to the overall objective function. Denoted by c_2 in the paper.
    """
    return ent_coef * dist.entropy().mean()

In [32]:
class PPOScheduler:
    def __init__(self, optimizer: Optimizer, initial_lr: float, end_lr: float, total_phases: int):
        self.optimizer = optimizer
        self.initial_lr = initial_lr
        self.end_lr = end_lr
        self.total_phases = total_phases
        self.n_step_calls = 0

    def step(self):
        """Implement linear learning rate decay so that after `total_phases` calls to step, the learning rate is end_lr.

        Do this by directly editing the learning rates inside each param group (i.e. `param_group["lr"] = ...`), for each param
        group in `self.optimizer.param_groups`.
        """
        self.n_step_calls += 1
        frac = self.n_step_calls / self.total_phases
        assert frac <= 1
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = self.initial_lr + frac * (self.end_lr - self.initial_lr)


def make_optimizer(
    actor: nn.Module, critic: nn.Module, total_phases: int, initial_lr: float, end_lr: float = 0.0
) -> tuple[optim.Adam, PPOScheduler]:
    """
    Return an appropriately configured Adam with its attached scheduler.
    """
    optimizer = optim.AdamW(
        itertools.chain(actor.parameters(), critic.parameters()), lr=initial_lr, eps=1e-5, maximize=True
    )
    scheduler = PPOScheduler(optimizer, initial_lr, end_lr, total_phases)
    return optimizer, scheduler

In [33]:
class PPOTrainer:
    def __init__(self, args: PPOArgs):
        set_global_seeds(args.seed)
        self.args = args
        self.run_name = f"{args.env_id}__{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y%m%d-%H%M%S')}"
        self.envs = gym.vector.SyncVectorEnv(
            [make_env(idx=idx, run_name=self.run_name, **args.__dict__) for idx in range(args.num_envs)]
        )

        # Define some basic variables from our environment
        self.num_envs = self.envs.num_envs
        self.action_shape = self.envs.single_action_space.shape
        self.obs_shape = self.envs.single_observation_space.shape

        # Create our replay memory
        self.memory = ReplayMemory(
            self.num_envs,
            self.obs_shape,
            self.action_shape,
            args.batch_size,
            args.minibatch_size,
            args.batches_per_learning_phase,
            args.seed,
        )

        # Create our networks & optimizer
        self.actor, self.critic = get_actor_and_critic(self.envs, mode=args.mode)
        self.optimizer, self.scheduler = make_optimizer(self.actor, self.critic, args.total_training_steps, args.lr)

        # Create our agent
        self.agent = PPOAgent(self.envs, self.actor, self.critic, self.memory)

    def rollout_phase(self) -> dict | None:
        """
        This function populates the memory with a new set of experiences, using `self.agent.play_step` to step through
        the environment. It also returns a dict of data which you can include in your progress bar postfix.
        """
        data = None
        t0 = time.time()

        for step in range(self.args.num_steps_per_rollout):
            # Play a step, returning the infos dict (containing information for each environment)
            infos = self.agent.play_step()

            # Get data from environments, and log it if some environment did actually terminate
            new_data = get_episode_data_from_infos(infos)
            if new_data is not None:
                data = new_data
                if self.args.use_wandb:
                    wandb.log(new_data, step=self.agent.step)

        if self.args.use_wandb:
            wandb.log(
                {"SPS": (self.args.num_steps_per_rollout * self.num_envs) / (time.time() - t0)}, step=self.agent.step
            )

        return data

    def learning_phase(self) -> None:
        """
        This function does the following:
            - Generates minibatches from memory
            - Calculates the objective function, and takes an optimization step based on it
            - Clips the gradients (see detail #11)
            - Steps the learning rate scheduler
        """
        minibatches = self.agent.get_minibatches(self.args.gamma, self.args.gae_lambda)
        for minibatch in minibatches:
            objective_fn = self.compute_ppo_objective(minibatch)
            objective_fn.backward()
            nn.utils.clip_grad_norm_(
                list(self.actor.parameters()) + list(self.critic.parameters()), self.args.max_grad_norm
            )
            self.optimizer.step()
            self.optimizer.zero_grad()
        self.scheduler.step()

    def compute_ppo_objective(self, minibatch: ReplayMinibatch) -> Float[Tensor, ""]:
        """
        Handles learning phase for a single minibatch. Returns objective function to be maximized.
        """
        logits = self.actor(minibatch.obs)
        dist = Categorical(logits=logits)
        values = self.critic(minibatch.obs).squeeze()

        clipped_surrogate_objective = calc_clipped_surrogate_objective(
            dist, minibatch.actions, minibatch.advantages, minibatch.logprobs, self.args.clip_coef
        )
        value_loss = calc_value_function_loss(values, minibatch.returns, self.args.vf_coef)
        entropy_bonus = calc_entropy_bonus(dist, self.args.ent_coef)

        total_objective_function = clipped_surrogate_objective - value_loss + entropy_bonus

        with t.inference_mode():
            newlogprob = dist.log_prob(minibatch.actions)
            logratio = newlogprob - minibatch.logprobs
            ratio = logratio.exp()
            approx_kl = (ratio - 1 - logratio).mean().item()
            clipfracs = [((ratio - 1.0).abs() > self.args.clip_coef).float().mean().item()]
        if self.args.use_wandb:
            wandb.log(
                dict(
                    total_steps=self.agent.step,
                    values=values.mean().item(),
                    lr=self.scheduler.optimizer.param_groups[0]["lr"],
                    value_loss=value_loss.item(),
                    clipped_surrogate_objective=clipped_surrogate_objective.item(),
                    entropy=entropy_bonus.item(),
                    approx_kl=approx_kl,
                    clipfrac=np.mean(clipfracs),
                ),
                step=self.agent.step,
            )

        return total_objective_function

    def train(self) -> None:
        if args.use_wandb:
            wandb.init(
                project=self.args.wandb_project_name,
                entity=self.args.wandb_entity,
                name=self.run_name,
                monitor_gym=self.args.video_log_freq is not None,
            )
            wandb.watch([self.actor, self.critic], log="all", log_freq=50)

        pbar = tqdm(range(self.args.total_phases))
        last_logged_time = time.time()  # so we don't update the progress bar too much

        for phase in pbar:
            data = self.rollout_phase()
            if data is not None and time.time() - last_logged_time > 0.5:
                last_logged_time = time.time()
                pbar.set_postfix(phase=phase, **data)

            self.learning_phase()

        self.envs.close()
        if self.args.use_wandb:
            wandb.finish()

In [34]:
# args = PPOArgs(use_wandb=True, video_log_freq=50)
# trainer = PPOTrainer(args)
# trainer.train()

# Atari

In [35]:
gym.envs.registration.registry.keys()

dict_keys(['CartPole-v0', 'CartPole-v1', 'MountainCar-v0', 'MountainCarContinuous-v0', 'Pendulum-v1', 'Acrobot-v1', 'phys2d/CartPole-v0', 'phys2d/CartPole-v1', 'phys2d/Pendulum-v0', 'LunarLander-v2', 'LunarLanderContinuous-v2', 'BipedalWalker-v3', 'BipedalWalkerHardcore-v3', 'CarRacing-v2', 'Blackjack-v1', 'FrozenLake-v1', 'FrozenLake8x8-v1', 'CliffWalking-v0', 'Taxi-v3', 'tabular/Blackjack-v0', 'tabular/CliffWalking-v0', 'Reacher-v2', 'Reacher-v4', 'Pusher-v2', 'Pusher-v4', 'InvertedPendulum-v2', 'InvertedPendulum-v4', 'InvertedDoublePendulum-v2', 'InvertedDoublePendulum-v4', 'HalfCheetah-v2', 'HalfCheetah-v3', 'HalfCheetah-v4', 'Hopper-v2', 'Hopper-v3', 'Hopper-v4', 'Swimmer-v2', 'Swimmer-v3', 'Swimmer-v4', 'Walker2d-v2', 'Walker2d-v3', 'Walker2d-v4', 'Ant-v2', 'Ant-v3', 'Ant-v4', 'Humanoid-v2', 'Humanoid-v3', 'Humanoid-v4', 'HumanoidStandup-v2', 'HumanoidStandup-v4', 'GymV26Environment-v0', 'GymV21Environment-v0', 'Adventure-v0', 'AdventureDeterministic-v0', 'AdventureNoFrameskip-v0

In [36]:
env = gym.make("ALE/Breakout-v5", render_mode="rgb_array")

print(env.action_space)  # Discrete(4): 4 actions to choose from
print(env.observation_space)  # Box(0, 255, (210, 160, 3), uint8): an RGB image of the game screen

Discrete(4)
Box(0, 255, (210, 160, 3), uint8)


In [37]:
print(env.get_action_meanings())

['NOOP', 'FIRE', 'RIGHT', 'LEFT']


In [38]:
def display_frames(frames: Int[Arr, "timesteps height width channels"], figsize=(4, 5)):
    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(frames[0])
    plt.close()

    def update(frame):
        im.set_array(frame)
        return [im]

    ani = FuncAnimation(fig, update, frames=frames, interval=100)
    display(HTML(ani.to_jshtml()))


nsteps = 150

frames = []
obs, info = env.reset()
for _ in tqdm(range(nsteps)):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    frames.append(obs)

display_frames(np.stack(frames))

100%|██████████| 150/150 [00:00<00:00, 4329.02it/s]


In [44]:
env_wrapped = prepare_atari_env(env)

frames = []
obs, info = env_wrapped.reset()
for _ in tqdm(range(nsteps)):
    action = env_wrapped.action_space.sample()
    obs, reward, terminated, truncated, info = env_wrapped.step(action)
    obs = einops.repeat(np.array(obs), "frames h w -> h (frames w) 3")  # stack frames across the row
    frames.append(obs)

display_frames(np.stack(frames), figsize=(12, 3))

100%|██████████| 150/150 [00:00<00:00, 1972.60it/s]


In [40]:
def get_actor_and_critic_atari(obs_shape: tuple[int,], num_actions: int) -> tuple[nn.Sequential, nn.Sequential]:
    """
    Returns (actor, critic) in the "atari" case, according to diagram above.
    """
    assert obs_shape[-1] % 8 == 4

    L_after_convolutions = (obs_shape[-1] // 8) - 3
    in_features = 64 * L_after_convolutions * L_after_convolutions

    hidden = nn.Sequential(
        layer_init(nn.Conv2d(4, 32, 8, stride=4, padding=0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(32, 64, 4, stride=2, padding=0)),
        nn.ReLU(),
        layer_init(nn.Conv2d(64, 64, 3, stride=1, padding=0)),
        nn.ReLU(),
        nn.Flatten(),
        layer_init(nn.Linear(in_features, 512)),
        nn.ReLU(),
    )

    actor = nn.Sequential(hidden, layer_init(nn.Linear(512, num_actions), std=0.01))
    critic = nn.Sequential(hidden, layer_init(nn.Linear(512, 1), std=1))

    return actor, critic

In [41]:
args = PPOArgs(
    env_id="ALE/Breakout-v5",
    wandb_project_name="PPOAtari",
    use_wandb=True,
    mode="atari",
    clip_coef=0.1,
    num_envs=8,
    video_log_freq=25,
)
trainer = PPOTrainer(args)

In [42]:
trainer.train()

 67%|██████▋   | 327/488 [46:48<23:02,  8.59s/it, episode_duration=1, episode_length=241, episode_reward=2, phase=326]    


KeyboardInterrupt: 