Skip to content

Commit

Permalink
Cherry-pick LLaMA GQA mask to rel-1.16.2 (round 4) (#18350)
Browse files Browse the repository at this point in the history
Cherry-pick LLaMA GQA attention mask and script changes to 1.16.2 release branch.

---------
Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com>
Co-authored-by: Yufeng Li <liyufeng1987@gmail.com>
Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com>
  • Loading branch information
tianleiwu committed Nov 8, 2023
1 parent 8f06330 commit 0c5b95f
Show file tree
Hide file tree
Showing 22 changed files with 1,306 additions and 463 deletions.
26 changes: 12 additions & 14 deletions docs/ContribOperators.md
Expand Up @@ -2236,19 +2236,15 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes

<dl>
<dt><tt>is_past_bsnh</tt> : int</dt>
<dd>Whether past kv uses BSNH, otherwise BNSH. Default value is 1 (BSNH).</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for q</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 1.</dd>
</dl>

#### Inputs (3 - 6)
#### Inputs

<dl>
<dt><tt>query</tt> : T</dt>
Expand All @@ -2258,11 +2254,13 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>value</tt> : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>past state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When buffered past_key and past_value is used (present_key uses same tensor as past_key), requiredto specify past_sequence_length (could be 0). Otherwise, past_sequence_length inferred from past_key.</dd>
<dd>past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>seqlens_k</tt> : M</dt>
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of total sequence length (past + new).</dd>
</dl>

#### Outputs
Expand All @@ -2271,18 +2269,18 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> : T</dt>
<dd>present state key with support for format BSNH or BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> : T</dt>
<dd>present state value with support for format BSNH or BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>M</tt> : tensor(int32), tensor(int64)</dt>
<dd>Constrain past sequence length to int tensor.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain mask to int tensor.</dd>
</dl>


Expand Down Expand Up @@ -4766,7 +4764,7 @@ This version of the operator has been available since version 1 of the 'com.micr

### <a name="com.microsoft.RotaryEmbedding"></a><a name="com.microsoft.rotaryembedding">**com.microsoft.RotaryEmbedding**</a>

RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
that are multiplied to query and key before the inner product of query and key is taken.

#### Version
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Expand Up @@ -843,7 +843,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)<br/> **T** = tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Expand Up @@ -86,18 +86,19 @@ struct PackedAttentionParameters {
// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters {
int batch_size;
int sequence_length;
int past_sequence_length; // actual sequence length of past_key and past_value
int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present)
int present_sequence_length; // past_sequence_length + kv_sequence_length
int max_sequence_length; // allocated length of past_key and past_value
int sequence_length; // sequence length of input query, key, value
int seqlen_past_kv_cache; // sequence length of past kv tensor
int seqlen_present_kv_cache; // sequence length of present kv tensor
int hidden_size;
int num_heads;
int head_size;
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
bool is_unidirectional; // causal
bool kv_share_buffer;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
bool left_padding; // copies last token to last index if true
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Expand Up @@ -401,6 +401,7 @@ Status EfficientAttention(
? data.scratch
: nullptr;
p.stream = stream;
p.has_custom_right_padding = false;
run_memory_efficient_attention(p);
DUMP_TENSOR("efficient attention output", data.output,
parameters.batch_size, parameters.sequence_length, parameters.num_heads, parameters.v_head_size);
Expand Down
133 changes: 132 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
Expand Up @@ -16,6 +16,133 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

template <typename AttentionKernel, int kQueriesPerBlock>
struct RightPaddingBatchHook {
using scalar_t = typename AttentionKernel::scalar_t;
using accum_t = typename AttentionKernel::accum_t;
using lse_scalar_t = typename AttentionKernel::lse_scalar_t;
using output_t = typename AttentionKernel::output_t;
using output_accum_t = typename AttentionKernel::output_accum_t;

static constexpr bool kSupportsDropout = AttentionKernel::kSupportsDropout;
static constexpr bool kSupportsBias = AttentionKernel::kSupportsBias;
static constexpr int kKeysPerBlock = AttentionKernel::kKeysPerBlock;
static constexpr bool kIsAligned = AttentionKernel::kIsAligned;
static constexpr bool kSingleValueIteration = AttentionKernel::kSingleValueIteration;
static constexpr int32_t kAlignLSE = AttentionKernel::kAlignLSE; // block size of backward
static constexpr bool kPreloadV = AttentionKernel::kPreloadV;
static constexpr bool kKeepOutputInRF = AttentionKernel::kKeepOutputInRF;
static constexpr bool kNeedsOutputAccumulatorBuffer = AttentionKernel::kNeedsOutputAccumulatorBuffer;

template <typename Params>
static CUTLASS_DEVICE bool AdvanceToBlockForGQA(Params& p) {
auto batch_id = blockIdx.z;
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;

auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;

// Advance to current batch - in case of different sequence lengths
if (p.seqlen_k_ptr) {
p.num_keys = p.seqlen_k_ptr[batch_id];
}

if (query_start >= p.num_queries) {
return false;
}

// Advance to the current batch / head / query_start
p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH;
p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH;
p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH;
p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value;

if (kSupportsBias && p.attn_bias_ptr != nullptr) {
p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH);
}
if (p.output_accum_ptr != nullptr) {
p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) +
int64_t(query_start) * (p.head_dim_value * p.num_heads) +
head_id * p.head_dim_value;
} else {
// Accumulate directly in the destination buffer (eg for f32)
p.output_accum_ptr = (accum_t*)(p.output_ptr);
}

if (p.logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
p.logsumexp_ptr +=
batch_id * lse_dim * p.num_heads + head_id * lse_dim + query_start;
}

// Custom masking
if (p.causal_diagonal_ptr) {
p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id];
}
if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
p.causal_diagonal_offset += p.num_keys - p.num_queries;
}
if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft ||
p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
// the bottom row of the current block is query_start + kQueriesPerBlock
// the last active key is then query_start + causal_diagonal_offset +
// kQueriesPerBlock so num_keys is the min between actual num_keys and
// this to avoid extra computations
p.num_keys = cutlass::fast_min(
int32_t(query_start + p.causal_diagonal_offset + kQueriesPerBlock),
p.num_keys);
}

p.num_queries -= query_start;
p.num_batches = 0; // no longer used after

// If num_queries == 1, and there is only one key head we're wasting
// 15/16th of tensor core compute In that case :
// - we only launch kernels for head_id % kQueriesPerBlock == 0
// - we iterate over heads instead of queries (strideM = strideH)
if (p.num_queries == 1 && p.k_strideH == 0 && p.v_strideH == 0) {
if (head_id % kQueriesPerBlock != 0)
return false;
p.q_strideM = p.q_strideH;
p.num_queries = p.num_heads;
p.num_heads = 1; // unused but here for intent
// remove causal since n_query = 1
// otherwise, offset would change with head !
p.custom_mask_type = AttentionKernel::NoCustomMask;
p.o_strideM = p.head_dim_value;
}

// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
p.query_ptr = warp_uniform(p.query_ptr);
p.key_ptr = warp_uniform(p.key_ptr);
p.value_ptr = warp_uniform(p.value_ptr);
if (kSupportsBias) {
p.attn_bias_ptr = warp_uniform(p.attn_bias_ptr);
}
p.output_ptr = warp_uniform(p.output_ptr);
p.output_accum_ptr = warp_uniform(p.output_accum_ptr);
p.logsumexp_ptr = warp_uniform(p.logsumexp_ptr);
p.num_queries = warp_uniform(p.num_queries);
p.num_keys = warp_uniform(p.num_keys);
p.num_heads = warp_uniform(p.num_heads);
p.head_dim = warp_uniform(p.head_dim);
p.head_dim_value = warp_uniform(p.head_dim_value);
p.o_strideM = warp_uniform(p.o_strideM);
p.custom_mask_type = warp_uniform(p.custom_mask_type);
return true;
}
};

template <typename AK, int kQueriesPerBlock>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched_impl_right_padding(typename AK::Params p) {
if (!RightPaddingBatchHook<AK, kQueriesPerBlock>::AdvanceToBlockForGQA(p)) {
return;
}
AK::attention_kernel(p);
}

template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
Expand Down Expand Up @@ -92,7 +219,11 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
}
}

constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
auto kernel_fn = attention_kernel_batched_impl<Attention>;
if (params.has_custom_right_padding) {
kernel_fn = attention_kernel_batched_impl_right_padding<Attention, queries_per_block>;
}

int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!");
Expand Down
Expand Up @@ -43,6 +43,8 @@ struct MemoryEfficientAttentionParams {
static bool need_workspace(size_t v_head_size, bool is_float) {
return (v_head_size > 128 && !is_float);
}

bool has_custom_right_padding = false;
};

void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params);
Expand Down

0 comments on commit 0c5b95f

Please sign in to comment.