Skip to content

Commit

Permalink
PagedAttention Transformation: Rank alignment for replacements (#24690)…
Browse files Browse the repository at this point in the history
… (#24713)

Reapplied #24690 to the release branch.

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
  • Loading branch information
3 people committed May 27, 2024
1 parent a6335d2 commit 67111cc
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,5 @@ class PrevSequenceLengthPattern;
class ov::pass::PrevSequenceLengthPattern : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("PrevSequenceLengthPattern", "0");
explicit PrevSequenceLengthPattern(const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
std::shared_ptr<ov::Node>);
explicit PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len, std::shared_ptr<ov::Node> batch_dim);
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

using namespace ov::op;

ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
const std::shared_ptr<ov::op::v1::Subtract>& prev_max_seq_len,
std::shared_ptr<ov::Node> batch_dim) {
ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(std::shared_ptr<ov::Node> prev_max_seq_len,
std::shared_ptr<ov::Node> batch_dim) {
MATCHER_SCOPE(PrevSequenceLengthPattern);
// The transformation addresses two cases that look similar: (1) previous sequence length, (2) batch size in
// kv-cache state In first case it should replace it by prev_max_seq_len. For the second case, connect to batch_dim.
Expand All @@ -39,30 +38,23 @@ ov::pass::PrevSequenceLengthPattern::PrevSequenceLengthPattern(
auto axis = gather_index->cast_vector<int64_t>().at(0);
auto kv_init_shape = pattern_map.at(kv_past).get_node()->get_input_partial_shape(0);
auto target_type = gather->get_output_element_type(0);
std::shared_ptr<ov::Node> replacement;
if (kv_init_shape[axis].is_static() && kv_init_shape[axis].get_length() == 0) {
// this is a sequence dimension based on how the initialization expression is build for stateful models
std::shared_ptr<ov::Node> replacement;
if (prev_max_seq_len->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(prev_max_seq_len, target_type);
} else {
replacement = prev_max_seq_len;
}
replace_node(
gather,
std::make_shared<v1::Reshape>(replacement, v0::Constant::create(element::i64, Shape{1}, {1}), false));
return true;
} else { // assumption that any other axis should point to batch dimension, precise reasoning is too complex
// (TODO)
// this is a batch dimension
std::shared_ptr<ov::Node> replacement;
if (batch_dim->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(batch_dim, target_type);
} else {
replacement = batch_dim;
}
replace_node(gather, replacement);
return true;
replacement = prev_max_seq_len;
} else {
// assumption that any other axis should point to batch dimension, precise reasoning is too complex
// TODO: provide more reliable check
replacement = batch_dim;
}
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}
auto required_shape = gather->get_output_partial_shape(0);
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
replace_node(gather, replacement);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(seq, matcher_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/cc/pass/itt.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
Expand All @@ -29,11 +30,13 @@ ov::pass::TotalSequenceLengthPattern::TotalSequenceLengthPattern(
// use symbolic infra or look at the constant input
auto gather = m.get_match_root();
auto target_type = gather->get_output_element_type(0);
std::shared_ptr<Node> replacement;
if (max_context_len->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(max_context_len, target_type);
} else {
replacement = max_context_len;
std::shared_ptr<Node> replacement = max_context_len;
if (replacement->get_output_element_type(0) != target_type) {
replacement = std::make_shared<v0::Convert>(replacement, target_type);
}
auto required_shape = gather->get_output_partial_shape(0);
if (replacement->get_output_partial_shape(0) != required_shape && required_shape.rank().is_static()) {
replacement = op::util::reshapeTo(replacement, Shape(required_shape.rank().get_length(), 1));
}
replace_node(gather, replacement);
return true;
Expand Down

0 comments on commit 67111cc

Please sign in to comment.