In [4]:
import dataclasses

import jax
import jax.numpy as jnp


### Jax Pallas Experimental Flash Attention

Github: https://github.com/jax-ml/jax/blob/922935a916dbcf599226f1ce3081feb6481328c3/jax/experimental/pallas/ops/gpu/attention.py

In [25]:
rng = jax.random.PRNGKey(0)

max_seq_len = 64
B, T, nH, C = 32, 15, 8, 128

q = jax.random.normal(rng, (B, T, nH, C), dtype=jnp.bfloat16)
k = jax.random.normal(rng, (B, T, nH, C), dtype=jnp.bfloat16)
v = jax.random.normal(rng, (B, T, nH, C), dtype=jnp.bfloat16)

In [15]:
from jax.experimental.pallas.ops.gpu.attention import mha

@dataclasses.dataclass(frozen=True, slots=True)
class BlockSizes:
  block_q: int
  block_k: int

pad_width = ((0, 0),  # no padding on the first dimension
             (0, max_seq_len-15),
             (0, 0),
             (0, 0))  # pad two zeros on the right side of the second dimension

q = jnp.pad(q, pad_width, mode='constant', constant_values=0)
k = jnp.pad(k, pad_width, mode='constant', constant_values=0)
v = jnp.pad(v, pad_width, mode='constant', constant_values=0)

block_sizes = BlockSizes(block_q=T, block_k=T)

y = mha(q, k, v, block_sizes=block_sizes, segment_ids=None, causal=True)

TypeError: got an unexpected keyword argument 'block_sizes'

### FlashAttention Jax

Github: https://github.com/nshepperd/flash_attn_jax

In [1]:
!pip install flash-attn-jax

Collecting flash-attn-jax
  Downloading flash_attn_jax-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.6 kB)
Downloading flash_attn_jax-0.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (75.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.4/75.4 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: flash-attn-jax
Successfully installed flash-attn-jax-0.2.2


In [30]:
from flash_attn_jax import flash_mha

# flash_mha : [n, l, h, d] x [n, lk, hk, d] x [n, lk, hk, d] -> [n, l, h, d]
%timeit flash_mha(q,k,v,softmax_scale=None, is_causal=True, window_size=(-1,-1))

393 µs ± 2.29 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### Flash Attention - Jax
Github: https://github.com/lucidrains/flash-attention-jax

In [18]:
!pip install flash-attention-jax

Collecting flash-attention-jax
  Downloading flash_attention_jax-0.3.1-py3-none-any.whl.metadata (683 bytes)
Downloading flash_attention_jax-0.3.1-py3-none-any.whl (10 kB)
Installing collected packages: flash-attention-jax
Successfully installed flash-attention-jax-0.3.1


In [31]:
from flash_attention_jax import causal_flash_attention

rng_key = jax.random.PRNGKey(42)

%timeit causal_flash_attention(q, k, v)

262 µs ± 15.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


### Kvax

Github: https://github.com/nebius/kvax

In [32]:
!pip install kvax

Collecting kvax
  Downloading kvax-0.1.0-py3-none-any.whl.metadata (31 kB)
Collecting jax>=0.4.34 (from kvax)
  Downloading jax-0.5.2-py3-none-any.whl.metadata (22 kB)
Collecting jax-triton>=0.2.0 (from kvax)
  Downloading jax_triton-0.2.0-py3-none-any.whl.metadata (3.0 kB)
Collecting jaxlib>=0.4.27 (from chex>=0.1.85->kvax)
  Downloading jaxlib-0.5.1-cp311-cp311-manylinux2014_x86_64.whl.metadata (978 bytes)
Downloading kvax-0.1.0-py3-none-any.whl (41 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.0/42.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jax-0.5.2-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m88.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jax_triton-0.2.0-py3-none-any.whl (27 kB)
Downloading jaxlib-0.5.1-cp311-cp311-manylinux2014_x86_64.whl (105.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.1/105.1 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[