# 11 Batched Top-K with exclusions + stable tie-breaking

Theme: topk, masking, tie-breaking, indexing

## Prompt

You have model scores for candidate items, but some items are not allowed.

- scores: float tensor [B, N]

- allowed: bool tensor [B, N] (True means selectable)

- k: int, 1 <= k <= N

- Guarantee: each row has at least k allowed items.

## Tie-breaking rule

If two items have equal score, choose the smaller index first (stable / deterministic).

## Requirements

- No Python loops over B or N.

- Must not accidentally pick disallowed items.

- Must implement deterministic tie-breaking.

## Why itâ€™s good

Candidates often do scores[~allowed] = -inf and topk, but:

- they might mutate input unintentionally

- tie-breaking is subtle (many topk implementations are not stable)

- broadcasting mistakes are common

## Optional extension (very ML-ish)

Return one_hot: [B, N] with True for selected items.

## Implement Function

In [None]:
import numpy as np

def select_topk(scores, allowed, k, **kwargs):
    """
    Returns:
      topk_idx: [B, k]  # indices of selected items
      topk_scores: [B, k]
    """

    B, N = scores.shape
    scratch_scores = np.where(allowed, scores, -np.inf)

    topk_idx = np.zeros(shape=(B, k))
    topk_scores = np.zeros(shape=(B, k))
    for i in range(k):
        kth_idx = scratch_scores.argmax(axis=1)
        kth_scores = scratch_scores[np.arange(B), kth_idx]
        topk_idx[:, i] = kth_idx
        topk_scores[:, i] = kth_scores
        scratch_scores[np.arange(B), kth_idx] = -np.inf
    
    return {"topk_idx": topk_idx, "topk_scores": topk_scores}

## Test Cases

In [5]:
test_cases = [
    # 1) Basic: all allowed, no ties
    {
        "scores": np.array(
            [
                [0.1, 0.4, 0.3, 0.2],
                [5.0, 1.0, 3.0, 2.0],
            ],
            dtype=np.float32,
        ),
        "allowed": np.array(
            [
                [True, True, True, True],
                [True, True, True, True],
            ],
            dtype=bool,
        ),
        "k": 2,
        "output": {
            "topk_idx": np.array([[1, 2], [0, 2]], dtype=np.int64),
            "topk_scores": np.array([[0.4, 0.3], [5.0, 3.0]], dtype=np.float32),
        },
    },

    # 2) Masking: highest score is disallowed, must skip it
    {
        "scores": np.array(
            [
                [10.0, 9.0, 8.0, 7.0],
            ],
            dtype=np.float32,
        ),
        "allowed": np.array(
            [
                [False, True, True, True],
            ],
            dtype=bool,
        ),
        "k": 2,
        "output": {
            "topk_idx": np.array([[1, 2]], dtype=np.int64),
            "topk_scores": np.array([[9.0, 8.0]], dtype=np.float32),
        },
    },

    # 3) Tie-breaking among allowed items: equal scores -> smaller index first
    #    Top-3 among indices {0,1,2,3}: scores 1.0 at 0,1,2 and 0.5 at 3
    {
        "scores": np.array(
            [
                [1.0, 1.0, 1.0, 0.5],
            ],
            dtype=np.float32,
        ),
        "allowed": np.array(
            [
                [True, True, True, True],
            ],
            dtype=bool,
        ),
        "k": 3,
        "output": {
            "topk_idx": np.array([[0, 1, 2]], dtype=np.int64),
            "topk_scores": np.array([[1.0, 1.0, 1.0]], dtype=np.float32),
        },
    },

    # 4) Tie-breaking with disallowed items present
    #    Among allowed indices {1,2,4}, scores are {2.0,2.0,2.0} -> pick 1 then 2 (smaller idx)
    {
        "scores": np.array(
            [
                [100.0, 2.0, 2.0, 50.0, 2.0],
            ],
            dtype=np.float32,
        ),
        "allowed": np.array(
            [
                [False, True, True, False, True],
            ],
            dtype=bool,
        ),
        "k": 2,
        "output": {
            "topk_idx": np.array([[1, 2]], dtype=np.int64),
            "topk_scores": np.array([[2.0, 2.0]], dtype=np.float32),
        },
    },

    # 5) Negative scores: still works; ensure masking doesn't introduce "0 wins max" type bugs
    {
        "scores": np.array(
            [
                [-1.0, -3.0, -2.0, -4.0],
                [-10.0, -1.0, -1.0, -2.0],
            ],
            dtype=np.float32,
        ),
        "allowed": np.array(
            [
                [True, True, True, True],
                [True, True, True, True],
            ],
            dtype=bool,
        ),
        "k": 2,
        "output": {
            "topk_idx": np.array([[0, 2], [1, 2]], dtype=np.int64),
            "topk_scores": np.array([[-1.0, -2.0], [-1.0, -1.0]], dtype=np.float32),
        },
    },

    # 6) k=1 (argmax with exclusions)
    {
        "scores": np.array(
            [
                [0.0, 1.0, 2.0, 3.0],
                [9.0, 8.0, 7.0, 6.0],
            ],
            dtype=np.float32,
        ),
        "allowed": np.array(
            [
                [True, False, True, True],   # best allowed is index 3 (score 3.0)
                [False, True, True, True],   # best allowed is index 1 (score 8.0)
            ],
            dtype=bool,
        ),
        "k": 1,
        "output": {
            "topk_idx": np.array([[3], [1]], dtype=np.int64),
            "topk_scores": np.array([[3.0], [8.0]], dtype=np.float32),
        },
    },

    # 7) Stress: mixed batch, mixed allowed counts, ties across rows
    {
        "scores": np.array(
            [
                [5.0, 5.0, 4.9, 5.0, 1.0],   # ties at 0,1,3
                [1.0, 2.0, 3.0, 4.0, 5.0],   # increasing
            ],
            dtype=np.float32,
        ),
        "allowed": np.array(
            [
                [True, True, True, True, False],   # allowed: 0,1,2,3
                [True, False, True, True, True],   # allowed: 0,2,3,4
            ],
            dtype=bool,
        ),
        "k": 3,
        "output": {
            # Row0: top scores are 5.0 at idx {0,1,3}; tie-break -> 0,1,3
            # Row1: allowed {0,2,3,4} with scores {1,3,4,5} -> 4,3,2
            "topk_idx": np.array([[0, 1, 3], [4, 3, 2]], dtype=np.int64),
            "topk_scores": np.array([[5.0, 5.0, 5.0], [5.0, 4.0, 3.0]], dtype=np.float32),
        },
    },
]


