Skip to content

GPT2 one step beam search update with configuration support#7425

Merged
tianleiwu merged 25 commits intomicrosoft:masterfrom
xi-liu-ds:xi-liu-ds/beam_search_update
Apr 29, 2021
Merged

GPT2 one step beam search update with configuration support#7425
tianleiwu merged 25 commits intomicrosoft:masterfrom
xi-liu-ds:xi-liu-ds/beam_search_update

Conversation

@xi-liu-ds
Copy link
Copy Markdown
Member

@xi-liu-ds xi-liu-ds commented Apr 22, 2021

Description: In this PR we add update one-step beam search with supporting early stopping, temperature, repetition penalty, length penalty, excluded token ids and sampling in ONNX graph.

Motivation and Context

  • Customers would like some configuration support in beam search. This PR will fulfill this purpose by adding GPT2LMHeadModel_BeamSearchStepConfiguration class with the following configuration included directly into ONNX compute graph.
    • early stopping finished beams
    • temperature
    • repetition penalty
    • length penalty
    • excluded token ids
    • sampling
    • ignore end of sentence token in model inference

How to run
In bash/CMD/PowerShell, run:

python onnxruntime/python/tools/transformers/convert_to_onnx.py \
--model_name_or_path=[path/to/model/folder] --model_class=GPT2LMHeadModel_BeamSearchStepConfiguration--output=[path/to/output/onnx_file_name] -o --precision=int8

and optionally with --input_test_file=[path/to/test/file] and setting different configuration flags (--ignore_eos, --repetition_penalty, --temperature, --excluded_token_ids, --length_penalty, --do_sample, --do_sample_top_p, --do_sample_top_k)

@xi-liu-ds xi-liu-ds requested a review from a team as a code owner April 22, 2021 21:37
@xi-liu-ds xi-liu-ds changed the title GPT2 one step beam search update with early stopping GPT2 one step beam search update with configuration support Apr 22, 2021
Comment thread onnxruntime/python/tools/transformers/convert_to_onnx.py Outdated
@tianleiwu
Copy link
Copy Markdown
Contributor

Please add some tests. For example, in test_gpt2.py?

@microsoft microsoft deleted a comment from xi-liu-ds Apr 23, 2021
Comment thread onnxruntime/python/tools/transformers/convert_to_onnx.py Outdated
def top_k_top_p_filtering(log_probs, top_p=1.0, top_k=0):
'''Set tail event (out of top_p) to a big negative number'''
sorted_log_probs, sorted_indices = torch.sort(log_probs, descending=True)
cumulative_probs = torch.cumsum(sorted_log_probs.exp(), dim=-1)
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu Apr 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CumSum for float16 need onnx opset version 14. User will encounter issue in float16 model right now since pytorch and ORT does not support opset 14 right now.

Copy link
Copy Markdown
Member Author

@xi-liu-ds xi-liu-ds Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested on Pytorch and ORT 1.7.0 and it works fine. Could you clarify? Thanks

@xi-liu-ds
Copy link
Copy Markdown
Member Author

Please add some tests. For example, in test_gpt2.py?

Done

Comment thread onnxruntime/python/tools/transformers/benchmark_gpt2.py Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants