Prioritized Experience Replay
---

In [None]:
import datetime
from agents.dqn import DQNAgent, ConvQNetworkFactory, ConvQNetwork
from agents.random import RandomAgent
from agents.logging import TensorBoardLogger
from agents.per import PERAgent
from env.env import WindowedGridView, DeliveryDrones
import os

from rl_helpers.rl_helpers import MultiAgentTrainer, test_agents, plot_cumulative_rewards

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
env = WindowedGridView(DeliveryDrones(), radius=3)
env.env_params.update({'n_drones': 11, 'pickup_reward': 1, 'discharge': 2, 'rgb_render_rescale': 2.0})

In [None]:
"""DQN with conv. Q-network"""
dqn_agent_1 = DQNAgent(
    env, ConvQNetworkFactory(env, conv_layers=[
        {'out_channels': 32, 'kernel_size': 3, 'stride': 2, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
    ], dense_layers=[256]),
    gamma=0.95, epsilon_start=1, epsilon_decay=0.99, epsilon_end=0.01, memory_size=10000, batch_size=64, 
    target_update_interval=500)

dqn_agent_2 = DQNAgent(
    env, ConvQNetworkFactory(env, conv_layers=[
        {'out_channels': 32, 'kernel_size': 3, 'stride': 2, 'padding': 1},
        {'out_channels': 64, 'kernel_size': 3, 'stride': 1, 'padding': 1}
    ], dense_layers=[256]),
    gamma=0.95, epsilon_start=1, epsilon_decay=0.99, epsilon_end=0.01, memory_size=10000, batch_size=64, 
    target_update_interval=500)

In [None]:
from agents.curiosity import CuriosityDQNAgent

"""DQN with intrinsic curiosity module)"""
dqn_factory = ConvQNetworkFactory(env, conv_layers=[
        {'out_channels': 32, 'kernel_size': 3, 'stride': 2, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
    ], dense_layers=[256])

curiosity_agent_1 = CuriosityDQNAgent(env, dqn_factory, gamma=0.95, epsilon_start=1, epsilon_decay=0.99, 
                                      epsilon_end=0.01, memory_size=10000, batch_size=64, 
                                      target_update_interval=500, eta=0.1)

In [None]:
"""DQN with intrinsic curiosity module)"""
dqn_factory = ConvQNetworkFactory(env, conv_layers=[
        {'out_channels': 32, 'kernel_size': 3, 'stride': 2, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
        {'out_channels': 32, 'kernel_size': 3, 'stride': 1, 'padding': 1},
    ], dense_layers=[256])

per_agent_1 = PERAgent(env, dqn_factory, gamma=0.95, epsilon_start=1, epsilon_decay=0.99, 
                                      epsilon_end=0.01, memory_size=10000, batch_size=64, 
                                      target_update_interval=500, alpha=0.6, beta=0.4)


# Reset environment with those parameters
env.reset()

# Setup random opponents
agents = {drone.index: RandomAgent(env) for drone in env.drones}
agents[0] = dqn_agent_1
agents[1] = dqn_agent_2
agents[2] = curiosity_agent_1

# Create trainer
trainer = MultiAgentTrainer(env, agents, reset_agents=True, seed=0)

In [None]:
trainer.train(100)

In [None]:
"""
from IPython.display import clear_output

for _ in range(1000):
    trainer.train(10)
    clear_output()
    per_agent_1.inspect_memory()
    #plot_rolling_rewards(trainer.rewards_log)
"""

In [None]:
dqn_agent_1.save('test-agent-5.pt')