<a href="https://colab.research.google.com/github/jbpacker/deep-rl-class/blob/main/unit3/DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Self made DQN implementation

Weights and Biases: https://wandb.ai/jefsnacker/DQN?workspace=user-jefsnacker

next steps:
1. eval mode that runs every n training steps
1. video replay for evaluation
1. longer run to see if it's actually training correctly

### Resources

PPO implementation details: https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/

Great discription of DQN can be found here: https://huggingface.co/blog/deep-rl-dqn

walkthrough code: https://pylessons.com/CartPole-DDQN

## Setup
Install relevant libraries and initialize virtual display

In [4]:
%%capture
!pip install pyglet==1.5.1 
!apt install python-opengl
!apt install ffmpeg
!apt install xvfb
!pip3 install pyvirtualdisplay

# Virtual display
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

In [5]:
%%capture
!pip install git+https://github.com/openai/gym.git # We install gym using git since Taxi-v3 "rgb_array version" is not on PyPi release
!pip install pygame
!pip install numpy
!pip install wandb

In [6]:
import numpy as np
import gym
import random
# import imageio
# import os
# import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb


# import pickle5 as pickle

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Environment

discrete action space: https://github.com/openai/gym/blob/master/gym/spaces/discrete.py

## Model

In [7]:
class DiscreteModel(nn.Module):
    def __init__(self, env):
        super().__init__()

        self.l1 = nn.Linear(env.observation_space.shape[0], 512)
        self.l2 = nn.Linear(512, 256)
        self.l3 = nn.Linear(256, 64)
        self.l4 = nn.Linear(64, env.action_space.n)

    def forward(self, x):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = self.l4(x)

        return x

In [None]:
input = torch.tensor([[-0.0140, -0.1768,  0.0392,  0.2849],
        [-0.0175, -0.3725,  0.0449,  0.5897],
        [-0.0250, -0.5682,  0.0567,  0.8962],
        [-0.0363, -0.7640,  0.0746,  1.2061]], dtype=torch.float32)

model = DiscreteModel(gym.make("CartPole-v1"))
out = model(input)
print(out)

## Helper Functions

In [10]:
def get_action_from_model(Q, obs):
    if not torch.is_tensor(obs):
        obs = torch.from_numpy(obs)

    pred = Q(obs)
    action = torch.argmax(pred, dim=1)
    return action

# eps percent chance of taking greedy action. otherwise random.
def sample_action_eps_greedy(eps, Q_model, obs, env):
    if random.random() > eps:
        return torch.tensor([[env.action_space.sample()]])
    else:
        with torch.no_grad():
            return get_action_from_model(Q_model, obs)

In [None]:
env = gym.make("CartPole-v1")
Q = DiscreteModel(env)
obs = env.reset()

input = torch.tensor([[-0.0140, -0.1768,  0.0392,  0.2849],
        [-0.0175, -0.3725,  0.0449,  0.5897],
        [-0.0250, -0.5682,  0.0567,  0.8962],
        [-0.0363, -0.7640,  0.0746,  1.2061]], dtype=torch.float32)

input = torch.tensor([[-0.0140, -0.1768,  0.0392,  0.2849]], dtype=torch.float32)

a = get_action_from_model(Q, input)

# obs = torch.from_numpy(input)
print(obs)
a = get_action_from_model(Q, input)
print(a)

a = sample_action_eps_greedy(0.0, Q, input, env)
print(a)

a = sample_action_eps_greedy(1.0, Q, input, env)
print(a)

## Hyperparameters


In [12]:
epochs = 500
max_rollout = 600
batch_size = 512
minibatch_size = 128

eps = 1.0
max_epsilon = 1.0
min_epsilon = 0.001
decay_rate = 0.08

epochs_to_reset_Q = 10

gamma = 0.95
lr = 0.0005


## Training Loop

