diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index a0f7abb449..7be4ef1a89 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -1700,7 +1700,7 @@ def aten_scaled_dot_product_attention( @torch_op("aten::_scaled_dot_product_flash_attention", private=True) -def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs( +def _aten__scaled_dot_product_flash_attention_fillin_empty_outputs( query: TFloat, ) -> Tuple[FLOAT, INT64, INT64, FLOAT]: query_first_three_dims = op.Slice( @@ -1723,7 +1723,7 @@ def _aten_scaled_dot_product_flash_attention_fillin_empty_outputs( @torch_op("aten::_scaled_dot_product_flash_attention", trace_only=True) -def aten_scaled_dot_product_flash_attention( +def aten__scaled_dot_product_flash_attention( query: TFloat, key: TFloat, value: TFloat, @@ -1751,7 +1751,7 @@ def aten_scaled_dot_product_flash_attention( empty_tensor_int, empty_int, empty_tensor_float, - ) = _aten_scaled_dot_product_flash_attention_fillin_empty_outputs(query) + ) = _aten__scaled_dot_product_flash_attention_fillin_empty_outputs(query) return ( result, @@ -1766,6 +1766,73 @@ def aten_scaled_dot_product_flash_attention( ) +@torch_op("aten::_scaled_dot_product_efficient_attention", private=True) +def _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( + query: TFloat, + compute_log_sumexp: bool, +) -> Tuple[FLOAT, INT64]: + """_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)""" + + query = op.Transpose(query, perm=[0, 2, 1, 3]) + query_shape = op.Shape(query) + query_first_dims = query_shape[:1] + query_second_dims = query_shape[1:2] + num_heads = query_shape[-2:-1] + + if compute_log_sumexp: + logsumexp_dim = op.Cast( + op.Ceil(op.Cast(query_second_dims, to=FLOAT.dtype) / 32.0) * 32.0, to=INT64.dtype + ) + logsum_exp = op.Expand( + 0.0, op.Concat(query_first_dims, num_heads, logsumexp_dim, axis=0) + ) + else: + logsum_exp = op.Expand(0.0, op.Concat(query_first_dims, num_heads, [0], axis=0)) + + # See Note [Seed and Offset]: + empty_tensor_int = op.Cast( + op.ConstantOfShape( + op.Constant(value=onnx.helper.make_tensor("Empty_INTS", INT64.dtype, [0], [])) + ), + to=INT64.dtype, + ) + + return logsum_exp, empty_tensor_int + + +@torch_op("aten::_scaled_dot_product_efficient_attention", trace_only=True) +def aten__scaled_dot_product_efficient_attention( + query: TFloat, + key: TFloat, + value: TFloat, + attn_bias: Optional[TFloat], # pylint: disable=unused-argument + compute_log_sumexp: bool, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, +) -> Tuple[TFloat, FLOAT, INT64, INT64]: + """_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)""" + + result = aten_scaled_dot_product_attention( + query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale + ) + + # The followings are not comsumed by the graph. + ( + logsumexp, + empty_tensor_int, + ) = _aten_scaled_dot_product_efficient_attention_fillin_empty_outputs( + query, compute_log_sumexp + ) + + return ( + result, + logsumexp, + empty_tensor_int, + empty_tensor_int, + ) + + @torch_op("aten::scaled_dot_product_attention", trace_only=True) def aten_scaled_dot_product_attention_bool_mask( query: TFloat, diff --git a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py index 2fa8c9a609..64e6a06225 100644 --- a/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py +++ b/onnxscript/tests/function_libs/torch_lib/extra_opinfo.py @@ -9,7 +9,11 @@ import torch from torch import testing as torch_testing -from torch.testing._internal import common_dtype, common_methods_invocations +from torch.testing._internal import ( + common_device_type, + common_dtype, + common_methods_invocations, +) from torch.testing._internal.opinfo import core as opinfo_core S = 5 @@ -1298,7 +1302,7 @@ def sample_inputs__softmax( yield opinfo_core.SampleInput(make_arg(shape), args=dim, kwargs=kwargs) -def sample_inputs_scaled_dot_product_flash_attention( +def sample_inputs__scaled_dot_product_flash_attention( op_info, device, dtype, requires_grad, **kwargs ): del op_info @@ -1342,6 +1346,41 @@ def sample_inputs_scaled_dot_product_flash_attention( yield from samples +def sample_inputs__scaled_dot_product_efficient_attention( + op_info, device, dtype, requires_grad, **kwargs +): + del op_info + del kwargs + + make = opinfo_core.partial( + opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8 + + dim_4_q_shape = (batch, num_heads, seq_q, head_dim) + dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim) + + qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)] + samples = [] + for qkv_shape, is_causal, dropout_p, compute_log_sumexp in opinfo_core.product( + qkv_shapes, [True, False], [0.0], [True, False] + ): + shape_q, shape_kv = qkv_shape + samples.append( + opinfo_core.SampleInput( + make(shape_q), + make(shape_kv), + make(shape_kv), + attn_bias=None, + is_causal=is_causal, + dropout_p=dropout_p, + compute_log_sumexp=compute_log_sumexp, + ) + ) + + yield from samples + + # NOTE: In `_native_batch_norm_legit` tests, it generates two kinds of args: # 1. (input, weight, bias, running_mean, running_var, training, momentum, eps) # 2. (input, weight, bias, training, momentum, eps) @@ -1765,11 +1804,26 @@ def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwar dtypes=common_dtype.floating_types_and(torch.bfloat16), # NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support # dim<=3 input. - sample_inputs_func=sample_inputs_scaled_dot_product_flash_attention, + sample_inputs_func=sample_inputs__scaled_dot_product_flash_attention, + supports_out=False, + supports_forward_ad=False, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + ), + opinfo_core.OpInfo( + "ops.aten._scaled_dot_product_efficient_attention", + aten_name="_scaled_dot_product_efficient_attention", + # only support CUDA + dtypes=common_dtype.empty_types(), + dtypesIfCUDA=common_dtype.floating_types_and(torch.bfloat16), + # NOTE: Different from aten::scaled_dot_product_attention, this op doesn't support + # dim<=3 input. + sample_inputs_func=sample_inputs__scaled_dot_product_efficient_attention, supports_out=False, supports_forward_ad=False, supports_fwgrad_bwgrad=True, check_batched_forward_grad=False, + decorators=[common_device_type.onlyCUDA], ), opinfo_core.OpInfo( "ops.aten._native_batch_norm_legit", diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test.py b/onnxscript/tests/function_libs/torch_lib/ops_test.py index 0913267be6..9cae237c80 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test.py @@ -406,11 +406,11 @@ def test_complex_output_match_opinfo_( common_device_type.instantiate_device_type_tests( - TestOutputConsistencyEager, globals(), only_for="cpu" + TestOutputConsistencyEager, globals(), only_for=["cpu", "cuda"] ) common_device_type.instantiate_device_type_tests( - TestOutputConsistencyFullGraph, globals(), only_for="cpu" + TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"] ) if __name__ == "__main__": diff --git a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py index 15feaa6b02..abf3198723 100644 --- a/onnxscript/tests/function_libs/torch_lib/ops_test_data.py +++ b/onnxscript/tests/function_libs/torch_lib/ops_test_data.py @@ -2018,15 +2018,33 @@ def _where_input_wrangler( ), TorchLibOpInfo( "ops.aten._scaled_dot_product_flash_attention", - nn_ops.aten_scaled_dot_product_flash_attention, + nn_ops.aten__scaled_dot_product_flash_attention, trace_only=True, tolerance={torch.float32: (3e-4, 1.5e-5)}, # Output[0] is OK, but other outputs just have the same shape with zero values nondeterministic=True, + compare_shape_only_for_output=(1, 2, 3, 4, 5, 6, 7, 8), ).skip( enabled_if=version_utils.torch_older_than("2.1"), reason="The operator is not supported in older version.", ), + TorchLibOpInfo( + "ops.aten._scaled_dot_product_efficient_attention", + nn_ops.aten__scaled_dot_product_efficient_attention, + trace_only=True, + tolerance={torch.float32: (3e-4, 1.5e-5)}, + # Output[0] is OK, but other outputs just have the same shape with zero values + nondeterministic=True, + compare_shape_only_for_output=(1, 2, 3), + ) + .skip( + enabled_if=version_utils.torch_older_than("2.1"), + reason="The operator is not supported in older version.", + ) + .skip( + enabled_if=not torch.cuda.is_available(), + reason="_scaled_dot_product_efficient_attention only supports CUDA", + ), TorchLibOpInfo( "nn.functional.scaled_dot_product_attention_bool_mask", nn_ops.aten_scaled_dot_product_attention_bool_mask,