In [22]:
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from typing import List, Union


@dataclass(frozen=True)
class Tree:
    val: int


@dataclass(frozen=True)
class Leaf(Tree):
    start: int
    until: int


@dataclass(frozen=True)
class Node(Tree):
    left: Tree
    right: Tree


ITree = Union[Leaf, Node]
executor = ThreadPoolExecutor()


def upsweep_sequential(input_ls: List[int], start: int, until: int) -> int:
    return sum(input_ls[idx] for idx in range(start, until))


def upsweep(input_ls: List[int], start: int, until: int, threshold: int) -> ITree:
    if until - start < threshold:
        return Leaf(
            start=start, until=until,
            val=upsweep_sequential(input_ls, start, until)
        )

    mid = start + (until - start) // 2
    
    future1 = executor.submit(upsweep, input_ls, start, mid, threshold)
    future2 = executor.submit(upsweep, input_ls, mid, until, threshold)

    left, right = future1.result(), future2.result()

    return Node(left=left, right=right, val=left.val + right.val)


def downsweep_sequential(input_ls: List[int], output_ls: List[int], start_val: int, start: int, until: int) -> None:
    for idx in range(start, until):
        start_val += input_ls[start]
        output_ls[idx] = start_val


def downsweep(input_ls: List[int], output_ls: List[int], start_val: int, tree: ITree):
    if isinstance(tree, Leaf):
        downsweep_sequential(input_ls, output_ls, start_val, tree.start, tree.until)
    elif isinstance(tree, Node):
        future1 = executor.submit(downsweep, input_ls, output_ls, start_val, tree.left)
        future2 = executor.submit(downsweep, input_ls, output_ls, start_val + tree.left.val, tree.right)

        future1.result(), future2.result()
    else:
        raise ValueError


def parallel_scan(input_ls: List[int], output_ls: List[int], start_val: int, threshold: int):
    downsweep(
        input_ls, output_ls,
        start_val, upsweep(input_ls, 0, len(input_ls), threshold)
    )

In [24]:
input_ls = list(range(10))
output_ls = [0] * len(input_ls)

parallel_scan(input_ls, output_ls, 0, 2)

print(output_ls)

[0, 1, 3, 6, 10, 15, 21, 28, 36, 45]
