# Install and Import Dependencies

In [2]:
!pip install gymnasium torch numpy mediapy

Collecting mediapy
  Downloading mediapy-1.2.2-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting n

In [1]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
from copy import deepcopy
import mediapy
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Create and Visualize Environment
In this section, we create a cartpole environment, in which the users choose to move the cartpole left or right to balance it. We visualize an episode with a random policy.

In [2]:
env = gym.make('CartPole-v1', render_mode='rgb_array')
print("State Space Shape:", env.observation_space.shape)
print("Number of Actions:", env.action_space.n)
env.reset()

# visualize environment with random actions
done = False
truncated = False
images = [env.render()]

while not (done or truncated):
    action = env.action_space.sample()
    state, reward, done, truncated, info = env.step(action)
    img = env.render()
    images.append(img)

env.close()

mediapy.show_video(images, fps=30)

State Space Shape: (4,)
Number of Actions: 2


0
This browser does not support the video tag.


In [3]:
# Helper Functions
def to_tensor(x):
    if isinstance(x, np.ndarray):
        return torch.tensor(x, dtype=torch.float32, device=device)
    elif isinstance(x, torch.Tensor):
        return x.to(device)
    else:
        x = np.array(x)
        return torch.tensor(x, dtype=torch.float32, device=device)


## DQN Implementation

In [4]:
# Hyperparameters
GAMMA = 0.99  # Discount factor
LR = 0.001  # Learning rate
MAX_EPSILON = 1.0  # Maximum epsilon
MIN_EPSILON = 0.01  # Minimum epsilon
EPSILON_STEPS = 12500 # Epsilon dacay step
BATCH_SIZE = 128  # Batch size
MEMORY_SIZE = 50000  # Replay buffer size
TRAIN_STEPS = 50000  # Number of episodes to train the agent
EVAL_EPISODES = 10  # Number of episodes to evaluate the agent
TARGET_UPDATE_FREQ = 3  # Target network update frequency
EVAL_FREQUENCY = 2000 # Evaluation frequency

In [5]:
# Define Q-Network (Neural Network)
class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ELU(),
            nn.Linear(64, 64),
            nn.ELU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        return self.model(x)

In [6]:
# Experience Replay Buffer
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def size(self):
        return len(self.buffer)

