In [12]:
from itertools import count
import math

def log2(i: int) -> int:
    return int(math.log2(i))

inf = float("inf")

## Insertion sort (serial)

In [4]:
def topk_insertion(data, k):
    topk = sorted(data[:k])                              # (ignore)
    for x in data:                                       # * n
        if x > topk[0]:                                  # | +1
            topk[0] = x                                  # | +1
        for j in range(1, k):                            # | * (k-1)
            if topk[j-1] > topk[j]:                      # | | +1
                topk[j-1], topk[j] = topk[j], topk[j-1]  # | | +2
    return topk

topk_insertion(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(450, 6), (500, 4)]

## Scan-max (parallel)

In [11]:
def scan_argmax(data):
    a = list(range(len(data)))                   # +1
    for i in range(log2(len(data))):             # * log(n)
        a = [a[j]                                # | +1
             if j+2**i >= len(data) \
                or data[a[j+2**i]] < data[a[j]]  # | +1
             else a[j+2**i]
             for j in range(len(data))]
    return a[0]

def topk_scan_max(data, k):
    data = data.copy()
    topk = [None] * k
    for i in range(k):         # * k
        j = scan_argmax(data)  # | +2*log(n) + 1
        topk[i] = data[j]      # | +1
        data[j] = (-inf, 0)    # | +1
    return topk

topk_scan_max(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(500, 4), (450, 6)]

## Radix select (serial)

In [16]:
def topk_radix_select(data, k):
    # Find kth value
    kth_value, mask, count_gt = 0, 0, 0
    for r in range(31, -1, -1):          # * log(n)
        r_mask = 1 << r
        kth_value |= r_mask
        mask |= r_mask
        count_1 = 0
        for x, _ in data:                # | * n
            if x & mask == kth_value:    # | | +2
                count_1 += 1             # | | +2
        if count_gt + count_1 < k:
            kth_value ^= r_mask
            count_gt += count_1

    # Collect topk
    topk = [None] * k
    i = 0
    for x in data:             # * n
        if x[0] >= kth_value:  # | +1
            topk[i] = x        # | +1
            i += 1             # | +2
    return topk

topk_radix_select(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(500, 4), (450, 6)]

## Radix select (parallel)

In [31]:
def scan_cumsum(data):
    s = data.copy()                                    # +1
    for i in range(log2(len(s))):                      # * log(n)
        s = [s[j] + (s[j-2**i] if j-2**i >= 0 else 0)  # | +2
             for j in range(len(s))]
    return s

def topk_radix_select_parallel(data, k):
    # Find kth value
    kth_value, mask, count_gt = 0, 0, 0
    for r in range(31, -1, -1):                              # * log(n)
        r_mask = 1 << r                                      # | +1
        kth_value |= r_mask                                  # | +2
        mask |= r_mask                                       # | +2
        count_1s = [x & mask == kth_value for x, _ in data]  # | +2
        count_1 = scan_cumsum(count_1s)[-1]                  # | +2*log(n) + 1
        if count_gt + count_1 < k:                           # | +2
            kth_value ^= r_mask                              # | +2
            count_gt += count_1                              # | +2

    # Collect topk
    in_topk = [x >= kth_value for x, _ in data]
    offset = scan_cumsum(in_topk)                            # +2*log(n) (+ 1)
    topk = [None] * k
    for i in range(len(data)):                               # (in parallel)
        if in_topk[i]:
            topk[offset[i] - 1] = data[i]
    return topk

topk_radix_select_parallel(list(zip([100, 300, 400, 100, 500, 100, 450, 100], count())), 2)

[(500, 4), (450, 6)]