Skip to content

Commit

Permalink
[QDQ] Hookup NNAPI GetCapability/Compile with shared QDQ selectors (#…
Browse files Browse the repository at this point in the history
…10347)

* add qdqgroup as input for NodeUnit

* minor update

* hookup nnapi_ep

* minor update

* update compiler setting

* Add a simple UT

* Pipeline change to add build minimal extended with NNAPI for Android

* move GetAllNodeUnits to node_unit.h, add UT for NodeUnits, minor updates

* minor updates

* address CR comments

Co-authored-by: gwang0000 <62914304+gwang0000@users.noreply.github.com>
  • Loading branch information
guoyu-wang and guoyu-wang committed Jan 26, 2022
1 parent 9aa5137 commit 4af1166
Show file tree
Hide file tree
Showing 19 changed files with 448 additions and 148 deletions.
15 changes: 15 additions & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ if (onnxruntime_MINIMAL_BUILD)
"${ONNXRUNTIME_ROOT}/core/optimizer/graph_transformer.cc"
)

if (onnxruntime_EXTENDED_MINIMAL_BUILD AND onnxruntime_USE_NNAPI_BUILTIN)
list(APPEND onnxruntime_optimizer_src_patterns
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/qdq_util.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/qdq_transformer/qdq_util.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/initializer.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/initializer.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/utils.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/utils.cc"
)
endif()

if (onnxruntime_ENABLE_RUNTIME_OPTIMIZATION_REPLAY_IN_MINIMAL_BUILD)
list(APPEND onnxruntime_optimizer_src_patterns
"${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/graph_transformer_utils.h"
Expand Down
72 changes: 36 additions & 36 deletions onnxruntime/core/graph/graph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,26 +480,6 @@ const Node* FirstChildByType(const Node& node, const std::string& child_type) {
return nullptr;
}

std::vector<const Node*> FindChildrenByType(const Node& node, const std::string& child_type) {
// find children and sort them by source argument index:
// Create a 2D vector to hold the result.
// 1st dimension index is output index,
// and the 2nd dimension stores the edges from the output.
std::vector<std::vector<const Node*>> children(node.OutputDefs().size(), std::vector<const Node*>());
for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); it++) {
if (it->GetNode().OpType().compare(child_type) == 0) {
children[it->GetSrcArgIndex()].push_back(&(it->GetNode()));
}
}

// aggregate children
std::vector<const Node*> agg_res;
for (size_t output_idx = 0; output_idx < children.size(); output_idx++) {
agg_res.insert(agg_res.end(), children[output_idx].begin(), children[output_idx].end());
}
return agg_res;
}

const Node* FirstParentByType(const Node& node, const std::string& parent_type) {
for (auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) {
if ((*it).OpType().compare(parent_type) == 0) {
Expand All @@ -509,22 +489,6 @@ const Node* FirstParentByType(const Node& node, const std::string& parent_type)
return nullptr;
}

std::vector<const Node*> FindParentsByType(const Node& node, const std::string& parent_type) {
// find parents and sort them by destination argument index
// as there is at most one input edge for each input argument,
// there is no need of extra work like FindChildrenByType
std::vector<const Node*> parents(node.InputDefs().size(), nullptr);
for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) {
if (it->GetNode().OpType().compare(parent_type) == 0) {
parents[it->GetDstArgIndex()] = &(it->GetNode());
}
}

// remove unmatched nodes
parents.erase(std::remove(parents.begin(), parents.end(), nullptr), parents.end());
return parents;
}

NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
// sanity check as AddInitializedTensor silently ignores attempts to add a duplicate initializer
const ONNX_NAMESPACE::TensorProto* existing = nullptr;
Expand Down Expand Up @@ -778,6 +742,42 @@ NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg) {

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

std::vector<const Node*> FindParentsByType(const Node& node, const std::string& parent_type) {
// find parents and sort them by destination argument index
// as there is at most one input edge for each input argument,
// there is no need of extra work like FindChildrenByType
std::vector<const Node*> parents(node.InputDefs().size(), nullptr);
for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); it++) {
if (it->GetNode().OpType().compare(parent_type) == 0) {
parents[it->GetDstArgIndex()] = &(it->GetNode());
}
}

