In [1]:
from typing import List


class SegmentTree:
  def __init__(self, n: int):
    self.n = n
    self.size = 4 * n
    self.sum = [0] * self.size
    self.min = [0] * self.size
    self.max = [0] * self.size

  def _pull(self, node: int):
    l, r = node * 2, node * 2 + 1
    self.sum[node] = self.sum[l] + self.sum[r]
    self.min[node] = min(self.min[l], self.sum[l] + self.min[r])
    self.max[node] = max(self.max[l], self.sum[l] + self.max[r])

  def update(self, idx: int, val: int):
    node, l, r = 1, 0, self.n - 1
    path = []
    while l != r:
      path.append(node)
      m = l + (r - l) // 2
      if idx <= m:
        node = node * 2
        r = m
      else:
        node = node * 2 + 1
        l = m + 1
    self.sum[node] = val
    self.min[node] = val
    self.max[node] = val
    while path:
      self._pull(path.pop())

  def find_rightmost_prefix(self, target: int = 0) -> int:
    node, l, r, sum_before = 1, 0, self.n - 1, 0

    def _exist(node: int, sum_before: int):
      return self.min[node] <= target - sum_before <= self.max[node]

    if not _exist(node, sum_before):
      return -1

    while l != r:
      m = l + (r - l) // 2
      left_child, right_child = node * 2, node * 2 + 1
      sum_before_right = self.sum[left_child] + sum_before
      if _exist(right_child, sum_before_right):
        node = right_child
        l = m + 1
        sum_before = sum_before_right
      else:
        node = left_child
        r = m
    return l


class Solution:
  def longestBalanced(self, nums: List[int]) -> int:
    n = len(nums)
    segment_tree = SegmentTree(n)
    first = dict()
    result = 0
    for l in reversed(range(n)):
      num = nums[l]
      if num in first:
        segment_tree.update(first[num], 0)
      first[num] = l
      segment_tree.update(l, 1 if num % 2 == 0 else -1)
      r = segment_tree.find_rightmost_prefix(target=0)
      if r >= l:
        result = max(result, r - l + 1)
    return result

In [2]:
nums = [2, 5, 4, 3]
Solution().longestBalanced(nums=nums)

4

In [3]:
nums = [3, 2, 2, 5, 4]
Solution().longestBalanced(nums=nums)

5

In [4]:
nums = [1, 2, 3, 2]
Solution().longestBalanced(nums=nums)

3