In [1]:
import numpy

class SumTree:
    write = 0

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = numpy.zeros( 2*capacity - 1 )
        self.data = numpy.zeros( capacity, dtype=object )

    def _propagate(self, idx, change):
        parent = (idx - 1) // 2

        self.tree[parent] += change

        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s-self.tree[left])

    def total(self):
        return self.tree[0]

    def p_array(self):
        return self.tree[-self.capacity:]
    
    def add(self, p, data):
        idx = self.write + self.capacity - 1

        self.data[self.write] = data
        self.update(idx, p)

        self.write += 1
        if self.write >= self.capacity:
            self.write = 0

    def update(self, idx, p):
        change = p - self.tree[idx]

        self.tree[idx] = p
        self._propagate(idx, change)

    def get(self, s):
        idx = self._retrieve(0, s)
        dataIdx = idx - self.capacity + 1

        return (idx, self.tree[idx], self.data[dataIdx])
    
    def normalise(self):
        self.total = numpy.sum(self.tree[-self.capacity:])
        for i in range(self.capacity - 1, len(self.tree), 1):
            p_norm = self.tree[i]/self.total
            self.update(i, p_norm)
    
    def unnormalise(self):
        for i in range(self.capacity - 1, len(self.tree), 1):
            p = self.tree[i]*self.total
            self.update(i, p)

In [2]:

sumtree = SumTree(10)


In [3]:
sumtree.add(0.1,numpy.array([1,2,3]))
sumtree.add(2,[0.2,1])
sumtree.add(3,[0.3,2])
sumtree.add(1,[0.4,3])
print(sumtree.data)
print(sumtree.tree)
print(sumtree.total())
sumtree.normalise()
print(sumtree.tree)
sumtree.unnormalise()
print(sumtree.get(2.6))

[array([1, 2, 3]) list([0.2, 1]) list([0.3, 2]) list([0.4, 3]) 0 0 0 0 0 0]
[6.1 2.1 4.  0.  2.1 4.  0.  0.  0.  0.1 2.  3.  1.  0.  0.  0.  0.  0.
 0. ]
6.1
[1.         0.3442623  0.6557377  0.         0.3442623  0.6557377
 0.         0.         0.         0.01639344 0.32786885 0.49180328
 0.16393443 0.         0.         0.         0.         0.
 0.        ]
(11, 3.0, [0.3, 2])


In [27]:
print(11//6)

1