In [7]:
# DQN Agent
class DQNAgent:
    def __init__(self, env):
        self.env = env
        self.test_env = deepcopy(env)
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n

        self.q_network = QNetwork(self.state_dim, self.action_dim).to(device)
        self.target_network = deepcopy(self.q_network).to(device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=LR)
        self.memory = ReplayBuffer(MEMORY_SIZE)
        self.epsilon = MAX_EPSILON

    def select_action(self, state, eval=False):
        if random.random() < self.epsilon and not eval:
            return self.env.action_space.sample()  # Random action (exploration)
        else:
            state = to_tensor(state).unsqueeze(0)
            q_values = self.q_network(state)
            return torch.argmax(q_values, dim=1).item()  # Action with max Q-value (exploitation)

    def update_target_network(self):
        with torch.no_grad():
            for target_param, source_param in zip(self.target_network.parameters(), self.q_network.parameters()):
                target_param.data.copy_(0.9 * target_param.data + 0.1 * source_param.data)

    def train(self):
        if self.memory.size() < BATCH_SIZE:
            return

        # Sample a batch of experiences
        batch = self.memory.sample(BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = to_tensor(states)
        actions = to_tensor(actions).long()
        rewards = to_tensor(rewards)
        next_states = to_tensor(next_states)
        dones = to_tensor(dones)

        # Compute Q-values for current states
        q_values = self.q_network(states)
        q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Compute target Q-values
        next_q_values = self.target_network(next_states)
        next_q_value = next_q_values.max(1)[0]
        target = rewards + GAMMA * next_q_value * (1 - dones)

        # Compute loss
        loss = nn.MSELoss()(q_value, target)

        # Optimize the Q-network
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.decay_epsilon()

    def evaluate(self):
        rewards = []
        for episode in range(EVAL_EPISODES):
            state, info = self.test_env.reset(seed=episode+100)
            done = False
            truncated = False
            episode_reward = 0

            while not (done or truncated):
                with torch.no_grad():
                    action = self.select_action(state, eval=True)
                next_state, reward, done, truncated, info = self.test_env.step(action)
                state = next_state
                episode_reward += reward

            rewards.append(episode_reward)

        return np.mean(rewards)

    def vis_policy(self):
        rec_env = gym.make('CartPole-v1', render_mode='rgb_array')
        state, _ = rec_env.reset()
        done = False
        truncated = False
        images = [rec_env.render()]
        episode_reward = 0

        while not (done or truncated):
            action = self.select_action(state, eval=True)
            state, reward, done, truncated, info = rec_env.step(action)
            img = rec_env.render()
            images.append(img)
            episode_reward += reward
        print("Reward:", episode_reward)
        return images

    def decay_epsilon(self):
        if self.epsilon > MIN_EPSILON:
            self.epsilon -= 1 / EPSILON_STEPS

    def load(self, name):
        self.q_network.load_state_dict(torch.load(os.path.join('models', name)))

    def save(self, name):
        os.makedirs('models', exist_ok=True)
        torch.save(self.q_network.state_dict(), os.path.join('models', name))

    def get_model(self):
        return self.q_network.state_dict()

    def load_model(self, model):
        self.q_network.load_state_dict(model)


In [8]:
# Main loop
def train_dqn():
    env = gym.make('CartPole-v1')
    agent = DQNAgent(env)
    state, _ = env.reset()
    best_reward = -np.inf
    done = False
    truncated = False

    for step in range(1, TRAIN_STEPS+1):
        if done or truncated:
            state, _ = env.reset()
            done = False
            truncated = False

        action = agent.select_action(state)
        next_state, reward, done, truncated, _ = env.step(action)

        agent.memory.push(state, action, reward, next_state, done or truncated)
        agent.train()
        state = next_state


        # Update the target network periodically
        if step % TARGET_UPDATE_FREQ == 0:
            agent.update_target_network()

        # Evaluate the agent periodically
        if step % EVAL_FREQUENCY == 0:
            eval_reward = agent.evaluate()
            if eval_reward > best_reward:
                best_reward = eval_reward
                best_model = agent.get_model()
            print(f"Step {step}, Evaluation Reward: {eval_reward}, Best Reward: {best_reward}")

    agent.load_model(best_model)
    images = agent.vis_policy()
    env.close()

    mediapy.show_video(images, fps=30)




In [9]:
# Train a DQN agent and visualize the best policy
train_dqn()

Step 2000, Evaluation Reward: 206.5, Best Reward: 206.5
Step 4000, Evaluation Reward: 302.4, Best Reward: 302.4
Step 6000, Evaluation Reward: 227.6, Best Reward: 302.4
Step 8000, Evaluation Reward: 183.7, Best Reward: 302.4
Step 10000, Evaluation Reward: 174.5, Best Reward: 302.4
Step 12000, Evaluation Reward: 148.3, Best Reward: 302.4
Step 14000, Evaluation Reward: 120.5, Best Reward: 302.4
Step 16000, Evaluation Reward: 136.1, Best Reward: 302.4
Step 18000, Evaluation Reward: 167.2, Best Reward: 302.4
Step 20000, Evaluation Reward: 252.6, Best Reward: 302.4
Step 22000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 24000, Evaluation Reward: 131.8, Best Reward: 500.0
Step 26000, Evaluation Reward: 263.8, Best Reward: 500.0
Step 28000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 30000, Evaluation Reward: 195.5, Best Reward: 500.0
Step 32000, Evaluation Reward: 371.0, Best Reward: 500.0
Step 34000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 36000, Evaluation Reward: 500.

0
This browser does not support the video tag.


## Double DQN Implementation
In this section, we implement Double DQN based on DQN. Double DQN uses the main network to select the action and the target network to evaluate its value in Q-value updates.

In [10]:
# Double DQN Agent
class DoubleDQNAgent(DQNAgent):
    def train(self):
        if self.memory.size() < BATCH_SIZE:
            return

        # Sample a batch of experiences
        batch = self.memory.sample(BATCH_SIZE)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = to_tensor(states)
        actions = to_tensor(actions).long()
        rewards = to_tensor(rewards)
        next_states = to_tensor(next_states)
        dones = to_tensor(dones)

        # Compute Q-values for current states
        q_values = self.q_network(states)
        q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        ### !!! IMPORTANT !!! ###
        # Compute target Q-values for Double DQN
        next_q_values = self.target_network(next_states)
        next_q_value = next_q_values.gather(1, self.q_network(next_states).argmax(dim=1).unsqueeze(1)).squeeze(1)
        target = rewards + GAMMA * next_q_value * (1 - dones)

        # Compute loss
        loss = nn.MSELoss()(q_value, target)

        # Optimize the Q-network
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.decay_epsilon()

In [11]:
# Main loop
def train_double_dqn():
    env = gym.make('CartPole-v1')
    agent = DoubleDQNAgent(env)
    state, _ = env.reset()
    best_reward = -np.inf
    done = False
    truncated = False

    for step in range(1, TRAIN_STEPS+1):
        if done or truncated:
            state, _ = env.reset()
            done = False
            truncated = False

        action = agent.select_action(state)
        next_state, reward, done, truncated, _ = env.step(action)

        agent.memory.push(state, action, reward, next_state, done or truncated)
        agent.train()
        state = next_state


        # Update the target network periodically
        if step % TARGET_UPDATE_FREQ == 0:
            agent.update_target_network()

        # Evaluate the agent periodically
        if step % EVAL_FREQUENCY == 0:
            eval_reward = agent.evaluate()
            if eval_reward > best_reward:
                best_reward = eval_reward
                best_model = agent.get_model()
            print(f"Step {step}, Evaluation Reward: {eval_reward}, Best Reward: {best_reward}")

    agent.load_model(best_model)
    images = agent.vis_policy()
    env.close()

    mediapy.show_video(images, fps=30)

In [12]:
# Train a Double DQN agent and visualize the best policy
train_double_dqn()

Step 2000, Evaluation Reward: 159.8, Best Reward: 159.8
Step 4000, Evaluation Reward: 229.8, Best Reward: 229.8
Step 6000, Evaluation Reward: 194.8, Best Reward: 229.8
Step 8000, Evaluation Reward: 313.0, Best Reward: 313.0
Step 10000, Evaluation Reward: 204.7, Best Reward: 313.0
Step 12000, Evaluation Reward: 114.0, Best Reward: 313.0
Step 14000, Evaluation Reward: 159.9, Best Reward: 313.0
Step 16000, Evaluation Reward: 228.1, Best Reward: 313.0
Step 18000, Evaluation Reward: 169.5, Best Reward: 313.0
Step 20000, Evaluation Reward: 139.3, Best Reward: 313.0
Step 22000, Evaluation Reward: 119.1, Best Reward: 313.0
Step 24000, Evaluation Reward: 124.1, Best Reward: 313.0
Step 26000, Evaluation Reward: 452.3, Best Reward: 452.3
Step 28000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 30000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 32000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 34000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 36000, Evaluation Reward: 173.

0
This browser does not support the video tag.


## Dueling DQN Implementation
In this section, we implement Dueling DQN based on DQN. Dueling DQN separately estimates state value and advantage functions with two neural networks.

In [13]:
### !!! IMPORTANT !!! ###
# Define Dueling Q-Network (Neural Network)
class DuelingQNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DuelingQNetwork, self).__init__()
        self.feature_model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ELU(),
        )
        self.value = nn.Sequential(
            nn.Linear(64, 64),
            nn.ELU(),
            nn.Linear(64, output_dim)
        )
        self.advantage = nn.Sequential(
            nn.Linear(64, 64),
            nn.ELU(),
            nn.Linear(64, output_dim)
        )

    def forward(self, x):
        feat = self.feature_model(x)
        value = self.value(feat)
        advantage = self.advantage(feat) - self.advantage(feat).mean()
        return value + advantage

