# Segment Trees

In this notebook, I will be exploring and implementing segment trees. This is a common data structure used in competitive programming, as well as useful for storing and querying an array. It is used to store information over intervals of an array, for instance the sum of values within a segment of an array.

## Initial Implementation

Let's create an implementation that stores the sum of sub arrays. We will pointers from one node to the next. This is not as efficient as storing the nodes in an array, we will update the code to work like this later.

In [36]:
from __future__ import annotations
from dataclasses import dataclass


class PtrNode:
    def __init__(self, bounds: tuple[int, int]):
        self.bounds = bounds
        self.sum_: int = 0
        self.left: "PtrNode" | None = None
        self.right: "PtrNode" | None = None


class PtrSegmentTree:
    def __init__(self, a: list[int]) -> None:
        root = PtrNode(bounds=(0, len(a)))
        self.root = root
        self._build(self.root, a)

    def _build(self, node: PtrNode, a: list[int]):
        i, j = node.bounds
        if len(a[i:j]) == 1:
            node.sum_ = a[i]
            return
        mid = (i + j) // 2

        left = PtrNode(bounds=(i, mid))
        self._build(left, a)
        node.sum_ += left.sum_
        node.left = left

        right = PtrNode(bounds=(mid, j))
        self._build(right, a)
        node.sum_ += right.sum_
        node.right = right

    def query(self, bounds: tuple[int, int]) -> int:
        return self._query(self.root, bounds)

    def _query(self, node: PtrNode | None, bounds: tuple[int, int]) -> int:
        i, j = bounds
        if i >= j:
            return 0
        elif bounds == node.bounds:
            return node.sum_
        ni, nj = node.bounds
        mid = (ni + nj) // 2
        # Otherwise, we gotta go left, and right.
        return self._query(node.left, (i, min(mid, j))) + self._query(
            node.right, (max(i, mid), j)
        )

    def update(self, i: int, value: int) -> None:
        self._update(self.root, i, value)

    def _update(self, node: PtrNode | None, i: int, value: int) -> None:
        # If the value is in the node, update
        if (node.bounds[0] == i) and ((i + 1) == node.bounds[1]):
            node.sum_ = value
        elif (node.bounds[0] <= i) and (i < node.bounds[1]):
            self._update(node.left, i, value)
            self._update(node.right, i, value)
            node.sum_ = node.left.sum_ + node.right.sum_


ar = [-1, 4, 10, 2, -1, -2, 5, 6, 6, 6]
tree = PtrSegmentTree(ar)
tree.query((2, 5)), sum(ar[2:5])

(11, 11)

In [37]:
ar[0] = 100
tree.update(0, 100)
tree.query((0, 3)), sum(ar[0:3])

(114, 114)

In [38]:
tree.root.left.sum_, tree.root.left.bounds

(115, (0, 5))

In [39]:
tree.root.right.sum_, tree.root.right.bounds

(21, (5, 10))

In [40]:
sum(ar), sum(ar[0:5]), sum(ar[5:10])

(136, 115, 21)

Apparently the more common way to do this is with an "implicit data structure", by using an array to house the tree. We keep track of the location in the node vector as `i` and then the left node is `2*i` and the right is `(2*i)+1`. The node vector is created of the size `4*N` where N is the size of the array.

In [41]:
class ArrSegmentTree:
    def __init__(self, a: list[int]) -> None:
        self.N = len(a)
        self.arr = [0 for _ in range(self.N * 4)]
        self._build(1, a, 0, self.N)

    def _build(self, current: int, a: list[int], left: int, right: int) -> None:
        if len(a[left:right]) == 1:
            self.arr[current] = a[left]
        else:
            mid = (left + right) // 2

            # Left
            self._build(2 * current, a, left, mid)
            self.arr[current] += self.arr[2 * current]

            # right
            self._build(2 * current + 1, a, mid, right)
            self.arr[current] += self.arr[2 * current + 1]

    def query(self, bounds: tuple[int, int]) -> int:
        left, right = bounds
        return self._query(1, left, right, 0, self.N)

    def _query(
        self,
        current: int,
        left: int,
        right: int,
        node_left: int,
        node_right: int,
    ) -> int:
        if left >= right:
            return 0
        elif (left, right) == (node_left, node_right):
            return self.arr[current]
        mid = (node_left + node_right) // 2
        # Otherwise, we gotta go left, and right.
        return self._query(
            current * 2, left, min(mid, right), node_left, mid
        ) + self._query(current * 2 + 1, max(left, mid), right, mid, node_right)

    def update(self, i: int, value: int) -> None:
        self._update(1, i, value, 0, self.N)

    def _update(
        self, current: int, i: int, value: int, node_left: int, node_right: int
    ) -> None:
        # If the value is in the node, update
        if (node_left == i) and ((i + 1) == node_right):
            self.arr[current] = value
        elif (node_left <= i) and (i < node_right):
            mid = (node_left + node_right) // 2
            self._update(current * 2, i, value, node_left, mid)
            self._update(current * 2 + 1, i, value, mid, node_right)
            self.arr[current] = self.arr[current * 2] + self.arr[current * 2 + 1]


ar = [-1, 4, 10, 2, -1, -2, 5, 6, 6, 6, 7]
tree = ArrSegmentTree(ar)
tree.query((2, 5)), sum(ar[2:5])

(11, 11)

In [42]:
tree.arr[1], tree.arr[2], tree.arr[3], sum(ar)

(42, 14, 28, 42)

In [43]:
ar[0] = 100
tree.update(0, 100)
tree.query((0, 3)), sum(ar[0:3])

(114, 114)