In [1]:
from functools import partial

import jax
from jax import Array, numpy as jnp
import jax.experimental.pallas as pl

from pallas_visualisation import visualise, pallas_call

  from .autonotebook import tqdm as notebook_tqdm


We can see how BlockSpecs are dividing up a simple add kernel

In [2]:
def add_vectors_kernel(x_ref, y_ref, o_ref):
  x, y = x_ref[...], y_ref[...]
  o_ref[...] = x + y

inputs = (
  jnp.ones((128,)), 
  jnp.ones((128,))
)
block_size = 32

pallas_function = pallas_call(
  add_vectors_kernel,
  out_shape=jax.ShapeDtypeStruct(inputs[0].shape, inputs[0].dtype),
  grid=(pl.cdiv(inputs[0].shape[0], block_size),),
  in_specs=(
    pl.BlockSpec(lambda i: (i,), (block_size,)),
    pl.BlockSpec(lambda i: (i,), (block_size,)),
  ),
  out_specs=(
    pl.BlockSpec(lambda i: (i,), (block_size,))
  ),  
)
visualise(
  pallas_function,
  inputs, 
  display_full_grid=True
)


Or more complex indexing, such as in matrix multiplication to see blocking in M, N and K dims

In [4]:
def matmul_kernel(x_ref, y_ref, o_ref, activation, block_k):
  acc = jnp.zeros((x_ref.shape[0], y_ref.shape[1]), jnp.float32)
  for k in range(x_ref.shape[1] // block_k):
    x = x_ref[:, k*block_k:(k+1)*block_k]
    y = y_ref[k*block_k:(k+1)*block_k, :]
    acc += x @ y
  o_ref[:, :] = activation(acc).astype(o_ref.dtype)

x, y = jnp.ones((32, 32)), jnp.ones((32, 64))
block_shape = 16, 32, 8

activation = jax.nn.gelu
block_m, block_n, block_k = block_shape

fused_matmul = pallas_call(
  partial(matmul_kernel, block_k=block_k, activation=activation),
  out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1],), jnp.float32),
  in_specs=[
      pl.BlockSpec(lambda i, j: (i, 0), (block_m, x.shape[1])),
      pl.BlockSpec(lambda i, j: (0, j), (y.shape[0], block_n))
  ],
  out_specs=pl.BlockSpec(lambda i, j: (i, j), (block_m, block_n)),
  grid=(2,2),
  interpret=True
)

visualise(
  fused_matmul,
  (x,y), 
  display_full_grid=True
)