## Check Test Cases

In [65]:
from tqdm import tqdm

test_func = select_topk
error_list = []; correct = 0
num = len(test_cases)
for i, t in tqdm(enumerate(test_cases)):
    result = test_func(**t)
    if all((np.all(result[k] == t['output'][k]) for k in result.keys())):
        correct += 1
    else:
        error_list.append(f'Received {result}\non input\n{t}.\n\n')

print(f'Test cases passed: {correct}/{num}.\n')

if correct == num:
    print('Success!')
else:
    [print(e) for e in error_list]

7it [00:00, 3456.17it/s]

Test cases passed: 7/7.

Success!





## Solutions

My solution (2026-02-04):

```python
def select_topk(scores, allowed, k, **kwargs):
    """
    Returns:
      topk_idx: [B, k]  # indices of selected items
      topk_scores: [B, k]
    """

    B, N = scores.shape
    scratch_scores = np.where(allowed, scores, -np.inf)

    topk_idx = np.zeros(shape=(B, k))
    topk_scores = np.zeros(shape=(B, k))
    for i in range(k):
        kth_idx = scratch_scores.argmax(axis=1)
        kth_scores = scratch_scores[np.arange(B), kth_idx]
        topk_idx[:, i] = kth_idx
        topk_scores[:, i] = kth_scores
        scratch_scores[np.arange(B), kth_idx] = -np.inf
    
    return {"topk_idx": topk_idx, "topk_scores": topk_scores}
```

This is what ChatGPT came up with:

```python
def select_topk(scores: np.ndarray, allowed: np.ndarray, k: int):
    """
    scores:  [B, N]
    allowed: [B, N]
    returns:
      topk_idx:    [B, k]
      topk_scores: [B, k]
    """
    B, N = scores.shape

    masked = np.where(allowed, scores, -np.inf)  # [B, N]

    # Indices 0..N-1 broadcast to [B, N]
    idx = np.broadcast_to(np.arange(N), (B, N))

    # lexsort sorts by the *last* key first, so we pass (idx, -masked):
    # - primary: -masked (descending scores)
    # - secondary: idx (ascending index)
    order = np.lexsort((idx, -masked), axis=1)   # [B, N], ascending by keys
    topk_idx = order[:, :k]                      # [B, k]

    topk_scores = masked[np.arange(B)[:, None], topk_idx]  # [B, k]

    return {"topk_idx": topk_idx.astype(np.int64),
            "topk_scores": topk_scores.astype(scores.dtype)}
```