# Setup

Run the following cell to install JAX with TPU support and other dependencies. **After running this cell, you must restart the kernel** for the changes to take effect.

In [3]:
# Install JAX for TPU
!pip install "jax[tpu]>=0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install flax optax

import os
print("Installation complete. Please restart the runtime/kernel now.")
# os.kill(os.getpid(), 9) # Uncomment to automatically restart in some environments

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Installation complete. Please restart the runtime/kernel now.


# Pallas Sinkhorn Implementation

This notebook contains the optimized Pallas implementation of the Sinkhorn-Knopp algorithm, including a custom VJP for memory-efficient backward passes.

Running this notebook will write the implementation to `pallas_sinkhorn.py` used by the main model.

In [4]:
%%writefile pallas_sinkhorn.py
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

try:
    from jax.experimental.pallas import tpu as pltpu
except ImportError:
    pltpu = None
import functools
from functools import partial


def resolve_compiler_params(dim_semantics):
    if pltpu is None:
        return None
    if hasattr(pltpu, "CompilerParams"):
        return pltpu.CompilerParams(dimension_semantics=dim_semantics)
    elif hasattr(pltpu, "TPUCompilerParams"):
        return pltpu.TPUCompilerParams(dimension_semantics=dim_semantics)
    return None


# Forward Kernel Tile Level
def tile_sinkhorn_kernel(q_ref, k_ref, v_ref, o_ref, u_ref, v_pot_ref, *, n_iters):
    # Load 128x64 tiles directly into VMEM
    q = q_ref[...]
    k = k_ref[...]
    v = v_ref[...]

    D = q.shape[-1]

    # 1. Local Logits (In-SRAM)
    logits = jnp.dot(q, k.T) / jnp.sqrt(D)

    # Initialize Potentials
    u_acc = jnp.zeros((logits.shape[0], 8), dtype=logits.dtype)
    v_acc = jnp.zeros((logits.shape[1], 8), dtype=logits.dtype)

    # 2. Local Sinkhorn Iterations
    log_alpha = logits
    for _ in range(n_iters):
        # Row norm
        max_r = jnp.max(log_alpha, axis=1, keepdims=True)
        correction_r = max_r + jnp.log(
            jnp.sum(jnp.exp(log_alpha - max_r), axis=1, keepdims=True)
        )
        # Broadcast correction_r explicitly for TPU compilation
        u_acc = u_acc - jnp.broadcast_to(correction_r, u_acc.shape)

        # Col norm
        max_c = jnp.max(log_alpha, axis=0, keepdims=True)
        correction_c = max_c + jnp.log(
            jnp.sum(jnp.exp(log_alpha - max_c), axis=0, keepdims=True)
        )
        log_alpha = log_alpha - correction_c
        
        # Broadcast correction_c explicitly for TPU compilation
        v_acc = v_acc - jnp.broadcast_to(correction_c.T, v_acc.shape)

    # 3. Multiply by V
    o_ref[...] = jnp.dot(jnp.exp(log_alpha), v)
    u_ref[...] = u_acc
    v_pot_ref[...] = v_acc


