Skip to content

Commit

Permalink
add jax.cudnn & add check for bias/mask sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
Cjkkkk committed Feb 9, 2024
1 parent 49f1537 commit 59307e9
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 25 deletions.
61 changes: 43 additions & 18 deletions jax/_src/cudnn/fused_attention_stablehlo.py
Expand Up @@ -27,7 +27,6 @@
from jax._src.core import ShapedArray

from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, PartitionSpec, NamedSharding

from jax._src.interpreters import batching
Expand Down Expand Up @@ -214,7 +213,8 @@ def check_is_flash_attention(query, key):
# check if regular fused attention is supported
is_flash_attention = False
else:
raise NotImplementedError("Unsupported sequence length and head dim.")
raise NotImplementedError(
f"Unsupported sequence length Q {q_seq_len}, KV {kv_sqe_len} and head dim {head_dim}.")
return is_flash_attention, is_cross_attention

def check_cudnn_version(is_flash_attention, is_cross_attention):
Expand Down Expand Up @@ -533,46 +533,71 @@ def _get_padded_spec(arg_info):
assert len(spec) <= ndim
return spec + (None,) * (ndim - len(spec))

def _check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec):
# check qkv spec
if not query_spec == key_spec == value_spec:
raise ValueError("Query, key and value should have same sharding.")
*batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec
if q_seq_spec != None:
raise ValueError("Sharding on sequence dim is not allowed.")
if head_spec != None:
raise ValueError("Sharding on head dim is not allowed.")
# check bias and mask spec
if bias_spec:
*bias_batch_spec, bias_num_head_spec, bias_q_seq_spec, bias_kv_seq_spec = bias_spec
if bias_batch_spec != batch_spec or bias_num_head_spec != num_head_spec:
raise ValueError("Query and bias should have same sharding on batch and num_head dim.")
if bias_q_seq_spec != None or bias_kv_seq_spec != None:
raise ValueError("Sharding on bias sequence dim is not allowed.")
if mask_spec:
*mask_batch_spec, mask_num_head_spec, mask_q_seq_spec, mask_kv_seq_spec = mask_spec
if mask_batch_spec != batch_spec or mask_num_head_spec != num_head_spec:
raise ValueError("Query and mask should have same sharding on batch and num_head dim.")
if mask_q_seq_spec != None or mask_kv_seq_spec != None:
raise ValueError("Sharding on mask sequence dim is not allowed.")

# fwd custom partition
def _infer_fwd_output_sharding(mesh, arg_shapes):
def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args):
# only sharding on batch and num_head dim is allowed
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
key_spec = _get_padded_spec(arg_shapes[1])
value_spec = _get_padded_spec(arg_shapes[2])
if not query_spec == key_spec == value_spec:
raise ValueError("Query, key and value should have same sharding.")
seq_spec = query_spec[-3]
head_spec = query_spec[-1]
if seq_spec != None:
raise ValueError("Sharding on sequence dim is not allowed.")
if head_spec != None:
raise ValueError("Sharding on head dim is not allowed.")
has_bias, has_mask = variadic_args
bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None
mask_spec = _get_padded_spec(arg_shapes[4]) if has_mask else None
_check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec)
# keep out sharding same as query sharding since they have same shape
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
# activation sharding
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
*batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec
activation_sharding = NamedSharding(mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None))
return (out_sharding, activation_sharding)

_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
return _infer_fwd_output_sharding(mesh, arg_shapes)
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args)

def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes)
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args)
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
return mesh, impl, out_shardings, arg_shardings

# bwd custom partition
def _infer_bwd_output_sharding(mesh, arg_shapes):
def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args):
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
key_spec = _get_padded_spec(arg_shapes[1])
value_spec = _get_padded_spec(arg_shapes[2])
has_bias, has_mask = variadic_args
bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None
mask_spec = _get_padded_spec(arg_shapes[4]) if has_mask else None
_check_qkv_bias_mask_spec(query_spec, key_spec, value_spec, bias_spec, mask_spec)
# keep grad query sharding same as query sharding
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
Expand All @@ -582,10 +607,10 @@ def _infer_bwd_output_sharding(mesh, arg_shapes):

_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
return _infer_bwd_output_sharding(mesh, arg_shapes)
return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)

