In [None]:
import math, random

import gymnasium as gym
import numpy as np
from collections import deque

import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd 
import torch.nn.functional as F

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
USE_CUDA = torch.cuda.is_available()
device = lambda inp: inp.cuda() if USE_CUDA else inp

In [None]:
## ENVIRONMENT

env_id = "CartPole-v1"

In [None]:
## REPLAY BUFFER

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        state      = np.expand_dims(state, 0)
        next_state = np.expand_dims(next_state, 0)
            
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.concatenate(state), action, reward, np.concatenate(next_state), done
    
    def __len__(self):
        return len(self.buffer)

In [None]:
## NEURAL NETWORK

class Net(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(Net, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(num_inputs, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, num_actions)
        )
        self.num_actions = num_actions
        
    def forward(self, x):
        return self.layers(x)
    
    def eps_act(self, state, epsilon):
        select = np.random.rand(state.shape[0]) > epsilon
        action = select * self.greedy_act(state) + np.logical_not(select) * np.random.randint(0, self.num_actions, size=state.shape[0])
        return action
    
    def greedy_act(self, state):
        state = device(torch.FloatTensor(state))
        with torch.no_grad():
            q_values = self.forward(state)
        action = q_values.max(1)[1].cpu().numpy()
        return action

DQN

Paper: https://arxiv.org/pdf/1312.5602v1.pdf

target: $R_{t+1}+\gamma\max_{a'}{Q\left(S_{t+1}, a';\theta^-_t\right)}$

In [None]:
## DQN Agent

class DQNAgent:
    def __init__(self, env_id, eps, gamma, lr, num_frames, rep_buf_size, batch_size, tgt_upd_delay):
        self.envs = gym.vector.make(env_id, num_envs=batch_size)
        self.eps = eps
        self.gamma = gamma
        self.lr = lr
        self.num_frames = num_frames
        self.rep_buf_size = rep_buf_size
        self.batch_size = batch_size
        self.tgt_upd_delay = tgt_upd_delay
        
        self.model = device(Net(self.envs.single_observation_space.shape[0], self.envs.single_action_space.n))
        self.target = device(Net(self.envs.single_observation_space.shape[0], self.envs.single_action_space.n))
        self.target.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.criterion = nn.MSELoss()

        self.rep_buf = ReplayBuffer(rep_buf_size)

    def train(self):
        losses = []
        all_rewards = []
        episode_reward = np.zeros(self.batch_size)
        
        state, _ = self.envs.reset()
        for frame_idx in range(0, self.num_frames):
            action = self.model.eps_act(state, self.eps)
            
            next_state, reward, terminated, truncated, info = self.envs.step(action)
            done = terminated
            for i in range(action.shape[0]):
                if truncated[i]:
                    self.rep_buf.push(state[i,:], action[i], reward[i], info['final_observation'][i], done[i])
                else:
                    self.rep_buf.push(state[i,:], action[i], reward[i], next_state[i,:], done[i])

            if (frame_idx % self.tgt_upd_delay) == 0:
                self.target.load_state_dict(self.model.state_dict())
            loss = self.compute_td_loss()
            losses.append(loss.item())
            
            state = next_state
            
            episode_reward += reward
            all_rewards.extend(episode_reward[np.logical_or(done, truncated)].tolist())
            episode_reward[np.logical_or(done, truncated)] = 0
                
            if (frame_idx + 1) % 200 == 0:
                self.plot_training(frame_idx, all_rewards, losses)
        
        self.envs.close()

    def compute_td_loss(self):
        state, action, reward, next_state, done = self.rep_buf.sample(self.batch_size)
        
        state      = device(torch.FloatTensor(state))
        next_state = device(torch.FloatTensor(next_state))
        action     = device(torch.LongTensor(action))
        reward     = device(torch.FloatTensor(reward))
        done       = device(torch.FloatTensor(done))
    
        q_values = self.model(state)
        q_value  = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    
        with torch.no_grad():
            next_q_values = self.target(next_state)
        next_q_value = next_q_values.max(1)[0]
        
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        loss = self.criterion(q_value, expected_q_value)
            
        self.optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
        self.optimizer.step()
        
        return loss

    @staticmethod
    def plot_training(frame_idx, rewards, losses):
        clear_output(True)
        plt.figure(figsize=(20,5))
        plt.subplot(131)
        plt.title('episode: {}, total reward(ma-10): {}'.format(len(rewards), np.mean(rewards[-10:])))
        plt.plot(np.array(rewards)[:100 * (len(rewards) // 100)].reshape(-1, 100).mean(axis=1))
        plt.subplot(132)
        plt.title('frame: {}, loss(ma-10): {:.4f}'.format(frame_idx, np.mean(losses[-10:])))
        plt.plot(losses)
        plt.show()

In [None]:
## Training

dqn_agent = DQNAgent(env_id=env_id, eps=0.05, gamma=0.99, lr=5e-4, num_frames=50000, rep_buf_size=10000, batch_size=128, tgt_upd_delay=50)
dqn_agent.train()

In [None]:
## Visualization (Test)

env = gym.make(env_id, render_mode='human')
state, _ = env.reset()
done = False
while not done:
    action = dqn_agent.model.greedy_act(np.expand_dims(state, 0))
    state, reward, terminated, truncated, _ = env.step(action[0])
    done = terminated or truncated
    env.render()
env.close()