diff --git a/jax/BUILD b/jax/BUILD index 6b00f1bdd691..db0cc0bcced4 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -605,6 +605,19 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "pallas_gpu_ops", + srcs = glob(["experimental/pallas/ops/gpu/**/*.py"]), + visibility = [ + ":pallas_gpu_users", + ], + deps = [ + ":jax", + ":pallas", + ":pallas_gpu", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "pallas_tpu_ops", srcs = glob(["experimental/pallas/ops/tpu/**/*.py"]), diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py new file mode 100644 index 000000000000..7e08836b0f6c --- /dev/null +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -0,0 +1,345 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Module containing decode attention.""" +from __future__ import annotations + +import functools +from typing import Any + +import jax +from jax import lax +from jax.experimental import pallas as pl +import jax.numpy as jnp + + +def attn_forward_kernel( + q_ref, # [num_heads, head_dim] + k_ref, # [k_seq_len, head_dim] + v_ref, # [k_seq_len, head_dim] + o_ref: Any, # [num_heads, head_dim] + *residual_refs: Any, # Residual outputs: [num_heads,], [num_heads,] + sm_scale: float, + block_k: int, +): + block_h, head_dim = q_ref.shape + k_seq_len, _ = k_ref.shape + start_q = pl.program_id(0) + + # o is the buffer where we accumulate the output on sram. + # m_i and l_i (see FlashAttention2 paper) are updated during the k,v loop. + m_i = jnp.zeros(block_h, dtype=jnp.float32) - float("inf") + l_i = jnp.zeros(block_h, dtype=jnp.float32) + o = jnp.zeros((block_h, head_dim), dtype=jnp.float32) + + # Load q: it will stay in L1 throughout. Indices form a matrix because we + # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. + # q tile has shape [block_h, head_dim]. + curr_q_slice = pl.dslice(start_q * block_h, block_h) + q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) + + def _dot(a, b): + # if a.shape[0] == 1: + # # Use matrix vector product + # return (a.T * b).sum(axis=0, keepdims=True) + return pl.dot(a, b) + + # Loop over blocks of kv to process entire kv seq_len. + # Grid loops over q blocks over num_heads. + def body(start_k, carry): + o_prev, m_prev, l_prev = carry + curr_k_slice = pl.dslice(start_k * block_k, block_k) + + k = pl.load(k_ref, (curr_k_slice, slice(None))) + qk = _dot(q, k.T) # [block_h, block_k] + if sm_scale != 1.0: + qk *= sm_scale # [block_h, block_k] + + m_curr = qk.max(axis=-1) + m_next = jnp.maximum(m_prev, m_curr) + correction = jnp.exp(m_prev - m_next) + l_prev_corr = correction * l_prev + s_curr = jnp.exp( + qk - m_next[:, None] + ) # Use m_next instead of m_curr to avoid a correction on l_curr + l_curr = s_curr.sum(axis=-1) + l_next = l_prev_corr + l_curr + v = pl.load(v_ref, (curr_k_slice, slice(None))) + o_curr = _dot(s_curr.astype(v.dtype), v) + + # flash2 unscaled_o + o_next = correction[:, None] * o_prev + o_curr + return o_next, m_next, l_next + + upper_bound = pl.cdiv(k_seq_len, block_k) # type: ignore + # o is left unscaled; it will be scaled in the final reduction step + o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) + + if residual_refs: + l_ref, m_ref = residual_refs + pl.store(l_ref, (curr_q_slice,), l_i) + pl.store(m_ref, (curr_q_slice,), m_i) + # Write output to dram. + o = o.astype(o_ref.dtype) + pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) + + +def attn_unbatched( + q, # [num_heads, head_dim] + k, # [k_seq_len, head_dim] + v, # [k_seq_len, head_dim] + sm_scale: float, + block_h: int, + block_k: int, + k_splits: int, + num_warps: int | None, + num_stages: int, + grid: tuple[int, ...] | None, + interpret: bool, + debug: bool, +): + num_heads, head_dim = q.shape + k_seq_len, _ = k.shape + # Pad num query heads to 16 if needed, and slice output at the end. + original_num_heads = None + if num_heads < 16: + q = jnp.pad(q, ((0, 16 - num_heads), (0, 0))) + original_num_heads = num_heads + num_heads = q.shape[0] + block_h = min(block_h, num_heads) + head_splits = pl.cdiv(num_heads, block_h) + grid_ = grid + if grid_ is None: + grid_ = (head_splits, k_splits) + + assert ( + k_seq_len % k_splits == 0 + ), f"{k_seq_len=} must be divisible by {k_splits=}" + k = k.reshape(k_splits, k_seq_len // k_splits, head_dim) + v = v.reshape(k_splits, k_seq_len // k_splits, head_dim) + k_seq_len = k_seq_len // k_splits + assert min(num_heads, head_dim, k_seq_len) >= 16, "Minimum pl.dot size is 16" + block_k = min(block_k, k_seq_len) + num_warps_ = num_warps + if num_warps_ is None: + num_warps_ = 4 + kernel = functools.partial( + attn_forward_kernel, + sm_scale=sm_scale, + block_k=block_k, + ) + + o, l, m = pl.pallas_call( + kernel, + grid=grid_, + in_specs=[ + pl.BlockSpec(lambda i, j: (i, 0), (block_h, head_dim)), + pl.BlockSpec(lambda i, j: (j, 0, 0), (None, k_seq_len, head_dim)), + pl.BlockSpec(lambda i, j: (j, 0, 0), (None, k_seq_len, head_dim)), + ], + out_specs=[ + pl.BlockSpec(lambda i, j: (j, i, 0), (None, block_h, head_dim)), # o + pl.BlockSpec( + lambda i, j: (j, i), + ( + None, + block_h, + ), + ), # l + pl.BlockSpec( + lambda i, j: (j, i), + ( + None, + block_h, + ), + ), # m + ], + num_warps=num_warps_, + num_stages=num_stages, + out_shape=[ + jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # l + jax.ShapeDtypeStruct( + shape=(k_splits, num_heads), dtype=jnp.float32 + ), # m + ], + debug=debug, + interpret=interpret, + name="mha_forward", + )(q, k, v) + + # final round of flash + m_next = m.max(axis=0) + correction = jnp.exp(m - m_next[None]) + o = o * correction[:, :, None] + l_next = (l * correction).sum(axis=0) + o = o.sum(axis=0) / l_next[:, None] + + if original_num_heads is not None: + o = o[:original_num_heads, :] + return o + + +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "block_h", + "block_k", + "k_splits", + "num_warps", + "num_stages", + "grid", + "interpret", + "debug", + ], +) +def mqa( + q, # [batch_size, num_heads, head_dim] + k, # [batch_size, k_seq_len, head_dim] + v, # [batch_size, k_seq_len, head_dim] + sm_scale: float = 1.0, + block_h: int = 16, + block_k: int = 256, + k_splits: int = 16, + num_warps: int | None = None, + num_stages: int = 2, + grid: tuple[int, ...] | None = None, + interpret: bool = False, + debug: bool = False, +): + inner = functools.partial( + attn_unbatched, + sm_scale=sm_scale, + block_h=block_h, + block_k=block_k, + k_splits=k_splits, + num_warps=num_warps, + num_stages=num_stages, + grid=grid, + interpret=interpret, + debug=debug, + ) + return jax.vmap(inner)(q, k, v) + + +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "block_h", + "block_k", + "k_splits", + "num_warps", + "num_stages", + "grid", + "interpret", + "debug", + ], +) +def gqa( + q, # [batch_size, num_q_heads, head_dim] + k, # [batch_size, k_seq_len, num_kv_heads, head_dim] + v, # [batch_size, k_seq_len, num_kv_heads, head_dim] + sm_scale: float = 1.0, + block_h: int = 16, + block_k: int = 256, + k_splits: int = 16, + num_warps: int | None = None, + num_stages: int = 2, + grid: tuple[int, ...] | None = None, + interpret: bool = False, + debug: bool = False, +): + batch_size, q_heads, head_dim = q.shape + kv_heads = k.shape[2] + assert kv_heads == v.shape[2] + assert q_heads % kv_heads == 0 + q_heads_per_kv_head = q_heads // kv_heads + q_reshaped = q.reshape(batch_size, kv_heads, q_heads_per_kv_head, head_dim) + k_transposed = jnp.swapaxes( + k, 1, 2 + ) # [batch_size, num_kv_heads, k_seq_len, head_dim] + v_transposed = jnp.swapaxes( + v, 1, 2 + ) # [batch_size, num_kv_heads, k_seq_len, head_dim] + inner = functools.partial( + attn_unbatched, + sm_scale=sm_scale, + block_h=block_h, + block_k=block_k, + k_splits=k_splits, + num_warps=num_warps, + num_stages=num_stages, + grid=grid, + interpret=interpret, + debug=debug, + ) + with_kv_heads = jax.vmap(inner) + o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed) + return o.reshape(batch_size, q_heads, head_dim) + + +@functools.partial(jax.jit, static_argnames=["sm_scale"]) +def mqa_reference( + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, head_dim] + v, # [bs, k_seq_len, head_dim] + sm_scale=1.0, +): + logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32) + weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) + return jnp.einsum("bns,bsd->bnd", weights, v) + + +@functools.partial(jax.jit, static_argnames=["sm_scale"]) +def mha_reference( + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + sm_scale=1.0, +): + assert q.shape[1] == k.shape[2] + logits = jnp.einsum("bnd,bsnd->bns", q, k).astype(jnp.float32) + weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) + return jnp.einsum("bns,bsnd->bnd", weights, v) + + +@functools.partial(jax.jit, static_argnames=["sm_scale"]) +def gqa_reference( + q, # [bs, num_q_heads, head_dim] + k, # [bs, k_seq_len, num_k_heads, head_dim] + v, # [bs, k_seq_len, num_v_heads, head_dim] + sm_scale=1.0, +): + bs, num_q_heads, head_dim = q.shape + num_kv_heads = k.shape[2] + assert num_q_heads % num_kv_heads == 0 + q_reshaped = q.reshape( + bs, num_kv_heads, num_q_heads // num_kv_heads, head_dim + ) + k_transposed = jnp.swapaxes( + k, 1, 2 + ) # [batch_size, num_kv_heads, k_seq_len, head_dim] + v_transposed = jnp.swapaxes( + v, 1, 2 + ) # [batch_size, num_kv_heads, k_seq_len, head_dim] + logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype( + jnp.float32 + ) + weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) + o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed) + return o.reshape(bs, num_q_heads, head_dim) diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index af380875202e..20e30295d993 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -57,6 +57,39 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) +jax_test( + name = "gpu_attention_test", + srcs = [ + "gpu_attention_test.py", + ], + backend_tags = { + "gpu": ["noasan"], # https://github.com/openai/triton/issues/2918 + }, + config_tags_overrides = { + "gpu_x32": { + "ondemand": False, # Include in presubmit. + }, + }, + disable_backends = [ + "cpu", + "tpu", + ], + disable_configs = [ + "gpu", + "gpu_a100", + "gpu_p100", + ], + enable_configs = [ + "gpu_x32", + "gpu_a100_x32", + ], + shard_count = 1, + deps = [ + "//third_party/py/jax:pallas_gpu", + "//third_party/py/jax:pallas_gpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + jax_test( name = "indexing_test", srcs = [ diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py new file mode 100644 index 000000000000..3ff740227a24 --- /dev/null +++ b/tests/pallas/gpu_attention_test.py @@ -0,0 +1,159 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import random +from jax._src import config +from jax._src import test_util as jtu +from jax.experimental.pallas.ops.gpu import decode_attention +import jax.numpy as jnp +import numpy as np + + +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" + +try: + from jax.experimental.pallas import gpu as plgpu +except ImportError: + pass +# pylint: disable=no-value-for-parameter + + +config.update("jax_traceback_filtering", "off") +config.parse_flags_with_absl() + + +class PallasTest(jtu.JaxTestCase): + + def check_gpu_capability_at_least(self, capability, device: int = 0): + return plgpu.get_compute_capability(device) >= capability + + def setUp(self): + if not jtu.test_device_matches(["gpu"]): + self.skipTest("Only works on GPU") + try: + import triton # noqa: F401 + except ImportError: + self.skipTest("Triton is not installed. Skipping PallasTest.") + super().setUp() + +class DecodeAttentionTest(PallasTest): + + @parameterized.named_parameters(*[ + ( + f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}", + batch_size, + seq_len, + num_heads, + head_dim, + kwargs, + ) + for ( + batch_size, + seq_len, + num_heads, + head_dim, + kwargs, + ) in [ + (1, 1024, 1, 64, {}), + (2, 1024, 2, 64, {}), + (1, 1024, 8, 64, {}), + ] + ]) + @jax.numpy_dtype_promotion("standard") + def test_mqa( + self, + batch_size, + seq_len, + num_heads, + head_dim, + kwargs, + ): + del kwargs + if not self.check_gpu_capability_at_least(80): + raise unittest.SkipTest( + "Fused attention only works on GPUs with capability >= sm80" + ) + + k1, k2, k3 = random.split(random.key(0), 3) + q = random.normal(k1, (batch_size, num_heads, head_dim), dtype=jnp.float16) + k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16) + v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16) + + o = decode_attention.mqa(q, k, v) + o_ref = decode_attention.mqa_reference(q, k, v) + np.testing.assert_allclose(o, o_ref, atol=0.05) + + @parameterized.named_parameters(*[ + ( + f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}_{kwargs=}", + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + kwargs, + ) + for ( + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + kwargs, + ) in [ + (1, 1024, 16, 4, 64, {}), + (1, 1024, 16, 16, 64, {}), + (1, 1024, 32, 32, 64, {}), + ] + ]) + @jax.numpy_dtype_promotion("standard") + def test_gqa( + self, + batch_size, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + kwargs, + ): + del kwargs + if not self.check_gpu_capability_at_least(80): + raise unittest.SkipTest( + "Fused attention only works on GPUs with capability >= sm80" + ) + + k1, k2, k3 = random.split(random.key(0), 3) + q = random.normal( + k1, (batch_size, num_q_heads, head_dim), dtype=jnp.float16 + ) + k = random.normal( + k2, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16 + ) + v = random.normal( + k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16 + ) + + o = decode_attention.gqa(q, k, v) + o_ref = decode_attention.gqa_reference(q, k, v) + np.testing.assert_allclose(o, o_ref, atol=0.05) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader())