[medium](https://medium.com/@jonathan_hui/rl-dqn-deep-q-network-e207751f7ae4) <br>
[github](https://github.com/udacity/deep-reinforcement-learning/tree/master/dqn) <br>
[paper](https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf)

In [2]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F 
import random
import numpy as np
from EXITrl.approx_v_base import ApproxVBase
from EXITrl.approx_policy_base import ApproxPolicyBase
from EXITrl.base import Base
from EXITrl.helpers import update_params, ExperienceReplay, WeightDecay, device
from EXITrl.nn_wrapper import NNWrapper
import gym

In [3]:
class QNetwork(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
        self.linear3 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        return self.linear3(x)

In [7]:
class DQN(Base):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.local_q_network = NNWrapper(
            QNetwork(self.num_state, 64, self.num_action),
            lr=self.alpha
        )
        self.target_q_network = NNWrapper(
            QNetwork(self.num_state, 64, self.num_action),
            lr=0 # use manual update
        )
        update_params(self.local_q_network.model, self.target_q_network.model, tau=0)
        self.num_step = 0
    
    def initialize(self, num_step_to_learn, eps_start, eps_end, eps_decay, num_experience, num_recall, skip_frame=1):
        self.num_step_to_learn= num_step_to_learn
        self.epsilon_decay = WeightDecay(eps_start, eps_end, eps_decay)
        self.epsilon = self.epsilon_decay.step()
        self.experience_replay = ExperienceReplay(num_experience=num_experience, num_recall=num_recall)
        self.skip_frame = skip_frame
        
    def policy(self, state):
        return self.local_q_network.epsilon_greedy(state, self.epsilon)
    
    def learn(self, state, action, reward, next_state, done):
        # detach because we only backprop local network and update target network weight manually
        targets_next_Q = self.target_q_network.forward(next_state).detach().max(1)[0]
        targets_Q = reward + (self.gamma * targets_next_Q * (1 - done))

        local_Q = self.local_q_network.forward(state)
        expected_Q = local_Q.gather(1, action.unsqueeze(1).long()).squeeze(1)

        loss = F.mse_loss(expected_Q, targets_Q)
        self.local_q_network.backprop(loss)

        update_params(self.local_q_network.model, self.target_q_network.model, self.tau)

    def _loop(self, episode) -> int:
        total_reward = 0
        state = self.env.reset()
        for i in range(1000):
            action = self.policy(state)
            for _ in range(self.skip_frame):
                _state, reward, done, _ = self.env.step(action)
                if done: break
            self.experience_replay.remember(state, action, reward, _state, done)
            
            self.num_step += 1
            if self.num_step%self.num_step_to_learn == 0:
                experiences = self.experience_replay.recall()
                self.learn(*experiences)
            state = _state
            
            total_reward += reward
            if done: break
        self.epsilon = self.epsilon_decay.step()
        return total_reward
    
    def _save(self, reward):
        torch.save(self.local_q_network.model.state_dict(), self.save_name)
    def _load(self):
        self.epsilon = 0
        self.local_q_network.model.load_state_dict(torch.load(self.save_name, map_location=device))
        self.local_q_network.model.eval()


In [8]:
try: env.close()
except: pass
env = gym.make('Breakout-ram-v0')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=int(1e6),
      alpha=5e-4, 
      gamma=.95,
      tau=1e-3,
      save_name="checkpoint/Breakout-ram-v0-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=10240, 
               num_recall=1024,
               skip_frame=1)
# dqn.train(early_stop=lambda mean_reward: mean_reward>200)
dqn.play()

  result = entry_point.load(False)


--- self.epsilon_decay 0.995
-----play: 0.995
self.epsilon: 0.995
action_values: tensor([16.9155, 17.4139, 16.5009, 18.3130]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.4783, 17.4141, 16.4131, 18.4886]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.6547, 17.3095, 16.3998, 18.2878]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.3674, 17.3345, 16.3595, 18.4534]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.4529, 17.2517, 16.3311, 18.3050]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.7538, 18.4185, 16.4606, 18.1269]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.8679, 18.3674, 16.4563, 18.0328]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.8516, 18.3541, 16.4698, 18.0428]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.7299, 18.3596, 16.4424, 18.0964]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.6068, 18.3236, 16.4146, 18.0857]) epsilon: 0.995
self

