# Prioritized Replay Deep Q Network

In [1]:
# autoreload code changes
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import json
import os
import shutil

In [3]:
from matplotlib import pyplot as plt

In [4]:
from banana_env import BananaEnv
from joe_agents.dqn_agent import DqnAgent

## Hyperparameter Grid Search

In [5]:
class Grid():
    def __init__(self, *args):
        self._values = args
        self._idx = 0
    
    def __iter__(self):
        self._idx = 0
        return self
    
    def __next__(self):
        if self._idx < len(self._values):
            n = self._values[self._idx]
            self._idx += 1
            return n
        raise StopIteration

In [6]:
class ExperimentIterator():
    def __init__(self, params):
        self._params = params
        self._current = None
        self._grids = []
        self._product_iter = None
        
    def __iter__(self):
        self._current = dict(self._params)
        self._grids = [(k, v) for (k, v) in self._params.items() if isinstance(v, Grid)]
        self._product_iter = itertools.product(*[i[1] for i in self._grids])
        return self
    
    def __next__(self):
        search = self._product_iter.__next__()
        for space, value in zip(self._grids, search):
            self._current[space[0]] = value
        return self._current

In [7]:
search_params = {
    "episodes": 1000,
    "batch_size": 64,
    "buffer_size": Grid(10000, 50000),
    "learning_rate": Grid(5e-4, 1e-3, 1e-2),
    "discount_rate": Grid(0.9, 0.99, 0.999),
    "update_rate": Grid(4, 10),
    "epsilon_decay": Grid(0.9, 0.995, 0.999),
    "epsilon_decay_rate": 1,
    "min_epsilon": Grid(0.01, 0.1),
    "replay": Grid("uniform", "prioritized"),
    "prioritized_replay_damp": 0.6,
    "e_constant": 1e-6,
    "prioritized_replay_beta_anneal_rate": 100,
    "learning_start": 64,
    "double_dqn": Grid(False, True),
    "deuling_dqn": Grid(False, True)
}
# note: RLLib doesn't anneal beta it's constant at 0.4

In [8]:
exp_iter = ExperimentIterator(search_params)

## Create Environment to Train In

In [9]:
# create the environment
exe = "../../deep-reinforcement-learning/p1_navigation/Banana_Windows_x86_64/Banana.exe"
evn_config = {"executable": exe, "train_mode": True}
env = BananaEnv(evn_config)

INFO:unityagents:
'Academy' started successfully!
Unity Academy name: Academy
        Number of Brains: 1
        Number of External Brains : 1
        Lesson number : 0
        Reset Parameters :
		
Unity brain name: BananaBrain
        Number of Visual Observations (per agent): 0
        Vector Observation space type: continuous
        Vector Observation space size (per agent): 37
        Number of stacked Vector Observation: 1
        Vector Action space type: discrete
        Vector Action space size (per agent): 4
        Vector Action descriptions: , , , 


## Run all of the Experiments

In [10]:
exp_folder = "experiments"
if os.path.exists(exp_folder):
    shutil.rmtree(exp_folder)
os.mkdir(exp_folder)

In [11]:
for i, params in enumerate(exp_iter):
    wkspc_folder = os.path.join(exp_folder, f"exp_{i}")
    os.mkdir(wkspc_folder)
    agent = DqnAgent(37, 4, params)
    scores, epsilons, buffer_stats = agent.train(env)
    agent.save(os.path.join(wkspc_folder, "checkpoint.pth"))
    param_file = os.path.join(wkspc_folder, "params.json")
    with open(param_file, 'w') as f:
        json.dump(params, f)
        
    scores_file = os.path.join(wkspc_folder, "scores.json")
    with open(scores_file, 'w') as f:
        json.dump(scores, f)
        
    epsilons_file = os.path.join(wkspc_folder, "epsilons.json")
    with open(epsilons_file, 'w') as f:
        json.dump(epsilons, f)
        
    buffer_stats_file = os.path.join(wkspc_folder, "buffer_stats.json")
    with open(buffer_stats_file, 'w') as f:
        json.dump(buffer_stats, f)

100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [30:52<00:00,  1.85s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [30:43<00:00,  1.84s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [29:43<00:00,  1.78s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [30:39<00:00,  1.84s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [53:20<00:00,  3.20s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [55:45<00:00,  3.35s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [53:39<00:00,  3.22s/it]
100%|████████████████████████████████████████████████████████████████

KeyboardInterrupt: 

In [12]:
env.close()