In [1]:
import sys
print(sys.version)

3.8.3 (default, Jul  2 2020, 11:26:31) 
[Clang 10.0.0 ]


In [2]:

class intervalTree(object):

    def __init__(self, nums):
        """
        :type nums: List[int]
        """
        # 使用完全二叉树，虽然空间有常数浪费，但方便初学
        self.nums_len = len(nums)
        self.leaf_nodes = 1 # 叶子节点数量
        
        # 找到一个大于所有叶子节点的集合
        while self.leaf_nodes < self.nums_len:
            self.leaf_nodes *= 2

        # 记录当前节点的数据
        self.array_tree = [0] * self.leaf_nodes * 2
        # 记录当前节点的累加数据
        self.interval_add = [0] * self.leaf_nodes * 2
        # 记录当前节点的区间内实际存在的数据量
        self.interval_count = [0] * self.leaf_nodes * 2

        for idx, num in enumerate(nums):
            self.array_tree[idx + self.leaf_nodes] = num
            self.interval_count[idx + self.leaf_nodes] = 1
        
        for idx in reversed(range(self.leaf_nodes)):
            self.array_tree[idx] = self.array_tree[idx * 2] + self.array_tree[idx * 2 + 1]
            self.interval_count[idx] = self.interval_count[idx * 2] + self.interval_count[idx * 2 + 1]
        

    def update(self, index, val):
        """
        :type index: int
        :type val: int
        :rtype: None
        """
        pos = self.leaf_nodes + index
        self.array_tree[pos] = val

        while pos > 1:
            pos = pos // 2
            self.array_tree[pos] = self.array_tree[pos * 2] + self.array_tree[pos * 2 + 1]


    def findByInterval(self, st, ed, left, right, array_pos, parent_add=0):
        if st > right:
            # 利用区间树会快的核心剪枝逻辑
            return 0

        if ed < left:
            # 利用区间树会快的核心剪枝逻辑
            return 0

        if left <= st and ed <= right:
            # print(f'findByInterval: [{st}, {ed}] = {self.array_tree[array_pos]} + {parent_add} * {self.interval_count[array_pos]}')
            return self.array_tree[array_pos] + parent_add * self.interval_count[array_pos]
        
        l = self.findByInterval(
                st, st + (ed - st) // 2, left, right, array_pos * 2, parent_add + self.interval_add[array_pos])
        r = self.findByInterval(
                st + (ed - st) // 2 + 1, ed, left, right, array_pos * 2 + 1, parent_add + self.interval_add[array_pos])
        
        return l + r


    def sumRange(self, left, right):
        """
        :type left: int
        :type right: int
        :rtype: int
        """
        return self.findByInterval(0, self.leaf_nodes - 1, left, right, 1) # 1为根节点
    
    def addRange(self, left, right, value):
        return self.intervalAdd(0, self.leaf_nodes - 1, left, right, 1, value)
    
    # st为当前节点代表的区间起始值；ed为当前节点代表的区间结束值。
    def intervalAdd(self, st, ed, left, right, array_pos, value):
        if st > right:
            # 利用区间树会快的核心剪枝逻辑
            return 0

        if ed < left:
            # 利用区间树会快的核心剪枝逻辑
            return 0

        if left <= st and ed <= right:
            self.interval_add[array_pos] += value
            self.array_tree[array_pos] += self.interval_count[array_pos] * value
            # 返回增量和变化
            return self.interval_count[array_pos] * value
        
        l = self.intervalAdd(st, st + (ed - st) // 2, left, right, array_pos * 2, value)
        r = self.intervalAdd(st + (ed - st) // 2 + 1, ed, left, right, array_pos * 2 + 1, value)
        
        self.array_tree[array_pos] += l + r
        
        return l + r
        

# Your NumArray object will be instantiated and called as such:
# obj = NumArray(nums)
# obj.update(index,val)
# param_2 = obj.sumRange(left,right)

In [3]:
nums = [1, 3, 5]

interval_tree = intervalTree(nums)
print(interval_tree.sumRange(0, 2))
interval_tree.update(1, 2)
print(interval_tree.sumRange(0, 2))

9
8


In [4]:
def print_all_range_value(interval_tree, st, ed):
    for i in range(st, ed + 1):
        for j in range(i, ed + 1):
            print(f'sum of interval [{i}, {j}] is', interval_tree.sumRange(i, j))

In [5]:
nums = [0] * 8

interval_tree = intervalTree(nums)

interval_tree.addRange(0, 3, 3)
interval_tree.addRange(2, 7, 2)

print(interval_tree.array_tree)
print(interval_tree.interval_add)
print(interval_tree.interval_count)

print_all_range_value(interval_tree, 0, 7)

[0, 24, 16, 8, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[0, 0, 3, 2, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[8, 8, 4, 4, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1]
sum of interval [0, 0] is 3
sum of interval [0, 1] is 6
sum of interval [0, 2] is 11
sum of interval [0, 3] is 16
sum of interval [0, 4] is 18
sum of interval [0, 5] is 20
sum of interval [0, 6] is 22
sum of interval [0, 7] is 24
sum of interval [1, 1] is 3
sum of interval [1, 2] is 8
sum of interval [1, 3] is 13
sum of interval [1, 4] is 15
sum of interval [1, 5] is 17
sum of interval [1, 6] is 19
sum of interval [1, 7] is 21
sum of interval [2, 2] is 5
sum of interval [2, 3] is 10
sum of interval [2, 4] is 12
sum of interval [2, 5] is 14
sum of interval [2, 6] is 16
sum of interval [2, 7] is 18
sum of interval [3, 3] is 5
sum of interval [3, 4] is 7
sum of interval [3, 5] is 9
sum of interval [3, 6] is 11
sum of interval [3, 7] is 13
sum of interval [4, 4] is 2
sum of interval [4, 5] is 4
sum of interval [4, 6] is 6
sum of interval [4, 7

In [6]:
nums = [0] * 8

interval_tree = intervalTree(nums)

interval_tree.addRange(0, 1, 3)
interval_tree.addRange(2, 3, 2)

print_all_range_value(interval_tree, 0, 7)

sum of interval [0, 0] is 3
sum of interval [0, 1] is 6
sum of interval [0, 2] is 8
sum of interval [0, 3] is 10
sum of interval [0, 4] is 10
sum of interval [0, 5] is 10
sum of interval [0, 6] is 10
sum of interval [0, 7] is 10
sum of interval [1, 1] is 3
sum of interval [1, 2] is 5
sum of interval [1, 3] is 7
sum of interval [1, 4] is 7
sum of interval [1, 5] is 7
sum of interval [1, 6] is 7
sum of interval [1, 7] is 7
sum of interval [2, 2] is 2
sum of interval [2, 3] is 4
sum of interval [2, 4] is 4
sum of interval [2, 5] is 4
sum of interval [2, 6] is 4
sum of interval [2, 7] is 4
sum of interval [3, 3] is 2
sum of interval [3, 4] is 2
sum of interval [3, 5] is 2
sum of interval [3, 6] is 2
sum of interval [3, 7] is 2
sum of interval [4, 4] is 0
sum of interval [4, 5] is 0
sum of interval [4, 6] is 0
sum of interval [4, 7] is 0
sum of interval [5, 5] is 0
sum of interval [5, 6] is 0
sum of interval [5, 7] is 0
sum of interval [6, 6] is 0
sum of interval [6, 7] is 0
sum of interval