In [3]:
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Dropout, Dense, Flatten
from keras.optimizers import Adam
from keras.callbacks import TensorBoard
import numpy as np
from collections import deque
import time

In [4]:
# Own Tensorboard class
class ModifiedTensorBoard(TensorBoard):

    # Overriding init to set initial step and writer (we want one log file for all .fit() calls)
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.step = 1
        self.writer = tf.summary.FileWriter(self.log_dir)

    # Overriding this method to stop creating default log writer
    def set_model(self, model):
        pass

    # Overrided, saves logs with our step number
    # (otherwise every .fit() will start writing from 0th step)
    def on_epoch_end(self, epoch, logs=None):
        self.update_stats(**logs)

    # Overrided
    # We train for one batch only, no need to save anything at epoch end
    def on_batch_end(self, batch, logs=None):
        pass

    # Overrided, so won't close writer
    def on_train_end(self, _):
        pass

    # Custom method for saving own metrics
    # Creates writer, writes custom metrics and closes writer
    def update_stats(self, **stats):
        self._write_logs(stats, self.step)

### Create the DQN Agent Class

In [None]:
MODEL_NAME = "256x2"
REPLAY_MEMORY_SIZE = 50_000

class DQNAgent:
    
    def __init__(self):
        # Main model - used for the training
        self.model = self.create_model()

        # Target Model - used for the prediction
        self.target_model = self.create_model()

        self.target_model.set_weights(self.model.get_weights())
        self.replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE)
        self.tensorboard = ModifiedTensorBoard(log_dir=f"logs/{MODEL_NAME}-{int(time.time())}")
        self.target_update_counter = 0

    def create_model(self, env):
        model = Sequential([
            Conv2D(256, (3, 3), input_shape=env.STATE_SPACE_SIZE, activation='relu'),
            MaxPooling2D((2, 2)),
            Dropout(0.2),
            Conv2D(256, (3, 3), input_shape=env.STATE_SPACE_SIZE, activation='relu'),
            MaxPooling2D((2, 2)),
            Dropout(0.2),
            Flatten(),
            Dense(64),
            Dense(env.ACTION_SPACE_SIZE, activation='linear')
        ])
        model.compile(loss='mse', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])
        return model
    
    def update_replay_memory(self, experience):
        self.replay_memory.append(experience)

    def get_qs(self, state, step):
        return self.model_predict(np.array(state).reshape(-1, *state.shape)/255)[0]
