In [None]:
import math, random

import gymnasium as gym
import numpy as np
from collections import deque

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
USE_CUDA = torch.cuda.is_available()
device = lambda inp: inp.cuda() if USE_CUDA else inp

In [None]:
## ENVIRONMENT

env_id = "CartPole-v1"

Prioritized experience replay

Paper: https://arxiv.org/pdf/1511.05952

In [None]:
## REPLAY BUFFER

class PrioritizedReplayBuffer:
    def __init__(self, capacity, num_training, prob_alpha=0.6, beta_0=0.4):
        self.prob_alpha = prob_alpha
        self.beta_0 = beta_0
        self.capacity = capacity
        self.num_training = num_training
        self.buffer = []
        self.pos = 0
        self.prio_alp = np.zeros((capacity,), dtype=np.float32)
        self.max_prio_alp = 1.
    
    def push(self, batch_state, batch_action, batch_reward, batch_next_state, batch_done):
        num_samples = batch_state.shape[0]
        buffer_len = len(self.buffer)
        pos = (self.pos + np.arange(num_samples)) % self.capacity
        
        if buffer_len < self.capacity:
            for i in range(num_samples):
                if (buffer_len + i) < self.capacity:
                    self.buffer.append(
                        (batch_state[i].reshape(1,-1),
                         batch_action[i],
                         batch_reward[i],
                         batch_next_state[i].reshape(1,-1),
                         batch_done[i])
                    )
                else:
                    if self.prio_alp[pos[i]] == self.max_prio_alp:
                        mask = np.ones_like(self.prio_alp)
                        mask[pos[i]] = 0
                        self.max_prio_alp = (self.prio_alp * mask).max()
                    self.buffer[pos[i]] = (
                        batch_state[i].reshape(1,-1),
                        batch_action[i],
                        batch_reward[i],
                        batch_next_state[i].reshape(1,-1),
                        batch_done[i]
                    )
        else:
            if self.prio_alp[pos].max() == self.max_prio_alp:
                mask = np.ones_like(self.prio_alp)
                mask[pos] = 0
                self.max_prio_alp = (self.prio_alp * mask).max()
            for i in range(num_samples):
                self.buffer[pos[i]] = (
                    batch_state[i].reshape(1,-1),
                    batch_action[i],
                    batch_reward[i],
                    batch_next_state[i].reshape(1,-1),
                    batch_done[i]
                )

        self.prio_alp[pos] = self.max_prio_alp
        self.pos = (self.pos + num_samples) % self.capacity
    
    def sample(self, batch_size, training_idx):
        beta = min(1.0, self.beta_0 + training_idx * (1.0 - self.beta_0) / self.num_training)
        buffer_len = len(self.buffer)
        
        if buffer_len == self.capacity:
            prio_alps = self.prio_alp
        else:
            prio_alps = self.prio_alp[:self.pos]

        # P_j = p_j^{\alpha} / {\sum_i p_i^{\alpha}}
        probs = prio_alps / prio_alps.sum()

        indices = np.random.choice(buffer_len, batch_size, p=probs)
        samples = [self.buffer[idx] for idx in indices]

        # w_j = (N  * P_j)^{-\beta}
        # w_j = w_j / {\max_i w_i}
        weights  = (buffer_len * probs[indices]) ** (-beta)
        weights /= (buffer_len * probs.min()) ** (-beta)
        weights  = np.array(weights, dtype=np.float32)
        
        states      = np.concatenate([val[0] for val in samples])
        actions     = [val[1] for val in samples]
        rewards     = [val[2] for val in samples]
        next_states = np.concatenate([val[3] for val in samples])
        dones       = [val[4] for val in samples]
        
        return states, actions, rewards, next_states, dones, indices, weights
    
    def update_priorities(self, batch_indices, batch_priorities):
        # p_i^{\alpha}
        prio_alps = batch_priorities ** self.prob_alpha
        self.prio_alp[batch_indices] = prio_alps
        max_prio_alps = prio_alps.max()
        if max_prio_alps > self.max_prio_alp:
            self.max_prio_alp = max_prio_alps

    def __len__(self):
        return len(self.buffer)

In [None]:
## NEURAL NETWORK

class Net(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(Net, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(num_inputs, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, num_actions)
        )
        self.num_actions = num_actions
        
    def forward(self, x):
        return self.layers(x)
    
    def eps_act(self, state, epsilon):
        select = np.random.rand(state.shape[0]) > epsilon
        action = select * self.greedy_act(state) + np.logical_not(select) * np.random.randint(0, self.num_actions, size=state.shape[0])
        return action
    
    def greedy_act(self, state):
        state = device(torch.FloatTensor(state))
        with torch.no_grad():
            q_values = self.forward(state)
        action = q_values.max(1)[1].cpu().numpy()
        return action

DDQN

Paper: https://arxiv.org/pdf/1509.06461

target: $R_{t+1}+{\gamma}Q\left(S_{t+1}, \arg\max_{a'}{Q\left(S_{t+1}, a';\theta_t\right)};\theta^-_t\right)$

