## DQN for continuous action spaces: Normalized Advantage Function (NAF)

In [None]:
%%capture

!apt-get update && apt-get install -y xvfb
!pip install swig
!pip install gym[box2d]==0.23.1 pytorch-lightning==1.6.0 pyvirtualdisplay

#### Setup virtual display

In [None]:
from pyvirtualdisplay import Display
Display(visible=False, size=(1400, 900)).start()

#### Import the necessary code libraries

In [None]:
import copy
import gym
import random
import torch

import numpy as np
import torch.nn.functional as F

from collections import deque, namedtuple
from IPython.display import HTML
from base64 import b64encode

from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset
from torch.optim import AdamW

from pytorch_lightning import LightningModule, Trainer

from gym.wrappers import RecordVideo, RecordEpisodeStatistics


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

In [None]:
def display_video(episode=0):
  video_file = open(f'/content/videos/rl-video-episode-{episode}.mp4', "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

#### Create the Deep Q-Network

In [None]:
class NafDQN(nn.Module):
    
  def __init__(self, hidden_size, obs_size, action_dims, max_action):
    super().__init__()
    self.action_dims = action_dims
    self.max_action = torch.from_numpy(max_action).to(device)
    self.net = nn.Sequential(
      nn.Linear(obs_size, hidden_size),
      nn.ReLU(),
      nn.Linear(hidden_size, hidden_size),
      nn.ReLU(),   
    )
    self.linear_mu = nn.Linear(hidden_size, action_dims)
    self.linear_value = nn.Linear(hidden_size, 1)
    self.linear_matrix = nn.Linear(hidden_size, int(action_dims * (action_dims + 1) / 2))

  @torch.no_grad()
  def mu(self, x):
    x = self.net(x)
    x = self.linear_mu(x)
    x = torch.tanh(x) * self.max_action
    return x
  
  @torch.no_grad()
  def value(self, x):
    x = self.net(x)
    x = self.linear_value(x)
    return x

  def forward(self, x, a):
    x = self.net(x)
    mu = torch.tanh(self.linear_mu(x)) * self.max_action
    value = self.linear_value(x)
    matrix = torch.tanh(self.linear_matrix(x))
    
    L = torch.zeros((x.shape[0], self.action_dims, self.action_dims)).to(device)
    tril_indices = torch.tril_indices(row=self.action_dims, col=self.action_dims, offset=0).to(device)

    L[:, tril_indices[0], tril_indices[1]] = matrix
    L.diagonal(dim1=1,dim2=2).exp_()
    P = L * L.transpose(2, 1)
    
    u_mu = (a-mu).unsqueeze(dim=1)
    u_mu_t = u_mu.transpose(1, 2)
    
    adv = - 1/2 * u_mu @ P @ u_mu_t
    adv = adv.squeeze(dim=-1)
    return value + adv


#### Create the policy

In [None]:
def noisy_policy(state, env, net, epsilon=0.0):
  state = torch.tensor([state]).to(device)
  amin = torch.from_numpy(env.action_space.low).to(device)
  amax = torch.from_numpy(env.action_space.high).to(device)
  mu = net.mu(state)
  mu = mu + torch.normal(0, epsilon, mu.size(), device=device)
  action = mu.clamp(amin, amax)
  action = action.squeeze().cpu().numpy()
  return action

#### Create the replay buffer

In [None]:
class ReplayBuffer:

  def __init__(self, capacity):
    self.buffer = deque(maxlen=capacity)

  def __len__(self):
    return len(self.buffer)
  
  def append(self, experience):
    self.buffer.append(experience)
  
  def sample(self, batch_size):
    return random.sample(self.buffer, batch_size)

In [None]:
class RLDataset(IterableDataset):

  def __init__(self, buffer, sample_size=400):
    self.buffer = buffer
    self.sample_size = sample_size
  
  def __iter__(self):
    for experience in self.buffer.sample(self.sample_size):
      yield experience

#### Create the environment

In [None]:
class RepeatActionWrapper(gym.Wrapper):
  def __init__(self, env, n):
    super().__init__(env)
    self.env = env
    self.n = n
      
  def step(self, action):
    done = False
    total_reward = 0.0
    for _ in range(self.n):
      next_state, reward, done, info = self.env.step(action)
      total_reward += reward
      if done:
        break
    return next_state, total_reward, done, info

In [None]:
def create_environment(name):
  env = gym.make(name)
  env = RecordVideo(env, video_folder='./videos', episode_trigger=lambda x: x % 50 == 0)
  env = RepeatActionWrapper(env, n=8)
  env = RecordEpisodeStatistics(env)
  return env

#### Update the target network

In [None]:
def polyak_average(net, target_net, tau=0.01):
    for qp, tp in zip(net.parameters(), target_net.parameters()):
        tp.data.copy_(tau * qp.data + (1 - tau) * tp.data)

#### Create the Deep Q-Learning algorithm

In [None]:
class NAFDeepQLearning(LightningModule):
                             
  def __init__(self, env_name, policy=noisy_policy, capacity=100_000, 
               batch_size=256, lr=1e-4, hidden_size=512, gamma=0.99, 
               loss_fn=F.smooth_l1_loss, optim=AdamW, eps_start=2.0, eps_end=0.2, 
               eps_last_episode=1_000, samples_per_epoch=1_000, tau=0.01):

    super().__init__()
    self.env = create_environment(env_name)

    obs_size = self.env.observation_space.shape[0]
    action_dims = self.env.action_space.shape[0]
    max_action = self.env.action_space.high

    self.q_net = NafDQN(hidden_size, obs_size, action_dims, max_action).to(device)
    self.target_q_net = copy.deepcopy(self.q_net)
    self.policy = policy

    self.buffer = ReplayBuffer(capacity=capacity)

    self.save_hyperparameters()

    while len(self.buffer) < self.hparams.samples_per_epoch:

      print(f"{len(self.buffer)} samples in experience buffer. Filling...")
      self.play_episode(epsilon=self.hparams.eps_start)
  
  @torch.no_grad()
  def play_episode(self, policy=None, epsilon=0.):
    obs = self.env.reset()
    done = False

    while not done:
      if policy:
        action = policy(obs, self.env, self.q_net, epsilon=epsilon)
      else:
        action = self.env.action_space.sample()
        
      next_obs, reward, done, info = self.env.step(action)
      exp = (obs, action, reward, done, next_obs)
      self.buffer.append(exp)
      obs = next_obs
  
  def forward(self, x):
    output = self.q_net(x)
    return output

  def configure_optimizers(self):
    q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.lr)
    return [q_net_optimizer]

  def train_dataloader(self):
    dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=self.hparams.batch_size,
    )
    return dataloader

  def training_step(self, batch, batch_idx):
    states, actions, rewards, dones, next_states = batch
    rewards = rewards.unsqueeze(1)
    dones = dones.unsqueeze(1)

    action_values = self.q_net(states, actions)

    next_state_values = self.target_q_net.value(next_states)
    next_state_values[dones] = 0.0
    
    target = rewards + self.hparams.gamma * next_state_values

    loss = self.hparams.loss_fn(action_values, target)
    self.log('episode/MSE Loss', loss)
    return loss

  def training_epoch_end(self, training_step_outputs):

    epsilon = max(
        self.hparams.eps_end,
        self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode
    )

    self.play_episode(policy=self.policy, epsilon=epsilon)
    
    polyak_average(self.q_net, self.target_q_net, tau=self.hparams.tau)
    
    self.log("episode/Return", self.env.return_queue[-1])

#### Purge logs and run the visualization tool (Tensorboard)

In [None]:
# Start tensorboard.
!rm -r /content/lightning_logs/
!rm -r /content/videos/
%load_ext tensorboard
%tensorboard --logdir /content/lightning_logs/

#### Train the policy

In [None]:
algo = NAFDeepQLearning('LunarLanderContinuous-v2')

trainer = Trainer(
    gpus=num_gpus, 
    max_epochs=10_000
)

trainer.fit(algo)

#### Check the resulting policy

In [None]:
display_video(episode=4300)