Add RemovePadding and RestorePadding for BERT model#13701
Conversation
| "output tensor with shape (total_tokens, hidden_size)", | ||
| "T") | ||
| .Output(1, | ||
| "token_offset", |
|
|
||
| const auto& dims = input->Shape().GetDims(); | ||
| if (dims.size() != 3) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is expected to have 3 dimensions, got ", |
| // total_token_count: 1 + 2 + 4 = 7 | ||
| // max_token_count: 4 | ||
| // cumulated_token_count: 0, 1, 1+2, 1+2+4 | ||
| __global__ void getTokenOffset(int* token_count_buffer, |
There was a problem hiding this comment.
There was a problem hiding this comment.
For offset idx > token_size, we don't need to fill it actually because the restoring won't use it
There was a problem hiding this comment.
global_ void getTokenOffset(int* token_count_buffer,
It can be implemented with cub::BlockScan. The kernel can be launched with Grid: 1, Block: batch. For kernel:
- it uses cub::BlockScan to compute cumulated_token_count firstly.
- then each thread fills its token_offset
Good suggestion. There is a TODO in comment that is related to this:
// TODO(tianleiwu): Use cub::DevicePartition::Flagged like BuildGlobalIndex in longformer_global_impl.cu
// to build token_offset when sequence length is large.
I could do it in another pull request later.
There was a problem hiding this comment.
For offset idx > token_size, we don't need to fill it actually because the restoring won't use it
The purpose is to fill zeros for those padded tokens (to make result determined). Otherwise, we will need the fill the whole output with zeros first, then use another kernel to restore non-padding tokens.
Another purpose is to make the shape as (batch_size, sequence_length). Otherwise, we will need pass these two values to restore padding operator.
Add two operators RemovePadding and RestorePadding based on ideal of effective transformer (https://github.com/bytedance/effective_transformer) to improve large batch size inference for BERT model.
Description
Add two operators RemovePadding and RestorePadding based on ideal of effective transformer (https://github.com/bytedance/effective_transformer) to improve large batch size inference for BERT model.
Motivation and Context