def pallas_sinkhorn_fwd(Q, K, V, n_iters=20):
    # Q, K, V are [B, H, N, D]
    B, H, L, D = Q.shape
    block_size = 128
    num_blocks = L // block_size
    head_dim = D

    # Chunk into tiles for Flash loading
    q_blocks = Q.reshape(B, H, num_blocks, block_size, head_dim)
    k_blocks = K.reshape(B, H, num_blocks, block_size, head_dim)
    v_blocks = V.reshape(B, H, num_blocks, block_size, head_dim)

    # Define tile computation
    def compute_tile(q_tile, k_tile, v_tile):
        out_shape = jax.ShapeDtypeStruct((block_size, head_dim), Q.dtype)
        u_shape = jax.ShapeDtypeStruct((block_size, 8), Q.dtype)
        v_shape = jax.ShapeDtypeStruct((block_size, 8), Q.dtype)

        is_cpu = jax.devices()[0].platform == "cpu"
        kw = {"interpret": True} if is_cpu else {}

        # Use pallas_call for the tile computation
        return pl.pallas_call(
            functools.partial(tile_sinkhorn_kernel, n_iters=n_iters),
            out_shape=(out_shape, u_shape, v_shape),
            grid=(),  # Single tile
            in_specs=[
                pl.BlockSpec(block_shape=q_tile.shape, index_map=lambda: (0, 0)),
                pl.BlockSpec(block_shape=k_tile.shape, index_map=lambda: (0, 0)),
                pl.BlockSpec(block_shape=v_tile.shape, index_map=lambda: (0, 0)),
            ],
            out_specs=[
                pl.BlockSpec(block_shape=out_shape.shape, index_map=lambda: (0, 0)),
                pl.BlockSpec(block_shape=u_shape.shape, index_map=lambda: (0, 0)),
                pl.BlockSpec(block_shape=v_shape.shape, index_map=lambda: (0, 0)),
            ],
            **kw,
        )(q_tile, k_tile, v_tile)

    # Scan over K/V blocks for a single Q block
    def scan_kv_blocks(acc, kv_tuple):
        k_tile, v_tile = kv_tuple
        q_tile, out_sum, u_list, v_list = acc
        out_tile, u_tile, v_tile = compute_tile(q_tile, k_tile, v_tile)
        return (q_tile, out_sum + out_tile, u_list, v_list), (u_tile, v_tile)

    def process_q_block(q_tile, k_blocks_all, v_blocks_all):
        init_sum = jnp.zeros((block_size, head_dim), dtype=q_tile.dtype)
        
        (q_val, final_sum, _, _), (u_stack, v_stack) = jax.lax.scan(
            scan_kv_blocks, (q_tile, init_sum, None, None), (k_blocks_all, v_blocks_all)
        )

        final_out = final_sum / num_blocks

        u_avg = jnp.mean(u_stack, axis=0)[:, 0]  # [block_size]
        v_all = (
            v_stack[:, :, 0]  # [num_blocks_K, block_size]
        )

        return final_out, u_avg, v_all

    # Vmap across the dimensions
    vmap_q = jax.vmap(process_q_block, in_axes=(0, None, None))
    vmap_heads = jax.vmap(vmap_q, in_axes=(0, 0, 0))
    vmap_batch = jax.vmap(vmap_heads, in_axes=(0, 0, 0))

    # out: [B, H, num_blocks_Q, block_size, D]
    # u:   [B, H, num_blocks_Q, block_size]
    # v:   [B, H, num_blocks_Q, num_blocks_K, block_size]
    out_blocks, u_blocks, v_blocks = vmap_batch(q_blocks, k_blocks, v_blocks)

    out = out_blocks.reshape(B, H, L, D)
    u_out = u_blocks.reshape(B, H, L)

    # v is tricky because it was computed for every Q-block.
    # v should be per-column (K). We should average v across Q-blocks.
    # v_blocks is [B, H, num_blocks_Q, num_blocks_K, block_size]
    v_avg = jnp.mean(
        v_blocks, axis=2
    )  # Average over Q-blocks -> [B, H, num_blocks_K, block_size]
    v_out = v_avg.reshape(B, H, L)

    return out, (u_out, v_out, Q, K, V)



