In [8]:
import functools
import math

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu

def blockwise_topk(\n logits,\n k,\n block_topk_val=None,\n block_topk_index=None,\n start_k=0,\n num_blocks=128,\n mode="jax",\n):\n """Compute blockwise top-k."""\n ntokens = logits.shape[0]\n assert (logits.dtype==jnp.bfloat16) and (pl.cdiv(logits.shape[1], num_blocks) < jnp.iinfo(jnp.uint16).max)\n\n # pack\n # vals are in bf16, indexs in uint16\n block_topk_packeds = [[],[]]\n for i in range(k):\n   block_topk_packeds[0].append(\n      ((block_topk_val[i].view(jnp.int32) << 16) >> 16\n      ) & (block_topk_val[i].view(jnp.int32) >> 16)\n   )\n   block_topk_packeds[1].append(\n      ((block_topk_val[i].view(jnp.int32) >> 16\n      ) & ((block_topk_val[i].view(jnp.int32) << 16) >> 16)\n      )\n   )\n\n def while_body(i, while_carry):\n   block_topk_packed = while_carry\n\n   if mode == "pallas":\n     vals_carry = logits.bitcast(jnp.int32)[..., pl.dslice(num_blocks * i, num_blocks)]\n   elif mode == "jax":\n     vals_carry = jax.lax.dynamic_slice_in_dim(\n       logits, num_blocks * i, num_blocks, axis=1\n     ).view(jnp.int32)\n   else:\n     raise ValueError(\n       "mode must be either `pallas` and a memory ref or `jax` and an array"\n     )\n\n   index_carry = jnp.full((ntokens, num_blocks), i, jnp.int32)\n   left_bubble = ((vals_carry << 16) >> 16) & index_carry\n   right_bubble = (vals_carry >> 16) & index_carry\n\n   bubbles = [left_bubble, right_bubble]\n   for (bubble, block_topk_packed) in zip(bubbles, block_topk_packeds):\n     for j in range(k):\n       if j < start_k:\n         # Nothing will be exchanged into the completed block topk, we just need\n         # to invalidate it from flowing downward. So we check if it's already\n         # found and invalidate if so.\n         bubble = jnp.where(\n           bubble == block_topk_packed[j], float("-inf"), bubble\n         )\n       else:\n         # Sinking bubble sort\n         mask = bubble > block_topk_packed[j]\n         block_topk_packed[j], bubble = (\n           jnp.where(v, bubble , block_topk_packed[j]) for v in (mask, ~mask)\n         )\n    return block_topk_packeds\n  block_topk_packeds = jax.lax.fori_loop(0, logits.shape[-1] // num_blocks, while_body, block_topk_packeds )\n\n  # unpack\n  for i in range(start_k, k):\n    block_topk_val[i] = ((block_topk_packed[0][i] << 16) >> 16) & (block_topk_packed[1][i] << 16)\n    block_topk_index[i] = ((block_topk_packed[0][i] >> 16) & ((block_topk_packed[1][i] << 16) >> 16))\n    block_topk_val[i] = block_topk_val[i].view(jnp.bfloat16)\n    block_topk_index[i] = block_topk_index[i].view(jnp.uint16)\n  return block_topk_val, block_topk_induex\n  \n\n# Pallas kernel scaffolding\ndef topk_blockwise_superset_kernel(\n logits_ref,\n block_topm_val_refs,\n block_topm_index_refs,\n max_m_ref,\n flag_ref,\n num_blocks: int = 128,\n k: int = 64,\n m_schedule: tuple[int] | None = None,\n):\n """Compute blockwise top-m's until they contain global top-k."""\n ### Initialize refs\n shape = block_topm_val_refs[0].shape\n for i in range(len(block_topm_val_refs)):\n   block_topm_val_refs[i][...] = jnp.full(shape, float("-inf"), dtype=logits_ref.dtype)\n   block_topm_index_refs[i][...] = jnp.full(shape, jnp.iinfo(jnp.uint16).max, dtype=jnp.uint16)\n\n block_token = logits_ref.shape[0]\n for i in range(block_token):\n   # Worst case m = k\n   max_m_ref[pl.program_id(0) * block_token + i] = k\n\n # flag for termination of while loop\n flag_ref[0] = 0\n\n ### Run increasing block topk, until sure overall topk present\n if m_schedule is None:\n   m_schedule = (5, 8, 12)\n # Ensure worst case of all k in one block is covered\n m_schedule = (0,) + m_schedule + (k,)\n\n for completed_m, m in zip(m_schedule, m_schedule[1:]):\n\n   @pl.when(flag_ref[0] == 0)\n   def _():\n     topk_vals, topk_indexs = blockwise_topk(\n       logits_ref,\n       # bf16, bf16 -> i1 mask not supported on v5e so we cast to f32\n       # TODO: check v6e, bf16 comparitor and make model specific\n       block_topk_val=jax.tree.map(\n         lambda ref: ref[...], block_topm_val_refs\n       ),\n       block_topk_index=jax.tree.map(lambda ref: ref[...], block_topm_index_refs),\n       k=m,\n       start_k=completed_m,\n       mode="pallas",\n     )\n\n     for i in range(completed_m, m):\n       block_topm_val_refs[i][...] = topk_vals[i].astype(block_topm_val_refs[i].dtype)\n       block_topm_index_refs[i][...] = topk_indexs[i].astype(\n         block_topm_index_refs[i].dtype\n       )\n\n     # Stopping criterion check\n     # To find top-k values of a set, we can split into N subsets,\n     # and sort the largest, 2nd-largest, 3-rd largest, ..., m-th largest values for each subset\n     # When in the superset of top-(m-1) subsets there are more than k values\n     # larger (or equal than) the largest m'th largest value from the subsets\n     # then the top-(m-1) subsets must contain the top-k of the set.\n     # We run a schedule of m's until we have that full top-k found.\n     pivot_point = topk_vals[m - 1].max(-1, keepdims=True)\n     n_larger = (\n       sum([(v >= pivot_point) for v in topk_vals[: m - 1]])\n       .astype(jnp.float32)\n       .sum(-1)\n     )\n     # flag SMEM used to check if all searches terminated\n     flag_ref[0] = 0\n     for i in range(block_token):\n       blockwise_topm_contains_topk = n_larger[i] >= k\n       flag_ref[0] += blockwise_topm_contains_topk\n       # Store when the criteria was hit for each query\n       token_index = pl.program_id(0) * block_token + i\n       max_m = max_m_ref[token_index]\n       max_m_ref[token_index] = jnp.where(\n         blockwise_topm_contains_topk & (max_m == k), m - 1, max_m\n       )\n\n     # If not all terminated, reset the flag say we need to search deeper\n     @pl.when(flag_ref[0] != block_token)\n     def _():\n       flag_ref[0] = 0\n\n\n# Pallas function\ndef topk_blockwise_superset_pallas(\n logits, k, num_blocks=128, block_token=None, m_schedule=None\n):\n num_tokens, vocab_size = logits.shape\n if block_token is None:\n   block_token = min(32, num_tokens)\n if num_tokens % block_token != 0:\n   raise ValueError("token block size must be a multiple of num tokens")\n\n out_shape = (\n   [jax.ShapeDtypeStruct((num_tokens, num_blocks), logits.dtype) for i in range(k)],\n   # uint16 fits vocab size of up to 2**16 * 128 = 8.4M. But not used to avoid unforseen issues.\n   [jax.ShapeDtypeStruct((num_tokens, num_blocks), jnp.uint16) for i in range(k)],\n   jax.ShapeDtypeStruct((num_tokens,), jnp.int32),\n   jax.ShapeDtypeStruct((1,), jnp.int32),  # scratch for termination flag\n )\n out_specs = jax.tree.map(\n   lambda _: pl.BlockSpec((block_token, num_blocks), lambda i: (i, 0)), out_shape[:2]\n )\n out_specs += (\n   pl.BlockSpec(memory_space=pltpu.SMEM),\n   pl.BlockSpec(memory_space=pltpu.SMEM),\n )\n return pl.pallas_call(\n   functools.partial(\n     topk_blockwise_superset_kernel,\n     k=k,\n     num_blocks=num_blocks,\n     m_schedule=m_schedule,\n   ),\n   in_specs=(pl.BlockSpec((block_token, vocab_size), lambda i: (i, 0)),),\n   out_shape=out_shape,\n   grid=(num_tokens // block_token),\n   out_specs=out_specs,\n   compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=2**26),\n )(logits)\n\n\ndef topk_on_filtered_subset(block_topm_val, block_topm_index, k):\n num_blocks = block_topm_val[0].shape[-1]\n topk_logits, local_indices = jax.lax.top_k(\n   jnp.concatenate(block_topm_val, axis=-1), k=k\n )\n\n @jax.vmap\n def unravel_indices(local_indices, block_topm_index):\n   m, col = jnp.unravel_index(local_indices, (k, num_blocks))\n   row = jnp.stack(block_topm_index)[m, col]\n   flat_index = row * num_blocks + col\n   return flat_index\n\n topk_flat_indices = unravel_indices(local_indices, block_topm_index)\n return topk_logits, topk_flat_indices\n\n\n@functools.partial(\n jax.jit,\n static_argnames=(\n   "k",\n   "num_blocks",\n   "m_stage1_schedule",\n   "m_stage2_schedule",\n   "block_token",\n ),\n)\ndef topk_optimized(\n logits,\n k: int = 64,\n num_blocks: int = 128,\n m_stage1_schedule: tuple[int] | None = None,\n m_stage2_schedule: tuple[int] | None = None,\n block_token: int | None = None,\n):\n """Fast implementation of jax.lax.top_k on TPUs."""\n if logits.ndim != 2:\n   raise ValueError("Expected 2D input")\n block_topm_val, block_topm_index, termination_m, _ = topk_blockwise_superset_pallas(\n   logits,\n   k=k,\n   block_token=block_token,\n   m_schedule=m_stage1_schedule,\n   num_blocks=num_blocks,\n )\n\n # top-k the smallest number of values we can, by taking max m required\n # such that all queries to have full top-k\n # We compile for a range of shapes, then use jax.lax.cond to run just one.\n # in practice 8 nearly always sufficient\n if m_stage2_schedule is None:\n   m_init = 8\n   m_stage2_schedule = [\n     m_init * (2**i) for i in range(int(math.log2(k // m_init)) + 1)\n   ]\n # Guarantee all cases covered\n m_stage2_schedule = (-1,) + tuple(m_stage2_schedule) + (k,)\n\n # Buffer for output to be written in to\n topk_logits, topk_flat_indices = jax.tree.map(\n   jnp.zeros_like,\n   topk_on_filtered_subset(block_topm_val[:1], block_topm_index[:1], k=k),\n )\n max_m = termination_m.max()\n for lower_m, upper_m in zip(m_stage2_schedule, m_stage2_schedule[1:]):\n   topk_logits, topk_flat_indices = jax.lax.cond(\n     (max_m > lower_m) & (max_m <= upper_m),\n     lambda *args: topk_on_filtered_subset(\n       block_topm_val=block_topm_val[:upper_m],\n       block_topm_index=block_topm_index[:upper_m],\n       k=k,\n     ),\n     lambda *args: args,\n     topk_logits,\n     topk_flat_indices,\n   )\n return topk_logits, topk_flat_indices\n

