Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize t5 encoder in beam search #11926

Merged
merged 10 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds<float>};
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds<float>,
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer<MLFloat16>};
ORT_RETURN_IF_ERROR(impl.Initialize());

return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
Expand All @@ -198,7 +201,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_func_,
device_copy_int32_func_,
create_encoder_inputs_func_,
update_decoder_feeds_fp16_func_};
update_decoder_feeds_fp16_func_,
expand_buffer_int32_func_,
expand_buffer_float_func_,
expand_buffer_float16_func_};

ORT_RETURN_IF_ERROR(impl.Initialize());

Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,15 @@ class BeamSearch : public IControlFlowKernel {
// device helpers for encoder-decoder model like T5
void SetDeviceHelpers_EncoderDecoder(
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float>& update_decoder_feeds_func,
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func) {
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func,
const BeamSearchDeviceHelper::ExpandBufferFunc<int32_t>& expand_buffer_int32_func,
const BeamSearchDeviceHelper::ExpandBufferFunc<float>& expand_buffer_float_func,
const BeamSearchDeviceHelper::ExpandBufferFunc<MLFloat16>& expand_buffer_float16_func) {
update_decoder_feeds_func_ = update_decoder_feeds_func;
update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
expand_buffer_int32_func_ = expand_buffer_int32_func;
expand_buffer_float_func_ = expand_buffer_float_func;
expand_buffer_float16_func_ = expand_buffer_float16_func;
}

private:
Expand Down Expand Up @@ -106,6 +112,10 @@ class BeamSearch : public IControlFlowKernel {
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float> update_decoder_feeds_func_;
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16> update_decoder_feeds_fp16_func_;

BeamSearchDeviceHelper::ExpandBufferFunc<int32_t> expand_buffer_int32_func_;
BeamSearchDeviceHelper::ExpandBufferFunc<float> expand_buffer_float_func_;
BeamSearchDeviceHelper::ExpandBufferFunc<MLFloat16> expand_buffer_float16_func_;

//------------------------------------------------------------
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
//------------------------------------------------------------
Expand Down
118 changes: 90 additions & 28 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,47 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator,
}
}

template <typename T>
wangyems marked this conversation as resolved.
Show resolved Hide resolved
Status ExpandBuffer(void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape) {
// Input shape (batch_size, xxx). The input is required with data type T.
// Output shape (batch_size * num_beams, xxx)
ORT_UNUSED_PARAMETER(stream);

const TensorShape& input_shape = input.Get<Tensor>().Shape();
const int64_t& batch_size = input_shape[0];
const int64_t& chunk_size = static_cast<int64_t>(input_shape.Size() / batch_size);

int64_t dims[4] = {0};
input_shape.CopyDims(dims, input_shape.NumDimensions());
dims[0] = batch_size * num_beams;
TensorShape expanded_shape(&dims[0], input_shape.NumDimensions());

MLDataType element_type = input.Get<Tensor>().DataType();
ORT_ENFORCE(element_type == DataTypeImpl::GetType<T>());
Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded);

if (only_copy_shape) {
return Status::OK();
}

const T* input_data = input.Get<Tensor>().Data<T>();
T* expanded_data = expanded.GetMutable<Tensor>()->MutableData<T>();
T* target = expanded_data;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size);
target += chunk_size;
}
}

return Status::OK();
}

Status CreateGptInputs(
const Tensor* original_input_ids,
int num_beams,
Expand Down Expand Up @@ -200,37 +241,43 @@ Status ProcessLogits(const OrtValue& logits, //
const TensorShape& logits_shape = logits.Get<Tensor>().Shape();
ORT_ENFORCE(logits_shape.NumDimensions() == 3);
auto input_length = logits_shape[1];
auto beam_batch_size = logits_shape[0];
wangyems marked this conversation as resolved.
Show resolved Hide resolved

// Get logits for the last token:
// next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size)
// When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
gsl::span<T>& next_token_logits = beam_state->next_token_logits;
if (input_length > 1) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved

if (input_length > 1 || beam_batch_size == batch_size) {
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const T> source(current_logits, vocab_size);
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size,
static_cast<gsl::index>(vocab_size));
gsl::copy(source, target);
current_logits += input_length * vocab_size;
if (beam_batch_size == batch_beam_size) {
current_logits += input_length * vocab_size;
} else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) {
current_logits += input_length * vocab_size;
}
}
}

#ifdef DEBUG_BEAM_SEARCH
dumper->Print("logits", logits);
if (input_length > 1) {
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
}
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
wangyems marked this conversation as resolved.
Show resolved Hide resolved
#endif

// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
gsl::span<T>& next_token_scores = beam_state->next_token_scores;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(batch_beam_size, // rows
vocab_size, // elements per row
input_length > 1 ? next_token_logits.data() : logits_data,
next_token_scores.data(),
true,
thread_pool));
ORT_RETURN_IF_ERROR(
SoftmaxCPU<T>(
batch_beam_size, // rows
vocab_size, // elements per row
(input_length == 1 && beam_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(),
next_token_scores.data(),
true,
thread_pool));

#ifdef DEBUG_BEAM_SEARCH
dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size);
Expand Down Expand Up @@ -456,13 +503,12 @@ Status UpdateGptFeeds(
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids) {
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids) {
const TensorShape& input_ids_shape = original_encoder_input_ids->Shape();
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
const int64_t& batch_size = input_ids_shape[0];
Expand All @@ -475,14 +521,12 @@ Status CreateEncoderInputs(
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
OrtValue encoder_input_ids;
Tensor::InitOrtValue(element_type,
input_ids_shape,
const_cast<Tensor*>(original_encoder_input_ids)->MutableData<int32_t>(),
allocator->Info(),
encoder_input_ids);

OrtValue encoder_attention_mask;
if (attn_mask_value != nullptr) {
const Tensor& attention_mask = attn_mask_value->Get<Tensor>();
Tensor::InitOrtValue(element_type, input_ids_shape, const_cast<Tensor*>(&attention_mask)->MutableData<int32_t>(),
Expand Down Expand Up @@ -511,20 +555,14 @@ Status CreateEncoderInputs(
}
}

// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// for encoder_input_ids and encoder_attention_mask
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
ExpandInputs<int32_t>(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids);
ExpandInputs<int32_t>(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask);

// decoder_input_ids is optional.
if (start_token_id >= 0) {
// Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID
int64_t dims[] = {batch_size * num_beams, 1};
// Filled decoder_input_ids with start token ID
int64_t dims[] = {batch_size, 1};
TensorShape decoder_input_ids_shape(&dims[0], 2);
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids);
int32_t* data = expanded_decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_size * num_beams; i++, data++) {
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, decoder_input_ids);
int32_t* data = decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_size; i++, data++) {
*data = start_token_id;
}
}
Expand Down Expand Up @@ -709,6 +747,30 @@ template Status UpdateDecoderFeeds<float>(

template void ExpandInputs<int32_t>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

template Status ExpandBuffer<int32_t>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

template Status ExpandBuffer<float>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

template Status ExpandBuffer<MLFloat16>(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,12 @@ using UpdateGptFeedsFunc = std::function<Status(
using CreateEncoderInputsFunc = std::function<Status(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids)>;
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids)>;

// Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5).
template <typename T>
Expand All @@ -132,8 +131,18 @@ using UpdateDecoderFeedsFunc = std::function<Status(
int current_length,
transformers::Sequences& sequences,
const transformers::IConsoleDumper* dumper)>;

template <typename T>
using ExpandBufferFunc = std::function<Status(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape)>;
} // namespace BeamSearchDeviceHelper


// These are CPU specific device helper implementations
namespace BeamSearchCpuDeviceHelper {
Status TopK(
Expand Down Expand Up @@ -212,13 +221,12 @@ Status UpdateGptFeeds(
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids);
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids);

// Update decoder inputs given decoder outputs of last iteration.
template <typename T>
Expand All @@ -244,6 +252,15 @@ Status UpdateDecoderFeeds(
template <typename T>
void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

template <typename T>
Status ExpandBuffer(
void* stream,
const OrtValue& input,
int num_beams,
AllocatorPtr allocator,
OrtValue& expanded,
bool only_copy_shape);

} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
17 changes: 16 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
this->sequences.Init(this->sequences_space, static_cast<int>(batch_beam_size), sequence_length, max_length);
}

// Copy input_ids to sequences[0]
// Copy expanded input_ids to sequences[0]
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
size_t batch_beam_size,
int max_length,
Expand All @@ -109,6 +109,21 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
}
}

// Copy unexpanded input_ids to sequences[0]
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
size_t batch_beam_size,
int beam_size,
int max_length,
int sequence_length) {
gsl::span<int32_t> sequences_0 = sequences_space;
for (size_t i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < sequence_length; j++) {
const size_t index = SafeInt<gsl::index>(i) * max_length + j;
sequences_0[index] = input_ids_in_cpu[SafeInt<gsl::index>(i / beam_size) * sequence_length + j];
}
}
}

private:
BufferUniquePtr final_beam_scores_buffer_;
BufferUniquePtr sequence_lengths_buffer_;
Expand Down
Loading