In [13]:
def collect_data(env, next_obs, next_done, running_reward, eps, Q, log=False, print_debug=False):
    data = [
            np.empty((0, env.observation_space.shape[0]), dtype=np.float32), # obs
            np.empty((0, 1), dtype=np.float32), # action
            np.empty((0, env.observation_space.shape[0]), dtype=np.float32), # next_obs
            np.empty((0, 1), dtype=np.float32), # reward
            np.empty((0, 1), dtype=bool)  # done
            ] 
    rollout_steps = 0
    for step in range(0, batch_size):
        obs = next_obs
        done = next_done

        action = sample_action_eps_greedy(eps, Q, np.reshape(torch.from_numpy(obs), (1,-1)), env)

        next_obs, reward, next_done, info = env.step(action[0].item())
        running_reward += reward

        data[0] = np.append(data[0], np.reshape(obs, (1,-1)), axis=0)
        data[1] = np.append(data[1], np.reshape(action, (1,-1)), axis=0)
        data[2] = np.append(data[2], np.reshape(next_obs, (1,-1)), axis=0)
        data[3] = np.append(data[3], np.reshape(reward, (1,-1)), axis=0)
        data[4] = np.append(data[4], np.reshape(done, (1,-1)), axis=0)

        if done or rollout_steps > max_rollout:
            next_obs = env.reset()
            next_done = False

            if log:
                wandb.log({"rollout_steps": rollout_steps})
                wandb.log({"running_reward": running_reward})
            running_reward = 0
            rollout_steps = 0

        rollout_steps += 1
    return data, next_obs, next_done, running_reward



In [None]:
env = gym.make("CartPole-v1")
Q = DiscreteModel(env)

next_obs = env.reset()
next_done = False
running_reward = 0
data, next_obs, next_done, running_reward = collect_data(env, next_obs, next_done, running_reward, 1.0, Q)
print(next_obs)
print(next_done)
print(running_reward)

data, next_obs, next_done, running_reward = collect_data(env, next_obs, next_done, running_reward, 0.0, Q)
print(next_obs)
print(next_done)
print(running_reward)
print(data)


In [14]:
def train(log=False, print_debug=False):
    if log:
        wandb.init(project="DQN")
    env = gym.make("CartPole-v1")
    Q = DiscreteModel(env)
    Q_frozen = DiscreteModel(env)

    Q.train()

    if log: 
        wandb.watch(Q, log_freq=1)

    next_obs = env.reset()
    next_done = 0

    optimizer = torch.optim.SGD(Q.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    running_reward = 0

    for epoch in range(0, epochs):
        if print_debug:
            print("*****EPOCH {}".format(epoch))
        eps = min_epsilon + (max_epsilon - min_epsilon)*np.exp(-decay_rate*epoch)
        if log:
            wandb.log({"epsilon": eps})
            wandb.log({"epoch": epoch})
        
        # Collect batch
        data, next_obs, next_done, running_reward = collect_data(
            env=env, 
            next_obs=next_obs,
            next_done=next_done,
            running_reward=running_reward, 
            eps=eps,
            Q=Q_frozen, 
            log=log, 
            print_debug=print_debug
        )

        if print_debug:
            print("**TRAIN")

        tensor_obs = torch.from_numpy(data[0])
        tensor_action = torch.from_numpy(data[1])
        tensor_next_obs = torch.from_numpy(data[2])
        tensor_reward = torch.from_numpy(data[3])
        tensor_done = torch.from_numpy(data[4])

        done_rewards = (tensor_done == True) * tensor_reward
        q_next = Q_frozen(tensor_next_obs)
        max_q_next = torch.max(q_next, dim=1)
        if print_debug:
            print("max(q_next)")
            print(max_q_next)

        not_done_rewards = (tensor_done == False) * (tensor_reward + gamma * max_q_next.values)
        td_target = torch.add(done_rewards, not_done_rewards)


        if print_debug:
            print("done rewards")
            print(done_rewards)
            print("next obs")
            print(next_obs)
            print("q next")
            print(q_next)
            print("done")
            print(done)
            print("not done")
            print(done == False)
            print("not done rewards")
            print(not_done_rewards)
            print("td_target")
            print(td_target)

        q_with_a = Q(tensor_obs)[:,tensor_action.to(int)]
        if print_debug:
            print("td target")
            print(td_target.to(float))
            print("q with a")
            print(q_with_a.to(float))

        loss = loss_fn(td_target.to(float), q_with_a.to(float))
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        if log:
            wandb.log({"loss": loss})

        # Copy weights if the correct epoch
        if epoch % epochs_to_reset_Q == 0:
            Q_frozen.load_state_dict(Q.state_dict())

    env.close()

In [None]:
train(print_debug=False)

In [None]:
train(log=True)