In [14]:
# Dueling DQN Agent
class DuelingDQNAgent(DQNAgent):
    def __init__(self, env):
        super().__init__(env)
        self.q_network = DuelingQNetwork(self.state_dim, self.action_dim).to(device)
        self.target_network = deepcopy(self.q_network).to(device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=LR)

In [15]:
# Main loop
def train_dueling_dqn():
    env = gym.make('CartPole-v1')
    agent = DuelingDQNAgent(env)
    state, _ = env.reset()
    best_reward = -np.inf
    done = False
    truncated = False

    for step in range(1, TRAIN_STEPS+1):
        if done or truncated:
            state, _ = env.reset()
            done = False
            truncated = False

        action = agent.select_action(state)
        next_state, reward, done, truncated, _ = env.step(action)

        agent.memory.push(state, action, reward, next_state, done or truncated)
        agent.train()
        state = next_state


        # Update the target network periodically
        if step % TARGET_UPDATE_FREQ == 0:
            agent.update_target_network()

        # Evaluate the agent periodically
        if step % EVAL_FREQUENCY == 0:
            eval_reward = agent.evaluate()
            if eval_reward > best_reward:
                best_reward = eval_reward
                best_model = agent.get_model()
            print(f"Step {step}, Evaluation Reward: {eval_reward}, Best Reward: {best_reward}")

    agent.load_model(best_model)
    images = agent.vis_policy()
    env.close()

    mediapy.show_video(images, fps=30)

