Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize t5 encoder in beam search #11926

Merged
merged 10 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 84 additions & 28 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,67 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator,
}
}

template <typename T>
wangyems marked this conversation as resolved.
Show resolved Hide resolved
void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) {
// Input shape (batch_size, num_heads, sequence_length, head_size). The input is required with data type T.
// Output shape (batch_size * num_beams, num_heads, sequence_length, head_size)

const TensorShape& input_shape = input.Get<Tensor>().Shape();
const int64_t& batch_size = input_shape[0];
const int64_t& num_heads = input_shape[1];
const int64_t& sequence_length = input_shape[2];
const int64_t& head_size = input_shape[3];

int64_t dims[] = {batch_size * num_beams, num_heads, sequence_length, head_size};
TensorShape expanded_shape(&dims[0], 4);

MLDataType element_type = input.Get<Tensor>().DataType();
ORT_ENFORCE(element_type == DataTypeImpl::GetType<T>());

Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded);

const T* input_data = input.Get<Tensor>().Data<T>();
T* expanded_data = expanded.GetMutable<Tensor>()->MutableData<T>();
T* target = expanded_data;
const int64_t chunk_size = sequence_length * num_heads * head_size;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size);
target += chunk_size;
}
}
}

template <typename T>
void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) {
// Input shape (batch_size, num_heads, sequence_length, head_size). The input is required with data type T.
// Output shape (batch_size * num_beams, num_heads, sequence_length, head_size)

const TensorShape& input_shape = input.Get<Tensor>().Shape();
const int64_t& batch_size = input_shape[0];
const int64_t& sequence_length = input_shape[1];
const int64_t& hidden_size = input_shape[2];

int64_t dims[] = {batch_size * num_beams, sequence_length, hidden_size};
TensorShape expanded_shape(&dims[0], 3);

MLDataType element_type = input.Get<Tensor>().DataType();
ORT_ENFORCE(element_type == DataTypeImpl::GetType<T>());

Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded);

// const T* input_data = input.Get<Tensor>().Data<T>();
// T* expanded_data = expanded.GetMutable<Tensor>()->MutableData<T>();
// T* target = expanded_data;
// const int64_t chunk_size = sequence_length * hidden_size;
// for (int i = 0; i < batch_size; i++) {
// for (int j = 0; j < num_beams; j++) {
// memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size);
// target += chunk_size;
// }
// }
}

Status CreateGptInputs(
const Tensor* original_input_ids,
int num_beams,
Expand Down Expand Up @@ -200,34 +261,36 @@ Status ProcessLogits(const OrtValue& logits, //
const TensorShape& logits_shape = logits.Get<Tensor>().Shape();
ORT_ENFORCE(logits_shape.NumDimensions() == 3);
auto input_length = logits_shape[1];
auto beam_batch_size = logits_shape[0];
wangyems marked this conversation as resolved.
Show resolved Hide resolved

// Get logits for the last token:
// next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size)
// When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
gsl::span<T>& next_token_logits = beam_state->next_token_logits;
if (input_length > 1) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const T> source(current_logits, vocab_size);
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size,
static_cast<gsl::index>(vocab_size));
gsl::copy(source, target);

const T* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const T> source(current_logits, vocab_size);
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size,
static_cast<gsl::index>(vocab_size));
gsl::copy(source, target);
if (beam_batch_size == batch_beam_size) {
current_logits += input_length * vocab_size;
} else if (beam_batch_size == batch_size && i != 0 && i % num_beams == 0) {
current_logits += input_length * vocab_size;
}
}

#ifdef DEBUG_BEAM_SEARCH
dumper->Print("logits", logits);
if (input_length > 1) {
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
}
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
wangyems marked this conversation as resolved.
Show resolved Hide resolved
#endif

// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
gsl::span<T>& next_token_scores = beam_state->next_token_scores;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(batch_beam_size, // rows
vocab_size, // elements per row
input_length > 1 ? next_token_logits.data() : logits_data,
next_token_logits.data(),
next_token_scores.data(),
true,
thread_pool));
Expand Down Expand Up @@ -456,13 +519,12 @@ Status UpdateGptFeeds(
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids) {
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids) {
const TensorShape& input_ids_shape = original_encoder_input_ids->Shape();
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
const int64_t& batch_size = input_ids_shape[0];
Expand All @@ -475,14 +537,12 @@ Status CreateEncoderInputs(
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
OrtValue encoder_input_ids;
Tensor::InitOrtValue(element_type,
input_ids_shape,
const_cast<Tensor*>(original_encoder_input_ids)->MutableData<int32_t>(),
allocator->Info(),
encoder_input_ids);

OrtValue encoder_attention_mask;
if (attn_mask_value != nullptr) {
const Tensor& attention_mask = attn_mask_value->Get<Tensor>();
Tensor::InitOrtValue(element_type, input_ids_shape, const_cast<Tensor*>(&attention_mask)->MutableData<int32_t>(),
Expand Down Expand Up @@ -511,20 +571,14 @@ Status CreateEncoderInputs(
}
}

// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// for encoder_input_ids and encoder_attention_mask
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
ExpandInputs<int32_t>(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids);
ExpandInputs<int32_t>(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask);

// decoder_input_ids is optional.
if (start_token_id >= 0) {
// Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID
int64_t dims[] = {batch_size * num_beams, 1};
// Filled decoder_input_ids with start token ID
int64_t dims[] = {batch_size, 1};
TensorShape decoder_input_ids_shape(&dims[0], 2);
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids);
int32_t* data = expanded_decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_size * num_beams; i++, data++) {
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, decoder_input_ids);
int32_t* data = decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_size; i++, data++) {
*data = start_token_id;
}
}
Expand Down Expand Up @@ -708,6 +762,8 @@ template Status UpdateDecoderFeeds<float>(
const transformers::IConsoleDumper* dumper);

