# Setup

Run the following cell to install JAX with TPU support and other dependencies. **After running this cell, you must restart the kernel** for the changes to take effect.

In [1]:
# Install JAX for TPU
!pip install "jax[tpu]>=0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install flax optax

import os
print("Installation complete. Please restart the runtime/kernel now.")
# os.kill(os.getpid(), 9) # Uncomment to automatically restart in some environments

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html
Collecting libtpu==0.0.23 (from jax[tpu]>=0.4.30)
  Downloading libtpu-0.0.23-cp312-cp312-manylinux_2_31_x86_64.whl.metadata (1.1 kB)
Downloading libtpu-0.0.23-cp312-cp312-manylinux_2_31_x86_64.whl (155.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.1/155.1 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: libtpu
  Attempting uninstall: libtpu
    Found existing installation: libtpu 0.0.21
    Uninstalling libtpu-0.0.21:
      Successfully uninstalled libtpu-0.0.21
Successfully installed libtpu-0.0.23
Installation complete. Please restart the runtime/kernel now.


# JAX Flash Sinkhorn Implementation

This notebook contains the JAX Scan-based implementation of the Sinkhorn-Knopp algorithm, used as a reference for validation.

Running this notebook will write the implementation to `flash_sinkhorn.py`.

In [2]:
%%writefile flash_sinkhorn.py
import jax
import jax.numpy as jnp
from functools import partial


@jax.jit
def sinkhorn_fwd(Q, K, V, n_iters=20):
    # Q: [B, H, M, D]
    # K: [B, H, N, D]
    # V: [B, H, N, E]
    # Returns: O, (u, v, Q, K, V)

    scale = 1.0 / jnp.sqrt(Q.shape[-1])
    # Initial potentials (zeros)
    u = jnp.zeros(Q.shape[:-1])  # [B, H, M]
    v = jnp.zeros(K.shape[:-1])  # [B, H, N]

    def body_fn(i, val):
        u, v = val
        # Row update
        S_v = jnp.einsum("bhmd,bhnd->bhmn", Q, K) * scale + v[..., None, :]
        u = -jax.scipy.special.logsumexp(S_v, axis=-1)

        # Col update
        S_u = jnp.einsum("bhmd,bhnd->bhmn", Q, K) * scale + u[..., :, None]
        v = -jax.scipy.special.logsumexp(S_u, axis=-2)
        return (u, v)

    u, v = jax.lax.fori_loop(0, n_iters, body_fn, (u, v))

    # Compute Output Blockwise
    BLOCK = 128
    M, N = Q.shape[-2], K.shape[-2]

    def scan_fn(carry, i):
        q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
        u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
        
        S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
        A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])
        O_block = A_block @ V
        return carry, O_block

    _, O = jax.lax.scan(scan_fn, None, jnp.arange(M // BLOCK))
    O = jnp.concatenate(O, axis=-2)

    return O, (u, v, Q, K, V)


# Backward Pass
def sinkhorn_bwd(n_iters, res, g):
    # n_iters passed due to nondiff_argnums
    # g is dO
    u, v, Q, K, V = res
    dO = g
    scale = 1.0 / jnp.sqrt(Q.shape[-1])
    BLOCK = 128
    M, N = Q.shape[-2], K.shape[-2]

    # --- Pass 1: Adjoints & dV ---
    g_u = jnp.zeros_like(u)
    g_v = jnp.zeros_like(v)
    dV = jnp.zeros_like(V)

    def pass1_body(carry_state, i):
        g_u, g_v, dV = carry_state

        q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
        u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
        do_block = jax.lax.dynamic_slice_in_dim(dO, i * BLOCK, BLOCK, axis=-2)

        Z_block = jnp.matmul(do_block, V.transpose(0, 1, 3, 2))
        S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
        A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])

        dV_contrib = jnp.matmul(A_block.transpose(0, 1, 3, 2), do_block)

        T_block = A_block * Z_block
        gu_contrib = jnp.sum(T_block, axis=-1)
        gv_contrib = jnp.sum(T_block, axis=-2)

        g_u = jax.lax.dynamic_update_slice_in_dim(g_u, gu_contrib, i * BLOCK, axis=-1)
        g_v = g_v + gv_contrib
        dV = dV + dV_contrib

        return (g_u, g_v, dV), None

    (g_u, g_v, dV), _ = jax.lax.scan(pass1_body, (g_u, g_v, dV), jnp.arange(M // BLOCK))

    # --- Solve Adjoints ---
    lambda_u = jnp.zeros_like(u)
    lambda_v = jnp.zeros_like(v)

    def matvec_A(vec_v):
        def mv_body(carry, i):
            q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
            u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
            S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
            A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])
            # vec_v: [B,H,N]. We need [B,H,N,1] for matmul
            res = A_block @ vec_v[..., None]
            return carry, res.squeeze(-1)

        _, res = jax.lax.scan(mv_body, None, jnp.arange(M // BLOCK))
        return jnp.concatenate(res, axis=-1)

    def matvec_AT(vec_u):
        res = jnp.zeros_like(v)

        def mvt_body(acc, i):
            q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
            u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
            vec_u_block = jax.lax.dynamic_slice_in_dim(vec_u, i * BLOCK, BLOCK, axis=-1)
            S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
            A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])
            contrib = jnp.matmul(
                A_block.transpose(0, 1, 3, 2), vec_u_block[..., None]
            ).squeeze(-1)
            return acc + contrib, None

        res, _ = jax.lax.scan(mvt_body, res, jnp.arange(M // BLOCK))
        return res

    def adjoint_step(i, state):
        lu, lv = state
        lu = g_u - matvec_A(lv)
        lv = g_v - matvec_AT(lu)
        return (lu, lv)

    lambda_u, lambda_v = jax.lax.fori_loop(0, 10, adjoint_step, (lambda_u, lambda_v))

    # --- Pass 2: dQ, dK ---
    dQ = jnp.zeros_like(Q)
    dK = jnp.zeros_like(K)

    def pass2_body(carry, i):
        dQ_acc, dK_acc = carry

        q_block = jax.lax.dynamic_slice_in_dim(Q, i * BLOCK, BLOCK, axis=-2)
        u_block = jax.lax.dynamic_slice_in_dim(u, i * BLOCK, BLOCK, axis=-1)
        do_block = jax.lax.dynamic_slice_in_dim(dO, i * BLOCK, BLOCK, axis=-2)
        lu_block = jax.lax.dynamic_slice_in_dim(lambda_u, i * BLOCK, BLOCK, axis=-1)

        Z_block = jnp.matmul(do_block, V.transpose(0, 1, 3, 2))
        S_block = jnp.einsum("bhmd,bhnd->bhmn", q_block, K) * scale
        A_block = jnp.exp(S_block + u_block[..., :, None] + v[..., None, :])

        dS_block = A_block * (Z_block - lu_block[..., :, None] - lambda_v[..., None, :])

        dq_contribution = jnp.matmul(dS_block, K) * scale
        dk_contribution = jnp.matmul(dS_block.transpose(0, 1, 3, 2), q_block) * scale

        dQ_acc = jax.lax.dynamic_update_slice_in_dim(
            dQ_acc, dq_contribution, i * BLOCK, axis=-2
        )
        dK_acc = dK_acc + dk_contribution

        return (dQ_acc, dK_acc), None

    (dQ, dK), _ = jax.lax.scan(pass2_body, (dQ, dK), jnp.arange(M // BLOCK))

    return dQ, dK, dV


@partial(jax.custom_vjp, nondiff_argnums=(3,))
def flash_sinkhorn(Q, K, V, n_iters=20):
    O, _ = sinkhorn_fwd(Q, K, V, n_iters)
    return O


flash_sinkhorn.defvjp(sinkhorn_fwd, sinkhorn_bwd)


Writing flash_sinkhorn.py


In [4]:
import jax

try:
    from flash_sinkhorn import flash_sinkhorn

    key = jax.random.PRNGKey(0)
    B, H, L, D = 1, 4, 1024, 64
    Q = jax.random.normal(key, (B, H, L, D))
    K = jax.random.normal(key, (B, H, L, D))
    V = jax.random.normal(key, (B, H, L, D))

    print("Compiling Flash Sinkhorn...")
    out = jax.block_until_ready(flash_sinkhorn(Q, K, V))
    print("Output Shape:", out.shape)
    print("Success!")
except Exception as e:
    print("Verification Failed:", e)

Compiling Flash Sinkhorn...
Output Shape: (1, 4, 1024, 64)
Success!