self.epsilon: 0.995
action_values: tensor([13.9127, 14.3542, 14.4485, 15.2595]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.7161, 14.4338, 14.4813, 15.4216]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.7060, 14.4321, 14.5236, 15.4228]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.7032, 14.4381, 14.5635, 15.4187]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.7068, 14.4407, 14.6157, 15.4114]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.7191, 14.4368, 14.6670, 15.4021]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.7314, 14.4329, 14.7183, 15.3929]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.7406, 14.4300, 14.7568, 15.3859]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.9549, 14.3504, 14.8053, 15.2081]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([13.9947, 14.3552, 14.8705, 15.2053]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.0008

action_values: tensor([16.8531, 14.6295, 15.9222, 14.5475]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.6616, 14.6549, 15.8540, 14.6851]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.4600, 14.6261, 15.7262, 14.7472]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.6014, 14.5498, 15.7901, 14.5667]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8808, 14.6104, 16.0473, 14.5170]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.9046, 14.6120, 16.1081, 14.5114]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.9115, 14.6097, 16.1329, 14.5066]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.9184, 14.6074, 16.1577, 14.5018]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.9321, 14.6028, 16.2073, 14.4922]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8213, 14.6302, 16.1805, 14.5907]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.7777, 14.5966, 16.1492, 

action_values: tensor([15.3201, 16.3206, 15.7193, 17.5841]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.2988, 16.3146, 15.7135, 17.5879]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.2562, 16.3025, 15.7020, 17.5956]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.2349, 16.2965, 15.6963, 17.5994]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.1923, 16.2844, 15.6848, 17.6071]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.1498, 16.2723, 15.6733, 17.6148]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.1072, 16.2603, 15.6618, 17.6225]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.8879, 15.9336, 15.5412, 17.1373]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.7618, 15.1378, 15.4974, 14.8479]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.0369, 15.4566, 15.7820, 15.1293]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.0669, 15.4257, 15.8291, 

self.epsilon: 0.995
action_values: tensor([16.7745, 13.7376, 15.6192, 13.1512]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8293, 13.7815, 15.7456, 13.1832]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8537, 13.7907, 15.7976, 13.1871]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8748, 13.7833, 15.8467, 13.1766]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8958, 13.7758, 15.8958, 13.1660]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.5841, 14.0944, 16.5000, 12.6105]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.0362, 14.3495, 16.0757, 13.1152]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.0875, 14.3732, 16.1197, 13.1031]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.4431, 14.4336, 16.3407, 12.9494]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.6571, 14.3832, 15.8746, 13.3147]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.4073

action_values: tensor([16.1353, 18.4565, 16.5731, 18.0286]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.8410, 18.4765, 16.5318, 18.1647]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.5292, 18.4391, 16.4093, 18.2004]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.6422, 18.3550, 16.4249, 18.0519]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.0520, 18.3549, 16.5550, 17.9266]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.5069, 18.3957, 16.8085, 17.8162]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.2687, 18.3926, 16.6810, 17.9404]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.9666, 14.9465, 14.2638, 14.5778]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.6874, 15.0645, 14.1994, 14.7072]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.4156, 15.1479, 14.1067, 14.7423]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.3207, 15.1941, 14.0987, 

action_values: tensor([14.8274, 14.1480, 14.1718, 15.0741]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.8085, 14.1291, 14.1565, 15.0875]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.7708, 14.0914, 14.1260, 15.1144]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.1221, 14.4208, 14.1752, 15.4658]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.5111, 14.1300, 14.0797, 15.2990]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.4733, 14.0923, 14.0492, 15.3259]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.8291, 14.4174, 14.0984, 15.6832]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.8095, 14.3992, 14.0832, 15.6957]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.3979, 14.0169, 13.9881, 15.3797]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.2909, 13.6203, 13.9060, 14.2151]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([14.6016, 13.4253, 13.9295, 

self.epsilon: 0.995
action_values: tensor([15.0539, 12.5446, 13.8064, 14.0243]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.0824, 12.5305, 13.8450, 14.0393]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.3891, 12.4469, 13.8726, 13.9434]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.6176, 12.3943, 13.9357, 13.9843]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.6365, 12.3815, 13.9721, 13.9936]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.7909, 12.3415, 14.0009, 13.9229]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([15.8763, 12.3621, 14.0697, 13.9356]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.1036, 12.3463, 14.1311, 13.8429]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.2189, 12.3739, 14.2268, 13.8593]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.7244, 12.3553, 14.3790, 13.6844]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8479

