# Replay Buffer



**目录：**
1. Basic Replay Buffer

2. Proportion-based Prioritized Experience Replay 分析

3. Rank-based Prioritized Experience Replay 分析
---

本节的重点内容在于如何分析，如何设计这些replay buffer，各种设计都有trade-off，所以并不是一定要按照此处的设计理念进行的

## 1 Basic Replay Buffer

### 1) 功能分析

设计更为复杂的结构之前，我们从简单的结构开始，一个基础的Replay Buffer只需要能够实现三个基础的功能：
  * 1. 记录新加入的transition
  * 2. 忘记太久之前的transition
  * 3. 从储存的记忆中抽样
  
### 2) 功能实现

这些功能都比较简单，只要我们将所有的transition储存在list中，就可以使用random.sample()函数就行抽样

至于遗忘以前的功能，可以使用deque数据结构自动遗忘，也可以选择覆盖list中已经存在的transition，这里我们选择第二种模式

### 3) 代码

In [1]:
import random
import torch
import numpy as np

class Replay_Buffer:
    '''
    Vanilla replay buffer
    '''
    
    def __init__(self, capacity=int(1e6), batch_size=None):
        
        self.capacity = capacity
        self.memory = [None for _ in range(capacity)] # save tuples (state, action, reward, next_state, done)
        self.ind_max = 0 # how many transitions have been stored
        
    def remember(self, state, action, reward, next_state, done):
        
        ind = self.ind_max % self.capacity
        self.memory[ind] = (state, action, reward, next_state, done)
        self.ind_max += 1
        
    def sample(self, k):
        '''
        return sampled transitions. Make sure that there are at least k transitions stored before calling this method 
        '''
        index_set = random.sample(list(range(len(self))), k)
        states = torch.from_numpy(np.vstack([self.memory[ind][0] for ind in index_set])).float()
        actions = torch.from_numpy(np.vstack([self.memory[ind][1] for ind in index_set])).long()
        rewards = torch.from_numpy(np.vstack([self.memory[ind][2] for ind in index_set])).float()
        next_states = torch.from_numpy(np.vstack([self.memory[ind][3] for ind in index_set])).float()
        dones = torch.from_numpy(np.vstack([self.memory[ind][4] for ind in index_set]).astype(np.uint8)).float()
        
        return states, actions, rewards, next_states, dones
    
    def __len__(self):
        return min(self.ind_max, self.capacity)

## 2 Proportion-based Replay Buffer

### 1) 功能分析

除了之前的基础的功能，Proportion-based Replay Buffer还需要一个Sum Tree来记录和更新cumulative weight以进行快速采样，所以需要加入的新功能是：
  * 1. 一个Sum Tree储存和更新每个transition的weight
  * 2. 更新Sum Tree的方法
  
### 2) 功能实现

由于我们已经有了Sum Tree的实现方法，此处只要正确的使用即可。只需要再实现一个更新Sum Tree的方法即可

为了方便结合agent使用，应该尽量与基础的Replay Buffer类调用方式相同

### 3) 代码

In [2]:
class SumTree:
    
    def __init__(self, capacity):
        
        self.capacity = capacity
        # the first capacity-1 positions are not leaves
        self.vals = [0 for _ in range(2*capacity - 1)] # think about why if you are not familiar with this
        
    def retrive(self, num):
        '''
        This function find the first index whose cumsum is no smaller than num
        '''
        ind = 0 # search from root
        while ind < self.capacity-1: # not a leaf
            left = 2*ind + 1
            right = left + 1
            if num > self.vals[left]: # the sum of the whole left tree is not large enouth
                num -= self.vals[left] # think about why?
                ind = right
            else: # search in the left tree
                ind = left
        return ind - self.capacity + 1
    
    def update(self, delta, ind):
        '''
        Change the value at ind by delta, and update the tree
        Notice that this ind should be the index in real memory part, instead of the ind in self.vals
        '''
        ind += self.capacity - 1
        while True:
            self.vals[ind] += delta
            if ind == 0:
                break
            ind -= 1
            ind //= 2

In [3]:
EPSILON = 0.05
ALPHA = 0.5
TD_INIT = 1

