# CartPole Overview

![alt text](images/cartpole_dqn_overview.jpg)

# Setup

In [104]:
import gymnasium as gym
import numpy as np
import torch as t
import wandb
import warnings
from dataclasses import dataclass
import time

from gymnasium.spaces import Box, Discrete
from jaxtyping import Bool, Float, Int
from torch import nn, Tensor
from tqdm import tqdm, trange

warnings.filterwarnings("ignore")
Arr: np.ndarray

device = t.device("mps") if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else t.device("cpu")
print(f"Using device: {device}")

Using device: mps


Only two actions to take: move the cart left or right. We observe 4 numbers: position, velocity, angle, and the angular velocity.

In [105]:
env = gym.make("CartPole-v1", render_mode = "rgb_array")

print(f"Action space: {env.action_space}")
print(f"Observation space: {env.observation_space}")

Action space: Discrete(2)
Observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)


# Building Blocks

Setting up the QNetwork with a simple 3-layer NN with 10k parameters.

In [106]:
class QNetwork(nn.Module):
    def __init__(self, obs_shape, num_actions):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(obs_shape[0], 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, num_actions),
        )

    def forward(self, x):
        return self.net(x)
    
net = QNetwork(obs_shape = (4,), num_actions = 2)

num_params = sum(p.numel() for p in net.parameters())
print(f"Parameters: {num_params}")


Parameters: 10934


Defining the replay buffer. Thought the add function was neat since it slices off old elements. 
Since using mps, need to tensorify the 5 returned arrays.

I'm curious much the capacity affects the rate of catastrophic forgetting. Maybe exponential decay?

In [107]:
class ReplayBuffer:

    def __init__(self, obs_shape, action_shape, capacity, seed):
        self.obs_shape = obs_shape
        self.action_shape = action_shape
        self.capacity = int(capacity)
        self.seed = seed
        self.rng = np.random.default_rng(seed)

        # obs, actions, rewards, next_obs, terminated
        self.obs = np.empty((0, *self.obs_shape), dtype = np.float32)
        self.actions = np.empty((0, *self.action_shape), dtype = np.int32)
        self.rewards = np.empty(0, dtype = np.float32)
        self.next_obs = np.empty((0, *self.obs_shape), dtype = np.float32)
        self.terminated = np.empty(0, dtype = bool)

    def add(self, obs, actions, rewards, next_obs, terminated):
        # For single environment, we need to add batch dimension to match buffer structure
        
        # obs: shape (obs_features,) -> (1, obs_features)
        if obs.ndim == 1:
            obs = obs[np.newaxis, :]
        
        # actions: scalar or shape () -> (1,)
        if np.isscalar(actions):
            actions = np.array([actions])
        elif actions.ndim == 0:
            actions = actions[np.newaxis]
        
        # rewards: scalar or shape () -> (1,)
        if np.isscalar(rewards):
            rewards = np.array([rewards])
        elif rewards.ndim == 0:
            rewards = rewards[np.newaxis]
        
        # next_obs: shape (obs_features,) -> (1, obs_features)
        if next_obs.ndim == 1:
            next_obs = next_obs[np.newaxis, :]
        
        # terminated: scalar or shape () -> (1,)
        if np.isscalar(terminated):
            terminated = np.array([terminated])
        elif terminated.ndim == 0:
            terminated = terminated[np.newaxis]

        # Now concatenate with proper shapes
        self.obs = np.concatenate((self.obs, obs))[-self.capacity:]
        self.actions = np.concatenate((self.actions, actions))[-self.capacity:]
        self.rewards = np.concatenate((self.rewards, rewards))[-self.capacity:]
        self.next_obs = np.concatenate((self.next_obs, next_obs))[-self.capacity:]
        self.terminated = np.concatenate((self.terminated, terminated))[-self.capacity:]

    def sample(self, batch_size, device):
        # Sample from current buffer size, not capacity
        current_size = len(self.obs)
        if current_size == 0:
            raise ValueError("Cannot sample from empty buffer")
        
        indices = self.rng.integers(0, current_size, size = batch_size)

        obs_tensor = t.tensor(self.obs[indices], dtype = t.float32, device = device)
        actions_tensor = t.tensor(self.actions[indices], dtype = t.long, device = device)  # Use long for actions
        rewards_tensor = t.tensor(self.rewards[indices], dtype = t.float32, device = device)
        next_obs_tensor = t.tensor(self.next_obs[indices], dtype = t.float32, device = device)
        terminated_tensor = t.tensor(self.terminated[indices], dtype = t.bool, device = device)

        return obs_tensor, actions_tensor, rewards_tensor, next_obs_tensor, terminated_tensor