self.epsilon: 0.995
action_values: tensor([17.9822, 14.1922, 15.6022, 13.9039]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([17.7626, 13.9296, 14.6611, 12.7759]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([17.5763, 13.9106, 14.5412, 12.8501]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([17.5107, 13.8615, 14.4435, 12.8347]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([17.2540, 13.9064, 14.9278, 12.7959]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([17.2313, 13.9217, 14.9207, 12.8079]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.8696, 13.9267, 14.7261, 12.9771]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.7133, 13.8950, 14.6004, 13.0192]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([16.9332, 13.9077, 14.6917, 12.8878]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([17.2783, 13.9608, 14.3238, 12.9275]) epsilon: 0.995
self.epsilon: 0.995
action_values: tensor([17.2704

### LunarLander

In [None]:
try: env.close()
except: pass
env = gym.make('LunarLander-v2')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=2000,
      alpha=5e-4, 
      gamma=.99,
      tau=1e-3,
      save_name="checkpoint/LunarLander-v2-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=2048, 
               num_recall=512,
               skip_frame=1)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

### Skip frame

In [143]:
try: env.close()
except: pass
env = gym.make('LunarLander-v2')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=2000,
      alpha=5e-4, 
      gamma=.95,
      tau=1e-3,
      save_name="checkpoint/LunarLander-v2-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=2048, 
               num_recall=32,
               skip_frame=3)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

Episode 100	Average Score: -148.99
Episode 200	Average Score: -142.11
Episode 300	Average Score: -149.07
Episode 400	Average Score: -150.33
Episode 500	Average Score: -166.24
Episode 600	Average Score: -175.00
Episode 700	Average Score: -155.82
Episode 800	Average Score: -142.49
Episode 900	Average Score: -95.593
Episode 1000	Average Score: -10.25
Episode 1100	Average Score: 32.34
Episode 1200	Average Score: 36.93
Episode 1300	Average Score: 31.62
Episode 1400	Average Score: 53.91
Episode 1500	Average Score: 56.46
Episode 1600	Average Score: 70.57
Episode 1700	Average Score: 54.78
Episode 1800	Average Score: 103.50
Episode 1900	Average Score: 102.85
Episode 2000	Average Score: 85.092


### Breakout

In [None]:
try: env.close()
except: pass
env = gym.make('Breakout-ram-v0')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=int(1e6),
      alpha=5e-4, 
      gamma=.95,
      tau=1e-3,
      save_name="checkpoint/Breakout-ram-v0-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=10240, 
               num_recall=1024,
               skip_frame=1)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

Episode 100	Average Score: 1.15
Episode 200	Average Score: 1.13
Episode 300	Average Score: 1.60
Episode 400	Average Score: 1.85
Episode 500	Average Score: 2.60
Episode 600	Average Score: 3.22
Episode 700	Average Score: 3.52
Episode 800	Average Score: 3.44
Episode 900	Average Score: 3.72
Episode 1000	Average Score: 3.96
Episode 1100	Average Score: 4.04
Episode 1200	Average Score: 3.53
Episode 1300	Average Score: 3.88
Episode 1390	Average Score: 4.69

### Pong

In [145]:
try: env.close()
except: pass
env = gym.make('Pong-ram-v0')
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=int(1e6),
      alpha=5e-4, 
      gamma=.99,
      tau=1e-3,
      save_name="checkpoint/Pong-ram-v0-DQN.pth")
dqn.initialize(num_step_to_learn=4, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.995, 
               num_experience=2048, 
               num_recall=64,
               skip_frame=2)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

Episode 100	Average Score: -10.37
Episode 200	Average Score: -10.57
Episode 300	Average Score: -10.79
Episode 366	Average Score: -10.72

KeyboardInterrupt: 