# Prioritized Experience Replay
This will test the basic code to build an agent with prioritized experience replay (Schaul et al., 2015).

## Initial setup

In [7]:
from vizdoom import *
import sys
sys.path.insert(0, "../python")
from Network import Network
import tensorflow as tf
import numpy as np

Normally, we would manipulate (hyper)parameters within an agent file, but currently, the agent is only set up to utilize a uniform replay memory. thus in this case, we will hard code parameters below to create a network, and then build our new prioritized ER from scratch.

In [5]:
tf.reset_default_graph()

# Set network parameters
phi = 1
num_channels = 1
output_shape = 4
output_directory = "../tmp/"
sess = tf.Session()
train_mode = True
lr = 0.01
net_file = "../networks/dqn_basic.json"

# Create main network
main_network = Network(phi=phi, 
                       num_channels=num_channels, 
                       output_shape=output_shape,
                       train_mode=True,
                       learning_rate=lr,
                       network_file=net_file,
                       params_file=None,
                       output_directory=output_directory,
                       session=sess,
                       scope="main_network")

# Create target network used to calculate target Q values
target_network = Network(phi=phi, 
                         num_channels=num_channels, 
                         output_shape=output_shape,
                         learning_rate=lr,
                         train_mode=True,
                         network_file=net_file,
                         params_file=None,
                         output_directory=output_directory,
                         session=sess,
                         scope="target_network")

## Build prioritized experience replay memory
Now that the initial network is created, we will build a replay memory that incorporates prioritization.

We will begin by initializing the basic building blocks of replay memory: arrays to store transition values, which include the state (s1), next state (s2), action taken (a), reward received (r), and whether or not the next state is terminal (isterminal).

In [9]:
# Set basic parameters
capacity = 5
state_shape = [30, 45]

# Initialize arrays to store transition variables
s1 = np.zeros([capacity] + list(state_shape), dtype=np.float32)
s2 = np.zeros([capacity] + list(state_shape), dtype=np.float32)
a = np.zeros(capacity, dtype=np.int32)
r = np.zeros(capacity, dtype=np.float32)
isterminal = np.zeros(capacity, dtype=np.float32)        

Now let's define a function that adds a transition to replay memory.

In [10]:
def add_transition(s1_, s2_, a_, r_, isterminal_, pos):
    s1[pos] = s1_
    s2[pos] = s2_
    a[pos] = a_
    r[pos] = r_
    isterminal[pos] = isterminal_

In addition to simply adding the values into the arrays, we also must assign each transition a priority. There are two basic schemes discussed in the paper:
- **Proportional prioritization**: $p_i = |\delta _i| + \epsilon$, where $\epsilon$ is a small constant to avoid edge-cases in which the TD error is zero (and thus leads to zero probability of sampling--see below).
- **Rank-based prioritization**: $p_i = \frac{1}{rank(i)}$, where $rank(i)$ is the priority of transition $i$ when sorted based on $|\delta_i|$.

We will use the proportional-based scheme as its implementation is easier; in reality, both performed equally well overall, although performance varied from game to game.

