In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#hide
from nbdev.showdoc import *

In [4]:
from anytree import Node, RenderTree


def build_segment_tree(a, start, end):
    """
    [start, end] inclusive
    """
    if start > end:
        raise ValueError(f"Invalid input {start}, {end}")
    if start == end:
        return Node("", v = a[start], start=start, end=end)

    mid = (start + end)//2
    l = build_segment_tree(a, start, mid)
    r = build_segment_tree(a, mid+1, end)
    p = Node("", children=[l, r], v = l.v + r.v, start=start, end=end)
#         l.parent = p
#         r.parent = p
    return p

def query(root, start, end):
    if start > end:
        raise ValueError(f"Invalid input {start}, {end}")

    # not in range
    if root.start > end:
        return 0

    if root.end < start:
        return 0

    # all in range    
    if root.start >= start and root.end <= end:
        return root.v

    return query(root.children[0], start, end) + \
           query(root.children[1], start, end)
    
    
def update(root, i, diff):
    if i < root.start:
        return
    if i > root.end:
        return

    root.v += diff
    if root.is_leaf:
        return

    mid = (root.start + root.end)//2
    if i <= mid:
        update(root.children[0], i, diff)
    else:
        update(root.children[1], i, diff)
    

In [5]:
a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
r = build_segment_tree(a, 0, len(a)-1)
print(RenderTree(r))

print(query(r, 1, 2))

diff = 1 - a[2]
a[2] += diff
update(r, 2, diff);
print(query(r, 1, 2))
print(RenderTree(r))

Node('/', end=11, start=0, v=78)
├── Node('//', end=5, start=0, v=21)
│   ├── Node('///', end=2, start=0, v=6)
│   │   ├── Node('////', end=1, start=0, v=3)
│   │   │   ├── Node('/////', end=0, start=0, v=1)
│   │   │   └── Node('/////', end=1, start=1, v=2)
│   │   └── Node('////', end=2, start=2, v=3)
│   └── Node('///', end=5, start=3, v=15)
│       ├── Node('////', end=4, start=3, v=9)
│       │   ├── Node('/////', end=3, start=3, v=4)
│       │   └── Node('/////', end=4, start=4, v=5)
│       └── Node('////', end=5, start=5, v=6)
└── Node('//', end=11, start=6, v=57)
    ├── Node('///', end=8, start=6, v=24)
    │   ├── Node('////', end=7, start=6, v=15)
    │   │   ├── Node('/////', end=6, start=6, v=7)
    │   │   └── Node('/////', end=7, start=7, v=8)
    │   └── Node('////', end=8, start=8, v=9)
    └── Node('///', end=11, start=9, v=33)
        ├── Node('////', end=10, start=9, v=21)
        │   ├── Node('/////', end=9, start=9, v=10)
        │   └── Node('/////', end=10, sta

In [6]:
total_sz = len(r.descendants)+1
total_sz

23

Observations:
1. we have 2*n-1 nodes, we can store tree in an array, but with large size
2. we don't need start end, which we can recompute them


In [40]:
def get_left_child(i):
    return i*2+1

def get_right_child(i):
    return i*2+2

def build_segment_tree2(st, current, a, start, end):
    """
    st: the array tree
    current: the current node in array tree
    a: the input array
    [start, end] inclusive
    """
    if start > end:
        raise ValueError(f"Invalid input {start}, {end}")
    if start == end:
        print(f"st[{current}] = {a[start]}")
        st[current] = a[start]
        return st[current]

    mid = (start + end)//2
    l = get_left_child(current)
    r = get_right_child(current)
    
    v = build_segment_tree2(st, l, a, start, mid)
    v += build_segment_tree2(st, r, a, mid+1, end)
    st[current] = v
    return v


def query2(st, current, t_start, t_end, l, r):
    """
    st: the array tree
    current: the current node in array tree
    [t_start, t_end]: the original range corresponding to the node
    [l, r]: query range
    """
    if l > r:
        raise ValueError(f"Invalid input {l}, {r}")

    # not in range
    if t_start > r:
        return 0

    if t_end < l:
        return 0

    # all in range    
    if t_start >= l and t_end <= r:
        return st[current]

    mid = (t_start + t_end)//2

    return query2(st, get_left_child(current), t_start, mid, l, r) + \
           query2(st, get_right_child(current), mid+1, t_end, l, r)

In [41]:
from math import ceil, log2
a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
n = len(a)
# Height of segment tree 
x = (int)(ceil(log2(n))); 
# Maximum size of segment tree 
max_size = 2 * (int)(2**x) - 1;
    
tree = [0]*(max_size)
root = 0
build_segment_tree2(tree, root, a, 0, len(a)-1)

print(tree)
print(query2(tree, root, 0, n-1, 1, 2))

st[15] = 1
st[16] = 2
st[8] = 3
st[19] = 4
st[20] = 5
st[10] = 6
st[23] = 7
st[24] = 8
st[12] = 9
st[27] = 10
st[28] = 11
st[14] = 12
[78, 21, 57, 6, 15, 24, 33, 3, 3, 9, 6, 15, 9, 21, 12, 1, 2, 0, 0, 4, 5, 0, 0, 7, 8, 0, 0, 10, 11, 0, 0]
5


Observations:
1. we allocate more space than 2*n-1, which we could still do better


In [47]:
# function to build the tree
def build_segment_tree3(tree, arr) :
    n = len(arr)
    # insert leaf nodes in tree
    for i in range(n) :
        tree[n + i] = arr[i];
     
    # build the tree by calculating parents
    # note 1 is the root
    for i in range(n - 1, 0, -1) :
        tree[i] = tree[i << 1] + tree[i << 1 | 1];
        
        
def update_3(tree, p, value) :
    # set value at position p
    tree[p + n] = value;
    p = p + n;
     
    # move upward and update parents
    i = p;
    while i > 1 :
        # if i has 1 in bit 0, then i^1 will the left child
        # otherwise, i^1 will be the right child
        tree[i >> 1] = tree[i] + tree[i ^ 1];
        i >>= 1;
        
        
def query_3(tree, l, r) :
    res = 0;
     
    # loop to find the sum in the range
    l += n;
    r += n;
     
    while l < r :
     
        # if l is the right child, we add the range in l
        # since we already add it, we increase l
        if (l & 1) :
            res += tree[l];
            l += 1
     
        # if r-1 is the left child, we add the range in r-1
        # since we already add it, we decrease r
        if (r & 1) :
            r -= 1;
            res += tree[r];
        # we finish handle in current level, move up to parent level     
        l >>= 1;
        r >>= 1
     
    return res;        

In [44]:
a = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
n = len(a)
tree = [0]*(2*n)
build_segment_tree3(tree, a)
tree

[0,
 78,
 68,
 10,
 26,
 42,
 3,
 7,
 11,
 15,
 19,
 23,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12]

In [45]:
update_3(tree, 2, 1)
tree

[0,
 76,
 68,
 8,
 26,
 42,
 3,
 5,
 11,
 15,
 19,
 23,
 1,
 2,
 1,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12]

In [48]:
query_3(tree, 1, 3)

3