class Proportion_Replay_Buffer:
    '''
    Proportion-based replay buffer
    '''
    
    def __init__(self, capacity=int(1e6), batch_size=None):
        self.capacity = capacity
        self.alpha = ALPHA
        self.memory = [None for _ in range(capacity)]
        self.weights = SumTree(self.capacity)
        self.default = TD_INIT
        self.ind_max = 0
        
    def remember(self, state, action, reward, next_state, done):
        index = self.ind_max % self.capacity
        self.memory[index] = (state, action, reward, next_state, done)
        delta = (self.default+EPSILON)**self.alpha - self.weights.vals[index+self.capacity-1]
        self.weights.update(delta, index)
        self.ind_max += 1
        
    def sample(self, batch_size):
        index_set = [self.weights.retrive(self.weights.vals[0]*random.random()) for _ in range(batch_size)]
        #print(index_set)
        probs = torch.from_numpy(np.vstack([self.weights.vals[ind+self.capacity-1]/self.weights.vals[0] for ind in index_set])).float()                     
        
        states = torch.from_numpy(np.vstack([self.memory[ind][0] for ind in index_set])).float()
        actions = torch.from_numpy(np.vstack([self.memory[ind][1] for ind in index_set])).long()
        rewards = torch.from_numpy(np.vstack([self.memory[ind][2] for ind in index_set])).float()
        next_states = torch.from_numpy(np.vstack([self.memory[ind][3] for ind in index_set])).float()
        dones = torch.from_numpy(np.vstack([self.memory[ind][4] for ind in index_set]).astype(np.uint8)).float()

        return index_set, states, actions, rewards, next_states, dones, probs
                                 
    def insert(self, error, index):
        delta = (error+EPSILON)**self.alpha - self.weights.vals[index+self.capacity-1]
        self.weights.update(delta, index)
            
    def __len__(self):
        return min(self.capacity, self.ind_max)

## 3 Rank-based Replay Buffer

### 1) 功能分析

除了之前的基础的功能，Rank-based Replay Buffer还需要一个知道每个transition对应的TD-error的rank以调整weight，基于这个rank，还需要计算和储存分割点以进行论文中描述的抽样，大致需要的新功能有：
  * 1. 对于所有transition TD-error及对应rank的储存
  * 2. 更新rank所需要的方法
  
### 2) 功能实现

由于训练过程中会有大量的TD-error的变更，以及新加入的transition，想要快速更新rank，需要一直维持记录一个排好序的所有TD-error的序列，这样才能在$O(\log(n))$的时间内确定rank，否则每个新样本加入，每个训练后改变了TD-error的transition更新rank，都需要$O(n)$时间

在这样的一个有序TD-error的基础上，快速抽样的方法是在这个序列里抽样，再对应到具体的transition，此处有两种储存方式：
  * 1. 将transition与TD-error一起储存到tuple中
  * 2. 将transition储存在list中，将其index和TD-error一起储存到tuple中

$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$
$$$$

此处我们选择第二种，因为我们其实需要一个transition $\longrightarrow$ rank的对应关系，不然当我们删除transition时，会需要$O(n)$时间寻找应该删除的rank和TD-error，所以没有必要将transition和它对应的TD-error对应起来。所以我们选择的储存方式是：
  * 与之前的方式类似，建立一个list储存transition
  * 建立另外的一个list，储存每个transition对应的rank
  * 建立第三个list，储存(TD-error, transition index)的tuple
这样各种寻找都可以在$O(1)$时间完成

在此基础上，还需要考虑对rank的list，以及TD-error的list的处理，主要为删改处理，这个平均时间为O(n)

### 3) 代码

注意此处的代码并没有最优化，寻找进一步优化空间的任务当做同学们的作业

In [4]:
import bisect

