Skip to content

Commit

Permalink
Fixed gemm layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Lyulkov committed May 8, 2024
1 parent 7b25f39 commit c2b8084
Showing 1 changed file with 28 additions and 37 deletions.
65 changes: 28 additions & 37 deletions modules/dnn/src/layers/gemm_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,62 +289,53 @@ class GemmLayerImpl CV_FINAL : public GemmLayer {
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
{
auto nodeA = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
std::shared_ptr<ov::Node> nodeAB;

ov::Output<ov::Node> nodeA = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
ov::Output<ov::Node> nodeB;
if (const_B)
nodeAB = std::make_shared<ov::op::v0::MatMul>(
nodeA,
std::make_shared<ov::op::v0::Constant>(ov::element::f32, getShape(blobs[0]), blobs[0].data),
trans_a,
trans_b);
nodeB = std::make_shared<ov::op::v0::Constant>(ov::element::f32, getShape(blobs[0]), blobs[0].data);
else
nodeAB = std::make_shared<ov::op::v0::MatMul>(
nodeB = nodes[1].dynamicCast<InfEngineNgraphNode>()->node;

int flatten_axis = nodeA.get_shape().size() - nodeB.get_shape().size();
if (flatten_axis > 0) {
std::vector<int> shape(1 + flatten_axis, 0);
shape[shape.size() - 1] = -1;
nodeA = std::make_shared<ov::op::v1::Reshape>(
nodeA,
nodes[1].dynamicCast<InfEngineNgraphNode>()->node,
trans_a,
trans_b);
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{shape.size()}, shape.data()),
true);
}

std::shared_ptr<ov::Node> nodeAB = std::make_shared<ov::op::v0::MatMul>(nodeA, nodeB, trans_a, trans_b);
if (alpha != 1.0f)
{
nodeAB = std::make_shared<ov::op::v1::Multiply>(
nodeAB,
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, &alpha)
);
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, &alpha));
}

if (!have_bias)
return Ptr<BackendNode>(new InfEngineNgraphNode(nodeAB));

std::shared_ptr<ov::Node> nodeGemm;
if (beta != 1.0f)
ov::Output<ov::Node> nodeC;
if (const_C)
{
std::shared_ptr<ov::Node> nodeC;
if (const_C)
nodeC = std::make_shared<ov::op::v1::Multiply>(
std::make_shared<ov::op::v0::Constant>(ov::element::f32, getShape(blobs.back()), blobs.back().data),
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, &beta));
else
nodeC = std::make_shared<ov::op::v1::Multiply>(
nodes.back().dynamicCast<InfEngineNgraphNode>()->node,
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, &beta));

nodeGemm = std::make_shared<ov::op::v1::Add>(nodeAB, nodeC, ov::op::AutoBroadcastType::NUMPY);
auto shape_C = blobs.back().total() == blobs.back().size[0] ? ov::Shape{blobs.back().total()} : getShape(blobs.back());
nodeC = std::make_shared<ov::op::v0::Constant>(ov::element::f32, shape_C, blobs.back().data);
}
else
{
if (const_C)
nodeGemm = std::make_shared<ov::op::v1::Add>(
nodeAB,
std::make_shared<ov::op::v0::Constant>(ov::element::f32, getShape(blobs.back()), blobs.back().data),
ov::op::AutoBroadcastType::NUMPY);
else
nodeGemm = std::make_shared<ov::op::v1::Add>(
nodeAB,
nodes.back().dynamicCast<InfEngineNgraphNode>()->node,
ov::op::AutoBroadcastType::NUMPY);
nodeC = nodes.back().dynamicCast<InfEngineNgraphNode>()->node;
}

if (beta != 1.0f)
{
nodeC = std::make_shared<ov::op::v1::Multiply>(
nodeC,
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, &beta));
}

auto nodeGemm = std::make_shared<ov::op::v1::Add>(nodeAB, nodeC, ov::op::AutoBroadcastType::NUMPY);
return Ptr<BackendNode>(new InfEngineNgraphNode(nodeGemm));
}
#endif // HAVE_DNN_NGRAPH
Expand Down

0 comments on commit c2b8084

Please sign in to comment.