In [9]:
k = 64
num_queries = 128
vocab_size = 201088

# To get large value range, randint across uint16, then bitcast to bfloat16, then remove non-normal values
dtype = jnp.uint16
logits = jax.lax.bitcast_convert_type(jax.random.randint(jax.random.key(7), (num_queries, vocab_size), dtype=dtype, minval=jnp.iinfo(dtype).min, maxval=jnp.iinfo(dtype).max), jnp.bfloat16)
logits = jnp.where(jnp.isnan(logits) | (logits==jnp.inf), 0, logits) # remove the nans and +infs

# Adversarial logits, in practice this is astronomically unlikely
logits_worst_case = jnp.zeros((num_queries, vocab_size)).at[...,::128].set(1.)

all_values_match = (jax.lax.top_k(logits, k)[0] == topk_optimized(logits, k=k)[0]).all()
exact_index_match =  (jax.lax.top_k(logits, k)[1] == topk_optimized(logits, k=k)[1]).mean()
print(f'''All topk_logits match = {all_values_match}. Indices match at {exact_index_match:.0%} of the time,
this is only O(40%) [not 100%] as bf16 has only 2**16=65k possible values so in 200k vocab size
theres high degeneracy and sorting is different. Having checked, it looks correct. Writing full checks would be tricky.''')