class Rank_Replay_Buffer:
    '''
    Rank-based replay buffer
    '''
    
    def __init__(self, capacity=int(1e6), batch_size=64):
        self.capacity = capacity
        self.batch_size = batch_size
        self.alpha = ALPHA
        self.memory = [None for _ in range(capacity)]
        self.segments = [-1] + [None for _ in range(batch_size)] # the ith index will be in [segments[i-1]+1, segments[i]]
        
        self.errors = [] # saves (-TD_error, index of transition), sorted
        self.memory_to_rank = [None for _ in range(capacity)]
        
        self.ind_max = 0 # how many transitions have been stored
        self.total_weights = 0 # sum of p_i
        self.cumulated_weights = []
        
    def remember(self, state, action, reward, next_state, done):
        index = self.ind_max % self.capacity
        if self.ind_max >= self.capacity: # memory is full, need to pop
            self.pop(index)
        else: # memory is not full, need to adjust weights and find segment points
            self.total_weights += (1/(1+self.ind_max))**self.alpha # memory is not full, calculate new weights
            self.cumulated_weights.append(self.total_weights)
            self.update_segments()
        
        max_error = -self.errors[0][0] if self.errors else 0
        self.insert(max_error, index)
        self.memory[index] = (state, action, reward, next_state, done)
        self.ind_max += 1
        
    def sample(self, batch_size=None): # notive that batch_size is not used. It's just to unify the calling form
        index_set = [random.randint(self.segments[i]+1, self.segments[i+1]) for i in range(self.batch_size)]
        probs = torch.from_numpy(np.vstack([(1/(1+ind))**self.alpha/self.total_weights for ind in index_set])).float()
        
        index_set = [self.errors[ind][1] for ind in index_set]
        states = torch.from_numpy(np.vstack([self.memory[ind][0] for ind in index_set])).float()
        actions = torch.from_numpy(np.vstack([self.memory[ind][1] for ind in index_set])).long()
        rewards = torch.from_numpy(np.vstack([self.memory[ind][2] for ind in index_set])).float()
        next_states = torch.from_numpy(np.vstack([self.memory[ind][3] for ind in index_set])).float()
        dones = torch.from_numpy(np.vstack([self.memory[ind][4] for ind in index_set]).astype(np.uint8)).float()
        for ind in index_set:
            self.pop(ind)
        
        return index_set, states, actions, rewards, next_states, dones, probs
    
    def insert(self, error, index):
        '''
        Input : 
            error : the TD-error of this transition
            index : the location of this transition
        insert error into self.errors, update self.memory_to_rank and self.rank_to_memory accordingly
        '''
        ind = bisect.bisect(self.errors, (-error, index))
        self.memory_to_rank[index] = ind
        self.errors.insert(ind, (-error, index))
        for i in range(ind+1, len(self.errors)):
            self.memory_to_rank[self.errors[i][1]] += 1
        
    def pop(self, index):
        '''
        Input :
            index : the location of a transition
        remove this transition, update self.memory_to_rank and self.rank_to_memory accordingly
        '''
        ind = self.memory_to_rank[index]
        self.memory_to_rank[index] = None
        self.errors.pop(ind)
        for i in range(ind, len(self.errors)):
            self.memory_to_rank[self.errors[i][1]] -= 1
        
    def update_segments(self):
        '''
        Update the segment points.
        '''
        if self.ind_max+1 < self.batch_size: # if there is no enough transitions
            return None
        for i in range(self.batch_size):
            ind = bisect.bisect_left(self.cumulated_weights, self.total_weights*((i+1)/self.batch_size))
            self.segments[i+1] = max(ind, self.segments[i]+1)
            
    def __len__(self):
        return min(self.capacity, self.ind_max)

## 4 Debug

In [5]:
states = np.random.standard_normal((11,5))
actions = np.random.randint(low=0, high=4, size=10).reshape(-1, 1)
rewards = np.random.rand(10, 1)
dones = np.array([random.random()<0.2 for _ in range(10)])

def get_transition(i):
    return states[i, :], actions[i], rewards[i], dones[i], states[i+1, :]

In [6]:
capacity = 7
batch_size = 3
N = 100

test = Proportion_Replay_Buffer(capacity=capacity, batch_size=batch_size)
for i in range(N):
    test.remember(*get_transition(i % 10))
    print('Episode {}'.format(i+1))
    print('Current errors are : {}'.format(test.weights.vals[capacity-1:]))
    if len(test) >= batch_size:
        temp_index, temp_states, temp_actions, temp_rewards, temp_next_states, temp_dones, temp_probs = test.sample(batch_size)
        temp_probs = temp_probs.numpy().reshape(-1)
        delta = [np.round(random.random(), 3) for _ in range(batch_size)]
        print('Sampled transitions {} with probs {}'.format(temp_index, temp_probs))
        print('Generated TD-error : {}'.format(delta))
        for ind, error in zip(temp_index, delta):
            test.insert(error, ind)
        print('=========================After updates========================')
        print('Current errors are : {}'.format(test.weights.vals[capacity-1:]))
    print()

Episode 1
Current errors are : [1.05, 0, 0, 0, 0, 0, 0]

Episode 2
Current errors are : [1.05, 1.05, 0, 0, 0, 0, 0]

Episode 3
Current errors are : [1.05, 1.05, 1.05, 0, 0, 0, 0]
Sampled transitions [1, 1, 0] with probs [0.33333334 0.33333334 0.33333334]
Generated TD-error : [0.389, 0.187, 0.522]
Current errors are : [0.5720000000000001, 0.237, 1.05, 0, 0, 0, 0]

Episode 4
Current errors are : [0.5720000000000001, 0.237, 1.05, 1.05, 0, 0, 0]
Sampled transitions [0, 0, 2] with probs [0.19663115 0.19663115 0.36094877]
Generated TD-error : [0.053, 0.985, 0.613]
Current errors are : [1.035, 0.237, 0.663, 1.05, 0, 0, 0]

