From fbfe92f66af3227b87b3f6a981577225e10abb6b Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Sun, 2 Apr 2023 21:53:03 -0700
Subject: [PATCH 1/4] DecoderMaskedMultiHeadAttention enhancement (#15292)
---
cmake/onnxruntime_rocm_hipify.cmake | 3 +
docs/ContribOperators.md | 70 +++++
docs/OperatorKernels.md | 1 +
.../cpu/bert/multihead_attention_helper.h | 15 +-
.../cuda/bert/multihead_attention.cc | 2 +
.../contrib_ops/cuda/cuda_contrib_kernels.cc | 4 +
.../decoder_masked_multihead_attention.cc | 223 ++++++++++++++++
.../decoder_masked_multihead_attention.h | 29 ++
.../decoder/decoder_masked_self_attention.cc | 6 +-
.../decoder_masked_multihead_attention_128.cu | 6 +-
.../decoder_masked_multihead_attention_32.cu | 63 +++++
.../decoder_masked_multihead_attention_64.cu | 6 +-
...decoder_masked_multihead_attention_impl.cu | 192 ++++++++------
.../decoder_masked_multihead_attention_impl.h | 12 +-
...er_masked_multihead_attention_impl_utils.h | 2 +-
.../core/graph/contrib_ops/bert_defs.cc | 115 ++++++++
onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 +
...oder_masked_multihead_attention_op_test.cc | 6 +-
.../python/transformers/test_parity_t5_mha.py | 251 ++++++++++++++++--
19 files changed, 897 insertions(+), 111 deletions(-)
create mode 100644 onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc
create mode 100644 onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h
create mode 100644 onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index ebd6229204bb..7bf641d15154 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -90,11 +90,14 @@ set(contrib_ops_excluded_files
"cuda_contrib_kernels.h"
"inverse.cc"
"fused_conv.cc"
+ "decoder/decoder_masked_multihead_attention.h"
+ "decoder/decoder_masked_multihead_attention.cc"
"decoder/decoder_masked_self_attention.h"
"decoder/decoder_masked_self_attention.cc"
"decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h"
"decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention.h"
"decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu"
+ "decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu"
"decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu"
"decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu"
)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 824df7728207..cb7823f06b4c 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -20,6 +20,7 @@ Do not modify directly.*
* com.microsoft.ConvTransposeWithDynamicPads
* com.microsoft.CropAndResize
* com.microsoft.DecoderAttention
+ * com.microsoft.DecoderMaskedMultiHeadAttention
* com.microsoft.DecoderMaskedSelfAttention
* com.microsoft.DequantizeBFP
* com.microsoft.DequantizeLinear
@@ -1102,6 +1103,75 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.DecoderMaskedMultiHeadAttention**
+
+ Multihead attention that supports input sequence length of 1.
+ Similar to DecoderMaskedSelfAttention but this op excludes QKV MatMul and Bias.
+ This op supports both Self and Cross Attention.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- mask_filter_value : float
+- The value to be filled in the attention mask. Default value is -10000.0f
+- num_heads : int (required)
+- Number of attention heads
+- past_present_share_buffer : int
+- Corresponding past and present are same tensor, its size is (batch_size, num_heads, max_sequence_length, head_size)
+- scale : float
+- Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
+
+#### Inputs (3 - 10)
+
+
+- query : T
+- Query with shape (batch_size, 1, hidden_size)
+- key : T
+- Key with shape (batch_size, 1, hidden_size) for self attention or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention
+- value : T
+- Value with shape (batch_size, 1, v_hidden_size) for self attention or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention
+- mask_index (optional) : M
+- Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)
+- relative_position_bias (optional) : T
+- additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+- past_key (optional) : T
+- past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
+- past_value (optional) : T
+- past state for value with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size).
+- past_sequence_length (optional) : M
+- When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.
+- beam_width (optional) : M
+- The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+- cache_indirection (optional) : M
+- A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
+
+
+#### Outputs (1 - 3)
+
+
+- output : T
+- 3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
+- present_key (optional) : T
+- past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+- present_value (optional) : T
+- past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16)
+- Constrain input and output types to float tensors.
+- M : tensor(int32)
+- Constrain mask index to integer types
+
+
+
### **com.microsoft.DecoderMaskedSelfAttention**
Self attention that supports input sequence length of 1.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d4cd2f936389..050d84b19cc9 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -798,6 +798,7 @@ Do not modify directly.*
|ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
+|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index cc7dad81b4dc..fe1c57e5711f 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -20,10 +20,12 @@ Status CheckInputs(const T* query,
const T* relative_position_bias,
const T* past_key,
const T* past_value,
+ const T* past_seq_len,
void* parameters,
int num_heads,
float mask_filter_value,
float scale,
+ bool past_present_share_buffer,
int max_threads_per_block) {
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
@@ -59,6 +61,7 @@ Status CheckInputs(const T* query,
int kv_sequence_length = sequence_length;
int past_sequence_length = 0;
+ int max_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
const auto& past_key_dims = past_key->Shape().GetDims();
const auto& past_value_dims = past_value->Shape().GetDims();
@@ -110,6 +113,14 @@ Status CheckInputs(const T* query,
past_value_dims[3]);
}
past_sequence_length = static_cast(past_key_dims[2]);
+ max_sequence_length = static_cast(past_key_dims[2]);
+ if (past_present_share_buffer) {
+ if (past_seq_len == nullptr || !onnxruntime::IsScalarOr1ElementVector(past_seq_len)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "past_sequence_length tensor must be of one element when past_present_share_buffer is set");
+ }
+ past_sequence_length = *((*past_seq_len).template Data());
+ }
} else if (past_key != nullptr || past_value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall be both present or both absent");
@@ -277,7 +288,7 @@ Status CheckInputs(const T* query,
output_parameters->past_sequence_length = past_sequence_length;
output_parameters->kv_sequence_length = kv_sequence_length;
output_parameters->total_sequence_length = total_sequence_length;
- output_parameters->max_sequence_length = 0;
+ output_parameters->max_sequence_length = max_sequence_length;
output_parameters->input_hidden_size = 0;
output_parameters->hidden_size = hidden_size;
output_parameters->v_hidden_size = v_hidden_size;
@@ -285,7 +296,7 @@ Status CheckInputs(const T* query,
output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
output_parameters->is_unidirectional = false;
- output_parameters->past_present_share_buffer = false;
+ output_parameters->past_present_share_buffer = past_present_share_buffer;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
output_parameters->scale = scale;
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index f077d56f03b7..7c4b65b11372 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -88,10 +88,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
relative_position_bias,
past_key,
past_value,
+ nullptr, // past_seq_len
¶meters,
num_heads_,
mask_filter_value_,
scale_,
+ false, // past_present_share_buffer
device_prop.maxThreadsPerBlock));
int sequence_length = parameters.sequence_length;
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 7a14267176eb..0b800c78bc2d 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -133,6 +133,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrd
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention);
#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
@@ -279,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifdef ENABLE_ATEN
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc
new file mode 100644
index 000000000000..6130fd9eeb48
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc
@@ -0,0 +1,223 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "core/platform/env_var_utils.h"
+#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
+#include "contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h"
+#include "contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+// TODO: refactor
+static constexpr int kPastSequenceLengthInputIndex = 7;
+static constexpr int kBeamWidthInputIndex = 8;
+static constexpr int kCacheIndirectionInputIndex = 9;
+static constexpr int kPastInputIndex = 5;
+static constexpr int kPresentOutputIndex = 1;
+
+#define REGISTER_KERNEL_TYPED(T1, T2) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ DecoderMaskedMultiHeadAttention, \
+ kMSDomain, \
+ 1, \
+ T1, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .MayInplace(kPastInputIndex, kPresentOutputIndex) \
+ .MayInplace(kPastInputIndex + 1, kPresentOutputIndex + 1) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \
+ .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \
+ DecoderMaskedMultiHeadAttention);
+
+REGISTER_KERNEL_TYPED(float, float)
+REGISTER_KERNEL_TYPED(MLFloat16, uint16_t)
+
+template
+DecoderMaskedMultiHeadAttention::DecoderMaskedMultiHeadAttention(const OpKernelInfo& info) : CudaKernel(info) {
+ int64_t num_heads = 0;
+ ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
+ num_heads_ = static_cast(num_heads);
+ mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
+ scale_ = info.GetAttrOrDefault("scale", 0.0f);
+ past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL);
+}
+
+template
+Status DecoderMaskedMultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* query = context->Input(0);
+ const Tensor* key = context->Input(1);
+ const Tensor* value = context->Input(2);
+ const Tensor* mask_index = context->Input(3);
+ const Tensor* relative_position_bias = context->Input(4);
+ const Tensor* past_key = context->Input(kPastInputIndex);
+ const Tensor* past_value = context->Input(kPastInputIndex + 1);
+ const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex);
+ const Tensor* beam_width = context->Input(kBeamWidthInputIndex);
+ const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex);
+
+ auto& device_prop = GetDeviceProp();
+ DecoderMaskedMultiHeadAttentionParams parameters;
+ ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query,
+ key,
+ value,
+ nullptr, //bias
+ mask_index,
+ relative_position_bias,
+ past_key,
+ past_value,
+ past_seq_len,
+ ¶meters,
+ num_heads_,
+ mask_filter_value_,
+ scale_,
+ past_present_share_buffer_,
+ device_prop.maxThreadsPerBlock));
+
+ int batch_size = parameters.batch_size;
+ int sequence_length = parameters.sequence_length;
+
+ // This kernel is for decoding only (i.e.) sequence length has to be 1
+ if (sequence_length != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input sequence length should be 1 to use DecoderMaskedMultiHeadAttention");
+ }
+
+ if (parameters.head_size != parameters.v_head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "QK head size should be same as V head size to use DecoderMaskedMultiHeadAttention");
+ }
+
+ if (parameters.mask_type != AttentionMaskType::MASK_2D_KEY_PADDING &&
+ parameters.mask_type != AttentionMaskType::MASK_NONE) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+ "DecoderMaskedMultiHeadAttention only supports no mask or 2D key "
+ "padding mask of shape [batch, total_seq_length] currently");
+ }
+
+ TensorShapeVector output_shape(3);
+ output_shape[0] = static_cast(batch_size);
+ output_shape[1] = static_cast(sequence_length);
+ output_shape[2] = static_cast(parameters.v_hidden_size);
+ Tensor* output = context->Output(0, output_shape);
+
+ // Present input will have the same shape as the past input
+ Tensor* present_key = context->Output(kPresentOutputIndex, past_key->Shape());
+ Tensor* present_value = context->Output(kPresentOutputIndex + 1, past_value->Shape());
+
+ auto cuda_stream = Stream(context);
+
+ parameters.is_mha = true;
+
+ // Update the q buffers
+ parameters.q = const_cast(query->Data());
+
+ // Update the relative position bias for self attention
+ if (relative_position_bias != nullptr) {
+ parameters.relative_attention_bias = const_cast(relative_position_bias->Data());
+ }
+
+ // Decoder cross-attention
+ if (past_key == nullptr && present_key == nullptr) {
+ if (relative_position_bias != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+ "DecoderMaskedMultiHeadAttention does not support relative position bias for cross-attention");
+ }
+
+ parameters.is_cross_attention = true;
+ parameters.total_sequence_length = parameters.kv_sequence_length;
+ parameters.max_sequence_length = parameters.kv_sequence_length;
+ // parameters.k and paraneters.v are nullptr
+ parameters.k_cache = const_cast(key->Data());
+ parameters.v_cache = const_cast(value->Data());
+ } else {
+ // Sanity check
+ ORT_ENFORCE(past_present_share_buffer_);
+
+ auto* present_key_data = present_key->MutableData();
+ auto* present_value_data = present_value->MutableData();
+ auto* past_key_data = past_key->Data();
+ auto* past_value_data = past_value->Data();
+
+ // No production use-case will incur this copy cost as the implementation of
+ // GreedySearch/BeamSearch is written in such a way that the past and present buffers
+ // will be shared.
+ // This is just to circumvent the OpTester's limitation of not being able to bind a specific
+ // buffer to inputs/outputs.
+ if (present_key_data != past_key_data) {
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(present_key_data, past_key_data, past_key->SizeInBytes(),
+ cudaMemcpyDeviceToDevice, cuda_stream));
+ }
+ if (present_value_data != past_value_data) {
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(present_value_data, past_value_data, past_value->SizeInBytes(),
+ cudaMemcpyDeviceToDevice, cuda_stream));
+ }
+
+ parameters.is_cross_attention = false;
+
+ parameters.k = const_cast(key->Data());
+ parameters.v = const_cast(value->Data());
+ parameters.k_cache = present_key_data;
+ parameters.v_cache = present_value_data;
+ }
+
+ parameters.out = output->MutableDataRaw();
+
+ // Scale
+ // If the scale is not provided - use `1/sqrt(head_size)`
+ if (parameters.scale == 0.f) {
+ parameters.scale = 1.f / sqrtf(static_cast(parameters.head_size));
+ }
+
+ // Mask
+ if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
+ parameters.mask = mask_index->Data();
+ }
+
+ // Beam width (in case we are using this op inside BeamSearch)
+ if (beam_width != nullptr) {
+ parameters.beam_width = static_cast(*beam_width->Data());
+ }
+
+ // Cache indirection (in case we are using this op inside BeamSearch)
+ if (parameters.beam_width > 1) {
+ // If beam width > 1, then cache indirection buffer MUST be present
+ if (cache_indir == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "If beam width is greater than 1, then cache indirection buffer MUST be present");
+ }
+
+ parameters.cache_indir = cache_indir->Data();
+ }
+
+ switch (parameters.head_size) {
+ case 32:
+ mmha_launch_kernel(parameters, cuda_stream);
+ break;
+
+ case 64:
+ mmha_launch_kernel(parameters, cuda_stream);
+ break;
+
+ case 128:
+ mmha_launch_kernel(parameters, cuda_stream);
+ break;
+
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
+ "Unsupported head size in DecoderMaskedMultiHeadAttention. "
+ "Got head size: ",
+ parameters.head_size);
+ }
+ return Status::OK();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h
new file mode 100644
index 000000000000..8200a66db383
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.h
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class DecoderMaskedMultiHeadAttention final : public CudaKernel {
+ public:
+ DecoderMaskedMultiHeadAttention(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ protected:
+ int num_heads_; // number of attention heads
+ float mask_filter_value_;
+ float scale_;
+ bool past_present_share_buffer_;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc
index bba764970322..98f4642e7903 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_self_attention.cc
@@ -51,7 +51,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont
const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex);
auto& device_prop = GetDeviceProp();
- DecoderMaskedSelfAttentionParams parameters;
+ DecoderMaskedMultiHeadAttentionParams parameters;
ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(),
weights->Shape(),
bias->Shape(),
@@ -186,6 +186,10 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont
}
switch (parameters.head_size) {
+ case 32:
+ mmha_launch_kernel(parameters, cuda_stream);
+ break;
+
case 64:
mmha_launch_kernel(parameters, cuda_stream);
break;
diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu
index 4cf00e222b0e..3582758d1dab 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu
+++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu
@@ -40,7 +40,7 @@ using namespace decoder_masked_self_attention_details;
<<>>(params)
template
-void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream) {
+void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream) {
constexpr int THREADS_PER_VALUE = ThreadsPerValue::value;
int total_sequence_length = params.total_sequence_length;
@@ -54,9 +54,9 @@ void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStre
}
// Instantiate templates
-template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream);
+template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
-template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream);
+template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu
new file mode 100644
index 000000000000..3d295116252f
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu
@@ -0,0 +1,63 @@
+/*
+ * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer
+ *
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Modifications Copyright (c) Microsoft.
+// Licensed under the MIT License.
+
+#include "decoder_masked_multihead_attention_impl.h"
+#include "decoder_masked_multihead_attention_impl_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace decoder_masked_self_attention_details;
+
+#define MMHA_LAUNCH_KERNEL( \
+ T, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \
+ size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
+ dim3 grid(params.num_heads, params.batch_size); \
+ masked_multihead_attention_kernel \
+ <<>>(params)
+
+template
+void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream) {
+ constexpr int THREADS_PER_VALUE = ThreadsPerValue::value;
+ int total_sequence_length = params.total_sequence_length;
+
+ if (total_sequence_length < 32) {
+ MMHA_LAUNCH_KERNEL(T, head_size, 4, THREADS_PER_VALUE, 64);
+ } else if (total_sequence_length < 2048) {
+ MMHA_LAUNCH_KERNEL(T, head_size, 2, THREADS_PER_VALUE, 128);
+ } else {
+ MMHA_LAUNCH_KERNEL(T, head_size, 1, THREADS_PER_VALUE, 256);
+ }
+}
+
+// Instantiate templates
+template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
+
+template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu
index 325681b0e1de..e5f57fac73cf 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu
+++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu
@@ -40,7 +40,7 @@ using namespace decoder_masked_self_attention_details;
<<>>(params)
template
-void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream) {
+void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream) {
constexpr int THREADS_PER_VALUE = ThreadsPerValue::value;
int total_sequence_length = params.total_sequence_length;
@@ -54,9 +54,9 @@ void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStre
}
// Instantiate templates
-template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream);
+template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
-template void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream);
+template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
index 26bc6f53b4c2..ea4e10519993 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
@@ -46,7 +46,7 @@ template <
int THREADS_PER_VALUE,
// The number of threads in a threadblock.
int THREADS_PER_BLOCK>
-__global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params) {
+__global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params) {
// This kernel contains some code that cannot be compiled on CUDA ARCH 5.3 or lower
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530
(void)(params);
@@ -137,13 +137,15 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
float qk = 0.0F;
- int qkv_base_offset = bi * (3 * params.hidden_size) + hi * head_size;
+ int qkv_base_offset = params.is_mha
+ ? bi * params.hidden_size + hi * head_size
+ : bi * (3 * params.hidden_size) + hi * head_size;
const size_t bi_total_seq_length = bi * params.total_sequence_length;
const size_t bi_max_seq_length = bi * params.max_sequence_length;
- int tlength = params.past_sequence_length;
+ int tlength = params.is_cross_attention ? params.kv_sequence_length : params.past_sequence_length;
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const bool is_masked = tidx >= QK_VECS_PER_WARP;
@@ -151,9 +153,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
// The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
- // The offset in the bias buffer.
- int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
-
// Trigger the loads from the Q and K buffers.
Qk_vec_k q;
zero(q);
@@ -163,81 +162,99 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
}
Qk_vec_k k;
- zero(k);
- if (!is_masked) {
- k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]));
+ if (!params.is_cross_attention) {
+ zero(k);
+
+ if (!is_masked) {
+ k = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]));
+ }
}
// Trigger the loads from the Q and K bias buffers.
Qk_vec_k q_bias;
- zero(q_bias);
+ Qk_vec_k k_bias;
+ if (!params.is_mha) {
+ // The offset in the bias buffer.
+ int qk_bias_offset = hi * head_size + tidx * QK_VEC_SIZE;
- if (!is_masked) {
- q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset]));
- }
+ zero(q_bias);
- Qk_vec_k k_bias;
- zero(k_bias);
+ if (!is_masked) {
+ q_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.q_bias)[qk_bias_offset]));
+ }
- if (!is_masked) {
- k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset]));
- }
+ zero(k_bias);
- // Computes the Q/K values with bias.
- q = add_vec(q, q_bias);
- k = add_vec(k, k_bias);
+ if (!is_masked) {
+ k_bias = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.k_bias)[qk_bias_offset]));
+ }
+
+ // Computes the Q/K values with bias.
+ q = add_vec(q, q_bias);
+ k = add_vec(k, k_bias);
+ }
T* params_k_cache = reinterpret_cast(params.k_cache);
+ const float inv_sqrt_dh = params.scale;
+
if (!is_masked) {
// Store the Q values to shared memory.
*reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q;
+ }
- // Write the K values to the global memory cache.
- // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
- // system. We designed it this way as it allows much better memory loads (and there are many
- // more loads) + the stores are really "write and forget" since we won't need the ack before
- // the end of the kernel. There's plenty of time for the transactions to complete.
+ if (!params.is_cross_attention) {
+ if (!is_masked) {
+ // Write the K values to the global memory cache.
+ // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
+ // system. We designed it this way as it allows much better memory loads (and there are many
+ // more loads) + the stores are really "write and forget" since we won't need the ack before
+ // the end of the kernel. There's plenty of time for the transactions to complete.
- // The 16B chunk written by the thread.
- int co = tidx / QK_VECS_IN_16B;
+ // The 16B chunk written by the thread.
+ int co = tidx / QK_VECS_IN_16B;
- // The position of the thread in that 16B chunk.
- int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
+ // The position of the thread in that 16B chunk.
+ int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
- // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
- int offset = bhi * params.max_sequence_length * head_size + co * params.max_sequence_length * QK_ELTS_IN_16B +
- tlength * QK_ELTS_IN_16B + ci;
+ // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
+ int offset = bhi * params.max_sequence_length * head_size + co * params.max_sequence_length * QK_ELTS_IN_16B +
+ tlength * QK_ELTS_IN_16B + ci;
- // Trigger the stores to global memory.
- *reinterpret_cast(¶ms_k_cache[offset]) = vec_conversion(k);
+ // Trigger the stores to global memory.
+ *reinterpret_cast(¶ms_k_cache[offset]) = vec_conversion(k);
- // Compute \sum_i Q[i] * K^T[i] for the current timestep.
- using Qk_vec_acum = Qk_vec_k;
- qk = dot(q, k);
+ // Compute \sum_i Q[i] * K^T[i] for the current timestep.
+ using Qk_vec_acum = Qk_vec_k;
+ qk = dot(q, k);
- if (QK_VECS_PER_WARP <= WARP_SIZE) {
-#pragma unroll
- for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
- qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
+ if (QK_VECS_PER_WARP <= WARP_SIZE) {
+ #pragma unroll
+ for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
+ qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
+ }
}
}
- }
- if (QK_VECS_PER_WARP > WARP_SIZE) {
- constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
- qk = block_sum(&red_smem[WARPS_PER_RED], qk);
- }
-
- const float inv_sqrt_dh = params.scale;
+ if (QK_VECS_PER_WARP > WARP_SIZE) {
+ constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
+ qk = block_sum(&red_smem[WARPS_PER_RED], qk);
+ }
- // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
- if (tidx == 0) {
- // Normalize qk.
- qk *= inv_sqrt_dh;
- qk_max = qk;
- qk_smem[tlength] = qk;
+ // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
+ if (tidx == 0) {
+ // Normalize qk.
+ qk *= inv_sqrt_dh;
+ if (params.relative_attention_bias != nullptr) {
+ qk = add_vec(qk,
+ reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length
+ * params.total_sequence_length
+ + tlength]);
+ }
+ qk_max = qk;
+ qk_smem[tlength] = qk;
+ }
}
// Make sure the data is in shared memory.
@@ -332,6 +349,12 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
// Store the product to shared memory. There's one qk value per timestep. Update the max.
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
+ if (params.relative_attention_bias != nullptr) {
+ qk = add_vec(qk,
+ reinterpret_cast(params.relative_attention_bias)[hi * params.sequence_length
+ * params.total_sequence_length
+ + ti]);
+ }
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
@@ -370,7 +393,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
// Compute the logits and start the sum.
float sum = 0.f;
- for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
+ int sum_tlength = params.is_cross_attention ? tlength - 1 : tlength;
+ for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) {
// This is a deviation from FasterTransformer kernel implementation
// but this aligns with ORT's other Attention kernels which strives to
// mimic PyTorch when dealing with mask filter values
@@ -384,7 +408,7 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
// Normalize the logits.
float inv_sum = __fdividef(1.f, sum + 1.e-6f);
- for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
+ for (int ti = tidx; ti <= sum_tlength; ti += THREADS_PER_BLOCK) {
float logit = qk_smem[ti] * inv_sum;
ConvertFromFloat(logits_smem[ti], logit);
}
@@ -418,12 +442,14 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
- zero(v_bias);
+ if (!params.is_mha) {
+ zero(v_bias);
- T* params_v_bias = reinterpret_cast(params.v_bias);
+ T* params_v_bias = reinterpret_cast(params.v_bias);
- if (vo == tlength % V_PER_ITER) {
- v_bias = vec_conversion(*reinterpret_cast(¶ms_v_bias[hi * head_size + vi]));
+ if (vo == tlength % V_PER_ITER) {
+ v_bias = vec_conversion(*reinterpret_cast(¶ms_v_bias[hi * head_size + vi]));
+ }
}
// From previous, before values, step
@@ -451,12 +477,14 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
}
// One group of threads computes the product(s) for the current timestep.
- if (vo == tlength % V_PER_ITER) {
+ if (vo == tlength % V_PER_ITER && !params.is_cross_attention) {
const auto v_offset = qkv_base_offset + vi;
V_vec_k v;
v = vec_conversion(*reinterpret_cast(&reinterpret_cast(params.v)[v_offset]));
- v = add_vec(v, v_bias);
+ if (!params.is_mha) {
+ v = add_vec(v, v_bias);
+ }
// Store the values with bias back to global memory in the cache for V.
*reinterpret_cast(&v_cache[tlength * head_size]) = vec_conversion(v);
@@ -497,33 +525,47 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionPara
// Template instantiation(s)
+// fp32 + head size = 32
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
+
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
+
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
+
+// fp16 + head size = 32
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
+
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
+
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
+
// fp32 + head size = 64
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
// fp16 + head size = 64
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
// fp32 + head size = 128
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
// fp16 + head size = 128
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
-template void __global__ masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h
index fe1e0cb70252..6501103ed067 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.h
@@ -10,9 +10,13 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-struct DecoderMaskedSelfAttentionParams : AttentionParameters {
+struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
int beam_width = 1;
+ // Weather to use multihead attention(excludes matmul and bias)
+ bool is_mha = false;
+ bool is_cross_attention = false;
+
void* q = nullptr;
void* q_bias = nullptr;
@@ -22,6 +26,8 @@ struct DecoderMaskedSelfAttentionParams : AttentionParameters {
void* v = nullptr;
void* v_bias = nullptr;
+ void* relative_attention_bias = nullptr;
+
void* k_cache = nullptr;
void* v_cache = nullptr;
@@ -43,10 +49,10 @@ template<
int THREADS_PER_VALUE,
// The number of threads in a threadblock.
int THREADS_PER_BLOCK>
-__global__ void masked_multihead_attention_kernel(DecoderMaskedSelfAttentionParams params);
+__global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParams params);
template
-void mmha_launch_kernel(const DecoderMaskedSelfAttentionParams& params, cudaStream_t stream);
+void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cudaStream_t stream);
diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
index 5c8da6da7dd2..42d54e38d41e 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
+++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl_utils.h
@@ -741,7 +741,7 @@ inline __device__ void ConvertFromFloat(uint4& dst, Float8_ src) {
//------------------------------------------------------------
template
-inline size_t CalcDynamicBlockMemory(const DecoderMaskedSelfAttentionParams& params,
+inline size_t CalcDynamicBlockMemory(const DecoderMaskedMultiHeadAttentionParams& params,
int threads_per_value, int threads_per_block) {
// The amount of shared memory needed to store the Q*K^T values in float.
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 174dde63582f..f7e2b596f617 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -549,6 +549,121 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttentionTypeAndShapeInference(ctx, past_input_index);
}));
+constexpr const char* DecoderMaskedMultiHeadAttention_ver1_doc = R"DOC(
+Multihead attention that supports input sequence length of 1.
+Similar to DecoderMaskedSelfAttention but this op excludes QKV MatMul and Bias.
+This op supports both Self and Cross Attention.
+)DOC";
+
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ DecoderMaskedMultiHeadAttention, 1,
+ OpSchema()
+ .SetDoc(DecoderMaskedMultiHeadAttention_ver1_doc)
+ .Attr("num_heads", "Number of attention heads", AttributeProto::INT)
+ .Attr("past_present_share_buffer",
+ "Corresponding past and present are same tensor, its size is "
+ "(batch_size, num_heads, max_sequence_length, head_size)",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
+ .Attr("scale",
+ "Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
+ .Attr("mask_filter_value",
+ "The value to be filled in the attention mask. Default value is -10000.0f",
+ AttributeProto::FLOAT,
+ OPTIONAL_VALUE)
+ .Input(0,
+ "query",
+ "Query with shape (batch_size, 1, hidden_size)",
+ "T")
+ .Input(1,
+ "key",
+ "Key with shape (batch_size, 1, hidden_size) for self attention "
+ "or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention",
+ "T")
+ .Input(2,
+ "value",
+ "Value with shape (batch_size, 1, v_hidden_size) for self attention "
+ "or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention",
+ "T")
+ .Input(3,
+ "mask_index",
+ "Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)",
+ "M",
+ OpSchema::Optional)
+ .Input(4,
+ "relative_position_bias",
+ "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)",
+ "T",
+ OpSchema::Optional)
+ .Input(5,
+ "past_key",
+ "past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attention"
+ "When past_present_share_buffer is set, "
+ "its shape is (batch_size, num_heads, max_sequence_length, head_size). "
+ "The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape "
+ "(batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape "
+ "(batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to "
+ "become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.",
+ "T",
+ OpSchema::Optional)
+ .Input(6,
+ "past_value",
+ "past state for value with shape (batch_size, num_heads, past_sequence_length, head_size) for self attention"
+ "When past_present_share_buffer is set, "
+ "its shape is (batch_size, num_heads, max_sequence_length, head_size). ",
+ "T",
+ OpSchema::Optional)
+ .Input(7,
+ "past_sequence_length",
+ "When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0)."
+ "Cross Attention doesn't need this input.",
+ "M",
+ OpSchema::Optional)
+ .Input(8,
+ "beam_width",
+ "The beam width that is being used while decoding."
+ "If not provided, the beam width will be assumed to be 1.",
+ "M",
+ OpSchema::Optional)
+ .Input(9,
+ "cache_indirection",
+ "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies"
+ "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration",
+ "M",
+ OpSchema::Optional)
+ .Output(0,
+ "output",
+ "3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
+ "T")
+ .Output(1,
+ "present_key",
+ "past state for key with shape (batch_size, num_heads, total_sequence_length, head_size). "
+ "If past_present_share_buffer is set, "
+ "its shape is (batch_size, num_heads, max_sequence_length, head_size), "
+ "while effective_seq_length = (past_sequence_length + kv_sequence_length).",
+ "T",
+ OpSchema::Optional)
+ .Output(2,
+ "present_value",
+ "past state for value with shape (batch_size, num_heads, total_sequence_length, head_size). "
+ "If past_present_share_buffer is set, "
+ "its shape is (batch_size, num_heads, max_sequence_length, head_size), "
+ "while effective_seq_length = (past_sequence_length + kv_sequence_length).",
+ "T",
+ OpSchema::Optional)
+ .TypeConstraint("T",
+ {"tensor(float)", "tensor(float16)"},
+ "Constrain input and output types to float tensors.")
+ .TypeConstraint("M",
+ {"tensor(int32)"},
+ "Constrain mask index to integer types")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ // TODO:
+ (void) (ctx);
+ }));
+
constexpr const char* MultiHeadAttention_ver1_doc = R"DOC(
Multi-Head Self/Cross Attention. Bias from input projection is included.
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index 3066804577f9..772f93ab1b96 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -100,6 +100,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Unique);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedMultiHeadAttention);
class OpSet_Microsoft_ver1 {
public:
@@ -194,6 +195,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
}
};
} // namespace contrib
diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
index 52b41d52c166..d9c870a7dc52 100644
--- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc
@@ -647,7 +647,8 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) {
int sequence_length = 1;
int number_of_heads = 12;
// Vary head_size / hidden_size
- for (int hidden_size = 768; hidden_size <= 1536; hidden_size += 768) {
+ int hidden_sizes[3] = {384, 768, 1536};
+ for (int hidden_size : hidden_sizes) {
int head_size = (hidden_size / number_of_heads);
int total_sequence_length = sequence_length + past_sequence_length;
int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length
@@ -760,7 +761,8 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) {
int number_of_heads = 12;
// Vary head_size / hidden_size
- for (int hidden_size = 768; hidden_size <= 1536; hidden_size += 768) {
+ int hidden_sizes[3] = {384, 768, 1536};
+ for (int hidden_size : hidden_sizes) {
int head_size = (hidden_size / number_of_heads);
int total_sequence_length = sequence_length + past_sequence_length;
int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length
diff --git a/onnxruntime/test/python/transformers/test_parity_t5_mha.py b/onnxruntime/test/python/transformers/test_parity_t5_mha.py
index 51d5ba7838d9..23218e494304 100644
--- a/onnxruntime/test/python/transformers/test_parity_t5_mha.py
+++ b/onnxruntime/test/python/transformers/test_parity_t5_mha.py
@@ -154,6 +154,113 @@ def create_t5_mha_graph(
return model.SerializeToString()
+# For decoder only (not decoder_init) starting from second iteration
+def create_t5_decoder_masked_mha_graph(
+ batch_size,
+ past_sequence_length,
+ kv_sequence_length,
+ head_size,
+ num_heads,
+ is_cross_attention,
+):
+ from onnx import TensorProto, helper
+
+ nodes = [
+ helper.make_node(
+ "DecoderMaskedMultiHeadAttention",
+ [
+ "query",
+ "key",
+ "value",
+ "mask_index" if is_cross_attention else "",
+ "relative_position_bias" if not is_cross_attention else "",
+ "past_key" if not is_cross_attention else "",
+ "past_value" if not is_cross_attention else "",
+ "past_sequence_length" if not is_cross_attention else "",
+ ],
+ [
+ "output",
+ "present_key" if not is_cross_attention else "",
+ "present_value" if not is_cross_attention else "",
+ ],
+ "DMMHA_0",
+ num_heads=num_heads,
+ mask_filter_value=-10000.0,
+ scale=1.0,
+ past_present_share_buffer=0 if is_cross_attention else 1,
+ domain="com.microsoft",
+ ),
+ ]
+
+ initializers = []
+
+ hidden_size = head_size * num_heads
+
+ graph_inputs = [
+ helper.make_tensor_value_info("query", TensorProto.FLOAT, [batch_size, 1, hidden_size]),
+ ]
+
+ graph_outputs = [
+ helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, 1, hidden_size]),
+ ]
+
+ if is_cross_attention:
+ graph_inputs.append(
+ helper.make_tensor_value_info("mask_index", TensorProto.INT32, [batch_size, kv_sequence_length])
+ )
+ graph_inputs.append(
+ helper.make_tensor_value_info(
+ "key", TensorProto.FLOAT, [batch_size, num_heads, kv_sequence_length, head_size]
+ )
+ )
+ graph_inputs.append(
+ helper.make_tensor_value_info(
+ "value", TensorProto.FLOAT, [batch_size, num_heads, kv_sequence_length, head_size]
+ )
+ )
+ else:
+ graph_inputs.append(helper.make_tensor_value_info("key", TensorProto.FLOAT, [batch_size, 1, hidden_size]))
+ graph_inputs.append(helper.make_tensor_value_info("value", TensorProto.FLOAT, [batch_size, 1, hidden_size]))
+ graph_inputs.append(
+ helper.make_tensor_value_info(
+ "relative_position_bias", TensorProto.FLOAT, [1, num_heads, 1, past_sequence_length + 1]
+ )
+ )
+ # use past_sequence_length + 1 to simulate max_sequence_length
+ graph_inputs.append(
+ helper.make_tensor_value_info(
+ "past_key", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size]
+ )
+ )
+ graph_inputs.append(
+ helper.make_tensor_value_info(
+ "past_value", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size]
+ )
+ )
+ graph_inputs.append(helper.make_tensor_value_info("past_sequence_length", TensorProto.INT32, [1]))
+ graph_outputs.append(
+ helper.make_tensor_value_info(
+ "present_key", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size]
+ )
+ )
+ graph_outputs.append(
+ helper.make_tensor_value_info(
+ "present_value", TensorProto.FLOAT, [batch_size, num_heads, past_sequence_length + 1, head_size]
+ )
+ )
+
+ graph = helper.make_graph(
+ nodes,
+ "T5_DMMHA_Graph",
+ graph_inputs,
+ graph_outputs,
+ initializers,
+ )
+
+ model = helper.make_model(graph)
+ return model.SerializeToString()
+
+
class T5Config:
def __init__(self, is_decoder, batch_size, seq_len, kv_sequence_length, num_heads, head_size, use_past):
self.is_decoder = is_decoder
@@ -173,7 +280,7 @@ def __init__(self, is_decoder, batch_size, seq_len, kv_sequence_length, num_head
class T5Attention(nn.Module):
- def __init__(self, config: T5Config, is_static_kv):
+ def __init__(self, config: T5Config, is_static_kv, use_decoder_masked_kernel: bool = False):
super().__init__()
self.is_decoder = config.is_decoder
self.is_static_kv = is_static_kv
@@ -199,17 +306,52 @@ def __init__(self, config: T5Config, is_static_kv):
self.num_heads = config.num_heads
self.hidden_size = self.d_model
self.use_past = config.use_past
+ self.use_decoder_masked_kernel = use_decoder_masked_kernel
# Create onnx graph
- self.onnx_graph = create_t5_mha_graph(
- self.batch_size,
- self.seq_len,
- self.kv_sequence_length,
- self.head_size,
- self.num_heads,
- self.use_past,
- is_static_kv,
- )
+ if self.use_decoder_masked_kernel:
+ self.onnx_graph = create_t5_decoder_masked_mha_graph(
+ self.batch_size,
+ self.kv_sequence_length,
+ self.kv_sequence_length,
+ self.head_size,
+ self.num_heads,
+ is_static_kv,
+ )
+ else:
+ self.onnx_graph = create_t5_mha_graph(
+ self.batch_size,
+ self.seq_len,
+ self.kv_sequence_length,
+ self.head_size,
+ self.num_heads,
+ self.use_past,
+ is_static_kv,
+ )
+
+ # Reorder 'K' from [B, N, S, H] to [B, N, H/4, S, 4]
+ def reorder_key_cache(self, key_cache, batch_size, num_heads, sequence_length, head_size, max_sequence_length):
+ ordered = np.zeros_like(key_cache)
+
+ # assume float
+ num_inner_elements = 4
+ chunks = int(head_size / num_inner_elements)
+
+ for b in range(batch_size):
+ for h in range(num_heads):
+ for c in range(chunks):
+ for s in range(sequence_length):
+ base_offset = (b * num_heads * max_sequence_length * head_size) + (
+ h * max_sequence_length * head_size
+ )
+ input_base_offset = base_offset + (s * head_size) + (c * num_inner_elements)
+ output_base_offset = (
+ base_offset + (c * max_sequence_length * num_inner_elements) + (s * num_inner_elements)
+ )
+ for e in range(num_inner_elements):
+ ordered[output_base_offset + e] = key_cache[input_base_offset + e]
+
+ return ordered
def create_inputs(self):
hidden_states = torch.normal(mean=0.5, std=0.1, size=(self.batch_size, self.seq_len, self.hidden_size)).to(
@@ -230,6 +372,10 @@ def create_inputs(self):
position_bias = torch.normal(
mean=0.5, std=0.1, size=(1, self.num_heads, position_bias_length, position_bias_length)
).to(torch.float32)
+ if self.use_decoder_masked_kernel:
+ position_bias = torch.normal(mean=5, std=0.1, size=(1, self.num_heads, 1, position_bias_length)).to(
+ torch.float32
+ )
return hidden_states, key_value_states, past_key_value, attention_mask, position_bias
def torch_forward(
@@ -302,6 +448,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
+
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
@@ -421,16 +568,57 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
ort_inputs = {
"query": np.ascontiguousarray(query_states.detach().numpy()),
}
+ torch_past_key = np.ascontiguousarray(torch_past_key.detach().numpy())
+ torch_past_value = np.ascontiguousarray(torch_past_value.detach().numpy())
+ max_seq_len = torch_past_key.shape[2] + 1
+ torch_past_key_padded = np.zeros(
+ [torch_past_key.shape[0], torch_past_key.shape[1], max_seq_len, torch_past_key.shape[3]],
+ dtype=np.float32,
+ )
+ torch_past_value_padded = np.zeros(
+ [torch_past_value.shape[0], torch_past_value.shape[1], max_seq_len, torch_past_value.shape[3]],
+ dtype=np.float32,
+ )
+ torch_past_key_padded[:, :, : torch_past_key.shape[2], :] = torch_past_key
+ torch_past_value_padded[:, :, : torch_past_value.shape[2], :] = torch_past_value
if self.is_static_kv:
- ort_inputs["key"] = np.ascontiguousarray(torch_past_key.detach().numpy())
- ort_inputs["value"] = np.ascontiguousarray(torch_past_value.detach().numpy())
+ if self.use_decoder_masked_kernel:
+ reordered_past_key = self.reorder_key_cache(
+ torch_past_key.flatten(),
+ batch_size=batch_size,
+ num_heads=self.num_heads,
+ sequence_length=self.kv_sequence_length,
+ head_size=self.head_size,
+ max_sequence_length=self.kv_sequence_length,
+ )
+ ort_inputs["key"] = reordered_past_key.reshape(torch_past_key.shape)
+ ort_inputs["value"] = torch_past_value
+ else:
+ ort_inputs["key"] = np.ascontiguousarray(torch_past_key)
+ ort_inputs["value"] = np.ascontiguousarray(torch_past_value)
else:
- ort_inputs["past_key"] = np.ascontiguousarray(torch_past_key.detach().numpy())
- ort_inputs["past_value"] = np.ascontiguousarray(torch_past_value.detach().numpy())
ort_inputs["key"] = np.ascontiguousarray(key_states.detach().numpy())
ort_inputs["value"] = np.ascontiguousarray(value_states.detach().numpy())
+ if self.use_decoder_masked_kernel:
+ reordered_past_key = self.reorder_key_cache(
+ torch_past_key_padded.flatten(),
+ batch_size=batch_size,
+ num_heads=self.num_heads,
+ sequence_length=self.kv_sequence_length,
+ head_size=self.head_size,
+ max_sequence_length=max_seq_len,
+ )
+ ort_inputs["past_key"] = reordered_past_key.reshape(torch_past_value_padded.shape)
+ ort_inputs["past_value"] = torch_past_value_padded
+ ort_inputs["past_sequence_length"] = np.array([self.kv_sequence_length], dtype=np.int32)
+ else:
+ ort_inputs["past_key"] = torch_past_key
+ ort_inputs["past_value"] = torch_past_value
if torch_key_padding_mask is not None:
- ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy())
+ if self.use_decoder_masked_kernel:
+ ort_inputs["mask_index"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy())
+ else:
+ ort_inputs["key_padding_mask"] = np.ascontiguousarray(torch_key_padding_mask.detach().numpy())
if torch_position_bias is not None:
ort_inputs["relative_position_bias"] = np.ascontiguousarray(torch_position_bias.detach().numpy())
@@ -445,7 +633,7 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
return output
-def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length):
+def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False):
config = T5Config(
is_decoder=True,
batch_size=batch_size,
@@ -455,7 +643,8 @@ def compare_t5_cross_attention_decoder(batch_size, seq_len, num_heads, head_size
head_size=head_size,
use_past=True,
)
- T5CrossAttention = T5Attention(config, is_static_kv=True) # noqa: N806
+
+ T5CrossAttention = T5Attention(config, is_static_kv=True, use_decoder_masked_kernel=use_dmmha) # noqa: N806
hidden_states, key_value_states, past_key_value, attention_mask, _ = T5CrossAttention.create_inputs()
torch_output = T5CrossAttention.torch_forward(
@@ -521,7 +710,7 @@ def compare_t5_self_attention_decoder_init(batch_size, seq_len, num_heads, head_
assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4)
-def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length):
+def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=False):
config = T5Config(
is_decoder=True,
batch_size=batch_size,
@@ -531,7 +720,8 @@ def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size,
head_size=head_size,
use_past=True,
)
- T5CrossAttention = T5Attention(config, is_static_kv=False) # noqa: N806
+
+ T5CrossAttention = T5Attention(config, is_static_kv=False, use_decoder_masked_kernel=use_dmmha) # noqa: N806
hidden_states, _, past_key_value, _, position_bias = T5CrossAttention.create_inputs()
torch_output = T5CrossAttention.torch_forward(
@@ -543,8 +733,9 @@ def compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size,
if ort_output is not None:
assert torch.allclose(torch_output[0], ort_output[0], atol=1e-4)
- assert torch.allclose(torch_output[1][0], ort_output[1][0], atol=1e-4)
- assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4)
+ if not use_dmmha:
+ assert torch.allclose(torch_output[1][0], ort_output[1][0], atol=1e-4)
+ assert torch.allclose(torch_output[1][1], ort_output[1][1], atol=1e-4)
class TestT5MHAParity(unittest.TestCase):
@@ -575,6 +766,24 @@ def test_t5_self_attention_decoder(self):
self.batch_size, self.seq_len, self.num_heads, self.head_size, self.kv_sequence_length
)
+ def test_t5_cross_attention_decoder_masked_mha(self):
+ batch_size = 2
+ seq_len = 1
+ num_heads = 2
+ head_size = 32
+ kv_sequence_length = 2
+ compare_t5_cross_attention_decoder(
+ batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True
+ )
+
+ def test_t5_self_attention_decoder_masked_mha(self):
+ batch_size = 2
+ seq_len = 1
+ num_heads = 2
+ head_size = 32
+ kv_sequence_length = 2
+ compare_t5_self_attention_decoder(batch_size, seq_len, num_heads, head_size, kv_sequence_length, use_dmmha=True)
+
if __name__ == "__main__":
unittest.main()
From e4aae94f20b995fe4b80643499d46af211e38fc2 Mon Sep 17 00:00:00 2001
From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com>
Date: Mon, 3 Apr 2023 12:47:14 -0700
Subject: [PATCH 2/4] Remove azure build to unblock PRs (#15336)
Temporarily remove Azure build check to unblock PR(s).
We need to investigate the sudden build failure and reenable.
Co-authored-by: Randy Shuai
---
.../azure-pipelines/win-ci-pipeline.yml | 20 -------------------
1 file changed, 20 deletions(-)
diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
index 77ec91824d1e..1b598a405ed7 100644
--- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
@@ -190,23 +190,3 @@ stages:
WITH_CACHE: true
MachinePool: 'onnxruntime-Win2019-CPU-training'
-- stage: x64_release_azure
- dependsOn: []
- jobs:
- - template: templates/win-ci-vs-2019.yml
- parameters:
- BuildConfig: 'RelWithDebInfo'
- EnvSetupScript: setup_env_azure.bat
- buildArch: x64
- additionalBuildFlags: --use_azure
- msbuildPlatform: x64
- isX86: false
- job_name_suffix: x64_release_azure
- RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }}
- RunStaticCodeAnalysis: false
- EnablePython: false
- isTraining: false
- ORT_EP_NAME: CPU
- GenerateDocumentation: false
- WITH_CACHE: true
- MachinePool: 'Win-CPU-2019'
From 85bb13345d13e0172b5ffa113a9544267f74bee6 Mon Sep 17 00:00:00 2001
From: Matthieu Darbois
Date: Tue, 4 Apr 2023 02:45:12 +0200
Subject: [PATCH 3/4] Rework some external targets to ease building with
`-DFETCHCONTENT_FULLY_DISCONNECTED=ON` (#15323)
### Description
Rework some external targets to ease building with
`-DFETCHCONTENT_FULLY_DISCONNECTED=ON`
This will allow package managers to more easily provide an onnxruntime
package by reducing the amount of patching needed downstream at each
version.
### Motivation and Context
Availability of onnxruntime in some C++ package managers
https://github.com/microsoft/onnxruntime/issues/7150
https://github.com/conan-io/conan-center-index/issues/16699
https://github.com/microsoft/vcpkg/issues/20548
My initial intent is to get this in conan but the PR would most likely
be useful (though not tested) to vcpkg as well (and maybe others).
I tried to get only a first batch of not too specific patches (i.e. not
specific to conan).
The first commit reworks `flatbuffers` and just extends what @snnn did
in https://github.com/microsoft/onnxruntime/pull/13991
The second commit reworks `pytorch_cpuinfo`
The third commit reworks `google_nsync`
---
cmake/CMakeLists.txt | 6 ++---
.../external/onnxruntime_external_deps.cmake | 8 ++++++-
cmake/onnxruntime_common.cmake | 4 ++--
cmake/onnxruntime_flatbuffers.cmake | 2 +-
cmake/onnxruntime_providers.cmake | 22 ++++++++---------
cmake/onnxruntime_unittests.cmake | 24 +++++++++----------
cmake/onnxruntime_webassembly.cmake | 4 ++--
7 files changed, 38 insertions(+), 32 deletions(-)
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 0bcf5fa38a6f..b59824069bf3 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -748,7 +748,7 @@ endif()
function(onnxruntime_set_compile_flags target_name)
if (CPUINFO_SUPPORTED)
- onnxruntime_add_include_to_target(${target_name} cpuinfo)
+ onnxruntime_add_include_to_target(${target_name} cpuinfo::cpuinfo)
endif()
if(onnxruntime_ENABLE_EAGER_MODE)
target_compile_definitions(${target_name} PRIVATE ENABLE_EAGER_MODE)
@@ -832,7 +832,7 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_options(${target_name} PRIVATE "-Wno-unused-parameter")
endif()
target_compile_definitions(${target_name} PUBLIC -DNSYNC_ATOMIC_CPP11)
- target_include_directories(${target_name} PRIVATE "${google_nsync_SOURCE_DIR}/public")
+ onnxruntime_add_include_to_target(${target_name} nsync::nsync_cpp)
endif()
foreach(ORT_FLAG ${ORT_PROVIDER_FLAGS})
target_compile_definitions(${target_name} PRIVATE ${ORT_FLAG})
@@ -1469,7 +1469,7 @@ if (WIN32)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SYS_PATH_LIB})
list(APPEND onnxruntime_EXTERNAL_LIBRARIES debug Dbghelp)
else()
- list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync_cpp)
+ list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync::nsync_cpp)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${CMAKE_DL_LIBS} Threads::Threads)
endif()
diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
index b829e0917363..337f4ce20482 100644
--- a/cmake/external/onnxruntime_external_deps.cmake
+++ b/cmake/external/onnxruntime_external_deps.cmake
@@ -236,7 +236,10 @@ if (NOT WIN32)
#nsync tests failed on Mac Build
set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
onnxruntime_fetchcontent_makeavailable(google_nsync)
- set(nsync_SOURCE_DIR ${google_nsync_SOURCE_DIR})
+ if (google_nsync_SOURCE_DIR)
+ add_library(nsync::nsync_cpp ALIAS nsync_cpp)
+ target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public)
+ endif()
endif()
if(onnxruntime_USE_CUDA)
@@ -360,6 +363,9 @@ FetchContent_Declare(
if (CPUINFO_SUPPORTED)
onnxruntime_fetchcontent_makeavailable(pytorch_cpuinfo)
+ if (pytorch_cpuinfo_SOURCE_DIR)
+ add_library(cpuinfo::cpuinfo ALIAS cpuinfo)
+ endif()
endif()
diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake
index 0e02ad9daa7e..685df7a48769 100644
--- a/cmake/onnxruntime_common.cmake
+++ b/cmake/onnxruntime_common.cmake
@@ -194,8 +194,8 @@ if (ARM64 OR ARM OR X86 OR X64 OR X86_64)
# Using it mainly in ARM with Android.
# Its functionality in detecting x86 cpu features are lacking, so is support for Windows.
if (CPUINFO_SUPPORTED)
- onnxruntime_add_include_to_target(onnxruntime_common cpuinfo)
- list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo clog)
+ onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo)
+ list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo)
endif()
endif()
endif()
diff --git a/cmake/onnxruntime_flatbuffers.cmake b/cmake/onnxruntime_flatbuffers.cmake
index c0cd1699bb33..3ab4c19122ba 100644
--- a/cmake/onnxruntime_flatbuffers.cmake
+++ b/cmake/onnxruntime_flatbuffers.cmake
@@ -9,7 +9,7 @@ file(GLOB onnxruntime_flatbuffers_srcs CONFIGURE_DEPENDS
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_flatbuffers_srcs})
onnxruntime_add_static_library(onnxruntime_flatbuffers ${onnxruntime_flatbuffers_srcs})
-onnxruntime_add_include_to_target(onnxruntime_flatbuffers onnx flatbuffers ${GSL_TARGET})
+onnxruntime_add_include_to_target(onnxruntime_flatbuffers onnx flatbuffers::flatbuffers ${GSL_TARGET})
if(onnxruntime_ENABLE_INSTRUMENT)
target_compile_definitions(onnxruntime_flatbuffers PUBLIC ONNXRUNTIME_ENABLE_INSTRUMENT)
endif()
diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 452985023560..53508e64d04a 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -548,10 +548,10 @@ if (onnxruntime_USE_CUDA)
if(APPLE)
set_property(TARGET onnxruntime_providers_cuda APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/cuda/exported_symbols.lst")
- target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync_cpp)
+ target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync::nsync_cpp)
elseif(UNIX)
set_property(TARGET onnxruntime_providers_cuda APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/cuda/version_script.lds -Xlinker --gc-sections")
- target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync_cpp)
+ target_link_libraries(onnxruntime_providers_cuda PRIVATE nsync::nsync_cpp)
elseif(WIN32)
set_property(TARGET onnxruntime_providers_cuda APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/cuda/symbols.def")
else()
@@ -609,10 +609,10 @@ if (onnxruntime_USE_DNNL)
INSTALL_RPATH "@loader_path"
BUILD_WITH_INSTALL_RPATH TRUE
INSTALL_RPATH_USE_LINK_PATH FALSE)
- target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync_cpp)
+ target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp)
elseif(UNIX)
set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/dnnl/version_script.lds -Xlinker --gc-sections -Xlinker -rpath=\$ORIGIN")
- target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync_cpp)
+ target_link_libraries(onnxruntime_providers_dnnl PRIVATE nsync::nsync_cpp)
elseif(WIN32)
set_property(TARGET onnxruntime_providers_dnnl APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/dnnl/symbols.def")
else()
@@ -742,11 +742,11 @@ if (onnxruntime_USE_TENSORRT)
if(APPLE)
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker -exported_symbols_list ${ONNXRUNTIME_ROOT}/core/providers/tensorrt/exported_symbols.lst")
- target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync_cpp)
+ target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp)
elseif(UNIX)
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/tensorrt/version_script.lds -Xlinker --gc-sections")
- target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync_cpp stdc++fs)
+ target_link_libraries(onnxruntime_providers_tensorrt PRIVATE nsync::nsync_cpp stdc++fs)
elseif(WIN32)
set_property(TARGET onnxruntime_providers_tensorrt APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/tensorrt/symbols.def")
else()
@@ -1091,7 +1091,7 @@ if (onnxruntime_USE_QNN)
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_qnn_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_qnn ${onnxruntime_providers_qnn_cc_srcs})
- onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers Boost::mp11)
+ onnxruntime_add_include_to_target(onnxruntime_providers_qnn onnxruntime_common onnxruntime_framework onnx onnx_proto protobuf::libprotobuf-lite flatbuffers::flatbuffers Boost::mp11)
target_link_libraries(onnxruntime_providers_qnn)
add_dependencies(onnxruntime_providers_qnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_providers_qnn PROPERTIES CXX_STANDARD_REQUIRED ON)
@@ -1286,7 +1286,7 @@ if (onnxruntime_USE_MIGRAPHX)
target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections")
- target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync_cpp stdc++fs)
+ target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs)
include(CheckLibraryExists)
check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC)
@@ -1552,7 +1552,7 @@ if (onnxruntime_USE_ROCM)
if(UNIX)
set_property(TARGET onnxruntime_providers_rocm APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/rocm/version_script.lds -Xlinker --gc-sections")
- target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync_cpp)
+ target_link_libraries(onnxruntime_providers_rocm PRIVATE nsync::nsync_cpp)
else()
message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it")
endif()
@@ -1688,7 +1688,7 @@ if (onnxruntime_USE_CANN)
onnxruntime_add_include_to_target(onnxruntime_providers_cann onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface)
add_dependencies(onnxruntime_providers_cann onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
- target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED})
+ target_link_libraries(onnxruntime_providers_cann PRIVATE ascendcl acl_op_compiler fmk_onnx_parser nsync::nsync_cpp ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED})
target_link_directories(onnxruntime_providers_cann PRIVATE ${onnxruntime_CANN_HOME}/lib64)
target_include_directories(onnxruntime_providers_cann PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${onnxruntime_CANN_HOME} ${onnxruntime_CANN_HOME}/include)
@@ -1710,7 +1710,7 @@ if (onnxruntime_USE_AZURE)
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_azure_src})
onnxruntime_add_static_library(onnxruntime_providers_azure ${onnxruntime_providers_azure_src})
add_dependencies(onnxruntime_providers_azure ${onnxruntime_EXTERNAL_DEPENDENCIES})
- onnxruntime_add_include_to_target(onnxruntime_providers_azure onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11)
+ onnxruntime_add_include_to_target(onnxruntime_providers_azure onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11)
target_link_libraries(onnxruntime_providers_azure PRIVATE onnx onnxruntime_common onnxruntime_framework)
set_target_properties(onnxruntime_providers_azure PROPERTIES FOLDER "ONNXRuntime")
set_target_properties(onnxruntime_providers_azure PROPERTIES LINKER_LANGUAGE CXX)
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index 930db84ef42d..4f1cccc7fd49 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -668,8 +668,8 @@ if(MSVC)
"$<$>:/wd6326>")
else()
target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11)
- target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}
- ${nsync_SOURCE_DIR}/public)
+ target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
+ onnxruntime_add_include_to_target(onnxruntime_test_utils nsync::nsync_cpp)
endif()
if (onnxruntime_USE_NCCL)
target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS})
@@ -702,8 +702,8 @@ if(MSVC)
"$<$>:/utf-8>")
else()
target_compile_definitions(onnx_test_runner_common PUBLIC -DNSYNC_ATOMIC_CPP11)
- target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}
- ${nsync_SOURCE_DIR}/public)
+ target_include_directories(onnx_test_runner_common PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
+ onnxruntime_add_include_to_target(onnx_test_runner_common nsync::nsync_cpp)
endif()
if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
#TODO: fix the warnings, they are dangerous
@@ -1070,7 +1070,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
# "Global initializer calls a non-constexpr function." BENCHMARK_CAPTURE macro needs this.
target_compile_options(onnxruntime_mlas_benchmark PRIVATE /wd26426)
else()
- target_link_libraries(onnxruntime_mlas_benchmark PRIVATE nsync_cpp ${CMAKE_DL_LIBS})
+ target_link_libraries(onnxruntime_mlas_benchmark PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
endif()
if (CPUINFO_SUPPORTED AND NOT onnxruntime_BUILD_WEBASSEMBLY)
target_link_libraries(onnxruntime_mlas_benchmark PRIVATE cpuinfo)
@@ -1128,7 +1128,7 @@ if(onnxruntime_ENABLE_EAGER_MODE)
list(APPEND onnxruntime_eager_mode_libs onnxruntime_training tensorboard)
endif()
IF(NOT WIN32)
- list(APPEND onnxruntime_eager_mode_libs nsync_cpp)
+ list(APPEND onnxruntime_eager_mode_libs nsync::nsync_cpp)
endif()
target_link_libraries(onnxruntime_eager_mode_test PRIVATE ${onnxruntime_eager_mode_libs} Threads::Threads ${onnxruntime_EXTERNAL_LIBRARIES})
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
@@ -1188,7 +1188,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
${onnxruntime_EXTERNAL_LIBRARIES}
${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS})
if(NOT WIN32)
- list(APPEND onnxruntime_perf_test_libs nsync_cpp)
+ list(APPEND onnxruntime_perf_test_libs nsync::nsync_cpp)
if(onnxruntime_USE_SNPE)
list(APPEND onnxruntime_perf_test_libs onnxruntime_providers_snpe)
endif()
@@ -1232,7 +1232,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
# test inference using shared lib
set(onnxruntime_shared_lib_test_LIBS onnxruntime_mocked_allocator onnxruntime_test_utils onnxruntime_common onnx_proto)
if(NOT WIN32)
- list(APPEND onnxruntime_shared_lib_test_LIBS nsync_cpp)
+ list(APPEND onnxruntime_shared_lib_test_LIBS nsync::nsync_cpp)
if(onnxruntime_USE_SNPE)
list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_providers_snpe)
endif()
@@ -1354,7 +1354,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo)
endif()
if(NOT WIN32)
- target_link_libraries(onnxruntime_mlas_test PRIVATE nsync_cpp ${CMAKE_DL_LIBS})
+ target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs})
@@ -1546,7 +1546,7 @@ endif()
if (NOT onnxruntime_BUILD_WEBASSEMBLY AND (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_MINIMAL_BUILD_CUSTOM_OPS))
- file(GLOB_RECURSE custom_op_get_const_input_test_library_src
+ file(GLOB_RECURSE custom_op_get_const_input_test_library_src
"${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op_lib.cc"
"${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op.h"
"${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op.cc"
@@ -1562,7 +1562,7 @@ if (NOT onnxruntime_BUILD_WEBASSEMBLY AND (NOT onnxruntime_MINIMAL_BUILD OR onnx
if (APPLE)
set(ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip")
else()
- string(CONCAT ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG
+ string(CONCAT ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_get_const_input_test_library/custom_op_lib.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
endif()
@@ -1582,7 +1582,7 @@ if (onnxruntime_BUILD_SHARED_LIB AND NOT onnxruntime_BUILD_WEBASSEMBLY AND NOT o
set(onnxruntime_logging_apis_test_LIBS onnxruntime_common onnxruntime_test_utils)
if(NOT WIN32)
- list(APPEND onnxruntime_logging_apis_test_LIBS nsync_cpp ${CMAKE_DL_LIBS})
+ list(APPEND onnxruntime_logging_apis_test_LIBS nsync::nsync_cpp ${CMAKE_DL_LIBS})
endif()
AddTest(DYN
diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake
index 886668d6b46b..2188565a876d 100644
--- a/cmake/onnxruntime_webassembly.cmake
+++ b/cmake/onnxruntime_webassembly.cmake
@@ -97,7 +97,7 @@ target_compile_options(onnx PRIVATE -Wno-unused-parameter -Wno-unused-variable)
if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB)
bundle_static_library(onnxruntime_webassembly
- nsync_cpp
+ nsync::nsync_cpp
${PROTOBUF_LIB}
onnx
onnx_proto
@@ -172,7 +172,7 @@ else()
endif()
target_link_libraries(onnxruntime_webassembly PRIVATE
- nsync_cpp
+ nsync::nsync_cpp
${PROTOBUF_LIB}
onnx
onnx_proto
From 44027797b0592d9d96be4d437c3bd6c81a72918e Mon Sep 17 00:00:00 2001
From: Hector Li
Date: Mon, 3 Apr 2023 17:51:42 -0700
Subject: [PATCH 4/4] [QNN EP] Gather support int64 indices input (#15317)
### Description
Gather support int64 indices input
### Motivation and Context
Support more scenario
---
.../builder/opbuilder/gather_op_builder.cc | 68 +++++++------
onnxruntime/test/optimizer/qdq_test_utils.h | 74 ++++++++++++++
.../test/providers/qnn/gather_op_htp_test.cc | 96 +++++++++++++++++++
3 files changed, 210 insertions(+), 28 deletions(-)
create mode 100644 onnxruntime/test/providers/qnn/gather_op_htp_test.cc
diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc
index 79f3435eaa19..44131eca2f88 100644
--- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc
+++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc
@@ -43,6 +43,7 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
bool do_op_validation) const {
ORT_UNUSED_PARAMETER(do_op_validation);
const auto& inputs = node_unit.Inputs();
+ ORT_RETURN_IF(inputs.size() != 2, "Gather should has 2 inputs at least!");
ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[0], logger, is_quantized_model, input_names));
// Process indices
@@ -53,54 +54,66 @@ Status GatherOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
return Status::OK();
}
- Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;
+ std::string indices_input_name(input_name);
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_INT_32;
const auto* type_proto = inputs[1].node_arg.TypeAsProto();
- ORT_RETURN_IF_ERROR(GetQnnDataType(is_quantized_model, type_proto, qnn_data_type));
+ ORT_RETURN_IF_ERROR(GetQnnDataType(false, type_proto, qnn_data_type));
std::vector unpacked_tensor;
std::vector gather_indices;
bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name);
+
+ ORT_RETURN_IF(is_quantized_model && qnn_data_type == QNN_DATATYPE_INT_64 && !is_initializer_input,
+ "HTP backend doesn't support any int64 data type.");
+
if (is_initializer_input) {
const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name);
ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(*input_tensor, unpacked_tensor));
- }
-
- // For Quantized model, Gather indices use int32 without quantization
- if (is_quantized_model) {
if (qnn_data_type == QNN_DATATYPE_INT_64) {
- if (!is_initializer_input) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Gather indices only support int32 type on Qnn NPU.");
- } else {
- // Convert initializer from int64 to int32
- size_t size = unpacked_tensor.size() / sizeof(int64_t);
- const int64_t* gather_indices_int64 = reinterpret_cast(unpacked_tensor.data());
- gather_indices.resize(size * sizeof(int32_t));
- int32_t* gather_indices_int32 = reinterpret_cast(gather_indices.data());
- std::transform(gather_indices_int64, gather_indices_int64 + size, gather_indices_int32,
- [](int64_t item) { return SafeInt(item); });
- qnn_data_type = QNN_DATATYPE_INT_32;
- }
+ // Convert initializer from int64 to int32
+ size_t size = unpacked_tensor.size() / sizeof(int64_t);
+ const int64_t* gather_indices_int64 = reinterpret_cast(unpacked_tensor.data());
+ gather_indices.resize(size * sizeof(int32_t));
+ int32_t* gather_indices_int32 = reinterpret_cast(gather_indices.data());
+ std::transform(gather_indices_int64, gather_indices_int64 + size, gather_indices_int32,
+ [](int64_t item) { return SafeInt(item); });
} else {
- qnn_data_type = QNN_DATATYPE_INT_32;
gather_indices = std::move(unpacked_tensor);
}
- InitializeQuantizeParam(quantize_param, false);
- ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[1].quant_param,
- quantize_param.scaleOffsetEncoding.scale,
- quantize_param.scaleOffsetEncoding.offset),
- "Cannot get quantization parameter");
- } else {
- gather_indices = std::move(unpacked_tensor);
+ qnn_data_type = QNN_DATATYPE_INT_32;
}
+ // Even for Quantized model, Gather indices use int32 without quantization
+ Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;
+
Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name);
std::vector input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, input_shape), "Cannot get shape");
+ std::vector cast_output_shape(input_shape);
QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param,
std::move(input_shape), std::move(gather_indices));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor.");
- input_names.push_back(input_name);
+
+ if (!is_initializer_input && qnn_data_type == QNN_DATATYPE_INT_64) {
+ // Insert cast node int64 -> int32
+ if (qnn_data_type == QNN_DATATYPE_INT_64) {
+ // Add Cast node for indices
+ indices_input_name = input_name + "_cast";
+ QnnTensorWrapper cast_output(indices_input_name, QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_INT_32, quantize_param,
+ std::move(cast_output_shape));
+ ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output)), "Failed to add tensor.");
+ ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_input_name,
+ qnn_def::package_name,
+ "Cast",
+ {input_name},
+ {indices_input_name},
+ {},
+ do_op_validation),
+ "Failed to add node.");
+ }
+ }
+
+ input_names.push_back(indices_input_name);
return Status::OK();
}
@@ -121,7 +134,6 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));
// if indicies is scalar shape, then need to add Reshape node
- ORT_ENFORCE(input_names.size() == 2, "Gather should has 2 inputs at least!");
const auto& input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[0]);
const auto& indices_input_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(input_names[1]);
diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h
index 2b2c8074000a..bd8d7707902b 100644
--- a/onnxruntime/test/optimizer/qdq_test_utils.h
+++ b/onnxruntime/test/optimizer/qdq_test_utils.h
@@ -195,6 +195,80 @@ GetQDQTestCaseFn BuildQDQReduceOpTestCase(const std::string& reduce_op_type, con
};
}
+// Creates the following graph:
+// _______________________
+// input (f32) -> Q -> DQ -> | | -> Q -> DQ -> output (f32)
+// axes (int32, initializer) -> | Gather |
+// |_______________________|
+//
+template
+GetQDQTestCaseFn BuildQDQGatherOpTestCase(const std::vector& input_shape,
+ const std::vector indices,
+ const std::vector& indices_shape,
+ int64_t axis) {
+ return [input_shape, indices, indices_shape, axis](ModelTestBuilder& builder) {
+
+ auto* input_data = builder.MakeInput(input_shape, -1.0f, 1.0f);
+ auto* final_output = builder.MakeOutput();
+
+ // input_data -> Q/DQ ->
+ auto* input_qdq_output = AddQDQNodePair(builder, input_data, .003f, 1);
+
+ std::vector gather_op_inputs;
+ gather_op_inputs.push_back(input_qdq_output);
+
+ auto* indices_input = builder.MakeInitializer(indices_shape, indices);
+
+ auto* gather_output = builder.MakeIntermediate();
+ Node& gather_node = builder.AddNode("Gather", {input_qdq_output, indices_input}, {gather_output});
+ gather_node.AddAttribute("axis", axis);
+
+ // -> Q/DQ -> final_output
+ auto* q_output = builder.MakeIntermediate();
+ builder.AddQuantizeLinearNode(gather_output, .003f, 1,
+ q_output);
+
+ builder.AddDequantizeLinearNode(q_output, .003f, 1,
+ final_output);
+ };
+}
+
+// Creates the following graph:
+// _______________________
+// input (f32) -> Q -> DQ -> | | -> Q -> DQ -> output (f32)
+// axes (int32, initializer) -> | Gather |
+// |_______________________|
+//
+template
+GetQDQTestCaseFn BuildQDQGatherOpScalarIndicesTestCase(const std::vector& input_shape,
+ const IndicesType indices,
+ int64_t axis) {
+ return [input_shape, indices, axis](ModelTestBuilder& builder) {
+ auto* input_data = builder.MakeInput(input_shape, -1.0f, 1.0f);
+ auto* final_output = builder.MakeOutput();
+
+ // input_data -> Q/DQ ->
+ auto* input_qdq_output = AddQDQNodePair(builder, input_data, .003f, 1);
+
+ std::vector gather_op_inputs;
+ gather_op_inputs.push_back(input_qdq_output);
+
+ auto* indices_input = builder.MakeScalarInitializer(indices);
+
+ auto* gather_output = builder.MakeIntermediate();
+ Node& gather_node = builder.AddNode("Gather", {input_qdq_output, indices_input}, {gather_output});
+ gather_node.AddAttribute("axis", axis);
+
+ // -> Q/DQ -> final_output
+ auto* q_output = builder.MakeIntermediate();
+ builder.AddQuantizeLinearNode(gather_output, .003f, 1,
+ q_output);
+
+ builder.AddDequantizeLinearNode(q_output, .003f, 1,
+ final_output);
+ };
+}
+
template
GetQDQTestCaseFn BuildQDQConvTestCase(const std::vector& input_shape, const std::vector& weights_shape) {
return [input_shape, weights_shape](ModelTestBuilder& builder) {
diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc
new file mode 100644
index 000000000000..09d98fab0e95
--- /dev/null
+++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc
@@ -0,0 +1,96 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#if !defined(ORT_MINIMAL_BUILD)
+
+#include
+#include "core/graph/graph.h"
+
+#include "test/optimizer/qdq_test_utils.h"
+#include "test/providers/qnn/qnn_test_utils.h"
+
+#include "gtest/gtest.h"
+
+namespace onnxruntime {
+namespace test {
+#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
+
+/**
+ * Runs a Gather op model on the QNN HTP backend. Checks the graph node assignment, and that inference
+ * outputs for QNN and CPU match.
+ *
+ * \param op_type The Gather op type (e.g., ReduceSum).
+ * \param opset The opset version.
+ * \param test_description Description of the test for error reporting.
+ * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None)
+ */
+template
+static void RunGatherOpQDQTest(int opset, const char* test_description, bool scalar_indices = false,
+ ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All) {
+ ProviderOptions provider_options;
+#if defined(_WIN32)
+ provider_options["backend_path"] = "QnnHtp.dll";
+#else
+ provider_options["backend_path"] = "libQnnHtp.so";
+#endif
+
+ constexpr int expected_nodes_in_partition = 1;
+ if (scalar_indices) {
+ RunQnnModelTest(BuildQDQGatherOpScalarIndicesTestCase({2, 3, 4},// input shape
+ 1, // indices
+ 1), // axis
+ provider_options,
+ opset,
+ expected_ep_assignment,
+ expected_nodes_in_partition,
+ test_description);
+ } else {
+ RunQnnModelTest(BuildQDQGatherOpTestCase({2, 3, 4}, // input shape
+ std::vector{1}, // indices
+ {1}, // indices_shape
+ 1), // axis
+ provider_options,
+ opset,
+ expected_ep_assignment,
+ expected_nodes_in_partition,
+ test_description);
+ }
+}
+
+// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all
+// nodes are supported by the QNN EP, and that the inference results match the CPU EP results.
+//
+// - Uses uint8 as the quantization type.
+TEST_F(QnnHTPBackendTests, TestQDQGatherOpU8) {
+ RunGatherOpQDQTest(11, "TestQDQGatherOpU8");
+}
+
+// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all
+// nodes are supported by the QNN EP, and that the inference results match the CPU EP results.
+//
+// - Uses int8 as the quantization type.
+TEST_F(QnnHTPBackendTests, TestQDQGatherOpI8) {
+ RunGatherOpQDQTest(11, "TestQDQGatherOpI8");
+}
+
+// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all
+// nodes are supported by the QNN EP, and that the inference results match the CPU EP results.
+//
+// - Uses uint8 as the quantization type.
+TEST_F(QnnHTPBackendTests, TestQDQGatherOpScalarIndicesU8) {
+ RunGatherOpQDQTest(11, "TestQDQGatherOpScalarIndicesU8", true);
+}
+
+// Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all
+// nodes are supported by the QNN EP, and that the inference results match the CPU EP results.
+//
+// - Uses int8 as the quantization type.
+TEST_F(QnnHTPBackendTests, TestQDQGatherOpScalarIndicesI8) {
+ RunGatherOpQDQTest(11, "TestQDQGatherOpScalarIndicesI8", true);
+}
+
+#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
+} // namespace test
+} // namespace onnxruntime
+
+#endif
\ No newline at end of file