All topk_logits match = True. Indices match at 32% of the time,
this is only O(40%) [not 100%] as bf16 has only 2**16=65k possible values so in 200k vocab size
theres high degeneracy and sorting is different. Having checked, it looks correct. Writing full checks would be tricky.


In [None]:
import functools
import jax
import jax.numpy as jnp

k = 64
num_queries = 32
vocab_size = 201088
hidden_dim = 2880

logit_key, key_act, key_weight = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(key_act, (num_queries, hidden_dim), dtype=jnp.bfloat16)
w = jax.random.normal(key_weight, (hidden_dim, vocab_size), dtype=jnp.bfloat16)
logits = jax.random.normal(key_weight, (num_queries, vocab_size), dtype=jnp.float32).astype(jnp.bfloat16)

topk_xla = jax.jit(jax.lax.top_k, static_argnames=('k',))
approx_topk_xla = jax.jit(jax.lax.approx_max_k, static_argnames=('k',))

@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None))
def matmul_and_topk_xla(x, w, k=k):
  logits = (x @ w)
  return jax.lax.top_k(logits, k)

def run():
  # reference runtimes
  o = jax.block_until_ready(x @ w)
  jax.block_until_ready(matmul_and_topk_xla(x, w))
  jax.block_until_ready(topk_xla(logits, k=k))

  # Optimized tpu run on random logits
  jax.block_until_ready(topk_optimized(logits, k=k))

  # Optimized tpu on adversarial logits, to check astronomically unlikely worst case runtime
  jax.block_until_ready(topk_optimized(logits_worst_case, k=k))

  # Not exact. Runtime varies with recall, here run with default 0.95
  jax.block_until_ready(approx_topk_xla(logits, k=k))



