Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,14 @@ class SplitDimensionM: public CommonOptimizations::SubgraphPass {

void reshape_subgraph(const std::shared_ptr<op::Subgraph>& subgraph, const ov::Shape& shape, size_t batch_m_dim, size_t new_m_dim);

static size_t get_dim_M(const ov::Shape& shape) {
return *(shape.rbegin() + dim_M_index);
}

size_t m_concurrency;

static const size_t min_kernel_m;
static const size_t dim_M_index;
};
} // namespace pass
} // namespace snippets
Expand Down
16 changes: 11 additions & 5 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
#include "snippets/utils/utils.hpp"

namespace {
size_t get_dim_M(const ov::Shape& shape) {
return *(shape.rbegin() + 1);
}
bool is_prime_number(size_t value) {
if (ov::snippets::utils::one_of(value, 2lu, 3lu)) return true;
if (value == 1 || value % 2 == 0 || value % 3 == 0) return false;
Expand All @@ -28,6 +25,7 @@ namespace snippets {
namespace pass {

const size_t SplitDimensionM::min_kernel_m = 32;
const size_t SplitDimensionM::dim_M_index = 1;

bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>& node) {
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(node);
Expand Down Expand Up @@ -113,13 +111,15 @@ std::vector<size_t> SplitDimensionM::get_updated_order(const std::vector<size_t>
}

ov::snippets::VectorDims SplitDimensionM::reshape_m_dim(ov::snippets::VectorDims shape, size_t m_index, size_t batch_m_dim, size_t new_m_dim) {
OPENVINO_ASSERT(m_index < shape.size(), "Incorrect M index: it should be less than target shape rank");
if (shape[m_index] == 1)
return unsqueeze_m_dim(std::move(shape), m_index);
shape[m_index] = new_m_dim;
shape.insert(shape.begin() + m_index, batch_m_dim);
return shape;
}
ov::snippets::VectorDims SplitDimensionM::unsqueeze_m_dim(ov::snippets::VectorDims shape, size_t m_index) {
OPENVINO_ASSERT(m_index < shape.size(), "Incorrect M index: it should be less than target shape rank");
shape.insert(shape.begin() + m_index, 1);
return shape;
}
Expand Down Expand Up @@ -194,6 +194,7 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
};

auto get_updated_shape = [&](const ov::snippets::VectorDims& shape, size_t m_index, bool split_m_dim) {
OPENVINO_ASSERT(m_index < shape.size(), "Dimension index must be less than shape rank");
const auto current_m_dim = shape[m_index];
OPENVINO_ASSERT(!split_m_dim || current_m_dim == 1 || current_m_dim == m_dim, "Incorrect shape for splitting!");
const auto new_shape = split_m_dim ? reshape_m_dim(shape, m_index, batch_m_dim, new_m_dim) : unsqueeze_m_dim(shape, m_index);
Expand All @@ -205,7 +206,8 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
const auto order_constant = ov::as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
OPENVINO_ASSERT(order_constant != nullptr, "Transpose must have Constant order");
const auto order = order_constant->cast_vector<size_t>();
const auto m_index = is_input ? order[order.size() - 2] : order.size() - 2; // Index of M dimension in the previous order
const auto forward_index = order.size() - 1 - dim_M_index;
const auto m_index = is_input ? order[forward_index] : forward_index; // Index of M dimension in the previous order
const auto new_order = get_updated_order(order, m_index);
transpose->set_argument(1, std::make_shared<ov::op::v0::Constant>(order_constant->get_element_type(), ov::Shape{new_order.size()}, new_order));
return m_index;
Expand All @@ -217,9 +219,13 @@ void SplitDimensionM::reshape_subgraph(const std::shared_ptr<op::Subgraph>& subg
return;

const auto shape = param->get_partial_shape().get_shape();
// if the index of dimension M is equal or greater than Shape rank, no need to reshape it.
if (shape.size() <= dim_M_index)
return;

const auto consumers = param->get_output_target_inputs(0);
const auto shared_consumer = consumers.begin()->get_node()->shared_from_this();
auto m_index = shape.size() - 2;
auto m_index = shape.size() - 1 - dim_M_index;
if (ov::is_type<ov::op::v1::Transpose>(shared_consumer)) {
m_index = reshape_transpose(shared_consumer, true);
}
Expand Down
10 changes: 10 additions & 0 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,16 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
run();
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM_ScalarParams) {
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1}, {64}, {8, 64, 512}},
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {}, {},
{8, 1, 64, 512}, {8, 512, 512}});
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(16);
run();
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Reshape_extraction) {
const auto& f = MHAWithExtractedReshapeFunction(std::vector<PartialShape>{{400, 196, 80},
{400, 80, 196},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,10 @@ std::shared_ptr<ov::Model> MHASelectSplitMFunction::initReference() const {
auto param2 = std::make_shared<ov::opset1::Parameter>(precision, input_shapes[4]);
ov::ParameterVector ngraphParam = {param0, param1, addParam, selectParam, param2};

auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) {
auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) -> std::shared_ptr<ov::Node> {
if (new_shape.empty()) {
return node;
}
auto shape_const = ov::op::v0::Constant::create(ov::element::i32, {new_shape.size()}, new_shape);
return std::make_shared<ov::op::v1::Reshape>(node, shape_const, true);
};
Expand Down
Loading