New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable inference with a merged decoder in ORTModelForCausalLM
#647
Conversation
The documentation is not available anymore as the PR was closed or merged. |
ORTModelForCausalLM
@@ -263,20 +288,57 @@ def prepare_io_binding( | |||
|
|||
return io_binding, output_shapes, output_buffers | |||
|
|||
def prepare_inputs_for_merged( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a short description of this method?
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some questions
@@ -507,6 +517,8 @@ def __init__( | |||
f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value " | |||
"of use_present_in_outputs value will be used for the outputs." | |||
) | |||
self.is_merged = False | |||
self.use_cache_branch = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between use_cache_branch
and use_past
and use_past_in_inputs
? I mean that use_cache_branch
must for the case of merged decoder, but why do we need to distinguish them?
And does use_cache_branch
urges use_past=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does use_cache_branch urges use_past=True?
Yes, in other cases use_cache_branch
does not make sense.
About the difference on use_past
and use_past_in_inputs
, it seems like code legacy that could be simplified. Or I miss something @michaelbenayoun ?
use_cache_branch
is a flag indicating that for the merged decoder case, we use the cache branch of the controlflow. This flag is used in several places:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_past
is the legacy here.
Basically you have two "use past":
use_past_in_inputs
: inputs will have past key valuesuse_present_in_outputs
: outputs will have past key values
If you set only use_past
, it sets both.
Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other questions for you @fxmarty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks a lot for helping to wrap up the PR.
Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
@@ -507,6 +517,8 @@ def __init__( | |||
f"use_past = {use_past} is different than use_present_in_outputs = {use_present_in_outputs}, the value " | |||
"of use_present_in_outputs value will be used for the outputs." | |||
) | |||
self.is_merged = False | |||
self.use_cache_branch = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_past
is the legacy here.
Basically you have two "use past":
use_past_in_inputs
: inputs will have past key valuesuse_present_in_outputs
: outputs will have past key values
If you set only use_past
, it sets both.
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
What does this PR do?
Enable the use of merged decoders in ORT modeling.
ORTModelDecoder
.ORTDecoder
(New inputuse_cache
+ dummy inputs for past_key_values)use_cache
)To discuss:
In current logic, if
use_merged=True
, the merging will be automatically inferred and applied if necessary. But maybe in exporter, we can also add an option of merging.