Only made sense to me mathetically until I realized it's analagous to ReLU being max(0,x)

Still don't exactly understand what's going on behind the scenes with `.detach().cpu().numpy()`, will need to dig a little deeper into the architecture

In [108]:
def linear_schedule(curr_step, start_e, end_e, exploration_fraction, total_timesteps):
    return start_e + (end_e - start_e) * min(curr_step / (exploration_fraction * total_timesteps), 1)

# returns the sampled action for each env
def epsilon_greedy_policy(envs, q_net, obs, epsilon):
    obs = t.from_numpy(obs).float().to(device)  # Move tensor to correct device
    if np.random.random() < epsilon:
        return np.random.randint(0, envs.action_space.n)
    else:
        q_values = q_net(obs)
        return q_values.argmax().detach().cpu().numpy().item()

# Args and Agents

Defining standard arguments for a DQN, for global, wandb, durations, hyperparameters, and rl-specific stuff.
Learned that `@dataclass` is for specific for classes that holds memory, automatically initializes stuff like `def __init__()`

In [109]:
@dataclass
class DQNArgs:
    
    seed = 0
    env_id = "CartPole"

    wandb_project_name = 'DQN CartPole'
    wandb_entity = None
    video_log_freq = 10000

    total_timesteps = 5e5
    steps_per_train = 1e1
    trains_per_target_update = 1e2
    buffer_size = 1e4

    batch_size = 128
    learning_rate = 2.5e-4

    gamma = 0.99
    start_e = 1.0
    end_e = 0.1
    exploration_fraction = 0.2

    def __post_init__(self):
        import pathlib
        self.total_training_steps = int((self.total_timesteps - self.buffer_size) // self.steps_per_train)
        
        # Create video save directory
        section_dir = pathlib.Path.cwd()
        self.video_save_path = section_dir / "videos"
        self.video_save_path.mkdir(exist_ok=True)

args = DQNArgs()


Standard implementation of our DQN Agent. We use true_next_obs, an augmented version of our next_obs with the information of whether we are terminated or truncated. Every single step, we add to the buffer and reset our observation, ready for the next.

In [110]:
class DQNAgent:
    def __init__(self, envs, buffer, q_network, start_e, end_e, exploration_fraction, total_timesteps):
        self.envs = envs
        self.buffer = buffer
        self.q_network = q_network
        self.start_e = start_e
        self.end_e = end_e
        self.exploration_fraction = exploration_fraction
        self.total_timesteps = total_timesteps
        
        self.step = 0
        self.obs, _ = envs.reset()
        self.epsilon = start_e

    def get_actions(self, obs):
        self.epsilon = linear_schedule(self.step, self.start_e, self.end_e, self.exploration_fraction, self.total_timesteps)
        actions = epsilon_greedy_policy(self.envs, self.q_network, self.obs, self.epsilon)
        return actions

    def play_step(self):
        self.obs = np.array(self.obs, dtype = np.float32)
        actions = self.get_actions(self.obs)
        next_obs, reward, terminated, truncated, infos = self.envs.step(actions)

        true_next_obs = next_obs.copy()
        if terminated | truncated:
            # Check if final_observation exists in infos, otherwise use next_obs
            if "final_observation" in infos:
                true_next_obs = infos["final_observation"]
            else:
                true_next_obs = next_obs

        self.buffer.add(self.obs, actions, reward, true_next_obs, terminated)
        
        # Reset environment if episode ended
        if terminated | truncated:
            self.obs, _ = self.envs.reset()
        else:
            self.obs = next_obs
        
        self.step += 1

        return infos

Gives a dict of the episode length & reward & duration for the first terminated env, or `None` if no envs terminate.

In [111]:
def get_episode_data_from_infos(infos):
    for final_info in infos.get("final_info", []):
        if final_info is not None and "episode" in final_info:
            return {"episode_length": final_info["episode"]["l"].item(), 
                    "episode_reward": final_info["episode"]["r"].item(),
                    "episode_duration": final_info["episode"]["t"].item()}

# Training

In [112]:
class DQNTrainer:
    def __init__(self, args):
        self.args = args
        self.rng = np.random.default_rng(args.seed)
        self.run_name = f"{args.env_id}_{args.wandb_project_name}__seed{args.seed}__{time.strftime('%Y-%m-%d_%H-%M-%S')}"
        self.envs = env

        # Create video recording environment
        if args.video_log_freq is not None:
            self.video_env = gym.make("CartPole-v1", render_mode="rgb_array")
            # Wrap with RecordVideo for automatic video recording
            from gymnasium.wrappers import RecordVideo
            self.video_env = RecordVideo(
                self.video_env,
                video_folder=str(args.video_save_path),
                episode_trigger=lambda x: True,  # Record every episode during evaluation
                name_prefix="dqn_cartpole"
            )
        else:
            self.video_env = None

        action_shape = self.envs.action_space.shape
        num_actions = self.envs.action_space.n
        obs_shape = self.envs.observation_space.shape

        self.buffer = ReplayBuffer(obs_shape, action_shape, args.buffer_size, args.seed)

        self.q_network = QNetwork(obs_shape, num_actions).to(device)
        self.target_network = QNetwork(obs_shape, num_actions).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = t.optim.AdamW(self.q_network.parameters(), lr = args.learning_rate)

        self.agent = DQNAgent(self.envs,self.buffer, self.q_network, args.start_e, args.end_e, args.exploration_fraction, args.total_timesteps)

    def add_to_replay_buffer(self, n: int, verbose = False):
        data = None
        t0 = time.time()

        for step in tqdm(range(n), disable = not verbose):
            infos = self.agent.play_step()
            new_data = get_episode_data_from_infos(infos)

            if new_data is not None:
                data = new_data
                wandb.log(new_data, step = self.agent.step)

        wandb.log({"Samples per second": n / (time.time() - t0)}, step = self.agent.step)
        return data
    
    def prepopulate_replay_buffer(self):
        n_steps_to_fill_buffer = self.args.buffer_size
        self.add_to_replay_buffer(int(n_steps_to_fill_buffer))
    
    def record_video_episode(self):
        """Record a single episode for video logging"""
        if self.video_env is None:
            return None
            
        obs, _ = self.video_env.reset()
        episode_reward = 0
        episode_length = 0
        terminated = False
        truncated = False
        
        while not (terminated or truncated):
            obs_tensor = t.from_numpy(obs).float().to(device)
            with t.no_grad():
                q_values = self.q_network(obs_tensor)
                action = q_values.argmax().item()  # Use greedy policy for evaluation
            
            obs, reward, terminated, truncated, _ = self.video_env.step(action)
            episode_reward += reward
            episode_length += 1
            
        return {"eval_episode_reward": episode_reward, "eval_episode_length": episode_length}
    
    def log_videos_to_wandb(self):
        """Find and log the most recent video to wandb"""
        import glob
        import os
        
        if self.video_env is None:
            return
            
        # Find the most recent video file
        video_files = glob.glob(str(self.args.video_save_path / "*.mp4"))
        if video_files:
            # Get the most recent video file
            latest_video = max(video_files, key=os.path.getctime)
            
            # Log video to wandb
            wandb.log({
                "eval_video": wandb.Video(latest_video, fps=30, format="mp4")
            }, step=self.agent.step)
    
    def training_step(self, step):
        obs, actions, rewards, next_obs, terminated = self.buffer.sample(self.args.batch_size, device)

        with t.inference_mode():
            target_max = self.target_network(next_obs).max(-1).values
        predicted_q_vals = self.q_network(obs)[range(len(actions)), actions]

        td_error = rewards + (1 - terminated.float()) * self.args.gamma * target_max - predicted_q_vals
        loss = td_error.pow(2).mean()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        if step % self.args.trains_per_target_update == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

        wandb.log({"td_loss": loss, "q_values": predicted_q_vals.mean().item(), "epsilon": self.agent.epsilon}, step = self.agent.step)

    def train(self):
        wandb.init(project = self.args.wandb_project_name, 
                   entity = self.args.wandb_entity,
                   name = self.run_name,
                   monitor_gym = False)  # We'll handle video logging manually
        wandb.watch(self.q_network, log = "all", log_freq = 50)

        pbar = tqdm(range(self.args.total_training_steps))
        last_logged_time = time.time()

        for step in pbar:
            data = self.add_to_replay_buffer(int(self.args.steps_per_train))
            if data is not None and time.time() - last_logged_time > 0.50:
                last_logged_time = time.time()
                pbar.set_postfix(**data)

            self.training_step(step)
            
            # Record and log videos at specified frequency
            if (self.args.video_log_freq is not None and 
                step > 0 and 
                step % self.args.video_log_freq == 0):
                
                eval_data = self.record_video_episode()
                if eval_data is not None:
                    wandb.log(eval_data, step=self.agent.step)
                    pbar.set_description(f"Eval reward: {eval_data['eval_episode_reward']:.1f}")
                
                self.log_videos_to_wandb()

        # Record final video
        if self.args.video_log_freq is not None:
            eval_data = self.record_video_episode()
            if eval_data is not None:
                wandb.log(eval_data, step=self.agent.step)
            self.log_videos_to_wandb()

        self.envs.close()
        if self.video_env is not None:
            self.video_env.close()
        wandb.finish()

In [113]:
args = DQNArgs()
trainer = DQNTrainer(args)
trainer.train()

 20%|██        | 9999/49000 [01:46<08:24, 77.26it/s] 

MoviePy - Building video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-0.mp4.
MoviePy - Writing video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-0.mp4



Eval reward: 299.0:  20%|██        | 10007/49000 [01:47<36:01, 18.04it/s]

MoviePy - Done !
MoviePy - video ready /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-0.mp4


Eval reward: 299.0:  41%|████      | 20000/49000 [03:50<04:54, 98.38it/s] 

MoviePy - Building video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-1.mp4.
MoviePy - Writing video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-1.mp4



Eval reward: 380.0:  41%|████      | 20010/49000 [03:50<19:49, 24.38it/s]

MoviePy - Done !
MoviePy - video ready /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-1.mp4


Eval reward: 380.0:  61%|██████    | 29999/49000 [05:34<03:08, 100.64it/s]

MoviePy - Building video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-2.mp4.
MoviePy - Writing video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-2.mp4



Eval reward: 500.0:  61%|██████    | 30010/49000 [05:35<13:08, 24.07it/s] 

MoviePy - Done !
MoviePy - video ready /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-2.mp4


Eval reward: 500.0:  82%|████████▏ | 39991/49000 [07:25<01:31, 98.92it/s] 

MoviePy - Building video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-3.mp4.
MoviePy - Writing video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-3.mp4



Eval reward: 500.0:  82%|████████▏ | 40010/49000 [07:26<05:54, 25.39it/s]

MoviePy - Done !
MoviePy - video ready /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-3.mp4


Eval reward: 500.0: 100%|██████████| 49000/49000 [09:02<00:00, 90.33it/s] 


MoviePy - Building video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-4.mp4.
MoviePy - Writing video /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-4.mp4





MoviePy - Done !
MoviePy - video ready /Users/alexwa/Documents/GitHub/rl/videos/dqn_cartpole-episode-4.mp4


0,1
Samples per second,█▆▃▄▄▂▂▂▅▃▂▁▂▃▄▁▂▃▃▃▃▂▃▃▃▃▃▃▃▃▁▃▄▃▃▂▂▃▃▃
epsilon,█▇▇▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
q_values,▁▁▃▄▄▅▅▅▆▆▆▆▆▆▆▇▆▆▆▆▆▇▇▇▇▇▇█████▇███████
td_loss,▁▁█▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁

0,1
Samples per second,1302.61934
epsilon,0.1
q_values,100.07554
td_loss,0.00183


In [None]:
test_args = DQNArgs()
test_args.buffer_size = 10
test_args.batch_size = 2

buffer = ReplayBuffer((4,), (), test_args.buffer_size, 0)

obs = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)  # Single obs
action = 1  # Single action
reward = 1.0  # Single reward
next_obs = np.array([1.1, 2.1, 3.1, 4.1], dtype=np.float32)  # Single next_obs
terminated = False  # Single terminated

buffer.add(obs, action, reward, next_obs, terminated)
print(f"Buffer size after adding one sample: {len(buffer.obs)}")
print(f"Buffer obs shape: {buffer.obs.shape}")
print(f"Buffer actions shape: {buffer.actions.shape}")

for i in range(3):
    buffer.add(obs + i, action, reward + i, next_obs + i, terminated)

print(f"Buffer size after adding 4 samples: {len(buffer.obs)}")

if len(buffer.obs) >= test_args.batch_size:
    sample_obs, sample_actions, sample_rewards, sample_next_obs, sample_terminated = buffer.sample(test_args.batch_size, device)
    print(f"Sample shapes - obs: {sample_obs.shape}, actions: {sample_actions.shape}, rewards: {sample_rewards.shape}")
    print("Buffer test successful!")
else:
    print("Not enough samples in buffer to test sampling")

print("ReplayBuffer fixes work correctly!")


Testing ReplayBuffer fixes...
Buffer size after adding one sample: 1
Buffer obs shape: (1, 4)
Buffer actions shape: (1,)
Buffer size after adding 4 samples: 4
Sample shapes - obs: torch.Size([2, 4]), actions: torch.Size([2]), rewards: torch.Size([2])
Buffer test successful!
ReplayBuffer fixes work correctly!
