From c2b8084ec5df903e873f7a0eb6567168f46a32ac Mon Sep 17 00:00:00 2001 From: Alexander Lyulkov Date: Wed, 8 May 2024 16:50:50 +0300 Subject: [PATCH] Fixed gemm layer --- modules/dnn/src/layers/gemm_layer.cpp | 65 ++++++++++++--------------- 1 file changed, 28 insertions(+), 37 deletions(-) diff --git a/modules/dnn/src/layers/gemm_layer.cpp b/modules/dnn/src/layers/gemm_layer.cpp index f770e4e8bda1..ac0914c2c210 100644 --- a/modules/dnn/src/layers/gemm_layer.cpp +++ b/modules/dnn/src/layers/gemm_layer.cpp @@ -289,62 +289,53 @@ class GemmLayerImpl CV_FINAL : public GemmLayer { virtual Ptr initNgraph(const std::vector >& inputs, const std::vector >& nodes) CV_OVERRIDE { - auto nodeA = nodes[0].dynamicCast()->node; - std::shared_ptr nodeAB; - + ov::Output nodeA = nodes[0].dynamicCast()->node; + ov::Output nodeB; if (const_B) - nodeAB = std::make_shared( - nodeA, - std::make_shared(ov::element::f32, getShape(blobs[0]), blobs[0].data), - trans_a, - trans_b); + nodeB = std::make_shared(ov::element::f32, getShape(blobs[0]), blobs[0].data); else - nodeAB = std::make_shared( + nodeB = nodes[1].dynamicCast()->node; + + int flatten_axis = nodeA.get_shape().size() - nodeB.get_shape().size(); + if (flatten_axis > 0) { + std::vector shape(1 + flatten_axis, 0); + shape[shape.size() - 1] = -1; + nodeA = std::make_shared( nodeA, - nodes[1].dynamicCast()->node, - trans_a, - trans_b); + std::make_shared(ov::element::i32, ov::Shape{shape.size()}, shape.data()), + true); + } + std::shared_ptr nodeAB = std::make_shared(nodeA, nodeB, trans_a, trans_b); if (alpha != 1.0f) { nodeAB = std::make_shared( nodeAB, - std::make_shared(ov::element::f32, ov::Shape{1}, &alpha) - ); + std::make_shared(ov::element::f32, ov::Shape{1}, &alpha)); } if (!have_bias) return Ptr(new InfEngineNgraphNode(nodeAB)); - std::shared_ptr nodeGemm; - if (beta != 1.0f) + ov::Output nodeC; + if (const_C) { - std::shared_ptr nodeC; - if (const_C) - nodeC = std::make_shared( - std::make_shared(ov::element::f32, getShape(blobs.back()), blobs.back().data), - std::make_shared(ov::element::f32, ov::Shape{1}, &beta)); - else - nodeC = std::make_shared( - nodes.back().dynamicCast()->node, - std::make_shared(ov::element::f32, ov::Shape{1}, &beta)); - - nodeGemm = std::make_shared(nodeAB, nodeC, ov::op::AutoBroadcastType::NUMPY); + auto shape_C = blobs.back().total() == blobs.back().size[0] ? ov::Shape{blobs.back().total()} : getShape(blobs.back()); + nodeC = std::make_shared(ov::element::f32, shape_C, blobs.back().data); } else { - if (const_C) - nodeGemm = std::make_shared( - nodeAB, - std::make_shared(ov::element::f32, getShape(blobs.back()), blobs.back().data), - ov::op::AutoBroadcastType::NUMPY); - else - nodeGemm = std::make_shared( - nodeAB, - nodes.back().dynamicCast()->node, - ov::op::AutoBroadcastType::NUMPY); + nodeC = nodes.back().dynamicCast()->node; + } + + if (beta != 1.0f) + { + nodeC = std::make_shared( + nodeC, + std::make_shared(ov::element::f32, ov::Shape{1}, &beta)); } + auto nodeGemm = std::make_shared(nodeAB, nodeC, ov::op::AutoBroadcastType::NUMPY); return Ptr(new InfEngineNgraphNode(nodeGemm)); } #endif // HAVE_DNN_NGRAPH