template void ExpandInputs<int32_t>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
template void ExpandCaches<float>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
template void ExpandHiddenStates<float>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,12 @@ using UpdateGptFeedsFunc = std::function<Status(
using CreateEncoderInputsFunc = std::function<Status(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids)>;
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids)>;

// Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5).
template <typename T>
Expand Down Expand Up @@ -212,13 +211,12 @@ Status UpdateGptFeeds(
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
const OrtValue* attn_mask_value,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids);
OrtValue& encoder_input_ids,
OrtValue& encoder_attention_mask,
OrtValue& decoder_input_ids);

// Update decoder inputs given decoder outputs of last iteration.
template <typename T>
Expand All @@ -244,6 +242,12 @@ Status UpdateDecoderFeeds(
template <typename T>
void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

template <typename T>
void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

template <typename T>
void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);

} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
17 changes: 16 additions & 1 deletion onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
this->sequences.Init(this->sequences_space, static_cast<int>(batch_beam_size), sequence_length, max_length);
}

// Copy input_ids to sequences[0]
// Copy expanded input_ids to sequences[0]
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
size_t batch_beam_size,
int max_length,
Expand All @@ -109,6 +109,21 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
}
}

// Copy unexpanded input_ids to sequences[0]
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
size_t batch_beam_size,
int beam_size,
int max_length,
int sequence_length) {
gsl::span<int32_t> sequences_0 = sequences_space;
for (size_t i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < sequence_length; j++) {
const size_t index = SafeInt<gsl::index>(i) * max_length + j;
sequences_0[index] = input_ids_in_cpu[SafeInt<gsl::index>(i / beam_size) * sequence_length + j];
}
}
}

private:
BufferUniquePtr final_beam_scores_buffer_;
BufferUniquePtr sequence_lengths_buffer_;
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,18 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
this->IsCuda());

IAllocatorUniquePtr<char> buffer;
OrtValue expanded_decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state
OrtValue decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state
ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds(
encoder_input_ids,
encoder_attn_mask_value,
this->implicit_inputs_,
parameters->num_beams,
parameters->pad_token_id,
parameters->decoder_start_token_id,
encoder_feeds,
this->create_encoder_inputs_func_,
this->add_to_feeds_func_,
buffer,
expanded_decoder_input_ids));
decoder_input_ids));

ORT_RETURN_IF_ERROR(utils::ExecuteSubgraph(this->encoder_session_state_,
encoder_feeds_fetches_manager,
Expand Down Expand Up @@ -150,9 +149,10 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
// Initialize resources
// ------------------------------------

// Copy expanded_decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam.
cpu_state.SetSequence(expanded_decoder_input_ids.Get<Tensor>().DataAsSpan<int32_t>(),
// Copy decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam.
cpu_state.SetSequence(decoder_input_ids.Get<Tensor>().DataAsSpan<int32_t>(),
static_cast<size_t>(parameters->BatchBeamSize()),
parameters->num_beams,
parameters->max_length,
parameters->sequence_length);

Expand Down Expand Up @@ -211,6 +211,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
encoder_fetches,
decoder_feeds,
this->device_copy_int32_func_,
parameters->num_beams,
this->cuda_stream_));
}

Expand Down
18 changes: 16 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
const std::vector<OrtValue>& encoder_fetches,
std::vector<OrtValue>& decoder_feeds,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
int num_beam,
void* stream) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");

Expand All @@ -144,13 +145,26 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
decoder_feeds.push_back(input_ids);

// The encoder_attention_mask is copied from the second input of encoder.
decoder_feeds.push_back(encoder_feeds[1]);
OrtValue expanded_decoder_attention_masks;
BeamSearchCpuDeviceHelper::ExpandInputs<int32_t>(encoder_feeds[1],
num_beam,
allocator,
expanded_decoder_attention_masks);
decoder_feeds.push_back(expanded_decoder_attention_masks);

// When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output
// of encoder.
// When first_past_input_index_ == 2, the past states are copied from the second output of encoder.
for (size_t j = 4 - first_past_input_index_; j < encoder_fetches.size(); j++) {
decoder_feeds.push_back(encoder_fetches[j]);
if (j == 1) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
OrtValue expanded_hidden_states;
BeamSearchCpuDeviceHelper::ExpandHiddenStates<float>(encoder_fetches[j], num_beam, allocator, expanded_hidden_states);
decoder_feeds.push_back(expanded_hidden_states);
} else {
OrtValue expanded_cache;
BeamSearchCpuDeviceHelper::ExpandCaches<float>(encoder_fetches[j], num_beam, allocator, expanded_cache);
decoder_feeds.push_back(expanded_cache);
}
}

// Pass through implicit inputs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class T5DecoderSubgraph : public Subgraph {
const std::vector<OrtValue>& encoder_fetches,
std::vector<OrtValue>& decoder_feeds,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
int num_beam,
void* stream);

Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
Expand Down
Loading