# Maximize Sum of Distinct Primes After Split

See problem [3569](https://leetcode.com/problems/maximize-count-of-distinct-primes-after-split/description/).

In [None]:
class SegmentTree:
    def __init__(self, a: list[int]) -> None:
        self.N = len(a)
        self.arr = [0 for _ in range(self.N * 4)]
        self.lazy = [False for _ in range(self.N * 4)]
        self._build(1, a, 0, self.N)

    def _build(self, current: int, a: list[int], left: int, right: int) -> None:
        if len(a[left:right]) == 1:
            self.arr[current] = a[left]
        else:
            mid = (left + right) // 2

            # Left
            self._build(2 * current, a, left, mid)
            # right
            self._build(2 * current + 1, a, mid, right)

            self.arr[current] = max(
                self.arr[2 * current],  self.arr[2 * current + 1]
            )

    def push(self, current: int) -> None:
        if self.lazy[current]:
            self.arr[2 * current],  self.arr[2 * current + 1] = self.arr[current]
            self.lazy[2 * current],  self.lazy[2 * current + 1] = True
            self.lazy[current] = False
            

    def query(self, bounds: tuple[int, int]) -> int:
        left, right = bounds
        return self._query(1, left, right, 0, self.N)

    def _query(
        self,
        current: int,
        left: int,
        right: int,
        node_left: int,
        node_right: int,
    ) -> int:
        if left >= right:
            return 0
        elif (left, right) == (node_left, node_right):
            return self.arr[current]
        self.push(current)
        mid = (node_left + node_right) // 2
        # Otherwise, we gotta go left, and right.
        return max(self._query(
            current * 2, left, min(mid, right), node_left, mid
        ), self._query(current * 2 + 1, max(left, mid), right, mid, node_right))

    def update(self, i: int, value: int) -> None:
        self._update(1, i, value, 0, self.N)

    def _update(
        self, current: int, i: int, value: int, node_left: int, node_right: int
    ) -> None:
        # If the value is in the node, update
        if (node_left == i) and ((i + 1) == node_right):
            self.arr[current] = value
            self.lazy[current] = True
        elif (node_left <= i) and (i < node_right):
            self.push(current)
            mid = (node_left + node_right) // 2
            self._update(current * 2, i, value, node_left, mid)
            self._update(current * 2 + 1, i, value, mid, node_right)
            self.arr[current] = max(self.arr[current * 2], self.arr[current * 2 + 1])


ar = [-1, 4, 10, 2, -1, -2, 5, 6, 6, 6, 7]
tree = SegmentTree(ar)
tree.query((2, 5)), max(ar[2:5]), tree.query((7, 10)), max(ar[7:10])

(10, 10, 6, 6)

In [None]:
# O(n*log(log(n)))
def sieve(n: int) -> list[bool]:
    prime = [True for _ in range(n+1)]
    # Start at 2
    p = 2
    while (p*p <= n):
        if prime[p]:
            for i in range(p*p, n+1, p):
                prime[i] = False
        p += 1
    return prime

nums = [2,1,3,1,2]
queries = [[1,2],[3,3]]

primes = sieve(max(*nums, *[v for _, v in queries]))
prime_indices : dict[int, dict[int, None]] = {}

for i, n in enumerate(nums):
    if primes[n]:
        res = prime_indices.get(n, {})
        res[i] = None
        prime_indices[n] = res

# Construct delta array
delta = [0 for _ in range(len(nums) + 1)]
for val, mp in prime_indices.items():
    if len(mp) >= 2:
        first = next(iter(mp.keys()))
        last = next(reversed(mp.keys()))
        delta[first + 1] += 1
        delta[last + 1] -= 1

print(delta)
for i in range(1, len(nums) + 1):
    delta[i] += delta[i-1]
print(delta)

tree = SegmentTree(delta)
for q in queries:
    [ind, new_val] = q
    old_val = nums[ind]
    nums[ind] = new_val
    if primes[old_val]:
        ...


[0, 1, 1, 0, -1, -1]
[0, 1, 2, 2, 1, 0]


In [9]:
primes, sieve(4), prime_indices

([True, True, True, True], [True, True, True, True, False], {})