# Setup

In [11]:
import gymnasium as gym
import numpy as np
import torch as t
import wandb
import warnings
from dataclasses import dataclass
import time

from gymnasium.spaces import Box, Discrete
from jaxtyping import Bool, Float, Int
from torch import nn, Tensor
from tqdm import tqdm, trange

warnings.filterwarnings("ignore")

device = t.device("mps") if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else t.device("cpu")
print(f"Using device: {device}")

Using device: mps


In [12]:
env = gym.make("CartPole-v1", render_mode = "rgb_array")

print(f"Action space: {env.action_space}") # Discrete(2), left or right action
print(f"Observation space: {env.observation_space}") # Box(4), position, velocity, angle, angular velocity

Action space: Discrete(2)
Observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)


# Building Blocks

Setting up the QNetwork with a simple 3-layer NN with 10k parameters.

In [13]:
class QNetwork(nn.Module):
    def __init__(self, num_obs, num_actions):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(num_obs, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_actions),
        )

    def forward(self, x):
        return self.net(x)
    
net = QNetwork(num_obs = 4, num_actions = 2)
num_params = sum(p.numel() for p in net.parameters())
print(f"Parameters: {num_params}")


Parameters: 10934


Defining the replay buffer. Thought the add function was neat since it slices off old elements. 
Since using mps, need to tensorify the 5 returned arrays.

I'm curious much the capacity affects the rate of catastrophic forgetting. Maybe exponential decay?

In [14]:
class ReplayBuffer:

    rng: np.random.Generator

    def __init__(self, obs_shape, action_shape, capacity, seed):
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.capacity = capacity
        self.seed = seed

        # obs, actions, rewards, next_obs, terminated
        self.obs = np.empty((0, *self.obs_shape), dtype = np.float32)
        self.actions = np.empty((0, *self.action_shape), dtype = np.float32)
        self.rewards = np.empty(0, dtype = np.float32)
        self.next_obs = np.empty((0, *self.obs_shape), dtype = np.float32)
        self.terminated = np.empty(0, dtype = bool)

    def add(self, obs, actions, rewards, next_obs, terminated):
        self.obs = np.concatenate((self.obs, obs))[-self.capacity:]
        self.actions = np.concatenate((self.actions, actions))[-self.capacity:]
        self.rewards = np.concatenate((self.rewards, rewards))[-self.capacity:]
        self.next_obs = np.concatenate((self.next_obs, next_obs))[-self.capacity:]
        self.terminated = np.concatenate((self.terminated, terminated))[-self.capacity:]

    def sample(self, batch_size, device):
        indices = self.rng.integers(0, self.capacity, size = batch_size)

        obs_tensor = t.tensor(self.obs[indices], dtype = t.float32, device = device)
        actions_tensor = t.tensor(self.actions[indices], dtype = t.float32, device = device)
        rewards_tensor = t.tensor(self.rewards[indices], dtype = t.float32, device = device)
        next_obs_tensor = t.tensor(self.next_obs[indices], dtype = t.float32, device = device)
        terminated_tensor = t.tensor(self.terminated[indices], dtype = t.bool, device = device)

        return obs_tensor, actions_tensor, rewards_tensor, next_obs_tensor, terminated_tensor

Only made sense to me mathetically until I realized it's analagous to ReLU being max(0,x)

Still don't exactly understand what's going on behind the scenes with `.detach().cpu().numpy()`, will need to dig a little deeper into the architecture

In [15]:
def linear_schedule(curr_step, start_e, end_e, exploration_fraction, total_timesteps):
    return start_e + (end_e - start_e) * min(curr_step / (exploration_fraction * total_timesteps), 1)

# returns the sampled action for each env
def epsilon_greedy_policy(envs, q_net, obs, epsilon):
    obs = t.from_numpy(obs).float()
    if np.random.random() < epsilon:
        return np.random.Generator.integers(0, envs.single_action_space.n)
    else:
        q_values = q_net(obs)
        return q_values.argmax(dim = 1).detach().cpu().numpy()

# Args and Agents

Defining standard arguments for a DQN, for global, wandb, durations, hyperparameters, and rl-specific stuff.
Learned that `@dataclass` is for specific for classes that holds memory, automatically initializes stuff like `def __init__()`

