From 15ed4927016b854f2b013605c699c9038eba7cab Mon Sep 17 00:00:00 2001 From: wangye Date: Tue, 21 Jun 2022 02:29:59 +0000 Subject: [PATCH 01/10] ooptimize t5 encoder --- .../transformers/beam_search_device_helper.cc | 64 +++++++++++++------ .../transformers/beam_search_device_helper.h | 17 ++--- .../cpu/transformers/beam_search_impl_base.h | 17 ++++- .../cpu/transformers/beam_search_impl_t5.h | 5 +- .../cpu/transformers/subgraph_t5_decoder.cc | 12 +++- .../cpu/transformers/subgraph_t5_decoder.h | 1 + .../cpu/transformers/subgraph_t5_encoder.cc | 22 +++---- .../cpu/transformers/subgraph_t5_encoder.h | 1 - 8 files changed, 95 insertions(+), 44 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index 2e6af56fc473..5e9afde3c0d7 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -60,6 +60,37 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, } } +template +void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) { + // Input shape (batch_size, num_heads, sequence_length, head_size). The input is required with data type T. + // Output shape (batch_size * num_beams, num_heads, sequence_length, head_size) + + const TensorShape& input_shape = input.Get().Shape(); + const int64_t& batch_size = input_shape[0]; + const int64_t& num_heads = input_shape[1]; + const int64_t& sequence_length = input_shape[2]; + const int64_t& head_size = input_shape[3]; + + int64_t dims[] = {batch_size * num_beams, num_heads, sequence_length, head_size}; + TensorShape expanded_shape(&dims[0], 4); + + MLDataType element_type = input.Get().DataType(); + ORT_ENFORCE(element_type == DataTypeImpl::GetType()); + + Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); + + const T* input_data = input.Get().Data(); + T* expanded_data = expanded.GetMutable()->MutableData(); + T* target = expanded_data; + const int64_t chunk_size = sequence_length * num_heads * head_size; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size); + target += chunk_size; + } + } +} + Status CreateGptInputs( const Tensor* original_input_ids, int num_beams, @@ -200,6 +231,7 @@ Status ProcessLogits(const OrtValue& logits, // const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); auto input_length = logits_shape[1]; + auto beam_batch_size = logits_shape[0]; // Get logits for the last token: // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size) @@ -212,7 +244,11 @@ Status ProcessLogits(const OrtValue& logits, // gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, static_cast(vocab_size)); gsl::copy(source, target); - current_logits += input_length * vocab_size; + if (beam_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (beam_batch_size == batch_size && i % num_beams == 0) { + current_logits += input_length * vocab_size; + } } } @@ -456,13 +492,12 @@ Status UpdateGptFeeds( Status CreateEncoderInputs( const Tensor* original_encoder_input_ids, const OrtValue* attn_mask_value, - int num_beams, int pad_token_id, int start_token_id, AllocatorPtr allocator, - OrtValue& expanded_encoder_input_ids, - OrtValue& expanded_encoder_attention_mask, - OrtValue& expanded_decoder_input_ids) { + OrtValue& encoder_input_ids, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids) { const TensorShape& input_ids_shape = original_encoder_input_ids->Shape(); ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); const int64_t& batch_size = input_ids_shape[0]; @@ -475,14 +510,12 @@ Status CreateEncoderInputs( // Current shape is (batch_size, sequence_length) // Note that we will expand it to (batch_size * num_beams, sequence_length) later. // To avoid cloning input_ids, we use const_cast here since this function does not change its content. - OrtValue encoder_input_ids; Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(original_encoder_input_ids)->MutableData(), allocator->Info(), encoder_input_ids); - OrtValue encoder_attention_mask; if (attn_mask_value != nullptr) { const Tensor& attention_mask = attn_mask_value->Get(); Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(&attention_mask)->MutableData(), @@ -511,20 +544,14 @@ Status CreateEncoderInputs( } } - // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) - // for encoder_input_ids and encoder_attention_mask - // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance. - ExpandInputs(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids); - ExpandInputs(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask); - // decoder_input_ids is optional. if (start_token_id >= 0) { - // Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID - int64_t dims[] = {batch_size * num_beams, 1}; + // Filled decoder_input_ids with start token ID + int64_t dims[] = {batch_size, 1}; TensorShape decoder_input_ids_shape(&dims[0], 2); - Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids); - int32_t* data = expanded_decoder_input_ids.GetMutable()->MutableData(); - for (int i = 0; i < batch_size * num_beams; i++, data++) { + Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, decoder_input_ids); + int32_t* data = decoder_input_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_size; i++, data++) { *data = start_token_id; } } @@ -708,6 +735,7 @@ template Status UpdateDecoderFeeds( const transformers::IConsoleDumper* dumper); template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +template void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); } // namespace BeamSearchCpuDeviceHelper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h index 8cd7a0291af0..5d82214c830c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h @@ -107,13 +107,12 @@ using UpdateGptFeedsFunc = std::function; + OrtValue& encoder_input_ids, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids)>; // Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5). template @@ -212,13 +211,12 @@ Status UpdateGptFeeds( Status CreateEncoderInputs( const Tensor* original_encoder_input_ids, const OrtValue* attn_mask_value, - int num_beams, int pad_token_id, int start_token_id, AllocatorPtr allocator, - OrtValue& expanded_encoder_input_ids, - OrtValue& expanded_encoder_attention_mask, - OrtValue& expanded_decoder_input_ids); + OrtValue& encoder_input_ids, + OrtValue& encoder_attention_mask, + OrtValue& decoder_input_ids); // Update decoder inputs given decoder outputs of last iteration. template @@ -244,6 +242,9 @@ Status UpdateDecoderFeeds( template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +template +void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); + } // namespace BeamSearchCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 03c88b5aa804..94755989e7ad 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -95,7 +95,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState { this->sequences.Init(this->sequences_space, static_cast(batch_beam_size), sequence_length, max_length); } - // Copy input_ids to sequences[0] + // Copy expanded input_ids to sequences[0] void SetSequence(gsl::span input_ids_in_cpu, size_t batch_beam_size, int max_length, @@ -109,6 +109,21 @@ struct BeamSearchCpuState : public IBeamSearchCpuState { } } + // Copy unexpanded input_ids to sequences[0] + void SetSequence(gsl::span input_ids_in_cpu, + size_t batch_beam_size, + int beam_size, + int max_length, + int sequence_length) { + gsl::span sequences_0 = sequences_space; + for (size_t i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < sequence_length; j++) { + const size_t index = SafeInt(i) * max_length + j; + sequences_0[index] = input_ids_in_cpu[SafeInt(i) * sequence_length * beam_size + j]; + } + } + } + private: BufferUniquePtr final_beam_scores_buffer_; BufferUniquePtr sequence_lengths_buffer_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 6acbe857e5a7..598000f44b9a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -115,7 +115,6 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches encoder_input_ids, encoder_attn_mask_value, this->implicit_inputs_, - parameters->num_beams, parameters->pad_token_id, parameters->decoder_start_token_id, encoder_feeds, @@ -150,9 +149,10 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches // Initialize resources // ------------------------------------ - // Copy expanded_decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam. + // Copy decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam. cpu_state.SetSequence(expanded_decoder_input_ids.Get().DataAsSpan(), static_cast(parameters->BatchBeamSize()), + parameters->num_beams, parameters->max_length, parameters->sequence_length); @@ -211,6 +211,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches encoder_fetches, decoder_feeds, this->device_copy_int32_func_, + parameters->num_beams, this->cuda_stream_)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index bb931255dd17..948a786588a1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -121,6 +121,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( const std::vector& encoder_fetches, std::vector& decoder_feeds, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + int num_beam, void* stream) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); @@ -144,13 +145,20 @@ Status T5DecoderSubgraph::CreateInitialFeeds( decoder_feeds.push_back(input_ids); // The encoder_attention_mask is copied from the second input of encoder. - decoder_feeds.push_back(encoder_feeds[1]); + OrtValue expanded_decoder_attention_masks; + BeamSearchCpuDeviceHelper::ExpandInputs(encoder_feeds[1], + num_beam, + allocator, + expanded_decoder_attention_masks); + decoder_feeds.push_back(expanded_decoder_attention_masks); // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output // of encoder. // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. for (size_t j = 4 - first_past_input_index_; j < encoder_fetches.size(); j++) { - decoder_feeds.push_back(encoder_fetches[j]); + OrtValue expanded_cache; + BeamSearchCpuDeviceHelper::ExpandCaches(encoder_fetches[j], num_beam, allocator, expanded_cache); + decoder_feeds.push_back(expanded_cache); } // Pass through implicit inputs. diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 108b1c298d75..1cfa2a5372af 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -28,6 +28,7 @@ class T5DecoderSubgraph : public Subgraph { const std::vector& encoder_fetches, std::vector& decoder_feeds, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + int num_beam, void* stream); Status Validate(const std::vector& subgraph_inputs, diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc index 7574c31ec5b6..153deaa8cb82 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -96,24 +96,23 @@ Status T5EncoderSubgraph::Validate(const std::vector& subgraph_i // Create inputs for first inference of subgraph. Status T5EncoderSubgraph::CreateInitialFeeds( - const Tensor& encoder_input_ids, + const Tensor& original_encoder_input_ids, const OrtValue* attn_mask_value, const std::vector& implicit_inputs, - int num_beams, int pad_token_id, int start_token_id, std::vector& feeds, const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func, const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, IAllocatorUniquePtr& buffer, - OrtValue& expanded_decoder_input_ids) { + OrtValue& decoder_input_ids) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); // The ordering is the same as used in Setup. feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); // Allocate subgraph inputs to be same device as encoder_input_ids. - AllocatorPtr cpu_allocator = session_state_->GetAllocator(encoder_input_ids.Location()); + AllocatorPtr cpu_allocator = session_state_->GetAllocator(original_encoder_input_ids.Location()); if (cpu_allocator == nullptr) { const IExecutionProvider* provider = GetProvider(); cpu_allocator = provider->GetAllocator(0, OrtMemTypeDefault); @@ -121,22 +120,21 @@ Status T5EncoderSubgraph::CreateInitialFeeds( ORT_RETURN_IF(cpu_allocator == nullptr, "cpu_allocator shouldn't be nullptr"); // TODO(tianleiwu): expand the outputs instead of inputs to save computation. - OrtValue expanded_encoder_input_ids; - OrtValue expanded_encoder_attention_mask; - ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&encoder_input_ids, + OrtValue encoder_input_ids; + OrtValue encoder_attention_mask; + ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&original_encoder_input_ids, attn_mask_value, - num_beams, pad_token_id, start_token_id, cpu_allocator, - expanded_encoder_input_ids, - expanded_encoder_attention_mask, - expanded_decoder_input_ids)); + encoder_input_ids, + encoder_attention_mask, + decoder_input_ids)); const IExecutionProvider* provider = GetProvider(); ORT_RETURN_IF_ERROR(add_to_feeds_func( provider, - {expanded_encoder_input_ids, expanded_encoder_attention_mask, expanded_decoder_input_ids}, + {encoder_input_ids, encoder_attention_mask, decoder_input_ids}, feeds, buffer)); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h index 83c9cb22c66a..9c67f4962135 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h @@ -24,7 +24,6 @@ class T5EncoderSubgraph : public Subgraph { const Tensor& encoder_input_ids, const OrtValue* attn_mask_value, const std::vector& implicit_inputs, - int num_beams, int pad_token_id, int start_token_id, std::vector& feeds, From 8a3dea06dd12e7e9805c529b1deac25d59881322 Mon Sep 17 00:00:00 2001 From: wangye Date: Tue, 21 Jun 2022 19:22:58 +0000 Subject: [PATCH 02/10] update --- .../contrib_ops/cpu/transformers/beam_search_impl_base.h | 2 +- .../contrib_ops/cpu/transformers/beam_search_impl_t5.h | 6 +++--- .../cuda/transformers/beam_search_device_helper.cc | 7 ++++++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 94755989e7ad..790e96a4476b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -119,7 +119,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState { for (size_t i = 0; i < batch_beam_size; i++) { for (int j = 0; j < sequence_length; j++) { const size_t index = SafeInt(i) * max_length + j; - sequences_0[index] = input_ids_in_cpu[SafeInt(i) * sequence_length * beam_size + j]; + sequences_0[index] = input_ids_in_cpu[SafeInt(i / beam_size) * sequence_length + j]; } } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 598000f44b9a..a16289105159 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -110,7 +110,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches this->IsCuda()); IAllocatorUniquePtr buffer; - OrtValue expanded_decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state + OrtValue decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds( encoder_input_ids, encoder_attn_mask_value, @@ -121,7 +121,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches this->create_encoder_inputs_func_, this->add_to_feeds_func_, buffer, - expanded_decoder_input_ids)); + decoder_input_ids)); ORT_RETURN_IF_ERROR(utils::ExecuteSubgraph(this->encoder_session_state_, encoder_feeds_fetches_manager, @@ -150,7 +150,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches // ------------------------------------ // Copy decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam. - cpu_state.SetSequence(expanded_decoder_input_ids.Get().DataAsSpan(), + cpu_state.SetSequence(decoder_input_ids.Get().DataAsSpan(), static_cast(parameters->BatchBeamSize()), parameters->num_beams, parameters->max_length, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index 3c45c2cc60b1..0fc4a2073381 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -221,6 +221,7 @@ Status ProcessLogits(const OrtValue& logits, // const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); auto input_length = logits_shape[1]; + auto beam_batch_size = logits_shape[0]; cudaStream_t cuda_stream = reinterpret_cast(stream); @@ -236,7 +237,11 @@ Status ProcessLogits(const OrtValue& logits, // gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); - current_logits += input_length * vocab_size; + if (beam_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (beam_batch_size == batch_size && i % num_beams == 0) { + current_logits += input_length * vocab_size; + } } } From 657df469b4dda8b249f9089f8255045252faf232 Mon Sep 17 00:00:00 2001 From: wangye Date: Tue, 21 Jun 2022 21:07:23 +0000 Subject: [PATCH 03/10] update --- .../transformers/beam_search_device_helper.cc | 60 ++++++++++++++----- .../transformers/beam_search_device_helper.h | 3 + .../cpu/transformers/subgraph_t5_decoder.cc | 12 +++- .../transformers/beam_search_device_helper.cc | 27 ++++----- 4 files changed, 69 insertions(+), 33 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index 5e9afde3c0d7..d4ef19bbdec3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -91,6 +91,36 @@ void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, } } +template +void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) { + // Input shape (batch_size, num_heads, sequence_length, head_size). The input is required with data type T. + // Output shape (batch_size * num_beams, num_heads, sequence_length, head_size) + + const TensorShape& input_shape = input.Get().Shape(); + const int64_t& batch_size = input_shape[0]; + const int64_t& sequence_length = input_shape[1]; + const int64_t& hidden_size = input_shape[2]; + + int64_t dims[] = {batch_size * num_beams, sequence_length, hidden_size}; + TensorShape expanded_shape(&dims[0], 3); + + MLDataType element_type = input.Get().DataType(); + ORT_ENFORCE(element_type == DataTypeImpl::GetType()); + + Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); + + // const T* input_data = input.Get().Data(); + // T* expanded_data = expanded.GetMutable()->MutableData(); + // T* target = expanded_data; + // const int64_t chunk_size = sequence_length * hidden_size; + // for (int i = 0; i < batch_size; i++) { + // for (int j = 0; j < num_beams; j++) { + // memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size); + // target += chunk_size; + // } + // } +} + Status CreateGptInputs( const Tensor* original_input_ids, int num_beams, @@ -237,33 +267,30 @@ Status ProcessLogits(const OrtValue& logits, // // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size) // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state->next_token_logits; - if (input_length > 1) { - const T* current_logits = logits_data + (input_length - 1) * vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span source(current_logits, vocab_size); - gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, - static_cast(vocab_size)); - gsl::copy(source, target); - if (beam_batch_size == batch_beam_size) { - current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i % num_beams == 0) { - current_logits += input_length * vocab_size; - } + + const T* current_logits = logits_data + (input_length - 1) * vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span source(current_logits, vocab_size); + gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, + static_cast(vocab_size)); + gsl::copy(source, target); + if (beam_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (beam_batch_size == batch_size && i != 0 && i % num_beams == 0) { + current_logits += input_length * vocab_size; } } #ifdef DEBUG_BEAM_SEARCH dumper->Print("logits", logits); - if (input_length > 1) { - dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); - } + dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); #endif // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) gsl::span& next_token_scores = beam_state->next_token_scores; ORT_RETURN_IF_ERROR(SoftmaxCPU(batch_beam_size, // rows vocab_size, // elements per row - input_length > 1 ? next_token_logits.data() : logits_data, + next_token_logits.data(), next_token_scores.data(), true, thread_pool)); @@ -736,6 +763,7 @@ template Status UpdateDecoderFeeds( template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); template void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +template void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); } // namespace BeamSearchCpuDeviceHelper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h index 5d82214c830c..e97c2d36333a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h @@ -245,6 +245,9 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, template void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +template +void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); + } // namespace BeamSearchCpuDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 948a786588a1..6ebc2fa05c75 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -156,9 +156,15 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // of encoder. // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. for (size_t j = 4 - first_past_input_index_; j < encoder_fetches.size(); j++) { - OrtValue expanded_cache; - BeamSearchCpuDeviceHelper::ExpandCaches(encoder_fetches[j], num_beam, allocator, expanded_cache); - decoder_feeds.push_back(expanded_cache); + if (j == 1) { + OrtValue expanded_hidden_states; + BeamSearchCpuDeviceHelper::ExpandHiddenStates(encoder_fetches[j], num_beam, allocator, expanded_hidden_states); + decoder_feeds.push_back(expanded_hidden_states); + } else { + OrtValue expanded_cache; + BeamSearchCpuDeviceHelper::ExpandCaches(encoder_fetches[j], num_beam, allocator, expanded_cache); + decoder_feeds.push_back(expanded_cache); + } } // Pass through implicit inputs. diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index 0fc4a2073381..d5e1c5e1cc96 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -229,19 +229,18 @@ Status ProcessLogits(const OrtValue& logits, // // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size) // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state->next_token_logits; - if (input_length > 1) { - // TODO(tianleiwu): use one kernel to replace a loop of memory copy. - const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span source(reinterpret_cast(current_logits), vocab_size); - gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, - cudaMemcpyDeviceToDevice, cuda_stream)); - if (beam_batch_size == batch_beam_size) { - current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i % num_beams == 0) { - current_logits += input_length * vocab_size; - } + + // TODO(tianleiwu): use one kernel to replace a loop of memory copy. + const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span source(reinterpret_cast(current_logits), vocab_size); + gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, + cudaMemcpyDeviceToDevice, cuda_stream)); + if (beam_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (beam_batch_size == batch_size && i != 0 && i % num_beams == 0) { + current_logits += input_length * vocab_size; } } @@ -255,7 +254,7 @@ Status ProcessLogits(const OrtValue& logits, // // The output will be float for consideration of precision and easy integration with remaining parts. float* Y_data = next_token_scores.data(); - const CudaT* X_data = input_length > 1 ? reinterpret_cast(next_token_logits.data()) : logits_data; + const CudaT* X_data = reinterpret_cast(next_token_logits.data()); dispatch_blockwise_softmax_forward( cuda_stream, Y_data, X_data, vocab_size, vocab_size, batch_size * num_beams); From 0429340b05cc06eb9d3826065cbf7f245699c935 Mon Sep 17 00:00:00 2001 From: wangye Date: Tue, 21 Jun 2022 21:13:18 +0000 Subject: [PATCH 04/10] update --- .../contrib_ops/cpu/transformers/beam_search_device_helper.cc | 2 +- .../contrib_ops/cuda/transformers/beam_search_device_helper.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index d4ef19bbdec3..f76c9c2b89e4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -276,7 +276,7 @@ Status ProcessLogits(const OrtValue& logits, // gsl::copy(source, target); if (beam_batch_size == batch_beam_size) { current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i != 0 && i % num_beams == 0) { + } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { current_logits += input_length * vocab_size; } } diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index d5e1c5e1cc96..5c1d02b3394c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -239,7 +239,7 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToDevice, cuda_stream)); if (beam_batch_size == batch_beam_size) { current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i != 0 && i % num_beams == 0) { + } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { current_logits += input_length * vocab_size; } } From 6532abbaa32852889ec75ad59f6bab1be65b88e7 Mon Sep 17 00:00:00 2001 From: wangye Date: Tue, 21 Jun 2022 23:39:31 +0000 Subject: [PATCH 05/10] refactor expand impl --- .../transformers/beam_search_device_helper.cc | 56 +++++-------------- .../transformers/beam_search_device_helper.h | 5 +- .../cpu/transformers/subgraph_t5_decoder.cc | 10 ++-- 3 files changed, 21 insertions(+), 50 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index f76c9c2b89e4..757255e4dd9c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -61,28 +61,30 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, } template -void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) { - // Input shape (batch_size, num_heads, sequence_length, head_size). The input is required with data type T. - // Output shape (batch_size * num_beams, num_heads, sequence_length, head_size) +void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape) { + // Input shape (batch_size, xxx). The input is required with data type T. + // Output shape (batch_size * num_beams, xxx) const TensorShape& input_shape = input.Get().Shape(); const int64_t& batch_size = input_shape[0]; - const int64_t& num_heads = input_shape[1]; - const int64_t& sequence_length = input_shape[2]; - const int64_t& head_size = input_shape[3]; + const int64_t& chunk_size = static_cast(input_shape.Size() / batch_size); - int64_t dims[] = {batch_size * num_beams, num_heads, sequence_length, head_size}; - TensorShape expanded_shape(&dims[0], 4); + int64_t dims[4] = {0}; + input_shape.CopyDims(dims, input_shape.NumDimensions()); + dims[0] = batch_size * num_beams; + TensorShape expanded_shape(&dims[0], input_shape.NumDimensions()); MLDataType element_type = input.Get().DataType(); ORT_ENFORCE(element_type == DataTypeImpl::GetType()); - Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); + if (only_copy_shape) { + return; + } + const T* input_data = input.Get().Data(); T* expanded_data = expanded.GetMutable()->MutableData(); T* target = expanded_data; - const int64_t chunk_size = sequence_length * num_heads * head_size; for (int i = 0; i < batch_size; i++) { for (int j = 0; j < num_beams; j++) { memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size); @@ -91,36 +93,6 @@ void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, } } -template -void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) { - // Input shape (batch_size, num_heads, sequence_length, head_size). The input is required with data type T. - // Output shape (batch_size * num_beams, num_heads, sequence_length, head_size) - - const TensorShape& input_shape = input.Get().Shape(); - const int64_t& batch_size = input_shape[0]; - const int64_t& sequence_length = input_shape[1]; - const int64_t& hidden_size = input_shape[2]; - - int64_t dims[] = {batch_size * num_beams, sequence_length, hidden_size}; - TensorShape expanded_shape(&dims[0], 3); - - MLDataType element_type = input.Get().DataType(); - ORT_ENFORCE(element_type == DataTypeImpl::GetType()); - - Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); - - // const T* input_data = input.Get().Data(); - // T* expanded_data = expanded.GetMutable()->MutableData(); - // T* target = expanded_data; - // const int64_t chunk_size = sequence_length * hidden_size; - // for (int i = 0; i < batch_size; i++) { - // for (int j = 0; j < num_beams; j++) { - // memcpy(target, input_data + i * chunk_size, sizeof(T) * chunk_size); - // target += chunk_size; - // } - // } -} - Status CreateGptInputs( const Tensor* original_input_ids, int num_beams, @@ -762,8 +734,8 @@ template Status UpdateDecoderFeeds( const transformers::IConsoleDumper* dumper); template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); -template void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); -template void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +template void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape); +template void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape); } // namespace BeamSearchCpuDeviceHelper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h index e97c2d36333a..5f4ed8967731 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h @@ -243,10 +243,7 @@ template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); template -void ExpandCaches(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); - -template -void ExpandHiddenStates(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); +void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape); } // namespace BeamSearchCpuDeviceHelper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 6ebc2fa05c75..191236845eed 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -146,10 +146,12 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // The encoder_attention_mask is copied from the second input of encoder. OrtValue expanded_decoder_attention_masks; - BeamSearchCpuDeviceHelper::ExpandInputs(encoder_feeds[1], + std::cout << "expanded_decoder_attention_masks 149" << std::endl; + BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_feeds[1], num_beam, allocator, - expanded_decoder_attention_masks); + expanded_decoder_attention_masks, false); + std::cout << "after crash 149" << std::endl; decoder_feeds.push_back(expanded_decoder_attention_masks); // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output @@ -158,11 +160,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( for (size_t j = 4 - first_past_input_index_; j < encoder_fetches.size(); j++) { if (j == 1) { OrtValue expanded_hidden_states; - BeamSearchCpuDeviceHelper::ExpandHiddenStates(encoder_fetches[j], num_beam, allocator, expanded_hidden_states); + BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_hidden_states, true); decoder_feeds.push_back(expanded_hidden_states); } else { OrtValue expanded_cache; - BeamSearchCpuDeviceHelper::ExpandCaches(encoder_fetches[j], num_beam, allocator, expanded_cache); + BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_cache, false); decoder_feeds.push_back(expanded_cache); } } From 2f3c687db0e143f5c418bc4dbde38d43b10f2fbe Mon Sep 17 00:00:00 2001 From: wangye Date: Wed, 22 Jun 2022 01:22:25 +0000 Subject: [PATCH 06/10] cuda tests passed --- .../cpu/transformers/beam_search.cc | 10 ++- .../cpu/transformers/beam_search.h | 12 +++- .../transformers/beam_search_device_helper.cc | 38 ++++++++-- .../transformers/beam_search_device_helper.h | 18 ++++- .../cpu/transformers/beam_search_impl_t5.h | 16 ++++- .../cpu/transformers/subgraph_t5_decoder.cc | 52 ++++++++++++-- .../cpu/transformers/subgraph_t5_decoder.h | 3 + .../cuda/transformers/beam_search.cc | 5 +- .../transformers/beam_search_device_helper.cc | 70 +++++++++++++++++++ .../transformers/beam_search_device_helper.h | 9 +++ 10 files changed, 216 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 4cc3533f45e7..c37ae8caa5ee 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -183,7 +183,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs, - update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds}; + update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds, + expand_buffer_int32_func_ ? expand_buffer_int32_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer, + expand_buffer_float_func_ ? expand_buffer_float_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer, + expand_buffer_float16_func_ ? expand_buffer_float16_func_ : BeamSearchCpuDeviceHelper::ExpandBuffer}; ORT_RETURN_IF_ERROR(impl.Initialize()); return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); @@ -198,7 +201,10 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { device_copy_func_, device_copy_int32_func_, create_encoder_inputs_func_, - update_decoder_feeds_fp16_func_}; + update_decoder_feeds_fp16_func_, + expand_buffer_int32_func_, + expand_buffer_float_func_, + expand_buffer_float16_func_}; ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 3a0de820106e..f9cb7d66c585 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -74,9 +74,15 @@ class BeamSearch : public IControlFlowKernel { // device helpers for encoder-decoder model like T5 void SetDeviceHelpers_EncoderDecoder( const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, - const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func) { + const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) { update_decoder_feeds_func_ = update_decoder_feeds_func; update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func; + expand_buffer_int32_func_ = expand_buffer_int32_func; + expand_buffer_float_func_ = expand_buffer_float_func; + expand_buffer_float16_func_ = expand_buffer_float16_func; } private: @@ -106,6 +112,10 @@ class BeamSearch : public IControlFlowKernel { BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_; BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_fp16_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_int32_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; + //------------------------------------------------------------ // Subgraph and FeedsFetchesManager re-used for each subgraph execution. //------------------------------------------------------------ diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index 757255e4dd9c..bcbca1a5af00 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -61,9 +61,15 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, } template -void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape) { +Status ExpandBuffer(void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape) { // Input shape (batch_size, xxx). The input is required with data type T. // Output shape (batch_size * num_beams, xxx) + ORT_UNUSED_PARAMETER(stream); const TensorShape& input_shape = input.Get().Shape(); const int64_t& batch_size = input_shape[0]; @@ -79,7 +85,7 @@ void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); if (only_copy_shape) { - return; + return Status::OK(); } const T* input_data = input.Get().Data(); @@ -91,6 +97,8 @@ void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, target += chunk_size; } } + + return Status::OK(); } Status CreateGptInputs( @@ -734,8 +742,30 @@ template Status UpdateDecoderFeeds( const transformers::IConsoleDumper* dumper); template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); -template void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape); -template void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); } // namespace BeamSearchCpuDeviceHelper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h index 5f4ed8967731..ab18eec25cde 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h @@ -131,8 +131,18 @@ using UpdateDecoderFeedsFunc = std::function; + +template +using ExpandBufferFunc = std::function; } // namespace BeamSearchDeviceHelper + // These are CPU specific device helper implementations namespace BeamSearchCpuDeviceHelper { Status TopK( @@ -243,7 +253,13 @@ template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); template -void ExpandBuffer(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded, bool only_copy_shape); +Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); } // namespace BeamSearchCpuDeviceHelper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index a16289105159..5360bfd8f4cf 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -33,7 +33,10 @@ class BeamSearchT5 : public BeamSearchBase { const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func, - const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func) + const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func) : BeamSearchBase(context, decoder_session_state, thread_pool, cuda_stream, cuda_dumper, params, topk_func, process_logits_func, device_copy_func, device_copy_int32_func), @@ -43,7 +46,10 @@ class BeamSearchT5 : public BeamSearchBase { add_to_feeds_func_(add_to_feeds_func), init_beam_state_func_(init_beam_state_func), create_encoder_inputs_func_(create_encoder_inputs_func), - update_decoder_feeds_func_(update_decoder_feeds_func) { + update_decoder_feeds_func_(update_decoder_feeds_func), + expand_buffer_int32_func_(expand_buffer_int32_func), + expand_buffer_float_func_(expand_buffer_float_func), + expand_buffer_float16_func_(expand_buffer_float16_func) { } // Execute beam search in iterations util stopping criteria is reached. @@ -62,6 +68,9 @@ class BeamSearchT5 : public BeamSearchBase { BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_; BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_int32_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float_func_; + BeamSearchDeviceHelper::ExpandBufferFunc expand_buffer_float16_func_; }; template @@ -211,6 +220,9 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches encoder_fetches, decoder_feeds, this->device_copy_int32_func_, + this->expand_buffer_int32_func_, + this->expand_buffer_float_func_, + this->expand_buffer_float16_func_, parameters->num_beams, this->cuda_stream_)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index 191236845eed..d8ea5b2dceb2 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -121,6 +121,9 @@ Status T5DecoderSubgraph::CreateInitialFeeds( const std::vector& encoder_fetches, std::vector& decoder_feeds, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, int num_beam, void* stream) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); @@ -147,10 +150,16 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // The encoder_attention_mask is copied from the second input of encoder. OrtValue expanded_decoder_attention_masks; std::cout << "expanded_decoder_attention_masks 149" << std::endl; - BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_feeds[1], - num_beam, - allocator, - expanded_decoder_attention_masks, false); + ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, + encoder_feeds[1], + num_beam, + allocator, + expanded_decoder_attention_masks, + false)); + // BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_feeds[1], + // num_beam, + // allocator, + // expanded_decoder_attention_masks, false); std::cout << "after crash 149" << std::endl; decoder_feeds.push_back(expanded_decoder_attention_masks); @@ -160,11 +169,42 @@ Status T5DecoderSubgraph::CreateInitialFeeds( for (size_t j = 4 - first_past_input_index_; j < encoder_fetches.size(); j++) { if (j == 1) { OrtValue expanded_hidden_states; - BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_hidden_states, true); + if (is_output_float16_) { + ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_hidden_states, + true)); + } else { + ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_hidden_states, + true)); + } + + //BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_hidden_states, true); decoder_feeds.push_back(expanded_hidden_states); } else { OrtValue expanded_cache; - BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_cache, false); + if (is_output_float16_) { + ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_cache, + false)); + } else { + ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, + encoder_fetches[j], + num_beam, + allocator, + expanded_cache, + false)); + } + //BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_cache, false); decoder_feeds.push_back(expanded_cache); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h index 1cfa2a5372af..edf7293a978c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -28,6 +28,9 @@ class T5DecoderSubgraph : public Subgraph { const std::vector& encoder_fetches, std::vector& decoder_feeds, const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_int32_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float_func, + const BeamSearchDeviceHelper::ExpandBufferFunc& expand_buffer_float16_func, int num_beam, void* stream); diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 91b660c197e1..5e500f560131 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,7 +49,10 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) BeamSearchCudaDeviceHelper::UpdateGptFeeds); SetDeviceHelpers_EncoderDecoder(BeamSearchCudaDeviceHelper::UpdateDecoderFeeds, - BeamSearchCudaDeviceHelper::UpdateDecoderFeeds); + BeamSearchCudaDeviceHelper::UpdateDecoderFeeds, + BeamSearchCudaDeviceHelper::ExpandBuffer, + BeamSearchCudaDeviceHelper::ExpandBuffer, + BeamSearchCudaDeviceHelper::ExpandBuffer); SetConsoleDumper(&g_cuda_dumper); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index 5c1d02b3394c..b693e7f97d3c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -622,6 +622,53 @@ Status UpdateDecoderFeeds( t5_decoder_first_past_input_idx, t5_decoder_first_present_output_idx, stream); } +template +Status ExpandBuffer(void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape) { + // Input shape (batch_size, xxx). The input is required with data type T. + // Output shape (batch_size * num_beams, xxx) + const TensorShape& input_shape = input.Get().Shape(); + const int64_t& batch_size = input_shape[0]; + const int64_t& chunk_size = static_cast(input_shape.Size() / batch_size); + + int64_t dims[4] = {0}; + input_shape.CopyDims(dims, input_shape.NumDimensions()); + dims[0] = batch_size * num_beams; + TensorShape expanded_shape(&dims[0], input_shape.NumDimensions()); + + MLDataType element_type = input.Get().DataType(); + ORT_ENFORCE(element_type == DataTypeImpl::GetType()); + Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); + + if (only_copy_shape) { + return Status::OK(); + } + + cudaStream_t cuda_stream = reinterpret_cast(stream); + + const T* input_data = input.Get().Data(); + T* expanded_data = expanded.GetMutable()->MutableData(); + T* target = expanded_data; + for (int i = 0; i < batch_size; i++) { + for (int j = 0; j < num_beams; j++) { + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync( + target, + input_data + i * chunk_size, + sizeof(T) * chunk_size, + cudaMemcpyDeviceToDevice, + cuda_stream)); + target += chunk_size; + } + } + + return Status::OK(); +} + // Explicit template instantiations of functions template void InitBeamState(transformers::IBeamSearchState* beam_state, gsl::span& sequence_lengths, @@ -734,6 +781,29 @@ template Status UpdateDecoderFeeds( transformers::Sequences& sequences, const transformers::IConsoleDumper* dumper); +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + +template Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); } // namespace BeamSearchCudaDeviceHelper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h index a90d0c7ee84c..14f64e923e78 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h @@ -97,6 +97,15 @@ Status UpdateDecoderFeeds( transformers::Sequences& sequences, const transformers::IConsoleDumper* dumper); +template +Status ExpandBuffer( + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); + } // namespace BeamSearchCudaDeviceHelper } // namespace contrib } // namespace onnxruntime From fd5bac2e5837064831ac4a7e93a3f8dc93629607 Mon Sep 17 00:00:00 2001 From: wangye Date: Wed, 22 Jun 2022 01:43:36 +0000 Subject: [PATCH 07/10] update --- .../transformers/beam_search_device_helper.cc | 36 ++++++++++--------- .../cpu/transformers/subgraph_t5_decoder.cc | 11 ++---- .../transformers/beam_search_device_helper.cc | 26 ++++++++------ 3 files changed, 37 insertions(+), 36 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index bcbca1a5af00..ee2c845dde23 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -248,16 +248,18 @@ Status ProcessLogits(const OrtValue& logits, // // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state->next_token_logits; - const T* current_logits = logits_data + (input_length - 1) * vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span source(current_logits, vocab_size); - gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, - static_cast(vocab_size)); - gsl::copy(source, target); - if (beam_batch_size == batch_beam_size) { - current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { - current_logits += input_length * vocab_size; + if (input_length > 1 || beam_batch_size == batch_size) { + const T* current_logits = logits_data + (input_length - 1) * vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span source(current_logits, vocab_size); + gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, + static_cast(vocab_size)); + gsl::copy(source, target); + if (beam_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { + current_logits += input_length * vocab_size; + } } } @@ -268,12 +270,14 @@ Status ProcessLogits(const OrtValue& logits, // // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) gsl::span& next_token_scores = beam_state->next_token_scores; - ORT_RETURN_IF_ERROR(SoftmaxCPU(batch_beam_size, // rows - vocab_size, // elements per row - next_token_logits.data(), - next_token_scores.data(), - true, - thread_pool)); + ORT_RETURN_IF_ERROR( + SoftmaxCPU( + batch_beam_size, // rows + vocab_size, // elements per row + (input_length == 1 && beam_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(), + next_token_scores.data(), + true, + thread_pool)); #ifdef DEBUG_BEAM_SEARCH dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index d8ea5b2dceb2..f14c28e032ce 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -149,18 +149,13 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // The encoder_attention_mask is copied from the second input of encoder. OrtValue expanded_decoder_attention_masks; - std::cout << "expanded_decoder_attention_masks 149" << std::endl; ORT_RETURN_IF_ERROR(expand_buffer_int32_func(stream, encoder_feeds[1], num_beam, allocator, expanded_decoder_attention_masks, false)); - // BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_feeds[1], - // num_beam, - // allocator, - // expanded_decoder_attention_masks, false); - std::cout << "after crash 149" << std::endl; + decoder_feeds.push_back(expanded_decoder_attention_masks); // When first_past_input_index_ == 3, the encoder_hidden_states and past states are copied from the second output @@ -168,6 +163,7 @@ Status T5DecoderSubgraph::CreateInitialFeeds( // When first_past_input_index_ == 2, the past states are copied from the second output of encoder. for (size_t j = 4 - first_past_input_index_; j < encoder_fetches.size(); j++) { if (j == 1) { + ORT_RETURN_IF(has_hidden_state_ == false, "Invalid hidden_states expension: has_hidden_state_ == false"); OrtValue expanded_hidden_states; if (is_output_float16_) { ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, @@ -184,8 +180,6 @@ Status T5DecoderSubgraph::CreateInitialFeeds( expanded_hidden_states, true)); } - - //BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_hidden_states, true); decoder_feeds.push_back(expanded_hidden_states); } else { OrtValue expanded_cache; @@ -204,7 +198,6 @@ Status T5DecoderSubgraph::CreateInitialFeeds( expanded_cache, false)); } - //BeamSearchCpuDeviceHelper::ExpandBuffer(encoder_fetches[j], num_beam, allocator, expanded_cache, false); decoder_feeds.push_back(expanded_cache); } } diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index b693e7f97d3c..96d30d13fbc5 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -231,16 +231,18 @@ Status ProcessLogits(const OrtValue& logits, // gsl::span& next_token_logits = beam_state->next_token_logits; // TODO(tianleiwu): use one kernel to replace a loop of memory copy. - const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size; - for (int i = 0; i < batch_beam_size; i++) { - gsl::span source(reinterpret_cast(current_logits), vocab_size); - gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, - cudaMemcpyDeviceToDevice, cuda_stream)); - if (beam_batch_size == batch_beam_size) { - current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { - current_logits += input_length * vocab_size; + if (input_length > 1 || beam_batch_size == batch_size) { + const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span source(reinterpret_cast(current_logits), vocab_size); + gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, + cudaMemcpyDeviceToDevice, cuda_stream)); + if (beam_batch_size == batch_beam_size) { + current_logits += input_length * vocab_size; + } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { + current_logits += input_length * vocab_size; + } } } @@ -254,7 +256,9 @@ Status ProcessLogits(const OrtValue& logits, // // The output will be float for consideration of precision and easy integration with remaining parts. float* Y_data = next_token_scores.data(); - const CudaT* X_data = reinterpret_cast(next_token_logits.data()); + const CudaT* X_data = (input_length == 1 && beam_batch_size == batch_beam_size) ? + logits_data : + reinterpret_cast(next_token_logits.data()); dispatch_blockwise_softmax_forward( cuda_stream, Y_data, X_data, vocab_size, vocab_size, batch_size * num_beams); From deba41b633e73e826060db1c48449e848affadb1 Mon Sep 17 00:00:00 2001 From: wangye Date: Wed, 22 Jun 2022 16:16:39 +0000 Subject: [PATCH 08/10] alignment --- .../cpu/transformers/beam_search_device_helper.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index ee2c845dde23..ac6a11391790 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -62,11 +62,11 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, template Status ExpandBuffer(void* stream, - const OrtValue& input, - int num_beams, - AllocatorPtr allocator, - OrtValue& expanded, - bool only_copy_shape) { + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape) { // Input shape (batch_size, xxx). The input is required with data type T. // Output shape (batch_size * num_beams, xxx) ORT_UNUSED_PARAMETER(stream); From 5112922543bfa750ca3d4cd6d86c63464fc3fd88 Mon Sep 17 00:00:00 2001 From: wangye Date: Wed, 22 Jun 2022 16:20:21 +0000 Subject: [PATCH 09/10] more alignments --- .../cpu/transformers/subgraph_t5_decoder.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc index f14c28e032ce..7918a333094c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -167,11 +167,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds( OrtValue expanded_hidden_states; if (is_output_float16_) { ORT_RETURN_IF_ERROR(expand_buffer_float16_func(stream, - encoder_fetches[j], - num_beam, - allocator, - expanded_hidden_states, - true)); + encoder_fetches[j], + num_beam, + allocator, + expanded_hidden_states, + true)); } else { ORT_RETURN_IF_ERROR(expand_buffer_float_func(stream, encoder_fetches[j], From d669a16cb9a45a3014af73ca63d295dbc4481520 Mon Sep 17 00:00:00 2001 From: wangye Date: Wed, 22 Jun 2022 17:03:30 +0000 Subject: [PATCH 10/10] review comments --- .../transformers/beam_search_device_helper.cc | 17 ++++++++++------- .../transformers/beam_search_device_helper.cc | 14 ++++++++------ 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index ac6a11391790..7b163dd923a3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -60,6 +60,7 @@ void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, } } +// TODO(wy): Dispatch it to avoid passing multiple functions to interface. template Status ExpandBuffer(void* stream, const OrtValue& input, @@ -241,23 +242,23 @@ Status ProcessLogits(const OrtValue& logits, // const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); auto input_length = logits_shape[1]; - auto beam_batch_size = logits_shape[0]; + auto logits_batch_size = logits_shape[0]; // Get logits for the last token: // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size) // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state->next_token_logits; - if (input_length > 1 || beam_batch_size == batch_size) { + if (input_length > 1 || logits_batch_size == batch_size) { const T* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { gsl::span source(current_logits, vocab_size); gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, static_cast(vocab_size)); gsl::copy(source, target); - if (beam_batch_size == batch_beam_size) { + if (logits_batch_size == batch_beam_size) { current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { + } else if (logits_batch_size == batch_size && i % num_beams == num_beams - 1) { current_logits += input_length * vocab_size; } } @@ -265,7 +266,9 @@ Status ProcessLogits(const OrtValue& logits, // #ifdef DEBUG_BEAM_SEARCH dumper->Print("logits", logits); - dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + if (input_length > 1 || logits_batch_size == batch_size) { + dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + } #endif // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) @@ -274,7 +277,7 @@ Status ProcessLogits(const OrtValue& logits, // SoftmaxCPU( batch_beam_size, // rows vocab_size, // elements per row - (input_length == 1 && beam_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(), + (input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(), next_token_scores.data(), true, thread_pool)); @@ -640,7 +643,7 @@ Status UpdateDecoderFeeds( TensorShape input_ids_shape(&dims[0], 2); Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); - // TODO: decouple has_hidden_state with full input_ids + // TODO(wy): decouple has_hidden_state with full input_ids if (has_hidden_state) { gsl::copy(beam_next_tokens, input_ids.GetMutable()->MutableDataAsSpan()); } else { diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index 96d30d13fbc5..b712908259da 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -221,7 +221,7 @@ Status ProcessLogits(const OrtValue& logits, // const TensorShape& logits_shape = logits.Get().Shape(); ORT_ENFORCE(logits_shape.NumDimensions() == 3); auto input_length = logits_shape[1]; - auto beam_batch_size = logits_shape[0]; + auto logits_batch_size = logits_shape[0]; cudaStream_t cuda_stream = reinterpret_cast(stream); @@ -231,16 +231,16 @@ Status ProcessLogits(const OrtValue& logits, // gsl::span& next_token_logits = beam_state->next_token_logits; // TODO(tianleiwu): use one kernel to replace a loop of memory copy. - if (input_length > 1 || beam_batch_size == batch_size) { + if (input_length > 1 || logits_batch_size == batch_size) { const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { gsl::span source(reinterpret_cast(current_logits), vocab_size); gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); - if (beam_batch_size == batch_beam_size) { + if (logits_batch_size == batch_beam_size) { current_logits += input_length * vocab_size; - } else if (beam_batch_size == batch_size && i % num_beams == num_beams - 1) { + } else if (logits_batch_size == batch_size && i % num_beams == num_beams - 1) { current_logits += input_length * vocab_size; } } @@ -248,7 +248,9 @@ Status ProcessLogits(const OrtValue& logits, // #ifdef DEBUG_BEAM_SEARCH dumper->Print("logits", logits); - dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + if (input_length > 1 || logits_batch_size == batch_size) { + dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + } #endif // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) @@ -256,7 +258,7 @@ Status ProcessLogits(const OrtValue& logits, // // The output will be float for consideration of precision and easy integration with remaining parts. float* Y_data = next_token_scores.data(); - const CudaT* X_data = (input_length == 1 && beam_batch_size == batch_beam_size) ? + const CudaT* X_data = (input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : reinterpret_cast(next_token_logits.data());