# CHAPTER 5
## Deep Q-Network

In [5]:
import gym
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np

import time
from datetime import datetime
from collections import deque
import sys

import altair as alt

import atari_wrappers as atari

In [25]:
def make_env(env_name, fire=True, frames_num=2, noop_num=30, skip_frames=True):
    
    env = gym.make(env_name)
    
    if skip_frames:
        env = atari.MaxAndSkipEnv(env) ## Return only every skip-th frame
        
    if fire:
        env = atari.FireResetEnv(env) ## Fire at the beggining
        
    env = atari.NoopResetEnv(env,noop_max=noop_num)
    env = atari.WarpFrame(env) ## Reshape image
    env = atari.FrameStack(env, frames_num) ## Stack last 2 frames
    
    return env

In [91]:
class QNet(Model):
    
    def __init__(self, h_layers, h_size, o_size, h_activation=tf.nn.relu, o_activation=None):
        
        super(QNet,self).__init__()
        self.conv_layer1 = Conv2D(filters=16, kernel_size=8, strides=4, padding='valid', activation='relu')
        self.conv_layer2 = Conv2D(filters=32, kernel_size=4, strides=2, padding='valid', activation='relu')
        self.conv_layer3 = Conv2D(filters=32, kernel_size=3, strides=1, padding='valid', activation='relu')
        
        self.flatten_layer = Flatten()
        
        self.hidden_layers = [Dense(h_size[i], activation=h_activation) for i in range(h_layers)]
        self.output_layer = Dense(o_size, activation=o_activation)
        
    def call(self,input_data):
        
        x = input_data
        
        x = self.conv_layer1(x)
        x = self.conv_layer2(x)
        x = self.conv_layer3(x)
        
        x = self.flatten_layer(x)
        
        for layer in self.hidden_layers:
            
            x = layer(x)
            
        return self.output_layer(x)

In [28]:
def scale_frames(frames):
    
    return np.array(frames, dtype=np.float32)/255.0

In [142]:
class ExperienceBuffer():
    
    def __init__(self,buffer_size):
        
        self.obs_buf = deque(maxlen=buffer_size)
        self.rew_buf = deque(maxlen=buffer_size)
        self.act_buf = deque(maxlen=buffer_size)
        self.next_obs_buf = deque(maxlen=buffer_size)
        self.done_buf = deque(maxlen=buffer_size)
        
    def add(self, obs, rew, act, next_obs, done):
        
        self.obs_buf.append(obs)
        self.rew_buf.append(rew)
        self.act_buf.append(act)
        self.next_obs_buf.append(obs2)
        self.done_buf.append(done)
        
    def samble_minibatch(self, batch_size):
        
        mb_indices = np.random.randint(len(self.obs_buf),size=batch_size)
        
        mb_obs = scale_frames([self.obs_buf[i] for i in mb_indices])
        mb_rew = [self.rew_buf[i] for i in mb_indices] 
        mb_act = [self.act_buf[i] for i in mb_indices]
        mb_next_obs = scale_frames([self.next_obs_buf[i] for i in mb_indices])
        mb_done = [self.done_buf[i] for i in mb_indices]
    
        return mb_obs, mb_rew, mb_act, mb_next_obs, mb_done
    
    def __len__(self):
        return len(self.obs_buf)
        

In [30]:
current_milli_time = lambda: int(round(time.time() * 1000))

In [48]:
def update_target(target_qv,online_qv):
    
    target_qv.set_weights(online_qv.get_weights())

In [136]:
def e_greedy(action_values,epsilon=0.1):
    
    if np.random.uniform(0,1) < epsilon:
        
        return np.random.randint(len(action_values))
    
    else:
        
        return np.argmax(action_values)

In [131]:
def DQN(env_name, hidden_layers =1, hidden_size=[32], alpha=1e-2, num_epochs=2000, buffer_size=100000, gamma=0.99,
        update_target_net=1000, batch_size=64, update_freq=4, frames_num=2, min_buffer_size=5000, test_frequency=20,
        start_exp=1, end_exp=0.1, exp_steps=100000):
    
    env = make_env(env_name, frames_num=frames_num, skip_frames=True, noop_num=20)
    env_test = make_env(env_name,frames_num=frames_num, skip_frames=True, noop_num=20)
    
    env_test = gym.wrappers.Monitor(env_test, "VIDEOS/TEST_VIDEOS"+env_name+str(current_milli_time()), force=True,
                                    video_callable=lambda x: x%20==0)
    
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.n
    
    target_qv = QNet(h_layers=hidden_layers, h_size=hidden_size, o_size=act_dim)
    online_qv = QNet(h_layers=hidden_layers, h_size=hidden_size, o_size=act_dim)
    
    obs = env.reset()
    obs = scale_frames(obs)
    
    _ = target_qv.predict(np.array([obs]))
    _ = online_qv.predict(np.array([obs]))
    
    update_target(target_qv,online_qv)
    
    #####################
    ### TENSORBOARD ##### --> Not implemented
    #####################
    
    step_count = 0
    last_update_loss = []
    ep_time = current_milli_time()
    batch_rew = []
    
    buffer = ExperienceBuffer(buffer_size)
    epsilon = start_exp
    eps_decay = (start_exp - end_exp)/exp_steps
    
    obs = env.reset()
    
    for epoch in range(num_epochs):
        
        game_reward = 0
        done = False
        
        while not done:
            
            obs_process = np.array([scale_frames(obs)])
            action_values = online_qv.predict(obs_process)[0]
            
            action = e_greedy(action_values, epsilon)
            next_obs, reward, done, _ = env.step(action) 
            buffer.add(obs, reward, action, next_obs, done)
            
            obs = next_obs
            game_reward += reward
            step_count += 1
            
            if epsilon > end_exp:
                epsilon -= eps_decay
                
            if len(buffer) > min_buffer_size and (step_count % update_freq == 0):
                
                mb_obs, mb_rew, mb_act, mb_next_obs, mb_done = buffer.sample_minibatch(batch_size)
                
                
                #mb_target_act = target_qv.predict
            
            done=True
            
    return action_values

In [132]:
t = DQN('PongNoFrameskip-v4')



In [133]:
t

array([[-0.17649953, -0.05435145,  0.08430061, -0.03263599,  0.11034723,
        -0.02487457]], dtype=float32)