# Backward kernel
def dq_kernel_batched(
    Q_ref, K_ref, V_ref, u_ref, v_ref, dO_ref, lu_ref, lv_ref, dQ_ref
):
    q_block = Q_ref[...]
    u_block = u_ref[...]
    do_block = dO_ref[...]
    lu_block = lu_ref[...]

    D = q_block.shape[-1]
    scale = 1.0 / jnp.sqrt(D)

    N = K_ref.shape[1]
    BLOCK_N = 128  # Safe Block Size
    UNROLL = 4  # Speed optimization
    iters = N // BLOCK_N

    init_state = (jnp.zeros_like(q_block),)

    def body_fn(i, state):
        (dq_acc,) = state

        for u_i in range(UNROLL):
            off = (i * UNROLL + u_i) * BLOCK_N

            k_block = K_ref[:, pl.ds(off, BLOCK_N), :]
            v_pot_block = v_ref[:, pl.ds(off, BLOCK_N)]
            v_mat_block = V_ref[:, pl.ds(off, BLOCK_N), :]
            lv_block = lv_ref[:, pl.ds(off, BLOCK_N)]

            s_block = jnp.einsum("bmd,bnd->bmn", q_block, k_block) * scale
            logits = s_block + u_block[:, :, None] + v_pot_block[:, None, :]
            a_block = jnp.exp(logits)

            z_block = jnp.einsum("bme,bne->bmn", do_block, v_mat_block)
            term_block = z_block - lu_block[:, :, None] - lv_block[:, None, :]
            ds_block = a_block * term_block

            dq_acc = dq_acc + jnp.einsum("bmn,bnd->bmd", ds_block, k_block)

        return (dq_acc,)

    res = jax.lax.fori_loop(0, iters // UNROLL, body_fn, init_state)
    dq_acc = res[0]

    dQ_ref[...] = dq_acc * scale


def dk_dv_kernel_batched(
    Q_ref, K_ref, V_ref, u_ref, v_ref, dO_ref, lu_ref, lv_ref, dK_ref, dV_ref
):
    k_block = K_ref[...]
    v_pot_block = v_ref[...]
    v_mat_block = V_ref[...]
    lv_block = lv_ref[...]

    D = k_block.shape[-1]
    M = Q_ref.shape[1]
    BLOCK_M = 128  # Safe Block Size
    UNROLL = 4  # Speed optimization
    iters = M // BLOCK_M
    scale = 1.0 / jnp.sqrt(D)

    init_state = (jnp.zeros_like(k_block), jnp.zeros_like(v_mat_block))

    def body_fn(i, state):
        dk_acc, dv_acc = state

        for u_i in range(UNROLL):
            off = (i * UNROLL + u_i) * BLOCK_M

            q_block = Q_ref[:, pl.ds(off, BLOCK_M), :]
            u_block = u_ref[:, pl.ds(off, BLOCK_M)]
            do_block = dO_ref[:, pl.ds(off, BLOCK_M), :]
            lu_block = lu_ref[:, pl.ds(off, BLOCK_M)]

            s_block = jnp.einsum("bmd,bnd->bmn", q_block, k_block) * scale
            logits = s_block + u_block[:, :, None] + v_pot_block[:, None, :]
            a_block = jnp.exp(logits)

            z_block = jnp.einsum("bme,bne->bmn", do_block, v_mat_block)
            term_block = z_block - lu_block[:, :, None] - lv_block[:, None, :]
            ds_block = a_block * term_block

            dk_acc = dk_acc + jnp.einsum(
                "bnm,bmd->bnd", ds_block.transpose(0, 2, 1), q_block
            )
            dv_acc = dv_acc + jnp.einsum(
                "bnm,bme->bne", a_block.transpose(0, 2, 1), do_block
            )

        return (dk_acc, dv_acc)

    res = jax.lax.fori_loop(0, iters // UNROLL, body_fn, init_state)
    dk_acc, dv_acc = res

    dK_ref[...] = dk_acc * scale
    dV_ref[...] = dv_acc


def pallas_sinkhorn_bwd_fused(u, v, Q, K, V, dO, lambda_u, lambda_v):
    B, H, M, D = Q.shape
    _, _, N, E = V.shape

    Batch = B * H
    Q_f = Q.reshape(Batch, M, D)
    K_f = K.reshape(Batch, N, D)
    V_f = V.reshape(Batch, N, E)
    u_f = u.reshape(Batch, M)
    v_f = v.reshape(Batch, N)
    dO_f = dO.reshape(Batch, M, E)
    lu_f = lambda_u.reshape(Batch, M)
    lv_f = lambda_v.reshape(Batch, N)

    BLOCK = 128  # Safe Block Size
    BATCH_BLOCK = 8  # Safe Batch Block

    is_cpu = jax.devices()[0].platform == "cpu"
    kw = (
        {"interpret": True}
        if is_cpu
        else {"compiler_params": resolve_compiler_params(("parallel", "parallel"))}
    )

    grid_dq = (Batch // BATCH_BLOCK, M // BLOCK)

    dQ_f = pl.pallas_call(
        dq_kernel_batched,
        out_shape=jax.ShapeDtypeStruct(Q_f.shape, Q_f.dtype),
        grid=grid_dq,
        in_specs=[
            pl.BlockSpec(
                index_map=lambda b, m: (b, m, 0), block_shape=(BATCH_BLOCK, BLOCK, D)
            ),
            pl.BlockSpec(
                index_map=lambda b, m: (b, 0, 0), block_shape=(BATCH_BLOCK, N, D)
            ),
            pl.BlockSpec(
                index_map=lambda b, m: (b, 0, 0), block_shape=(BATCH_BLOCK, N, E)
            ),
            pl.BlockSpec(
                index_map=lambda b, m: (b, m), block_shape=(BATCH_BLOCK, BLOCK)
            ),
            pl.BlockSpec(index_map=lambda b, m: (b, 0), block_shape=(BATCH_BLOCK, N)),
            pl.BlockSpec(
                index_map=lambda b, m: (b, m, 0), block_shape=(BATCH_BLOCK, BLOCK, E)
            ),
            pl.BlockSpec(
                index_map=lambda b, m: (b, m), block_shape=(BATCH_BLOCK, BLOCK)
            ),
            pl.BlockSpec(index_map=lambda b, m: (b, 0), block_shape=(BATCH_BLOCK, N)),
        ],
        out_specs=pl.BlockSpec(
            index_map=lambda b, m: (b, m, 0), block_shape=(BATCH_BLOCK, BLOCK, D)
        ),
        **kw,
    )(Q_f, K_f, V_f, u_f, v_f, dO_f, lu_f, lv_f)

    grid_dk_dv = (Batch // BATCH_BLOCK, N // BLOCK)

    out_shape_k = jax.ShapeDtypeStruct(K_f.shape, K_f.dtype)
    out_shape_v = jax.ShapeDtypeStruct(V_f.shape, V_f.dtype)

    dK_f, dV_f = pl.pallas_call(
        dk_dv_kernel_batched,
        out_shape=(out_shape_k, out_shape_v),
        grid=grid_dk_dv,
        in_specs=[
            pl.BlockSpec(
                index_map=lambda b, n: (b, 0, 0), block_shape=(BATCH_BLOCK, M, D)
            ),
            pl.BlockSpec(
                index_map=lambda b, n: (b, n, 0), block_shape=(BATCH_BLOCK, BLOCK, D)
            ),
            pl.BlockSpec(
                index_map=lambda b, n: (b, n, 0), block_shape=(BATCH_BLOCK, BLOCK, E)
            ),
            pl.BlockSpec(index_map=lambda b, n: (b, 0), block_shape=(BATCH_BLOCK, M)),
            pl.BlockSpec(
                index_map=lambda b, n: (b, n), block_shape=(BATCH_BLOCK, BLOCK)
            ),
            pl.BlockSpec(
                index_map=lambda b, n: (b, 0, 0), block_shape=(BATCH_BLOCK, M, E)
            ),
            pl.BlockSpec(index_map=lambda b, n: (b, 0), block_shape=(BATCH_BLOCK, M)),
            pl.BlockSpec(
                index_map=lambda b, n: (b, n), block_shape=(BATCH_BLOCK, BLOCK)
            ),
        ],
        out_specs=[
            pl.BlockSpec(
                index_map=lambda b, n: (b, n, 0), block_shape=(BATCH_BLOCK, BLOCK, D)
            ),
            pl.BlockSpec(
                index_map=lambda b, n: (b, n, 0), block_shape=(BATCH_BLOCK, BLOCK, E)
            ),
        ],
        **kw,
    )(Q_f, K_f, V_f, u_f, v_f, dO_f, lu_f, lv_f)

    return dQ_f.reshape(Q.shape), dK_f.reshape(K.shape), dV_f.reshape(V.shape)


def pallas_sinkhorn_bwd_full(n_iters, res, g):
    u, v, Q, K, V = res
    dO = g

    # Constants
    scale = 1.0 / jnp.sqrt(Q.shape[-1])
    BLOCK = 128
    M, N = Q.shape[-2], K.shape[-2]

    g_u_init = jnp.zeros_like(u)
    g_v_init = jnp.zeros_like(v)

    def pass1_body(carry, i):
        g_u, g_v = carry

        q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
        u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
        do_block = jax.lax.dynamic_slice_in_dim(dO, i * BLOCK, BLOCK, axis=-2)

        Z_block = jnp.matmul(do_block, V.transpose(0, 1, 3, 2))
        S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
        A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])
        T_block = A_block * Z_block

        gu_contrib = jnp.sum(T_block, axis=-1)
        gv_contrib = jnp.sum(T_block, axis=-2)

        g_u = jax.lax.dynamic_update_slice_in_dim(g_u, gu_contrib, i * BLOCK, axis=-1)
        g_v = g_v + gv_contrib

        return (g_u, g_v), None

    (g_u, g_v), _ = jax.lax.scan(
        pass1_body, (g_u_init, g_v_init), jnp.arange(M // BLOCK)
    )

    lambda_u = jnp.zeros_like(u)
    lambda_v = jnp.zeros_like(v)

    def matvec_A(vec_v):
        def mv_body(carry, i):
            q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
            u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
            S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
            A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])
            res = A_block @ vec_v[..., None]
            return carry, res.squeeze(-1)

        _, res = jax.lax.scan(mv_body, None, jnp.arange(M // BLOCK))
        return jnp.concatenate(res, axis=-1)

    def matvec_AT(vec_u):
        res = jnp.zeros_like(v)

        def mvt_body(acc, i):
            q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
            u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
            vec_u_block = jax.lax.dynamic_slice_in_dim(vec_u, i * BLOCK, BLOCK, axis=-1)
            S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
            A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])
            contrib = jnp.matmul(
                A_block.transpose(0, 1, 3, 2), vec_u_block[..., None]
            ).squeeze(-1)
            return acc + contrib, None

        res, _ = jax.lax.scan(mvt_body, res, jnp.arange(M // BLOCK))
        return res

    def adjoint_step(i, state):
        lu, lv = state
        lu = g_u - matvec_A(lv)
        lv = g_v - matvec_AT(lu)
        return (lu, lv)

    lambda_u, lambda_v = jax.lax.fori_loop(0, 10, adjoint_step, (lambda_u, lambda_v))

    return pallas_sinkhorn_bwd_fused(u, v, Q, K, V, dO, lambda_u, lambda_v)


@partial(jax.custom_vjp, nondiff_argnums=(3,))
def pallas_flash_sinkhorn(Q, K, V, n_iters=20):
    # Use the PURE PALLAS forward pass
    O, _ = pallas_sinkhorn_fwd(Q, K, V, n_iters)
    return O


# Bind the Pallas-accelerated backward pass
pallas_flash_sinkhorn.defvjp(pallas_sinkhorn_fwd, pallas_sinkhorn_bwd_full)


Overwriting pallas_sinkhorn.py


In [5]:
import jax

try:
    from pallas_sinkhorn import pallas_flash_sinkhorn

    key = jax.random.PRNGKey(0)
    B, H, L, D = 1, 4, 1024, 128
    Q = jax.random.normal(key, (B, H, L, D))
    K = jax.random.normal(key, (B, H, L, D))
    V = jax.random.normal(key, (B, H, L, D))

    print("Compiling Pallas Sinkhorn...")
    out = jax.block_until_ready(pallas_flash_sinkhorn(Q, K, V))
    print("Output Shape:", out.shape)
    print("Success!")
except Exception as e:
    print("Verification Failed:", e)

Compiling Pallas Sinkhorn...
Output Shape: (1, 4, 1024, 128)
Success!
