# How to Write a Softmax GPU Kernel in Pallas







## Softmax Operation

Given an input vector $z = (z_1, ..., z_n) \in R^n$, the softmax function σ : R^n → (0,1)^n produces a probability distribution over the n entries:

$$
\sigma(z)_i = \frac{\exp(z_i)}{\sum_{j=1}^n \exp(z_j)} \quad\text{for } i=1,\dots,n.
$$

Properties:
- Each output is positive and they sum to 1: ∑_i σ(z)_i = 1.
- Softmax is invariant to shifts: σ(z) = σ(z + c·1) for any scalar c. For numerical stability one commonly uses
$$
\sigma(z)_i = \frac{\exp(z_i - \max_j z_j)}{\sum_{k=1}^n \exp(z_k - \max_j z_j)}.
$$


Use cases: converts logits to probabilities (classification), used with cross-entropy loss for efficient training (softmax + log-loss simplifications).


## Online Softmax

The online softmax algorithm is a numerically stable method for computing softmax in a single pass, without storing all exponentials in memory. This is particularly useful for processing large sequences or implementing efficient GPU kernels.

### Algorithm Description

The key idea is to maintain running statistics (maximum value and sum of exponentials) as we iterate through the input:

1. **Track the running maximum**: As we process elements, we keep track of the largest value seen so far.
2. **Update the sum of exponentials**: When we encounter a new maximum, we rescale previous exponentials and add the new ones.
3. **Compute probabilities**: Finally, divide each exponential by the total sum.

This avoids numerical overflow/underflow and allows us to process data in blocks without materializing the full exponential array.

### Pseudocode

In [1]:
from functools import partial

import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl
from jax.experimental.pallas import triton as plgpu


def online_softmax(logits):
    max_rows = jnp.max(logits, axis=-1)
    s = jnp.exp(logits - max_rows[..., None])
    l = jnp.sum(s, axis=-1)
    l = l[..., None]
    return s / l 

The next step is to convert this into an efficient GPU kernel. See my previous post on writing an efficient matrix multiplication kernel here.


## Forward Pass

In [2]:
INTERPRET_MODE = True # Set to False on GPU

# Pallas softmax
BLOCK_M = 64
BLOCK_N = 64
NUM_WARPS = 4
NUM_STAGES = 3


def softmax_kernel(x_ref, out_ref, *, n_col_blocks, n_rows, n_cols):
    max_reg = jnp.full((BLOCK_M,), -jnp.inf, dtype=jnp.float32) 
    l_reg = jnp.zeros((BLOCK_M,), dtype=jnp.float32) 
    row_ids = pl.program_id(0) * BLOCK_M + jnp.arange(BLOCK_M)
    row_mask = row_ids < n_rows

    def stats_body(t, args):
        max_reg, l_reg = args
        idx = pl.dslice(t * BLOCK_N, BLOCK_N)
        col_ids = t * BLOCK_N + jnp.arange(BLOCK_N)
        cols_mask = col_ids < n_cols
        mask = row_mask[:, None] & cols_mask[None, :]

        x_tile = plgpu.load(
            x_ref.at[:, idx],
            mask=mask,
            other=-jnp.inf,
        ).astype(jnp.float32)
        max_tile = jnp.max(x_tile, axis=-1)
        max_new = jnp.maximum(max_reg, max_tile)
        l_update = l_reg * jnp.exp(max_reg - max_new) + jnp.sum(
            jnp.exp(x_tile - max_new[:, None]), axis=-1
        )
        return max_new, l_update
        
    max_reg, l_reg = jax.lax.fori_loop(0, n_col_blocks, stats_body, (max_reg, l_reg))

    def out_body(t, _):
        idx = pl.dslice(t * BLOCK_N, BLOCK_N)
        col_ids = t * BLOCK_N + jnp.arange(BLOCK_N)
        cols_mask = col_ids < n_cols
        mask = row_mask[:, None] & cols_mask[None, :]

        x_tile = plgpu.load(
            x_ref.at[:, idx],
            mask=mask,
            other=-jnp.inf,
        ).astype(jnp.float32)
        out_tile = jnp.exp(x_tile - max_reg[:, None]) / l_reg[:, None]
        plgpu.store(out_ref.at[:, idx], out_tile.astype(jnp.float32), mask=mask)

    _ = jax.lax.fori_loop(0, n_col_blocks, out_body, None)