// remove unmatched nodes
parents.erase(std::remove(parents.begin(), parents.end(), nullptr), parents.end());
return parents;
}

std::vector<const Node*> FindChildrenByType(const Node& node, const std::string& child_type) {
// find children and sort them by source argument index:
// Create a 2D vector to hold the result.
// 1st dimension index is output index,
// and the 2nd dimension stores the edges from the output.
std::vector<std::vector<const Node*>> children(node.OutputDefs().size(), std::vector<const Node*>());
for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); it++) {
if (it->GetNode().OpType().compare(child_type) == 0) {
children[it->GetSrcArgIndex()].push_back(&(it->GetNode()));
}
}

// aggregate children
std::vector<const Node*> agg_res;
for (size_t output_idx = 0; output_idx < children.size(); output_idx++) {
agg_res.insert(agg_res.end(), children[output_idx].begin(), children[output_idx].end());
}
return agg_res;
}

const std::string& GetNodeInputName(const Node& node, int index) {
const auto& inputs = node.InputDefs();
ORT_ENFORCE(index >= 0 && static_cast<size_t>(index) < inputs.size(),
Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/core/graph/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,9 @@ bool GetRepeatedNodeAttributeValues(const Node& node,
/** Find the first child of the specified op type. */
const Node* FirstChildByType(const Node& node, const std::string& child_type);

/** Find node children by op types.
@returns The matched children are sorted by source argument index of their corresponding edge.
**/
std::vector<const Node*> FindChildrenByType(const Node& node, const std::string& child_type);

/** Find the first parent of the specified op type. */
const Node* FirstParentByType(const Node& node, const std::string& parent_type);

/** Find node parents by op types.
@returns The matched parents are sorted by destination argument index of their corresponding edge.
**/
std::vector<const Node*> FindParentsByType(const Node& node, const std::string& parent_type);

/** Tests if we can remove a node and merge its input edge (if any) with its output edges.
Conditions:
Input rules:
Expand Down Expand Up @@ -290,6 +280,16 @@ NodeArg& CreateNodeArg(Graph& graph, const NodeArg& base_arg);

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

/** Find node parents by op types.
@returns The matched parents are sorted by destination argument index of their corresponding edge.
**/
std::vector<const Node*> FindParentsByType(const Node& node, const std::string& parent_type);

/** Find node children by op types.
@returns The matched children are sorted by source argument index of their corresponding edge.
**/
std::vector<const Node*> FindChildrenByType(const Node& node, const std::string& child_type);

/** Gets the name of the incoming NodeArg with the specified index for the given node. */
const std::string& GetNodeInputName(const Node& node, int index);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"

Expand Down Expand Up @@ -251,4 +251,4 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer,
} // namespace QDQ
} // namespace onnxruntime

#endif // !defined(ORT_MINIMAL_BUILD)
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#pragma once

#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

#include "core/optimizer/selectors_actions/selector_action_transformer.h"

Expand Down Expand Up @@ -193,4 +193,4 @@ class MatMulSelector : public BaseSelector {
} // namespace QDQ
} // namespace onnxruntime

#endif // !defined(ORT_MINIMAL_BUILD)
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

#include "utils.h"

#include <iostream>
Expand All @@ -10,7 +12,6 @@
#include <core/graph/graph_viewer.h>
#include <core/providers/common.h>

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"

namespace onnxruntime {
Expand Down Expand Up @@ -101,7 +102,7 @@ void SelectorManager::InitializeSelectorsMap() {
}
}

void SelectorManager::Initialize() {
SelectorManager::SelectorManager() {
CreateSelectors();
InitializeSelectorsMap();
}
Expand Down Expand Up @@ -141,4 +142,6 @@ std::vector<NodeGroup> SelectorManager::GetQDQSelections(const GraphViewer& grap
}

} // namespace QDQ
} // namespace onnxruntime
} // namespace onnxruntime

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#pragma once

#include <string>
#include "core/common/common.h"
#include "core/graph/basic_types.h"
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h"
#include "core/optimizer/selectors_actions/helpers.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "onnx/defs/schema.h"
#endif

