In [1]:
import random

In [2]:
class SumTree(object): 
    def __init__(self, capacity): 
        self.capacity = capacity
        self.tree_size = 2 * self.capacity - 1
        self.nonleaf_size = self.tree_size - self.capacity 
        self.nodes = [0.0 for _ in range(self.tree_size)]
        self.data = [None for _ in range(self.capacity)]
        self.pointer = 0

    def total(self):
        return self.nodes[0]
    
    def _node_for(self, number):
        # float 'number' is within a respective
        # 'cumulative range' of leaf node. Find that node.
        
        # Start at the root node
        # While we are not at a leaf node:
        #   if the number greater than the left child's weight
        #     go down the right child's tree
        #     Update number to reflect that we are at this subtree
        #   else go down left child's tree
        node_id = 0

        while True: 
            
            left_child_id = 2 * node_id + 1           
            if left_child_id >= self.tree_size: break
            
            end = self.nodes[left_child_id]
            if number > end:
                node_id = left_child_id + 1 # right child
                number -= end
            else:
                node_id = left_child_id
                
        return node_id
        
    def datatriple_for(self, number):
        # float 'number' is within a respective 'cumulative range'
        # of a node containing a particular data's weight. 
        # Find that node. return 3 values (object, float, int)
        # data, data's weight, leaf node's id
        
        weight_id = self._node_for(number)
        data_id = weight_id - self.nonleaf_size
        return self.data[data_id], self.nodes[weight_id], weight_id
            
    def update_tree(self, current_id, weight):
        # Updates the SumTree given:  
        # A new weight to store in a particular leaf node
        
        # IMPORTANT: we are assuming that the argument
        # passed is indeed a leaf node 
        
        change = weight - self.nodes[current_id]
        self.nodes[current_id] = weight
        # Update all the above nodes in the chain to the root
        # IE: Propagate the change upwards 
        while True:
            current_id = SumTree._parent_id(current_id)
            if current_id < 0.0: break
            self.nodes[current_id] += change
                
    def add(self, new_data, weight):
        self.data[self.pointer] = new_data
        weight_id = self.nonleaf_size + self.pointer
        self.update_tree(weight_id, weight)
        self._update_pointer()
    
    def _update_pointer(self):
        # Update where to store the next data that comes in 
        self.pointer += 1
        if self.pointer >= self.capacity: 
            self.pointer = 0
    
    def __str__(self):
        return "nodes: " + str(self.nodes) + "\n data: " + str(self.data) + "\n"
    
    @classmethod
    def _parent_id(cls, i):
        return (i - 1) // 2

In [3]:
my_sumtree = SumTree(capacity=3)
print(my_sumtree)
#nodes: [0.0, 0.0, 0.0, 0.0, 0.0]
# data: [None, None, None]

my_sumtree.add('x', 8.0)
print(my_sumtree)
# nodes: [8.0, 0.0, 8.0, 0.0, 0.0]
# data: ['x', None, None]

my_sumtree.add('y', 5.0)
print(my_sumtree)
# nodes: [13.0, 5.0, 8.0, 5.0, 0.0]
# data: ['x', 'y', None]

my_sumtree.add('z', 7.0)
print(my_sumtree)
# nodes: [20.0, 12.0, 8.0, 5.0, 7.0]
# data: ['x', 'y', 'z']

my_sumtree.add('a', 3.0)
print(my_sumtree)
# nodes: [15.0, 12.0, 3.0, 5.0, 7.0]
# data: ['a', 'y', 'z']

my_sumtree.add('b', 7.0)
print(my_sumtree)
# nodes: [17.0, 14.0, 3.0, 7.0, 7.0]
# data: ['a', 'b', 'z']

nodes: [0.0, 0.0, 0.0, 0.0, 0.0]
 data: [None, None, None]

nodes: [8.0, 0.0, 8.0, 0.0, 0.0]
 data: ['x', None, None]

nodes: [13.0, 5.0, 8.0, 5.0, 0.0]
 data: ['x', 'y', None]

nodes: [20.0, 12.0, 8.0, 5.0, 7.0]
 data: ['x', 'y', 'z']

nodes: [15.0, 12.0, 3.0, 5.0, 7.0]
 data: ['a', 'y', 'z']

nodes: [17.0, 14.0, 3.0, 7.0, 7.0]
 data: ['a', 'b', 'z']



In [4]:
my_sumtree = SumTree(capacity=5)
my_sumtree.add('a', 70.0)
my_sumtree.add('b', 20.0)
my_sumtree.add('c', 40.0)
my_sumtree.add('d', 40.0)
my_sumtree.add('e', 15.0)

print(my_sumtree)
# nodes: [15.0, 12.0, 3.0, 5.0, 7.0]
# data: ['a', 'y', 'z']

my_sumtree.update_tree(current_id=8, weight=10.0)
print(my_sumtree)
# nodes: [180.0, 120.0, 60.0, 50.0, 70.0, 20.0, 40.0, 40.0, 10]
# data: ['a', 'b', 'c', 'd', 'e']

nodes: [185.0, 125.0, 60.0, 55.0, 70.0, 20.0, 40.0, 40.0, 15.0]
 data: ['a', 'b', 'c', 'd', 'e']

nodes: [180.0, 120.0, 60.0, 50.0, 70.0, 20.0, 40.0, 40.0, 10.0]
 data: ['a', 'b', 'c', 'd', 'e']



In [5]:
for n in [20.0, 47.0, 119.5, 122.75, 155.0]:
    data, weight, weight_id = my_sumtree.datatriple_for(n)
    print('{:6} {} {:4} {}'.format(n, data, weight, weight_id))

#   20.0 d 40.0 7
#   47.0 e 10.0 8
#  119.5 a 70.0 4
# 122.75 b 20.0 5
# 155.00 c 40.0 6

nodes: [180.0, 120.0, 60.0, 50.0, 70.0, 20.0, 40.0, 40.0, 10.0]
 data: ['a', 'b', 'c', 'd', 'e']

  20.0 d 40.0 7
  47.0 e 10.0 8
 119.5 a 70.0 4
122.75 b 20.0 5
 155.0 c 40.0 6

nodes: [180.0, 120.0, 60.0, 50.0, 70.0, 20.0, 40.0, 40.0, 10.0]
 data: ['a', 'b', 'c', 'd', 'e']



In [36]:
def stratified_sampling(number_of_segments, sumtree, number_of_samples):
    
    samples = []
    segment_size = sumtree.total() / number_of_segments
    i = 0
    while True:
        
        for j in range(number_of_segments):
            
            start = segment_size * (j)
            end = segment_size * (j + 1.0) - 0.0001
            r = random.uniform(start, end)
            data = sumtree.datatriple_for(r)
            samples.append(data)
            i+=1
            
            if i >= number_of_samples: break
            
        if i >= number_of_samples: break
    
    return samples

In [38]:
my_sumtree = SumTree(capacity=3)
my_sumtree.add('x', 8.0)
my_sumtree.add('y', 5.0)
my_sumtree.add('z', 7.0)
print(my_sumtree)

samples = stratified_sampling(5, my_sumtree, 9)

for data, _, _ in samples: 
    print(data)
    
# y
# z or y
# z
# x 
# x
# y
# z or y
# z
# x

nodes: [20.0, 12.0, 8.0, 5.0, 7.0]
 data: ['x', 'y', 'z']

y
z
z
x
x
y
y
z
x
