In [3]:
import gym
import numpy as np
import tensorflow as tf
import pandas as pd
from tensorflow import keras
from src.env_wrapper import EnvWrapper
from src.dqn_model import DQNModel
from src.replay_buffer import PrioritizedReplayBuffer
from src.training_data_manager import TrainingDataManager, save_models_and_data

Environment Initialization

In [4]:
gym_env = gym.make("BreakoutNoFrameskip-v4")
env = EnvWrapper(gym_env, frame_stack=4, frame_skip=4, seed=42)

Model Definition

In [None]:
input_shape=(84, 84, 4, )
n_outputs = env.action_space.n
DQN_model = DQNModel.create(input_shape, n_outputs)
DQN_model_target = DQNModel.create(input_shape, n_outputs)
DQN_model_target.set_weights(DQN_model.get_weights())

Replay Buffer Initializaion

In [6]:
replay_buffer = PrioritizedReplayBuffer(memory_length=50000, batch_size=32)

Training functions

In [7]:
discount_factor = 0.99
optimizer = keras.optimizers.Adam(learning_rate=0.00025)
loss_function = keras.losses.MeanSquaredError()

In [8]:
def epsilon_greedy_policy(state, epsilon):
  if np.random.rand() < epsilon:
    return np.random.randint(n_outputs)
  else:
    state_tensor = tf.expand_dims(tf.convert_to_tensor(state), 0)
    action_probs = DQN_model(state_tensor)
    return tf.argmax(action_probs[0]).numpy()

In [9]:
def training_step():
  experiences = replay_buffer.sample_experiences()
  states, actions, rewards, next_states, dones = experiences
  next_Q_values = DQN_model_target(next_states)
  max_next_Q_values = tf.reduce_max(next_Q_values, axis=1)
  target_Q_values = rewards + (discount_factor*max_next_Q_values)*(1-dones)
  mask = tf.one_hot(actions, n_outputs)
  with tf.GradientTape() as tape:
    all_Q_values = DQN_model(states)
    Q_values = tf.reduce_sum(tf.multiply(all_Q_values, mask), axis=1)
    loss = loss_function(target_Q_values, Q_values)
  grads = tape.gradient(loss, DQN_model.trainable_variables)
  optimizer.apply_gradients(zip(grads, DQN_model.trainable_variables))
  tds = [loss_function([tq], [q]).numpy() for tq, q in zip(target_Q_values, Q_values)]
  replay_buffer.update_prio(tds)

Data manager initialization

In [10]:
tdm = TrainingDataManager(EnvWrapper(gym_env, frame_stack=4, frame_skip=4, seed=42), n_aav_states=5)

Training Phase

In [13]:
max_epochs = 50
updates_per_epoch = 10000
epsilon_max = 1.0
epsilon_min = 0.1
random_steps = 1e6
target_network_update_period = 10000
update_period = 4
sample_period = update_period * 5

In [14]:
steps = 0
epochs = 0
updates = 0
episode_reward = 0

In [None]:
while epochs < max_epochs:
  state = env.reset()
  while True:
    steps += 1
    epsilon = max(epsilon_max - steps/random_steps, epsilon_min)
    action = epsilon_greedy_policy(state, epsilon)
    next_state, reward, done, _ = env.step(action)
    episode_reward += reward
    replay_buffer.add((state, action, reward, next_state, done))
    state = next_state
    if not(steps % sample_period) and replay_buffer.usable(): replay_buffer.update_sample()
    if not(steps % update_period) and replay_buffer.usable():
      updates += 1
      training_step()
      if updates >= updates_per_epoch:
        epochs += 1
        updates = 0
        tdm.update_aav(DQN_model)
        tdm.update_arpe()
        print(f"Epoch {epochs}/{max_epochs} concluded: reward={tdm.get_arpe()[-1]:.2f} - ε={epsilon:.2f} - episodes={tdm.get_episodes()}")
        save_models_and_data("adam", "PrioritizedDQN", DQN_model, DQN_model_target, tdm.get_arpe(), tdm.get_aav())
    if steps % target_network_update_period == 0:
      DQN_model_target.set_weights(DQN_model.get_weights())
    if done:
      tdm.end_episode_update(episode_reward)
      episode_reward = 0
      break
env.close()
tdm.print_results()