# Setup

In [12]:
import gymnasium as gym
import numpy as np
import torch as t
import wandb
import warnings

from gymnasium.spaces import Box, Discrete
from jaxtyping import Bool, Float, Int
from torch import nn, Tensor
from tqdm import tqdm, trange

warnings.filterwarnings("ignore")

device = t.device("mps") if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else t.device("cpu")
print(f"Using device: {device}")

Using device: mps


In [13]:
env = gym.make("CartPole-v1", render_mode = "rgb_array")

print(f"Action space: {env.action_space}") # Discrete(2), left or right action
print(f"Observation space: {env.observation_space}") # Box(4), position, velocity, angle, angular velocity

Action space: Discrete(2)
Observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)


# Building Blocks

Setting up the QNetwork with a simple 3-layer NN with 10k parameters.

In [8]:
class QNetwork(nn.Module):
    def __init__(self, num_obs, num_actions):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(num_obs, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_actions),
        )

    def forward(self, x):
        return self.net(x)
    
net = QNetwork(num_obs = 4, num_actions = 2)
num_params = sum(p.numel() for p in net.parameters())
print(f"Parameters: {num_params}")


Parameters: 10934


Defining the replay buffer. Thought the add function was neat since it slices off old elements. 
Since using mps, need to tensorify the 5 returned arrays.

I'm curious much the capacity affects the rate of catastrophic forgetting. Maybe exponential decay?

In [15]:
class ReplayBuffer:

    rng: np.random.Generator

    def __init__(self, num_envs, obs_shape, action_shape, capacity, seed):
        self.num_envs = num_envs
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.capacity = capacity
        self.seed = seed

        # obs, actions, rewards, next_obs, terminated
        self.obs = np.empty((0, *self.obs_shape), dtype = np.float32)
        self.actions = np.empty((0, *self.action_shape), dtype = np.float32)
        self.rewards = np.empty(0, dtype = np.float32)
        self.next_obs = np.empty((0, *self.obs_shape), dtype = np.float32)
        self.terminated = np.empty(0, dtype = bool)

    def add(self, obs, actions, rewards, next_obs, terminated):
        self.obs = np.concatenate((self.obs, obs))[-self.capacity:]
        self.actions = np.concatenate((self.actions, actions))[-self.capacity:]
        self.rewards = np.concatenate((self.rewards, rewards))[-self.capacity:]
        self.next_obs = np.concatenate((self.next_obs, next_obs))[-self.capacity:]
        self.terminated = np.concatenate((self.terminated, terminated))[-self.capacity:]

    def sample(self, batch_size, device):
        indices = self.rng.integers(0, self.capacity, size = batch_size)

        obs_tensor = t.tensor(self.obs[indices], dtype = t.float32, device = device)
        actions_tensor = t.tensor(self.actions[indices], dtype = t.float32, device = device)
        rewards_tensor = t.tensor(self.rewards[indices], dtype = t.float32, device = device)
        next_obs_tensor = t.tensor(self.next_obs[indices], dtype = t.float32, device = device)
        terminated_tensor = t.tensor(self.terminated[indices], dtype = t.bool, device = device)

        return obs_tensor, actions_tensor, rewards_tensor, next_obs_tensor, terminated_tensor