[medium](https://medium.com/@jonathan_hui/rl-dqn-deep-q-network-e207751f7ae4) <br>
[github](https://github.com/udacity/deep-reinforcement-learning/tree/master/dqn) <br>
[paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)

In [111]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F 
import random
import numpy as np
from EXITrl.approx_v_base import ApproxVBase
from EXITrl.approx_policy_base import ApproxPolicyBase
from EXITrl.base import Base
from EXITrl.helpers import update_params, ExperienceReplay, WeightDecay
from EXITrl.nn_wrapper import NNWrapper
import gym

In [112]:
class QNetwork(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
        self.linear3 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        return self.linear3(x)

In [128]:
class DQN(Base):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.local_q_network = NNWrapper(
            QNetwork(self.num_state, 64, self.num_action),
            lr=self.alpha
        )
        self.target_q_network = NNWrapper(
            QNetwork(self.num_state, 64, self.num_action),
            lr=0 # use manual update
        )
        update_params(self.local_q_network.model, self.target_q_network.model, tau=0)
        self.num_step = 0
    
    def initialize(self, num_step_to_learn, eps_start, eps_end, eps_decay, num_experience, num_recall):
        self.num_step_to_learn= num_step_to_learn
        self.epsilon_decay = WeightDecay(eps_start, eps_end, eps_decay)
        self.epsilon = self.epsilon_decay.step()
        self.experience_replay = ExperienceReplay(num_experience=num_experience, num_recall=num_recall)
        
    def policy(self, state):
        return self.local_q_network.epsilon_greedy(state, self.epsilon)
    
    def learn(self, state, action, reward, next_state, done):
        # detach because we only backprop local network and update target network weight manually
        targets_next_Q = self.target_q_network.forward(next_state).detach().max(1)[0]
        targets_Q = reward + (self.gamma * targets_next_Q * (1 - done))

        local_Q = self.local_q_network.forward(state)
        expected_Q = local_Q.gather(1, action.unsqueeze(1).long()).squeeze(1)

        loss = F.mse_loss(expected_Q, targets_Q)
        self.local_q_network.backprop(loss)

        update_params(self.local_q_network.model, self.target_q_network.model, self.tau)

    def _loop(self, episode) -> int:
        total_reward = 0
        state = self.env.reset()
        for i in range(1000):
            action = self.policy(state)
            for i in range(4):
                _state, reward, done, _ = self.env.step(action)
                if done: break
            self.experience_replay.remember(state, action, reward, _state, done)
            
            self.num_step += 1
            if self.num_step%self.num_step_to_learn == 0:
                experiences = self.experience_replay.recall()
                self.learn(*experiences)
            state = _state
            
            total_reward += reward
            if done: break
        self.epsilon = self.epsilon_decay.step()
        return total_reward
    
    def _save(self, reward):
        torch.save(self.local_q_network.model.state_dict(), self.save_name)
    def _load(self):
        self.epsilon = 0
        self.local_q_network.model.load_state_dict(torch.load(self.save_name, map_location=device))
        self.local_q_network.model.eval()


In [121]:
try: env.close()
except: pass
env = gym.make('LunarLander-v2')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=2000,
      alpha=5e-4, 
      gamma=.99,
      tau=1e-3,
      save_name="checkpoint/LunarLander-v2-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=2024, 
               num_recall=512)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

Episode 100	Average Score: -154.91
Episode 200	Average Score: -89.525
Episode 300	Average Score: -29.54
Episode 400	Average Score: 40.573
Episode 500	Average Score: 177.55
Episode 522	Average Score: 200.56--- early stop ----
 current_mean_reward: 200.56308424718296 num_mean_episode: 100


In [122]:
try: env.close()
except: pass
env = gym.make('Breakout-ram-v0')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=2000,
      alpha=5e-4, 
      gamma=.99,
      tau=1e-3,
      save_name="checkpoint/Breakout-ram-v0-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=2024, 
               num_recall=512)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

Episode 100	Average Score: 1.33
Episode 200	Average Score: 1.21
Episode 300	Average Score: 0.91
Episode 400	Average Score: 1.33
Episode 500	Average Score: 2.20
Episode 600	Average Score: 2.40
Episode 700	Average Score: 3.04
Episode 800	Average Score: 3.14
Episode 900	Average Score: 3.54
Episode 1000	Average Score: 3.84
Episode 1100	Average Score: 4.02
Episode 1200	Average Score: 4.38
Episode 1300	Average Score: 4.48
Episode 1400	Average Score: 4.56
Episode 1500	Average Score: 4.31
Episode 1519	Average Score: 4.43

KeyboardInterrupt: 

In [129]:
try: env.close()
except: pass
env = gym.make('Breakout-ram-v0')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=20000,
      alpha=5e-4, 
      gamma=.99,
      tau=1e-3,
      save_name="checkpoint/Breakout-ram-v0-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=2024, 
               num_recall=512)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

Episode 8	Average Score: 0.62

KeyboardInterrupt: 