<a href="https://colab.research.google.com/github/murphybrendan/ml-courses/blob/main/huggingface/deep-rl/unit3/dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Setup

##Install necessary packages

In [2]:
!pip install wandb einops pygame stable_baselines3
!pip install gymnasium[classic_control,box2d,atari]
!pip install gymnasium[accept-rom-license]

Collecting wandb
  Using cached wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
Collecting stable_baselines3
  Downloading stable_baselines3-2.3.2-py3-none-any.whl (182 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.3/182.3 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m29.3 MB/s[0m eta [36m0:0

## Set up the virtual display

In [3]:
%%capture
!apt install python-opengl
!apt install xvfb
!pip3 install pyvirtualdisplay

In [4]:
# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

<pyvirtualdisplay.display.Display at 0x7e3301f95e40>

#DQN Implementation

Implement the Q-Network. It's a simple feed forward network with some number of hidden layers. The input dimension is the dimension of an observation, and the output dimension is the dimension of the action space.

In [10]:
import torch

class QNetwork(torch.nn.Module):
    def __init__(self, observation_dim, action_space_dim, hidden_layers=[128, 64]):
        super().__init__()
        layer_dim = [observation_dim] + hidden_layers
        layers = []
        for i in range(len(layer_dim)-1):
            layers.append(torch.nn.Linear(layer_dim[i], layer_dim[i+1]))
            layers.append(torch.nn.ReLU())
        layers.append(torch.nn.Linear(layer_dim[-1], action_space_dim))
        self.layers = torch.nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

In [11]:
q = QNetwork(10, 2)
q

QNetwork(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=2, bias=True)
  )
)

Define the epsilon-greedy policy

In [12]:
import random
from gymnasium.spaces import Space

def epsilon_greedy_policy(Q: QNetwork, observation: torch.Tensor, action_space: Space, epsilon: float) -> int:
    n = random.random()
    if n < epsilon:
        return action_space.sample()
    return Q(observation).argmax()


Implement DQN, using the ReplayBuffer from stable_baselines3

In [14]:
from stable_baselines3.common.buffers import ReplayBuffer
from gymnasium import Env
import torch.nn.functional as F
from tqdm.notebook import trange, tqdm


class DQN:
    def __init__(self,
                 env: Env, buffer_size=1000000, batch_size=32, gamma=0.99, train_freq=4, exploration_initial_eps=1.0, exploration_final_eps=0.05, learning_starts=100, target_update_interval=10000, learning_rate=0.0001) -> None:
        self.env = env
        self.batch_size = batch_size
        self.gamma = gamma
        self.train_freq = train_freq
        self.target_update_interval = target_update_interval
        self.exploration_initial_eps = exploration_initial_eps
        self.exploration_final_eps = exploration_final_eps

        self.replay_buffer = ReplayBuffer(buffer_size, env.observation_space, env.action_space)
        # TODO: env.observation.n is likely wrong, want the shape of a single observation
        self.q = QNetwork(env.observation_space.n, env.action_space.n)
        self.q_hat = QNetwork(env.observation_space.n, env.action_space.n).load_state_dict(self.q)
        self.optimizer = torch.optim.SGD(self.q.parameters(), learning_rate)
        self.last_target_update = learning_starts
        self.last_training_step = learning_starts
        self.timestep = 0


    def collect_rollouts(self):
        # Linear schedule for eps
        action = epsilon_greedy_policy(self.q, self.prev_observation, self.env.action_space, self.eps)
        next_observation, reward, terminated, truncated, info = self.env.step(action)
        self.replay_buffer.add(self.prev_observation, next_observation, action, reward, terminated, [info])
        self.prev_observation = next_observation
        self.timestep += self.env.num_envs

    def step(self):
        samples = self.replay_buffer.sample(self.batch_size, self.env)

        # Double DQN: use the Q network to choose the next action instead of taking the max over all actions
        next_actions = self.q(samples.next_observations).argmax()

        # Discounted future return is 0 if this was a terminating state
        td_target = self.gamma * self.q_hat(samples.next_observations)[next_actions] * (1.0 - samples.dones.astype(float))

        y = samples.rewards + td_target

        self.optimizer.zero_grad()
        loss = F.mse_loss(self.q(samples.observations)[samples.actions], y)
        loss.backward()

        self.optimizer.step()


    def learn(self, total_timesteps):
        self.prev_observation, _ = self.env.reset()
        for _ in trange(total_timesteps // self.env.num_envs):

            # Linear decay of epsilon over the course of training
            self.eps = self.exploration_initial_eps * (1 - self.timestep / total_timesteps) + self.exploration_final_eps * (self.timestep / total_timesteps)

            self.collect_rollouts()
            if self.last_training_step < self.timestep - self.train_freq:
                self.step()
                self.last_training_step = self.timestep
                if self.last_target_update < self.time_step - self.target_update_interval:
                    self.q_hat.load_state_dict(self.q)
                    self.last_target_update = self.timestep
