Skip to content

Commit

Permalink
[GPU] Improve swiglu fusion pass (#24537)
Browse files Browse the repository at this point in the history
### Details:
 - Allow reversed inputs order to last multiply op for swiglu pattern
  • Loading branch information
vladimir-paramuzov committed May 17, 2024
1 parent 88b0309 commit 7ffea61
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ SwiGLUFusion::SwiGLUFusion() {
auto mul = std::dynamic_pointer_cast<ov::op::v1::Multiply>(pattern_map.at(mul_m).get_node_shared_ptr());
if (!mul || transformation_callback(mul))
return false;
if (mul->input_value(1).get_index() != 1)

size_t split_in_idx = ov::is_type<ov::op::v4::Swish>(mul->get_input_node_shared_ptr(0)) ? 1 : 0;
if (mul->input_value(split_in_idx).get_index() != 1)
return false;

auto variadic_split = std::dynamic_pointer_cast<ov::op::v1::VariadicSplit>(pattern_map.at(variadic_split_m).get_node_shared_ptr());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,28 @@ TEST_F(TransformationTestsF, SwiGLUFusionTest3) {
}
}

TEST_F(TransformationTestsF, SwiGLUFusionTest3ReverseOrder) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 6 });
auto axis_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1});
auto split_lengths_const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{2}, {3, -1});
auto variadic_split = std::make_shared<ov::op::v1::VariadicSplit>(input, axis_const, split_lengths_const);
auto swish = std::make_shared<ov::op::v4::Swish>(variadic_split->output(0));
auto mul = std::make_shared<ov::op::v1::Multiply>(variadic_split->output(1), swish);

model = std::make_shared<ov::Model>(ov::NodeVector{mul}, ov::ParameterVector{input});
manager.register_pass<SwiGLUFusion>();
}
{
int64_t axis = -1;
int64_t split_lenghts = 3;
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 6 });
auto swiglu = std::make_shared<op::SwiGLU>(input, axis, split_lenghts, ov::element::f16);

model_ref = std::make_shared<ov::Model>(ov::NodeVector{swiglu}, ov::ParameterVector{input});
}
}

TEST_F(TransformationTestsF, SwiGLUFusionTest4) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{ -1, -1, 6 });
Expand Down

0 comments on commit 7ffea61

Please sign in to comment.