## Implementation of Prioritized Experience Replay

Implementing the rank-based approach. Code based on contribution from [Damcy](https://github.com/Damcy/prioritized-experience-replay).

In [3]:
import math
import random
import numpy as np
from utils import BinaryHeap

## Current

In [None]:
class Experience_Buffer():
    def __init__(self,memory_size):
        self.memory_size = memory_size
        # These are the arrays where we will store the experiences
        self.actions = np.empty(self.memory_size,dtype=np.uint8)
        self.rewards = np.empty(self.memory_size,dtype=np.integer)
        self.screens = np.empty((self.memory_size, screen_height, screen_width),dtype=np.float16)
        self.terminals = np.empty(self.memory_size,dtype=np.bool)
        
        self.prestates = np.empty((batch_size,history_length,screen_height, screen_width),dtype=np.float16)
        self.poststates = np.empty((batch_size,history_length,screen_height, screen_width),dtype=np.float16)
        self.current = 0     # Pointer to the current saving location
        self.count = 0       # Number of collected experiences
        
    def add(self, screen, reward, action, terminal):
        # Adds an experience to replay memory and increases pointers
        self.actions[self.current] = action
        self.rewards[self.current] = reward
        self.screens[self.current,...] = screen
        self.terminals[self.current] = terminal
        
        self.count = max(self.count, self.current+1)
        self.current = (self.current + 1) % self.memory_size    # Pointer resets when reaching memory_size
        
    def getState(self,index):
        index = index % self.count
        # If index is not in the beginning, just use simple slicing
        if index >= history_length-1:
            return self.screens[(index-(history_length-1)):(index+1),...]
        # Otherwise determine the list of indexes which need to be returned
        else:
            indexes = [(index-i) % self.count for i in reversed(range(history_length))]
            return self.screens[indexes,...]
        
    def sample_from_replay(self):
        # Sample random indexes
        indexes = []
        while len(indexes) < batch_size:
            while True:
                index = random.randint(history_length,self.count-1)
                # If index wraps over current pointer, get new one
                if index >= self.current and index - history_length < self.current:
                    continue
                # If index wraps over terminal state, get new one
                if self.terminals[(index-history_length):index].any():
                    continue
                # Use the index otherwise
                break
            self.prestates[len(indexes),...] = self.getState(index-1)
            self.poststates[len(indexes),...] = self.getState(index)
            indexes.append(index)
            
        actions = self.actions[indexes]
        rewards = self.rewards[indexes]
        terminals = self.terminals[indexes]
        
        return np.transpose(self.prestates,(0,2,3,1)),actions,rewards,np.transpose(self.poststates,(0,2,3,1)),terminals    

In [None]:
class Experience():
    def __init__(self,memory_size):
        self.size = memory_size
        self.alpha = 0.7
        self.beta_zero = 0.5
        self.batch_size = 32        
        self.partition_num = 100    # Split total size to N segments
        
        self.index = 0
        self.record_size = 0
        self.learn_start = 1000     # CHANGE
        self._experience = {}
        self.priority_queue = BinaryHeap(self.size)
        self.distributions = self.build_distributions()
        
        self.beta_grad = (1-self.beta_zero) / (TOTAL_STEPS - LEARN_START)
        
    def build_distributions(self):
        '''
        Preprocess probabilities: (rank_i)^(-alpha) / sum((rank_i)^(-alpha))
        '''
        results = {}  
        
        # Creating different distributions according to the number of experiences which have been collected
        partition_size = math.floor(self.size / self.partition_num)
        current_partition = 1
        
        for n in range(partition_size,self.size+1,partition_size):
            if self.learn_start <= n <= self.size:
                distribution = {}
                # P(i) = (rank_i)^(-alpha) / sum((rank_i)^(-alpha))
                pdf = list(map(lambda x: math.pow(x,-self.alpha),range(1,n+1)))
                pdf_sum = math.fsum(pdf)
                distribution['pdf'] = list(map(lambda x: x/pdf_sum, pdf))
                
                # Split each distribution to K segments, setting k = batch_size
                # strata_ends keeps start and end position of each segment
                cdf = np.cumsum(distribution['pdf'])
                strata_ends = {1: 0, self.batch_size+1: n}
                step = 1/float(self.batch_size)
                index = 1
                for s in range(2,self.batch_size+1):
                    while cdf[index] < step:
                        index += 1
                    strata_ends[s] = index
                    step += 1/float(self.batch_size)
                    
                distribution['strata_ends'] = strata_ends
                results[current_partition] = distribution
            
            current_partition += 1
            
        return results
    
    def store(self,experience):
        '''
        Store experience: experience is a tuple of (s1,a,r,s2,t)
        '''
        if self.record_size < self.size:
            self.record_size += 1
        if self.index % self.size == 0:
            self.index = 1
        else:
            self.index += 1
            
        if self.index in self._experience:
            del self._experience[self.index]
        self._experience[self.index] = experience
        # Add to priority queue
        priority = self.priority_queue.get_max_priority()
        self.priority_queue.update(priority,self.index)
        
    def retrieve(self,indices):
        '''
        Get experiences from indices
        '''
        return [self._experience[v] for v in indices]
    
    def rebalance(self):
        '''
        Rebalance priority queue
        '''
        self.priority_queue.balance_tree()
        
    def update_priority(self,indices,delta):
        '''
        Update the priority values according to new observations
        '''
        for i in range(0,len(indices)):
            self.priority_queue.update(math.fabs(delta[i]),indices[i])
    
    def sample_from_replay(self,global_step):
        

In [16]:
alpha = 0.9
n = 10
batch_size = 5

distribution = {}
# P(i) = (rank_i)^(-alpha) / sum((rank_i)^(-alpha))
pdf = list(map(lambda x: math.pow(x,-alpha),range(1,n+1)))
pdf_sum = math.fsum(pdf)
distribution['pdf'] = list(map(lambda x: x/pdf_sum, pdf))

print(distribution['pdf'])

cdf = np.cumsum(distribution['pdf'])

print(cdf)

strata_ends = {1: 0, batch_size+1: n}
step = 1/float(batch_size)
print('step: %.2f' % step)
index = 1
for s in range(2,batch_size+1):
    while cdf[index] < step:
        index += 1
    strata_ends[s] = index
    step += 1/float(batch_size)

distribution['strata_ends'] = strata_ends

print(strata_ends)

[0.3104488028397141, 0.1663653941798837, 0.11549970106682916, 0.0891530072831946, 0.07293180893490515, 0.061894757267151136, 0.05387679509211429, 0.04777591365571642, 0.042970630984892164, 0.03908318869559933]
[ 0.3104488   0.4768142   0.5923139   0.68146691  0.75439871  0.81629347
  0.87017027  0.91794618  0.96091681  1.        ]
step: 0.20
{1: 0, 2: 1, 3: 1, 4: 3, 5: 5, 6: 10}
