Skip to content

Commit

Permalink
dnn: expand refactor with cv::broadcast for onnx models (opencv#24295)
Browse files Browse the repository at this point in the history
* add expand impl with cv::broadcast

* remove expandMid

* deduce shape from -1

* add constant folding

* handle input constant; handle input constant 1d

* add expand conformance tests; add checks to disallow shape of neg values; add early copy for unchanged total elements

* fix ExpandSubgraph

* dummy commit to trigger build

* dummy commit to trigger build 1

* remove conformance from test names
  • Loading branch information
fengyuentau authored and hanliutong committed Oct 7, 2023
1 parent 08bb92f commit 28fefc3
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 153 deletions.
6 changes: 6 additions & 0 deletions modules/dnn/include/opencv2/dnn/all_layers.hpp
Expand Up @@ -1144,6 +1144,12 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<GemmLayer> create(const LayerParams& params);
};

class CV_EXPORTS ExpandLayer : public Layer
{
public:
static Ptr<ExpandLayer> create(const LayerParams &params);
};

//! @}
//! @}
CV__DNN_INLINE_NS_END
Expand Down
1 change: 1 addition & 0 deletions modules/dnn/src/init.cpp
Expand Up @@ -158,6 +158,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Reciprocal, ReciprocalLayer);
CV_DNN_REGISTER_LAYER_CLASS(Gather, GatherLayer);
CV_DNN_REGISTER_LAYER_CLASS(LayerNormalization, LayerNormLayer);
CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer);

CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
Expand Down
149 changes: 149 additions & 0 deletions modules/dnn/src/layers/expand_layer.cpp
@@ -0,0 +1,149 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.

#include "../precomp.hpp"
#include <opencv2/dnn/shape_utils.hpp>

namespace cv { namespace dnn {

class ExpandLayerImpl CV_FINAL : public ExpandLayer
{
public:
ExpandLayerImpl(const LayerParams &params) {
setParamsFrom(params);

// shape as param
CV_CheckTrue(params.has("shape"), "DNN/Expand: shape is required in Expand layer initialization");
DictValue param_shape = params.get("shape");
int ndims_shape = param_shape.size();
CV_CheckGT(ndims_shape, 0, "DNN/Expand: ndims of shape must be > 0");
target_shape.resize(ndims_shape);
for (int i = 0; i < ndims_shape; i++) {
target_shape[i] = param_shape.get<int>(i);
}

// FIXME: remove when 0d/1d mat is available
const_input_1d = params.get("const_input_1d", false);
}

virtual bool supportBackend(int backendId) CV_OVERRIDE {
return backendId == DNN_BACKEND_OPENCV;
}

virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE {
CV_CheckGE(inputs.size(), static_cast<size_t>(1), "DNN/Expand: one input at least");
CV_CheckLE(inputs.size(), static_cast<size_t>(2), "DNN/Expand: two input at most");
CV_CheckFalse(target_shape.empty(), "DNN/Expand: shape must known before memory is set");

MatShape input_shape = inputs[0]; // 1d tensor is represented as 2d mat, e.g. [3] -> [3, 1]
if (const_input_1d) {
input_shape = {inputs[0][0]};
}

auto& moreDimension = input_shape.size() > target_shape.size() ? input_shape : target_shape;
auto& lessDimension = input_shape.size() <= target_shape.size() ? input_shape : target_shape;

/* Example:
i = 3
|
moreDimension: 1 2 3 4 5, assign non-aligned dimensions to output shape
lessDimension: 1 1 5, when dimension is aligned, check valid dimension (either equal or one of them is 1) and assign bigger one
|
j = 0 = i - (moreDimension.size() - lessDimension.size());
*/
MatShape outputShape(moreDimension.size(), 1);
for (int i = 0; i < moreDimension.size(); i++) {
int d = moreDimension[i];
int j = i - (moreDimension.size() - lessDimension.size());
if (j >= 0) {
if (d == 1 || lessDimension[j] == 1 || // broadcast
d == lessDimension[j]) { // plain copy
outputShape[i] = std::max(d, lessDimension[j]);
} else {
CV_Error(Error::StsBadSize, cv::format("DNN/Expand: invalid dimension, d (%d) != d (%d)", moreDimension[i], lessDimension[j]));
}
} else {
outputShape[i] = d;
}
}
outputs.assign(1, outputShape);
return false;
}

virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
std::vector<Mat> inputs;
inputs_arr.getMatVector(inputs);

const auto &input = inputs[0];
auto input_shape = shape(input);
if (const_input_1d) {
input_shape = {input_shape[0]};
}

auto& moreDimension = input_shape.size() > target_shape.size() ? input_shape : target_shape;
auto& lessDimension = input_shape.size() <= target_shape.size() ? input_shape : target_shape;

MatShape final_target_shape(moreDimension.size(), 1);
for (int i = 0; i < moreDimension.size(); i++) {
int d = moreDimension[i];
int j = i - (moreDimension.size() - lessDimension.size());
if (j >= 0) {
final_target_shape[i] = std::max(lessDimension[j], d);
} else {
final_target_shape[i] = d;
}
}
target_shape.clear();
target_shape = std::move(final_target_shape);
}

void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());

if (inputs_arr.depth() == CV_16S)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}

std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);

int target_shape_total = std::accumulate(target_shape.begin(), target_shape.end(), 1, std::multiplies<int>());
if (target_shape_total == inputs[0].total()) {
const char *data = inputs[0].ptr<const char>();
char *output = outputs[0].ptr<char>();
int step = target_shape_total * outputs[0].elemSize();
std::memcpy(output, data, step);
return;
}

if (const_input_1d) {
const char *data = inputs[0].ptr<const char>();
char *output = outputs[0].ptr<char>();
int step = target_shape.back() * outputs[0].elemSize();
int total = std::accumulate(target_shape.begin(), target_shape.end() - 1, 1, std::multiplies<int>());
for (int i = 0; i < total; i++) {
std::memcpy(output + i * step, data, step);
}
} else {
cv::broadcast(inputs[0], target_shape, outputs[0]);
}
}

private:
MatShape target_shape;
bool const_input_1d;
};

Ptr<ExpandLayer> ExpandLayer::create(const LayerParams &params) {
return makePtr<ExpandLayerImpl>(params);
}

}} // cv::dnn
132 changes: 132 additions & 0 deletions modules/dnn/src/onnx/onnx_graph_simplifier.cpp
Expand Up @@ -821,6 +821,16 @@ class GatherCastSubgraph : public Subgraph
}
};

/* Constant folding shape for Expand.
Before fusion:
+--------------------------------------------------------------+ (X)
| |
ConstantOfShape[input=[4]] -> Mul[B=-1] -> Equal[A=[2, -1, -1, -1]] -> Where[Y=[2, -1, -1, -1]] -> Expand
\ \
value=[1] (condition)
*/
class ExpandSubgraph : public Subgraph
{
public:
Expand All @@ -837,6 +847,128 @@ class ExpandSubgraph : public Subgraph
addNodeToMatch("Expand", input, where);
setFusedNode("Expand", input, shape);
}

static int extractValue(const Ptr<ImportGraphWrapper>& net, int node_id, int64_t &val) {
Ptr<ImportNodeWrapper> node_wrapper = net->getNode(node_id);
opencv_onnx::NodeProto* node = node_wrapper.dynamicCast<ONNXNodeWrapper>()->node;

if (node->attribute_size() == 0) {
val = 0;
return 1;
} else if (node->attribute_size() == 1) {
opencv_onnx::AttributeProto attr = node->attribute(0);
if (attr.name() != "value") {
return 0;
}
Mat mat_value = getMatFromTensor(attr.t());
switch (mat_value.type()) {
case CV_32S: {
val = static_cast<int64_t>(mat_value.at<int>());
} break;
default: return 0;
}
return 1;
}
return 0;
}

static std::vector<int64_t> extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
Mat mat_constant;
if (initializer_id != -1) // initializer
{
mat_constant = onnx_net->getMatFromInitializer(initializer_id);
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
mat_constant = getMatFromTensor(constant_proto);
}

std::vector<int64_t> retvals{mat_constant.begin<int>(), mat_constant.end<int>()};
return retvals;
}

virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds,
std::vector<int>& targetNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) {
int64_t value_ConstantOfShape;
if (!extractValue(net, matchedNodesIds[0], value_ConstantOfShape)) {
return false;
}
std::vector<int64_t> input_ConstantOfShape = extractConstant(net, matchedNodesIds[0], 0);
if (input_ConstantOfShape.size() != static_cast<size_t>(1)) {
return false;
}

auto B_Mul = extractConstant(net, matchedNodesIds[1], 1);
if (B_Mul.size() != static_cast<size_t>(1)) {
return false;
}

auto A_Equal = extractConstant(net, matchedNodesIds[2], 0);
if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) {
return false;
}

auto Y_Where = extractConstant(net, matchedNodesIds[3], 2);
if (Y_Where.size() != A_Equal.size()) {
return false;
}

// run ConstantOfShape
std::vector<int64_t> output_ConstantOfShape(std::accumulate(input_ConstantOfShape.begin(), input_ConstantOfShape.end(), static_cast<int64_t>(1), std::multiplies<int64_t>()), value_ConstantOfShape);
// run Mul
std::vector<int64_t> output_Mul = output_ConstantOfShape;
for (size_t i = 0; i < output_Mul.size(); i++) {
int64_t b = B_Mul[0];
output_Mul[i] *= b;
}
// run Equal
std::vector<bool> output_Equal(output_Mul.size());
for (int i = 0; i < output_Equal.size(); i++) {
if (A_Equal[i] == output_Mul[i]) {
output_Equal[i] = true;
} else {
output_Equal[i] = false;
}
}
// run Where
std::vector<int64_t> output_Where(output_Equal.size());
for (int i = 0; i < output_Where.size(); i++) {
if (output_Equal[i]) {
output_Where[i] = output_ConstantOfShape[i];
} else {
output_Where[i] = Y_Where[i];
}
}
shape = output_Where;

return true;
}
return false;
}

virtual void finalize(const Ptr<ImportGraphWrapper>& graph,
const Ptr<ImportNodeWrapper>& fusedNode,
std::vector<Ptr<ImportNodeWrapper> >& inputs) CV_OVERRIDE {
// replace values
opencv_onnx::NodeProto* node_shape = inputs[1].dynamicCast<ONNXNodeWrapper>()->node;
auto attr = node_shape->mutable_attribute()->Mutable(0);
auto tensor = attr->mutable_t();
tensor->clear_raw_data();
tensor->set_raw_data(std::string((const char*)(shape.data()), shape.size() * sizeof(int64_t)));
}

protected:
std::vector<int64_t> shape;
};

class MishSubgraph : public Subgraph
Expand Down

0 comments on commit 28fefc3

Please sign in to comment.