***307. Range Sum Query - Mutable***

https://leetcode.com/problems/range-sum-query-mutable/description/

In [None]:
# https://www.educative.io/answers/what-is-a-segment-tree
class SegmentTree:
    def __init__(self, arr):
        """
        Initialize the Segment Tree with the given array.
        """
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)  # Allocate space for the tree
        self._build(arr, 0, 0, self.n - 1)

    def _build(self, arr, node, start, end):
        """
        Recursively build the Segment Tree.
        """
        if start == end:  # Leaf node
            self.tree[node] = arr[start]
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
            self._build(arr, left_child, start, mid)
            self._build(arr, right_child, mid + 1, end)
            # Combine results from children
            self.tree[node] = self.tree[left_child] + self.tree[right_child]

    def query(self, L, R):
        """
        Perform a range query for sum in the range [L, R].
        """
        return self._query(0, 0, self.n - 1, L, R)

    def _query(self, node, start, end, L, R):
        """
        Helper function to recursively perform range queries.
        """
        if R < start or L > end:  # Completely outside range
            return 0
        if L <= start and end <= R:  # Completely inside range
            return self.tree[node]
        # Partially overlapping range
        mid = (start + end) // 2
        left_child = 2 * node + 1
        right_child = 2 * node + 2
        left_query = self._query(left_child, start, mid, L, R)
        right_query = self._query(right_child, mid + 1, end, L, R)
        return left_query + right_query

    def update(self, idx, value):
        """
        Update a specific index in the array and reflect it in the tree.
        """
        self._update(0, 0, self.n - 1, idx, value)

    def _update(self, node, start, end, idx, value):
        """
        Helper function to recursively update the tree.
        """
        if start == end:  # Leaf node
            self.tree[node] = value
        else:
            mid = (start + end) // 2
            left_child = 2 * node + 1
            right_child = 2 * node + 2
            if start <= idx <= mid:
                self._update(left_child, start, mid, idx, value)
            else:
                self._update(right_child, mid + 1, end, idx, value)
            # Recompute the current node's value
            self.tree[node] = self.tree[left_child] + self.tree[right_child]


In [7]:
class TreeNode:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.total = 0
        self.left = None
        self.right = None

    @staticmethod
    def build_tree(nums, start, end):
        if start > end:
            return None

        if start == end:
            # Leaf node
            node = TreeNode(start, end)
            node.total = nums[start]
            return node

        mid = (start + end) // 2
        node = TreeNode(start, end)
        node.left = TreeNode.build_tree(nums, start, mid)
        node.right = TreeNode.build_tree(nums, mid + 1, end)
        node.total = node.left.total + node.right.total
        return node

    def update(self, index, value):
        if self.start == self.end:
            # Leaf node
            self.total = value
            return

        mid = (self.start + self.end) // 2
        if index <= mid:
            self.left.update(index, value)
        else:
            self.right.update(index, value)

        # Update total after child update
        self.total = self.left.total + self.right.total

    def query(self, left, right):
        if self.start == left and self.end == right:
            return self.total

        mid = (self.start + self.end) // 2
        if right <= mid:
            return self.left.query(left, right)
        elif left > mid:
            return self.right.query(left, right)
        else:
            left_sum = self.left.query(left, mid)
            right_sum = self.right.query(mid + 1, right)
            return left_sum + right_sum


class NumArray:
    def __init__(self, nums: list[int]):
        self.root = TreeNode.build_tree(nums, 0, len(nums) - 1) if nums else None

    def update(self, index: int, val: int) -> None:
        if self.root:
            self.root.update(index, val)

    def sumRange(self, left: int, right: int) -> int:
        return self.root.query(left, right) if self.root else 0

na = NumArray([1, 3, 5])
print(na.sumRange(0, 2))  # 9
na.update(1, 2)
print(na.sumRange(0, 2))  # 8
na.update(0, 2)
print(na.sumRange(0, 2))  # 7


9
8
9
