In [None]:
%load_ext cython

In [None]:
%%cython -c=-O3
# cython: boundscheck=False
import numpy as np
cimport numpy as np

from cython.parallel cimport prange
from libc.math cimport pow, exp, log

cdef int _st_get_one_tree_idx(float[:] tree, int capacity, float s) nogil:
    cdef int parent = 0
    cdef int base = capacity - 1
    cdef float lval
    while parent < base:
        parent = parent * 2 + 1
        lval = tree[parent]
        if s > lval and tree[parent+1] > 0:
            s -= lval
            parent += 1
    return parent

cdef void _st_to_prios(float[:] errors, float upper, float alpha) nogil:
    cdef int n = errors.shape[0]
    cdef int i
    cdef float e
    for i in range(n):
        e = errors[i]
        if e < 1e-5:
            e = 1e-5
        if e > upper:
            e = upper
        errors[i] = pow(e, alpha)

cdef void _st_update_one(float[:] tree, int capacity, int data_idx, float value) nogil:
    cdef int i = data_idx + capacity - 1
    cdef float d = value - tree[i]
    while True:
        tree[i] += d
        if 0 == i:
            break
        i = (i - 1) // 2

cdef void _st_update_many(float[:] tree, int capacity, int[:] data_indices, float[:] values) nogil:
    cdef int N = data_indices.shape[0]
    cdef int i
    for i in range(N):
        _st_update_one(tree, capacity, data_indices[i], values[i])

cdef void _st_sample(float[:] tree, int capacity, int n, float[:] in_s, int[:] out_indices, float[:] out_weights):
    cdef int i
    cdef int base = capacity - 1
    assert tree.shape[0] == capacity * 2 - 1, "capacity does not match tree shape"
    assert n == in_s.shape[0], "in_s shape don't match n"
    assert n == out_indices.shape[0], "out_indices shape don't match n"
    assert n == out_weights.shape[0], "out_weights shape don't match n"
    cdef int t_i
    for i in prange(n, nogil=True):
        t_i = _st_get_one_tree_idx(tree, capacity, in_s[i])
        out_indices[i] = t_i - base
        out_weights[i] = tree[t_i]

class PrioritizedBuffer(object):
    def __init__(self, capacity, alpha=0.6, beta=0.4, beta_step=1e-4):
        self.capacity = capacity
        self.alpha = alpha
        self.beta = beta
        self.beta_step = beta_step
        
        self.tree = np.zeros((capacity * 2 - 1,), dtype=np.float32)
        self.data = np.empty(capacity, dtype=object)
        self.pos = 0
        self.count = 0

    def __len__(self):
        return self.count
    
    def add(self, obj):
        self.data[self.pos] = obj
        _st_update_one(self.tree, self.capacity, self.pos, 1)
        self.pos = (self.pos + 1) % self.capacity
        if self.count < self.capacity:
            self.count += 1
    
    def update(self, indices, weights):
        _st_to_prios(weights, 1, self.alpha)
        _st_update_many(self.tree, self.capacity, indices, weights)
    
    def sample(self, n):
        weights = np.empty(n, dtype=np.float32)
        indices = np.empty(n, dtype=np.int32)
        segment = self.tree[0] / n
        a = np.arange(n, dtype=np.float32) * segment
        s = np.random.uniform(a, a + segment).astype(np.float32)
        _st_sample(self.tree, self.capacity, n, s, indices, weights)
        self.beta = min(1, self.beta + self.beta_step)
        weights = np.power(weights / self.tree[0] * self.count, -self.beta)
        weights /= weights.max()
        return indices, weights.astype(np.float32), self.data[indices]

In [None]:
%%time
from tqdm import tqdm_notebook
p = PrioritizedBuffer(1000000)
tq = tqdm_notebook(range(10000000))
for k in tq:
    p.add(f'_{k}')
    if len(p) >= 32:
        i, w, d = p.sample(32)
        for j in range(32):
            assert d[j] is not None, f'k={k}, i[{j}]={i[j]}, w[{j}]={w[j]}, d[{j}]={d[j]}, tree[0]={p.tree[0]}'
        w += (np.random.random_sample(32) - 0.5) * 0.5
        p.update(i, w)
print(i)
print(w)
print(d)
print(p.tree[:15])