-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PagedAttention Transformation: Rank alignment for replacements #24690
PagedAttention Transformation: Rank alignment for replacements #24690
Conversation
…_heads dimension broadcasted in SDPA itself and not in the UBR pattern.
…ttention/prev_sequence_length_pattern.cpp Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
…tion. Allowed optional Reshape in UBR pattern (appeared in one of MQA cases).
…n matching with Or pattern and multi-output nodes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirm, that expected models are fixed by this PR.
} | ||
if (replacement->get_output_element_type(0) != target_type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible that we will have f16 -> fp32, int -> fp conversions here?
Could this cause any potential issues in accuracy?
replacement = std::make_shared<v0::Convert>(replacement, target_type); | ||
} | ||
auto required_shape = gather->get_output_partial_shape(0); | ||
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: It's better to change order of the checks:
required_shape.rank().is_static() && replacement->get_output_partial_shape(0) != required_shape
replacement = std::make_shared<v0::Convert>(replacement, target_type); | ||
} | ||
auto required_shape = gather->get_output_partial_shape(0); | ||
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: It's better to change order of the checks:
required_shape.rank().is_static() && replacement->get_output_partial_shape(0) != required_shape
@CuriousPanCake did we test these changes via #24634? |
Not yet. I'll test ASAP |
See "Confirm, that expected models are fixed by this PR." above |
b0dfa6a
into
openvinotoolkit:master
…inotoolkit#24690) During the elimination of dependencies from `beam_idx` input and `ReadValue`(s), we are replacing them by the new PA-related inputs and sub-expressions dependent on other remaining inputs. In such replacements we need to guarantee matching shape and element type of old and new nodes. Before this PR it was not guaranteed for shape and sometimes a scalar was replaced by a shape of rank 1 that led to errors like `'start' input is not a scalar`. Now the shape is aligned. --------- Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com> Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
During the elimination of dependencies from
beam_idx
input andReadValue
(s), we are replacing them by the new PA-related inputs and sub-expressions dependent on other remaining inputs. In such replacements we need to guarantee matching shape and element type of old and new nodes. Before this PR it was not guaranteed for shape and sometimes a scalar was replaced by a shape of rank 1 that led to errors like'start' input is not a scalar
. Now the shape is aligned.