Episode 5
Current errors are : [1.035, 0.237, 0.663, 1.05, 1.05, 0, 0]
Sampled transitions [1, 0, 2] with probs [0.05873606 0.25650558 0.16431227]
Generated TD-error : [0.513, 0.054, 0.614]
Current errors are : [0.10399999999999998, 0.5630000000000001, 0.664, 1.05, 1.05, 0, 0]

Episode 6
Current errors are : [0.10399999999999998, 0.5630000000000001, 0.664, 1.05, 1.05, 1.05,

In [7]:
capacity = 7
batch_size = 3
N = 100

test = Rank_Replay_Buffer(capacity=capacity, batch_size=batch_size)
for i in range(N):
    test.remember(*get_transition(i % 10))
    print('Episode {}'.format(i+1))
    print('Current relations (memory_to_rank) are : {}'.format(test.memory_to_rank))
    print('Current errors (td_error, index) are : {}'.format(test.errors))
    print('Current segment weights are : {}'.format([test.total_weights * (i / batch_size) for i in range(1, batch_size+1)]))
    print('Current cumulated sums are : {}'.format(test.cumulated_weights))
    print('Current segment is : {}'.format(test.segments))
    if len(test) >= batch_size:
        temp_index, temp_states, temp_actions, temp_rewards, temp_next_states, temp_dones, temp_probs = test.sample(batch_size)
        temp_probs = temp_probs.numpy().reshape(-1)
        delta = [np.round(random.random(), 3) for _ in range(batch_size)]
        print('Sampled transitions {} with probs {}'.format(temp_index, temp_probs))
        print('Generated TD-error : {}'.format(delta))
        for ind, error in zip(temp_index, delta):
            test.insert(error, ind)
        print('=========================After updates========================')
        print('Current relations (memory_to_rank) are : {}'.format(test.memory_to_rank))
        print('Current errors (td_error, index) are : {}'.format(test.errors))
    print()

Episode 1
Current relations (memory_to_rank) are : [0, None, None, None, None, None, None]
Current errors (td_error, index) are : [(0, 0)]
Current segment weights are : [0.3333333333333333, 0.6666666666666666, 1.0]
Current cumulated sums are : [1.0]
Current segment is : [-1, None, None, None]

Episode 2
Current relations (memory_to_rank) are : [0, 1, None, None, None, None, None]
Current errors (td_error, index) are : [(0, 0), (0, 1)]
Current segment weights are : [0.5690355937288492, 1.1380711874576983, 1.7071067811865475]
Current cumulated sums are : [1.0, 1.7071067811865475]
Current segment is : [-1, None, None, None]

Episode 3
Current relations (memory_to_rank) are : [0, 1, 2, None, None, None, None]
Current errors (td_error, index) are : [(0, 0), (0, 1), (0, 2)]
Current segment weights are : [0.7614856834587244, 1.5229713669174487, 2.284457050376173]
Current cumulated sums are : [1.0, 1.7071067811865475, 2.284457050376173]
Current segment is : [-1, 0, 1, 2]
Sampled transitions [0

Sampled transitions [0, 4, 3] with probs [0.24888726 0.14369513 0.09407055]
Generated TD-error : [0.087, 0.467, 0.464]
Current relations (memory_to_rank) are : [6, 0, 1, 4, 3, 2, 5]
Current errors (td_error, index) are : [(-0.999, 1), (-0.831, 2), (-0.732, 5), (-0.467, 4), (-0.464, 3), (-0.271, 6), (-0.087, 0)]

Episode 73
Current relations (memory_to_rank) are : [6, 0, 1, 4, 3, 2, 5]
Current errors (td_error, index) are : [(-0.999, 1), (-0.999, 2), (-0.732, 5), (-0.467, 4), (-0.464, 3), (-0.271, 6), (-0.087, 0)]
Current segment weights are : [1.3392944697830738, 2.6785889395661475, 4.0178834093492215]
Current cumulated sums are : [1.0, 1.7071067811865475, 2.284457050376173, 2.784457050376173, 3.231670645876131, 3.639918936339994, 4.0178834093492215]
Current segment is : [-1, 1, 3, 6]
Sampled transitions [1, 4, 3] with probs [0.24888726 0.12444363 0.11130577]
Generated TD-error : [0.208, 0.53, 0.946]
Current relations (memory_to_rank) are : [6, 5, 0, 1, 3, 2, 4]
Current errors (td_erro