def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes)
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
Expand Down Expand Up @@ -700,7 +725,7 @@ def dot_product_attention(query: Array,
mask: mask used mask out logits with shape of `[batch, num_heads,
q_length, kv_length]`.
scale: scale for the query.
dropout_rate: dropout rate
dropout_rate: dropout rate.
Returns:
Output of shape `[batch, q_length, num_heads, v_depth_per_head]`.
"""
Expand Down
15 changes: 15 additions & 0 deletions jax/cudnn/__init__.py
@@ -0,0 +1,15 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
26 changes: 19 additions & 7 deletions tests/fused_attention_stablehlo_test.py
Expand Up @@ -25,7 +25,7 @@
from jax.sharding import PartitionSpec, NamedSharding
from jax._src import config
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
from jax.cudnn import dot_product_attention

config.parse_flags_with_absl()
Array = jnp.ndarray
Expand Down Expand Up @@ -76,6 +76,8 @@ def get_causal_mask(input_t):
bias = get_causal_mask(attn_weights)
if bias is not None:
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
if mask is not None:
attn_weights = jax.lax.select(mask, attn_weights, large_negative_number)
attn_weights = jax.nn.softmax(attn_weights)
if dropout_rate > 0.:
keep_prob = 1.0 - dropout_rate
Expand Down Expand Up @@ -107,14 +109,15 @@ class DotProductAttentionTest(jtu.JaxTestCase):
num_heads=[8],
head_dim=[64, 128],
use_bias=[True],
use_mask=[True],
is_causal_mask=[False],
dropout_rate=[0, 0.5],
scale=[0.5],
dtype=[jnp.float16, jnp.bfloat16]
)
@jtu.run_on_devices("cuda")
def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
head_dim: int, use_bias: bool, is_causal_mask: bool,
head_dim: int, use_bias: bool, use_mask: bool, is_causal_mask: bool,
dropout_rate: float, scale: float, dtype: jnp.dtype):
if seq_len == 256 and is_causal_mask:
self.skipTest("Fused attention does not support mask generation.")
Expand All @@ -123,7 +126,7 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
if len(jax.local_devices()) <= 4:
self.skipTest("Require at least 4 devices to run sharding tests.")

k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5)
k1, k2, k3, k4, k5, k6 = jax.random.split(jax.random.key(0), 6)
query = jax.random.normal(
k1, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
key = jax.random.normal(
Expand All @@ -137,25 +140,34 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
k5, (batch_size, num_heads, seq_len, seq_len), dtype=dtype)
else:
bias = None

if use_mask:
mask = jax.random.bernoulli(
k5, 0.5, (batch_size, num_heads, seq_len, seq_len)).astype(dtype)
else:
mask = None
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
with Mesh(devices, ('dp', 'tp')) as mesh:
qkv_spec = PartitionSpec('dp', None, 'tp', None)
qkv_sharding = NamedSharding(mesh, qkv_spec)
if bias is not None:
bias_spec = PartitionSpec('dp', 'tp', None, None)
mask_spec = PartitionSpec('dp', 'tp', None, None)
else:
bias_spec = PartitionSpec()
mask_spec = PartitionSpec()
bias_sharding = NamedSharding(mesh, bias_spec)
mask_sharding = NamedSharding(mesh, mask_spec)
replicated = NamedSharding(mesh, PartitionSpec())
query = jax.device_put(query, qkv_sharding)
key = jax.device_put(key, qkv_sharding)
value = jax.device_put(value, qkv_sharding)
if bias is not None:
bias = jax.device_put(bias, bias_sharding)
if mask is not None:
mask = jax.device_put(mask, mask_sharding)
grad = jax.device_put(grad, qkv_sharding)
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, replicated)
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, mask_sharding)
out_shardings = (replicated, (qkv_sharding, qkv_sharding, qkv_sharding))
jitted_sdpa_train = jax.jit(
partial(sdpa_train, scale=scale, is_causal_mask=is_causal_mask, dropout_rate=dropout_rate),
Expand All @@ -169,8 +181,8 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
out_shardings=out_shardings
)

out, (query_grad, key_grad, value_grad) = jitted_sdpa_train(query, key, value, grad, bias, None)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = jitted_sdpa_train_ref(query, key, value, grad, bias, None)
out, (query_grad, key_grad, value_grad) = jitted_sdpa_train(query, key, value, grad, bias, mask)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = jitted_sdpa_train_ref(query, key, value, grad, bias, mask)
self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5)
if seq_len > 512:
# query_grad in flash attention is not deterministic
Expand Down

0 comments on commit 59307e9

Please sign in to comment.