namespace onnxruntime {

Expand All @@ -15,6 +18,9 @@ class Node;

namespace QDQ {

struct NodeGroup;
class NodeGroupSelector;

// struct that provides a join between selector and op versions supported
struct OpVersionsAndSelector {
using OpVersionsMap = std::unordered_map<std::string, std::vector<ONNX_NAMESPACE::OperatorSetVersion>>;
Expand Down Expand Up @@ -52,9 +58,7 @@ class Selectors {
// class that manages qdq node group selections
class SelectorManager {
public:
SelectorManager() = default;

void Initialize();
SelectorManager();

// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
// Can be used in QDQ support in different EPs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
#pragma once

#include <functional>
#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include <optional>
#endif // !defined(ORT_MINIMAL_BUILD)
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

#include "core/framework/kernel_registry_manager.h"
#include "core/optimizer/graph_transformer.h"
Expand All @@ -19,7 +19,7 @@ class Graph;
class GraphViewer;
class Node;

#if !defined(ORT_MINIMAL_BUILD)
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

// Base class for a selector which checks for a match and returns the set of nodes involved.
struct NodeSelector {
Expand All @@ -33,7 +33,7 @@ struct NodeSelector {
NodeSelector() = default;
};

#endif // !defined(ORT_MINIMAL_BUILD)
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

// class to manage a set of selector and associated actions
class SelectorActionRegistry {
Expand Down
37 changes: 26 additions & 11 deletions onnxruntime/core/optimizer/utils.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#if !defined(ORT_MINIMAL_BUILD)
#include "core/graph/constants.h"
#include "core/graph/onnx_protobuf.h"
#include "core/graph/graph_utils.h"
Expand All @@ -13,29 +15,25 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#endif // #if !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
#include "core/graph/node_arg.h"
#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

using namespace onnxruntime;

namespace onnxruntime {
namespace optimizer_utils {

#if !defined(ORT_MINIMAL_BUILD)

bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
}

bool IsScalar(const NodeArg& input_arg) {
auto shape = input_arg.Shape();
if (shape == nullptr) {
// shape inferencing wasn't able to populate shape information for this NodeArg
return false;
}

auto dim_size = shape->dim_size();
return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1);
}

// Check whether input is a constant scalar with expected float value.
bool IsInitializerWithExpectedValue(const Graph& graph, const NodeArg& input_arg, float expected_value, bool is_constant) {
if (!IsScalar(input_arg)) {
Expand Down Expand Up @@ -293,5 +291,22 @@ bool IsOperationDeterministic(const std::string& domain, const std::string& op)
return itDomain->second.count(op) == 0;
}

#endif // #if !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

bool IsScalar(const NodeArg& input_arg) {
auto shape = input_arg.Shape();
if (shape == nullptr) {
// shape inferencing wasn't able to populate shape information for this NodeArg
return false;
}

auto dim_size = shape->dim_size();
return dim_size == 0 || (dim_size == 1 && shape->dim(0).has_dim_value() && shape->dim(0).dim_value() == 1);
}

#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

} // namespace optimizer_utils
} // namespace onnxruntime
16 changes: 13 additions & 3 deletions onnxruntime/core/optimizer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,22 @@

#pragma once

#if !defined(ORT_MINIMAL_BUILD)
#include "core/graph/onnx_protobuf.h"
#include "core/graph/graph.h"
#endif // !#if !defined(ORT_MINIMAL_BUILD)

namespace onnxruntime {
class Graph;
class NodeArg;

namespace optimizer_utils {

#if !defined(ORT_MINIMAL_BUILD)

// Check if TensorProto contains a floating point type.
bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto);

// Check if NodeArg takes in a scalar tensor.
bool IsScalar(const NodeArg& input_arg);

/** Check whether a input is initializer with specified float value.
@param expected_value is the expected value of the initializer.
@param is_constant means whether the initializer is required to be constant.
Expand Down Expand Up @@ -101,5 +102,14 @@ bool CheckOutputEdges(const Graph& graph, const Node& node, size_t expected_outp

bool IsOperationDeterministic(const std::string& domain, const std::string& op);

#endif // !#if !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

// Check if NodeArg takes in a scalar tensor.
bool IsScalar(const NodeArg& input_arg);

#endif // #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

} // namespace optimizer_utils
} // namespace onnxruntime
Loading

0 comments on commit 4af1166

Please sign in to comment.