[explaining article](https://www.fcodelabs.com/2019/03/18/Sum-Tree-Introduction/)

In [1]:
%load_ext cython

In [2]:
%matplotlib inline
import matplotlib.pyplot as plot
import numba
import numpy as np
import seaborn as sns

In [3]:
import cProfile
import pstats

In [4]:
# benchmark workload: fill, sample, update, repeat
def buffer_benchmark(factory, max_frame, batch_size):
    buffer = factory()
    for frame in range(max_frame):
        buffer.add('frame%d' % frame, 1.0)
        if len(buffer) > batch_size:
            indices, weights, objects = buffer.sample(batch_size)
            assert len(indices) == len(weights) == len(objects) == batch_size, 'invalid output of sample()'
            weights = np.clip(weights + np.random.random(weights.shape) - 0.5, 0.01, 1.0)
            buffer.update(indices, weights)

In [5]:
class NaivePrioritizedBuffer(object):
    def __init__(self, capacity, prob_alpha=0.6):
        self.prob_alpha = prob_alpha
        self.capacity = capacity
        self.buffer = []
        self.pos = 0
        self.priorities = np.zeros((capacity,), dtype=np.float32)
    
    def add(self, obj, p):
        max_prio = np.max(self.priorities) if self.buffer else 1.0

        if len(self.buffer) < self.capacity:
            self.buffer.append(obj)
        else:
            self.buffer[self.pos] = obj
        
        self.priorities[self.pos] = max_prio
        self.pos = (self.pos + 1) % self.capacity
    
    def sample(self, n, beta=0.4):
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[: self.pos]
        probs = prios ** self.prob_alpha
        probs /= probs.sum()
        
        indices = np.random.choice(len(self.buffer), n, p=probs)
        samples = [self.buffer[idx] for idx in indices]
        weights = (len(self.buffer) * probs[indices]) ** (-beta)
        weights /= weights.max()
        
        return indices, weights, samples
    
    def update(self, indices, priorities):
        for idx, prio in zip(indices, priorities):
            self.priorities[idx] = prio
    
    def __len__(self):
        return len(self.buffer)

In [6]:
factory = lambda: NaivePrioritizedBuffer(10000, 0.6)
%timeit buffer_benchmark(factory, 10000, 32)
%timeit buffer_benchmark(factory, 20000, 32)
%timeit buffer_benchmark(factory, 50000, 32)
%timeit buffer_benchmark(factory, 100000, 32)

1.94 s ± 57.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
4.87 s ± 83.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
13.6 s ± 177 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
28 s ± 58.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


most of sumtree code derived from [github](https://github.com/chuyangliu/Snake/blob/master/snake/util/sumtree.py):

In [7]:
%%cython -c=-O3
# cython: boundscheck=False
import numpy as np
cimport numpy as np

cdef class SumTree:
    cpdef float[:] tree
    cpdef np.ndarray nd_tree
    cpdef size_t capacity
    cpdef size_t pos
    cpdef size_t count
    def __init__(self, capacity : size_t):
        self.nd_tree = np.zeros((capacity * 2 - 1, ), dtype=np.float32)
        self.tree = self.nd_tree
        self.capacity = capacity
        self.pos = 0
        self.count = 0

    def get_one(self, s : float) -> [size_t, float]:
        cdef size_t parent = 0
        cdef size_t left, right
        with nogil:
            while parent < self.capacity - 1:
                left = parent * 2 + 1
                right = left + 1
                if s <= self.tree[left]:
                    parent = left
                else:
                    s -= self.tree[left]
                    parent = right
        return parent - self.capacity + 1, self.tree[parent]
    
    def add_one(self, value : float) -> size_t:
        cdef size_t idx = self.pos
        self.update_one(idx, value)
        self.pos = (self.pos + 1) % self.capacity
        self.count = min(self.count + 1, self.capacity)
        return idx
    
    def update_one(self, idx : size_t, value : float):
        idx = idx + self.capacity - 1
        cdef float[:] tree = self.tree
        cdef float change = value - tree[idx]
        with nogil:
            while True:
                tree[idx] += change
                if 0 == idx:
                    break
                idx = (idx - 1) // 2

    def sum_total(self) -> float:
        return self.tree[0]
    
    def __len__(self) -> size_t:
        return self.count
        
class SumTreePrioritizedBufferCython(object):
    def __init__(self, capacity : size_t, alpha : float):
        self.capacity = capacity
        self.tree = SumTree(capacity)
        self.data = np.empty(capacity, dtype=object)

    def __len__(self) -> size_t:
        return len(self.tree)
    
    def add(self, obj : object, p : float):
        idx = self.tree.add_one(p)
        self.data[idx] = obj
    
    def update(self, indices, priorities):
        for i in range(indices.size):
            idx, prio = indices[i], priorities[i]
            self.tree.update_one(int(idx), float(prio))

    def sample(self, n : size_t):
        segment = self.tree.sum_total() / float(n)
        a = np.arange(n) * segment
        b = a + segment
        s = np.random.uniform(a, b)
        indices = np.zeros(n, dtype=np.uint32)
        weights = np.empty(n, dtype=np.float32)
        samples = np.empty(n, dtype=object)
        for i in range(n):
            idx, prio = self.tree.get_one(s[i])
            indices[i] = idx
            weights[i] = prio
            samples[i] = self.data[idx]
        return indices, weights, samples

In [8]:
factory = lambda: SumTreePrioritizedBufferCython(10000, 0.6)
%timeit buffer_benchmark(factory, 10000, 32)
%timeit buffer_benchmark(factory, 20000, 32)
%timeit buffer_benchmark(factory, 50000, 32)
%timeit buffer_benchmark(factory, 100000, 32)

597 ms ± 2.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.21 s ± 3.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.05 s ± 35.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
6.04 s ± 39.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
class SumTreePrioritizedBufferNumba(object):
    def __init__(self, capacity, prob_alpha=0.6):
        self.capacity, self.prob_alpha = capacity, prob_alpha
        self.tree = np.zeros(capacity * 2 - 1, dtype=np.float32)
        self.data = np.empty(capacity, dtype=object)
        self.pos, self.len = 0, 0

    @numba.jit
    def _update(self, idx, value):
        tree = self.tree
        change = value - tree[idx]
        tree[idx] = value
        while True:
            tree[idx] += change
            if 0 == idx:
                break
            idx = (idx - 1) // 2
    
    @numba.jit
    def _retrieve(self, s):
        tree_idx, parent = None, 0
        while True:
            if parent >= self.capacity - 1:
                tree_idx = parent
                break
            left = parent * 2 + 1
            right = left + 1
            if s <= self.tree[left]:
                parent = left
            else:
                s -= self.tree[left]
                parent = right
        return tree_idx
    
    @numba.jit
    def add(self, obj, p=1):
        idx = self.pos + self.capacity - 1
        self.data[self.pos] = obj
        self._update(idx, p)
        self.pos = (self.pos + 1) % self.capacity
        self.len = min(self.len + 1, self.capacity)

    @numba.jit
    def sample(self, n, beta=0.4):
        segment = self.tree[0] / n
        a = np.arange(n) * segment
        b = a + segment
        s = np.random.uniform(a, b)
        indices = np.zeros(n, dtype=np.int32)
        weights = np.empty(n, dtype=np.float32)
        samples = np.empty(n, dtype=object)
        for i in range(n):
            idx = self._retrieve(s[i])
            indices[i] = idx
            weights[i] = self.tree[idx]
            samples[i] = self.data[idx - self.capacity + 1]
        return indices, weights, samples
    
    @numba.jit
    def update(self, indices, priorities):
        for idx, prio in zip(indices, priorities):
            self._update(idx, prio)
    
    def __len__(self):
        return self.len

In [10]:
pr = cProfile.Profile()
pr.enable()
buffer_benchmark(factory, 10000, 32)
pr.disable()
st = pstats.Stats(pr).sort_stats('tottime')
st.print_stats()

         149576 function calls in 0.647 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.540    0.540    0.647    0.647 <ipython-input-4-64489c64b0f7>:2(buffer_benchmark)
     9968    0.029    0.000    0.029    0.000 {method 'reduce' of 'numpy.ufunc' objects}
     9968    0.014    0.000    0.048    0.000 /home/user/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py:69(_wrapreduction)
     9968    0.013    0.000    0.013    0.000 {method 'random_sample' of 'mtrand.RandomState' objects}
     9968    0.012    0.000    0.012    0.000 {method 'clip' of 'numpy.ndarray' objects}
     9969    0.010    0.000    0.010    0.000 /home/user/.local/lib/python3.7/site-packages/numpy/core/_internal.py:886(npy_ctypes_check)
     9968    0.008    0.000    0.056    0.000 /home/user/.local/lib/python3.7/site-packages/numpy/core/fromnumeric.py:2171(all)
    39904    0.006    0.000    0.006    0.000 {built-in method built

<pstats.Stats at 0x7f6b4d79cf98>

In [11]:
factory = lambda: SumTreePrioritizedBufferNumba(10000, 0.6)
%timeit buffer_benchmark(factory, 10000, 32)
%timeit buffer_benchmark(factory, 20000, 32)
%timeit buffer_benchmark(factory, 50000, 32)
%timeit buffer_benchmark(factory, 100000, 32)

14.3 s ± 50 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
29 s ± 94 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min 12s ± 243 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2min 25s ± 727 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
factory = lambda: SumTreePrioritizedBufferCython(100000, 0.6)
%timeit buffer_benchmark(factory, 100000, 32)
%timeit buffer_benchmark(factory, 200000, 32)
%timeit buffer_benchmark(factory, 500000, 32)
%timeit buffer_benchmark(factory, 1000000, 32)

6.3 s ± 18.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
12.6 s ± 58.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
31.8 s ± 118 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1min 4s ± 259 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
factory = lambda: SumTreePrioritizedBufferNumba(100000, 0.6)
%timeit buffer_benchmark(factory, 100000, 32)
%timeit buffer_benchmark(factory, 200000, 32)
%timeit buffer_benchmark(factory, 500000, 32)
%timeit buffer_benchmark(factory, 1000000, 32)

2min 31s ± 260 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
5min 4s ± 740 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
12min 42s ± 1.85 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
25min 25s ± 6.62 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [14]:
factory = lambda: NaivePrioritizedBuffer(100000, 0.6)
%timeit buffer_benchmark(factory, 100000, 32)
%timeit buffer_benchmark(factory, 200000, 32)
%timeit buffer_benchmark(factory, 500000, 32)
%timeit buffer_benchmark(factory, 1000000, 32)

1min 50s ± 1.25 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
5min 22s ± 4.58 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
16min 33s ± 10.5 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
35min 12s ± 4.82 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
