In [17]:
from itertools import count
import math

def log2(i: int) -> int:
    return int(math.log2(i))

## Insertion sort (serial)

In [18]:
def topk_insertion(data: list[tuple[float, int]], k: int) -> list[tuple[float, int]]:
    topk = sorted(data[:k])  # can be merged into loop
    for x in data:                                        # n*
        if x > topk[0]:                                   # +1
            topk[0] = x                                   # +1
        for j in range(1, k):
            if topk[j-1] > topk[j]:                       # +1
                topk[j-1], topk[j] = topk[j], topk[j-1]   # +2
    return topk

topk_insertion(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(450, 6), (500, 4)]

## Scan-max (parallel)

In [33]:
def scan_argmax(data: list) -> int:
    argmax = [i for i in range(len(data))]      # +1
    for i in range(log2(len(data))):            # log(n)*
        argmax = [argmax[j]                     #   +1
                  if j+2**i >= len(data) or data[argmax[j+2**i]] < data[argmax[j]]  #   +1
                  else argmax[j+2**i]
                  for j in range(len(data))]
    return argmax[0]

def topk_scan_max(data: list[tuple[float, int]], k: int) -> list[tuple[float, int]]:
    data = data.copy()
    topk = [None for _ in range(k)]
    for i in range(k):                # k*
        j = scan_argmax(data)         # +2*log(n) + 1
        topk[i] = data[j]             # +1
        data[j] = (-float("inf"), 0)  # +1
    return topk

topk_scan_max(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(500, 4), (450, 6)]

## Radix select (serial)

In [69]:
def topk_radix_select(data: list[tuple[float, int]], k: int) -> list[tuple[float, int]]:
    partition = 0
    partition_mask = 0
    count_greater = 0
    for r in range(31, -1, -1):  # log(n)* (bits required to find partition)
        count_1 = 0
        mask = 1 << r
        partition |= mask
        partition_mask |= mask
        for x, _ in data:                        # n*
            if x & partition_mask == partition:  #   +2
                count_1 += 1                     #   +2

        if count_greater + count_1 < k:
            partition ^= mask
            count_greater += count_1

    topk = [None for _ in range(k)]
    ptr = 0
    for x in data:              # n*
        if x[0] >= partition:   #  +1
            topk[ptr] = x       #  +1
            ptr += 1            #  +2
    return topk

topk_radix_select(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(500, 4), (450, 6)]

## Radix select (parallel)

In [87]:
def scan_sum(data: list[float]) -> float:
    data = data.copy()                          # +1
    for i in range(log2(len(data))):            # log(n)*
        data = [data[j] + (data[j+2**i] if j+2**i < len(data) else 0) for j in range(len(data))]  # +2
    return data[0]

def scan_cumsum(data: list[int]) -> list[int]:
    data = data.copy()                          # +1
    for i in range(log2(len(data))):            # log(n)*
        data = [data[j] + (data[j-2**i] if j-2**i >= 0 else 0) for j in range(len(data))]  # +2
    return data

def topk_radix_select_parallel(data: list[tuple[float, int]], k: int) -> list[tuple[float, int]]:
    partition = 0
    partition_mask = 0
    count_greater = 0
    for r in range(31, -1, -1):  # log(n)* (bits required to find partition)
        mask = 1 << r                                                  # +1
        partition |= mask                                              # +2
        partition_mask |= mask                                         # +2

        count_1s = [x & partition_mask == partition for x, _ in data]  # +2
        count_1 = scan_sum(count_1s)                                   # +2log(n) + 1

        if count_greater + count_1 < k:                                # +2
            partition ^= mask                                          # +2
            count_greater += count_1                                   # +2

    in_topk = [x >= partition for x, _ in data]
    offset = scan_cumsum(in_topk)                                      # +2log(n) (+ 1)
    topk = [None for _ in range(k)]
    for i in range(len(data)):  # (parallel)
        if in_topk[i]:
            topk[offset[i] - 1] = data[i]
    return topk

topk_radix_select_parallel(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(500, 4), (450, 6)]