Skip to content

Commit

Permalink
Merge pull request #25458 from alexlyulkov:al/dnn-openvino-int-support
Browse files Browse the repository at this point in the history
Added int support for OpenVINO dnn backend #25458

Modified dnn OpenVINO integration to support type inference and int operations.

Added OpenVINO support to Cast, CumSum, Expand, Gather, GatherElements, Scatter, ScatterND, Tile layers.
I tried to add Reduce layer, but looks like OpenVINO uses float values inside Reduce operation so it can't pass our int tests.

OpenVINO uses int32 precision for int64 operations, so I've modified input values for int64 tests when backend is OpenVINO.

OpenVINO has a strange behavior with custom layers and int64 values. After model compilation OpenVINO may change types, so the model can have different output type. That's why these tests were disabled:
- Test_ArgMax_Int.random/0, where GetParam() = (4, NGRAPH/CPU)
- Test_ArgMax_Int.random/6, where GetParam() = (11, NGRAPH/CPU)
- Test_Reduce_Int.random/6, where GetParam() = (11, NGRAPH/CPU)
- Test_Reduce_Int.two_axes/6, where GetParam() = (11, NGRAPH/CPU)

Also these tests were temporary disabled, they didn't work on both 4.x and 5.x branches:
- Test_Caffe_layers.layer_prelu_fc/0, where GetParam() = NGRAPH/CPU
- Test_ONNX_layers.LSTM_Activations/0, where GetParam() = NGRAPH/CPU
- Test_ONNX_layers.Quantized_Convolution/0, where GetParam() = NGRAPH/CPU
- Test_ONNX_layers.Quantized_Eltwise_Scalar/0, where GetParam() = NGRAPH/CPU
- Test_TFLite.EfficientDet_int8/0, where GetParam() = NGRAPH/CPU


### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [ ] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
  • Loading branch information
alexlyulkov committed May 15, 2024
1 parent 5bdc419 commit 6af0394
Show file tree
Hide file tree
Showing 22 changed files with 374 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/PR-5.x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ jobs:
Ubuntu2004-x64:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-U20.yaml@main

Ubuntu2004-x64-OpenVINO:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-U20-OpenVINO.yaml@main

Ubuntu2204-x64:
uses: opencv/ci-gha-workflow/.github/workflows/OCV-PR-5.x-U22.yaml@main

Expand Down
71 changes: 59 additions & 12 deletions modules/dnn/src/ie_ngraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,48 @@ ngraphWrappers(const std::vector<Ptr<BackendWrapper> >& ptrs)
return wrappers;
}

ov::element::Type cvTypeToOvType(MatType cvType)
{
switch (cvType) {
case CV_32F:
return ov::element::f32;
case CV_8U:
return ov::element::u8;
case CV_8S:
return ov::element::i8;
case CV_32S:
return ov::element::i32;
case CV_64S:
return ov::element::i64;
default:
CV_Error(Error::StsNotImplemented, format("Unsupported data type %s", typeToString(cvType).c_str()));
}
}

ov::element::Type cvTypeToOvType(const cv::Mat& mat)
{
return cvTypeToOvType(mat.depth());
}

MatType ovTypeToCvType(ov::element::Type ovType)
{
switch (ovType) {
case ov::element::f32:
return CV_32F;
case ov::element::u8:
return CV_8U;
case ov::element::i8:
return CV_8S;
case ov::element::i32:
return CV_32S;
case ov::element::i64:
return CV_64S;
default:
CV_Error(Error::StsNotImplemented, format("Unsupported data type %s", ovType.get_type_name().c_str()));
}
}


