In [14]:
import numpy as np
from collections import defaultdict

In [2]:
def make_tree(capacity):
    """
    Create a binary tree, with initial weights of 0. 
    Each node is the sum of its children.
    """
    c = 1
    tree = []
    while c < capacity:
        tree.append([0] * c)
        c *= 2
    tree.append([0] * capacity)
    return tree

In [3]:
def update_tree(tree, idx, val):
    """
    Replace the idx'th leaf and update the weights.
    Mutates tree.
    """
    delta = val - tree[-1][idx]
    for depth in range(-1, -len(tree) - 1, -1):
        tree[depth][idx] += delta
        idx //= 2

In [10]:
def get_leaf(tree):
    """
    Randomly selects a leaf from the tree based on the weights in the tree.
    """
    val = np.random.randint(0, tree[0][0])
    idx = 0
    for depth in range(1, len(tree)-1):
        left_val = tree[depth][idx]
        if val >= left_val:
            val -= left_val
            idx = (idx + 1) * 2
        else:
            idx *= 2
    left_leaf = tree[len(tree)-1][idx]
    return idx + (val >= left_leaf)

In [16]:
def simple_hist(ll):
    hist = defaultdict(int)
    for ele in ll:
        hist[ele] += 1
    return hist

In [11]:
tree = make_tree(7)
update_tree(tree, 1, 7)
update_tree(tree, 2, 3)
update_tree(tree, 6, 4)
tree

[[14], [10, 4], [7, 3, 0, 4], [0, 7, 3, 0, 0, 0, 4]]

In [30]:
s1 = {get_leaf(tree) for i in range(3)}
s2 = {get_leaf(tree) for i in range(3)}
print(s1, s2)

{1} {1, 6}


In [39]:
s = s1 | s2
s

{1, 6}

In [33]:
len(s1| s2)

2

In [38]:
s = set()
s |= s1
s

{1}