@jax.jit
def softmax(logits):
    n_row_blocks = pl.cdiv(logits.shape[0], BLOCK_M)
    n_col_blocks = pl.cdiv(logits.shape[1], BLOCK_N)
    return pl.pallas_call(
        partial(softmax_kernel, n_col_blocks=n_col_blocks, n_rows=logits.shape[0], n_cols=logits.shape[1]),
        out_shape=jax.ShapeDtypeStruct(logits.shape, jnp.float32),
        grid=(n_row_blocks,),
        in_specs=[pl.BlockSpec((BLOCK_M, logits.shape[1]), lambda i: (i, 0))],
        out_specs=pl.BlockSpec((BLOCK_M, logits.shape[1]), lambda i: (i, 0)),
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES,
        ),
    )(logits)


## Performance

Let's compare our performance with the out-of-the-box softmax implementation provided by Jax.

In [3]:
import time

def bench(fn, *args, iters=10):
    times = []
    for _ in range(iters):
        t0 = time.perf_counter()
        out = fn(*args)
        out.block_until_ready()
        t1 = time.perf_counter()
        times.append(t1 - t0)
    times.sort()
    return times[len(times) // 2]


d = 1024
key = jax.random.key(0)
logits = jax.random.normal(shape=(d, d), key=key)

out_jax = jax.nn.softmax(logits)
out_pallas = softmax(logits)

assert jnp.allclose(jnp.squeeze(out_jax), out_pallas)

softmax_jit = jax.jit(jax.nn.softmax)

_ = softmax_jit(logits).block_until_ready()
_ = softmax(logits).block_until_ready()

t_jax = bench(softmax_jit, logits)
t_pallas = bench(softmax, logits)

print(f"Jax Softmax: {t_jax*1e3:.2f} ms")
print(f"Pallas Softmax: {t_pallas*1e3:.2f} ms")
print(f"Speedup (jax / pallas): {t_jax / t_pallas:.2f}x")

Jax Softmax: 0.83 ms
Pallas Softmax: 28.84 ms
Speedup (jax / pallas): 0.03x



## Backward Pass


### Gradient / Jacobian:
Let s = σ(z). The Jacobian matrix J with entries $∂σ_i/∂z_j$ is
$$
\frac{\partial \sigma_i}{\partial z_j} = s_i(\delta_{ij} - s_j),
$$
where $δ_{ij}$ is the Kronecker delta. Equivalently, $J = diag(s) - s s^T$.



## Let's Evaluate our Kernel in a Model Training Loop

In [5]:
import jax
import jax.numpy as jnp
import flax.nnx as nnx

D, B, T, C =  1000, config.n_experts, 5, config.n_embed 

default = jax.random.key(69)
gate_noise = jax.random.key(42)
rngs = nnx.Rngs(default=default, gate_noise=gate_noise)

#model = MOE(config, rngs)
#model.train(add_noise=False)
#tx = optax.adam(1e-2)
#state = nnx.Optimizer(model, tx)

x = jax.random.normal(jax.random.key(1000), (D, B, T, C))

expert_ids = (x[:, :, :, 0] > 0).astype(jnp.int32)[..., None]
t = [
    jax.random.normal(jax.random.key(2000), (C, C)),
    jax.random.normal(jax.random.key(3000), (C, C)),
]

def transform(xi, eid):
    return jnp.where(eid == 1, xi @ t[0], xi @ t[1])

y = jax.vmap(lambda xi, ei: transform(xi, ei))(x, expert_ids)

NameError: name 'config' is not defined