Ignore the over-annotated code, it's more brain-dumping.

# List of Tumbling Blocks

1. **Not adding type annotations to my dataclass**. This makes the class-level default values rather than instance attributes that can be set via the constructor.
2. **Mixing up standard deviations between actor & critic network**. The critic network needs a larger std (e.g. 1) to estimate returns over a widge range. The actor network needs a smaller std (e.g. 0.01) to make the policy more uniform at the beginning, which encourages exploration instead of action commitment. A small std for the actor network is *one of the most important initialisation details*.

# Setup

In [1]:
import gymnasium as gym
import numpy as np
import torch as t
import torch.nn as nn
import torch.optim as optim
import wandb
import warnings
import time
import einops

from dataclasses import dataclass
from gymnasium.spaces import Box, Discrete
from jaxtyping import Bool, Float, Int
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.optim.optimizer import Optimizer
from tqdm import tqdm, trange
from typing import Literal

warnings.filterwarnings("ignore")
Arr: np.ndarray

device = t.device("mps") if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else t.device("cpu")
print(f"Using device: {device}")

Using device: mps


In [2]:
# effectively @dataclass is always going to be paired with the Args class since it initalizes all of the arguments
# and gets rid of the the annoying initalizations in def __init__(self, )
@dataclass 
class PPOArgs:
    seed: int = 1
    env_id: str = "CartPole-PPO"
    mode: Literal["classic-control", "atari", "mujoco"] = "classic-control"

    total_timesteps: int = 500000
    num_envs: int = 4
    num_steps_per_rollout: int = 128
    num_minibatches: int = 4
    batches_per_learning_phase: int = 4

    lr: float = 2.5e-4
    max_grad_norm: float = 0.5

    gamma: float = 0.99
    gae_lambda: float = 0.95
    clip_coef: float = 0.2
    ent_coef: float = 0.01
    vf_coef: float = 0.25

    video_log_freq: int | None = None
    wandb_project_name: str = "Cartpole_PPO"
    wandb_entity: str = None

    # comments in reference to num_minibatches = 2
    def __post_init__(self):
        self.batch_size = self.num_steps_per_rollout * self.num_envs # 512

        self.minibatch_size = self.batch_size // self.num_minibatches # 256 
        self.total_phases = self.total_timesteps // self.batch_size # 976
        self.total_training_steps = self.total_phases * self.batches_per_learning_phase * self.num_minibatches # 7808

        self.video_save_path = "videos/cartpole_ppo"

args = PPOArgs(num_minibatches = 2)

In [3]:
# this is essentially He initialization
# orthogonality preserves magnitude and structures of signals as they propagate through the nn
def layer_init(layer, std = np.sqrt(2), bias_const = 0.0):
    t.nn.init.orthogonal_(layer.weight, std)
    t.nn.init.constant_(layer.bias, bias_const)
    return layer

def get_actor_and_critic_classic(num_obs, num_actions):
    actor = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, num_actions), std = 0.01),
    )

    critic = nn.Sequential(
        layer_init(nn.Linear(num_obs, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 64)),
        nn.Tanh(),
        layer_init(nn.Linear(64, 1), std = 1)
    )

    return actor, critic

# returns the networks used for PPO
# comments based on classic-control, atari, mujoco
def get_actor_and_critic(envs, mode):
    # (4, ), (84, 84, 4), (24, )
    obs_shape = envs.single_observation_space.shape
    # 4, 28224, 24
    num_obs = np.array(obs_shape).prod()
    # 2, 4, idk for mujuco yet
    num_actions = (envs.single_action_space.n if isinstance(envs.single_action_space, gym.spaces.Discrete)
                   else np.array(envs.single_action_space.shape).prod())
    
    if mode == "classic_control":
        actor, critic = get_actor_and_critic_classic(num_obs, num_actions)
    if mode == "atari":
        raise NotImplementedError()
    if mode == "mujoco":
        raise NotImplementedError()

In [4]:
@t.inference_mode()
def compute_advantages(next_value, next_terminated, rewards, values, terminated, gamma, gae_lambda):
    terminated = terminated.float()
    next_terminated = next_terminated.float()

    next_values = t.concat(values[1:], next_value[None, :])
    next_terminated = t.concat(terminated[1:], next_terminated[None, :])

    deltas = rewards + gamma * (1.0 - next_terminated) * next_values - values
    advantages = t.zeros_like(deltas)
    
    advantages[-1] = deltas[-1]

    for s in reversed(range(values.shape[0] - 1)):
        advantages[s] = rewards[s] + (1 - next_terminated[s]) * gamma * gae_lambda * advantages[s + 1]

    return advantages