In [16]:
# Train a Dueling DQN agent and visualize the best policy
train_dueling_dqn()

Step 2000, Evaluation Reward: 260.8, Best Reward: 260.8
Step 4000, Evaluation Reward: 311.6, Best Reward: 311.6
Step 6000, Evaluation Reward: 351.5, Best Reward: 351.5
Step 8000, Evaluation Reward: 321.3, Best Reward: 351.5
Step 10000, Evaluation Reward: 235.8, Best Reward: 351.5
Step 12000, Evaluation Reward: 238.2, Best Reward: 351.5
Step 14000, Evaluation Reward: 265.5, Best Reward: 351.5
Step 16000, Evaluation Reward: 233.2, Best Reward: 351.5
Step 18000, Evaluation Reward: 282.8, Best Reward: 351.5
Step 20000, Evaluation Reward: 233.7, Best Reward: 351.5
Step 22000, Evaluation Reward: 204.6, Best Reward: 351.5
Step 24000, Evaluation Reward: 168.4, Best Reward: 351.5
Step 26000, Evaluation Reward: 133.2, Best Reward: 351.5
Step 28000, Evaluation Reward: 198.9, Best Reward: 351.5
Step 30000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 32000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 34000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 36000, Evaluation Reward: 500.

0
This browser does not support the video tag.


## N-step DQN Implementation
In this section, we implement N-step DQN based on DQN. N-step DQN uses n-step TD returns, allowing it to incorporate more future rewards per update.

In [17]:
# N-Step Replay Buffer
class NStepReplayBuffer:
    def __init__(self, capacity, n_step, gamma):
        self.buffer = deque(maxlen=capacity)
        self.n_step_buffer = deque(maxlen=n_step)
        self.n_step = n_step
        self.gamma = gamma

    def push(self, state, action, reward, next_state, done):
        """Store an experience with N-Step return"""
        self.n_step_buffer.append((state, action, reward, next_state, done))

        # If we have enough steps in buffer, add the N-Step transition to main memory
        if len(self.n_step_buffer) == self.n_step:
            state, action, rewards, next_state, done = self._get_n_step_info()
            self.buffer.append((state, action, rewards, next_state, done))

    ### !!! IMPORTANT !!! ###
    def _get_n_step_info(self):
        """Compute N-Step return"""
        state, action, reward, _, done = self.n_step_buffer[0]
        next_state = self.n_step_buffer[-1][3]
        for i in range(1, len(self.n_step_buffer)):
            reward += self.gamma ** i * self.n_step_buffer[i][2]
            done = done or self.n_step_buffer[i][4]

        return state, action, reward, next_state, done

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def size(self):
        return len(self.buffer)

In [18]:
N_STEP = 3
class NStepDQNAgent(DQNAgent):
    def __init__(self, env):
        super().__init__(env)
        self.memory = NStepReplayBuffer(MEMORY_SIZE, N_STEP, GAMMA)

