In [None]:
# https://leetcode.com/problems/range-sum-query-mutable/
# https://www.educative.io/answers/what-is-a-segment-tree
# Easy to understand explanation of Segment Tree
class TreeNode:
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.total = 0
        self.left = None
        self.right = None

    @staticmethod
    def build_tree(nums, start, end):
        if start > end:
            return None

        if start == end:
            # Leaf node
            node = TreeNode(start, end)
            node.total = nums[start]
            return node

        mid = (start + end) // 2
        node = TreeNode(start, end)
        node.left = TreeNode.build_tree(nums, start, mid)
        node.right = TreeNode.build_tree(nums, mid + 1, end)
        node.total = node.left.total + node.right.total
        return node

    def update(self, index, value):
        if self.start == self.end:
            # Leaf node
            self.total = value
            return

        mid = (self.start + self.end) // 2
        if index <= mid:
            self.left.update(index, value)
        else:
            self.right.update(index, value)

        # Update total after child update
        self.total = self.left.total + self.right.total


    def query(self, left, right):
        # If the current segment matches the query range, return the total
        if self.start == left and self.end == right:
            return self.total

        # Calculate the midpoint of the current segment
        mid = (self.start + self.end) // 2

        if right <= mid:
            # Entire query range lies in the left child
            return self.left.query(left, right)
        elif left > mid:
            # Entire query range lies in the right child
            return self.right.query(left, right)
        else:
            # Query range overlaps both children
            left_sum = self.left.query(left, mid)
            right_sum = self.right.query(mid + 1, right)
            return left_sum + right_sum



class NumArray:
    def __init__(self, nums: list[int]):
        self.root = TreeNode.build_tree(nums, 0, len(nums) - 1) if nums else None

    def update(self, index: int, val: int) -> None:
        if self.root:
            self.root.update(index, val)

    def sumRange(self, left: int, right: int) -> int:
        return self.root.query(left, right) if self.root else 0

na = NumArray([1, 3, 5])
print(na.sumRange(0, 2))  # 9
na.update(1, 2)
print(na.sumRange(0, 2))  # 8
na.update(0, 2)
print(na.sumRange(0, 2))  # 7