In [1]:
import dataclasses

import jax
import jax.numpy as jnp


### Jax Pallas Experimental Flash Attention

Github: https://github.com/jax-ml/jax/blob/922935a916dbcf599226f1ce3081feb6481328c3/jax/experimental/pallas/ops/gpu/attention.py

In [2]:
rng = jax.random.PRNGKey(0)

max_seq_len = 64
B, T, nH, C = 32, 15, 8, 128

q = jax.random.normal(rng, (B, T, nH, C), dtype=jnp.bfloat16)
k = jax.random.normal(rng, (B, T, nH, C), dtype=jnp.bfloat16)
v = jax.random.normal(rng, (B, T, nH, C), dtype=jnp.bfloat16)

In [3]:
from jax.experimental.pallas.ops.gpu.attention import mha

@dataclasses.dataclass(frozen=True, slots=True)
class BlockSizes:
  block_q: int
  block_k: int

pad_width = ((0, 0),  # no padding on the first dimension
             (0, max_seq_len-T),
             (0, 0),
             (0, 0))  # pad two zeros on the right side of the second dimension

q = jnp.pad(q, pad_width, mode='constant', constant_values=0)
k = jnp.pad(k, pad_width, mode='constant', constant_values=0)
v = jnp.pad(v, pad_width, mode='constant', constant_values=0)

block_sizes = BlockSizes(block_q=T, block_k=T)

#y = mha(q, k, v, block_sizes=block_sizes, segment_ids=None, causal=True)
%timeit mha(q, k, v, segment_ids=None, causal=True)

684 μs ± 53.7 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### FlashAttention Jax

Github: https://github.com/nshepperd/flash_attn_jax

In [4]:
!export CUDA_HOME=/usr/
!uv add flash-attn-jax

  pid, fd = os.forkpty()


[2K[2mResolved [1m186 packages[0m [2min 137ms[0m[0m                                       [0m
[2K  [31m×[0m Failed to build `flash-attn-jax==0.2.2`                                             
[31m  ├─▶ [0mThe build backend returned an error
[31m  ╰─▶ [0mCall to `setuptools.build_meta:__legacy__.build_wheel` failed (exit
[31m      [0mstatus: 1)

[31m      [0m[31m[stderr][39m
[31m      [0mfatal: invalid gitfile format: /home/ubuntu/.cache/uv/sdists-v8/.git
[31m      [0mTraceback (most recent call last):
[31m      [0m  File [35m"<string>"[0m, line [35m14[0m, in [35m<module>[0m
[31m      [0m    requires = get_requires_for_build({})
[31m      [0m  File
[31m      [0m[35m"/home/ubuntu/.cache/uv/builds-v0/.tmp82FD9r/lib/python3.13/site-packages/setuptools/build_meta.py"[0m,
[31m      [0mline [35m334[0m, in [35mget_requires_for_build_wheel[0m
[31m      [0m    return [31mself._get_build_requires[0m[1;31m(config_settings, requirements=[])[0m


In [5]:
from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
%timeit flash_mha(q,k,v,softmax_scale=None, is_causal=True, window_size=(-1,-1))

ModuleNotFoundError: No module named 'flash_attn_jax'

### Flash Attention - Jax
Github: https://github.com/lucidrains/flash-attention-jax

In [None]:
!pip install flash-attention-jax

In [None]:
from flash_attention_jax import causal_flash_attention

rng_key = jax.random.PRNGKey(42)

%timeit causal_flash_attention(q, k, v)

### Kvax

Github: https://github.com/nebius/kvax

In [None]:
!pip install kvax

### Nvidia TransformerEngine

Github: https://github.com/NVIDIA/TransformerEngine

In [None]:
!pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