In [19]:
# Main loop
def train_nstep_dqn():
    env = gym.make('CartPole-v1')
    agent = NStepDQNAgent(env)
    state, _ = env.reset()
    best_reward = -np.inf
    done = False
    truncated = False

    for step in range(1, TRAIN_STEPS+1):
        if done or truncated:
            state, _ = env.reset()
            done = False
            truncated = False

        action = agent.select_action(state)
        next_state, reward, done, truncated, _ = env.step(action)

        agent.memory.push(state, action, reward, next_state, done or truncated)
        agent.train()
        state = next_state


        # Update the target network periodically
        if step % TARGET_UPDATE_FREQ == 0:
            agent.update_target_network()

        # Evaluate the agent periodically
        if step % EVAL_FREQUENCY == 0:
            eval_reward = agent.evaluate()
            if eval_reward > best_reward:
                best_reward = eval_reward
                best_model = agent.get_model()
            print(f"Step {step}, Evaluation Reward: {eval_reward}, Best Reward: {best_reward}")

    agent.load_model(best_model)
    images = agent.vis_policy()
    env.close()

    mediapy.show_video(images, fps=30)

In [20]:
train_nstep_dqn()

Step 2000, Evaluation Reward: 221.2, Best Reward: 221.2
Step 4000, Evaluation Reward: 352.1, Best Reward: 352.1
Step 6000, Evaluation Reward: 434.3, Best Reward: 434.3
Step 8000, Evaluation Reward: 406.4, Best Reward: 434.3
Step 10000, Evaluation Reward: 401.1, Best Reward: 434.3
Step 12000, Evaluation Reward: 396.6, Best Reward: 434.3
Step 14000, Evaluation Reward: 350.4, Best Reward: 434.3
Step 16000, Evaluation Reward: 404.4, Best Reward: 434.3
Step 18000, Evaluation Reward: 465.6, Best Reward: 465.6
Step 20000, Evaluation Reward: 476.0, Best Reward: 476.0
Step 22000, Evaluation Reward: 423.9, Best Reward: 476.0
Step 24000, Evaluation Reward: 498.4, Best Reward: 498.4
Step 26000, Evaluation Reward: 478.5, Best Reward: 498.4
Step 28000, Evaluation Reward: 471.5, Best Reward: 498.4
Step 30000, Evaluation Reward: 474.7, Best Reward: 498.4
Step 32000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 34000, Evaluation Reward: 500.0, Best Reward: 500.0
Step 36000, Evaluation Reward: 500.

0
This browser does not support the video tag.


## Prioritized Experiance Replay (PER) DQN Implementation
In this section, we implement PER DQN based on DQN. PER DQN prioritizes experiences with higher temporal-difference (TD) errors, increasing the likelihood of learning from more informative transitions.

In [21]:
# Prioritized Replay Buffer
class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha):
        self.capacity = capacity
        self.alpha = alpha  # Prioritization parameter
        self.buffer = []
        self.priorities = np.zeros((capacity,), dtype=np.float32)
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        """Store transition with max priority"""
        max_priority = self.priorities.max() if self.buffer else 1.0  # Highest priority
        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.position] = (state, action, reward, next_state, done)

        self.priorities[self.position] = max_priority  # Assign max priority
        self.position = (self.position + 1) % self.capacity  # Circular buffer

    ### !!! IMPORTANT !!! ###
    def sample(self, batch_size, beta):
        """Sample batch with probability proportional to priority"""
        if len(self.buffer) == 0:
            return [], [], [], [], [], []

        priorities = self.priorities[:len(self.buffer)]
        probabilities = priorities ** self.alpha
        probabilities /= probabilities.sum()  # Normalize

        indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
        samples = [self.buffer[idx] for idx in indices]

        # Compute importance sampling weights
        weights = (len(self.buffer) * probabilities[indices]) ** (-beta)
        weights /= weights.max()  # Normalize

        states, actions, rewards, next_states, dones = zip(*samples)
        return states, actions, rewards, next_states, dones, indices, weights

    ### !!! IMPORTANT !!! ###
    def update_priorities(self, indices, td_errors):
        """Update priorities based on TD error"""
        for i, td_error in zip(indices, td_errors):
            self.priorities[i] = abs(td_error) + 1e-2  # Small offset to avoid zero priority

    def size(self):
        return len(self.buffer)

