Some of the code was inspired or cloned from these githubs👏

https://github.com/adventuresinML/adventures-in-ml-code/blob/master/dueling_q_tensorflow2.py
https://simoninithomas.github.io/Deep_reinforcement_learning_Course/   

In [1]:
import gym
import tensorflow as tf
import random
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
import cv2

from tensorflow.keras.models import Model
from tensorflow.keras.layers import *

from memory import replay_buffer, frame_stack
from models import agent, encoder, decoder

In [2]:
MAX_EPSILON = 1
MIN_EPSILON = 0.01
EPSILON_MIN_ITER = 5000
DELAY_TRAINING = 300
GAMMA = 0.95
BATCH_SIZE = 32
TAU = 0.08
RANDOM_REWARD_STD = 1.0

env = gym.make("Breakout-v0")

In [3]:
class DQNAgent:
    def __init__(self, env):
        
        #Environment
        self.state_size = env.observation_space.shape[0]
        self.num_actions = env.action_space.n
        
        #Initiate networks
        self.primary_network = agent(20, 32, self.num_actions)
        self.target_network = agent(20, 32, self.num_actions)
        self.enc = encoder((160,160,1),20,5)
        self.dec = decoder((160,160,1),20,5)
        
        self.optimizer = tf.keras.optimizers.Adam()
        self.MSE = tf.keras.losses.MeanSquaredError()
        # make target_network = primary_network
        for t, e in zip(self.target_network.trainable_variables, self.primary_network.trainable_variables):
            t.assign(e)
        
        #Initiate memory
        self.qtable = replay_buffer(500000)
        self.states = replay_buffer(100000)
        self.frame_stack = frame_stack(4, (160,160))

    def update_network(self):
        for t, e in zip(self.target_network.trainable_variables, self.primary_network.trainable_variables):
            t.assign(t * (1 - TAU) + e * TAU)
    
    def choose_action(self, state, eps):
        state = self.preprocess_frame(state)
        latent_representation = self.enc.predict(np.expand_dims(state, axis=0))
        
        if random.random() < eps:
            return latent_representation[0], env.action_space.sample()
        else:
            return latent_representation[0], np.argmax(self.primary_network.predict(latent_representation)[0])
        
    def preprocess_frame(self,frame):
        resized = frame[32:192,0:160]
        gray = cv2.cvtColor(resized, cv2.COLOR_RGB2GRAY)/255.
        return np.expand_dims(gray, axis=2).astype(np.float32)
        
    def train(self):
        batch = self.qtable.sample(BATCH_SIZE)
        states = np.array([val[0] for val in batch], dtype=np.float32)
        actions = np.array([val[1] for val in batch])
        rewards = np.array([val[2] for val in batch], dtype=np.float32)
        next_states = np.array([(np.zeros(self.state_size) if val[3] is None else val[3]) for val in batch], dtype=np.float32)

        with tf.GradientTape() as tape:
            prim_qt = self.primary_network(states)
            prim_qtp1 = self.primary_network(next_states)
            target_q = prim_qt.numpy()
            updates = rewards
            valid_idxs = np.array(next_states).sum(axis=1) != 0
            batch_idxs = np.arange(BATCH_SIZE)
            prim_action_tp1 = np.argmax(prim_qtp1.numpy(), axis=1)
            q_from_target = self.target_network(next_states)
            updates[valid_idxs] += GAMMA * q_from_target.numpy()[batch_idxs[valid_idxs], prim_action_tp1[valid_idxs]]
            target_q[batch_idxs, actions] = updates
            loss = self.MSE(prim_qt, target_q)
            
        gradients = tape.gradient(loss, self.primary_network.trainable_variables)   
        self.optimizer.apply_gradients(zip(gradients, self.primary_network.trainable_variables))
        self.update_network()

In [4]:
DQNAgent = DQNAgent(env)

In [6]:
num_episodes = 120
eps = MAX_EPSILON
steps = 0
rewards = []
avg_rewards = []

for episode in range(num_episodes):
    episode_reward = 0
    state = env.reset()
    representation, action = DQNAgent.choose_action(state, eps)
    while True:
        env.render()
        
        #Need to formulate strategy to get next_represntation
        
        state, reward, done, info = env.step(action)
        next_representation, action = DQNAgent.choose_action(state, 0.5)
        episode_reward += reward
        
        if done:
            representation = None
        # store in memory
        experience = representation, action, reward, next_representation, done
        DQNAgent.qtable.store(experience)
        DQNAgent.states.store(state)


        # linearly decay the eps value
        #if steps > DELAY_TRAINING:
        #    DQNAgent.train()
        #    
        #    eps = MAX_EPSILON - ((steps - DELAY_TRAINING) / EPSILON_MIN_ITER) * \
        #          (MAX_EPSILON - MIN_EPSILON) if steps < EPSILON_MIN_ITER else \
        #        MIN_EPSILON
        steps += 1

        if done:
            if steps > DELAY_TRAINING:
                print("episode: {}, reward: {}, average reward: {}".format(episode, np.round(episode_reward, decimals=2), np.mean(rewards[-10:])))
            else:
                print("episode: {}, pretraining...".format(episode))
            break

        representation = next_representation
        
    rewards.append(episode_reward)
    avg_rewards.append(np.mean(rewards[-10:]))

env.close()
plt.plot(rewards)
plt.plot(avg_rewards)
plt.plot()
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.show()        

ValueError: Error when checking input: expected input_1 to have shape (210,) but got array with shape (20,)

In [7]:
env.close()

In [7]:
DQNAgent.primary_network.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 210)]        0                                            
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 32)           6752        input_1[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 32)           6752        input_1[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 4)            132         dense_2[0][0]                    
______________________________________________________________________________________________