### Try to solve atari
- follow this [link](https://github.com/transedward/pytorch-dqn)

In [5]:
%reload_ext autoreload
%autoreload 2
import torch
import torch.nn.functional as F 
import random
import numpy as np
from EXITrl.approx_v_base import ApproxVBase
from EXITrl.approx_policy_base import ApproxPolicyBase
from EXITrl.base import Base
from EXITrl.helpers import update_params, ExperienceReplay, WeightDecay, device
from EXITrl.nn_wrapper import NNWrapper
from env.atari_wrapper import wrap_deepmind_ram
import gym

In [6]:
class QNetwork(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
        self.linear3 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        return self.linear3(x)

In [7]:
class DQN(Base):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.local_q_network = NNWrapper(
            QNetwork(self.num_state, 32, self.num_action),
            lr=self.alpha
        )
        self.target_q_network = NNWrapper(
            QNetwork(self.num_state, 32, self.num_action),
            lr=0 # use manual update
        )
        update_params(self.local_q_network.model, self.target_q_network.model, tau=0)
        self.num_step = 0
    
    def initialize(self, num_step_to_learn, eps_start, eps_end, eps_decay, num_experience, num_recall, skip_frame=1):
        self.num_step_to_learn= num_step_to_learn
        self.epsilon_decay = WeightDecay(eps_start, eps_end, eps_decay)
        self.epsilon = self.epsilon_decay.step()
        self.experience_replay = ExperienceReplay(num_experience=num_experience, num_recall=num_recall)
        self.skip_frame = skip_frame
        
    def policy(self, state):
        return self.local_q_network.epsilon_greedy(state, self.epsilon)
    
    def learn(self, state, action, reward, next_state, done):
        # detach because we only backprop local network and update target network weight manually
        targets_next_Q = self.target_q_network.forward(next_state).detach().max(1)[0]
        targets_Q = reward + (self.gamma * targets_next_Q * (1 - done))

        local_Q = self.local_q_network.forward(state)
        expected_Q = local_Q.gather(1, action.unsqueeze(1).long()).squeeze(1) # select Q from action

        # Huber loss
        loss = F.smooth_l1_loss(expected_Q, targets_Q) # loss = F.mse_loss(expected_Q, targets_Q)
        self.local_q_network.backprop(loss)

        update_params(self.local_q_network.model, self.target_q_network.model, self.tau)

    def _loop(self, episode) -> int:
        total_reward = 0
        state = self.env.reset()
        for i in range(1000):
            action = self.policy(state)
            for _ in range(self.skip_frame):
                _state, reward, done, _ = self.env.step(action)
                if done: break
            self.experience_replay.remember(state, action, reward, _state, done)
            
            self.num_step += 1
            if self.num_step%self.num_step_to_learn == 0:
                experiences = self.experience_replay.recall()
                self.learn(*experiences)
            state = _state
            
            total_reward += reward
            self.additional_log['num_step'] = self.num_step
            self.additional_log['epsilon'] = self.epsilon_decay.val
            if done: break
        self.epsilon = self.epsilon_decay.step()
        return total_reward
    
    def _save(self, reward):
        torch.save(self.local_q_network.model.state_dict(), self.save_name)
    def _load(self):
        self.epsilon = 0
        self.local_q_network.model.load_state_dict(torch.load(self.save_name, map_location=device))
        self.local_q_network.model.eval()


### Pong

In [14]:
try: env.close()
except: pass
env = wrap_deepmind_ram(gym.make('Pong-ram-v0'))
dqn = DQN(env, 
      num_mean_episode=100,
      num_episodes=int(1e6),
      alpha=5e-4, 
      gamma=.99,
      tau=1,
      save_name="checkpoint/Pong-ram-v0-DQN.pth")
dqn.initialize(num_step_to_learn=3000, 
               eps_start=1, 
               eps_end=.01, 
               eps_decay=.999, 
               num_experience=20480, 
               num_recall=1024,
               skip_frame=1)
dqn.train(early_stop=lambda mean_reward: mean_reward>200)
# dqn.play()

  result = entry_point.load(False)


Episode 100	Average Score: -20.62 	other{'num_step': 28868, 'epsilon': 0.9047921471137096}
Episode 200	Average Score: -20.59 	other{'num_step': 57747, 'epsilon': 0.818648829478636}}
Episode 300	Average Score: -20.75 	other{'num_step': 85473, 'epsilon': 0.7407070321560997}
Episode 400	Average Score: -20.79 	other{'num_step': 113516, 'epsilon': 0.6701859060067403}
Episode 469	Average Score: -20.80 	other{'num_step': 132711, 'epsilon': 0.6254807928315229}

KeyboardInterrupt: 

### Pong