<a href="https://colab.research.google.com/github/oliverdutton/fast_exact_topk_tpu/blob/main/Fast_exact_topk_tpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install tensorboard tensorboard-plugin-profile
# Pip install will require a restart, then comment the code


import functools
import math

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu


def block_topk(logits, k, topk_val=None, topk_index=None, start_k=0, num_lanes=128, mode='jax'):
  ntokens = logits.shape[0]

  if mode == 'jax':
    topk_val = [jnp.full((ntokens, num_lanes), float('-inf'), dtype=logits.dtype) for i in range(k)]
    # TODO?: Could use uint16 when vocab size < 4M if hardware supports
    topk_index = [jnp.full((ntokens, num_lanes), -1, dtype=jnp.int32) for i in range(k)]
  elif mode=='pallas':
    if topk_val is None or topk_index is None:
      raise ValueError("Pass through of topk_val and tok_index expected for pallas topk.")

  def while_body(i, while_carry):
    topk_val, topk_index = while_carry

    if mode == 'pallas':
      vals_carry = logits[..., pl.dslice(num_lanes*i, num_lanes)]
    elif mode == 'jax':
      vals_carry = jax.lax.dynamic_slice_in_dim(logits, num_lanes*i, num_lanes, axis=1)
    else:
      raise ValueError("mode must be either `pallas` and a memory ref or `jax` and an array")

    index_carry = jnp.full((ntokens, num_lanes), i, jnp.int32)

    for depth in range(k):
      if depth < start_k:
        # Nothing will be exchanged into the completed block topk, we just need
        # to invalidate it from flowing downward. So we check if it's already
        # found and invalidate if so.
        vals_carry = jnp.where(index_carry == topk_index[depth], float('-inf'), vals_carry)
      else:
        # Sinking sort
        mask = vals_carry > topk_val[depth]
        # TODO: Consider packing bfloat16 val and uint16 index into single uint32 and packed sort as in https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/topk_details/_topk_forward.py
        topk_val[depth], vals_carry = (
            jnp.where(m, vals_carry, topk_val[depth]) for m in (mask, ~mask))
        topk_index[depth], index_carry = (
            jnp.where(m, index_carry, topk_index[depth]) for m in (mask, ~mask))
    return (topk_val, topk_index)

  (topk_val, topk_index) = jax.lax.fori_loop(
      0,
      logits.shape[-1] // num_lanes,
      while_body,
      (topk_val, topk_index)
  )
  return topk_val, topk_index

# Pallas kernel scaffolding
def block_topk_kernel(logits_ref, topk_val_refs, topk_index_refs, max_depth_ref, flag_ref, num_lanes=128, k=64, depth_schedule=None):
  ### Initialize refs
  shape = topk_val_refs[0].shape
  for i in range(len(topk_val_refs)):
    topk_val_refs[i][...] = jnp.full(shape, float('-inf'), dtype=logits_ref.dtype)
    topk_index_refs[i][...] = jnp.full(shape, -1, dtype=jnp.int32)

  block_token = logits_ref.shape[0]
  for i in range(block_token):
    max_depth_ref[pl.program_id(0) * block_token + i] = -1

  # flag for termination of while loop
  flag_ref[0] = 0

  ### Run increasing block topk, until sure overall topk present
  if depth_schedule is None:
    depth_schedule = (0, 5, 8, 12, k)

  for completed_depth, depth in zip(depth_schedule, depth_schedule[1:]):
    @pl.when(flag_ref[0] == 0)
    def _():

      topk_vals, topk_indexs = block_topk(
          logits_ref,
          # bf16, bf16 -> i1 mask not supported on v5e so we cast to f32
          # TODO: check v6e, bf16 comparitor and make model specific
          topk_val=jax.tree.map(lambda ref: ref[...].astype(jnp.float32), topk_val_refs),
          topk_index=jax.tree.map(lambda ref: ref[...], topk_index_refs),
          k=depth,
          start_k=completed_depth,
          mode='pallas',
      )

      for i in range(completed_depth, depth):
        topk_val_refs[i][...] = topk_vals[i].astype(topk_val_refs[i].dtype)
        topk_index_refs[i][...] = topk_indexs[i].astype(topk_index_refs[i].dtype)

      # Stopping criterion check
      # To find top-k values of a set, we can split into N subsets,
      # and sort the largest, 2nd-largest, 3-rd largest, ..., m-th largest values for each subset
      # When in the superset of top-(m-1) subsets there are more than k values
      # larger (or equal than) the largest m'th largest value from the subsets
      # then the top-(m-1) subsets must contain the top-k of the set.
      # We run a schedule of m's until we have that full top-k found.
      pivot_point = topk_vals[depth-1].max(-1, keepdims=True)
      n_larger = sum(
          [(v >= pivot_point) for v in topk_vals[:depth-1]]
      ).astype(jnp.float32).sum(-1)
      # flag SMEM used to check if all searches terminated
      flag_ref[0] = 0
      for i in range(block_token):
        topk_all_present = n_larger[i] > k
        flag_ref[0] += topk_all_present
        # Store when the criteria was hit for each query
        token_index = pl.program_id(0) * block_token + i
        block_topk_depth = max_depth_ref[token_index]
        max_depth_ref[token_index] = jnp.where(
            topk_all_present & (block_topk_depth == -1),
            depth - 1,
            block_topk_depth)

      # If not all terminated, reset the flag say we need to search deeper
      @pl.when(flag_ref[0] != block_token)
      def _():
        flag_ref[0] = 0

