diff --git a/jax/_src/cudnn/fused_attention_stableHLO.py b/jax/_src/cudnn/fused_attention_stableHLO.py deleted file mode 100644 index 0ff2017fc4fb..000000000000 --- a/jax/_src/cudnn/fused_attention_stableHLO.py +++ /dev/null @@ -1,672 +0,0 @@ -# Copyright 2023 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import partial, reduce -import operator -from typing import Any, Optional -import json - -import jax -import jax.numpy as jnp -from jax import core, dtypes -from jax.interpreters import mlir, xla -from jax.interpreters.mlir import ir -from jaxlib.hlo_helpers import custom_call -from jax._src.lib.mlir.dialects import hlo -from jax._src.core import ShapedArray - -from jax.experimental.custom_partitioning import custom_partitioning -from jax.experimental.pjit import pjit -from jax.sharding import Mesh, PartitionSpec, NamedSharding - -from jax._src.interpreters import batching -from jax._src import dispatch -from jax._src.lib import cuda_versions - -Array = jnp.ndarray -DType = jnp.dtype -PRNGKey = jnp.ndarray - -def element_type_to_backend_config_type_mapping(dtype): - _element_type_to_backend_config_type_mapping = { - ir.BF16Type.get(): "BF16", - ir.F16Type.get(): "F16", - } - return _element_type_to_backend_config_type_mapping.get(dtype) - -def default_layouts(*shapes): - return [range(len(shape) - 1, -1, -1) for shape in shapes] - -def create_dot_product_attention_backend_config(batch, - num_heads, - seq_q, - seq_kv, - dtype, - fmha_scale, - seed, - dropout_rate, - is_flash_attention, - is_causal_mask, - is_bwd): - # b q_seq num_heads head_dim -> Q - # b kv_seq num_heads head_dim -> K - # b kv_seq num_heads head_dim -> V - # b num_heads q_seq kv_seq -> P - # b q_seq num_heads head_dim -> O - # bmm1: Q @ K -> P - # bmm2: P @ V -> O - # bmm2Grad1: P @ dO -> dV - # bmm2Grad2: dO @ V -> dP - # bmm1Grad1: dP @ Q -> dK - # bmm1Grad2: dP @ K -> dQ - backend_config = { - "algorithm":{"algo_id":"0","math_type":"TENSOR_OP_MATH","tuning_knobs":{"17":"1","24":"0"},"is_cudnn_frontend":True,"workspace_size":"0"}, - "fmha_scale":fmha_scale, - "dropout_rate":dropout_rate, - "intermediate_tensor_shape":{"element_type":element_type_to_backend_config_type_mapping(dtype),"dimensions":[str(batch),str(num_heads),str(seq_q),str(seq_kv)],"tuple_shapes":[],"layout":{"dim_level_types":[],"dim_unique":[],"dim_ordered":[],"minor_to_major":["3","2","1","0"],"tiles":[],"element_size_in_bits":"0","memory_space":"0","index_primitive_type":"PRIMITIVE_TYPE_INVALID","pointer_primitive_type":"PRIMITIVE_TYPE_INVALID","dynamic_shape_metadata_prefix_bytes":"0"},"is_dynamic_dimension":[False,False,False,False]}, - "seed":seed, - "is_flash_attention":is_flash_attention, - "is_causal_mask":is_causal_mask - } - fwd_dot_number = { - "bmm1_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["3"],"lhs_batch_dimensions":["0","2"],"rhs_batch_dimensions":["0","2"]}, - "bmm2_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]}, - } - bwd_dot_number = { - "bmm1_grad_gemm1_dot_dimension_numbers":{"lhs_contracting_dimensions":["2"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]}, - "bmm1_grad_gemm2_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]}, - "bmm2_grad_gemm1_dot_dimension_numbers":{"lhs_contracting_dimensions":["2"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]}, - "bmm2_grad_gemm2_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["3"],"lhs_batch_dimensions":["0","2"],"rhs_batch_dimensions":["0","2"]}, - } - if is_bwd: - backend_config = {**backend_config, **bwd_dot_number} - else: - backend_config = {**backend_config, **fwd_dot_number} - - backend_config = json.dumps(backend_config) - return backend_config - -def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd): - index = is_bwd << 3 | has_dropout << 2 | has_mask << 1 | has_bias - _custom_name_maps = [ - # fMHA forward call targets. - "__cudnn$fhmaSoftmax", - "__cudnn$fhmaScaleBiasSoftmax", - "__cudnn$fhmaScaleMaskSoftmax", - "__cudnn$fhmaScaleBiasMaskSoftmax", - "__cudnn$fhmaSoftmaxDropout", - "__cudnn$fhmaScaleBiasSoftmaxDropout", - "__cudnn$fhmaScaleMaskSoftmaxDropout", - "__cudnn$fhmaScaleBiasMaskSoftmaxDropout", - # fMHA backward call targets. - "__cudnn$fhmaSoftmaxBackward", - "__cudnn$fhmaScaleBiasSoftmaxBackward", - "__cudnn$fhmaScaleMaskSoftmaxBackward", - "__cudnn$fhmaScaleBiasMaskSoftmaxBackward", - "__cudnn$fhmaSoftmaxDropoutBackward", - "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward", - "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward", - "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward" - ] - return _custom_name_maps[index] - -def check_qkv_layout(query, key, value): - assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \ - "query, key and value should have rank 4." - - # Only support fp16 and bf16 here - query_dtype = query.dtype - key_dtype = key.dtype - value_dtype = value.dtype - assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \ - "query, key and value should have same dtype and should be float16 or bfloat16" - - q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape - k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape - v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape - assert (q_batch == k_batch == v_batch) \ - and (k_seq_len == v_seq_len) \ - and (q_num_heads == k_num_heads == v_num_heads) \ - and (q_head_dim == k_head_dim == v_head_dim), \ - "query should have layout [batch, q_seq, num_heads, head_dim], " \ - "key and value should have layout [batch, kv_seq, num_heads, head_dim]." - -def check_is_flash_attention(query, key): - batch, q_seq_len, num_heads, head_dim = query.shape - _, kv_sqe_len, _, _ = key.shape - # check if attention pattern is supported by flash attention or fused attention - if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]: - # check if flash attention is supported - is_flash_attention = True - elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64: - # check if regular fused attention is supported - is_flash_attention = False - else: - raise NotImplementedError("Unsupported sequence length and head dim.") - return is_flash_attention - -def check_cuDNN_version(is_flash_attention): - # check if cuDNN is installed and if cuDNN version contraint is satisfied - if cuda_versions is None: - raise RuntimeError("cuDNN is not detected.") - elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903: - raise RuntimeError("Require cuDNN at lease 8.9.3 to run flash attention.") - elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901: - raise RuntimeError("Require cuDNN at lease 8.9.1 to run fused attention.") - -def _dot_product_attention_fwd(query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - output, _ = _dot_product_attention_fwd_p_wrapper.bind( - query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - return output - -def _dot_product_attention_fwd_rule(query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - output, activation = _dot_product_attention_fwd_p_wrapper.bind( - query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - res = (query, key, value, bias, mask, activation, output) - return output, res - -def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, res, grad_output): - # {Q, K, V, bias, mask, activation, fwd_output, dO} - query, key, value, bias, mask, activation, fwd_output = res - grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind( - query, key, value, bias, mask, activation, fwd_output, grad_output, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - grads = (grad_query, grad_key, grad_value, None, None) - return grads - -def _dot_product_attention_fwd_impl(query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - # args: {Q, K, V, mask*, bias*} - output, activation = _dot_product_attention_fwd_p.bind( - query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - return output, activation - -def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - grad_query, grad_key, grad_value = _dot_product_attention_bwd_p.bind( - query, key, value, bias, mask, activation, fwd_output, grad_output, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - grads = (grad_query, grad_key, grad_value) - return grads - -def _dot_product_attention_fwd_abstract(query, key, value, bias, mask, - *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - query_dtype = dtypes.canonicalize_dtype(query.dtype) - batch, q_seq_len, num_heads, head_dim = query.shape - _, kv_seq_len, _, _ = key.shape - output_shape = (batch, q_seq_len, num_heads, head_dim) - activation_shape = (batch, num_heads, q_seq_len, kv_seq_len) - softmax_stat_shape = (batch, num_heads, q_seq_len) - if q_seq_len > 512: - # is flash attention - return ( - ShapedArray(output_shape, query_dtype), # output - ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat - ) - else: - return ( - ShapedArray(output_shape, query_dtype), # output - ShapedArray(activation_shape, query_dtype), # activation - ) - -def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output, - *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - query_dtype = dtypes.canonicalize_dtype(query.dtype) - key_dtype = dtypes.canonicalize_dtype(key.dtype) - value_dtype = dtypes.canonicalize_dtype(value.dtype) - - return ( - ShapedArray( - query.shape, query_dtype - ), # grad query - ShapedArray( - key.shape, key_dtype - ), # grad key - ShapedArray( - value.shape, value_dtype - ), # part value - ) - -def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - query_type = ir.RankedTensorType(query.type) - query_shape = query_type.shape - key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape - value_type = ir.RankedTensorType(value.type) - value_shape = value_type.shape - - batch, q_seq_len, num_heads, head_dim = query_shape - _, kv_seq_len, _, _ = key_shape - - output_shape = (batch, num_heads, q_seq_len, head_dim) - output_layout = (3, 1, 2, 0) - output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) - activation_shape = (batch, num_heads, q_seq_len, kv_seq_len) - softmax_stat_shape = (batch, num_heads, q_seq_len) - scratch_shape = (0,) - scratch_type = ir.IntegerType.get_unsigned(8) - # get backend config - backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, False) - # {Q, K, V, mask*, bias*} - # {output, scratch, activation*} - has_dropout = dropout_rate > 0 - has_bias, has_mask = variadic_args - operands = [query, key, value] - if has_mask: - operands.append(mask) - if has_bias: - operands.append(bias) - # get custom call name - custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, False) - # create output types and layouts - if is_flash_attention: - result_types = [ - ir.RankedTensorType.get(output_shape, query_type.element_type), - ir.RankedTensorType.get(scratch_shape, scratch_type), - ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()), - ] - result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape) - else: - result_types = [ - ir.RankedTensorType.get(output_shape, query_type.element_type), - ir.RankedTensorType.get(scratch_shape, scratch_type), - ir.RankedTensorType.get(activation_shape, query_type.element_type), - ] - result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape) - # create custom call here - out = custom_call( - custom_call_name, - result_types=result_types, - operands=operands, - backend_config=backend_config, - operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]), - result_layouts=result_layouts, - ) - # dropout scratch memory - # output should be (batch, q_seq_len, num_heads, head_dim) instead of (batch, num_heads, q_seq_len, head_dim) - return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]] - -def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output, - scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - query_type = ir.RankedTensorType(query.type) - query_shape = query_type.shape - key_type = ir.RankedTensorType(key.type) - key_shape = key_type.shape - value_type = ir.RankedTensorType(value.type) - value_shape = value_type.shape - activation_type = ir.RankedTensorType(activation.type) - activation_shape = activation_type.shape - grad_output_type = ir.RankedTensorType(grad_output.type) - grad_output_shape = grad_output_type.shape - - batch, q_seq_len, num_heads, head_dim = query_shape - _, kv_seq_len, _, _ = key_shape - scratch_shape = (0,) - scratch_type = ir.IntegerType.get_unsigned(8) - - grad_query_shape = (batch, num_heads, q_seq_len, head_dim) - grad_key_shape = (batch, num_heads, kv_seq_len, head_dim) - grad_value_shape = (batch, num_heads, kv_seq_len, head_dim) - softmax_sum_shape = (batch, num_heads, q_seq_len) - grad_layout = (3, 1, 2, 0) - grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) - backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, True) - # {Q, K, V, activation, dO, mask*, bias*, O*} - # {dQ, dK, dV, d_S*, softmax_sum*, d_Q_accum*, scratch, dbias*} - has_dropout = dropout_rate > 0 - has_bias, has_mask = variadic_args - # create operands - operands = [query, key, value, activation, grad_output] - if has_mask: - operands.append(mask) - if has_bias and is_flash_attention: - # flash attention requires bias in the bwd for remat - operands.append(bias) - if is_flash_attention: - operands.append(fwd_output) - # get custom call name - custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, True) - - # create output types and layouts - if is_flash_attention: - result_types = [ - ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query - ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key - ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value - ir.RankedTensorType.get(softmax_sum_shape, ir.F32Type.get()), # softmax_sum - ir.RankedTensorType.get(grad_query_shape, ir.F32Type.get()), # d_Q_accum - ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch - ] - result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(softmax_sum_shape, grad_query_shape, scratch_shape) - else: - result_types = [ - ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query - ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key - ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value - ir.RankedTensorType.get(activation_shape, activation_type.element_type), # dS - ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch - ] - result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(activation_shape, scratch_shape) - out = custom_call( - custom_call_name, - result_types=result_types, - operands=operands, - backend_config=backend_config, - operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]), - result_layouts=result_layouts, - ) - # Only keep dQ, dK and dV here - return [hlo.transpose(out.results[0], grad_transpose_perm), - hlo.transpose(out.results[1], grad_transpose_perm), - hlo.transpose(out.results[2], grad_transpose_perm)] - -# batcher -def _check_valid_batch_dims(bdims): - for dim in bdims: - assert dim in [0, None], \ - "Currently only support batch_dim in [0, None], " \ - f"but got {dim=}" - -def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - _check_valid_batch_dims(batch_dims) - query, key, value, bias, mask = batched_args - query_bdim = batch_dims[0] - out_bdims = query_bdim, query_bdim - - *batch_tuple, q_seq_len, num_heads, head_dim = query.shape - *_, kv_seq_len, _, _ = key.shape - batch = reduce(operator.mul, batch_tuple) - has_bias, has_mask = variadic_args - # reshape to 4D shape - query = jnp.reshape(query, (batch, q_seq_len, num_heads, head_dim)) - key = jnp.reshape(key, (batch, kv_seq_len, num_heads, head_dim)) - value = jnp.reshape(value, (batch, kv_seq_len, num_heads, head_dim)) - if has_bias: - bias = jnp.reshape(bias, (batch, num_heads, q_seq_len, kv_seq_len)) - if has_mask: - mask = jnp.reshape(mask, (batch, num_heads, q_seq_len, kv_seq_len)) - - output, activation = _dot_product_attention_fwd_p_wrapper.bind( - query, key, value, bias, mask, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - - # reshape to original shape - output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim)) - if is_flash_attention: - activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len)) - else: - activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len)) - return (output, activation), out_bdims - -def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): - _check_valid_batch_dims(batch_dims) - query, key, value, bias, mask, activation, fwd_output, grad_output = batched_args - query_bdim = batch_dims[0] - out_bdims = query_bdim, query_bdim, query_bdim - - *batch_tuple, q_seq_len, num_heads, head_dim = query.shape - *_, kv_seq_len, _, _ = key.shape - batch = reduce(operator.mul, batch_tuple) - has_bias, has_mask = variadic_args - # reshape to 4D shape - query = jnp.reshape(query, (batch, q_seq_len, num_heads, head_dim)) - key = jnp.reshape(key, (batch, kv_seq_len, num_heads, head_dim)) - value = jnp.reshape(value, (batch, kv_seq_len, num_heads, head_dim)) - if has_bias: - bias = jnp.reshape(bias, (batch, num_heads, q_seq_len, kv_seq_len)) - if has_mask: - mask = jnp.reshape(mask, (batch, num_heads, q_seq_len, kv_seq_len)) - if is_flash_attention: - activation = jnp.reshape(activation, (batch, num_heads, q_seq_len)) - else: - activation = jnp.reshape(activation, (batch, num_heads, q_seq_len, kv_seq_len)) - fwd_output = jnp.reshape(fwd_output, (batch, q_seq_len, num_heads, head_dim)) - grad_output = jnp.reshape(grad_output, (batch, q_seq_len, num_heads, head_dim)) - - grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind( - query, key, value, bias, - mask, activation, fwd_output, grad_output, - scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, - is_causal_mask=is_causal_mask) - - # reshape to original shape - grad_query = jnp.reshape(grad_query, (*batch_tuple, q_seq_len, num_heads, head_dim)) - grad_key = jnp.reshape(grad_key, (*batch_tuple, kv_seq_len, num_heads, head_dim)) - grad_value = jnp.reshape(grad_value, (*batch_tuple, kv_seq_len, num_heads, head_dim)) - grads = (grad_query, grad_key, grad_value) - return grads, out_bdims - -# custom partitioning -def _get_padded_spec(arg_info): - spec = None if arg_info.sharding is None else arg_info.sharding.spec - ndim = arg_info.ndim - if spec is None: - return (None,) * ndim - assert len(spec) <= ndim - return spec + (None,) * (ndim - len(spec)) - -# fwd custom partition -_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10)) -def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): - # (*batch, q_seq, num_head, head) - query_spec = _get_padded_spec(arg_shapes[0]) - # (*batch, kv_seq, num_head, head) - key_spec = _get_padded_spec(arg_shapes[1]) - # keep out sharding same as query sharding since they have same shape - out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) - # activation sharding - if query_spec[-3] == key_spec[-3]: - # self attention - activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None)) - else: - # cross attention - activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3])) - return (out_sharding, activation_sharding) - -def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): - # (*batch, q_seq, num_head, head) - query_spec = _get_padded_spec(arg_shapes[0]) - # (*batch, kv_seq, num_head, head) - key_spec = _get_padded_spec(arg_shapes[1]) - # keep out sharding same as query sharding since they have same shape - out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) - # activation sharding - if query_spec[-3] == key_spec[-3]: - # self attention - activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None)) - else: - # cross attention - activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3])) - # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) - out_shardings = (out_sharding, activation_sharding) - impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) - return mesh, impl, out_shardings, arg_shardings - -# bwd custom partition -_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13)) -def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): - # (*batch, q_seq, num_head, head) - query_spec = _get_padded_spec(arg_shapes[0]) - # (*batch, kv_seq, num_head, head) - key_spec = _get_padded_spec(arg_shapes[1]) - # keep grad query sharding same as query sharding - grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) - grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) - grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) - out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding) - return out_shardings - -def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): - # (*batch, q_seq, num_head, head) - query_spec = _get_padded_spec(arg_shapes[0]) - # (*batch, kv_seq, num_head, head) - key_spec = _get_padded_spec(arg_shapes[1]) - # keep grad query sharding same as query sharding - grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) - grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) - grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) - out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding) - # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) - impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate, - variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) - return mesh, impl, out_shardings, arg_shardings - -# Create dot_product_attention_fwd_p for forward operation. -_dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") -_dot_product_attention_fwd_p.multiple_results = True -_dot_product_attention_fwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_fwd_p)) -_dot_product_attention_fwd_p.def_abstract_eval(_dot_product_attention_fwd_abstract) - -mlir.register_lowering( - _dot_product_attention_fwd_p, - _dot_product_attention_fwd_cuda_lowering, - platform="gpu", -) - -_dot_product_attention_fwd_p_wrapper = core.Primitive("dot_product_attention_fwd_wrapper") -_dot_product_attention_fwd_p_wrapper.multiple_results = True -_dot_product_attention_fwd_p_wrapper.def_impl(_dot_product_attention_fwd_impl) -_dot_product_attention_fwd_p_wrapper.def_abstract_eval(_dot_product_attention_fwd_abstract) - -# Create dot_product_attention_bwd_p for backward operation. -_dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") -_dot_product_attention_bwd_p.multiple_results = True -_dot_product_attention_bwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_bwd_p)) -_dot_product_attention_bwd_p.def_abstract_eval(_dot_product_attention_bwd_abstract) - -mlir.register_lowering( - _dot_product_attention_bwd_p, - _dot_product_attention_bwd_cuda_lowering, - platform="gpu", -) - -_dot_product_attention_bwd_p_wrapper = core.Primitive("dot_product_attention_bwd_wrapper") -_dot_product_attention_bwd_p_wrapper.multiple_results = True -_dot_product_attention_bwd_p_wrapper.def_impl(_dot_product_attention_bwd_impl) -_dot_product_attention_bwd_p_wrapper.def_abstract_eval(_dot_product_attention_bwd_abstract) - - -batching.primitive_batchers[_dot_product_attention_fwd_p_wrapper] = _dot_product_attention_fwd_batcher -batching.primitive_batchers[_dot_product_attention_bwd_p_wrapper] = _dot_product_attention_bwd_batcher - -_dot_product_attention_fwd_lower.def_partition( - infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands, - partition=_dot_product_attention_fwd_partition) - -mlir.register_lowering(_dot_product_attention_fwd_p_wrapper, - mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True)) - -_dot_product_attention_bwd_lower.def_partition( - infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands, - partition=_dot_product_attention_bwd_partition) - -mlir.register_lowering(_dot_product_attention_bwd_p_wrapper, - mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True)) - -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p) -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_wrapper) -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p) -dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper) - -@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10)) -def _dot_product_attention(query: Array, - key: Array, - value: Array, - bias: Array, - mask: Array, - scale: float, - seed: int, - dropout_rate: float, - variadic_args: tuple[bool], - is_flash_attention: bool, - is_causal_mask: bool): - output = _dot_product_attention_fwd( - query, key, value, bias, mask, - scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) - return output - -# _dot_product_attention_fwd must have the same func signature as _dot_product_attention -_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule) - -# User interface -def dot_product_attention(query: Array, - key: Array, - value: Array, - scale: float = 1.0, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - is_causal_mask: bool = False, - seed: int = 42, - dropout_rate: float = 0.): - """Computes dot-product attention given query, key, and value. - This is the core function for applying attention based on - https://arxiv.org/abs/1706.03762. It calculates the attention weights given - query and key and combines the values using the attention weights. - batch seq num_heads, head_dim // but all assume Q, K and V will have same - b q_seq num_heads head_dim -> Q - b kv_seq num_heads head_dim -> K - b kv_seq num_heads head_dim -> V - Args: - query: queries for calculating attention with shape of `[batch, q_length, - num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of `[batch, kv_length, - num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of `[batch, kv_length, - num_heads, v_depth_per_head]`. - scale: scale for the query. - dropout_rate: dropout rate - Returns: - Output of shape `[batch, length, num_heads, v_depth_per_head]`. - """ - # check if query, key and value layout meets cuDNN layout requirement - check_qkv_layout(query, key, value) - # check if flash attention is supported for this attention pattern - is_flash_attention = check_is_flash_attention(query, key) - # check if cuDNN is installed and if cuDNN version is sufficient - check_cuDNN_version(is_flash_attention) - - variadic_args = (bias is not None, mask is not None) - if bias is None: - bias = jnp.zeros(0, dtype=query.dtype) - if mask is None: - mask = jnp.zeros(0, dtype=query.dtype) - # TODO: remove this once scale behavior is fixed - if scale != 1.0: - query = query * scale - scale = 1.0 - output = _dot_product_attention( - query, key, value, bias, mask, - scale, seed, dropout_rate, variadic_args, - is_flash_attention, is_causal_mask) - return output diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py new file mode 100644 index 000000000000..b1278f122fee --- /dev/null +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -0,0 +1,732 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial, reduce +import operator +from typing import Any, Optional +import json + +import jax +import jax.numpy as jnp +from jax import core, dtypes +from jax.interpreters import mlir, xla +from jax.interpreters.mlir import ir +from jaxlib.hlo_helpers import custom_call +from jax._src.lib.mlir.dialects import hlo +from jax._src.core import ShapedArray + +from jax.experimental.custom_partitioning import custom_partitioning +from jax.experimental.pjit import pjit +from jax.sharding import Mesh, PartitionSpec, NamedSharding + +from jax._src.interpreters import batching +from jax._src import dispatch +from jax._src.lib import cuda_versions + +Array = jnp.ndarray +DType = jnp.dtype +PRNGKey = jnp.ndarray + +def element_type_to_backend_config_type_mapping(dtype): + _element_type_to_backend_config_type_mapping = { + ir.BF16Type.get(): "BF16", + ir.F16Type.get(): "F16", + } + return _element_type_to_backend_config_type_mapping[dtype] + +def default_layouts(*shapes): + return [range(len(shape) - 1, -1, -1) for shape in shapes] + +def create_dot_product_attention_backend_config(batch, + num_heads, + seq_q, + seq_kv, + dtype, + fmha_scale, + seed, + dropout_rate, + is_flash_attention, + is_causal_mask, + is_bwd): + # b q_seq num_heads head_dim -> Q + # b kv_seq num_heads head_dim -> K + # b kv_seq num_heads head_dim -> V + # b num_heads q_seq kv_seq -> P + # b q_seq num_heads head_dim -> O + # bmm1: Q @ K -> P + # bmm2: P @ V -> O + # bmm2Grad1: P @ dO -> dV + # bmm2Grad2: dO @ V -> dP + # bmm1Grad1: dP @ Q -> dK + # bmm1Grad2: dP @ K -> dQ + cudnn_fmha_backend_config = { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": {"17": "1", "24": "0"}, + "is_cudnn_frontend": True, + "workspace_size": "0", + }, + "fmha_scale": fmha_scale, + "dropout_rate": dropout_rate, + "intermediate_tensor_shape": { + "element_type": element_type_to_backend_config_type_mapping(dtype), + "dimensions": [str(batch), str(num_heads), str(seq_q), str(seq_kv)], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0", + }, + "is_dynamic_dimension": [False, False, False, False], + }, + "seed": seed, + "is_flash_attention": is_flash_attention, + "is_causal_mask": is_causal_mask, + } + fwd_dot_number = { + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"], + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"], + }, + } + bwd_dot_number = { + "bmm1_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"], + }, + "bmm1_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"], + }, + "bmm2_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"], + }, + "bmm2_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"], + }, + } + if is_bwd: + cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **bwd_dot_number} + else: + cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **fwd_dot_number} + + backend_config = { + "operation_queue_id":"0", + "wait_on_operation_queues":[], + "cudnn_fmha_backend_config": cudnn_fmha_backend_config + } + backend_config = json.dumps(backend_config) + return backend_config + +def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd): + index = is_bwd << 3 | has_dropout << 2 | has_mask << 1 | has_bias + _custom_name_maps = [ + # fMHA forward call targets. + "__cudnn$fhmaSoftmax", + "__cudnn$fhmaScaleBiasSoftmax", + "__cudnn$fhmaScaleMaskSoftmax", + "__cudnn$fhmaScaleBiasMaskSoftmax", + "__cudnn$fhmaSoftmaxDropout", + "__cudnn$fhmaScaleBiasSoftmaxDropout", + "__cudnn$fhmaScaleMaskSoftmaxDropout", + "__cudnn$fhmaScaleBiasMaskSoftmaxDropout", + # fMHA backward call targets. + "__cudnn$fhmaSoftmaxBackward", + "__cudnn$fhmaScaleBiasSoftmaxBackward", + "__cudnn$fhmaScaleMaskSoftmaxBackward", + "__cudnn$fhmaScaleBiasMaskSoftmaxBackward", + "__cudnn$fhmaSoftmaxDropoutBackward", + "__cudnn$fhmaScaleBiasSoftmaxDropoutBackward", + "__cudnn$fhmaScaleMaskSoftmaxDropoutBackward", + "__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward" + ] + return _custom_name_maps[index] + +def check_qkv_layout(query, key, value): + assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \ + "query, key and value should have rank 4." + + # Only support fp16 and bf16 here + query_dtype = query.dtype + key_dtype = key.dtype + value_dtype = value.dtype + assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \ + "query, key and value should have same dtype and should be float16 or bfloat16" + + q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape + k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape + v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape + if not((q_batch == k_batch == v_batch) + and (k_seq_len == v_seq_len) + and (q_num_heads == k_num_heads == v_num_heads) + and (q_head_dim == k_head_dim == v_head_dim)): + raise ValueError( + "query should have layout [batch, q_seq, num_heads, head_dim], " \ + "key and value should have layout [batch, kv_seq, num_heads, head_dim].") + +def check_is_flash_attention(query, key): + batch, q_seq_len, num_heads, head_dim = query.shape + _, kv_sqe_len, _, _ = key.shape + # check if attention pattern is supported by flash attention or fused attention + if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]: + # check if flash attention is supported + is_flash_attention = True + elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64: + # check if regular fused attention is supported + is_flash_attention = False + else: + raise NotImplementedError("Unsupported sequence length and head dim.") + return is_flash_attention + +def check_cudnn_version(is_flash_attention): + # check if cuDNN is installed and if cuDNN version contraint is satisfied + if cuda_versions is None: + raise RuntimeError("cuDNN is not detected.") + elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903: + raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.") + elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901: + raise RuntimeError("JAX requires cuDNN >= 8.9.1 to use fused attention.") + +def _dot_product_attention_fwd(query, key, value, bias, mask, + scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + output, _ = _dot_product_attention_fwd_p_wrapper.bind( + query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, + is_causal_mask=is_causal_mask) + return output + +def _dot_product_attention_fwd_rule(query, key, value, bias, mask, + scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + output, activation = _dot_product_attention_fwd_p_wrapper.bind( + query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, + is_causal_mask=is_causal_mask) + res = (query, key, value, bias, mask, activation, output) + return output, res + +def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, res, grad_output): + query, key, value, bias, mask, activation, fwd_output = res + grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind( + query, key, value, bias, mask, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, + is_causal_mask=is_causal_mask) + grads = (grad_query, grad_key, grad_value, None, None) + return grads + +def _dot_product_attention_fwd_impl(query, key, value, bias, mask, + scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + # args: {Q, K, V, mask*, bias*} + output, activation = _dot_product_attention_fwd_p.bind( + query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, + is_causal_mask=is_causal_mask) + return output, activation + +def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output, + scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + grad_query, grad_key, grad_value = _dot_product_attention_bwd_p.bind( + query, key, value, bias, mask, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, + is_causal_mask=is_causal_mask) + grads = (grad_query, grad_key, grad_value) + return grads + +def _dot_product_attention_fwd_abstract(query, key, value, bias, mask, + *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + query_dtype = dtypes.canonicalize_dtype(query.dtype) + batch, q_seq_len, num_heads, head_dim = query.shape + _, kv_seq_len, _, _ = key.shape + output_shape = (batch, q_seq_len, num_heads, head_dim) + activation_shape = (batch, num_heads, q_seq_len, kv_seq_len) + softmax_stat_shape = (batch, num_heads, q_seq_len) + if q_seq_len > 512: + # is flash attention + return ( + ShapedArray(output_shape, query_dtype), # output + ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat + ) + return ( + ShapedArray(output_shape, query_dtype), # output + ShapedArray(activation_shape, query_dtype), # activation + ) + +def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output, + *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + query_dtype = dtypes.canonicalize_dtype(query.dtype) + key_dtype = dtypes.canonicalize_dtype(key.dtype) + value_dtype = dtypes.canonicalize_dtype(value.dtype) + + return ( + ShapedArray( + query.shape, query_dtype + ), # grad query + ShapedArray( + key.shape, key_dtype + ), # grad key + ShapedArray( + value.shape, value_dtype + ), # part value + ) + +def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask, + scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + query_type = ir.RankedTensorType(query.type) + query_shape = query_type.shape + key_type = ir.RankedTensorType(key.type) + key_shape = key_type.shape + value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape + + batch, q_seq_len, num_heads, head_dim = query_shape + _, kv_seq_len, _, _ = key_shape + + output_shape = (batch, num_heads, q_seq_len, head_dim) + output_layout = (3, 1, 2, 0) + output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) + activation_shape = (batch, num_heads, q_seq_len, kv_seq_len) + softmax_stat_shape = (batch, num_heads, q_seq_len) + scratch_shape = (0,) + scratch_type = ir.IntegerType.get_unsigned(8) + # get backend config + backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, False) + # {Q, K, V, mask*, bias*} + # {output, scratch, activation*} + has_dropout = dropout_rate > 0 + has_bias, has_mask = variadic_args + operands = [query, key, value] + if has_mask: + operands.append(mask) + if has_bias: + operands.append(bias) + custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, False) + # create output types and layouts + if is_flash_attention: + result_types = [ + ir.RankedTensorType.get(output_shape, query_type.element_type), + ir.RankedTensorType.get(scratch_shape, scratch_type), + ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()), + ] + result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape) + else: + result_types = [ + ir.RankedTensorType.get(output_shape, query_type.element_type), + ir.RankedTensorType.get(scratch_shape, scratch_type), + ir.RankedTensorType.get(activation_shape, query_type.element_type), + ] + result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape) + # create custom call here + out = custom_call( + custom_call_name, + result_types=result_types, + operands=operands, + backend_config=backend_config, + operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]), + result_layouts=result_layouts, + ) + # drop scratch memory + # output should be (batch, q_seq_len, num_heads, head_dim) instead of (batch, num_heads, q_seq_len, head_dim) + return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]] + +def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output, + scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + query_type = ir.RankedTensorType(query.type) + query_shape = query_type.shape + key_type = ir.RankedTensorType(key.type) + key_shape = key_type.shape + value_type = ir.RankedTensorType(value.type) + value_shape = value_type.shape + activation_type = ir.RankedTensorType(activation.type) + activation_shape = activation_type.shape + grad_output_type = ir.RankedTensorType(grad_output.type) + grad_output_shape = grad_output_type.shape + + batch, q_seq_len, num_heads, head_dim = query_shape + _, kv_seq_len, _, _ = key_shape + scratch_shape = (0,) + scratch_type = ir.IntegerType.get_unsigned(8) + + grad_query_shape = (batch, num_heads, q_seq_len, head_dim) + grad_key_shape = (batch, num_heads, kv_seq_len, head_dim) + grad_value_shape = (batch, num_heads, kv_seq_len, head_dim) + softmax_sum_shape = (batch, num_heads, q_seq_len) + grad_layout = (3, 1, 2, 0) + grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) + backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, True) + # {Q, K, V, activation, dO, mask*, bias*, O*} + # {dQ, dK, dV, d_S*, softmax_sum*, d_Q_accum*, scratch, dbias*} + has_dropout = dropout_rate > 0 + has_bias, has_mask = variadic_args + # create operands + operands = [query, key, value, activation, grad_output] + if has_mask: + operands.append(mask) + if has_bias and is_flash_attention: + # flash attention requires bias in the bwd for remat + operands.append(bias) + if is_flash_attention: + operands.append(fwd_output) + # get custom call name + custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, True) + + # create output types and layouts + if is_flash_attention: + result_types = [ + ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query + ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key + ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value + ir.RankedTensorType.get(softmax_sum_shape, ir.F32Type.get()), # softmax_sum + ir.RankedTensorType.get(grad_query_shape, ir.F32Type.get()), # d_Q_accum + ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch + ] + result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(softmax_sum_shape, grad_query_shape, scratch_shape) + else: + result_types = [ + ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query + ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key + ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value + ir.RankedTensorType.get(activation_shape, activation_type.element_type), # dS + ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch + ] + result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(activation_shape, scratch_shape) + out = custom_call( + custom_call_name, + result_types=result_types, + operands=operands, + backend_config=backend_config, + operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]), + result_layouts=result_layouts, + ) + # Only keep dQ, dK and dV here + return [hlo.transpose(out.results[0], grad_transpose_perm), + hlo.transpose(out.results[1], grad_transpose_perm), + hlo.transpose(out.results[2], grad_transpose_perm)] + +# batcher +def _check_valid_batch_dims(bdims): + for dim in bdims: + if dim not in [0, None]: + raise NotImplementedError("Currently only support batch_dim in [0, None], " \ + f"but got {dim=}") + +def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + _check_valid_batch_dims(batch_dims) + query, key, value, bias, mask = batched_args + query_bdim = batch_dims[0] + out_bdims = query_bdim, query_bdim + + *batch_tuple, q_seq_len, num_heads, head_dim = query.shape + *_, kv_seq_len, _, _ = key.shape + new_batch = reduce(operator.mul, batch_tuple) + has_bias, has_mask = variadic_args + # reshape to 4D shape + query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim)) + key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim)) + value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim)) + if has_bias: + bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len)) + if has_mask: + mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len)) + + output, activation = _dot_product_attention_fwd_p_wrapper.bind( + query, key, value, bias, mask, + scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, + is_causal_mask=is_causal_mask) + + # reshape to original shape + output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim)) + if is_flash_attention: + activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len)) + else: + activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len)) + return (output, activation), out_bdims + +def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask): + _check_valid_batch_dims(batch_dims) + query, key, value, bias, mask, activation, fwd_output, grad_output = batched_args + query_bdim = batch_dims[0] + out_bdims = query_bdim, query_bdim, query_bdim + + *batch_tuple, q_seq_len, num_heads, head_dim = query.shape + *_, kv_seq_len, _, _ = key.shape + new_batch = reduce(operator.mul, batch_tuple) + has_bias, has_mask = variadic_args + # reshape to 4D shape + query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim)) + key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim)) + value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim)) + if has_bias: + bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len)) + if has_mask: + mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len)) + if is_flash_attention: + activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len)) + else: + activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len, kv_seq_len)) + fwd_output = jnp.reshape(fwd_output, (new_batch, q_seq_len, num_heads, head_dim)) + grad_output = jnp.reshape(grad_output, (new_batch, q_seq_len, num_heads, head_dim)) + + grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind( + query, key, value, bias, + mask, activation, fwd_output, grad_output, + scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, + is_causal_mask=is_causal_mask) + + # reshape to original shape + grad_query = jnp.reshape(grad_query, (*batch_tuple, q_seq_len, num_heads, head_dim)) + grad_key = jnp.reshape(grad_key, (*batch_tuple, kv_seq_len, num_heads, head_dim)) + grad_value = jnp.reshape(grad_value, (*batch_tuple, kv_seq_len, num_heads, head_dim)) + grads = (grad_query, grad_key, grad_value) + return grads, out_bdims + +# custom partitioning +def _get_padded_spec(arg_info): + spec = None if arg_info.sharding is None else arg_info.sharding.spec + ndim = arg_info.ndim + if spec is None: + return (None,) * ndim + assert len(spec) <= ndim + return spec + (None,) * (ndim - len(spec)) + +# fwd custom partition +_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10)) +def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): + # (*batch, q_seq, num_head, head) + query_spec = _get_padded_spec(arg_shapes[0]) + # (*batch, kv_seq, num_head, head) + key_spec = _get_padded_spec(arg_shapes[1]) + # keep out sharding same as query sharding since they have same shape + out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) + # activation sharding + if query_spec[-3] == key_spec[-3]: + # self attention + activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None)) + else: + # cross attention + activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3])) + return (out_sharding, activation_sharding) + +def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): + # (*batch, q_seq, num_head, head) + query_spec = _get_padded_spec(arg_shapes[0]) + # (*batch, kv_seq, num_head, head) + key_spec = _get_padded_spec(arg_shapes[1]) + # keep out sharding same as query sharding since they have same shape + out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) + # activation sharding + if query_spec[-3] == key_spec[-3]: + # self attention + activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None)) + else: + # cross attention + activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3])) + # args sharding + arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + out_shardings = (out_sharding, activation_sharding) + impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) + return mesh, impl, out_shardings, arg_shardings + +# bwd custom partition +_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13)) +def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): + # (*batch, q_seq, num_head, head) + query_spec = _get_padded_spec(arg_shapes[0]) + # (*batch, kv_seq, num_head, head) + key_spec = _get_padded_spec(arg_shapes[1]) + # keep grad query sharding same as query sharding + grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) + grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) + grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) + out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding) + return out_shardings + +def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape): + # (*batch, q_seq, num_head, head) + query_spec = _get_padded_spec(arg_shapes[0]) + # (*batch, kv_seq, num_head, head) + key_spec = _get_padded_spec(arg_shapes[1]) + # keep grad query sharding same as query sharding + grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) + grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) + grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) + out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding) + # args sharding + arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate, + variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) + return mesh, impl, out_shardings, arg_shardings + +# Create dot_product_attention_fwd_p for forward operation. +_dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd") +_dot_product_attention_fwd_p.multiple_results = True +_dot_product_attention_fwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_fwd_p)) +_dot_product_attention_fwd_p.def_abstract_eval(_dot_product_attention_fwd_abstract) + +mlir.register_lowering( + _dot_product_attention_fwd_p, + _dot_product_attention_fwd_cuda_lowering, + platform="cuda", +) + +_dot_product_attention_fwd_p_wrapper = core.Primitive("dot_product_attention_fwd_wrapper") +_dot_product_attention_fwd_p_wrapper.multiple_results = True +_dot_product_attention_fwd_p_wrapper.def_impl(_dot_product_attention_fwd_impl) +_dot_product_attention_fwd_p_wrapper.def_abstract_eval(_dot_product_attention_fwd_abstract) + +# Create dot_product_attention_bwd_p for backward operation. +_dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd") +_dot_product_attention_bwd_p.multiple_results = True +_dot_product_attention_bwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_bwd_p)) +_dot_product_attention_bwd_p.def_abstract_eval(_dot_product_attention_bwd_abstract) + +mlir.register_lowering( + _dot_product_attention_bwd_p, + _dot_product_attention_bwd_cuda_lowering, + platform="cuda", +) + +_dot_product_attention_bwd_p_wrapper = core.Primitive("dot_product_attention_bwd_wrapper") +_dot_product_attention_bwd_p_wrapper.multiple_results = True +_dot_product_attention_bwd_p_wrapper.def_impl(_dot_product_attention_bwd_impl) +_dot_product_attention_bwd_p_wrapper.def_abstract_eval(_dot_product_attention_bwd_abstract) + + +batching.primitive_batchers[_dot_product_attention_fwd_p_wrapper] = _dot_product_attention_fwd_batcher +batching.primitive_batchers[_dot_product_attention_bwd_p_wrapper] = _dot_product_attention_bwd_batcher + +_dot_product_attention_fwd_lower.def_partition( + infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands, + partition=_dot_product_attention_fwd_partition) + +mlir.register_lowering(_dot_product_attention_fwd_p_wrapper, + mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True)) + +_dot_product_attention_bwd_lower.def_partition( + infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands, + partition=_dot_product_attention_bwd_partition) + +mlir.register_lowering(_dot_product_attention_bwd_p_wrapper, + mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True)) + +dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p) +dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_wrapper) +dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p) +dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper) + +@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10)) +def _dot_product_attention(query: Array, + key: Array, + value: Array, + bias: Array, + mask: Array, + scale: float, + seed: int, + dropout_rate: float, + variadic_args: tuple[bool], + is_flash_attention: bool, + is_causal_mask: bool): + output = _dot_product_attention_fwd( + query, key, value, bias, mask, + scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, + is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask) + return output + +# _dot_product_attention_fwd must have the same func signature as _dot_product_attention +_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule) + +# User interface +def dot_product_attention(query: Array, + key: Array, + value: Array, + scale: float = 1.0, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + is_causal_mask: bool = False, + seed: int = 42, + dropout_rate: float = 0.): + """Computes dot-product attention given query, key, and value. + This is the core function for applying attention based on + https://arxiv.org/abs/1706.03762. It calculates the attention weights given + query and key and combines the values using the attention weights. + batch seq num_heads, head_dim // but all assume Q, K and V will have same + b q_seq num_heads head_dim -> Q + b kv_seq num_heads head_dim -> K + b kv_seq num_heads head_dim -> V + Args: + query: queries for calculating attention with shape of `[batch, q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch, kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch, kv_length, + num_heads, v_depth_per_head]`. + bias: bias to be added to logits with shape of `[batch, num_heads, + q_length, kv_length]`. + mask: mask used mask out logits with shape of `[batch, num_heads, + q_length, kv_length]`. + scale: scale for the query. + dropout_rate: dropout rate + Returns: + Output of shape `[batch, q_length, num_heads, v_depth_per_head]`. + """ + # check if query, key and value layout meets cuDNN layout requirement + check_qkv_layout(query, key, value) + # check if flash attention is supported for this attention pattern + is_flash_attention = check_is_flash_attention(query, key) + # check if cuDNN is installed and if cuDNN version is sufficient + check_cudnn_version(is_flash_attention) + + variadic_args = (bias is not None, mask is not None) + if bias is None: + bias = jnp.zeros(0, dtype=query.dtype) + if mask is None: + mask = jnp.zeros(0, dtype=query.dtype) + # TODO: remove this once scale behavior is fixed + if scale != 1.0: + query = query * scale + scale = 1.0 + output = _dot_product_attention( + query, key, value, bias, mask, + scale, seed, dropout_rate, variadic_args, + is_flash_attention, is_causal_mask) + return output diff --git a/tests/BUILD b/tests/BUILD index 4f71899f768f..829a570e5f88 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1430,6 +1430,19 @@ jax_test( ], ) +jax_test( + name = "fused_attention_stablehlo_test", + srcs = ["fused_attention_stablehlo_test.py"], + disable_backends = [ + "tpu", + "cpu", + ], + shard_count = 4, + deps = [ + "//jax:fused_attention_stablehlo", + ], +) + exports_files( [ "api_test.py", diff --git a/jax/_src/cudnn/fused_attention_stableHLO_test.py b/tests/fused_attention_stablehlo_test.py similarity index 76% rename from jax/_src/cudnn/fused_attention_stableHLO_test.py rename to tests/fused_attention_stablehlo_test.py index 1baa53bc219c..b72d719262e7 100644 --- a/jax/_src/cudnn/fused_attention_stableHLO_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -16,6 +16,8 @@ from absl.testing import absltest from typing import Any, Optional import os +os.environ['XLA_FLAGS'] = '--xla_dump_disable_metadata --xla_gpu_enable_triton_gemm=false --xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_dump_hlo_module_re=.*pjit__unnamed_function.* --xla_dump_hlo_pass_re=.* --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true' + import numpy as np import jax import jax.numpy as jnp @@ -24,41 +26,41 @@ from jax.experimental.pjit import pjit from jax._src import config from jax._src import test_util as jtu -from jax._src.cudnn.fused_attention_stableHLO import dot_product_attention +from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention config.parse_flags_with_absl() Array = jnp.ndarray def f(query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - causal_mask: bool = False, - scale: float = 0.5, - dropout_rate: float = 0.1) -> Array: - - output = dot_product_attention( - query, - key, - value, - scale=scale, - bias=bias, - mask=mask, - is_causal_mask=causal_mask, - dropout_rate=dropout_rate) - return output - -def f_train(query: Array, key: Array, value: Array, - grad: Array, bias: Optional[Array] = None, mask: Optional[Array] = None, causal_mask: bool = False, scale: float = 0.5, dropout_rate: float = 0.1) -> Array: + output = dot_product_attention( + query, + key, + value, + scale=scale, + bias=bias, + mask=mask, + is_causal_mask=causal_mask, + dropout_rate=dropout_rate) + return output + +def f_train(query: Array, + key: Array, + value: Array, + grad: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + causal_mask: bool = False, + scale: float = 0.5, + dropout_rate: float = 0.1) -> Array: + out, f_vjp = jax.vjp( partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate), query, key, value, bias, None) @@ -101,33 +103,27 @@ def get_causal_mask(input_t): attn_weights = jax.nn.softmax(attn_weights) if dropout_rate > 0.: keep_prob = 1.0 - dropout_rate - dropout_shape = list(attn_weights.shape) - dropout_shape[-2] = 1 - dropout_rng = jax.random.PRNGKey(0) - keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape) - keep = jnp.broadcast_to(keep, attn_weights.shape) - multiplier = ( - keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=attn_weights.dtype)) - attn_weights = attn_weights * multiplier + dropout_rng = jax.random.key(0) + keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) + attn_weights = jax.lax.select(keep, attn_weights / keep_prob, jnp.zeros_like(attn_weights)) return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value) def g_train(query: Array, - key: Array, - value: Array, - grad: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - causal_mask: bool = False, - scale: float = 0.5, - dropout_rate: float = 0.1) -> Array: + key: Array, + value: Array, + grad: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + causal_mask: bool = False, + scale: float = 0.5, + dropout_rate: float = 0.1) -> Array: out_ref, g_vjp = jax.vjp( partial(g, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate), query, key, value, bias, None) query_grad_ref, key_grad_ref, value_grad_ref, _, _ = g_vjp(grad) return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) -@jtu.with_config(jax_legacy_prng_key='allow') class DotProductAttentionTest(jtu.JaxTestCase): @jtu.sample_product( batch_size=[4], @@ -136,7 +132,7 @@ class DotProductAttentionTest(jtu.JaxTestCase): head_dim=[64, 128], use_bias=[True], is_causal_mask=[False], - dropout_rate=[0], + dropout_rate=[0, 0.5], scale=[0.5], dtype=[jnp.float16, jnp.bfloat16] ) @@ -144,15 +140,14 @@ class DotProductAttentionTest(jtu.JaxTestCase): def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, head_dim: int, use_bias: bool, is_causal_mask: bool, dropout_rate: float, scale: float, dtype: jnp.dtype): - if (seq_len == 256 and is_causal_mask): + if seq_len == 256 and is_causal_mask: self.skipTest("Fused attention does not support mask generation.") - if (seq_len == 256 and head_dim == 128): - self.skipTest("Fused attention does not head dim = 128.") + if seq_len == 256 and head_dim == 128: + self.skipTest("Fused attention does not support head dim = 128.") if len(jax.local_devices()) <= 4: self.skipTest("Require at least 4 devices to run sharding tests.") - os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true' - k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5) + k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5) query = jax.random.normal( k1, (batch_size, seq_len, num_heads, head_dim), dtype=dtype) key = jax.random.normal( @@ -197,14 +192,14 @@ def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, out, (query_grad, key_grad, value_grad) = pjitted_f_train(query, key, value, grad, bias, None) out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None) - assert jnp.allclose(out_ref, out, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) if seq_len > 512: # query_grad in flash attention is not deterministic - assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) else: - assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5) - assert jnp.allclose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) - assert jnp.allclose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())