Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
698c663
change interface of BeamSearch op
tianleiwu Apr 11, 2022
23c1e38
Merge remote-tracking branch 'origin/master' into EncoderDecoderBeamS…
tianleiwu Apr 27, 2022
75d9168
update interface for cpu
tianleiwu May 3, 2022
ee950f1
update decoder i/o names
tianleiwu May 3, 2022
c4580d0
Add comments for subgraph inputs/outputs
tianleiwu May 3, 2022
261106e
int64 inputs
tianleiwu May 3, 2022
0a4f2a3
move cpu state out of InitBeamState
tianleiwu May 4, 2022
f2a31ca
fix linux build
tianleiwu May 6, 2022
dfc2db3
fix cuda build
tianleiwu May 6, 2022
5c321a8
encoder subgraph dynamic axes to fixed dim value
tianleiwu May 17, 2022
9287bd3
Merge remote-tracking branch 'origin/master' into EncoderDecoderBeamS…
tianleiwu May 18, 2022
beddbd7
add --separate_encoder_and_decoder_init
tianleiwu May 18, 2022
9ad1ec4
correct name t5-3B => t5-3b, t5-11B => t5-11b
tianleiwu May 18, 2022
59042e7
fix export
tianleiwu May 18, 2022
71bd39c
Merge branch 'EncoderDecoderBeamSearch' of https://github.com/microso…
tianleiwu May 18, 2022
5630631
update comments
tianleiwu May 18, 2022
c5687c9
Enable dump
tianleiwu May 18, 2022
691cd4e
update sequence length, and no position for T5
tianleiwu May 24, 2022
71aaba9
update decoder inputs
tianleiwu May 24, 2022
bc28320
include header with directory
tianleiwu May 24, 2022
664c75c
fix line_length>120 and spelling errors
tianleiwu May 25, 2022
2eb4ba3
make sure line_length <= 120
tianleiwu May 25, 2022
2d331d3
fix cpplint warnings
tianleiwu May 26, 2022
219479f
include relative path
tianleiwu May 26, 2022
25397bf
Merge branch 'EncoderDecoderBeamSearch' of https://github.com/microso…
tianleiwu May 31, 2022
6777f47
add --use_int32_inputs in convert t5 to onnx
tianleiwu May 31, 2022
237546f
update cuda device helper to use int32 inputs
tianleiwu Jun 1, 2022
9ae7253
Fix lint warnings
tianleiwu Jun 1, 2022
ea4963d
Merge remote-tracking branch 'origin/master' into EncoderDecoderBeamS…
tianleiwu Jun 1, 2022
ae0e72d
Merge branch 'EncoderDecoderBeamSearch' of https://github.com/microso…
tianleiwu Jun 1, 2022
235f63d
use constexpr for past input/output index
tianleiwu Jun 2, 2022
3f180fe
initial sequence length 1
tianleiwu Jun 2, 2022
fb5c5d7
fix ORT_ENFORCE of hypothesis_buffer_offset_
tianleiwu Jun 2, 2022
9b2c4d0
init sequence from decoder_input_ids in cpu
tianleiwu Jun 6, 2022
28b934e
Merge branch 'EncoderDecoderBeamSearch' of https://github.com/microso…
tianleiwu Jun 6, 2022
dd1aee2
fix cuda error
tianleiwu Jun 6, 2022
313a3b5
fix linux build
tianleiwu Jun 6, 2022
ac6b9e6
add tests and default int32 inputs
tianleiwu Jun 7, 2022
c3beb84
Merge branch 'EncoderDecoderBeamSearch' of https://github.com/Microso…
tianleiwu Jun 7, 2022
29deedb
add more tests; allow t5 in one step
tianleiwu Jun 7, 2022
221a052
update for lint warning
tianleiwu Jun 7, 2022
05142ff
remove useless comment
tianleiwu Jun 7, 2022
d3cff62
update comments; disable dump
tianleiwu Jun 8, 2022
13f113d
add comments
tianleiwu Jun 8, 2022
84e4729
address review feedback
tianleiwu Jun 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>decoder_start_token_id</tt> : int</dt>
<dd>The id of the token that indicates decoding starts.</dd>
<dt><tt>early_stopping</tt> : int</dt>
<dd>early stop or not</dd>
<dt><tt>encoder_decoder_init</tt> : graph</dt>
<dd>subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>encoder</tt> : graph</dt>
<dd>The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>eos_token_id</tt> : int (required)</dt>
<dd>The id of the end-of-sequence token</dd>
<dt><tt>model_type</tt> : int</dt>
Expand Down
685 changes: 127 additions & 558 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc

Large diffs are not rendered by default.

92 changes: 61 additions & 31 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
// Licensed under the MIT License.

#pragma once

#include <memory>
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "beam_search_parameters.h"
#include "gpt_subgraph.h"
#include "beam_search_device_helper.h"
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"

namespace onnxruntime {
class FeedsFetchesManager;
Expand All @@ -20,7 +24,11 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
class BeamSearch : public IControlFlowKernel {
public:
BeamSearch(const OpKernelInfo& info)
: IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) {
: IControlFlowKernel(info),
encoder_feeds_fetches_manager_(nullptr),
decoder_feeds_fetches_manager_(nullptr),
cuda_stream_(nullptr),
dumper_(nullptr) {
Init(info);
}

Expand All @@ -36,54 +44,76 @@ class BeamSearch : public IControlFlowKernel {
void SetComputeStream(void* stream) { cuda_stream_ = stream; }
void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; }

// device helpers that is same for both GPT and encoder-decoder models.
void SetDeviceHelpers(
// const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const BeamSearchDeviceHelper::TopkFunc& topk_func) {
// create_inputs_func_ = create_inputs_func;
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
}

// Type dependent helpers: float
void SetDeviceHelpers(
const BeamSearchDeviceHelper::TopkFunc& topk_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<float>& process_logits_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_fp16_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::UpdateFeedsFunc<float>& update_feeds_func) {
const BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_fp16_func) {
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
device_copy_func_ = device_copy_func;
device_copy_int32_func_ = device_copy_int32_func;
process_logits_func_ = process_logits_func;
process_logits_fp16_func_ = process_logits_fp16_func;
init_beam_state_func_ = init_beam_state_func;
device_copy_func_ = device_copy_func;
update_feeds_func_ = update_feeds_func;
init_beam_state_fp16_func_ = init_beam_state_fp16_func;
}

// Type dependent helpers: MLFloat16
void SetDeviceHelpers(
const BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_func,
const BeamSearchDeviceHelper::UpdateFeedsFunc<MLFloat16>& update_feeds_func) {
process_logits_fp16_func_ = process_logits_func;
init_beam_state_fp16_func_ = init_beam_state_func;
update_feeds_fp16_func_ = update_feeds_func;
void SetDeviceHelpers_Gpt(
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<float>& update_gpt_feeds_func,
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<MLFloat16>& update_gpt_feeds_fp16_func) {
update_gpt_feeds_func_ = update_gpt_feeds_func;
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
}

// device helpers for encoder-decoder model like T5
void SetDeviceHelpers_EncoderDecoder(
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float>& update_decoder_feeds_func,
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func) {
update_decoder_feeds_func_ = update_decoder_feeds_func;
update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
}

private:
// Device specific functions
BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_;
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
BeamSearchDeviceHelper::TopkFunc topk_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc<float> process_logits_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<float> init_beam_state_func_;
BeamSearchDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
BeamSearchDeviceHelper::UpdateFeedsFunc<float> update_feeds_func_;
BeamSearchDeviceHelper::DeviceCopyFunc<int32_t> device_copy_int32_func_;

BeamSearchDeviceHelper::ProcessLogitsFunc<float> process_logits_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16> process_logits_fp16_func_;

BeamSearchDeviceHelper::InitBeamStateFunc<float> init_beam_state_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16> init_beam_state_fp16_func_;
BeamSearchDeviceHelper::UpdateFeedsFunc<MLFloat16> update_feeds_fp16_func_;

//------------------------------------------------------------
// Device specific functions for GPT
//------------------------------------------------------------
BeamSearchDeviceHelper::UpdateGptFeedsFunc<float> update_gpt_feeds_func_;
BeamSearchDeviceHelper::UpdateGptFeedsFunc<MLFloat16> update_gpt_feeds_fp16_func_;

//------------------------------------------------------------
// Device specific functions for encoder-decoder model like T5
//------------------------------------------------------------
BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;

BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float> update_decoder_feeds_func_;
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16> update_decoder_feeds_fp16_func_;

//------------------------------------------------------------
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
//------------------------------------------------------------
std::unique_ptr<GptSubgraph> gpt_subgraph_;
FeedsFetchesManager* feeds_fetches_manager_;
std::unique_ptr<T5EncoderSubgraph> t5_encoder_subgraph_;
std::unique_ptr<T5DecoderSubgraph> t5_decoder_subgraph_;
FeedsFetchesManager* encoder_feeds_fetches_manager_;
FeedsFetchesManager* decoder_feeds_fetches_manager_;

void* cuda_stream_;

Expand Down
Loading