class NgraphCustomOp: public ov::op::Op {
public:
OPENVINO_OP(kOpenCVLayersType);
Expand All @@ -60,14 +102,27 @@ class NgraphCustomOp: public ov::op::Op {

void validate_and_infer_types() override
{
std::vector<MatType> inputTypes(get_input_size());
std::vector<MatType> internalTypes;
std::vector<MatType> outputTypes;
for (int i = 0; i < get_input_size(); ++i)
{
inputTypes[i] = ovTypeToCvType(get_input_element_type(i));
}
cvLayer->getTypes(inputTypes, outputs.size(), internals.size(), outputTypes, internalTypes);
for (int i = 0; i < internals.size(); ++i) {
if (internals[i].depth() != internalTypes[i])
internals[i] = cv::Mat(shape(internals[i]), internalTypes[i]);
}

set_output_size(outputs.size());
for (int i = 0; i < outputs.size(); ++i)
{
ov::PartialShape shape;
for (int j = 0; j < outputs[i].dims; ++j) {
shape.push_back(outputs[i].size[j]);
}
set_output_type(i, get_input_element_type(0), shape);
set_output_type(i, cvTypeToOvType(outputTypes[i]), shape);
}
}

Expand Down Expand Up @@ -270,7 +325,7 @@ ov::ParameterVector InfEngineNgraphNet::setInputs(const std::vector<cv::Mat>& in
for (size_t i = 0; i < inputs.size(); i++)
{
std::vector<size_t> shape = getShape<size_t>(inputs[i]);
auto inp = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape(shape));
auto inp = std::make_shared<ov::op::v0::Parameter>(cvTypeToOvType(inputs[i]), ov::Shape(shape));
inp->set_friendly_name(names[i]);

auto it = std::find_if(inputs_vec.begin(), inputs_vec.end(),
Expand Down Expand Up @@ -427,16 +482,7 @@ void NgraphBackendLayer::forward(InputArrayOfArrays inputs, OutputArrayOfArrays

ov::Tensor wrapToNgraphBlob(const Mat& m) {
std::vector<size_t> shape = getShape<size_t>(m);
if (m.type() == CV_32F)
return ov::Tensor(ov::element::f32, shape, m.data);
else if (m.type() == CV_8U)
return ov::Tensor(ov::element::u8, shape, m.data);
else if (m.type() == CV_8SC1)
return ov::Tensor(ov::element::i8, shape, m.data);
else if (m.type() == CV_32SC1)
return ov::Tensor(ov::element::i32, shape, m.data);
else
CV_Error(Error::StsNotImplemented, format("Unsupported data type %s", typeToString(m.type()).c_str()));
return ov::Tensor(cvTypeToOvType(m), shape, m.data);
}


Expand All @@ -445,6 +491,7 @@ NgraphBackendWrapper::NgraphBackendWrapper(int targetId, const cv::Mat& m)
, host((Mat*)&m)
{
blob = wrapToNgraphBlob(m);
hostMatDepth = m.depth();
}

NgraphBackendWrapper::NgraphBackendWrapper(Ptr<BackendWrapper> wrapper)
Expand Down
4 changes: 4 additions & 0 deletions modules/dnn/src/ie_ngraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ namespace cv { namespace dnn {

#ifdef HAVE_DNN_NGRAPH

ov::element::Type cvTypeToOvType(MatType cvType);
ov::element::Type cvTypeToOvType(const cv::Mat& mat);
MatType ovTypeToCvType(ov::element::Type ovType);

class InfEngineNgraphNode;

class InfEngineNgraphNet
Expand Down
9 changes: 5 additions & 4 deletions modules/dnn/src/int8layers/convolution_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ class ConvolutionLayerInt8Impl CV_FINAL : public BaseConvolutionLayerInt8Impl
CV_Assert(!blobs.empty());
CV_Assert_N(inputs.size() >= 1, nodes.size() >= 1);
CV_CheckTypeEQ(weightsMat.type(), CV_8S, "");
CV_CheckTypeEQ(blobs[0].type(), CV_8S, "");
auto ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
std::vector<size_t> dims = ieInpNode.get_shape();
CV_Check(dims.size(), dims.size() >= 3 && dims.size() <= 5, "");
Expand All @@ -581,7 +582,7 @@ class ConvolutionLayerInt8Impl CV_FINAL : public BaseConvolutionLayerInt8Impl
const int inpGroupCn = nodes.size() > 1 ? ieWeights.get_shape()[1] : blobs[0].size[1];
const int group = inpCn / inpGroupCn;

std::vector<size_t> kernel_shape;
std::vector<int64_t> kernel_shape;
if (group != 1)
{
kernel_shape.push_back(group);
Expand All @@ -592,7 +593,7 @@ class ConvolutionLayerInt8Impl CV_FINAL : public BaseConvolutionLayerInt8Impl

if (nodes.size() == 1)
{
ieWeights = std::make_shared<ov::op::v0::Constant>(ov::element::i8, kernel_shape, blobs[0].data);
ieWeights = std::make_shared<ov::op::v0::Constant>(ov::element::i8, ov::Shape(kernel_shape.begin(), kernel_shape.end()), blobs[0].data);
}
else
{
Expand Down Expand Up @@ -655,7 +656,7 @@ class ConvolutionLayerInt8Impl CV_FINAL : public BaseConvolutionLayerInt8Impl
pad_type);
}

std::vector<size_t> shape(conv_node.get_shape().size(), 1);
std::vector<int64_t> shape(conv_node.get_shape().size(), 1);
shape[1] = conv_node.get_shape()[1];
if (biasvec.size() || nodes.size() == 3)
{
Expand All @@ -672,7 +673,7 @@ class ConvolutionLayerInt8Impl CV_FINAL : public BaseConvolutionLayerInt8Impl
for (int i = 0; i < numOutput; ++i) {
ovBias[i] = (biasvec[i] + input_zp * cv::sum(blobs[0].row(i))[0]) * outputMultiplier[i] * output_sc;
}
bias = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape(shape), ovBias.data());
bias = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape(shape.begin(), shape.end()), ovBias.data());
}
conv_node = std::make_shared<ov::op::v1::Add>(conv_node, bias, ov::op::AutoBroadcastType::NUMPY);
}
Expand Down
3 changes: 2 additions & 1 deletion modules/dnn/src/layer_internals.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ struct DataLayer : public Layer

virtual bool supportBackend(int backendId) CV_OVERRIDE
{
return backendId == DNN_BACKEND_OPENCV;
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
}

void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
Expand Down
3 changes: 1 addition & 2 deletions modules/dnn/src/layers/blank_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ class BlankLayerImpl CV_FINAL : public BlankLayer
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
{
auto ieInpNode = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
ov::OutputVector inp{ieInpNode};
auto blank = std::make_shared<ov::op::v0::Concat>(inp, 0);
auto blank = std::make_shared<ov::op::v1::ConvertLike>(ieInpNode, ieInpNode);
return Ptr<BackendNode>(new InfEngineNgraphNode(blank));
}
#endif // HAVE_DNN_NGRAPH
Expand Down
14 changes: 13 additions & 1 deletion modules/dnn/src/layers/cast_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// of this distribution and at http://opencv.org/license.html.

#include "../precomp.hpp"
#include "../op_inf_engine.hpp"
#include "../ie_ngraph.hpp"
#include "layers_common.hpp"


Expand All @@ -19,7 +21,8 @@ class CastLayerImpl CV_FINAL : public CastLayer

virtual bool supportBackend(int backendId) CV_OVERRIDE
{
return backendId == DNN_BACKEND_OPENCV;
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
}

virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
Expand Down Expand Up @@ -83,6 +86,15 @@ class CastLayerImpl CV_FINAL : public CastLayer
inputs[0].convertTo(outputs[0], outputs[0].depth());
}

#ifdef HAVE_DNN_NGRAPH
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
{
auto cast = std::make_shared<ov::op::v0::Convert>(nodes[0].dynamicCast<InfEngineNgraphNode>()->node, cvTypeToOvType(outputType));
return Ptr<BackendNode>(new InfEngineNgraphNode(cast));
}
#endif // HAVE_DNN_NGRAPH

private:
int outputType;
};
Expand Down
15 changes: 1 addition & 14 deletions modules/dnn/src/layers/const_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,10 @@ class ConstLayerImpl CV_FINAL : public ConstLayer
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
{
ov::element::Type dType;
if (blobs[0].depth() == CV_32F) {
dType = ov::element::f32;
} else if (blobs[0].depth() == CV_32S) {
dType = ov::element::i32;
} else if (blobs[0].depth() == CV_8S) {
dType = ov::element::i8;
} else {
CV_Error(Error::StsNotImplemented, format("Unexpected Const data depth: %d", blobs[0].depth()));
}
std::shared_ptr<ov::Node> node =
std::make_shared<ov::op::v0::Constant>(dType,
std::make_shared<ov::op::v0::Constant>(cvTypeToOvType(blobs[0]),
getShape<size_t>(blobs[0]),
blobs[0].data);
if (node->get_element_type() != ov::element::f32) {
node = std::make_shared<ov::op::v0::Convert>(node, ov::element::f32);
}
return Ptr<BackendNode>(new InfEngineNgraphNode(node));
}
#endif // HAVE_DNN_NGRAPH
Expand Down
38 changes: 38 additions & 0 deletions modules/dnn/src/layers/cumsum_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// of this distribution and at http://opencv.org/license.html.

#include "../precomp.hpp"
#include "../op_inf_engine.hpp"
#include "../ie_ngraph.hpp"
#include "layers_common.hpp"

#include <opencv2/dnn/shape_utils.hpp>
Expand All @@ -23,6 +25,12 @@ class CumSumLayerImpl CV_FINAL : public CumSumLayer
setParamsFrom(params);
}

virtual bool supportBackend(int backendId) CV_OVERRIDE
{
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
}

bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
Expand Down Expand Up @@ -151,6 +159,36 @@ class CumSumLayerImpl CV_FINAL : public CumSumLayer
}
}

#ifdef HAVE_DNN_NGRAPH
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
{
std::shared_ptr<ov::op::v0::CumSum> cumsum;
if (nodes.size() == 2)
{
int32_t axis_shape = 1;
auto axis_scalar = std::make_shared<ov::op::v1::Reshape>(
nodes[1].dynamicCast<InfEngineNgraphNode>()->node,
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, &axis_shape),
false);
cumsum = std::make_shared<ov::op::v0::CumSum>(
nodes[0].dynamicCast<InfEngineNgraphNode>()->node,
std::make_shared<ov::op::v0::Convert>(axis_scalar, ov::element::i32),
exclusive_raw,
reverse_raw);
}
else
{
cumsum = std::make_shared<ov::op::v0::CumSum>(
nodes[0].dynamicCast<InfEngineNgraphNode>()->node,
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{}, &axis_raw),
exclusive_raw,
reverse_raw);
}
return Ptr<BackendNode>(new InfEngineNgraphNode(cumsum));
}
#endif // HAVE_DNN_NGRAPH

int axis_raw;
int exclusive_raw;
int reverse_raw;
Expand Down
27 changes: 25 additions & 2 deletions modules/dnn/src/layers/expand_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// of this distribution and at http://opencv.org/license.html.

#include "../precomp.hpp"
#include "../op_inf_engine.hpp"
#include "../ie_ngraph.hpp"
#include <opencv2/dnn/shape_utils.hpp>

namespace cv { namespace dnn {
Expand All @@ -27,8 +29,10 @@ class ExpandLayerImpl CV_FINAL : public ExpandLayer
const_input_1d = params.get("const_input_1d", false);
}

virtual bool supportBackend(int backendId) CV_OVERRIDE {
return backendId == DNN_BACKEND_OPENCV;
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
}

virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
Expand Down Expand Up @@ -145,6 +149,25 @@ class ExpandLayerImpl CV_FINAL : public ExpandLayer
}
}

#ifdef HAVE_DNN_NGRAPH
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
{
auto input_shape = nodes[0].dynamicCast<InfEngineNgraphNode>()->node.get_shape();
CV_CheckGE(target_shape.size(), input_shape.size(), "");

std::vector<int32_t> output_shape(target_shape.begin(), target_shape.end());
for (int i = 1; i < input_shape.size() + 1; ++i)
output_shape[output_shape.size() - i] = std::max(
(int32_t)input_shape[input_shape.size() - i],
output_shape[output_shape.size() - i]);

auto shape_node = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{output_shape.size()}, output_shape.data());
auto expand = std::make_shared<ov::op::v3::Broadcast>(nodes[0].dynamicCast<InfEngineNgraphNode>()->node, shape_node);
return Ptr<BackendNode>(new InfEngineNgraphNode(expand));
}
#endif // HAVE_DNN_NGRAPH

private:
MatShape target_shape;
bool const_input_1d;
Expand Down

0 comments on commit 6af0394

Please sign in to comment.