From 5a578cba197208c73081076179871d86123c6686 Mon Sep 17 00:00:00 2001 From: Tao Wang Date: Tue, 29 Aug 2023 13:58:37 -0700 Subject: [PATCH] Add segment_ids support to pallas flash attention on GPU. PiperOrigin-RevId: 561130379 --- jax/experimental/pallas/ops/attention.py | 361 ++++++++++++++++++----- tests/pallas/pallas_test.py | 161 +++++++--- 2 files changed, 403 insertions(+), 119 deletions(-) diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 2cf5cdddb98e..0a9ef0de3f72 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -13,22 +13,33 @@ # limitations under the License. """Module containing fused attention forward and backward pass.""" -import functools +from __future__ import annotations +import functools from typing import Any, Optional import jax -import jax.numpy as jnp from jax import lax - from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) + def mha_forward_kernel( - q_ref, k_ref, v_ref, # Input arrays - o_ref, # Output - *residual_refs, # Residual outputs - sm_scale: float, causal: bool, - block_q: int, block_d: int, block_k: int): + q_ref, + k_ref, + v_ref, # Input arrays + segment_ids_ref: jax.Array | None, # segment_id arrays + o_ref: Any, # Output + *residual_refs: Any, # Residual outputs + sm_scale: float, + causal: bool, + block_q: int, + block_d: int, + block_k: int, +): seq_len = q_ref.shape[0] start_q = pl.program_id(0) @@ -43,6 +54,11 @@ def mha_forward_kernel( # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) + q_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.dslice(start_q * block_q, block_q),)) + ) # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). # Here we only loop over blocks of kv to process entire seq_len, the loop over @@ -51,18 +67,33 @@ def body(start_k, carry): acc, m_prev, l_prev = carry k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.dslice(start_k * block_k, block_k),)) + ) qk = jnp.zeros([block_q, block_k], dtype=jnp.float32) qk += pl.dot(q, k.T) # [block_q, block_k] if sm_scale != 1.: qk *= sm_scale # [block_q, block_k] - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - span_k = start_k * block_k + jnp.arange(block_k) - qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) # Bring closer to XLA:GPU numerics. qk = qk.astype(q_ref.dtype) qk = qk.astype(jnp.float32) + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + mask = segment_mask(q_segment_ids, kv_segment_ids) + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + span_k = start_k * block_k + jnp.arange(block_k) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + # Apply mask to qk. + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) l_prev *= jnp.exp(m_prev - m_curr) p = jnp.exp(qk - m_curr[:, None]) @@ -91,22 +122,55 @@ def body(start_k, carry): acc = acc.astype(o_ref.dtype) pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc) -@functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) -@functools.partial(jax.jit, static_argnames=["sm_scale", "causal", "block_q", "block_k", - "backward_pass_impl", - "num_warps", "num_stages", "grid", - "interpret", "debug"]) -def mha(q, k, v, - sm_scale: float = 1.0, - causal: bool = False, - block_q: int = 128, - block_k: int = 128, - backward_pass_impl: str = "triton", - num_warps: Optional[int] = None, - num_stages: int = 2, - grid=None, - interpret: bool = False, - debug: bool = False): + +def segment_mask( + q_segment_ids: jax.Array, + kv_segment_ids: jax.Array, +): + # [B, T, 1] or [T, 1] + q_segment_ids = jnp.expand_dims(q_segment_ids, axis=-1) + # [B, 1, S] or [1, S] + if kv_segment_ids.ndim == 1: + kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=0) + else: + kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=1) + return jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_) + + +@functools.partial( + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13] +) +@functools.partial( + jax.jit, + static_argnames=[ + "sm_scale", + "causal", + "block_q", + "block_k", + "backward_pass_impl", + "num_warps", + "num_stages", + "grid", + "interpret", + "debug", + ], +) +def mha( + q, + k, + v, + segment_ids: jnp.ndarray | None, + sm_scale: float = 1.0, + causal: bool = False, + block_q: int = 128, + block_k: int = 128, + backward_pass_impl: str = "triton", + num_warps: Optional[int] = None, + num_stages: int = 2, + grid=None, + interpret: bool = False, + debug: bool = False, +): del backward_pass_impl batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) @@ -123,26 +187,56 @@ def mha(q, k, v, block_q=block_q, block_k=block_k, block_d=head_dim, causal=causal) + + in_specs = [ + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + ] + in_specs.append( + None # type: ignore[arg-type] + if segment_ids is None + else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + ) out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) return pl.pallas_call( kernel, grid=grid_, - in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - ], - out_specs=pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + in_specs=in_specs, + out_specs=pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), num_warps=num_warps_, num_stages=num_stages, out_shape=out_shape, debug=debug, interpret=interpret, - name="mha_forward")(q, k, v) + name="mha_forward", + )(q, k, v, segment_ids) + -def _mha_forward(q, k, v, sm_scale: float, causal: bool, block_q: int, - block_k: int, backward_pass_impl: str, num_warps: Optional[int], - num_stages: int, grid: Any, interpret: bool, debug: bool): +def _mha_forward( + q, + k, + v, + segment_ids: jax.Array | None, + sm_scale: float, + causal: bool, + block_q: int, + block_k: int, + backward_pass_impl: str, + num_warps: Optional[int], + num_stages: int, + grid: Any, + interpret: bool, + debug: bool, +): del backward_pass_impl batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) @@ -165,26 +259,42 @@ def _mha_forward(q, k, v, sm_scale: float, causal: bool, block_q: int, jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m dtype=jnp.float32) ] + in_specs = [ + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + ] + in_specs.append( + None # type: ignore[arg-type] + if segment_ids is None + else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + ) out, l, m = pl.pallas_call( kernel, grid=grid_, - in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - ], + in_specs=in_specs, out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], num_warps=num_warps_, num_stages=num_stages, out_shape=out_shape, debug=debug, interpret=interpret, - name="mha_forward")(q, k, v) - return out, (q, k, v, out, l, m) + name="mha_forward", + )(q, k, v, segment_ids) + return out, (q, k, v, segment_ids, out, l, m) + def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, new_dout_ref, delta_ref, *, @@ -231,14 +341,29 @@ def _preprocess_backward(out, do, l, block_q: int, name="mha_preprocess_backward")(out, do, l) return do_scaled, delta + def mha_backward_kernel( # Inputs - q_ref, k_ref, v_ref, out_ref, do_scaled_ref, - l_ref, m_ref, delta_ref, _, + q_ref, + k_ref, + v_ref, + segment_ids_ref: jax.Array | None, + out_ref, + do_scaled_ref, + l_ref, + m_ref, + delta_ref, + _, # Outputs - dq_ref, dk_ref, dv_ref, - *, sm_scale: float, causal: bool, - block_q: int, block_d: int, block_k: int + dq_ref, + dk_ref, + dv_ref, + *, + sm_scale: float, + causal: bool, + block_q: int, + block_d: int, + block_k: int, ): del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] @@ -250,6 +375,11 @@ def outer_loop(start_k, _): k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) span_k = start_k * block_k + jnp.arange(block_k) + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),)) + ) def inner_loop(start_q, carry): dv, dk = carry @@ -259,9 +389,28 @@ def inner_loop(start_q, carry): qk = qk.astype(jnp.float32) if sm_scale != 1.0: qk *= sm_scale - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + + q_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),)) + ) + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + mask = segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask + if mask is None + else jnp.logical_and(mask, causal_mask) + ) + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) p = jnp.exp(qk - m[:, None]) do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) @@ -291,12 +440,13 @@ def inner_loop(start_q, carry): slice(None)), dk.astype(dk_ref.dtype)) lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None) + def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, backward_pass_impl: str, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid - q, k, v, out, l, m = res + q, k, v, segment_ids, out, l, m = res batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) @@ -304,8 +454,13 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) if backward_pass_impl == "xla": - return jax.vjp(functools.partial(mha_reference, sm_scale=sm_scale, - causal=causal), q, k, v)[1](do) + return jax.vjp( + functools.partial(mha_reference, sm_scale=sm_scale, causal=causal), + q, + k, + v, + segment_ids, + )[1](do) elif backward_pass_impl == "triton": # We accumulate into dq so we need to initialize it to zeros. dq = jnp.zeros(q.shape, jnp.float32) @@ -315,50 +470,94 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, jax.ShapeDtypeStruct(v.shape, v.dtype), ] + in_specs = [ + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + ] + if segment_ids is None: + in_specs.insert(3, None) # type: ignore[arg-type] + input_output_aliases = {8: 0} + else: + in_specs.insert(3, pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len))) + input_output_aliases = {9: 0} grid = (batch_size, num_heads) # TODO(sharadmv): figure out why num_warps=8 doesn't work! num_warps = 4 dq, dk, dv = pl.pallas_call( - functools.partial(mha_backward_kernel, block_q=block_q, block_d=head_dim, - block_k=block_k, sm_scale=sm_scale, causal=causal), + functools.partial( + mha_backward_kernel, + block_q=block_q, + block_d=head_dim, + block_k=block_k, + sm_scale=sm_scale, + causal=causal, + ), grid=grid, out_shape=out_shapes, - in_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - ], + in_specs=in_specs, out_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), + pl.BlockSpec( + lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) + ), ], name="mha_backward", debug=debug, interpret=interpret, num_warps=num_warps, num_stages=1, - input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq) + input_output_aliases=input_output_aliases, + )(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") - return dq.astype(q.dtype), dk, dv + return dq.astype(q.dtype), dk, dv, None mha.defvjp(_mha_forward, _mha_backward) @functools.partial(jax.jit, static_argnames=['sm_scale', 'causal']) -def mha_reference(q, k, v, sm_scale=1.0, causal: bool = False): +def mha_reference( + q, + k, + v, + segment_ids: jnp.ndarray | None, + sm_scale=1.0, + causal: bool = False, +): q_seq_len = q.shape[1] kv_seq_len = k.shape[1] logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32) - if causal: - mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool)) + mask = None + if segment_ids is not None: + mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1) mask = jnp.broadcast_to(mask, logits.shape) - logits = jnp.where(mask, logits, float('-inf')) + if causal: + causal_mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool)) + causal_mask = jnp.broadcast_to(causal_mask, logits.shape) + mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + logits = logits if mask is None else jnp.where(mask, logits, float("-inf")) weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) return jnp.einsum('bhqk,bkhc->bqhc', weights, v) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index c7c679bcdedf..00360b264320 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1410,68 +1410,153 @@ def body(x_ref): class FusedAttentionTest(PallasTest): - @parameterized.named_parameters(*[ - (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_{use_fwd=}", - batch_size, seq_len, num_heads, head_dim, causal, use_fwd) - for batch_size, seq_len, num_heads, head_dim, causal, use_fwd in [ - (1, 384, 1, 64, False, False), - (2, 384, 2, 64, False, False), - (1, 384, 1, 64, True, False), - (2, 384, 2, 64, True, False), - (1, 384, 8, 64, True, True), - (2, 384, 8, 64, True, True), + @parameterized.named_parameters( + *[ + ( + ( + f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}" + f"_{use_fwd=}_{use_segment_ids=}" + ), + batch_size, + seq_len, + num_heads, + head_dim, + causal, + use_fwd, + use_segment_ids, + ) + for ( + batch_size, + seq_len, + num_heads, + head_dim, + causal, + use_fwd, + use_segment_ids, + ) in [ + (1, 384, 1, 64, False, False, True), + (1, 384, 1, 64, False, False, False), + (2, 384, 2, 64, False, False, True), + (1, 384, 1, 64, True, False, True), + (2, 384, 2, 64, True, False, True), + (1, 384, 8, 64, True, True, True), + (1, 384, 8, 64, True, True, False), + (2, 384, 8, 64, True, True, True), + ] ] - ]) - def test_fused_attention_fwd(self, batch_size, seq_len, num_heads, head_dim, - causal, use_fwd): + ) + def test_fused_attention_fwd( + self, + batch_size, + seq_len, + num_heads, + head_dim, + causal, + use_fwd, + use_segment_ids, + ): if plgpu.get_compute_capability(0) < 80: raise unittest.SkipTest( "Fused attention only works on GPUs with capability >= sm80") k1, k2, k3 = random.split(random.PRNGKey(0), 3) - q = random.normal(k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16) - k = random.normal(k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16) - v = random.normal(k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16) + q = random.normal( + k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + k = random.normal( + k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + v = random.normal( + k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + if use_segment_ids: + segment_ids_1 = jnp.zeros((batch_size, seq_len // 2), dtype=jnp.int32) + segment_ids_2 = jnp.ones((batch_size, seq_len // 2), dtype=jnp.int32) + segment_ids = jnp.concatenate((segment_ids_1, segment_ids_2), axis=-1) + else: + segment_ids = None if use_fwd: + @jax.jit def impl(q, k, v): - v, _ = jax.vjp(functools.partial(attention.mha, causal=causal), q, k, v) + v, _ = jax.vjp( + functools.partial( + attention.mha, causal=causal, segment_ids=segment_ids + ), + q, + k, + v, + ) return v + else: - impl = functools.partial(attention.mha, causal=causal) + impl = functools.partial( + attention.mha, causal=causal, segment_ids=segment_ids + ) o = impl(q, k, v) - o_ref = attention.mha_reference(q, k, v, causal=causal) + o_ref = attention.mha_reference(q, k, v, segment_ids, causal=causal) np.testing.assert_allclose(o, o_ref, atol=0.05) - @parameterized.named_parameters(*[ - (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}", - batch_size, seq_len, num_heads, head_dim, causal) - for batch_size, seq_len, num_heads, head_dim, causal in [ - (1, 384, 1, 32, False), - (2, 384, 2, 32, False), - # TODO(b/283035396): (1, 384, 1, 32, True), - # TODO(b/283035396): (2, 384, 2, 32, True), + @parameterized.named_parameters( + *[ + ( + ( + f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_" + f"{use_segment_ids=}" + ), + batch_size, + seq_len, + num_heads, + head_dim, + causal, + use_segment_ids, + ) + for ( + batch_size, + seq_len, + num_heads, + head_dim, + causal, + use_segment_ids, + ) in [ + (1, 384, 1, 32, False, True), + (1, 384, 1, 32, False, False), + (2, 384, 2, 32, False, True), + (2, 384, 2, 32, False, False), + # TODO(b/283035396): (1, 384, 1, 32, True, True), + # TODO(b/283035396): (2, 384, 2, 32, True, True), + ] ] - ]) - def test_fused_attention_bwd(self, batch_size, seq_len, num_heads, head_dim, - causal): + ) + def test_fused_attention_bwd( + self, batch_size, seq_len, num_heads, head_dim, causal, use_segment_ids + ): if plgpu.get_compute_capability(0) < 80: raise unittest.SkipTest( "Fused attention only works on GPUs with capability >= sm80") k1, k2, k3 = random.split(random.PRNGKey(0), 3) - q = random.normal(k1, (batch_size, seq_len, num_heads, head_dim), - dtype=jnp.float16) - k = random.normal(k2, (batch_size, seq_len, num_heads, head_dim), - dtype=jnp.float16) - v = random.normal(k3, (batch_size, seq_len, num_heads, head_dim), - dtype=jnp.float16) + q = random.normal( + k1, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + k = random.normal( + k2, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + v = random.normal( + k3, (batch_size, seq_len, num_heads, head_dim), dtype=jnp.float16 + ) + if use_segment_ids: + segment_ids_1 = jnp.zeros((batch_size, seq_len // 2), dtype=jnp.int32) + segment_ids_2 = jnp.ones((batch_size, seq_len // 2), dtype=jnp.int32) + segment_ids = jnp.concatenate((segment_ids_1, segment_ids_2), axis=-1) + else: + segment_ids = None def f(q, k, v): - return attention.mha(q, k, v, causal=causal).sum() + return attention.mha(q, k, v, segment_ids, causal=causal).sum() def f_ref(q, k, v): - return attention.mha_reference(q, k, v, causal=causal).sum() + return attention.mha_reference(q, k, v, segment_ids, causal=causal).sum() dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v) dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v)