diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0bcf5fa38a6f..b59824069bf3 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -748,7 +748,7 @@ endif() function(onnxruntime_set_compile_flags target_name) if (CPUINFO_SUPPORTED) - onnxruntime_add_include_to_target(${target_name} cpuinfo) + onnxruntime_add_include_to_target(${target_name} cpuinfo::cpuinfo) endif() if(onnxruntime_ENABLE_EAGER_MODE) target_compile_definitions(${target_name} PRIVATE ENABLE_EAGER_MODE) @@ -832,7 +832,7 @@ function(onnxruntime_set_compile_flags target_name) target_compile_options(${target_name} PRIVATE "-Wno-unused-parameter") endif() target_compile_definitions(${target_name} PUBLIC -DNSYNC_ATOMIC_CPP11) - target_include_directories(${target_name} PRIVATE "${google_nsync_SOURCE_DIR}/public") + onnxruntime_add_include_to_target(${target_name} nsync::nsync_cpp) endif() foreach(ORT_FLAG ${ORT_PROVIDER_FLAGS}) target_compile_definitions(${target_name} PRIVATE ${ORT_FLAG}) @@ -1469,7 +1469,7 @@ if (WIN32) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SYS_PATH_LIB}) list(APPEND onnxruntime_EXTERNAL_LIBRARIES debug Dbghelp) else() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync_cpp) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync::nsync_cpp) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${CMAKE_DL_LIBS} Threads::Threads) endif() diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index b829e0917363..337f4ce20482 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -236,7 +236,10 @@ if (NOT WIN32) #nsync tests failed on Mac Build set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE) onnxruntime_fetchcontent_makeavailable(google_nsync) - set(nsync_SOURCE_DIR ${google_nsync_SOURCE_DIR}) + if (google_nsync_SOURCE_DIR) + add_library(nsync::nsync_cpp ALIAS nsync_cpp) + target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public) + endif() endif() if(onnxruntime_USE_CUDA) @@ -360,6 +363,9 @@ FetchContent_Declare( if (CPUINFO_SUPPORTED) onnxruntime_fetchcontent_makeavailable(pytorch_cpuinfo) + if (pytorch_cpuinfo_SOURCE_DIR) + add_library(cpuinfo::cpuinfo ALIAS cpuinfo) + endif() endif() diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 0e02ad9daa7e..685df7a48769 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -194,8 +194,8 @@ if (ARM64 OR ARM OR X86 OR X64 OR X86_64) # Using it mainly in ARM with Android. # Its functionality in detecting x86 cpu features are lacking, so is support for Windows. if (CPUINFO_SUPPORTED) - onnxruntime_add_include_to_target(onnxruntime_common cpuinfo) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo clog) + onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo) endif() endif() endif() diff --git a/cmake/onnxruntime_flatbuffers.cmake b/cmake/onnxruntime_flatbuffers.cmake index c0cd1699bb33..3ab4c19122ba 100644 --- a/cmake/onnxruntime_flatbuffers.cmake +++ b/cmake/onnxruntime_flatbuffers.cmake @@ -9,7 +9,7 @@ file(GLOB onnxruntime_flatbuffers_srcs CONFIGURE_DEPENDS source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_flatbuffers_srcs}) onnxruntime_add_static_library(onnxruntime_flatbuffers ${onnxruntime_flatbuffers_srcs}) -onnxruntime_add_include_to_target(onnxruntime_flatbuffers onnx flatbuffers ${GSL_TARGET}) +onnxruntime_add_include_to_target(onnxruntime_flatbuffers onnx flatbuffers::flatbuffers ${GSL_TARGET}) if(onnxruntime_ENABLE_INSTRUMENT) target_compile_definitions(onnxruntime_flatbuffers PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT) endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 452985023560..53508e64d04a 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -548,10 +548,10 @@ if (onnxruntime_USE_CUDA) if(APPLE) set_property(TARGET onnxruntime_providers_cuda APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/cuda/exported_symbols.lst") - target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync_cpp) + target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync::nsync_cpp) elseif(UNIX) set_property(TARGET onnxruntime_providers_cuda APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/cuda/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync_cpp) + target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync::nsync_cpp) elseif(WIN32) set_property(TARGET onnxruntime_providers_cuda APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/cuda/symbols.def") else() @@ -609,10 +609,10 @@ if (onnxruntime_USE_DNNL) INSTALL_RPATH "@loader_path" BUILD_WITH_INSTALL_RPATH TRUE INSTALL_RPATH_USE_LINK_PATH FALSE) - target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync_cpp) + target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) elseif(UNIX) set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/dnnl/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\$ORIGIN") - target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync_cpp) + target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp) elseif(WIN32) set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/dnnl/symbols.def") else() @@ -742,11 +742,11 @@ if (onnxruntime_USE_TENSORRT) if(APPLE) set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/tensorrt/exported_symbols.lst") - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync_cpp) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp) elseif(UNIX) set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/tensorrt/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync_cpp stdc++fs) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp stdc++fs) elseif(WIN32) set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/tensorrt/symbols.def") else() @@ -1091,7 +1091,7 @@ if (onnxruntime_USE_QNN) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers Boost::mp11) + onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers::flatbuffers Boost::mp11) target_link_libraries(onnxruntime_providers_qnn) add_dependencies(onnxruntime_providers_qnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON) @@ -1286,7 +1286,7 @@ if (onnxruntime_USE_MIGRAPHX) target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync_cpp stdc++fs) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs) include(CheckLibraryExists) check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC) @@ -1552,7 +1552,7 @@ if (onnxruntime_USE_ROCM) if(UNIX) set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync_cpp) + target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync::nsync_cpp) else() message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it") endif() @@ -1688,7 +1688,7 @@ if (onnxruntime_USE_CANN) onnxruntime_add_include_to_target(onnxruntime_providers_cann onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_cann onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) - target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED}) + target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync::nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED}) target_link_directories(onnxruntime_providers_cann PRIVATE ${onnxruntime_CANN_HOME}/lib64) target_include_directories(onnxruntime_providers_cann PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${onnxruntime_CANN_HOME} ${onnxruntime_CANN_HOME}/include) @@ -1710,7 +1710,7 @@ if (onnxruntime_USE_AZURE) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_azure_src}) onnxruntime_add_static_library(onnxruntime_providers_azure ${onnxruntime_providers_azure_src}) add_dependencies(onnxruntime_providers_azure ${onnxruntime_EXTERNAL_DEPENDENCIES}) - onnxruntime_add_include_to_target(onnxruntime_providers_azure onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11) + onnxruntime_add_include_to_target(onnxruntime_providers_azure onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11) target_link_libraries(onnxruntime_providers_azure PRIVATE onnx onnxruntime_common onnxruntime_framework) set_target_properties(onnxruntime_providers_azure PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_azure PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index ebd6229204bb..7bf641d15154 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -90,11 +90,14 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" + "decoder/decoder_masked_multihead_attention.h" + "decoder/decoder_masked_multihead_attention.cc" "decoder/decoder_masked_self_attention.h" "decoder/decoder_masked_self_attention.cc" "decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h" "decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention.h" "decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu" + "decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu" "decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu" "decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu" ) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 930db84ef42d..4f1cccc7fd49 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -668,8 +668,8 @@ if(MSVC) "$<$>:/wd6326>") else() target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11) - target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} - ${nsync_SOURCE_DIR}/public) + target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + onnxruntime_add_include_to_target(onnxruntime_test_utils nsync::nsync_cpp) endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS}) @@ -702,8 +702,8 @@ if(MSVC) "$<$>:/utf-8>") else() target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11) - target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT} - ${nsync_SOURCE_DIR}/public) + target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp) endif() if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) #TODO: fix the warnings, they are dangerous @@ -1070,7 +1070,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) # "Global initializer calls a non-constexpr function." BENCHMARK_CAPTURE macro needs this. target_compile_options(onnxruntime_mlas_benchmark PRIVATE /wd26426) else() - target_link_libraries(onnxruntime_mlas_benchmark PRIVATE nsync_cpp ${CMAKE_DL_LIBS}) + target_link_libraries(onnxruntime_mlas_benchmark PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) endif() if (CPUINFO_SUPPORTED AND NOT onnxruntime_BUILD_WEBASSEMBLY) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE cpuinfo) @@ -1128,7 +1128,7 @@ if(onnxruntime_ENABLE_EAGER_MODE) list(APPEND onnxruntime_eager_mode_libs onnxruntime_training tensorboard) endif() IF(NOT WIN32) - list(APPEND onnxruntime_eager_mode_libs nsync_cpp) + list(APPEND onnxruntime_eager_mode_libs nsync::nsync_cpp) endif() target_link_libraries(onnxruntime_eager_mode_test PRIVATE ${onnxruntime_eager_mode_libs} Threads::Threads ${onnxruntime_EXTERNAL_LIBRARIES}) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) @@ -1188,7 +1188,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) if(NOT WIN32) - list(APPEND onnxruntime_perf_test_libs nsync_cpp) + list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp) if(onnxruntime_USE_SNPE) list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe) endif() @@ -1232,7 +1232,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) # test inference using shared lib set(onnxruntime_shared_lib_test_LIBS onnxruntime_mocked_allocator onnxruntime_test_utils onnxruntime_common onnx_proto) if(NOT WIN32) - list(APPEND onnxruntime_shared_lib_test_LIBS nsync_cpp) + list(APPEND onnxruntime_shared_lib_test_LIBS nsync::nsync_cpp) if(onnxruntime_USE_SNPE) list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_providers_snpe) endif() @@ -1354,7 +1354,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo) endif() if(NOT WIN32) - target_link_libraries(onnxruntime_mlas_test PRIVATE nsync_cpp ${CMAKE_DL_LIBS}) + target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS}) endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs}) @@ -1546,7 +1546,7 @@ endif() if (NOT onnxruntime_BUILD_WEBASSEMBLY AND (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_MINIMAL_BUILD_CUSTOM_OPS)) - file(GLOB_RECURSE custom_op_get_const_input_test_library_src + file(GLOB_RECURSE custom_op_get_const_input_test_library_src "${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op_lib.cc" "${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op.h" "${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op.cc" @@ -1562,7 +1562,7 @@ if (NOT onnxruntime_BUILD_WEBASSEMBLY AND (NOT onnxruntime_MINIMAL_BUILD OR onnx if (APPLE) set(ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip") else() - string(CONCAT ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG + string(CONCAT ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG "-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op_lib.lds " "-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack") endif() @@ -1582,7 +1582,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT onnxruntime_BUILD_WEBASSEMBLY AND NOT o set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils) if(NOT WIN32) - list(APPEND onnxruntime_logging_apis_test_LIBS nsync_cpp ${CMAKE_DL_LIBS}) + list(APPEND onnxruntime_logging_apis_test_LIBS nsync::nsync_cpp ${CMAKE_DL_LIBS}) endif() AddTest(DYN diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 886668d6b46b..2188565a876d 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -97,7 +97,7 @@ target_compile_options(onnx PRIVATE -Wno-unused-parameter -Wno-unused-variable) if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) bundle_static_library(onnxruntime_webassembly - nsync_cpp + nsync::nsync_cpp ${PROTOBUF_LIB} onnx onnx_proto @@ -172,7 +172,7 @@ else() endif() target_link_libraries(onnxruntime_webassembly PRIVATE - nsync_cpp + nsync::nsync_cpp ${PROTOBUF_LIB} onnx onnx_proto diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 824df7728207..cb7823f06b4c 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -20,6 +20,7 @@ Do not modify directly.* * com.microsoft.ConvTransposeWithDynamicPads * com.microsoft.CropAndResize * com.microsoft.DecoderAttention + * com.microsoft.DecoderMaskedMultiHeadAttention * com.microsoft.DecoderMaskedSelfAttention * com.microsoft.DequantizeBFP * com.microsoft.DequantizeLinear @@ -1102,6 +1103,75 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.DecoderMaskedMultiHeadAttention** + + Multihead attention that supports input sequence length of 1. + Similar to DecoderMaskedSelfAttention but this op excludes QKV MatMul and Bias. + This op supports both Self and Cross Attention. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
mask_filter_value : float
+
The value to be filled in the attention mask. Default value is -10000.0f
+
num_heads : int (required)
+
Number of attention heads
+
past_present_share_buffer : int
+
Corresponding past and present are same tensor, its size is (batch_size, num_heads, max_sequence_length, head_size)
+
scale : float
+
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
+ +#### Inputs (3 - 10) + +
+
query : T
+
Query with shape (batch_size, 1, hidden_size)
+
key : T
+
Key with shape (batch_size, 1, hidden_size) for self attention or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention
+
value : T
+
Value with shape (batch_size, 1, v_hidden_size) for self attention or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention
+
mask_index (optional) : M
+
Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)
+
relative_position_bias (optional) : T
+
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+
past_key (optional) : T
+
past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
+
past_value (optional) : T
+
past state for value with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size).
+
past_sequence_length (optional) : M
+
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.
+
beam_width (optional) : M
+
The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+
cache_indirection (optional) : M
+
A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+
+ +#### Outputs (1 - 3) + +
+
output : T
+
3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
+
present_key (optional) : T
+
past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
present_value (optional) : T
+
past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
M : tensor(int32)
+
Constrain mask index to integer types
+
+ + ### **com.microsoft.DecoderMaskedSelfAttention** Self attention that supports input sequence length of 1. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d4cd2f936389..050d84b19cc9 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -798,6 +798,7 @@ Do not modify directly.* |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index cc7dad81b4dc..fe1c57e5711f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -20,10 +20,12 @@ Status CheckInputs(const T* query, const T* relative_position_bias, const T* past_key, const T* past_value, + const T* past_seq_len, void* parameters, int num_heads, float mask_filter_value, float scale, + bool past_present_share_buffer, int max_threads_per_block) { // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None // relative_position_bias : (B, 1, S, L) @@ -59,6 +61,7 @@ Status CheckInputs(const T* query, int kv_sequence_length = sequence_length; int past_sequence_length = 0; + int max_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); const auto& past_value_dims = past_value->Shape().GetDims(); @@ -110,6 +113,14 @@ Status CheckInputs(const T* query, past_value_dims[3]); } past_sequence_length = static_cast(past_key_dims[2]); + max_sequence_length = static_cast(past_key_dims[2]); + if (past_present_share_buffer) { + if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "past_sequence_length tensor must be of one element when past_present_share_buffer is set"); + } + past_sequence_length = *((*past_seq_len).template Data()); + } } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent"); @@ -277,7 +288,7 @@ Status CheckInputs(const T* query, output_parameters->past_sequence_length = past_sequence_length; output_parameters->kv_sequence_length = kv_sequence_length; output_parameters->total_sequence_length = total_sequence_length; - output_parameters->max_sequence_length = 0; + output_parameters->max_sequence_length = max_sequence_length; output_parameters->input_hidden_size = 0; output_parameters->hidden_size = hidden_size; output_parameters->v_hidden_size = v_hidden_size; @@ -285,7 +296,7 @@ Status CheckInputs(const T* query, output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; output_parameters->is_unidirectional = false; - output_parameters->past_present_share_buffer = false; + output_parameters->past_present_share_buffer = past_present_share_buffer; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; output_parameters->scale = scale; diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index f077d56f03b7..7c4b65b11372 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -88,10 +88,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { relative_position_bias, past_key, past_value, + nullptr, // past_seq_len ¶meters, num_heads_, mask_filter_value_, scale_, + false, // past_present_share_buffer device_prop.maxThreadsPerBlock)); int sequence_length = parameters.sequence_length; diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 7a14267176eb..0b800c78bc2d 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -133,6 +133,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrd class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -279,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc new file mode 100644 index 000000000000..6130fd9eeb48 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc @@ -0,0 +1,223 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h" +#include "contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// TODO: refactor +static constexpr int kPastSequenceLengthInputIndex = 7; +static constexpr int kBeamWidthInputIndex = 8; +static constexpr int kCacheIndirectionInputIndex = 9; +static constexpr int kPastInputIndex = 5; +static constexpr int kPresentOutputIndex = 1; + +#define REGISTER_KERNEL_TYPED(T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DecoderMaskedMultiHeadAttention, \ + kMSDomain, \ + 1, \ + T1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ + .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ + DecoderMaskedMultiHeadAttention); + +REGISTER_KERNEL_TYPED(float, float) +REGISTER_KERNEL_TYPED(MLFloat16, uint16_t) + +template +DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const OpKernelInfo& info) : CudaKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); +} + +template +Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* context) const { + const Tensor* query = context->Input(0); + const Tensor* key = context->Input(1); + const Tensor* value = context->Input(2); + const Tensor* mask_index = context->Input(3); + const Tensor* relative_position_bias = context->Input(4); + const Tensor* past_key = context->Input(kPastInputIndex); + const Tensor* past_value = context->Input(kPastInputIndex + 1); + const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); + const Tensor* beam_width = context->Input(kBeamWidthInputIndex); + const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); + + auto& device_prop = GetDeviceProp(); + DecoderMaskedMultiHeadAttentionParams parameters; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, + key, + value, + nullptr, //bias + mask_index, + relative_position_bias, + past_key, + past_value, + past_seq_len, + ¶meters, + num_heads_, + mask_filter_value_, + scale_, + past_present_share_buffer_, + device_prop.maxThreadsPerBlock)); + + int batch_size = parameters.batch_size; + int sequence_length = parameters.sequence_length; + + // This kernel is for decoding only (i.e.) sequence length has to be 1 + if (sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention"); + } + + if (parameters.head_size != parameters.v_head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "QK head size should be same as V head size to use DecoderMaskedMultiHeadAttention"); + } + + if (parameters.mask_type != AttentionMaskType::MASK_2D_KEY_PADDING && + parameters.mask_type != AttentionMaskType::MASK_NONE) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "DecoderMaskedMultiHeadAttention only supports no mask or 2D key " + "padding mask of shape [batch, total_seq_length] currently"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.v_hidden_size); + Tensor* output = context->Output(0, output_shape); + + // Present input will have the same shape as the past input + Tensor* present_key = context->Output(kPresentOutputIndex, past_key->Shape()); + Tensor* present_value = context->Output(kPresentOutputIndex + 1, past_value->Shape()); + + auto cuda_stream = Stream(context); + + parameters.is_mha = true; + + // Update the q buffers + parameters.q = const_cast(query->Data()); + + // Update the relative position bias for self attention + if (relative_position_bias != nullptr) { + parameters.relative_attention_bias = const_cast(relative_position_bias->Data()); + } + + // Decoder cross-attention + if (past_key == nullptr && present_key == nullptr) { + if (relative_position_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "DecoderMaskedMultiHeadAttention does not support relative position bias for cross-attention"); + } + + parameters.is_cross_attention = true; + parameters.total_sequence_length = parameters.kv_sequence_length; + parameters.max_sequence_length = parameters.kv_sequence_length; + // parameters.k and paraneters.v are nullptr + parameters.k_cache = const_cast(key->Data()); + parameters.v_cache = const_cast(value->Data()); + } else { + // Sanity check + ORT_ENFORCE(past_present_share_buffer_); + + auto* present_key_data = present_key->MutableData(); + auto* present_value_data = present_value->MutableData(); + auto* past_key_data = past_key->Data(); + auto* past_value_data = past_value->Data(); + + // No production use-case will incur this copy cost as the implementation of + // GreedySearch/BeamSearch is written in such a way that the past and present buffers + // will be shared. + // This is just to circumvent the OpTester's limitation of not being able to bind a specific + // buffer to inputs/outputs. + if (present_key_data != past_key_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(present_key_data, past_key_data, past_key->SizeInBytes(), + cudaMemcpyDeviceToDevice, cuda_stream)); + } + if (present_value_data != past_value_data) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(present_value_data, past_value_data, past_value->SizeInBytes(), + cudaMemcpyDeviceToDevice, cuda_stream)); + } + + parameters.is_cross_attention = false; + + parameters.k = const_cast(key->Data()); + parameters.v = const_cast(value->Data()); + parameters.k_cache = present_key_data; + parameters.v_cache = present_value_data; + } + + parameters.out = output->MutableDataRaw(); + + // Scale + // If the scale is not provided - use `1/sqrt(head_size)` + if (parameters.scale == 0.f) { + parameters.scale = 1.f / sqrtf(static_cast(parameters.head_size)); + } + + // Mask + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + parameters.mask = mask_index->Data(); + } + + // Beam width (in case we are using this op inside BeamSearch) + if (beam_width != nullptr) { + parameters.beam_width = static_cast(*beam_width->Data()); + } + + // Cache indirection (in case we are using this op inside BeamSearch) + if (parameters.beam_width > 1) { + // If beam width > 1, then cache indirection buffer MUST be present + if (cache_indir == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "If beam width is greater than 1, then cache indirection buffer MUST be present"); + } + + parameters.cache_indir = cache_indir->Data(); + } + + switch (parameters.head_size) { + case 32: + mmha_launch_kernel(parameters, cuda_stream); + break; + + case 64: + mmha_launch_kernel(parameters, cuda_stream); + break; + + case 128: + mmha_launch_kernel(parameters, cuda_stream); + break; + + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "Unsupported head size in DecoderMaskedMultiHeadAttention. " + "Got head size: ", + parameters.head_size); + } + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h new file mode 100644 index 000000000000..8200a66db383 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class DecoderMaskedMultiHeadAttention final : public CudaKernel { + public: + DecoderMaskedMultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + float mask_filter_value_; + float scale_; + bool past_present_share_buffer_; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc index bba764970322..98f4642e7903 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc @@ -51,7 +51,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); auto& device_prop = GetDeviceProp(); - DecoderMaskedSelfAttentionParams parameters; + DecoderMaskedMultiHeadAttentionParams parameters; ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -186,6 +186,10 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont } switch (parameters.head_size) { + case 32: + mmha_launch_kernel(parameters, cuda_stream); + break; + case 64: mmha_launch_kernel(parameters, cuda_stream); break; diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu index 4cf00e222b0e..3582758d1dab 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu +++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu @@ -40,7 +40,7 @@ using namespace decoder_masked_self_attention_details; <<>>(params) template -void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream) { +void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream) { constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; int total_sequence_length = params.total_sequence_length; @@ -54,9 +54,9 @@ void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStre } // Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu new file mode 100644 index 000000000000..3d295116252f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu @@ -0,0 +1,63 @@ +/* + * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer + * + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modifications Copyright (c) Microsoft. +// Licensed under the MIT License. + +#include "decoder_masked_multihead_attention_impl.h" +#include "decoder_masked_multihead_attention_impl_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace decoder_masked_self_attention_details; + +#define MMHA_LAUNCH_KERNEL( \ + T, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ + size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + masked_multihead_attention_kernel \ + <<>>(params) + +template +void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream) { + constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; + int total_sequence_length = params.total_sequence_length; + + if (total_sequence_length < 32) { + MMHA_LAUNCH_KERNEL(T, head_size, 4, THREADS_PER_VALUE, 64); + } else if (total_sequence_length < 2048) { + MMHA_LAUNCH_KERNEL(T, head_size, 2, THREADS_PER_VALUE, 128); + } else { + MMHA_LAUNCH_KERNEL(T, head_size, 1, THREADS_PER_VALUE, 256); + } +} + +// Instantiate templates +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream); + +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu index 325681b0e1de..e5f57fac73cf 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu +++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu @@ -40,7 +40,7 @@ using namespace decoder_masked_self_attention_details; <<>>(params) template -void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream) { +void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream) { constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; int total_sequence_length = params.total_sequence_length; @@ -54,9 +54,9 @@ void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStre } // Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 26bc6f53b4c2..ea4e10519993 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -46,7 +46,7 @@ template < int THREADS_PER_VALUE, // The number of threads in a threadblock. int THREADS_PER_BLOCK> -__global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params) { +__global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params) { // This kernel contains some code that cannot be compiled on CUDA ARCH 5.3 or lower #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 (void)(params); @@ -137,13 +137,15 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara float qk = 0.0F; - int qkv_base_offset = bi * (3 * params.hidden_size) + hi * head_size; + int qkv_base_offset = params.is_mha + ? bi * params.hidden_size + hi * head_size + : bi * (3 * params.hidden_size) + hi * head_size; const size_t bi_total_seq_length = bi * params.total_sequence_length; const size_t bi_max_seq_length = bi * params.max_sequence_length; - int tlength = params.past_sequence_length; + int tlength = params.is_cross_attention ? params.kv_sequence_length : params.past_sequence_length; // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. const bool is_masked = tidx >= QK_VECS_PER_WARP; @@ -151,9 +153,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara // The offset in the Q and K buffer also accounts for the batch. int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; - // The offset in the bias buffer. - int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE; - // Trigger the loads from the Q and K buffers. Qk_vec_k q; zero(q); @@ -163,81 +162,99 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara } Qk_vec_k k; - zero(k); - if (!is_masked) { - k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset])); + if (!params.is_cross_attention) { + zero(k); + + if (!is_masked) { + k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset])); + } } // Trigger the loads from the Q and K bias buffers. Qk_vec_k q_bias; - zero(q_bias); + Qk_vec_k k_bias; + if (!params.is_mha) { + // The offset in the bias buffer. + int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE; - if (!is_masked) { - q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset])); - } + zero(q_bias); - Qk_vec_k k_bias; - zero(k_bias); + if (!is_masked) { + q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset])); + } - if (!is_masked) { - k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset])); - } + zero(k_bias); - // Computes the Q/K values with bias. - q = add_vec(q, q_bias); - k = add_vec(k, k_bias); + if (!is_masked) { + k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset])); + } + + // Computes the Q/K values with bias. + q = add_vec(q, q_bias); + k = add_vec(k, k_bias); + } T* params_k_cache = reinterpret_cast(params.k_cache); + const float inv_sqrt_dh = params.scale; + if (!is_masked) { // Store the Q values to shared memory. *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + } - // Write the K values to the global memory cache. - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. + if (!params.is_cross_attention) { + if (!is_masked) { + // Write the K values to the global memory cache. + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi * params.max_sequence_length * head_size + co * params.max_sequence_length * QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi * params.max_sequence_length * head_size + co * params.max_sequence_length * QK_ELTS_IN_16B + + tlength * QK_ELTS_IN_16B + ci; - // Trigger the stores to global memory. - *reinterpret_cast(¶ms_k_cache[offset]) = vec_conversion(k); + // Trigger the stores to global memory. + *reinterpret_cast(¶ms_k_cache[offset]) = vec_conversion(k); - // Compute \sum_i Q[i] * K^T[i] for the current timestep. - using Qk_vec_acum = Qk_vec_k; - qk = dot(q, k); + // Compute \sum_i Q[i] * K^T[i] for the current timestep. + using Qk_vec_acum = Qk_vec_k; + qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + if (QK_VECS_PER_WARP <= WARP_SIZE) { + #pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } } } - } - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - const float inv_sqrt_dh = params.scale; + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= inv_sqrt_dh; - qk_max = qk; - qk_smem[tlength] = qk; + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + qk = add_vec(qk, + reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length + * params.total_sequence_length + + tlength]); + } + qk_max = qk; + qk_smem[tlength] = qk; + } } // Make sure the data is in shared memory. @@ -332,6 +349,12 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara // Store the product to shared memory. There's one qk value per timestep. Update the max. if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add_vec(qk, + reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length + * params.total_sequence_length + + ti]); + } qk_max = fmaxf(qk_max, qk); qk_smem[ti] = qk; } @@ -370,7 +393,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara // Compute the logits and start the sum. float sum = 0.f; - for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + int sum_tlength = params.is_cross_attention ? tlength - 1 : tlength; + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { // This is a deviation from FasterTransformer kernel implementation // but this aligns with ORT's other Attention kernels which strives to // mimic PyTorch when dealing with mask filter values @@ -384,7 +408,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara // Normalize the logits. float inv_sum = __fdividef(1.f, sum + 1.e-6f); - for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) { float logit = qk_smem[ti] * inv_sum; ConvertFromFloat(logits_smem[ti], logit); } @@ -418,12 +442,14 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara // One group of threads computes the product(s) for the current timestep. V_vec_k v_bias; - zero(v_bias); + if (!params.is_mha) { + zero(v_bias); - T* params_v_bias = reinterpret_cast(params.v_bias); + T* params_v_bias = reinterpret_cast(params.v_bias); - if (vo == tlength % V_PER_ITER) { - v_bias = vec_conversion(*reinterpret_cast(¶ms_v_bias[hi * head_size + vi])); + if (vo == tlength % V_PER_ITER) { + v_bias = vec_conversion(*reinterpret_cast(¶ms_v_bias[hi * head_size + vi])); + } } // From previous, before values, step @@ -451,12 +477,14 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara } // One group of threads computes the product(s) for the current timestep. - if (vo == tlength % V_PER_ITER) { + if (vo == tlength % V_PER_ITER && !params.is_cross_attention) { const auto v_offset = qkv_base_offset + vi; V_vec_k v; v = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.v)[v_offset])); - v = add_vec(v, v_bias); + if (!params.is_mha) { + v = add_vec(v, v_bias); + } // Store the values with bias back to global memory in the cache for V. *reinterpret_cast(&v_cache[tlength * head_size]) = vec_conversion(v); @@ -497,33 +525,47 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara // Template instantiation(s) +// fp32 + head size = 32 +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); + +// fp16 + head size = 32 +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); + +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); + // fp32 + head size = 64 -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); // fp16 + head size = 64 -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); // fp32 + head size = 128 -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); // fp16 + head size = 128 -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); -template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h index fe1e0cb70252..6501103ed067 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h @@ -10,9 +10,13 @@ namespace onnxruntime { namespace contrib { namespace cuda { -struct DecoderMaskedSelfAttentionParams : AttentionParameters { +struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters { int beam_width = 1; + // Weather to use multihead attention(excludes matmul and bias) + bool is_mha = false; + bool is_cross_attention = false; + void* q = nullptr; void* q_bias = nullptr; @@ -22,6 +26,8 @@ struct DecoderMaskedSelfAttentionParams : AttentionParameters { void* v = nullptr; void* v_bias = nullptr; + void* relative_attention_bias = nullptr; + void* k_cache = nullptr; void* v_cache = nullptr; @@ -43,10 +49,10 @@ template< int THREADS_PER_VALUE, // The number of threads in a threadblock. int THREADS_PER_BLOCK> -__global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params); +__global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params); template -void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream); +void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream); diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h index 5c8da6da7dd2..42d54e38d41e 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h +++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h @@ -741,7 +741,7 @@ inline __device__ void ConvertFromFloat(uint4& dst, Float8_ src) { //------------------------------------------------------------ template -inline size_t CalcDynamicBlockMemory(const DecoderMaskedSelfAttentionParams& params, +inline size_t CalcDynamicBlockMemory(const DecoderMaskedMultiHeadAttentionParams& params, int threads_per_value, int threads_per_block) { // The amount of shared memory needed to store the Q*K^T values in float. diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 174dde63582f..f7e2b596f617 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -549,6 +549,121 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttentionTypeAndShapeInference(ctx, past_input_index); })); +constexpr const char* DecoderMaskedMultiHeadAttention_ver1_doc = R"DOC( +Multihead attention that supports input sequence length of 1. +Similar to DecoderMaskedSelfAttention but this op excludes QKV MatMul and Bias. +This op supports both Self and Cross Attention. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + DecoderMaskedMultiHeadAttention, 1, + OpSchema() + .SetDoc(DecoderMaskedMultiHeadAttention_ver1_doc) + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Attr("past_present_share_buffer", + "Corresponding past and present are same tensor, its size is " + "(batch_size, num_heads, max_sequence_length, head_size)", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("mask_filter_value", + "The value to be filled in the attention mask. Default value is -10000.0f", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Input(0, + "query", + "Query with shape (batch_size, 1, hidden_size)", + "T") + .Input(1, + "key", + "Key with shape (batch_size, 1, hidden_size) for self attention " + "or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention", + "T") + .Input(2, + "value", + "Value with shape (batch_size, 1, v_hidden_size) for self attention " + "or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention", + "T") + .Input(3, + "mask_index", + "Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)", + "M", + OpSchema::Optional) + .Input(4, + "relative_position_bias", + "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", + "T", + OpSchema::Optional) + .Input(5, + "past_key", + "past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attention" + "When past_present_share_buffer is set, " + "its shape is (batch_size, num_heads, max_sequence_length, head_size). " + "The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape " + "(batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape " + "(batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to " + "become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.", + "T", + OpSchema::Optional) + .Input(6, + "past_value", + "past state for value with shape (batch_size, num_heads, past_sequence_length, head_size) for self attention" + "When past_present_share_buffer is set, " + "its shape is (batch_size, num_heads, max_sequence_length, head_size). ", + "T", + OpSchema::Optional) + .Input(7, + "past_sequence_length", + "When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0)." + "Cross Attention doesn't need this input.", + "M", + OpSchema::Optional) + .Input(8, + "beam_width", + "The beam width that is being used while decoding." + "If not provided, the beam width will be assumed to be 1.", + "M", + OpSchema::Optional) + .Input(9, + "cache_indirection", + "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies" + "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration", + "M", + OpSchema::Optional) + .Output(0, + "output", + "3D output tensor with shape (batch_size, sequence_length, v_hidden_size)", + "T") + .Output(1, + "present_key", + "past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). " + "If past_present_share_buffer is set, " + "its shape is (batch_size, num_heads, max_sequence_length, head_size), " + "while effective_seq_length = (past_sequence_length + kv_sequence_length).", + "T", + OpSchema::Optional) + .Output(2, + "present_value", + "past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). " + "If past_present_share_buffer is set, " + "its shape is (batch_size, num_heads, max_sequence_length, head_size), " + "while effective_seq_length = (past_sequence_length + kv_sequence_length).", + "T", + OpSchema::Optional) + .TypeConstraint("T", + {"tensor(float)", "tensor(float16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("M", + {"tensor(int32)"}, + "Constrain mask index to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // TODO: + (void) (ctx); + })); + constexpr const char* MultiHeadAttention_ver1_doc = R"DOC( Multi-Head Self/Cross Attention. Bias from input projection is included. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index 3066804577f9..772f93ab1b96 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -100,6 +100,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Unique); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedMultiHeadAttention); class OpSet_Microsoft_ver1 { public: @@ -194,6 +195,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); } }; } // namespace contrib diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index 79f3435eaa19..44131eca2f88 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -43,6 +43,7 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, bool do_op_validation) const { ORT_UNUSED_PARAMETER(do_op_validation); const auto& inputs = node_unit.Inputs(); + ORT_RETURN_IF(inputs.size() != 2, "Gather should has 2 inputs at least!"); ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, is_quantized_model, input_names)); // Process indices @@ -53,54 +54,66 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; + std::string indices_input_name(input_name); Qnn_DataType_t qnn_data_type = QNN_DATATYPE_INT_32; const auto* type_proto = inputs[1].node_arg.TypeAsProto(); - ORT_RETURN_IF_ERROR(GetQnnDataType(is_quantized_model, type_proto, qnn_data_type)); + ORT_RETURN_IF_ERROR(GetQnnDataType(false, type_proto, qnn_data_type)); std::vector unpacked_tensor; std::vector gather_indices; bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); + + ORT_RETURN_IF(is_quantized_model && qnn_data_type == QNN_DATATYPE_INT_64 && !is_initializer_input, + "HTP backend doesn't support any int64 data type."); + if (is_initializer_input) { const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(*input_tensor, unpacked_tensor)); - } - - // For Quantized model, Gather indices use int32 without quantization - if (is_quantized_model) { if (qnn_data_type == QNN_DATATYPE_INT_64) { - if (!is_initializer_input) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Gather indices only support int32 type on Qnn NPU."); - } else { - // Convert initializer from int64 to int32 - size_t size = unpacked_tensor.size() / sizeof(int64_t); - const int64_t* gather_indices_int64 = reinterpret_cast(unpacked_tensor.data()); - gather_indices.resize(size * sizeof(int32_t)); - int32_t* gather_indices_int32 = reinterpret_cast(gather_indices.data()); - std::transform(gather_indices_int64, gather_indices_int64 + size, gather_indices_int32, - [](int64_t item) { return SafeInt(item); }); - qnn_data_type = QNN_DATATYPE_INT_32; - } + // Convert initializer from int64 to int32 + size_t size = unpacked_tensor.size() / sizeof(int64_t); + const int64_t* gather_indices_int64 = reinterpret_cast(unpacked_tensor.data()); + gather_indices.resize(size * sizeof(int32_t)); + int32_t* gather_indices_int32 = reinterpret_cast(gather_indices.data()); + std::transform(gather_indices_int64, gather_indices_int64 + size, gather_indices_int32, + [](int64_t item) { return SafeInt(item); }); } else { - qnn_data_type = QNN_DATATYPE_INT_32; gather_indices = std::move(unpacked_tensor); } - InitializeQuantizeParam(quantize_param, false); - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[1].quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - } else { - gather_indices = std::move(unpacked_tensor); + qnn_data_type = QNN_DATATYPE_INT_32; } + // Even for Quantized model, Gather indices use int32 without quantization + Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; + Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, input_shape), "Cannot get shape"); + std::vector cast_output_shape(input_shape); QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param, std::move(input_shape), std::move(gather_indices)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); - input_names.push_back(input_name); + + if (!is_initializer_input && qnn_data_type == QNN_DATATYPE_INT_64) { + // Insert cast node int64 -> int32 + if (qnn_data_type == QNN_DATATYPE_INT_64) { + // Add Cast node for indices + indices_input_name = input_name + "_cast"; + QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32, quantize_param, + std::move(cast_output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name, + qnn_def::package_name, + "Cast", + {input_name}, + {indices_input_name}, + {}, + do_op_validation), + "Failed to add node."); + } + } + + input_names.push_back(indices_input_name); return Status::OK(); } @@ -121,7 +134,6 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); // if indicies is scalar shape, then need to add Reshape node - ORT_ENFORCE(input_names.size() == 2, "Gather should has 2 inputs at least!"); const auto& input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[0]); const auto& indices_input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[1]); diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 52b41d52c166..d9c870a7dc52 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -647,7 +647,8 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { int sequence_length = 1; int number_of_heads = 12; // Vary head_size / hidden_size - for (int hidden_size = 768; hidden_size <= 1536; hidden_size += 768) { + int hidden_sizes[3] = {384, 768, 1536}; + for (int hidden_size : hidden_sizes) { int head_size = (hidden_size / number_of_heads); int total_sequence_length = sequence_length + past_sequence_length; int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length @@ -760,7 +761,8 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { int number_of_heads = 12; // Vary head_size / hidden_size - for (int hidden_size = 768; hidden_size <= 1536; hidden_size += 768) { + int hidden_sizes[3] = {384, 768, 1536}; + for (int hidden_size : hidden_sizes) { int head_size = (hidden_size / number_of_heads); int total_sequence_length = sequence_length + past_sequence_length; int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 2b2c8074000a..bd8d7707902b 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -195,6 +195,80 @@ GetQDQTestCaseFn BuildQDQReduceOpTestCase(const std::string& reduce_op_type, con }; } +// Creates the following graph: +// _______________________ +// input (f32) -> Q -> DQ -> | | -> Q -> DQ -> output (f32) +// axes (int32, initializer) -> | Gather | +// |_______________________| +// +template +GetQDQTestCaseFn BuildQDQGatherOpTestCase(const std::vector& input_shape, + const std::vector indices, + const std::vector& indices_shape, + int64_t axis) { + return [input_shape, indices, indices_shape, axis](ModelTestBuilder& builder) { + + auto* input_data = builder.MakeInput(input_shape, -1.0f, 1.0f); + auto* final_output = builder.MakeOutput(); + + // input_data -> Q/DQ -> + auto* input_qdq_output = AddQDQNodePair(builder, input_data, .003f, 1); + + std::vector gather_op_inputs; + gather_op_inputs.push_back(input_qdq_output); + + auto* indices_input = builder.MakeInitializer(indices_shape, indices); + + auto* gather_output = builder.MakeIntermediate(); + Node& gather_node = builder.AddNode("Gather", {input_qdq_output, indices_input}, {gather_output}); + gather_node.AddAttribute("axis", axis); + + // -> Q/DQ -> final_output + auto* q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(gather_output, .003f, 1, + q_output); + + builder.AddDequantizeLinearNode(q_output, .003f, 1, + final_output); + }; +} + +// Creates the following graph: +// _______________________ +// input (f32) -> Q -> DQ -> | | -> Q -> DQ -> output (f32) +// axes (int32, initializer) -> | Gather | +// |_______________________| +// +template +GetQDQTestCaseFn BuildQDQGatherOpScalarIndicesTestCase(const std::vector& input_shape, + const IndicesType indices, + int64_t axis) { + return [input_shape, indices, axis](ModelTestBuilder& builder) { + auto* input_data = builder.MakeInput(input_shape, -1.0f, 1.0f); + auto* final_output = builder.MakeOutput(); + + // input_data -> Q/DQ -> + auto* input_qdq_output = AddQDQNodePair(builder, input_data, .003f, 1); + + std::vector gather_op_inputs; + gather_op_inputs.push_back(input_qdq_output); + + auto* indices_input = builder.MakeScalarInitializer(indices); + + auto* gather_output = builder.MakeIntermediate(); + Node& gather_node = builder.AddNode("Gather", {input_qdq_output, indices_input}, {gather_output}); + gather_node.AddAttribute("axis", axis); + + // -> Q/DQ -> final_output + auto* q_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(gather_output, .003f, 1, + q_output); + + builder.AddDequantizeLinearNode(q_output, .003f, 1, + final_output); + }; +} + template GetQDQTestCaseFn BuildQDQConvTestCase(const std::vector& input_shape, const std::vector& weights_shape) { return [input_shape, weights_shape](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc new file mode 100644 index 000000000000..09d98fab0e95 --- /dev/null +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include "core/graph/graph.h" + +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +/** + * Runs a Gather op model on the QNN HTP backend. Checks the graph node assignment, and that inference + * outputs for QNN and CPU match. + * + * \param op_type The Gather op type (e.g., ReduceSum). + * \param opset The opset version. + * \param test_description Description of the test for error reporting. + * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None) + */ +template +static void RunGatherOpQDQTest(int opset, const char* test_description, bool scalar_indices = false, + ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + constexpr int expected_nodes_in_partition = 1; + if (scalar_indices) { + RunQnnModelTest(BuildQDQGatherOpScalarIndicesTestCase({2, 3, 4},// input shape + 1, // indices + 1), // axis + provider_options, + opset, + expected_ep_assignment, + expected_nodes_in_partition, + test_description); + } else { + RunQnnModelTest(BuildQDQGatherOpTestCase({2, 3, 4}, // input shape + std::vector{1}, // indices + {1}, // indices_shape + 1), // axis + provider_options, + opset, + expected_ep_assignment, + expected_nodes_in_partition, + test_description); + } +} + +// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all +// nodes are supported by the QNN EP, and that the inference results match the CPU EP results. +// +// - Uses uint8 as the quantization type. +TEST_F(QnnHTPBackendTests, TestQDQGatherOpU8) { + RunGatherOpQDQTest(11, "TestQDQGatherOpU8"); +} + +// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all +// nodes are supported by the QNN EP, and that the inference results match the CPU EP results. +// +// - Uses int8 as the quantization type. +TEST_F(QnnHTPBackendTests, TestQDQGatherOpI8) { + RunGatherOpQDQTest(11, "TestQDQGatherOpI8"); +} + +// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all +// nodes are supported by the QNN EP, and that the inference results match the CPU EP results. +// +// - Uses uint8 as the quantization type. +TEST_F(QnnHTPBackendTests, TestQDQGatherOpScalarIndicesU8) { + RunGatherOpQDQTest(11, "TestQDQGatherOpScalarIndicesU8", true); +} + +// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all +// nodes are supported by the QNN EP, and that the inference results match the CPU EP results. +// +// - Uses int8 as the quantization type. +TEST_F(QnnHTPBackendTests, TestQDQGatherOpScalarIndicesI8) { + RunGatherOpQDQTest(11, "TestQDQGatherOpScalarIndicesI8", true); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime + +#endif \ No newline at end of file diff --git a/onnxruntime/test/python/transformers/test_parity_t5_mha.py b/onnxruntime/test/python/transformers/test_parity_t5_mha.py index 51d5ba7838d9..23218e494304 100644 --- a/onnxruntime/test/python/transformers/test_parity_t5_mha.py +++ b/onnxruntime/test/python/transformers/test_parity_t5_mha.py @@ -154,6 +154,113 @@ def create_t5_mha_graph( return model.SerializeToString() +# For decoder only (not decoder_init) starting from second iteration +def create_t5_decoder_masked_mha_graph( + batch_size, + past_sequence_length, + kv_sequence_length, + head_size, + num_heads, + is_cross_attention, +): + from onnx import TensorProto, helper + + nodes = [ + helper.make_node( + "DecoderMaskedMultiHeadAttention", + [ + "query", + "key", + "value", + "mask_index" if is_cross_attention else "", + "relative_position_bias" if not is_cross_attention else "", + "past_key" if not is_cross_attention else "", + "past_value" if not is_cross_attention else "", + "past_sequence_length" if not is_cross_attention else "", + ], + [ + "output", + "present_key" if not is_cross_attention else "", + "present_value" if not is_cross_attention else "", + ], + "DMMHA_0", + num_heads=num_heads, + mask_filter_value=-10000.0, + scale=1.0, + past_present_share_buffer=0 if is_cross_attention else 1, + domain="com.microsoft", + ), + ] + + initializers = [] + + hidden_size = head_size * num_heads + + graph_inputs = [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, 1, hidden_size]), + ] + + graph_outputs = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, 1, hidden_size]), + ] + + if is_cross_attention: + graph_inputs.append( + helper.make_tensor_value_info("mask_index", TensorProto.INT32, [batch_size, kv_sequence_length]) + ) + graph_inputs.append( + helper.make_tensor_value_info( + "key", TensorProto.FLOAT, [batch_size, num_heads, kv_sequence_length, head_size] + ) + ) + graph_inputs.append( + helper.make_tensor_value_info( + "value", TensorProto.FLOAT, [batch_size, num_heads, kv_sequence_length, head_size] + ) + ) + else: + graph_inputs.append(helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, 1, hidden_size])) + graph_inputs.append(helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, 1, hidden_size])) + graph_inputs.append( + helper.make_tensor_value_info( + "relative_position_bias", TensorProto.FLOAT, [1, num_heads, 1, past_sequence_length + 1] + ) + ) + # use past_sequence_length + 1 to simulate max_sequence_length + graph_inputs.append( + helper.make_tensor_value_info( + "past_key", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size] + ) + ) + graph_inputs.append( + helper.make_tensor_value_info( + "past_value", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size] + ) + ) + graph_inputs.append(helper.make_tensor_value_info("past_sequence_length", TensorProto.INT32, [1])) + graph_outputs.append( + helper.make_tensor_value_info( + "present_key", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size] + ) + ) + graph_outputs.append( + helper.make_tensor_value_info( + "present_value", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size] + ) + ) + + graph = helper.make_graph( + nodes, + "T5_DMMHA_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + class T5Config: def __init__(self, is_decoder, batch_size, seq_len, kv_sequence_length, num_heads, head_size, use_past): self.is_decoder = is_decoder @@ -173,7 +280,7 @@ def __init__(self, is_decoder, batch_size, seq_len, kv_sequence_length, num_head class T5Attention(nn.Module): - def __init__(self, config: T5Config, is_static_kv): + def __init__(self, config: T5Config, is_static_kv, use_decoder_masked_kernel: bool = False): super().__init__() self.is_decoder = config.is_decoder self.is_static_kv = is_static_kv @@ -199,17 +306,52 @@ def __init__(self, config: T5Config, is_static_kv): self.num_heads = config.num_heads self.hidden_size = self.d_model self.use_past = config.use_past + self.use_decoder_masked_kernel = use_decoder_masked_kernel # Create onnx graph - self.onnx_graph = create_t5_mha_graph( - self.batch_size, - self.seq_len, - self.kv_sequence_length, - self.head_size, - self.num_heads, - self.use_past, - is_static_kv, - ) + if self.use_decoder_masked_kernel: + self.onnx_graph = create_t5_decoder_masked_mha_graph( + self.batch_size, + self.kv_sequence_length, + self.kv_sequence_length, + self.head_size, + self.num_heads, + is_static_kv, + ) + else: + self.onnx_graph = create_t5_mha_graph( + self.batch_size, + self.seq_len, + self.kv_sequence_length, + self.head_size, + self.num_heads, + self.use_past, + is_static_kv, + ) + + # Reorder 'K' from [B, N, S, H] to [B, N, H/4, S, 4] + def reorder_key_cache(self, key_cache, batch_size, num_heads, sequence_length, head_size, max_sequence_length): + ordered = np.zeros_like(key_cache) + + # assume float + num_inner_elements = 4 + chunks = int(head_size / num_inner_elements) + + for b in range(batch_size): + for h in range(num_heads): + for c in range(chunks): + for s in range(sequence_length): + base_offset = (b * num_heads * max_sequence_length * head_size) + ( + h * max_sequence_length * head_size + ) + input_base_offset = base_offset + (s * head_size) + (c * num_inner_elements) + output_base_offset = ( + base_offset + (c * max_sequence_length * num_inner_elements) + (s * num_inner_elements) + ) + for e in range(num_inner_elements): + ordered[output_base_offset + e] = key_cache[input_base_offset + e] + + return ordered def create_inputs(self): hidden_states = torch.normal(mean=0.5, std=0.1, size=(self.batch_size, self.seq_len, self.hidden_size)).to( @@ -230,6 +372,10 @@ def create_inputs(self): position_bias = torch.normal( mean=0.5, std=0.1, size=(1, self.num_heads, position_bias_length, position_bias_length) ).to(torch.float32) + if self.use_decoder_masked_kernel: + position_bias = torch.normal(mean=5, std=0.1, size=(1, self.num_heads, 1, position_bias_length)).to( + torch.float32 + ) return hidden_states, key_value_states, past_key_value, attention_mask, position_bias def torch_forward( @@ -302,6 +448,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): key_states = project( hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ) + value_states = project( hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None ) @@ -421,16 +568,57 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): ort_inputs = { "query": np.ascontiguousarray(query_states.detach().numpy()), } + torch_past_key = np.ascontiguousarray(torch_past_key.detach().numpy()) + torch_past_value = np.ascontiguousarray(torch_past_value.detach().numpy()) + max_seq_len = torch_past_key.shape[2] + 1 + torch_past_key_padded = np.zeros( + [torch_past_key.shape[0], torch_past_key.shape[1], max_seq_len, torch_past_key.shape[3]], + dtype=np.float32, + ) + torch_past_value_padded = np.zeros( + [torch_past_value.shape[0], torch_past_value.shape[1], max_seq_len, torch_past_value.shape[3]], + dtype=np.float32, + ) + torch_past_key_padded[:, :, : torch_past_key.shape[2], :] = torch_past_key + torch_past_value_padded[:, :, : torch_past_value.shape[2], :] = torch_past_value if self.is_static_kv: - ort_inputs["key"] = np.ascontiguousarray(torch_past_key.detach().numpy()) - ort_inputs["value"] = np.ascontiguousarray(torch_past_value.detach().numpy()) + if self.use_decoder_masked_kernel: + reordered_past_key = self.reorder_key_cache( + torch_past_key.flatten(), + batch_size=batch_size, + num_heads=self.num_heads, + sequence_length=self.kv_sequence_length, + head_size=self.head_size, + max_sequence_length=self.kv_sequence_length, + ) + ort_inputs["key"] = reordered_past_key.reshape(torch_past_key.shape) + ort_inputs["value"] = torch_past_value + else: + ort_inputs["key"] = np.ascontiguousarray(torch_past_key) + ort_inputs["value"] = np.ascontiguousarray(torch_past_value) else: - ort_inputs["past_key"] = np.ascontiguousarray(torch_past_key.detach().numpy()) - ort_inputs["past_value"] = np.ascontiguousarray(torch_past_value.detach().numpy()) ort_inputs["key"] = np.ascontiguousarray(key_states.detach().numpy()) ort_inputs["value"] = np.ascontiguousarray(value_states.detach().numpy()) + if self.use_decoder_masked_kernel: + reordered_past_key = self.reorder_key_cache( + torch_past_key_padded.flatten(), + batch_size=batch_size, + num_heads=self.num_heads, + sequence_length=self.kv_sequence_length, + head_size=self.head_size, + max_sequence_length=max_seq_len, + ) + ort_inputs["past_key"] = reordered_past_key.reshape(torch_past_value_padded.shape) + ort_inputs["past_value"] = torch_past_value_padded + ort_inputs["past_sequence_length"] = np.array([self.kv_sequence_length], dtype=np.int32) + else: + ort_inputs["past_key"] = torch_past_key + ort_inputs["past_value"] = torch_past_value if torch_key_padding_mask is not None: - ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy()) + if self.use_decoder_masked_kernel: + ort_inputs["mask_index"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy()) + else: + ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy()) if torch_position_bias is not None: ort_inputs["relative_position_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy()) @@ -445,7 +633,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): return output -def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length): +def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False): config = T5Config( is_decoder=True, batch_size=batch_size, @@ -455,7 +643,8 @@ def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size head_size=head_size, use_past=True, ) - T5CrossAttention = T5Attention(config, is_static_kv=True) # noqa: N806 + + T5CrossAttention = T5Attention(config, is_static_kv=True, use_decoder_masked_kernel=use_dmmha) # noqa: N806 hidden_states, key_value_states, past_key_value, attention_mask, _ = T5CrossAttention.create_inputs() torch_output = T5CrossAttention.torch_forward( @@ -521,7 +710,7 @@ def compare_t5_self_attention_decoder_init(batch_size, seq_len, num_heads, head_ assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4) -def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length): +def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False): config = T5Config( is_decoder=True, batch_size=batch_size, @@ -531,7 +720,8 @@ def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, head_size=head_size, use_past=True, ) - T5CrossAttention = T5Attention(config, is_static_kv=False) # noqa: N806 + + T5CrossAttention = T5Attention(config, is_static_kv=False, use_decoder_masked_kernel=use_dmmha) # noqa: N806 hidden_states, _, past_key_value, _, position_bias = T5CrossAttention.create_inputs() torch_output = T5CrossAttention.torch_forward( @@ -543,8 +733,9 @@ def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, if ort_output is not None: assert torch.allclose(torch_output[0], ort_output[0], atol=1e-4) - assert torch.allclose(torch_output[1][0], ort_output[1][0], atol=1e-4) - assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4) + if not use_dmmha: + assert torch.allclose(torch_output[1][0], ort_output[1][0], atol=1e-4) + assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4) class TestT5MHAParity(unittest.TestCase): @@ -575,6 +766,24 @@ def test_t5_self_attention_decoder(self): self.batch_size, self.seq_len, self.num_heads, self.head_size, self.kv_sequence_length ) + def test_t5_cross_attention_decoder_masked_mha(self): + batch_size = 2 + seq_len = 1 + num_heads = 2 + head_size = 32 + kv_sequence_length = 2 + compare_t5_cross_attention_decoder( + batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True + ) + + def test_t5_self_attention_decoder_masked_mha(self): + batch_size = 2 + seq_len = 1 + num_heads = 2 + head_size = 32 + kv_sequence_length = 2 + compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True) + if __name__ == "__main__": unittest.main() diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 77ec91824d1e..1b598a405ed7 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -190,23 +190,3 @@ stages: WITH_CACHE: true MachinePool: 'onnxruntime-Win2019-CPU-training' -- stage: x64_release_azure - dependsOn: [] - jobs: - - template: templates/win-ci-vs-2019.yml - parameters: - BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_azure.bat - buildArch: x64 - additionalBuildFlags: --use_azure - msbuildPlatform: x64 - isX86: false - job_name_suffix: x64_release_azure - RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false - EnablePython: false - isTraining: false - ORT_EP_NAME: CPU - GenerateDocumentation: false - WITH_CACHE: true - MachinePool: 'Win-CPU-2019'