# Pallas function
def block_topk_pallas(logits, k, num_lanes=128, block_token=None, depth_schedule=None):
  num_tokens, vocab_size = logits.shape
  if block_token is None:
    block_token = min(32, num_tokens)
  if num_tokens % block_token != 0:
    raise ValueError('token block size must be a multiple of num tokens')

  out_shape = (
          [jax.ShapeDtypeStruct((num_tokens, num_lanes), logits.dtype) for i in range(k)],
          [jax.ShapeDtypeStruct((num_tokens, num_lanes), jnp.int32) for i in range(k)], # uint16 fits vocab size of up to 2**16 * 128 = 8.4M. But not used to avoid unforseen issues.
          jax.ShapeDtypeStruct((num_tokens,), jnp.int32), # block_topk required to be certain to contain topk vals
          jax.ShapeDtypeStruct((1,), jnp.int32), # scratch for stopping boolean
  )
  out_specs = jax.tree.map(
      lambda _: pl.BlockSpec((block_token, num_lanes), lambda i: (i, 0)),
      out_shape[:2]
  )
  out_specs += (
      pl.BlockSpec(memory_space=pltpu.SMEM),
      pl.BlockSpec(memory_space=pltpu.SMEM),
  )
  return pl.pallas_call(
      functools.partial(
          block_topk_kernel,
          k=k,
          num_lanes=num_lanes,
          depth_schedule=depth_schedule,
      ),
      in_specs=(
          pl.BlockSpec((block_token, vocab_size), lambda i: (i, 0)),
      ),
      out_shape=out_shape,
      grid=(num_tokens // block_token),
      out_specs=out_specs,
      debug=False,
      compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=2**26),
  )(logits)

def topk_on_filtered_subset(topk_val, topk_index, k):
  num_lanes = topk_val[0].shape[-1]
  topk_logits, local_indices = jax.lax.top_k(
      jnp.concatenate(topk_val, axis=-1),
      k=k
  )

  @jax.vmap
  def unravel_indices(local_indices, topk_index):
    depth, col = jnp.unravel_index(local_indices, (k, num_lanes))
    row = jnp.stack(topk_index)[depth, col]
    flat_index = row * num_lanes + col
    return flat_index

  topk_flat_indices = unravel_indices(local_indices, topk_index)
  return topk_logits, topk_flat_indices


