Skip to content

Commit

Permalink
[t5 optimization] kernel changes to t5 (#14928)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

1. support optional bias in Attention op (used in T5 encoder)
2. support broadcasting rel_pos_bias in attention_softmax.h
3. add scale in
MHA op's attributes
4. support past_key/past_value and present_key/present_value in MHA
5. UT and parity tests are added
6. fix an issue: #14920

note: the fusions will be in another PR since mt5 needs to be tested and
an issue from github will be investigated.

Future works:
1. support shared buffer for past/present
2. enable trt kernels when possible and investigate (trt/cutlass)kernels
with rel_pos_bias)
3. support KV/QKV packing with past/present

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
  • Loading branch information
wangyems and Ubuntu committed Mar 13, 2023
1 parent b34e570 commit 538d648
Show file tree
Hide file tree
Showing 19 changed files with 1,345 additions and 152 deletions.
22 changes: 16 additions & 6 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>

#### Inputs (3 - 7)
#### Inputs (2 - 7)

<dl>
<dt><tt>input</tt> : T</dt>
<dd>Input tensor with shape (batch_size, sequence_length, input_hidden_size)</dd>
<dt><tt>weights</tt> : T</dt>
<dd>Merged Q/K/V weights with shape (input_hidden_size, hidden_size + hidden_size + v_hidden_size)</dd>
<dt><tt>bias</tt> : T</dt>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection</dd>
<dt><tt>mask_index</tt> (optional) : M</dt>
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)</dd>
Expand Down Expand Up @@ -2381,30 +2381,40 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>The value to be filled in the attention mask. Default value is -10000.0f</dd>
<dt><tt>num_heads</tt> : int (required)</dt>
<dd>Number of attention heads</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
</dl>

#### Inputs (1 - 6)
#### Inputs (1 - 8)

<dl>
<dt><tt>query</tt> : T</dt>
<dd>Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)</dd>
<dt><tt>key</tt> (optional) : T</dt>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)</dd>
<dd>Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size), or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size)</dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size)</dd>
<dd>Value with shape (batch_size, kv_sequence_length, v_hidden_size), or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size)</dd>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dd>past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>past state for self attention value with shape (batch_size, num_heads, past_sequence_length, head_size)</dd>
</dl>

#### Outputs
#### Outputs (1 - 3)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
<dt><tt>present_key</tt> (optional) : T</dt>
<dd>present state for cross attention key with shape (batch_size, num_heads, kv_sequence_length, head_size)or present state for self attention key with shape (batch_size, num_heads, total_sequence_length, head_size)</dd>
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>present state for cross attention value with shape (batch_size, num_heads, kv_sequence_length, head_size)or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)</dd>
</dl>

#### Type Constraints
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**<br> *in* weight:**T2**<br> *in* bias:**T3**<br> *in* input_scale:**T3**<br> *in* weight_scale:**T3**<br> *in* mask_index:**T4**<br> *in* input_zero_point:**T1**<br> *in* weight_zero_point:**T2**<br> *in* past:**T3**<br> *out* output:**T3**<br> *out* present:**T3**|1+|**T1** = tensor(int8)<br/> **T2** = tensor(int8)<br/> **T3** = tensor(float), tensor(float16)<br/> **T4** = tensor(int32)|
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
}
}

bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();

Expand All @@ -202,11 +203,14 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
relative_position_bias_dims.size());
}

if (relative_position_bias_dims[0] != batch_size) {
if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'relative_position_bias' dimension 0 should be same as batch_size, got ",
"Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ",
relative_position_bias_dims[0]);
}
if (relative_position_bias_dims[0] == 1) {
broadcast_res_pos_bias = true;
}
if (relative_position_bias_dims[1] != num_heads_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
Expand Down Expand Up @@ -255,6 +259,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
output_parameters->pass_past_in_kv = false;
}

return Status::OK();
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ struct AttentionParameters {
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
bool broadcast_res_pos_bias;
bool pass_past_in_kv;
float mask_filter_value;
float scale;
AttentionMaskType mask_type;
Expand Down
126 changes: 108 additions & 18 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@ Status CheckInputs(const T* query,
const T* bias,
const T* key_padding_mask,
const T* relative_position_bias,
const T* past_key,
const T* past_value,
void* parameters,
int num_heads,
float mask_filter_value,
float scale,
int max_threads_per_block) {
// key_padding_mask (K/V) : (B) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// past_key : (B, N, S*, H)
// past_value : (B, N, S*, H)
// When no packing for q/k/v:
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D_v)
// key (K) : (B, L, D) or (B, N, S*, H)
// value (V) : (B, L, D_v) or (B, N, S*, H)
// bias (Q/K/V) : (D + D + D_v)
// When packed kv is used:
// query (Q) : (B, S, D)
Expand All @@ -40,6 +45,7 @@ Status CheckInputs(const T* query,
// value (V) : None
// bias (Q/K/V) : None


const auto& query_dims = query->Shape().GetDims();
if (query_dims.size() != 3 && query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ",
Expand All @@ -52,15 +58,72 @@ Status CheckInputs(const T* query,
int head_size = static_cast<int>(hidden_size) / num_heads;
int kv_sequence_length = sequence_length;

int past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
const auto& past_key_dims = past_key->Shape().GetDims();
const auto& past_value_dims = past_value->Shape().GetDims();

if (past_key_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' is expected to have 4 dimensions, got ",
past_key_dims.size());
}
if (past_value_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' is expected to have 4 dimensions, got ",
past_value_dims.size());
}

if (past_key_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 0 should be batch_size, got ",
past_key_dims[0]);
}
if (past_value_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 0 should be batch_size, got ",
past_value_dims[0]);
}

if (past_key_dims[1] != num_heads) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 1 should be same as number of heads, got ",
past_key_dims[1]);
}
if (past_value_dims[1] != num_heads) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 1 should be same as number of heads, got ",
past_value_dims[1]);
}
if (past_key_dims[2] != past_value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall have same dim 2 (past_sequence_length)");
}
if (past_key_dims[3] != head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' dimension 3 should be same as head_size, got ",
past_key_dims[3]);
}
if (past_value_dims[3] != head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_value' dimension 3 should be same as head_size, got ",
past_value_dims[3]);
}
past_sequence_length = static_cast<int>(past_key_dims[2]);
} else if (past_key != nullptr || past_value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall be both present or both absent");
}

if (key != nullptr) {
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ",
query_dims.size());
}

const auto& key_dims = key->Shape().GetDims();
if (key_dims.size() != 3 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
if (key_dims.size() != 3 && key_dims.size() != 4 && key_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3, 4, or 5 dimensions, got ",
key_dims.size());
}
if (query_dims[0] != key_dims[0]) {
Expand All @@ -73,8 +136,9 @@ Status CheckInputs(const T* query,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'query' and 'key' shall have same dim 2 (hidden_size)");
}
} else // if (key_dims.size() == 5)
{

kv_sequence_length = static_cast<int>(key_dims[1]);
} else if (key_dims.size() == 5) {
if (static_cast<int>(key_dims[2]) != num_heads || static_cast<int>(key_dims[3]) != 2 || static_cast<int>(key_dims[4]) != head_size) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -83,9 +147,17 @@ Status CheckInputs(const T* query,
if (value != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
}
}

kv_sequence_length = static_cast<int>(key_dims[1]);
kv_sequence_length = static_cast<int>(key_dims[1]);
} else { // key_dims.size() == 4 (cross-attention with past_key)
if (static_cast<int>(key_dims[1]) != num_heads || static_cast<int>(key_dims[3]) != head_size) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Expect 'key' shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key");
}

kv_sequence_length = static_cast<int>(key_dims[2]);
}
} else { // packed QKV
if (query_dims.size() != 5) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 5 dimensions when key is empty, got ",
Expand Down Expand Up @@ -128,11 +200,13 @@ Status CheckInputs(const T* query,
}
}

// NOTE: In Cross-Attention, we pass the past key and value to 'key' and 'value' instead of 'past_key' and 'past_value'.
bool pass_past_in_kv = false;
int v_hidden_size = hidden_size;
if (value != nullptr) {
const auto& value_dims = value->Shape().GetDims();
if (value_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
if (value_dims.size() != 3 && value_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 or 4 dimensions, got ",
value_dims.size());
}

Expand All @@ -141,13 +215,24 @@ Status CheckInputs(const T* query,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}

if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)");
if (value_dims.size() == 3) {
if (static_cast<int64_t>(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have the same dim 1 (kv_sequence_length)");
}
v_hidden_size = static_cast<int>(value_dims[2]);
} else { // value_dims.size() == 4
if (static_cast<int64_t>(kv_sequence_length) != value_dims[2]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'past_key' and 'past_value' shall have the same dim 2 (kv_sequence_length)");
}
v_hidden_size = static_cast<int>(value_dims[1]) * static_cast<int>(value_dims[3]);
pass_past_in_kv = true;
}
v_hidden_size = static_cast<int>(value_dims[2]);
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();

Expand All @@ -161,6 +246,9 @@ Status CheckInputs(const T* query,
"Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ",
relative_position_bias_dims[0]);
}
if (relative_position_bias_dims[0] == 1) {
broadcast_res_pos_bias = true;
}
if (relative_position_bias_dims[1] != num_heads) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
Expand All @@ -171,7 +259,7 @@ Status CheckInputs(const T* query,
"Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ",
relative_position_bias_dims[2]);
}
if (relative_position_bias_dims[3] != kv_sequence_length) {
if (relative_position_bias_dims[3] != total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ",
relative_position_bias_dims[3]);
Expand All @@ -182,9 +270,9 @@ Status CheckInputs(const T* query,
AttentionParameters* output_parameters = reinterpret_cast<AttentionParameters*>(parameters);
output_parameters->batch_size = batch_size;
output_parameters->sequence_length = sequence_length;
output_parameters->past_sequence_length = 0;
output_parameters->past_sequence_length = past_sequence_length;
output_parameters->kv_sequence_length = kv_sequence_length;
output_parameters->total_sequence_length = kv_sequence_length;
output_parameters->total_sequence_length = total_sequence_length;
output_parameters->max_sequence_length = 0;
output_parameters->input_hidden_size = 0;
output_parameters->hidden_size = hidden_size;
Expand All @@ -196,7 +284,9 @@ Status CheckInputs(const T* query,
output_parameters->past_present_share_buffer = false;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
output_parameters->scale = 0.0f;
output_parameters->scale = scale;
output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
output_parameters->pass_past_in_kv = pass_past_in_kv;
}

if (max_threads_per_block > 0 && num_heads > max_threads_per_block) {
Expand Down

0 comments on commit 538d648

Please sign in to comment.