From 0c5b95fc86750526d09ee9e669a98506116c6bde Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 8 Nov 2023 13:53:56 -0800 Subject: [PATCH] Cherry-pick LLaMA GQA mask to rel-1.16.2 (round 4) (#18350) Cherry-pick LLaMA GQA attention mask and script changes to 1.16.2 release branch. --------- Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Co-authored-by: Yufeng Li Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> --- docs/ContribOperators.md | 26 +- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/bert/attention_common.h | 11 +- .../contrib_ops/cuda/bert/attention_impl.cu | 1 + .../bert/cutlass_fmha/fmha_launch_template.h | 133 +++- .../cutlass_fmha/memory_efficient_attention.h | 2 + .../cuda/bert/group_query_attention.cc | 74 +- .../cuda/bert/group_query_attention.h | 2 +- .../cuda/bert/group_query_attention_helper.h | 99 ++- .../cuda/bert/group_query_attention_impl.cu | 326 ++++++--- .../cuda/bert/group_query_attention_impl.h | 3 +- .../cuda/bert/packed_attention_impl.cu | 1 + .../bert/packed_multihead_attention_impl.cu | 1 + .../core/graph/contrib_ops/bert_defs.cc | 38 +- .../tools/transformers/convert_generation.py | 119 ++- .../tools/transformers/models/llama/README.md | 97 ++- .../transformers/models/llama/benchmark.py | 22 +- .../models/llama/convert_to_onnx.py | 43 +- .../transformers/models/llama/llama_inputs.py | 47 +- .../transformers/models/llama/llama_parity.py | 29 +- .../transformers/models/llama/llama_torch.py | 6 +- .../python/transformers/test_flash_attn.py | 687 ++++++++++++++---- 22 files changed, 1306 insertions(+), 463 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 43816c3db813..7778a4d3696b 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2236,19 +2236,15 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
-
is_past_bsnh : int
-
Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).
kv_num_heads : int (required)
Number of attention heads for k and v
num_heads : int (required)
Number of attention heads for q
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-
unidirectional : int
-
Whether every token can only attend to previous tokens. Default value is 1.
-#### Inputs (3 - 6) +#### Inputs
query : T
@@ -2258,11 +2254,13 @@ This version of the operator has been available since version 1 of the 'com.micr
value : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
past_key (optional) : T
-
past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
past_value (optional) : T
-
past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
-
past_sequence_length (optional) : M
-
When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.
+
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
+
seqlens_k : M
+
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
+
total_sequence_length : M
+
Scalar tensor of total sequence length (past + new).
#### Outputs @@ -2271,9 +2269,9 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
present_key : T
-
present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
present_value : T
-
present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
+
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
#### Type Constraints @@ -2281,8 +2279,8 @@ This version of the operator has been available since version 1 of the 'com.micr
T : tensor(float16)
Constrain input and output to float tensors.
-
M : tensor(int32), tensor(int64)
-
Constrain past sequence length to int tensor.
+
M : tensor(int32)
+
Constrain mask to int tensor.
@@ -4766,7 +4764,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.RotaryEmbedding** - RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices that are multiplied to query and key before the inner product of query and key is taken. #### Version diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 7802a6853ac7..f4142adc07bb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -843,7 +843,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)
**T** = tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 0fd8790e0d29..b693b58c7c40 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -86,11 +86,9 @@ struct PackedAttentionParameters { // Parameters deduced from node attributes and inputs/outputs. struct GroupQueryAttentionParameters { int batch_size; - int sequence_length; - int past_sequence_length; // actual sequence length of past_key and past_value - int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present) - int present_sequence_length; // past_sequence_length + kv_sequence_length - int max_sequence_length; // allocated length of past_key and past_value + int sequence_length; // sequence length of input query, key, value + int seqlen_past_kv_cache; // sequence length of past kv tensor + int seqlen_present_kv_cache; // sequence length of present kv tensor int hidden_size; int num_heads; int head_size; @@ -98,6 +96,9 @@ struct GroupQueryAttentionParameters { int kv_num_heads; int num_splits; // number of splits for splitkv bool is_unidirectional; // causal + bool kv_share_buffer; + bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool left_padding; // copies last token to last index if true float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 16ce3a899fb5..83c426e7e6ed 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -401,6 +401,7 @@ Status EfficientAttention( ? data.scratch : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR("efficient attention output", data.output, parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 51c3d3d3a458..db78722cc0e4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -16,6 +16,133 @@ namespace onnxruntime { namespace contrib { namespace cuda { +template +struct RightPaddingBatchHook { + using scalar_t = typename AttentionKernel::scalar_t; + using accum_t = typename AttentionKernel::accum_t; + using lse_scalar_t = typename AttentionKernel::lse_scalar_t; + using output_t = typename AttentionKernel::output_t; + using output_accum_t = typename AttentionKernel::output_accum_t; + + static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout; + static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias; + static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock; + static constexpr bool kIsAligned = AttentionKernel::kIsAligned; + static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration; + static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward + static constexpr bool kPreloadV = AttentionKernel::kPreloadV; + static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF; + static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer; + + template + static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; + + // Advance to current batch - in case of different sequence lengths + if (p.seqlen_k_ptr) { + p.num_keys = p.seqlen_k_ptr[batch_id]; + } + + if (query_start >= p.num_queries) { + return false; + } + + // Advance to the current batch / head / query_start + p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH; + p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH; + p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH; + p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value; + + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH); + } + if (p.output_accum_ptr != nullptr) { + p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) + + int64_t(query_start) * (p.head_dim_value * p.num_heads) + + head_id * p.head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + p.output_accum_ptr = (accum_t*)(p.output_ptr); + } + + if (p.logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + p.logsumexp_ptr += + batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (p.causal_diagonal_ptr) { + p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id]; + } + if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + p.causal_diagonal_offset += p.num_keys - p.num_queries; + } + if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft || + p.custom_mask_type == AttentionKernel::CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + p.num_keys = cutlass::fast_min( + int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock), + p.num_keys); + } + + p.num_queries -= query_start; + p.num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + p.q_strideM = p.q_strideH; + p.num_queries = p.num_heads; + p.num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + p.custom_mask_type = AttentionKernel::NoCustomMask; + p.o_strideM = p.head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + p.query_ptr = warp_uniform(p.query_ptr); + p.key_ptr = warp_uniform(p.key_ptr); + p.value_ptr = warp_uniform(p.value_ptr); + if (kSupportsBias) { + p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr); + } + p.output_ptr = warp_uniform(p.output_ptr); + p.output_accum_ptr = warp_uniform(p.output_accum_ptr); + p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr); + p.num_queries = warp_uniform(p.num_queries); + p.num_keys = warp_uniform(p.num_keys); + p.num_heads = warp_uniform(p.num_heads); + p.head_dim = warp_uniform(p.head_dim); + p.head_dim_value = warp_uniform(p.head_dim_value); + p.o_strideM = warp_uniform(p.o_strideM); + p.custom_mask_type = warp_uniform(p.custom_mask_type); + return true; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl_right_padding(typename AK::Params p) { + if (!RightPaddingBatchHook::AdvanceToBlockForGQA(p)) { + return; + } + AK::attention_kernel(p); +} + template void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { using Attention = AttentionKernel; @@ -92,7 +219,11 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { } } - constexpr auto kernel_fn = attention_kernel_batched_impl; + auto kernel_fn = attention_kernel_batched_impl; + if (params.has_custom_right_padding) { + kernel_fn = attention_kernel_batched_impl_right_padding; + } + int smem_bytes = sizeof(typename Attention::SharedStorage); if (smem_bytes > 0xc000) { ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!"); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index f16567bb6f2b..484b783db172 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -43,6 +43,8 @@ struct MemoryEfficientAttentionParams { static bool need_workspace(size_t v_head_size, bool is_float) { return (v_head_size > 128 && !is_float); } + + bool has_custom_right_padding = false; }; void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 8694dc998c7a..f21dff08e035 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -17,19 +17,19 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupQueryAttention, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) \ - .MayInplace(3, 1) \ - .MayInplace(4, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 5), \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", {DataTypeImpl::GetTensorType()}) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ GroupQueryAttention); // REGISTER_KERNEL_TYPED(float) @@ -44,8 +44,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); num_heads_ = static_cast(num_heads); kv_num_heads_ = static_cast(kv_num_heads); - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 1) == 1; - is_past_bsnh_ = info.GetAttrOrDefault("is_past_bsnh", 1) == 1; + is_unidirectional_ = true; + // left_padding_ = info.GetAttrOrDefault("left_padding_last_token", 0) == 1; + is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -70,7 +71,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* value = context->Input(2); const Tensor* past_key = context->Input(3); const Tensor* past_value = context->Input(4); - const Tensor* past_seq_len = context->Input(5); + const Tensor* seqlens_k = context->Input(5); + const Tensor* total_seqlen = context->Input(6); auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -85,11 +87,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { ¶meters, num_heads_, kv_num_heads_, - past_seq_len, + seqlens_k, + total_seqlen, is_past_bsnh_, scale_, device_prop.maxThreadsPerBlock)); parameters.is_unidirectional = is_unidirectional_; + // parameters.left_padding = left_padding_; int sequence_length = parameters.sequence_length; TensorShapeVector output_shape(3); @@ -108,33 +112,26 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { size_t softmax_lse_bytes = 0; size_t softmax_lse_accum_bytes = 0; size_t out_accum_bytes = 0; - size_t seqlens_k_bytes = 0; if (use_flash_attention) { // softmax buffer softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); // split kv buffer using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = num_splits; softmax_lse_accum_bytes = slse_accum_bytes; out_accum_bytes = o_accum_bytes; - // seqlens_k buffer - if (past_key != nullptr) { - seqlens_k_bytes = sizeof(int) * parameters.batch_size; - } } auto softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); auto softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); auto out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); - auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); #else constexpr bool use_flash_attention = false; auto softmax_lse_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr auto softmax_lse_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr auto out_accum_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr - auto seqlens_k_buffer = GetScratchBuffer(0, context->GetComputeStream()); // nullptr #endif #if USE_MEMORY_EFFICIENT_ATTENTION @@ -143,7 +140,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && (parameters.head_size & 7) == 0 && - parameters.sequence_length <= parameters.past_sequence_length + parameters.kv_sequence_length && + parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && has_memory_efficient_attention(sm, sizeof(T) == 2); // allocate buffers @@ -151,7 +148,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { // need a buffer if we must ungroup kv const bool needs_buff = (parameters.num_heads != parameters.kv_num_heads); if (use_memory_efficient_attention && needs_buff) { - kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * (parameters.past_sequence_length + parameters.kv_sequence_length) * parameters.head_size); + kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); } size_t fmha_buffer_bytes = 0; if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { @@ -167,13 +164,18 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif + // seqlens_k buffer + size_t seqlens_k_bytes = 0; + seqlens_k_bytes = sizeof(int) * parameters.batch_size; + auto seqlens_k_buffer = GetScratchBuffer(seqlens_k_bytes, context->GetComputeStream()); + std::vector present_dims; if (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { present_dims = { - parameters.batch_size, parameters.present_sequence_length, parameters.kv_num_heads, parameters.head_size}; + parameters.batch_size, parameters.seqlen_present_kv_cache, parameters.kv_num_heads, parameters.head_size}; } else { // BNSH present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.present_sequence_length, parameters.head_size}; + parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; } TensorShape present_shape(present_dims); Tensor* present_key = context->Output(1, present_shape); @@ -187,8 +189,15 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.output = reinterpret_cast(output->MutableData()); data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); + data.seqlens_k = const_cast(seqlens_k->Data()); data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; + if (data.past_key == data.present_key) { + parameters.kv_share_buffer = true; + } else { + parameters.kv_share_buffer = false; + } + // Flash Buffers if (softmax_lse_buffer != nullptr) { data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); } @@ -199,8 +208,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } if (seqlens_k_buffer != nullptr) { - data.seqlens_k = reinterpret_cast(seqlens_k_buffer.get()); + data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); } + // Memory Efficient Buffers if (k_buffer != nullptr) { data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); @@ -208,6 +218,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } + if (k_buffer != nullptr) { + data.k = reinterpret_cast(k_buffer.get()); + data.v = reinterpret_cast(v_buffer.get()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index a90418ec2243..aade0436dc14 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -22,7 +22,7 @@ class GroupQueryAttention final : public CudaKernel { protected: int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention - int past_sequence_length_; + // bool left_padding_; // shifts last token to end of buffer bool is_unidirectional_; // causal bool is_past_bsnh_; float scale_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 8c21de9ced05..2cb9955807f2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -19,20 +19,21 @@ Status CheckInputs(const Tensor* query, void* parameters, int num_heads, int kv_num_heads, - const Tensor* past_seq_len, + const Tensor* seqlens_k, + const Tensor* total_seqlen, bool is_past_bsnh, float scale) { - // Note: Here S* is max_sequence_length, S- is past_sequence_length, S+ is kv_sequence_length - // past_key : (B, S*, N_k, H) or (B, N_k, S*, H) or (B, S-, N_k, H) or (B, N_k, S-, H) - // past_value : (B, S*, N_k, H) or (B, N_k, S*, H) or (B, S-, N_k, H) or (B, N_k, S-, H) + // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) // no packing for q/k/v: // query (Q) : (B, S, D) - // key (K) : (B, S+, D_kv) - // value (V) : (B, S+, D_kv) + // key (K) : (B, S, D_kv) + // value (V) : (B, S, D_kv) ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; - AttentionQkvFormat past_kv_format = Q_K_V_BSNH; + AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; const auto& query_dims = query->Shape().GetDims(); const auto& key_dims = key->Shape().GetDims(); @@ -47,10 +48,9 @@ Status CheckInputs(const Tensor* query, int q_hidden_size = static_cast(query_dims[2]); int head_size = static_cast(q_hidden_size) / num_heads; - int kv_sequence_length = static_cast(key_dims[1]); int kv_hidden_size = static_cast(key_dims[2]); - int max_sequence_length = 0; + int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); const auto& past_value_dims = past_value->Shape().GetDims(); @@ -79,7 +79,6 @@ Status CheckInputs(const Tensor* query, // BNSH if (!is_past_bsnh) { - past_kv_format = Q_K_V_BNSH; if (past_key_dims[2] != past_value_dims[2]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" @@ -94,11 +93,10 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_value' shall have kv_num_heads"); } - // We assume all sequence in past kv are left-padded to max or past sequence length - max_sequence_length = static_cast(past_key_dims[2]); + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[2]); // BSNH } else { - past_kv_format = Q_K_V_BSNH; if (past_key_dims[1] != past_value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" @@ -113,8 +111,8 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_value' shall have kv_num_heads"); } - // We assume all sequence in past kv are left-padded to max or past sequence length - max_sequence_length = static_cast(past_key_dims[1]); + // We assume all sequence in past kv are right-padded to max or past sequence length + past_sequence_length = static_cast(past_key_dims[1]); } if (past_key_dims[3] != head_size) { @@ -129,7 +127,7 @@ Status CheckInputs(const Tensor* query, } } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be both present or both absent"); + "Input 'past_key' and 'past_value' shall be both present or both absent."); } if (key_dims.size() != 3) { @@ -158,56 +156,45 @@ Status CheckInputs(const Tensor* query, "Input 'query' and 'value' shall have same dim 0 (batch_size)"); } - if (static_cast(kv_sequence_length) != value_dims[1]) { + if (static_cast(sequence_length) != value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)"); + "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); } if (value_dims[2] != kv_hidden_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); } - // When kv-cache, we take past_seq_len as an argument... otherwise we use sequence length of past kv directly. - int32_t past_sequence_length = 0; - int present_sequence_length = kv_sequence_length; - if (past_seq_len != nullptr) { - if (past_key == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past KV must be present as share-buffer when using past_seq_len pointer."); - } - if (!onnxruntime::IsScalarOr1ElementVector(past_seq_len)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "past_sequence_length tensor must be of one element when using past kv."); - } - if (past_seq_len->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_INT32) { - past_sequence_length = *((*past_seq_len).template Data()); - } else { - past_sequence_length = static_cast(*((*past_seq_len).template Data())); - } - if (past_sequence_length + kv_sequence_length > max_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "KV buffer too small... shall be that max_sequence_length >= past_sequence_length + kv_sequence_length"); - } - present_sequence_length = max_sequence_length; - } else if (past_key != nullptr) { - past_sequence_length = max_sequence_length; // this is the length of past_key tensor - present_sequence_length = past_sequence_length + kv_sequence_length; + // Check seqlens_k tensor (holding past seqlen for token gen) + const auto& seqlens_dim = seqlens_k->Shape().GetDims(); + if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k must be shape (batch_size)."); } + // Set present sequence length and kv_share_buffer from input total_seqlen tensor + if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "total_sequence_length tensor must be of one element."); + } + int total_sequence_length = *((*total_seqlen).template Data()); + int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + + bool is_prompt = sequence_length != 1; + if (parameters != nullptr) { GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); output_parameters->batch_size = batch_size; - output_parameters->sequence_length = sequence_length; - output_parameters->past_sequence_length = past_sequence_length; - output_parameters->kv_sequence_length = kv_sequence_length; - output_parameters->present_sequence_length = present_sequence_length; - output_parameters->max_sequence_length = max_sequence_length; + output_parameters->sequence_length = sequence_length; // sequence length of Q + output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors + output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; output_parameters->head_size = q_hidden_size / num_heads; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; output_parameters->is_unidirectional = true; + output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; output_parameters->qkv_format = qkv_format; output_parameters->past_kv_format = past_kv_format; @@ -216,16 +203,16 @@ Status CheckInputs(const Tensor* query, return Status::OK(); } -template -Status CheckInputs(const T* query, - const T* key, - const T* value, - const T* past_key, - const T* past_value, +Status CheckInputs(const Tensor* query, + const Tensor* key, + const Tensor* value, + const Tensor* past_key, + const Tensor* past_value, void* parameters, int num_heads, int kv_num_heads, - const T* past_seq_len, + const Tensor* seqlens_k, + const Tensor* total_seqlen, bool is_past_bsnh, float scale, int max_threads_per_block) { @@ -233,7 +220,7 @@ Status CheckInputs(const T* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, past_seq_len, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 0455825c364a..2d158155eeba 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -41,6 +41,8 @@ limitations under the License. #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" +#include using namespace onnxruntime::cuda; @@ -60,35 +62,37 @@ __global__ void repeat_seqlen(int32_t* seqlens_k, int32_t seqlen, int batch_size // Adapted from ConcatTensorToTensor kernel in attention_kv_cache.cu file template __global__ void ConcatNewToPastKV(const int new_seqlen, + const int past_buffer_seqlen, const T* past_kv, const T* new_kv, T* present_kv, + const int* seqlens_k, const bool is_bsnh) { // refers to past; otherwise bnsh const int h = threadIdx.x; const int n = threadIdx.y; const int s = blockIdx.x; const int b = blockIdx.y; - const int present_seqlen = gridDim.x; + const int present_buffer_seqlen = gridDim.x; const int num_heads = blockDim.y; const int H = blockDim.x; - const int present_batch_stride = present_seqlen * num_heads * H; + const int present_batch_stride = present_buffer_seqlen * num_heads * H; const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = present_seqlen - new_seqlen; + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { - const int past_batch_stride = past_seqlen * num_heads * H; - const int past_head_stride = is_bsnh ? H : past_seqlen * H; + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < present_seqlen) { + } else if (s < past_seqlen + new_seqlen) { // Note: new KV always BSNH const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; @@ -101,11 +105,13 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, // Use when (H*)*num_heads > 1024 template __global__ void ConcatNewToPastKVLarge(const int new_seqlen, + const int past_buffer_seqlen, const int H, const int num_heads, const T* past_kv, const T* new_kv, T* present_kv, + const int* seqlens_k, const bool is_bsnh) { int i = threadIdx.x + (blockDim.x * blockIdx.x); if (i < H * num_heads) { @@ -113,24 +119,24 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int n = i / H; const int s = blockIdx.y; const int b = blockIdx.z; - const int present_seqlen = gridDim.y; + const int present_buffer_seqlen = gridDim.y; - const int present_batch_stride = present_seqlen * num_heads * H; + const int present_batch_stride = present_buffer_seqlen * num_heads * H; const int row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; + const int present_head_stride = is_bsnh ? H : present_buffer_seqlen * H; // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = present_seqlen - new_seqlen; + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { - const int past_batch_stride = past_seqlen * num_heads * H; - const int past_head_stride = is_bsnh ? H : past_seqlen * H; + const int past_batch_stride = past_buffer_seqlen * num_heads * H; + const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < present_seqlen) { + } else if (s < past_seqlen + new_seqlen) { const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; const int new_head_stride = H; @@ -147,10 +153,13 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter cudaStream_t stream, const int max_threads_per_block) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.kv_sequence_length; - const int present_sequence_length = parameters.present_sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int past_sequence_length = parameters.seqlen_past_kv_cache; + const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + AttentionQkvFormat past_kv_format = parameters.past_kv_format; assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -159,32 +168,40 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter const dim3 grid(present_sequence_length, batch_size, 1); const dim3 block(H, kv_num_heads, 1); ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, reinterpret_cast(data.past_key), reinterpret_cast(data.key), reinterpret_cast(data.present_key), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, + past_sequence_length, reinterpret_cast(data.past_value), reinterpret_cast(data.value), reinterpret_cast(data.present_value), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { int steps = (H * kv_num_heads + 255) / 256; const dim3 grid(steps, present_sequence_length, batch_size); const dim3 block(256, 1, 1); ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, H, kv_num_heads, reinterpret_cast(data.past_key), reinterpret_cast(data.key), reinterpret_cast(data.present_key), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKVLarge<<>>(kv_sequence_length, + past_sequence_length, H, kv_num_heads, reinterpret_cast(data.past_value), reinterpret_cast(data.value), reinterpret_cast(data.present_value), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } return CUDA_CALL(cudaGetLastError()); @@ -192,10 +209,10 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter // Kernel to append new kv to kv buffer in place template -__global__ void ConcatKVInPlace(const int past_seqlen, - const int present_seqlen, +__global__ void ConcatKVInPlace(const int max_seqlen, T* kv_buff, const T* new_kv, + const int* seqlens_k, const bool is_bsnh) { // refers to kv buff; otherwise bnsh const int h = threadIdx.x; const int n = threadIdx.y; @@ -206,14 +223,16 @@ __global__ void ConcatKVInPlace(const int past_seqlen, const int num_heads = blockDim.y; const int H = blockDim.x; - const int present_batch_stride = present_seqlen * num_heads * H; + const int present_batch_stride = max_seqlen * num_heads * H; const int present_row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; // kv_buff: BTNH or BNTH with buffered memory for new // new_kv: BLNH - int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; // Note: new KV always BSNH const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; @@ -223,12 +242,12 @@ __global__ void ConcatKVInPlace(const int past_seqlen, } template -__global__ void ConcatKVInPlaceLarge(const int past_seqlen, - const int present_seqlen, +__global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int H, const int num_heads, T* kv_buff, const T* new_kv, + const int* seqlens_k, const bool is_bsnh) { // refers to kv buff; otherwise bnsh int i = threadIdx.x + (blockDim.x * blockIdx.x); if (i < H * num_heads) { @@ -238,14 +257,16 @@ __global__ void ConcatKVInPlaceLarge(const int past_seqlen, const int b = blockIdx.z; const int new_seqlen = gridDim.y; - const int present_batch_stride = present_seqlen * num_heads * H; + const int present_batch_stride = max_seqlen * num_heads * H; const int present_row_stride = is_bsnh ? num_heads * H : H; - const int present_head_stride = is_bsnh ? H : present_seqlen * H; + const int present_head_stride = is_bsnh ? H : max_seqlen * H; // kv_buff: BTNH or BNTH with buffered memory for new // new_kv: BLNH - int out_offset = b * present_batch_stride + (s + past_seqlen) * present_row_stride + n * present_head_stride + h; + const int past_seq_len = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + int out_offset = b * present_batch_stride + (s + past_seq_len) * present_row_stride + n * present_head_stride + h; // Note: new KV always BSNH const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; @@ -262,44 +283,47 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, cudaStream_t stream, const int max_threads_per_block) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.kv_sequence_length; - const int present_sequence_length = parameters.present_sequence_length; - const int past_sequence_length = parameters.past_sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; + + // Indicates past sequence_length of each sequence + const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + AttentionQkvFormat past_kv_format = parameters.past_kv_format; assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); const int H = head_size / 4; if (H * kv_num_heads <= max_threads_per_block) { const dim3 grid(kv_sequence_length, batch_size, 1); const dim3 block(H, kv_num_heads, 1); - ConcatKVInPlace<<>>(past_sequence_length, - present_sequence_length, + ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_key), reinterpret_cast(data.key), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatKVInPlace<<>>(past_sequence_length, - present_sequence_length, + ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_value), reinterpret_cast(data.value), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { int steps = int(ceil(float(H * kv_num_heads) / 256.0)); const dim3 grid(steps, kv_sequence_length, batch_size); const dim3 block(256, 1, 1); - ConcatKVInPlaceLarge<<>>(past_sequence_length, - present_sequence_length, + ConcatKVInPlaceLarge<<>>(present_sequence_length, H, kv_num_heads, reinterpret_cast(data.present_key), reinterpret_cast(data.key), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatKVInPlaceLarge<<>>(past_sequence_length, - present_sequence_length, + ConcatKVInPlaceLarge<<>>(present_sequence_length, H, kv_num_heads, reinterpret_cast(data.present_value), reinterpret_cast(data.value), + seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } return CUDA_CALL(cudaGetLastError()); @@ -417,6 +441,82 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } + +__global__ void PastToTotalSeqlen(int32_t* seqlens_k, + int32_t* seqlens_k_buff, + const int add_seqlen) { + seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; +} + +// Convert Past to Total sequence length tensor +Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, + const int threads_per_block) { + if (parameters.is_prompt) { + return Status::OK(); + } + const int batch_size = parameters.batch_size; + const int add_seqlen = is_total ? parameters.sequence_length : 0; + + const dim3 grid(1, 1, 1); + // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads + const dim3 block(batch_size, 1, 1); + + // TODO(aciddelgado): small version + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); + + return CUDA_CALL(cudaGetLastError()); +} + +// // Kernel to append new kv to kv buffer in place +// template +// __global__ void LeftPadLast(const int max_seqlen, +// T* kv_buff, +// const int* seqlens_k) { // refers to kv buff; otherwise bnsh +// const int h = threadIdx.x; +// const int n = blockIdx.x; +// const int b = blockIdx.y; + +// const int num_heads = gridDim.x; +// const int H = blockDim.x; + +// const int present_batch_stride = max_seqlen * num_heads * H; +// const int present_row_stride = num_heads * H; +// const int present_head_stride = H; + +// // kv_buff: BTNH or BNTH with buffered memory for new +// // new_kv: BLNH + +// const int s = seqlens_k[b]; + +// const int in_offset = b * present_batch_stride + s * present_row_stride + n * present_head_stride + h; +// const int out_offset = b * present_batch_stride + (max_seqlen - 1) * present_row_stride + n * present_head_stride + h; +// kv_buff[out_offset] = kv_buff[in_offset]; +// } + +// // Concat new to kv buffer in place +// template +// Status LaunchLeftPadLast(contrib::GroupQueryAttentionParameters& parameters, +// GroupQueryAttentionData& data, +// cudaStream_t stream, +// const int max_threads_per_block) { +// const int batch_size = parameters.batch_size; +// const int sequence_length = parameters.sequence_length; +// const int num_heads = parameters.num_heads; +// const int head_size = parameters.head_size; + +// // Indicates past sequence_length of each sequence +// const int* seqlens_k = reinterpret_cast(data.seqlens_k); + +// const int H = head_size / 4; +// const dim3 grid(num_heads, batch_size, 1); +// const dim3 block(H, 1, 1); +// LeftPadLast<<>>(sequence_length, +// reinterpret_cast(data.output), +// seqlens_k); +// return CUDA_CALL(cudaGetLastError()); +// } + ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -430,8 +530,8 @@ Status FlashAttention( const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int present_sequence_length = parameters.present_sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; @@ -443,40 +543,81 @@ Status FlashAttention( bool is_causal = parameters.is_unidirectional; - if (data.past_key != nullptr && data.past_key == data.present_key) { + // Note: seqlens_k is past sequence length for flash + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } + + void* seqlens_k = reinterpret_cast(data.seqlens_k); + + if (parameters.kv_share_buffer) { // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv shall share the same tensor when kv_share_buffer is on."); + } + + if (parameters.is_prompt) { + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + key = nullptr; + value = nullptr; + seqlens_k = reinterpret_cast(data.seqlens_k_total); + } + void* present_key = reinterpret_cast(const_cast(data.present_key)); void* present_value = reinterpret_cast(const_cast(data.present_value)); - // Launch kernel to copy seqlen - int thr_per_blk = 256; - int blk_in_grid = ceil(float(batch_size) / thr_per_blk); - repeat_seqlen<<>>(data.seqlens_k, parameters.past_sequence_length, batch_size); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - reinterpret_cast(data.seqlens_k), batch_size, num_heads, kv_num_heads, + seqlens_k, batch_size, num_heads, kv_num_heads, head_size, sequence_length, present_sequence_length, kv_sequence_length, scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum))); - } else { - // Not share buffer or no past (prompt generation) + // Not share buffer case // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient + if (data.past_key != nullptr && data.past_key == data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv share the same tensor but kv_share_buffer is not on."); + } + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + if (!parameters.is_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + seqlens_k = reinterpret_cast(data.seqlens_k_total); + void* present_key = reinterpret_cast(const_cast(data.present_key)); void* present_value = reinterpret_cast(const_cast(data.present_value)); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); + DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); + DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( - device_prop, stream, query, present_key, present_value, data.output, reinterpret_cast(data.softmax_lse), - batch_size, num_heads, kv_num_heads, head_size, - sequence_length, present_sequence_length, scale, is_causal, parameters.num_splits, - reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), past_bsnh)); + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), + seqlens_k, batch_size, num_heads, kv_num_heads, + head_size, sequence_length, present_sequence_length, 0, + scale, is_causal, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum))); } + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -495,9 +636,7 @@ Status EfficientAttention( const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; - const int past_sequence_length = parameters.past_sequence_length; - const int present_sequence_length = parameters.present_sequence_length; + const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; @@ -506,64 +645,68 @@ Status EfficientAttention( const void* query = reinterpret_cast(data.query); const void* key = reinterpret_cast(data.key); const void* value = reinterpret_cast(data.value); - if (data.past_key != nullptr) { - // Past key case - // concatenate new kv to past kv - if (data.past_key == data.present_key) { - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); - } else { - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + + if (parameters.is_prompt) { + // Launch kernel to copy seqlen + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + } else { + ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + } + + if (parameters.kv_share_buffer) { + // Share buffer case + if (data.past_key == nullptr || data.past_key != data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv shall share the same tensor when kv_share_buffer is on."); } - const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - if (num_heads == kv_num_heads) { - // Use present kv directly if not grouped - key = reinterpret_cast(data.present_key); - value = reinterpret_cast(data.present_value); - } else { - // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path - float2* k_buff = reinterpret_cast(data.k); - float2* v_buff = reinterpret_cast(data.v); - const float2* k_og = reinterpret_cast(data.present_key); - const float2* v_og = reinterpret_cast(data.present_value); - ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, past_sequence_length + kv_sequence_length, - present_sequence_length, is_bsnh, stream, max_threads_per_block)); - key = reinterpret_cast(data.k); - value = reinterpret_cast(data.v); + // Concatenate new kv in place + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + } else { + // Not share buffer case + if (data.past_key != nullptr && data.past_key == data.present_key) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Past and present kv share the same tensor but kv_share_buffer is not on."); } - } else if (num_heads == kv_num_heads) { - // no past or present and no need to ungroup... still copy kv into present buffer + // Copy past and concat new KV to present buffer ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + } + + // Ungroup if grouped, otherwise use present kv directly + const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + if (num_heads == kv_num_heads) { + // Use present kv directly if not grouped key = reinterpret_cast(data.present_key); value = reinterpret_cast(data.present_value); } else { - // intermediate buffer so q and kv have same num heads... still copy kv into present buffer - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + // Otherwise we use intermediate buffers to run memory efficient attention... best avoid this path float2* k_buff = reinterpret_cast(data.k); float2* v_buff = reinterpret_cast(data.v); const float2* k_og = reinterpret_cast(data.present_key); const float2* v_og = reinterpret_cast(data.present_value); - ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, kv_sequence_length, - kv_sequence_length, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH, stream, - max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, + present_sequence_length, is_bsnh, stream, max_threads_per_block)); key = reinterpret_cast(data.k); value = reinterpret_cast(data.v); } + DUMP_TENSOR_INIT(); + DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; p.is_half = sizeof(T) == 2; p.batch_size = batch_size; p.num_heads = num_heads; p.sequence_length = sequence_length; - p.kv_sequence_length = past_sequence_length + kv_sequence_length; - p.max_sequence_length = (num_heads == kv_num_heads) ? present_sequence_length : past_sequence_length + kv_sequence_length; + p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; p.causal = parameters.is_unidirectional; p.scale = scale; - p.seqlen_k_ptr = nullptr; - p.seqstart_q_ptr = nullptr; - p.seqstart_k_ptr = nullptr; + p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient p.query = query; p.key = key; p.value = value; @@ -575,8 +718,13 @@ Status EfficientAttention( ? data.fmha_buffer : nullptr; p.stream = stream; + p.has_custom_right_padding = true; run_memory_efficient_attention(p); + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 8412631078e6..de32d7ea9316 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -20,11 +20,12 @@ struct GroupQueryAttentionData { const T* value = nullptr; const T* past_key = nullptr; const T* past_value = nullptr; + int* seqlens_k = nullptr; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; - int* seqlens_k = nullptr; + int* seqlens_k_total = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; T* k = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index d7aeef1501cd..3b5232083940 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -529,6 +529,7 @@ Status FusedScaledDotProductAttentionCutlass( p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? accum_workspace : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR("PackedAttention cutlass output", data.output, parameters.token_count, num_heads, v_head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 3fe9dbf8ed34..8a508241d80b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -708,6 +708,7 @@ Status FusedAttentionCutlass( ? (data.workspace + (data.no_qkv_workspace ? 0 : (elements_qk + elements_qk + elements_v))) : nullptr; p.stream = stream; + p.has_custom_right_padding = false; run_memory_efficient_attention(p); DUMP_TENSOR_INIT(); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index db32cb3c05de..893776e7786f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -990,18 +990,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .SetDoc(GroupQueryAttention_ver1_doc) .Attr("num_heads", "Number of attention heads for q", AttributeProto::INT) .Attr("kv_num_heads", "Number of attention heads for k and v", AttributeProto::INT) - .Attr("unidirectional", - "Whether every token can only attend to previous tokens. Default value is 1.", - AttributeProto::INT, - static_cast(1)) - .Attr("is_past_bsnh", - "Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).", - AttributeProto::INT, - static_cast(1)) .Attr("scale", "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) + // .Attr("left_padding_last_token", + // "Copy last token to last index of buffer. Default is 0; 1 when true.", + // AttributeProto::INT, + // OPTIONAL_VALUE) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size)", @@ -1016,40 +1012,42 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "T") .Input(3, "past_key", - "past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" + "past state key with support for format BNSH. When past_key uses same tensor as present_key" "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", "T", OpSchema::Optional) .Input(4, "past_value", - "past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" + "past state value with support for format BNSH. When past_value uses same tensor as present_value" "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", "T", OpSchema::Optional) .Input(5, - "past_sequence_length", - "When buffered past_key and past_value is used (present_key uses same tensor as past_key), required" - "to specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.", - "M", - OpSchema::Optional) + "seqlens_k", + "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", + "M") + .Input(6, + "total_sequence_length", + "Scalar tensor of total sequence length (past + new).", + "M") .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Output(1, "present_key", - "present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key" + "present state key with support for format BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") .Output(2, "present_value", - "present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value" + "present state value with support for format BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output to float tensors.") - .TypeConstraint("M", {"tensor(int32)", "tensor(int64)"}, "Constrain past sequence length to int tensor.") + .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { GroupQueryAttentionTypeAndShapeInference(ctx, 3); })); @@ -1119,7 +1117,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); constexpr const char* RotaryEmbedding_ver1_doc = R"DOC( -RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices +RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices that are multiplied to query and key before the inner product of query and key is taken. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1532,4 +1530,4 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 7aca5e8526a2..b59af41c49df 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,39 +1272,96 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove -def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1): - past_seq_len = past_seq_len_input - if past_seq_len not in model.get_graphs_input_names(): - # Add model input for past sequence length - new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1]) - model.model.graph.input.append(new_input) +def replace_mha_with_gqa(model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1): + # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes + # + # attention_mask + # / \ + # ReduceSum Shape + # | | + # Sub Gather + # | | + # seqlens_k total_sequence_length + # | | + # Cast to int32 Cast to int32 + + model.add_initializer( + onnx.helper.make_tensor( + name="one", + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + ) + ) + reduce_sum_node = onnx.helper.make_node( + "ReduceSum", + inputs=[attn_mask, "one"], + outputs=[attn_mask + "_row_sums"], + name=model.create_node_name("ReduceSum"), + ) + sub_node = onnx.helper.make_node( + "Sub", + inputs=[attn_mask + "_row_sums", "one"], + outputs=["seqlens_k_int64"], + name=model.create_node_name("Sub"), + ) + seqlen_k_cast_node = onnx.helper.make_node( + "Cast", + inputs=["seqlens_k_int64"], + outputs=["seqlens_k"], + name=model.create_node_name("Cast"), + to=TensorProto.INT32, + ) + shape_node = onnx.helper.make_node( + "Shape", + inputs=[attn_mask], + outputs=[attn_mask + "_shape"], + name=model.create_node_name("Shape"), + ) + gather_node = onnx.helper.make_node( + "Gather", + inputs=[attn_mask + "_shape", "one"], + outputs=["total_seq_len_int64"], + name=model.create_node_name("Gather"), + axis=0, + ) + total_seqlen_cast_node = onnx.helper.make_node( + "Cast", + inputs=["total_seq_len_int64"], + outputs=["total_seq_len"], + name=model.create_node_name("Cast"), + to=TensorProto.INT32, + ) + model.model.graph.node.extend( + [reduce_sum_node, sub_node, seqlen_k_cast_node, shape_node, gather_node, total_seqlen_cast_node] + ) # Replace MultiHeadAttention with GroupQueryAttention - for node in model.model.graph.node: - if node.op_type == "MultiHeadAttention": - num_heads_mha = 0 - for att in node.attribute: - if att.name == "num_heads": - num_heads_mha = att.i - gqa_node = onnx.helper.make_node( - "GroupQueryAttention", - inputs=[ - node.input[0], # query - node.input[1], # key - node.input[2], # value - node.input[6], # past_key - node.input[7], # past_value - past_seq_len, # past_sequence_length - ], - outputs=node.output, - name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), - domain="com.microsoft", - num_heads=num_heads_mha // world_size, - kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, - is_past_bsnh=0, - ) - model.model.graph.node.remove(node) - model.model.graph.node.extend([gqa_node]) + mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node)) + for node in mha_nodes: + num_heads_mha = 0 + for att in node.attribute: + if att.name == "num_heads": + num_heads_mha = att.i + gqa_node = onnx.helper.make_node( + "GroupQueryAttention", + inputs=[ + node.input[0], # query + node.input[1], # key + node.input[2], # value + node.input[6], # past_key + node.input[7], # past_value + "seqlens_k", # seqlens_k (for attention_mask) + "total_seq_len", # total_seq_len (for attention_mask) + ], + outputs=node.output, + name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), + domain="com.microsoft", + num_heads=num_heads_mha // world_size, + kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, + ) + model.model.graph.node.remove(node) + model.model.graph.node.extend([gqa_node]) return model diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index 1bb6940d1cd7..0c6f830ed26b 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -117,7 +117,7 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu ``` -Export for FP16 CUDA +Export for FP16 CUDA (with MultiHeadAttention) ``` # From source: $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda @@ -126,6 +126,63 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output l $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda ``` +Export for FP16 CUDA (with GroupQueryAttention) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --use_gqa +``` + +Note: GroupQueryAttention currently runs on Linux for FP16 CUDA and INT4 CUDA models, and it can provide faster inference than MultiHeadAttention, especially for large sequence lengths (e.g. 1024 or larger). For the best performance, you should pre-allocate the KV cache buffers to have size `(batch_size, num_heads, max_sequence_length, head_size)` so that the past KV and present KV caches share the same memory. You also need to bind them with ONNX Runtime's [IO binding](https://onnxruntime.ai/docs/api/python/api_summary.html#iobinding). + +Here is an example of how you can bind directly to `torch.tensor` objects: +``` +# Assumes all inputs and outputs to the model are pre-allocated with the correct shapes in GPU memory + +# Bind inputs +for k, v in inputs.items(): + io_binding.bind_input( + name=k, + device_type="cuda", + device_id=0, + element_type=np.float16, + shape=tuple(v.shape), + buffer_ptr=v.data_ptr() + ) + +# Bind outputs +for output in model.get_outputs(): + name = output.name + if "present" in name: + # Bind KV cache outputs to KV cache inputs + v = inputs[name.replace("present", "past_key_values")] + io_binding.bind_output( + name=name, + device_type="cuda", + device_id=0, + element_type=np.float16, + shape=tuple(v.shape), + buffer_ptr=v.data_ptr() + ) + else: + # Bind other outputs as actual outputs + v = outputs[name] + io_binding.bind_output( + name=name, + device_type="cuda", + device_id=0, + element_type=np.float16, + shape=tuple(v.shape), + buffer_ptr=v.data_ptr() + ) + +io_binding.synchronize_inputs() +sess.run_with_iobinding(io_binding) +io_binding.synchronize_outputs() +``` + Export for INT8 CPU (SmoothQuant) ``` # From source: @@ -149,12 +206,14 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama Export for INT4 CUDA ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda --use_gqa # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda --use_gqa ``` +Note: See the FP16 CUDA notes about GroupQueryAttention. The `--use_gqa` flag is optional. + Export for INT4 CPU ``` # From source: @@ -168,13 +227,13 @@ Export LLaMA-2 70B sharded model into 4 partitions ``` # From source: # 1. Install necessary packages from requirements-70b-model.txt +$ pip install -r requirements-70b-model.txt # 2. Build ONNX Runtime from source with NCCL enabled. Here is a sample command: -$ ./build.sh --config RelWithDebInfo --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ +$ ./build.sh --config Release --use_cuda --cuda_home /usr/local/cuda-12.2 --cudnn_home /usr/local/cuda-12.2 --build_wheel --cuda_version=12.2 --parallel --skip_tests --enable_nccl --nccl_home /usr/local/cuda-12.2 --use_mpi --mpi_home=/usr/lib/x86_64-linux-gnu/ # 3. Shard and export the LLaMA-2 70B model. With FP16, you will need at least 140GB of GPU memory to load the model. Therefore, you will need at least 4 40GB A100 GPUs or 2 80GB A100 GPUs to shard the PyTorch model and export each shard to ONNX. Here is an example command: -$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-dis --precision fp16 --execution_provider cuda - +$ CUDA_VISIBLE_DEVICES=0,1,2,3 bash convert_70b_model.sh 4 -m meta-llama/Llama-2-70b-hf --output llama2-70b-distributed --precision fp16 --execution_provider cuda --use_gqa ``` ## Benchmark LLaMA-2 @@ -220,7 +279,20 @@ python3 -m models.llama.benchmark \ --auth ``` -4. ONNX Runtime, FP32, Microsoft custom export +4. Optimum + ONNX Runtime, FP16, export via Optimum or convert_to_onnx +``` +python3 -m models.llama.benchmark \ + --benchmark-type hf-ort \ + --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda \ + --auth +``` + +5. ONNX Runtime, FP32, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -232,7 +304,7 @@ python3 -m models.llama.benchmark \ --device cpu ``` -5. ONNX Runtime, FP16, Microsoft custom export +6. ONNX Runtime, FP16, Microsoft custom export ``` python3 -m models.llama.benchmark \ --benchmark-type ort-msft \ @@ -244,7 +316,7 @@ python3 -m models.llama.benchmark \ --device cuda ``` -6. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU +7. ONNX Runtime, FP32, convert_to_onnx, use 2nd GPU ``` CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ @@ -256,7 +328,7 @@ CUDA_VISIBLE_DEVICES=1 python3 -m models.llama.benchmark \ --device cpu ``` -7. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU +8. ONNX Runtime, FP16, convert_to_onnx, use 5th GPU ``` CUDA_VISIBLE_DEVICES=4 python3 -m models.llama.benchmark \ --benchmark-type ort-convert-to-onnx \ @@ -283,5 +355,8 @@ python3 -m models.llama.benchmark_all \ --precision fp16 \ --batch-sizes "1 2" \ --sequence-lengths "8 16" \ - --device cuda + --device cuda \ + --warmup-runs 5 \ + --num-runs 1000 \ + --timeout 60 # number of minutes before moving to the next benchmark ``` diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index be678931de5d..021b0dd03a9d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -11,6 +11,7 @@ import onnx import psutil import torch +from benchmark_helper import measure_memory, setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings, @@ -22,10 +23,9 @@ from optimum.onnxruntime import ORTModelForCausalLM from torch.profiler import ProfilerActivity, profile, record_function from tqdm import trange -from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import measure_memory, setup_logger logger = logging.getLogger(__name__) @@ -107,6 +107,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, max_seq_len=max_seq_len, use_fp16=args.use_fp16, + use_gqa=args.use_gqa, engine="pt", return_dict=True, ) @@ -118,6 +119,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, + use_gqa=args.use_gqa, engine="pt", return_dict=True, ) @@ -132,6 +134,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=0, max_seq_len=max_seq_len, use_fp16=args.use_fp16, + use_gqa=args.use_gqa, engine="ort", return_dict=True, world_size=args.world_size, @@ -144,6 +147,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): past_seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, + use_gqa=args.use_gqa, engine="ort", return_dict=True, world_size=args.world_size, @@ -160,6 +164,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): seq_len=args.sequence_length, max_seq_len=max_seq_len, use_fp16=args.use_fp16, + use_gqa=args.use_gqa, split_kv=split_kv, ) iter_inputs = get_msft_sample_inputs( @@ -169,6 +174,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): seq_len=1, max_seq_len=max_seq_len, use_fp16=args.use_fp16, + use_gqa=args.use_gqa, split_kv=split_kv, ) @@ -192,7 +198,7 @@ def get_model(args: argparse.Namespace): if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name start_time = time.time() - model = LlamaForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( source, torch_dtype=torch.float16 if args.use_fp16 else torch.float32, use_auth_token=args.auth, @@ -456,7 +462,7 @@ def prepare_ort_inputs(inputs, kv_cache_ortvalues): # Add IO bindings for non-CPU execution providers if args.device != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings( - model, inputs, args.device, int(args.rank), kv_cache_ortvalues + model, inputs, args.device, int(args.rank), args.use_gqa, kv_cache_ortvalues ) setattr(args, "io_binding", io_binding) # noqa: B010 return io_binding, kv_cache_ortvalues @@ -650,8 +656,8 @@ def main(): args.rank = rank args.world_size = world_size - tokenizer = LlamaTokenizer.from_pretrained(args.model_name) - config = LlamaConfig.from_pretrained(args.model_name) + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + config = AutoConfig.from_pretrained(args.model_name) target_device = f"cuda:{args.rank}" if args.device != "cpu" else args.device use_fp16 = args.precision == "fp16" @@ -670,9 +676,9 @@ def main(): gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" - setattr(args, "past_present_share_buffer", use_buffer_share) # noqa: B010 + setattr(args, "use_gqa", use_buffer_share) # noqa: B010 else: - setattr(args, "past_present_share_buffer", False) # noqa: B010 + setattr(args, "use_gqa", False) # noqa: B010 # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index b0e0b41e75d3..c9c7f4d39d42 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -7,6 +7,8 @@ import onnx import torch +from benchmark_helper import Precision, prepare_environment, setup_logger +from convert_generation import replace_mha_with_gqa from dist_settings import barrier, get_rank, get_size, init_dist from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check @@ -14,12 +16,10 @@ from onnx_model import OnnxModel from optimizer import optimize_model from packaging import version -from transformers import LlamaConfig, LlamaForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM from onnxruntime import quantization as ort_quantization from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer -from onnxruntime.transformers.benchmark_helper import Precision, prepare_environment, setup_logger -from onnxruntime.transformers.convert_generation import replace_mha_with_gqa logger = logging.getLogger("") init_dist() @@ -133,7 +133,7 @@ def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: st # temp_dir.cleanup() # def run_dynamo_export( - args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 + args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 ): from torch._dynamo import config @@ -194,7 +194,7 @@ def _prepare_dir(dir_path): def run_torchscript_separate_export( - args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 + args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 ): # Dummy values for export batch_size, sequence_length = 2, 8 @@ -313,7 +313,7 @@ def run_torchscript_separate_export( def run_torchscript_merged_export( - args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM, rank: int = 0, world_size: int = 1 + args: argparse.Namespace, l_config: AutoConfig, llama: AutoModelForCausalLM, rank: int = 0, world_size: int = 1 ): # Dummy values for export batch_size, sequence_length, past_sequence_length = 2, 8, 0 @@ -391,7 +391,7 @@ def run_torchscript_merged_export( # Optimize the model as FP32 -def optimize_export(config: LlamaConfig, input_path: str, output_path: str): +def optimize_export(config: AutoConfig, input_path: str, output_path: str): from fusion_options import FusionOptions optimization_options = FusionOptions("gpt2") @@ -411,7 +411,7 @@ def optimize_export(config: LlamaConfig, input_path: str, output_path: str): def convert_to_float16( - args: argparse.Namespace, config: LlamaConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 + args: argparse.Namespace, config: AutoConfig, old_paths: List[str], rank: int = 0, world_size: int = 1 ): decoder_model_fp16_path = os.path.join(args.output, f"rank_{rank}_{args.model_name}_decoder_model_fp16.onnx") decoder_with_past_model_fp16_path = os.path.join( @@ -427,7 +427,8 @@ def convert_to_float16( if os.path.exists(fp32_path): model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False) - model = use_group_query_attention(config, model, world_size) + if args.use_gqa: + model = use_group_query_attention(config, model, world_size) model.save_model_to_file(fp16_path, use_external_data_format=True) del model logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") @@ -437,11 +438,9 @@ def convert_to_float16( return new_paths -def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel, world_size: int = 1): - # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes - fp16_model_opt = replace_mha_with_gqa( - fp16_model_opt, "past_sequence_length", config.num_key_value_heads, world_size - ) +def use_group_query_attention(config: AutoConfig, fp16_model_opt: OnnxModel, world_size: int = 1): + # Replace MultiHeadAttention with GroupQueryAttention + fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "attention_mask", config.num_key_value_heads, world_size) fp16_model_opt.prune_graph() fp16_model_opt.update_graph(allow_remove_graph_inputs=True) return fp16_model_opt @@ -520,8 +519,8 @@ def smooth_quant( logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") - logger.info(f"Removing {args.nc_workspace}") - os.system(f"rm -R {args.nc_workspace}") + logger.warning(f"Removing {args.nc_workspace}") + shutil.rmtree(args.nc_workspace) def remove_existing_model(model_path: str): @@ -594,6 +593,14 @@ def get_args(): ) parser.set_defaults(reexport=False) + parser.add_argument( + "--use_gqa", + required=False, + action="store_true", + help="Use GroupQueryAttention instead of MultiHeadAttention", + ) + parser.set_defaults(use_gqa=False) + parser.add_argument( "--no_merged", required=False, @@ -747,7 +754,7 @@ def main(): location = args.original_model_name if use_auth_token else args.input - # use cuda for Llama-2-70b to speedup export, other models use CPU by default + # Use CUDA for LLaMA-2-70B to speed up export and CPU for other models l_config, llama = setup_torch_model( args, location, use_auth_token, device=args.device if args.model_name == "Llama-2-70b-hf" else None ) @@ -944,6 +951,8 @@ def main(): parity_cmd.append("--use_past_kv") if "merged" in filename: parity_cmd.append("--merged") + if args.use_gqa: + parity_cmd.append("--use_gqa") try: logger.debug(f"check parity with cmd: {parity_cmd}") diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 6530eead55f0..bae1ae82e8f7 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -2,7 +2,7 @@ import numpy as np import torch -from transformers import LlamaConfig +from transformers import AutoConfig from onnxruntime import InferenceSession, OrtValue @@ -24,7 +24,7 @@ def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): # attention_mask: (batch_size, sequence_length) # position_ids: (batch_size, sequence_length) def get_sample_inputs( - config: LlamaConfig, + config: AutoConfig, device: torch.device, batch_size: int, seq_len: int, @@ -59,7 +59,7 @@ def get_sample_inputs( # past_key: (batch_size, num_heads, past_sequence_length, head_size) # past_value: (batch_size, num_heads, past_sequence_length, head_size) def get_sample_with_past_kv_inputs( - config: LlamaConfig, + config: AutoConfig, device: torch.device, batch_size: int, past_seq_len: int, @@ -115,13 +115,14 @@ def get_sample_with_past_kv_inputs( # For models with GQA, kv_sequence_length = max_sequence_length # For models without GQA, kv_sequence_length = past_sequence_length def get_merged_sample_with_past_kv_inputs( - config: LlamaConfig, + config: AutoConfig, device: torch.device, batch_size: int, seq_len: int, past_seq_len: int, max_seq_len: int, use_fp16: bool = False, + use_gqa: bool = False, engine: str = "pt", return_dict: bool = False, world_size: int = 1, @@ -156,9 +157,7 @@ def get_merged_sample_with_past_kv_inputs( assert isinstance(past_kv, dict) inputs.update(past_kv) - if use_fp16: # If model has GQA - del inputs["attention_mask"] - inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) + if use_gqa: inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len) else: @@ -170,12 +169,13 @@ def get_merged_sample_with_past_kv_inputs( # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx def get_msft_sample_inputs( - config: LlamaConfig, + config: AutoConfig, batch_size: int, past_seq_len: int, seq_len: int, max_seq_len: int, use_fp16: bool, + use_gqa: bool, split_kv: bool, ): np_dtype = np.float16 if use_fp16 else np.float32 @@ -213,8 +213,7 @@ def get_msft_sample_inputs( } ) - if use_fp16: # If model has GQA - del ort_inputs["attn_mask"] + if use_gqa: ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs @@ -222,7 +221,7 @@ def get_msft_sample_inputs( # Create past_key_values # Each is of shape (batch_size, num_heads, past_sequence_length, head_size) -def get_past_kv_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): +def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ @@ -247,8 +246,7 @@ def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tenso # Format PyTorch inputs to ONNX Runtime inputs def convert_inputs_for_ort( pt_inputs: dict, - use_fp16: bool, - use_buffer_share: bool = False, + use_gqa: bool = False, past_seq_len: int = 0, max_seq_len: int = 2048, device: str = "", @@ -260,17 +258,11 @@ def convert_inputs_for_ort( ort_inputs[k] = v elif k == "past_key_values": ort_inputs.update(flatten_past_kv_inputs(v)) - elif k == "attention_mask" and use_fp16 and use_buffer_share: - # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, - # and GQA supports a causal mask by default - - # Instead, add the past sequence length input for GQA - ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) else: ort_inputs[k] = v.detach().cpu().numpy() - # Reshape kv caches if using past-present-share-buffer - if use_buffer_share and device != "" and device != "cpu" and device_id > -1: + # Reshape KV caches if using past-present-share-buffer + if use_gqa and device != "" and device != "cpu" and device_id > -1: ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len) return ort_inputs @@ -289,17 +281,14 @@ def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_se # Add IO bindings for execution providers -def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, device_id: int, kv_cache_ortvalues: dict): - use_fp16 = False +def add_io_bindings( + model: InferenceSession, ort_inputs: dict, device: str, device_id: int, use_gqa: bool, kv_cache_ortvalues: dict +): io_binding = model.io_binding() for k, v in ort_inputs.items(): - # Detect if model is in FP16 - if v.dtype == np.float16: - use_fp16 = True - # Bind OrtValue inputs to device - if use_fp16 and ("cache" in k or "past_key_values" in k): + if use_gqa and ("cache" in k or "past_key_values" in k): if k not in kv_cache_ortvalues: v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id) io_binding.bind_ortvalue_input(k, v_device) @@ -313,7 +302,7 @@ def add_io_bindings(model: InferenceSession, ort_inputs: dict, device: str, devi for output in model.get_outputs(): name = output.name - if use_fp16 and ("out" in name or "present" in name): + if use_gqa and ("out" in name or "present" in name): # Bind present KV cache outputs to past KV cache inputs in order to buffer share input_name = name.replace("out", "cache").replace("present", "past_key_values") io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name]) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 42581caf3bb9..418a65325c8f 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -6,6 +6,7 @@ import numpy as np import torch +from benchmark_helper import setup_logger from dist_settings import get_rank, get_size from llama_inputs import ( add_io_bindings, @@ -15,10 +16,9 @@ get_sample_with_past_kv_inputs, ) from llama_torch import setup_torch_model -from transformers import LlamaConfig, LlamaForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM import onnxruntime as ort -from onnxruntime.transformers.benchmark_helper import setup_logger logger = logging.getLogger("") @@ -30,7 +30,7 @@ def get_sequence_lengths(args: argparse.Namespace): return past_sequence_length, curr_sequence_length, max_sequence_length -def get_inputs(args: argparse.Namespace, config: LlamaConfig): +def get_inputs(args: argparse.Namespace, config: AutoConfig): # Dummy values for parity world_size = get_size() batch_size = 2 @@ -45,6 +45,7 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, use_fp16=args.use_fp16, + use_gqa=args.use_gqa, return_dict=True, world_size=world_size, ) @@ -64,7 +65,9 @@ def get_inputs(args: argparse.Namespace, config: LlamaConfig): return inputs -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM, kv_cache_ortvalues: dict): +def verify_parity( + args: argparse.Namespace, config: AutoConfig, pt_model: AutoModelForCausalLM, kv_cache_ortvalues: dict +): inputs = get_inputs(args, config) # Run inference with PyTorch @@ -82,8 +85,7 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) inputs = convert_inputs_for_ort( inputs, - use_fp16=args.use_fp16, - use_buffer_share=args.use_fp16, + use_gqa=args.use_gqa, past_seq_len=past_sequence_length, max_seq_len=max_sequence_length, device=args.execution_provider, @@ -102,7 +104,12 @@ def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: Llama # Add IO bindings for non-CPU execution providers if args.execution_provider != "cpu": io_binding, kv_cache_ortvalues = add_io_bindings( - ort_model, inputs, args.execution_provider, int(args.rank), kv_cache_ortvalues + ort_model, + inputs, + args.execution_provider, + int(args.rank), + args.use_gqa, + kv_cache_ortvalues, ) io_binding.synchronize_inputs() @@ -183,6 +190,14 @@ def get_args(argv: List[str]): ) parser.set_defaults(use_past_kv=False) + parser.add_argument( + "-g", + "--use_gqa", + action="store_true", + help="Use if model has GroupQueryAttention", + ) + parser.set_defaults(use_gqa=False) + parser.add_argument( "--merged", action="store_true", diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py index cf6406dde5be..94e0397116d1 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_torch.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_torch.py @@ -3,7 +3,7 @@ import torch from dist_settings import barrier, get_rank, get_size -from transformers import LlamaConfig, LlamaForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM logger = logging.getLogger("") @@ -19,9 +19,9 @@ def setup_torch_model(args, location, use_auth_token, torch_dtype=torch.float32, for i in range(world_size): if i == rank % (world_size): - l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) + l_config = AutoConfig.from_pretrained(location, use_auth_token=use_auth_token, cache_dir=args.cache_dir) l_config.use_cache = True - llama = LlamaForCausalLM.from_pretrained( + llama = AutoModelForCausalLM.from_pretrained( location, use_auth_token=use_auth_token, config=l_config, diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 319fed87dc9e..99f62ffdb9f5 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -52,6 +52,25 @@ def __init__(self, b, s, s2, sp, n, n2, h): self.head_size = h +class PromptConfig: + batch_size = 0 + q_sequence_length = 0 + kv_sequence_length = 0 + buffer_sequence_length = 0 + num_heads = 0 + kv_num_heads = 0 + head_size = 0 + + def __init__(self, b, sq, skv, sb, n, n2, h): + self.batch_size = b + self.q_sequence_length = sq + self.kv_sequence_length = skv + self.buffer_sequence_length = sb + self.num_heads = n + self.kv_num_heads = n2 + self.head_size = h + + def create_packed_multihead_attention_graph(config): nodes = [ helper.make_node( @@ -164,7 +183,9 @@ def create_multihead_attention_graph(config): return model.SerializeToString() -def create_group_query_attention_graph_no_past(config, causal=False, present_kv_format=Formats.BSNH): +def create_group_query_attention_graph_prompt(config, past_kv_format=Formats.BSNH, share_buffer=True): + past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 + present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length nodes = [ helper.make_node( "GroupQueryAttention", @@ -172,13 +193,17 @@ def create_group_query_attention_graph_no_past(config, causal=False, present_kv_ "query", "key", "value", + "past_key" if share_buffer else "", + "past_value" if share_buffer else "", + "seqlens_k", + "total_sequence_length", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, - unidirectional=1 if causal else 0, - is_past_bsnh=1 if present_kv_format == Formats.BSNH else 0, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", ), ] @@ -189,7 +214,7 @@ def create_group_query_attention_graph_no_past(config, causal=False, present_kv_ TensorProto.FLOAT16, [ config.batch_size, - config.sequence_length, + config.q_sequence_length, config.num_heads * config.head_size, ], ), @@ -211,21 +236,54 @@ def create_group_query_attention_graph_no_past(config, causal=False, present_kv_ config.kv_num_heads * config.head_size, ], ), + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), ] + if share_buffer: + graph_input += [ + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT16, + [ + config.batch_size, + past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, + config.head_size, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( "output", TensorProto.FLOAT16, - [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], ), helper.make_tensor_value_info( "present_key", TensorProto.FLOAT16, [ config.batch_size, - config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, config.head_size, ], ), @@ -234,8 +292,8 @@ def create_group_query_attention_graph_no_past(config, causal=False, present_kv_ TensorProto.FLOAT16, [ config.batch_size, - config.kv_sequence_length if present_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if present_kv_format == Formats.BSNH else config.kv_sequence_length, + present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, config.head_size, ], ), @@ -252,10 +310,10 @@ def create_group_query_attention_graph_no_past(config, causal=False, present_kv_ return model.SerializeToString() -def create_group_query_attention_graph_past(config, causal=False, past_kv_format=Formats.BSNH, share_buffer=True): - past_kv_seqlen = config.kv_sequence_length if share_buffer else config.past_sequence_length +def create_group_query_attention_graph_past(config, past_kv_format=Formats.BSNH, share_buffer=True): + past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( - config.kv_sequence_length if share_buffer else config.past_sequence_length + config.sequence_length + config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length ) nodes = [ helper.make_node( @@ -266,14 +324,15 @@ def create_group_query_attention_graph_past(config, causal=False, past_kv_format "value", "past_key", "past_value", - "past_sequence_length" if share_buffer else "", + "seqlens_k", + "total_sequence_length", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, - unidirectional=1 if causal else 0, - is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, + # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", ), ] @@ -327,14 +386,18 @@ def create_group_query_attention_graph_past(config, causal=False, past_kv_format ], ), ] - if share_buffer: - graph_input += [ - helper.make_tensor_value_info( - "past_sequence_length", - TensorProto.INT32, - [1], - ) - ] + graph_input += [ + helper.make_tensor_value_info( + "seqlens_k", + TensorProto.INT32, + [config.batch_size], + ), + helper.make_tensor_value_info( + "total_sequence_length", + TensorProto.INT32, + [1], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -573,26 +636,78 @@ def mha_func(q, k, v, config): return output -def gqa_no_past_func(q, k, v, config, causal=True, present_kv_format=Formats.BSNH): - onnx_model_str = create_group_query_attention_graph_no_past(config, causal, present_kv_format=present_kv_format) - q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) - k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) - v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) - ort_inputs = { - "query": q.detach().cpu().numpy(), - "key": k.detach().cpu().numpy(), - "value": v.detach().cpu().numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) - ort_output, _, _ = ort_session.run(None, ort_inputs) - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output +def gqa_prompt_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): + onnx_model_str = create_group_query_attention_graph_prompt(config, past_kv_format, share_buffer) + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + past_k = k.clone() if share_buffer else None + past_v = v.clone() if share_buffer else None + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if share_buffer: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), + "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_input( + "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() + ) + io_binding.bind_input( + "past_value", + "cuda", + 0, + numpy.float16, + ort_inputs["past_value"].shape(), + ort_inputs["past_value"].data_ptr(), + ) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) + io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v + else: + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": new_k.detach().cpu().numpy(), + "value": new_v.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + io_binding = ort_session.io_binding() + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + ort_output = numpy.array(ort_output) + output = torch.tensor(ort_output) + return output, present_k, present_v -def gqa_past_func(q, k, v, config, new_k, new_v, past_kv_format=Formats.BSNH, causal=False, share_buffer=True): - onnx_model_str = create_group_query_attention_graph_past(config, causal, past_kv_format, share_buffer) +def gqa_past_func(q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True): + onnx_model_str = create_group_query_attention_graph_past(config, past_kv_format, share_buffer) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() @@ -605,7 +720,8 @@ def gqa_past_func(q, k, v, config, new_k, new_v, past_kv_format=Formats.BSNH, ca "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), - "past_sequence_length": torch.tensor([config.past_sequence_length], dtype=torch.int32) + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32) .detach() .cpu() .numpy(), @@ -627,7 +743,8 @@ def gqa_past_func(q, k, v, config, new_k, new_v, past_kv_format=Formats.BSNH, ca ort_inputs["past_value"].shape(), ort_inputs["past_value"].data_ptr(), ) - io_binding.bind_cpu_input("past_sequence_length", ort_inputs["past_sequence_length"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) io_binding.bind_output("output") io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) @@ -643,6 +760,13 @@ def gqa_past_func(q, k, v, config, new_k, new_v, past_kv_format=Formats.BSNH, ca "value": new_v.detach().cpu().numpy(), "past_key": past_k.detach().cpu().numpy(), "past_value": past_v.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor( + [config.kv_sequence_length + config.sequence_length], dtype=torch.int32 + ) + .detach() + .cpu() + .numpy(), } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) @@ -652,6 +776,8 @@ def gqa_past_func(q, k, v, config, new_k, new_v, past_kv_format=Formats.BSNH, ca io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) io_binding.bind_output("output") io_binding.bind_output("present_key") io_binding.bind_output("present_value") @@ -829,15 +955,15 @@ def parity_check_mha( ) -def parity_check_gqa_no_past( +def parity_check_gqa_prompt( config, - causal=False, + past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, ): q = torch.randn( config.batch_size, - config.sequence_length, + config.q_sequence_length, config.num_heads, config.head_size, device="cuda", @@ -845,6 +971,24 @@ def parity_check_gqa_no_past( requires_grad=False, ) k = torch.randn( + config.batch_size, + config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( config.batch_size, config.kv_sequence_length, config.kv_num_heads, @@ -853,7 +997,7 @@ def parity_check_gqa_no_past( dtype=torch.float16, requires_grad=False, ) - v = torch.randn( + new_v = torch.randn( config.batch_size, config.kv_sequence_length, config.kv_num_heads, @@ -862,23 +1006,157 @@ def parity_check_gqa_no_past( dtype=torch.float16, requires_grad=False, ) + # Pytorch to compare - out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None, causal=causal) + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) + # cache_seqlens = torch.randint( + # 0, + # config.kv_sequence_length, + # (config.batch_size,), + # dtype=torch.int32, + # device="cuda", + # ) + # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") + arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) + kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") + update_mask = arange < kv_seqlens_expanded + k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + # Flash function - out = gqa_no_past_func(q, k, v, config, causal=causal) + out, present_k, present_v = gqa_prompt_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + # Compare results print( - " causal:", - causal, + "KV-buffer", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", config.batch_size, " S:", - config.sequence_length, + config.q_sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_prompt_no_buff( + config, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + q = torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = new_k.clone() + v_cache_ref = new_v.clone() + # if past_format == Formats.BNSH: + # k_cache_ref = k_cache_ref.transpose(1, 2) + # v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) + # cache_seqlens = torch.randint( + # 0, + # config.kv_sequence_length, + # (config.batch_size,), + # dtype=torch.int32, + # device="cuda", + # ) + # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + new_mask = brange < cache_seqlens_expanded + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_prompt_func(q, None, None, config, new_k, new_v, cache_seqlens, past_format, False) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) + + # Compare results + print( + "KV-buffer", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.q_sequence_length, " kv S:", config.kv_sequence_length, " N:", @@ -901,7 +1179,6 @@ def parity_check_gqa_no_past( def parity_check_gqa_past( config, - causal=False, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -958,7 +1235,14 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length - config.sequence_length + 1, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( @@ -969,14 +1253,14 @@ def parity_check_gqa_past( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal) + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, causal, True) + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, True) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -990,8 +1274,6 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", - " causal:", - causal, " B:", config.batch_size, " S:", @@ -1018,7 +1300,135 @@ def parity_check_gqa_past( def parity_check_gqa_past_no_buff( config, - causal=False, + past_format=Formats.BSNH, + rtol=1e-3, + atol=1e-3, +): + torch.manual_seed(69) + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, + config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_k = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + new_v = torch.randn( + config.batch_size, + config.sequence_length, + config.kv_num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + + # Pytorch to compare + k_cache_ref = k.clone() + v_cache_ref = v.clone() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + k_cache_ref = torch.cat((k_cache_ref, new_k), 1) + v_cache_ref = torch.cat((v_cache_ref, new_v), 1) + # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) + cache_seqlens = torch.randint( + 0, + config.kv_sequence_length, + (config.batch_size,), + dtype=torch.int32, + device="cuda", + ) + cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length + ) + k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") + k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) + key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) + out_ref = out_ref.detach().cpu().numpy() + if past_format == Formats.BNSH: + k_cache_ref = k_cache_ref.transpose(1, 2) + v_cache_ref = v_cache_ref.transpose(1, 2) + + # Flash function + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, cache_seqlens, past_format, False) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + + # Make sure past-present buffer updating correctly + # assert numpy.allclose( + # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True + # ) + # assert numpy.allclose( + # present_v[:, :, :-1, :], v_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True + # ) + + # Compare results + print( + "NO buff", + "past kv format:", + "BSNH" if past_format == Formats.BSNH else "BNSH", + " B:", + config.batch_size, + " S:", + config.sequence_length, + " kv S:", + config.kv_sequence_length, + " N:", + config.num_heads, + " kv N:", + config.kv_num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +def parity_check_gqa_past_no_buff_no_mask( + config, past_format=Formats.BSNH, rtol=1e-3, atol=1e-3, @@ -1080,14 +1490,14 @@ def parity_check_gqa_past_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = None - out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=causal) + out_ref, _ = attention_ref(q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, causal, False) + out, present_k, present_v = gqa_past_func(q, k, v, config, new_k, new_v, past_format, False) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1129,8 +1539,6 @@ def parity_check_gqa_past_no_buff( "Unbuffered", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", - " causal:", - causal, " B:", config.batch_size, " S:", @@ -1214,46 +1622,51 @@ def test_gqa_no_past(self): return major, minor = torch.cuda.get_device_capability() torch.manual_seed(69) - print("-------- TEST GQA ---------") - batches = [2] if pipeline_mode else [1, 5] + print("-------- TEST GQA NO PAST (PROMPT CASE) ---------") + batches = [3] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (113, 211), (2048, 2048)] + [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] if pipeline_mode else [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (1024, 1024), - (1023, 1024), - (2048, 2048), + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), ] ) - num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] if major < 5 or (major == 5 and minor < 3): return - print("------- MEMORY EFFICIENT ATTENTION ---------") + print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" for b in batches: - for s, s2 in seqs: + for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for causal in [True, False]: - config = Config(b, s, s2, 0, n, n2, h) - parity_check_gqa_no_past(config, causal=causal) + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) if major < 8 or platform.system() != "Linux": return - print("------- FLASH ATTENTION --------") + print("------- FLASH ATTENTION (PROMPT CASE) --------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" for b in batches: - for s, s2 in seqs: + for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for causal in [True, False]: - config = Config(b, s, s2, 0, n, n2, h) - parity_check_gqa_no_past(config, causal=causal) + for past_kv_format in [Formats.BNSH]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt(config, past_format=past_kv_format) + parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1262,78 +1675,74 @@ def test_gqa_past(self): if major < 5 or (major == 5 and minor < 3): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- TEST GQA PAST ---------") - print("-------- MEMORY EFFICEINT --------") - batches = [2] if pipeline_mode else [1, 2] + print("-------- TEST GQA PAST (TOKEN GEN) ---------") + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") + batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( - [(1, 128), (3, 1024), (64, 2048)] + [(1, 128), (1, 1024), (1, 2048)] if pipeline_mode else [ (1, 128), (1, 339), - (3, 1024), - (64, 800), - (64, 256), - (3, 799), - (64, 2048), - (16, 20000), - (1, 128 * 512), - (16, 128 * 512), - (128, 128), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), ] ) - num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) for b in batches: for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for causal in [True]: - for past_kv_format in [Formats.BNSH, Formats.BSNH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if major < 8 or platform.system() != "Linux": return - print("------- FLASH ATTENTION -------") + print("------- FLASH ATTENTION (TOKEN GEN) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" for b in batches: for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for causal in [True]: - for past_kv_format in [Formats.BNSH, Formats.BSNH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - causal=causal, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for past_kv_format in [Formats.BNSH]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) + parity_check_gqa_past_no_buff( + config, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + ) if __name__ == "__main__": unittest.main() + # test_gqa = TestGQA() + # test_gqa.test_gqa_past()