<a href="https://colab.research.google.com/github/dude123studios/AdvancedReinforcementLearning/blob/main/Deep%20Q%20networks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import random
import gym
import numpy as np
from collections import deque 
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

In [None]:
env = gym.make('MsPacman-v0')

state_size = (88, 80, 1)
action_size = env.action_space.n

color = np.array([120, 164, 74]).mean()

def preprocess_state(state):
    #resize
    image = state[1:176:2, ::2]

    #convert to grayscale
    image = image.mean(axis=2)
    
    #Improve contrast
    image[image==color] = 0

    #Normalize
    image = (image - 128) / 128 - 1

    image = np.expand_dims(image.reshape(88, 80, 1), axis = 0)

    return image

In [None]:
class DQN:

    def __init__(self, state_size, action_size):

        self.state_size = state_size
        self.action_size = action_size

        self.replay_buffer = deque(maxlen=5000)

        self.gamma = 0.9

        self.epsilon = 0.8

        self.update_rate = 5

        self.main_network = self.build_network()
            
        self.target_network = tf.keras.models.clone_model(self.main_network)
    
    def build_network(self):

        model = Sequential([
            Conv2D(32, (8, 8), strides=4, padding='same', input_shape=self.state_size),
            Activation('relu'),

            Conv2D(64, (4, 4), strides=2, padding='same'),
            Activation('relu'),
            
            Conv2D(64, (3, 3), strides=1, padding='same'),
            Activation('relu'),

            Flatten(),

            Dense(512, activation='relu'),
            Dense(self.action_size, activation='linear')
        ])

        model.compile(loss='mse', optimizer=Adam())

        return model
    
    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))
    
    def epsilon_greedy(self, state):

        if random.uniform(0, 1) < self.epsilon:
            return np.random.randint(self.action_size)
        
        Q_values = self.main_network.predict(state)
        return np.argmax(Q_values)
    
    def train(self, batch_size):

        minibatch = random.sample(self.replay_buffer, batch_size)

        for state, action, reward, next_state, done in minibatch:
            
            if not done:
                target_Q = reward + self.gamma * np.amax(self.target_network.predict(next_state))
            
            else:
                target_Q = reward
            
            Q_values = self.main_network.predict(state)
            Q_values[0][action] = target_Q

            self.main_network.fit(state, Q_values, epochs=1, verbose=0)

    def update_target_network(self):
        self.target_network.set_weights(self.main_network.get_weights()) 

In [None]:
num_episodes = 200
num_timesteps = 2000

batch_size = 8
num_screens = 4

dqn = DQN(state_size, action_size)

done = False

time_step = 0

In [None]:
for i in range(num_episodes):

    Return = 0

    state = preprocess_state(env.reset())
    
    time_step += 1

    for t in range(num_timesteps):

        #env.render()
        
        action = dqn.epsilon_greedy(state)

        next_state, reward, done, _ = env.step(action)

        next_state = preprocess_state(next_state)

        dqn.store_transition(state, action, reward, next_state, done)

        state = next_state

        Return += reward

        if done:
            print('Episode: ', i, ', ' 'Return', Return)
            break
        
    if len(dqn.replay_buffer) > batch_size:
        dqn.train(batch_size)
    
    if time_step % dqn.update_rate == 0:
            dqn.update_target_network()

Episode:  0 , Return 290.0
Episode:  1 , Return 250.0
Episode:  2 , Return 590.0
Episode:  3 , Return 200.0
Episode:  4 , Return 240.0
Episode:  5 , Return 140.0
Episode:  6 , Return 540.0
Episode:  7 , Return 350.0
Episode:  8 , Return 200.0
Episode:  9 , Return 430.0
Episode:  10 , Return 180.0
Episode:  11 , Return 220.0
Episode:  12 , Return 280.0
Episode:  13 , Return 460.0
Episode:  14 , Return 240.0
Episode:  15 , Return 320.0
Episode:  16 , Return 260.0
Episode:  17 , Return 230.0
Episode:  18 , Return 190.0
Episode:  19 , Return 320.0
Episode:  20 , Return 160.0
Episode:  21 , Return 260.0
Episode:  22 , Return 360.0
Episode:  23 , Return 180.0
Episode:  24 , Return 300.0
Episode:  25 , Return 730.0
Episode:  26 , Return 220.0
Episode:  27 , Return 380.0
Episode:  28 , Return 230.0
Episode:  29 , Return 180.0
Episode:  30 , Return 310.0
Episode:  31 , Return 180.0
Episode:  32 , Return 260.0
Episode:  33 , Return 200.0
Episode:  34 , Return 280.0
Episode:  35 , Return 450.0
Ep

KeyboardInterrupt: ignored