Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable inference with a merged decoder in ORTModelForCausalLM #647

Merged
merged 78 commits into from Feb 15, 2023

Conversation

JingyaHuang
Copy link
Collaborator

@JingyaHuang JingyaHuang commented Dec 27, 2022

What does this PR do?

Enable the use of merged decoders in ORT modeling.

  • Check if it works for large proto, and add a saving option.
  • Enable loading, saving, and automatical merging in ORTModelDecoder.
  • Enable inference with a merged model in ORTDecoder (New input use_cache + dummy inputs for past_key_values)
  • Enable IOBinding for merged model (bind the new input use_cache)
  • Tests.

To discuss:

  • Where should the merging be applied? Shall it be automatically applied?

In current logic, if use_merged=True, the merging will be automatically inferred and applied if necessary. But maybe in exporter, we can also add an option of merging.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 27, 2022

The documentation is not available anymore as the PR was closed or merged.

@JingyaHuang JingyaHuang changed the title Enable merged decoder in ORTModel Enable inference with a merged decoder in ORTModelForCausalLM Jan 2, 2023
@JingyaHuang JingyaHuang marked this pull request as ready for review January 4, 2023 18:10
@@ -263,20 +288,57 @@ def prepare_io_binding(

return io_binding, output_shapes, output_buffers

def prepare_inputs_for_merged(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a short description of this method?

Copy link
Collaborator Author

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some questions

optimum/commands/export/onnx.py Outdated Show resolved Hide resolved
@@ -507,6 +517,8 @@ def __init__(
f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value "
"of use_present_in_outputs value will be used for the outputs."
)
self.is_merged = False
self.use_cache_branch = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between use_cache_branch and use_past and use_past_in_inputs ? I mean that use_cache_branch must for the case of merged decoder, but why do we need to distinguish them?

And does use_cache_branch urges use_past=True?

Copy link
Collaborator

@fxmarty fxmarty Feb 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does use_cache_branch urges use_past=True?

Yes, in other cases use_cache_branch does not make sense.

About the difference on use_past and use_past_in_inputs, it seems like code legacy that could be simplified. Or I miss something @michaelbenayoun ?

use_cache_branch is a flag indicating that for the merged decoder case, we use the cache branch of the controlflow. This flag is used in several places:
image
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_past is the legacy here.
Basically you have two "use past":

  1. use_past_in_inputs: inputs will have past key values
  2. use_present_in_outputs: outputs will have past key values

If you set only use_past, it sets both.

Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
Copy link
Collaborator Author

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other questions for you @fxmarty

optimum/exporters/onnx/constants.py Show resolved Hide resolved
optimum/exporters/onnx/base.py Show resolved Hide resolved
optimum/exporters/onnx/config.py Show resolved Hide resolved
Copy link
Collaborator Author

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks a lot for helping to wrap up the PR.

optimum/onnxruntime/modeling_decoder.py Outdated Show resolved Hide resolved
optimum/exporters/onnx/config.py Show resolved Hide resolved
fxmarty and others added 2 commits February 14, 2023 14:30
Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
optimum/exporters/onnx/base.py Show resolved Hide resolved
@@ -507,6 +517,8 @@ def __init__(
f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value "
"of use_present_in_outputs value will be used for the outputs."
)
self.is_merged = False
self.use_cache_branch = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_past is the legacy here.
Basically you have two "use past":

  1. use_past_in_inputs: inputs will have past key values
  2. use_present_in_outputs: outputs will have past key values

If you set only use_past, it sets both.

optimum/onnx/graph_transformations.py Outdated Show resolved Hide resolved
optimum/onnx/graph_transformations.py Outdated Show resolved Hide resolved
optimum/onnx/graph_transformations.py Outdated Show resolved Hide resolved
optimum/onnx/graph_transformations.py Outdated Show resolved Hide resolved
optimum/onnxruntime/base.py Show resolved Hide resolved
optimum/onnxruntime/modeling_decoder.py Outdated Show resolved Hide resolved
optimum/onnxruntime/modeling_seq2seq.py Show resolved Hide resolved
fxmarty and others added 11 commits February 14, 2023 19:18
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
@fxmarty fxmarty merged commit 2155fbe into main Feb 15, 2023
@fxmarty fxmarty deleted the enable-merged-modeling branch February 15, 2023 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants