## Segmentation Tree

Consider the following problem:

We have an array arr[0 . . . n-1]. We should be able to,

1. Find the sum of elements from index l to r where 0 <= l <= r <= n-1
2. Change value of a specified element of the array to a new value x. We need to do arr[i] = x where 0 <= i <= n-1.

### Solution

There two simple ideas:
1. We can calculate the range sum with a loop. This requires O(n) for range sum and O(1) for updating value
2. We can build a new array storing sum from `0` to `i` at index `i`. This make the range sum to O(1), but updating value becomes O(n).

It depends on the frequency of requesting the range sum or updating value. However, what if both calculations have same frequency?

We can achieve both in O(logn) by segmentation tree

### Representation of segmentation tree

Basic features : 
1. Leaf Nodes are the elements of the input array.
2. Each internal node represents some merging of the leaf nodes. The merging may be different for different problems. For this problem, merging is sum of leaves under a node.

Segmentation tree can be represented as array. It is a full binary tree (either 0 or 2 children for each node). We use dummy node for leaf node to fill the array.

![seg_tree](segment_tree.png)

The array representation of the above tree is :

Memory representation of segment tree for input array [1, 3, 5, 7, 9, 11]

st = [36, 9, 27, 4, 5, 16, 11, 1, 3, DUMMY, DUMMY, 7, 9, DUMMY, DUMMY]


Since it is a full binary tree, there are n leafs, thus n-1 internal nodes. So the memory requirements is `2n-1` nodes. [1]

### Query for range sum
```python
def getSum(node, l, r):
   if the range of the node is within l and r:
        return value in the node
   elif the range of the node is completely outside l and r:
        return 0
   else:
        return getSum(node's left child, l, r) + 
           getSum(node's right child, l, r)
```

### Updating a value
We can update the value recursively. We start from root of the segment tree and add `diff` to all nodes which have given index in their range.

### Python implementation

In [77]:
import math


class SegmentationTree():
    def __init__(self, arr):
        self.arr = arr
        self.leaf = math.ceil(math.log(len(arr), 2))
        self.num_nodes = 2 * (2**self.leaf) - 1
        self.segTree = [0] * self.num_nodes
        
        self._construct(0, len(arr) - 1, 0)
    
    def _getMid(self, s, e):
        return (s + (e-s)//2)
    
    def _construct(self, start, end, cur):
        if start == end:
            self.segTree[cur] = self.arr[start]
            return self.segTree[cur]

        mid = self._getMid(start, end)
        self.segTree[cur] = self._construct(start, mid, 2*cur+1) + self._construct(mid+1, end, 2*cur+2)
        return self.segTree[cur]

    def getSum(self, query_s, query_e):
        if 0 <= query_s < len(self.arr) and 0 <= query_e < len(self.arr):
            return self._getSum(0, len(self.arr)-1, query_s, query_e, 0)
        
    def update(self, index, new_v):
        if 0 <= index < len(self.arr):
            diff = new_v - self.arr[index]
            self.arr[index] = new_v
            self._update(0, len(self.arr)-1, index, diff, 0)
        else:
            print("Invalid index.")
        
    def _getSum(self, segment_s, segment_e, query_s, query_e, cur):
        if query_s <= segment_s and query_e >= segment_e:
            return self.segTree[cur]
        if segment_e < query_s or segment_s > query_e:
            return 0

        mid = self._getMid(segment_s, segment_e)
        return self._getSum(segment_s, mid, query_s, query_e, 2*cur+1) + \
                self._getSum(mid+1, segment_e, query_s, query_e, 2*cur+2)

    def _update(self, segment_s, segment_e, index, diff, cur):
        if segment_s <= index <= segment_e:
            self.segTree[cur] += diff
            if segment_s < segment_e:
                mid = self._getMid(segment_s, segment_e)
                self._update(segment_s, mid, index, diff, 2*cur+1)
                self._update(mid+1, segment_e, index, diff, 2*cur+2)


In [78]:
arr = [1, 3, 5, 7, 9, 11]
st = SegmentationTree(arr)
print(st.arr)
print(st.segTree)
print(st.getSum(2, 4))
st.update(2, 7)
print(st.arr)
print(st.segTree)
print(st.getSum(2, 4))

[1, 3, 5, 7, 9, 11]
[36, 9, 27, 4, 5, 16, 11, 1, 3, 0, 0, 7, 9, 0, 0]
21
[1, 3, 7, 7, 9, 11]
[38, 11, 27, 4, 7, 16, 11, 1, 3, 0, 0, 7, 9, 0, 0]
23


### Reference
1. [Segmentation tree](https://www.geeksforgeeks.org/segment-tree-set-1-sum-of-given-range/)
2. [Full binary tree](https://www.geeksforgeeks.org/binary-tree-set-3-types-of-binary-tree/)