@functools.partial(jax.jit, static_argnames=('k', 'base_cutoff', 'block_token', 'depth_schedule', 'num_lanes'))
def topk_optimized(logits, k=64, base_cutoff=8, block_token=None, depth_schedule=None, num_lanes=128):
  topk_val, topk_index, depths, _ = block_topk_pallas(logits, k=k, block_token=block_token, depth_schedule=depth_schedule, num_lanes=num_lanes)

  # top-k the smallest number of values we can, by taking max depth required
  # such that all queries in logits are guaranteed to have top-k
  # We compile for a range of shapes, then use jax.lax.cond to run just one.
  # in practice 8 nearly always sufficient
  cutoff_schedule = [-1]+[
      base_cutoff * (2**i) for i in range(int(math.log2(k // base_cutoff))+1)
  ] + [k]

  # Buffer for output to be written in to
  topk_logits, topk_flat_indices = jax.tree.map(
      jnp.zeros_like,
      topk_on_filtered_subset(topk_val[:1], topk_index[:1], k=k)
  )
  for min_cutoff, cutoff in zip(cutoff_schedule, cutoff_schedule[1:]):
    max_depth = depths.max()
    topk_logits, topk_flat_indices = jax.lax.cond(
        (max_depth > min_cutoff) & (max_depth <= cutoff),
        lambda *args: topk_on_filtered_subset(topk_val=topk_val[:cutoff], topk_index=topk_index[:cutoff], k=k),
        lambda *args: args,
        topk_logits, topk_flat_indices
    )
  return topk_logits , topk_flat_indices

In [None]:
k = 64
num_queries = 32
vocab_size = 201088
logits = jax.random.normal(jax.random.key(7), (num_queries, vocab_size), dtype=jnp.float32).astype(jnp.bfloat16)

#
all_values_match = (jax.lax.top_k(logits, 64)[0] == topk_optimized(logits, k=64)[0]).all()
exact_index_match =  (jax.lax.top_k(logits, 64)[1] == topk_optimized(logits, k=64)[1]).mean()
print(f'''All topk_logits match = {all_values_match}. Indices match at {exact_index_match:.0%} of the time,
this is only O(40%) [not 100%] as bf16 has only 2**16=65k possible values so in 200k vocab size
theres high degeneracy and sorting is different. Having checked, it looks correct. Writing full checks would be tricky.''')

In [None]:
import functools
import jax
import jax.numpy as jnp

k = 64
num_queries = 32
vocab_size = 201088
hidden_dim = 2880

logit_key, key_act, key_weight = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(key_act, (num_queries, hidden_dim), dtype=jnp.bfloat16)
w = jax.random.normal(key_weight, (hidden_dim, vocab_size), dtype=jnp.bfloat16)
logits = jax.random.normal(key_weight, (num_queries, vocab_size), dtype=jnp.float32).astype(jnp.bfloat16)

topk_xla = jax.jit(jax.lax.top_k, static_argnames=('k',))
approx_topk_xla = jax.jit(jax.lax.approx_max_k, static_argnames=('k',))

@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def matmul_and_topk_xla(x, w, k=k):
  logits = (x @ w)
  return jax.lax.top_k(logits, k)

def run():
  # reference runtimes
  o = jax.block_until_ready(x @ w)
  jax.block_until_ready(matmul_and_topk_xla(x, w))
  jax.block_until_ready(topk_xla(logits, k=k))

  # optimized tpu run
  jax.block_until_ready(topk_optimized(logits, k=k))

  jax.block_until_ready(approx_topk_xla(logits, k=k))


run()
with jax.profiler.trace("/tmp/tensorboard"):
  run()


In [None]:
# 32 tokens, top-64, v6e: topk_xla 1.32ms, topk_pallas_tpu 0.120ms (11x faster)
# 2048 tokens, top-64, v6e: topk_xla 87.25ms, topk_pallas_tpu 7.23ms (12x faster)


In [None]:
# !pip install tensorboard tensorboard-plugin-profile
%load_ext tensorboard
%tensorboard --logdir=/tmp/tensorboard

In [None]:
# Track how many iters you expect it to take
# stopping_probs = (logits.reshape(logits.shape[0], -1, 128).sort(1).swapaxes(0,1)[::-1].max(-1) < jax.lax.top_k(logits, 64)[0].min(-1)).sum(-1)[:8] / logits.shape[0]

In [None]:
# Some maths calculating expectations of how far in subset top-m you need
# to go to have filtered out all overall top-k values
# Beware, this code is slow (as you might expect), so precomputed vals below


# Each next largest value to be taken if the top-most not yet taken from one of the 128 lanes.
# e.g. the second value taken is one of one of the 127 other top-1s or the 2nd largest of the taken lane.
# and so on.

def add_one(x, i):
  if i == len(x):
    # append to end
    return x + (1,)
  else:
    # increment an intermediate val
    return x[:i] + (x[i]+1,) + x[i+1:]

def compute_termination_probabilities(k = 64, num_lanes = 128):
  parent_states = {
      (1,): 1.0, # only non zero values are in the hash
  }
  for generation in range(k-1):
    child_states = {}
    for state, parent_prob in parent_states.items():
      for i in range(len(state)+1):
        if i == 0:
          child_prob = (1 - state[0] / num_lanes)
        elif i == len(state):
          child_prob = state[i-1] / num_lanes
        else:
          child_prob = (state[i-1] - state[i]) / num_lanes
        if child_prob > 0:
          child_prob *= parent_prob
          child = add_one(state, i)
          # print(state, i, child, child_prob)
          if child in child_states:
            child_states[child] += child_prob
          else:
            child_states[child] = child_prob
    parent_states = child_states


  probs = [0 for _ in range(k)]
  for state, p in child_states.items():
    probs[len(state)-1] += p
  return probs

# compute_termination_probabilities(k=64, num_lanes=128)
# precomputed probs for k=64 and 128 lanes (random subsets being sorted)
probs = jnp.array([
4.1812404072740824e-09,
0.13436890728003542,
0.673656436392413,
0.17274256944422797,
0.01775253655062719,
0.0013838922281360039,
9.032679548553634e-05,
5.067660246625492e-06,
2.4828455416721553e-07,
1.075248105864657e-08,
4.156291895201961e-10,
1.4454295961378757e-11,
4.552534161066638e-13,
1.3058450068301742e-14,
3.427414716089682e-16,
8.264927218917032e-18,
1.8375011880871763e-19,
3.77788958180651e-21,
7.201944499092436e-23,
1.2759350490518157e-24,
2.1050297022227087e-26,
3.239666327687077e-28,
4.658198759426817e-30,
6.265949774819566e-32,
7.894109952528496e-34,
9.323751912435309e-36,
1.0332533469598683e-37,
1.0750948773204598e-39,
1.0508665648530184e-41,
9.653629860854444e-44,
8.33689142161678e-46,
6.769621479167146e-48,
5.168883019168442e-50,
3.710870161978265e-52,
2.5045242937986133e-54,
1.5886090227506496e-56,
9.466067809537837e-59,
5.295976602932471e-61,
2.7800402115131096e-63,
1.368130025350939e-65,
6.305957481932515e-68,
2.719104276048889e-70,
1.0954091571704027e-72,
4.116605637182969e-75,
1.4406318940272857e-77,
4.6853827433273565e-80,
1.412914883228222e-82,
3.9402153895800056e-85,
1.0130716090837228e-87,
2.393082541142658e-90,
5.1726347963559024e-93,
1.0182351961330517e-95,
1.81530565348338e-98,
2.9116888580223354e-101,
4.168487985715583e-104,
5.275083221518592e-107,
5.82962643626865e-110,
5.53996538879725e-113,
4.43611268287515e-116,
2.9108350937500984e-119,
1.502948286433509e-122,
5.7262444238005166e-126,
1.431382183177232e-129,
1.761050914342067e-133])

### Termination probs
block_token = 32 # algorithm batched to sort 32 tokens of subsets
probs.cumsum() ** block_token
# by 128 top-8's you have 99.99998% chance of having the full
# top-64 values for all 32 tokens