In [None]:
## DDQN Agent

class DDQNAgent:
    def __init__(self, env_id, eps, gamma, lr, num_frames, rep_buf_size, batch_size, tau):
        self.envs = gym.vector.make(env_id, num_envs=batch_size)
        self.eps = eps
        self.gamma = gamma
        self.lr = lr
        self.num_frames = num_frames
        self.rep_buf_size = rep_buf_size
        self.batch_size = batch_size
        self.tau = tau
        
        self.model = device(Net(self.envs.single_observation_space.shape[0], self.envs.single_action_space.n))
        self.target = device(Net(self.envs.single_observation_space.shape[0], self.envs.single_action_space.n))
        self.target.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)

        self.rep_buf = PrioritizedReplayBuffer(rep_buf_size, num_frames)

    def train(self):
        losses = []
        all_rewards = []
        episode_reward = np.zeros(self.batch_size)
        
        state, _ = self.envs.reset()
        for frame_idx in range(0, self.num_frames):
            action = self.model.eps_act(state, self.eps)
            
            next_state, reward, terminated, truncated, _ = self.envs.step(action)
            done = np.logical_or(terminated, truncated)
            self.rep_buf.push(state, action, reward, next_state, done)

            model_w = self.model.state_dict()
            target_w = self.target.state_dict()
            for key in model_w:
                target_w[key] = model_w[key] * self.tau + target_w[key] * (1 - self.tau)
            self.target.load_state_dict(target_w)
            
            loss = self.compute_td_loss(frame_idx)
            losses.append(loss.item())
            
            state = next_state
            
            episode_reward += reward
            all_rewards.extend(episode_reward[done].tolist())
            episode_reward[done] = 0
                
            if (frame_idx + 1) % 200 == 0:
                self.plot_training(frame_idx, all_rewards, losses)
        
        self.envs.close()

    def compute_td_loss(self, frame_idx):
        state, action, reward, next_state, done, indices, weights = self.rep_buf.sample(self.batch_size, frame_idx)
        
        state      = device(torch.FloatTensor(state))
        next_state = device(torch.FloatTensor(next_state))
        action     = device(torch.LongTensor(action))
        reward     = device(torch.FloatTensor(reward))
        done       = device(torch.FloatTensor(done))
        weights    = device(torch.FloatTensor(weights))
        
        q_values = self.model(state)
        q_value  = q_values.gather(1, action.unsqueeze(1)).squeeze(1)

        with torch.no_grad():
            model_next_q_values = self.model(next_state)
        model_next_q_action = model_next_q_values.max(1)[1]
    
        with torch.no_grad():
            target_next_q_values = self.target(next_state)
        target_next_q_value = target_next_q_values.gather(1, model_next_q_action.unsqueeze(1)).squeeze(1)
        
        expected_q_value = reward + self.gamma * target_next_q_value * (1 - done)

        td_error = expected_q_value - q_value
        loss  = td_error.pow(2) * weights
        loss  = loss.mean()
        
        self.rep_buf.update_priorities(indices, np.abs(td_error.detach().cpu().numpy()) + 1e-5)
            
        self.optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
        self.optimizer.step()
        
        return loss

    @staticmethod
    def plot_training(frame_idx, rewards, losses):
        clear_output(True)
        plt.figure(figsize=(20,5))
        plt.subplot(131)
        plt.title('episode: {}, total reward(ma-10): {}'.format(len(rewards), np.mean(rewards[-10:])))
        plt.plot(np.array(rewards)[:100 * (len(rewards) // 100)].reshape(-1, 100).mean(axis=1))
        plt.subplot(132)
        plt.title('frame: {}, loss(ma-10): {:.4f}'.format(frame_idx, np.mean(losses[-10:])))
        plt.plot(losses)
        plt.show()

In [None]:
## Training

ddqn_agent = DDQNAgent(env_id=env_id, eps=0.05, gamma=0.99, lr=5e-4, num_frames=50000, rep_buf_size=10000, batch_size=128, tau=0.02)
ddqn_agent.train()

In [None]:
## Visualization (Test)

env = gym.make(env_id, render_mode='human')
state, _ = env.reset()
done = False
while not done:
    action = ddqn_agent.model.greedy_act(np.expand_dims(state, 0))
    state, reward, terminated, truncated, _ = env.step(action[0])
    done = terminated or truncated
    env.render()
env.close()