In [16]:
@dataclass
class DQNArgs:
    
    seed = 0
    env_id = "CartPole"

    wandb_project_name = 'DQN CartPole'
    wandb_entity = None
    video_log_freq = 50

    total_timesteps = 1e6
    steps_per_train = 1e1
    trains_per_target_update = 1e2
    buffer_size = 1e4

    batch_size = 128
    learning_rate = 2.5e-4

    gamma = 0.99
    start_e = 1.0
    end_e = 0.1
    exploration_fraction = 0.2

    def __post_init__(self):
        self.total_training_steps = (self.total_timesteps - self.buffer_size) // self.steps_per_train

args = DQNArgs()


Standard implementation of our DQN Agent. We use true_next_obs, an augmented version of our next_obs with the information of whether we are terminated or truncated. Every single step, we add to the buffer and reset our observation, ready for the next.

In [17]:
class DQNAgent:
    def __init__(self, envs, buffer, q_network, start_e, end_e, exploration_fraction, total_timesteps):
        self.envs = envs
        self.buffer = buffer
        self.q_network = q_network
        self.start_e = start_e
        self.end_e = end_e
        self.exploration_fraction = exploration_fraction
        self.total_timesteps = total_timesteps
        
        self.step = 0
        self.obs, _ = envs.reset()
        self.epsilon = start_e

    def get_actions(self, obs):
        self.epsilon = linear_schedule(self.start_e, self.end_e, self.exploration_fraction, self.total_timesteps)
        actions = epsilon_greedy_policy(self.envs, self.q_network, self.obs, self.epsilon)
        return actions

    def play_step(self):
        self.obs = np.array(self.obs, dtype = np.float32)
        actions = self.get_actions(self.obs)
        next_obs, reward, terminated, truncated, infos = self.envs.step(actions)

        true_next_obs = next_obs.copy()
        if terminated | truncated:
            true_next_obs = infos["final_observation"]

        self.buffer.add(self.obs, actions, reward, true_next_obs, terminated)
        self.obs = true_next_obs
        self.step += 1

        return infos

Gives a dict of the episode length & reward & duration for the first terminated env, or `None` if no envs terminate.

In [18]:
def get_episode_data_from_infos(infos):
    for final_info in infos.get("final_info", []):
        if final_info is not None and "episode" in final_info:
            return {"episode_length": final_info["episode"]["l"].item(), 
                    "episode_reward": final_info["episode"]["r"].item(),
                    "episode_duration": final_info["episode"]["t"].item()}

# Training

In [20]:
class DQNTrainer:
    def __init__(self, args):
        self.args = args
        self.rng = np.random.default_rng(args.seed)
        self.run_name = f"{args.env_id}_{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y-%m-%d_%H-%M-%S')}"
        self.envs = env

        action_shape = self.envs.single_action_space.shape
        num_actions = self.envs.single_action_space.n
        obs_shape = self.envs.single_observation_space.shape

        self.buffer = ReplayBuffer(obs_shape, action_shape, args.buffer_size, args.seed)

        self.q_network = QNetwork(obs_shape, num_actions).to(device)
        self.target_q_network = QNetwork(obs_shape, num_actions).to(device)
        self.target_q_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = t.optim.AdamW(self.q_network.parameters(), lr = args.learning_rate)

        self.agent = DQNAgent(self.envs,self.buffer, self.q_network, args.start_e, args.end_e, args.exploration_fraction, args.total_timesteps)

    def prepopulate_replay_buffer(self):
        n_steps_to_fill_buffer = self.args.buffer_size
        self.add_to_replay_buffer(n_steps_to_fill_buffer)

    def add_to_replay_buffer(self, n, verbose):
        data = None
        t0 = time.time()

        for step in tqdm(range(n), disable = not verbose):
            infos = self.agent.play_step()
            new_data = get_episode_data_from_infos(infos)

            if new_data is not None:
                data = new_data
                wandb.log(new_data, step = self.agent.step)

        wandb.log({"Samples per second": n / (time.time() - t0)}, step = self.agent.step)
        return data
    
    