In [24]:
import numpy as np

class SumTree:
    def __init__(self, capacity):
        """
            Efficient Implementation, first capacity - 1 nodes are reserved as inner nodes.
            The last capacity nodes are used to store the data values.
            Creates a multilevel binary tree efficiently (no python classes which are slow).
        """
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)
        self.data = np.zeros(capacity, dtype=object)
        self.size = 0
        self.write = 0

    def add(self, priority, data):
        idx = self.write + self.capacity - 1
        self.data[self.write] = data
        self.update(idx, priority)
        
        self.write += 1
        if self.write >= self.capacity:
            self.write = 0
        self.size = min(self.size + 1, self.capacity)

    def update(self, idx, priority):
        change = priority - self.tree[idx]
        self.tree[idx] = priority
        self._propagate(idx, change)

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1
        return (idx, self.tree[idx], self.data[dataIdx])

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1
        if left >= len(self.tree):
            return idx
        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    @property
    def total(self):
        return self.tree[0]


In [43]:
test_tree = SumTree(capacity=8)
test_tree.add(1,('test1'))
test_tree.add(2,('test2'))
test_tree.add(3,('test3'))
test_tree.add(4,('test4'))
test_tree.add(5,('test4'))



index 7 propagting to parent 3
index 3 propagting to parent 1
index 1 propagting to parent 0
index 8 propagting to parent 3
index 3 propagting to parent 1
index 1 propagting to parent 0
index 9 propagting to parent 4
index 4 propagting to parent 1
index 1 propagting to parent 0
index 10 propagting to parent 4
index 4 propagting to parent 1
index 1 propagting to parent 0
index 11 propagting to parent 5
index 5 propagting to parent 2
index 2 propagting to parent 0


In [44]:
test_tree.tree

array([15., 10.,  5.,  3.,  7.,  5.,  0.,  1.,  2.,  3.,  4.,  5.,  0.,
        0.,  0.])

In [38]:
test_tree.data

array(['test1', 'test2', 'test3', 'test4', 0, 0, 0], dtype=object)