In [1]:
import tensorflow as tf
import numpy as np

In [9]:
class Memory(object):
    
    @property
    def root_priority(self):
        return self.tree[0]

    eps = 0.001
    alpha = 0.6 
    beta = 0.4
    beta_inc = 1e-5
    abs_err = 1
    data_pointer=0
    

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)+1e-5
        self.data = np.zeros(capacity, dtype=object)

    def remember(self, error, data):
        p = self._get_priority(error)
        idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        change = p - self.tree[idx]
        self.tree[idx] = p
        self._propagate_change(idx, change)
        self.data_pointer += 1
        if self.data_pointer >= self.capacity:
            self.data_pointer = 0

    def sample(self, n):
        batch_idx, batch_memory, weights = [], [], []
        segment = self.root_priority / n
        self.beta = np.min([1, self.beta + self.beta_inc])

        min_prob = np.min(self.tree[-self.capacity:]) / self.root_priority
        maxiwi = np.power(self.capacity * min_prob, -self.beta)
        for i in range(n):
            a = segment * i
            b = segment * (i + 1)
            lower_bound = np.random.uniform(a, b)
            while True:
                idx = self._recover(lower_bound)
                data_idx = idx - self.capacity + 1
                data = self.data[data_idx]
                p = self.tree[idx]
                if type(data) is int:
                    i -= 1
                    lower_bound = np.random.uniform(segment * i, segment * (i+1))
                else:
                    break
            prob = p / self.root_priority
            weights.append(self.capacity * prob)
            batch_idx.append(idx)
            batch_memory.append(data)

        weights = np.vstack(weights)
        weights = np.power(weights, -self.beta) / maxiwi
        return batch_idx, np.vstack(batch_memory), weights

    def update(self, idx, error):
        p = self._get_priority(error)
        change = p - self.tree[idx]
        self.tree[idx] = p
        self._propagate_change(idx, change)

    def _get_priority(self, error):
        error += self.eps
        c_error = np.clip(error, 0, self.abs_err)
        return np.power(c_error, self.alpha)

    def _propagate_change(self, tree_idx, change):
        parent_idx = (tree_idx - 1) // 2
        self.tree[parent_idx] += change
        if parent_idx != 0:
            self._propagate_change(parent_idx, change)

    
    def _recover(self, lower_bound, parent_i=0):

        lchild_i = 2 * parent_i + 1
        rchild_i = lchild_i + 1

        if lchild_i >= len(self.tree):
            return parent_i

        if self.tree[lchild_i] == self.tree[rchild_i]:
            return self._recover(lower_bound, np.random.choice([lchild_i, rchild_i]))
        if lower_bound <= self.tree[lchild_i]:
            return self._recover(lower_bound, lchild_i)
        else:
            return self._recover(lower_bound - self.tree[lchild_i], rchild_i)