In [22]:
ALPHA = 0.7
BETA = 0.4

# PER DQN Agent
class PERDQNAgent(DQNAgent):
    def __init__(self, env):
        super().__init__(env)
        self.memory = PrioritizedReplayBuffer(MEMORY_SIZE, ALPHA)
        self.beta = BETA  # Importance Sampling Correction

    def train(self):
        if self.memory.size() < BATCH_SIZE:
            return

        # Sample a batch of experiences
        states, actions, rewards, next_states, dones, indices, weights = self.memory.sample(BATCH_SIZE, self.beta)

        states = to_tensor(states)
        actions = to_tensor(actions).long()
        rewards = to_tensor(rewards)
        next_states = to_tensor(next_states)
        dones = to_tensor(dones)
        weights = to_tensor(weights)

        # Compute TD error
        q_values = self.q_network(states)
        q_value = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        next_q_values = self.target_network(next_states)
        next_q_value = next_q_values.max(1)[0]
        target = rewards + GAMMA * next_q_value * (1 - dones)

        td_errors = target - q_value
        loss = (weights * td_errors.pow(2)).mean()

        # Update priorities
        self.memory.update_priorities(indices, td_errors.detach().cpu().numpy())

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

In [23]:
# Main loop
def train_per_dqn():
    env = gym.make('CartPole-v1')
    agent = PERDQNAgent(env)
    state, _ = env.reset()
    best_reward = -np.inf
    done = False
    truncated = False

    for step in range(1, TRAIN_STEPS+1):
        if done or truncated:
            state, _ = env.reset()
            done = False
            truncated = False

        action = agent.select_action(state)
        next_state, reward, done, truncated, _ = env.step(action)

        agent.memory.push(state, action, reward, next_state, done or truncated)
        agent.train()
        state = next_state


        # Update the target network periodically
        if step % TARGET_UPDATE_FREQ == 0:
            agent.update_target_network()

        # Evaluate the agent periodically
        if step % EVAL_FREQUENCY == 0:
            eval_reward = agent.evaluate()
            if eval_reward > best_reward:
                best_reward = eval_reward
                best_model = agent.get_model()
            print(f"Step {step}, Evaluation Reward: {eval_reward}, Best Reward: {best_reward}")

    agent.load_model(best_model)
    images = agent.vis_policy()
    env.close()

    mediapy.show_video(images, fps=30)

In [24]:
train_per_dqn()

Step 2000, Evaluation Reward: 255.4, Best Reward: 255.4
Step 4000, Evaluation Reward: 222.8, Best Reward: 255.4
Step 6000, Evaluation Reward: 222.4, Best Reward: 255.4
Step 8000, Evaluation Reward: 236.9, Best Reward: 255.4
Step 10000, Evaluation Reward: 74.9, Best Reward: 255.4
Step 12000, Evaluation Reward: 83.7, Best Reward: 255.4
Step 14000, Evaluation Reward: 105.6, Best Reward: 255.4
Step 16000, Evaluation Reward: 130.6, Best Reward: 255.4
Step 18000, Evaluation Reward: 118.1, Best Reward: 255.4
Step 20000, Evaluation Reward: 151.7, Best Reward: 255.4
Step 22000, Evaluation Reward: 437.0, Best Reward: 437.0
Step 24000, Evaluation Reward: 160.3, Best Reward: 437.0
Step 26000, Evaluation Reward: 127.3, Best Reward: 437.0
Step 28000, Evaluation Reward: 110.2, Best Reward: 437.0
Step 30000, Evaluation Reward: 108.1, Best Reward: 437.0
Step 32000, Evaluation Reward: 105.0, Best Reward: 437.0
Step 34000, Evaluation Reward: 105.7, Best Reward: 437.0
Step 36000, Evaluation Reward: 112.3,

0
This browser does not support the video tag.