This function will take $\delta_i = (r_i + Q_i'(s,a)) - Q_i(s,a)$ as input and return the priority $p_i$ of the transition.

In [11]:
def assign_priority(delta):
    return abs(delta) + 0.1

Sorting and selecting transitions from replay memory becomes prohibitively expensive as the replay memory size grows. If naively implemented, the time to search and insert based on priorities scales as $O(nlogn)$ and $O(n)$, respectively. To reduce this cost, we need to implement a binary heap (see [here](https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/) for a good explanation).

The binary heap will be implemented using a numpy array. Since we know the number of leaves (i.e. the replay memory capacity), we can compute the total number of elements in the array. If the PER capacity is N, then the previous layer in the heap must be of size $2^{ceil(log(N))-1}$. For example if $N=20$, then the previous layer must be $2^{ceil(log(20))-1}=2^{5-1}=16$; if $N=33$, then it must be $2^{ceil(log(33))-1}=2^{6-1}=32$. The sum of all previous elements is equal to one less than twice the size of the next-to-last-layer:

$\sum_{k=0}^{ceil(log(N))-1} 2^{k}=2*2^{ceil(log(N))-1}-1=2^{ceil(log(N))}-1$

Thus the total number of elements is simply the above plus the number of transitions:

$2^{ceil(log(N))}-1+N$

Because we will being indexing at 1 instead of 0, we will need to add one more to our array capacity:

$2^{ceil(log(N))}+N$

In [84]:
import math

num_leaves = 8
num_elements = 2 ** math.ceil(math.log(num_leaves, 2)) + num_leaves
heap = np.zeros(num_elements, dtype=np.float32)
print(heap.shape)

(16,)


Because we will be adding in unsorted, sequential order, we will simply ad transition priorities from left to right across the bottom layer of the tree. However, while the addition is easy, we must perform additional operations to maintain the specialness of this tree. It is constructed such that the parent node is equal to the sum of the children nodes. When a priority is added, we must propagate the new value up the tree, changing the values of the parent nodes accordingly. Keeping in mind that the indices of the left and right children are given by $2i$ and $2i+1$, respectively, we can formulate the code below.

In [129]:
# Set starting position of transitions in heap
start_pos = 2 ** math.ceil(math.log(num_leaves, 2))
print(start_pos)

# Recursive function to update parent of node j
def _propagate(child_id):
    parent_id = child_id // 2
    heap[parent_id] = heap[2 * parent_id] + heap[2 * parent_id + 1]
    
# Add priority leaf to heap and update parent nodes
def add_priority(p, i, verbose=True):  
    # note that while heap is 1-indexed, RM is still 0-indexed
    j = start_pos + i 
    
    # Add priority of transition i to heap
    heap[j] = p
    
    # Recursively update parent nodes
    while j > 1:
        if verbose: print(j, end=" ")
        _propagate(j)
        j = j // 2
    if verbose: print()

8


Now let's run through an example of propagation. We will add a few priority values to the tree and observe how they are propagated upward to maintain the special structure noted above. We will print which indices were updated and the final value of the heap after updating.

In [86]:
# Initial values
np.set_printoptions(precision=1)
print(heap)

[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.]


In [87]:
# Add priority 3.5
add_priority(3.5, 0)
print(heap)

8 4 2 
[ 0.   3.5  3.5  0.   3.5  0.   0.   0.   3.5  0.   0.   0.   0.   0.   0.
  0. ]


In [88]:
# Add priority 2.2
add_priority(2.2, 1)
print(heap)

9 4 2 
[ 0.   5.7  5.7  0.   5.7  0.   0.   0.   3.5  2.2  0.   0.   0.   0.   0.
  0. ]


In [89]:
# Add priority 4.6
add_priority(4.6, 2)
print(heap)

10 5 2 
[  0.   10.3  10.3   0.    5.7   4.6   0.    0.    3.5   2.2   4.6   0.
   0.    0.    0.    0. ]


Seems to be working! Now let's add a list of numbers to a new heap and see if we get the result as this picture:

![Example of Sum Tree](prioritized_experience_replay/sumtree.png)

(credit: https://jaromiru.com/2016/11/07/lets-make-a-dqn-double-learning-and-prioritized-experience-replay/)

In [137]:
# Create new blank heap
num_leaves = 8
num_elements = 2 ** math.ceil(math.log(num_leaves, 2)) + num_leaves
heap = np.zeros(num_elements, dtype=np.float32)
start_pos = 2 ** math.ceil(math.log(num_leaves, 2))

# List of priorities to add from example above
priorities = [3, 10, 12, 4, 1, 2, 8, 2]

for i, p in enumerate(priorities):
    add_priority(p, i)
    print(heap)

8 4 2 
[ 0.  3.  3.  0.  3.  0.  0.  0.  3.  0.  0.  0.  0.  0.  0.  0.]
9 4 2 
[  0.  13.  13.   0.  13.   0.   0.   0.   3.  10.   0.   0.   0.   0.   0.
   0.]
10 5 2 
[  0.  25.  25.   0.  13.  12.   0.   0.   3.  10.  12.   0.   0.   0.   0.
   0.]
11 5 2 
[  0.  29.  29.   0.  13.  16.   0.   0.   3.  10.  12.   4.   0.   0.   0.
   0.]
12 6 3 
[  0.  30.  29.   1.  13.  16.   1.   0.   3.  10.  12.   4.   1.   0.   0.
   0.]
13 6 3 
[  0.  32.  29.   3.  13.  16.   3.   0.   3.  10.  12.   4.   1.   2.   0.
   0.]
14 7 3 
[  0.  40.  29.  11.  13.  16.   3.   8.   3.  10.  12.   4.   1.   2.   8.
   0.]
15 7 3 
[  0.  42.  29.  13.  13.  16.   3.  10.   3.  10.  12.   4.   1.   2.   8.
   2.]


They match! Now that we've built the update function, we must add the other side of the coin: binary search. Like a normal binary tree search, the algorith will be as follows:
    
    if value <= node.left: move left with value
    else:                  move right with (node.left.value - value)

In [138]:
def _search(node, m, verbose):
    # Verbose code to track search
    if verbose:
        print("Node %d, node value %d, search value %d" % (node, heap[node], m))
    
    # Return value if no children
    if 2 * node > heap.size - 1:
        return node, heap[node]
    
    # Move left
    if m <= heap[2 * node]:
        return _search(2 * node, m, verbose)
    
    # Move right
    else:
        m = m - heap[2 * node]
        return _search(2 * node + 1, m, verbose)

def retrieve(m, verbose=True):
    return _search(1, m, verbose=verbose)

In order to sample transitions based on their priority, we need to generate a random number in the range $[0, p_total]$, where $p_total=\sum_{i} p_i$. Then we will match the random number with a transition: if the random number falls in the range of priority values in the cumulative sum function of a transition, we will choose that transition. Let's see an example of randomly generating 24 as above.

In [139]:
rand_int = 24
i, p = retrieve(24)
print("Found transition %d with priority %d" % (i - start_pos, p))

Node 1, node value 42, search value 24
Node 2, node value 29, search value 24
Node 5, node value 16, search value 11
Node 10, node value 12, search value 11
Found transition 2 with priority 12


Awesome! Now let's try with some more random numbers. You can verify the results by looking at the graph above.

In [140]:
rand_int = [12, 29, 35]
for r in rand_int:
    i, p = retrieve(r)
    print("Found transition %d with priority %d" % (i - start_pos, p),
          end="\n\n")

Node 1, node value 42, search value 12
Node 2, node value 29, search value 12
Node 4, node value 13, search value 12
Node 9, node value 10, search value 9
Found transition 1 with priority 10

Node 1, node value 42, search value 29
Node 2, node value 29, search value 29
Node 5, node value 16, search value 16
Node 11, node value 4, search value 4
Found transition 3 with priority 4

Node 1, node value 42, search value 35
Node 3, node value 13, search value 6
Node 7, node value 10, search value 3
Node 14, node value 8, search value 3
Found transition 6 with priority 8



Note that because of the `<=` operator in the `move left` part of the search function, ties go to the leftmost node that contains the random int in its cumulative range.

Now let's run some tests to make sure both the update and search functions scale as $O(logN)$.

In [147]:
from time import time

sizes = [4 ** x for x in range(5, 10)]
update_time = []
search_time = []
for i, size in enumerate(sizes):
    print("Creating heap of size %d..." % size)
    
    # Create new blank heap
    num_leaves = size
    num_elements = 2 ** math.ceil(math.log(num_leaves, 2)) + num_leaves
    heap = np.zeros(num_elements, dtype=np.float32)
    start_pos = 2 ** math.ceil(math.log(num_leaves, 2))
    
    # Add random priorities
    priorities = 10 * np.random.random(num_leaves)
    for j in range(num_leaves):
        add_priority(priorities[j], j, verbose=False)
    
    # Test time to update
    start_time = time()
    t = 1000
    for j in range(t):
        #print("Update iteration %d of %d..." % (j+1, t))
        p = 10 * np.random.random()
        k = np.random.randint(0, num_leaves)
        add_priority(p, k, verbose=False)
    end_time = time()
    update_time.append((end_time - start_time) / t)
    
    # Test time to search
    start_time = time()
    t = 1000
    for j in range(t):
        #print("Search iteration %d of %d..." % (j+1, t))
        m = heap[1] * np.random.random()
        retrieve(m, verbose=False)
    end_time = time()
    search_time.append((end_time - start_time) / t)

print("Update times: ", update_time)
update_time_ratios = [update_time[i+1] / update_time[i] for i in range(len(update_time)-2)]
print("Update time ratios: ", update_time_ratios)

print("Search times: ", search_time)
search_time_ratios = [search_time[i+1] / search_time[i] for i in range(len(search_time)-2)]
print("Search time ratios: ", search_time_ratios)

Creating heap of size 1024...
Creating heap of size 4096...
Creating heap of size 16384...
Creating heap of size 65536...
Creating heap of size 262144...
Update times:  [1.2619733810424804e-05, 1.143193244934082e-05, 1.329207420349121e-05, 1.4959096908569336e-05, 1.6237974166870117e-05]
Update time ratios:  [0.9058774631123538, 1.1627145508769734, 1.125414790766085]
Search times:  [1.3097047805786133e-05, 1.2251138687133789e-05, 1.611495018005371e-05, 1.6282081604003907e-05, 1.766824722290039e-05]
Search time ratios:  [0.9354122294431398, 1.3153838668872238, 1.0103712032667072]