run()
with jax.profiler.trace("/tmp/tensorboard"):
  run()


In [None]:
# 32 tokens, top-64, v5e: topk_xla 1.32ms, topk_pallas_tpu 0.120ms (11x faster)
# 2048 tokens, top-64, v5e: topk_xla 87.25ms, topk_pallas_tpu 7.23ms (12x faster)


In [None]:
# !pip install tensorboard tensorboard-plugin-profile
%load_ext tensorboard
%tensorboard --logdir=/tmp/tensorboard


In [None]:
# Track how many iters you expect it to take
# stopping_probs = (logits.reshape(logits.shape[0], -1, 128).sort(1).swapaxes(0,1)[::-1].max(-1) < jax.lax.top_k(logits, 64)[0].min(-1)).sum(-1)[:8] / logits.shape[0]


In [None]:
# Some maths calculating expectations of how far in subset top-m you need
# to go to have filtered out all overall top-k values
# Beware, this code is slow (as you might expect), so precomputed vals below


# Each next largest value to be taken if the top-most not yet taken from one of the 128 lanes.
# e.g. the second value taken is one of one of the 127 other top-1s or the 2nd largest of the taken lane.
# and so on.

def add_one(x, i):
  if i == len(x):
    # append to end
    return x + (1,)
  else:
    # increment an intermediate val
    return x[:i] + (x[i]+1,) + x[i+1:]

def compute_termination_probabilities(k = 64, num_lanes = 128):
  parent_states = {
      (1,): 1.0, # only non zero values are in the hash
  }
  for generation in range(k-1):
    child_states = {}
    for state, parent_prob in parent_states.items():
      for i in range(len(state)+1):
        if i == 0:
          child_prob = (1 - state[0] / num_lanes)
        elif i == len(state):
          child_prob = state[i-1] / num_lanes
        else:
          child_prob = (state[i-1] - state[i]) / num_lanes
        if child_prob > 0:
          child_prob *= parent_prob
          child = add_one(state, i)
          # print(state, i, child, child_prob)
          if child in child_states:
            child_states[child] += child_prob
          else:
            child_states[child] = child_prob
    parent_states = child_states


  probs = [0 for _ in range(k)]
  for state, p in child_states.items():
    probs[len(state)-1] += p
  return probs

# compute_termination_probabilities(k=64, num_lanes=128)
# precomputed probs for k=64 and 128 lanes (random subsets being sorted)
probs = jnp.array([
4.1812404072740824e-09,
0.13436890728003542,
0.673656436392413,
0.17274256944422797,
0.01775253655062719,
0.0013838922281360039,
9.032679548553634e-05,
5.067660246625492e-06,
2.4828455416721553e-07,
1.075248105864657e-08,
4.156291895201961e-10,
1.4454295961378757e-11,
4.552534161066638e-13,
1.3058450068301742e-14,
3.427414716089682e-16,
8.264927218917032e-18,
1.8375011880871763e-19,
3.77788958180651e-21,
7.201944499092436e-23,
1.2759350490518157e-24,
2.1050297022227087e-26,
3.239666327687077e-28,
4.658198759426817e-30,
6.265949774819566e-32,
7.894109952528496e-34,
9.323751912435309e-36,
1.0332533469598683e-37,
1.0750948773204598e-39,
1.0508665648530184e-41,
9.653629860854444e-44,
8.33689142161678e-46,
6.769621479167146e-48,
5.168883019168442e-50,
3.710870161978265e-52,
2.5045242937986133e-54,
1.5886090227506496e-56,
9.466067809537837e-59,
5.295976602932471e-61,
2.7800402115131096e-63,
1.368130025350939e-65,
6.305957481932515e-68,
2.719104276048889e-70,
1.0954091571704027e-72,
4.116605637182969e-75,
1.4406318940272857e-77,
4.6853827433273565e-80,
1.412914883228222e-82,
3.9402153895800056e-85,
1.0130716090837228e-87,
2.393082541142658e-90,
5.1726347963559024e-93,
1.0182351961330517e-95,
1.81530565348338e-98,
2.9116888580223354e-101,
4.168487985715583e-104,
5.275083221518592e-107,
5.82962643626865e-110,
5.53996538879725e-113,
4.43611268287515e-116,
2.9108350937500984e-119,
1.502948286433509e-122,
5.7262444238005166e-126,
1.431382183177232e-129,
1.761050914342067e-133])

### Termination probs
block_token = 32 # algorithm batched to sort 32 tokens of subsets
probs.cumsum() ** block_token
# by 128 top-8's you have 99.99998% chance of having the full
# top-64 values for all 32 tokens
