# Model of recall and cost

Problem: find the top-$K$ from $N$ elements (optionally in a batch of $B$ independent tasks).

Approximate method:
 1. Split $N$ elements into $L$ buckets (assume $L$ divides $N$), and compute the top-$J$ for each bucket (of size $N/L$).
 2. If $L\, J > K$, get the top-$K$ of these candidates.

Metrics: if $S^*$ is a set of the true top-$K$ and $S$ is the set returned from our algorithm, recall $R = |S^* \cap S|/|S^*|$.

Cost model (worst case, because of parallelism):

 | Algorithm | Total ops (serial) | Shortest-path ops (parallel) |
 | --- | --- | --- |
 | Insertion sort | $B\,N\,K(K+3)/2$ | TODO |
 | Radix select + filter | $B\,N(8 \log_2 N + 3)$ | TODO |


---

## Explaining the cost model

Model:

 - Iterating through the input data linearly is free.
 - If statements are free, but all branches are "taken".
 - Fixed-offset addressing is free.
 - Ops that count (1): `{==, <, >, &, |, not, =}`
 - Ops that count (2): `+=`

```python
# Insertion sort (unrolled)
for i, x in data:
    if x > topk[0]:            # +1
        topk[k-1] = topk[k-2]  # +1 *k
        ...                    # |
        topk[0] = (x, i)       # |

    if x > topk[1]:            # +1
        topk[k-1] = topk[k-2]  # +1 *(k-1)
        ...                    # |
        topk[1] = (x, i)       # |

    ...

    if x > topk[k-1]:          # +1
        topk[k-1] = (x, i)     # +1

# Radix select (base 2)
partition = 0
partition_mask = 0
for r in range(log(N)-1, -1, -1):  # log(N) bits required to find partition
    counts = [0, 0]
    mask = 1 << r
    for i, x in data:
        if x & partition_mask == partition:  # +2
            bit = x & r                      # +1
            if not bit: counts[0] += 1       # +3
            if bit: counts[1] += 1           # +2
    partition_mask |= mask
    partition |= update_partition(counts, mask, ...)

# - scan (second scan may be needed for ties)
ptr = 0
for i, x in data:
    if x > partition:      # +1
        out[ptr] = (x, i)  # +1
        ptr += 1           # +1
```

---

## Tests (WIP)

In [63]:
from collections import Counter
from math import log2
import random
import scipy.stats


def recall_model(K: int, L: int, J: int) -> float:
    return (J + sum([scipy.stats.binom.cdf(J-1, i, p=1/L) for i in range(J, K)])) / K


def recall_simulation(K: int, L: int, J: int, reps: int) -> float:
    recalls = []
    for _ in range(reps):
        counts = Counter(random.randint(0, L-1) for _ in range(K))
        recalls.append(sum(min(v, J) for v in counts.values()) / K)
    return sum(recalls) / len(recalls)


def ops_insertion(K: int, N: int, B: int) -> int:
    return B * N * (K * (K+3)) // 2


def ops_radix(K: int, N: int, B: int) -> int:
    return B * N * (8*log2(N) + 3)


def ops_topk(K: int, N: int, B: int) -> int:
    return min(ops_insertion(K=K, N=N, B=B), ops_radix(K=K, N=N, B=B))


def ops_approx_topk(K: int, N: int, L: int, J: int) -> int:
    return ops_topk(K=J, N=(N//L), B=L) + (J*L>K) * ops_topk(K=K, N=J*L, B=1)


# N = 2**20
# K = 1024
N = 16384
K = 1024
print("      exact ", ops_topk(K=K, N=N, B=1)/N)
print("   ours J=2 ", ops_approx_topk(K=K, N=N, L=K//2, J=2)/N)
print("theirs L=2K ", ops_approx_topk(K=K, N=N, L=2*K, J=1)/N)

# for (K, L, J) in [(100, 200, 1), (100, 400, 1), (100, 100, 2)]:
#     print(K, L, J, recall_model(K, L, J), recall_simulation(K, L, J, reps=1000))

      exact  115.0
   ours J=2  5.0
theirs L=2K  13.375
