New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement function retrieval APIs; Add documentation for functions #1112
Changes from 38 commits
f14cabb
cdbc593
d251325
cc7bdf4
2a4fd29
65ab836
83d5328
668c3bd
fcfd814
826da02
e11baa5
03970b5
c3639d1
3bca586
fd7b65f
cb82317
7bd3ded
411a5c7
1bb892d
bbeffa0
0be2361
1ae57b1
398143b
a3271ed
06016b1
87398e9
5812ed2
b3a77df
743dc8d
e3010fa
0db95c4
95e3b84
8906f13
2d7d271
35728c5
9d37027
936a830
4731be7
5f3794e
9579398
c4a25f7
9c6a01e
f07f977
f081fd6
abad4cc
d22b043
293c4c8
af7009e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,8 @@ list(APPEND CMAKE_MODULE_PATH ${ONNX_ROOT}/cmake/Modules) | |
if(NOT MSVC) | ||
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0") | ||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0") | ||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread") | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") | ||
if(ONNX_COVERAGE) | ||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fprofile-arcs -ftest-coverage") | ||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fprofile-arcs -ftest-coverage") | ||
|
@@ -361,6 +363,7 @@ if (MSVC) | |
/wd4800 # disable warning type' : forcing value to bool 'true' or 'false' (performance warning) | ||
/wd4503 # identifier' : decorated name length exceeded, name was truncated | ||
/wd4146 # unary minus operator applied to unsigned type, result still unsigned | ||
/wd4244 # conversion from 'protobuf::int64' to 'int' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where does such cast happen? do you also need to ignore this warning in other targets? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This happens when I convert fields in functionproto in cpp2py.. I only notice the warning in onnx target.. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at least you would also need to ignore it in other platforms as well. maybe what's better is to find the place in pybind11 binding and fix it? |
||
${EXTRA_FLAGS} | ||
) | ||
if(${ONNX_USE_MSVC_STATIC_RUNTIME}) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
## Functions | ||
*This file is automatically generated from the | ||
[def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py). | ||
Do not modify directly and instead edit function definitions.* | ||
## ai.onnx (default) | ||
* <sub>experimental</sub><a href="#FuncMeanVarianceNormalization">FuncMeanVarianceNormalization</a> | ||
|
||
|
||
|
||
### <sub>experimental</sub> <a name="FuncMeanVarianceNormalization"></a><a name="funcmeanvariancenormalization">**FuncMeanVarianceNormalization**</a> | ||
|
||
A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula: <br/> ``` (X-EX)/sqrt(E(X-EX)^2) ``` <br/><br/><b>INPUT: </b>X(float/float16/double) with shape [N,C,W,H] or N-D shape <br/><br/><b>ATTRIBUTE: </b><br/> <tt>axes: </tt>will be passed to ReducedMean Ops. Use [0,2,3] (without C axis for N-D cases) for for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance. Use [0,1,2,3] (with C axis) to calculate global mean and global variance with all variables sharing the same mean/variance.<br/> (The KeepDims attribute in ReducedMean is set to true for calculation)<br/><br/><b>OUTPUT: </b>X_MVN(float/float16/double) with shape [N,C,W,H] or the input N-D shape <br/> | ||
|
||
#### Version | ||
|
||
This version of the function has been available since version 8 of the default ONNX operator set. | ||
|
||
#### Inputs | ||
|
||
<dl> | ||
<dt>X; </dt> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the type constraint info should also be inferred from its body and printed here. You may add it in separate PR and push this in firstly. |
||
<br/></dl> | ||
|
||
#### Outputs | ||
|
||
<dl> | ||
<dt>X_MVN; </dt> | ||
<br/></dl> | ||
|
||
#### Attributes | ||
|
||
<dl> | ||
<dt>axes;<br/></dt> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs to document done its value for across_channels=True/False. Please also make sure the description is understandable for people from outside. |
||
</dl> | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
## Function Changelog | ||
*This file is automatically generated from the | ||
[def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py). | ||
Do not modify directly and instead edit function definitions.* | ||
# ai.onnx (default) | ||
## Version 8 of domain ai.onnx (default) | ||
### <a name="FuncMeanVarianceNormalization-8"></a>**FuncMeanVarianceNormalization-8**</a> | ||
|
||
A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula: <br/> ``` (X-EX)/sqrt(E(X-EX)^2) ``` <br/><br/><b>INPUT: </b>X(float/float16/double) with shape [N,C,W,H] or N-D shape <br/><br/><b>ATTRIBUTE: </b><br/> <tt>axes: </tt>will be passed to ReducedMean Ops. Use [0,2,3] (without C axis for N-D cases) for for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance. Use [0,1,2,3] (with C axis) to calculate global mean and global variance with all variables sharing the same mean/variance.<br/> (The KeepDims attribute in ReducedMean is set to true for calculation)<br/><br/><b>OUTPUT: </b>X_MVN(float/float16/double) with shape [N,C,W,H] or the input N-D shape <br/> | ||
|
||
#### Version | ||
|
||
This version of the function has been available since version 8 of the default ONNX operator set. | ||
|
||
#### Inputs | ||
|
||
<dl> | ||
<dt>X; </dt> | ||
<br/></dl> | ||
|
||
#### Outputs | ||
|
||
<dl> | ||
<dt>X_MVN; </dt> | ||
<br/></dl> | ||
|
||
#### Attributes | ||
|
||
<dl> | ||
<dt>axes;<br/></dt> | ||
</dl> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import numpy as np # type: ignore | ||
|
||
import onnx | ||
from ..base import Base | ||
from . import expect | ||
|
||
|
||
class MVN(Base): | ||
|
||
@staticmethod | ||
def export(): # type: () -> None | ||
node = onnx.helper.make_node( | ||
'FuncMeanVarianceNormalization', | ||
inputs=['X'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question please: no attributes needed? |
||
outputs=['Y'], | ||
axes=[0, 2, 3] | ||
) | ||
|
||
input_data = np.array([[[[0.8439683], [0.5665144], [0.05836735]], | ||
[[0.02916367], [0.12964272], [0.5060197]], | ||
[[0.79538304], [0.9411346], [0.9546573]]], | ||
[[[0.17730942], [0.46192095], [0.26480448]], | ||
[[0.6746842], [0.01665257], [0.62473077]], | ||
[[0.9240844], [0.9722341], [0.11965699]]], | ||
[[[0.41356155], [0.9129373], [0.59330076]], | ||
[[0.81929934], [0.7862604], [0.11799799]], | ||
[[0.69248444], [0.54119414], [0.07513223]]]], dtype=np.float32) | ||
|
||
# Calculate expected output data | ||
data_mean = np.mean(input_data, axis=(0, 2, 3), keepdims=1) | ||
data_mean_squared = np.power(data_mean, 2) | ||
data_squared = np.power(input_data, 2) | ||
data_squared_mean = np.mean(data_squared, axis=(0, 2, 3), keepdims=1) | ||
std = np.sqrt(data_squared_mean - data_mean_squared) | ||
expected_output = (input_data - data_mean) / (std + 1e-9) | ||
|
||
expect(node, inputs=[input_data], outputs=[expected_output], | ||
name='test_mvn') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) Facebook Inc. and Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#include "onnx/common/model_helpers.h" | ||
#include "onnx/checker.h" | ||
#include "onnx/defs/schema.h" | ||
#include "onnx/string_utils.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
using namespace Common; | ||
|
||
Common::Status BuildNode( | ||
NodeProto* node, | ||
const std::string& name, | ||
const std::string& domain, | ||
const std::string& doc_string, | ||
const std::string& op_type, | ||
std::vector<std::string> const& inputs, | ||
std::vector<std::string> const& outputs) { | ||
if (node == NULL) { | ||
return Status( | ||
Common::CHECKER, | ||
Common::INVALID_ARGUMENT, | ||
"node_proto should not be nullptr."); | ||
} | ||
node->set_name(name); | ||
node->set_domain(domain); | ||
node->set_doc_string(doc_string); | ||
node->set_op_type(op_type); | ||
for (auto& input : inputs) { | ||
node->add_input(input); | ||
} | ||
for (auto& output : outputs) { | ||
node->add_output(output); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
} // namespace ONNX_NAMESPACE |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (c) Facebook Inc. and Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
#include "onnx/common/status.h" | ||
#include "onnx/onnx-operators_pb.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
|
||
// Helper function for register nodes in | ||
// a FunctionProto. Attributes need to be | ||
// registered separately. | ||
Common::Status BuildNode( | ||
NodeProto* node, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. out parameter in the end. |
||
const std::string& name, | ||
const std::string& domain, | ||
const std::string& doc_string, | ||
const std::string& op_type, | ||
std::vector<std::string> const& inputs, | ||
std::vector<std::string> const& outputs); | ||
} // namespace ONNX_NAMESPACE |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
#include <unordered_map> | ||
|
||
#include "onnx/checker.h" | ||
#include "onnx/defs/function.h" | ||
#include "onnx/defs/schema.h" | ||
#include "onnx/optimizer/optimize.h" | ||
#include "onnx/py_utils.h" | ||
|
@@ -96,6 +97,57 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
.value("COMMON", OpSchema::SupportType::COMMON) | ||
.value("EXPERIMENTAL", OpSchema::SupportType::EXPERIMENTAL); | ||
|
||
py::class_<FunctionProto> function_proto(defs, "FunctionProto"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not good, because python land already has a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently the solution is blocked by #1194 (and seems the import issue is in discussion in protobuf repo as well). I could rename the generated pybind class (to function_schema probably) and explore a way to resolve #1194 (probably need to manually run text replacement scripts), switch to use function_proto later. This class is only used for the display of function information now so there is no use case for users to really create one using function classes in python module. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Found a way to manually correct the import in build. Pushed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment resolved |
||
function_proto.def_property_readonly("name", &FunctionProto::name) | ||
.def_property_readonly("doc_string", &FunctionProto::doc_string) | ||
.def_property_readonly("since_version", &FunctionProto::since_version) | ||
.def_property_readonly( | ||
"inputs", | ||
[](FunctionProto* fp) -> std::vector<std::string> { | ||
std::vector<std::string> _stl_vec; | ||
_stl_vec.assign(fp->input().begin(), fp->input().end()); | ||
return _stl_vec; | ||
}) | ||
.def_property_readonly( | ||
"outputs", | ||
[](FunctionProto* fp) -> std::vector<std::string> { | ||
std::vector<std::string> _stl_vec; | ||
_stl_vec.assign(fp->output().begin(), fp->output().end()); | ||
return _stl_vec; | ||
}) | ||
.def_property_readonly( | ||
"attribute", | ||
[](FunctionProto* fp) -> std::vector<std::string> { | ||
std::vector<std::string> _stl_vec; | ||
_stl_vec.assign(fp->attribute().begin(), fp->attribute().end()); | ||
return _stl_vec; | ||
}) | ||
.def_property_readonly( | ||
"nodes", [](FunctionProto* fp) -> std::vector<NodeProto> { | ||
std::vector<NodeProto> _stl_vec; | ||
_stl_vec.assign(fp->node().begin(), fp->node().end()); | ||
return _stl_vec; | ||
}); | ||
|
||
py::class_<NodeProto> node_proto(function_proto, "NodeProto"); | ||
node_proto.def_property_readonly("name", &NodeProto::name) | ||
.def_property_readonly("doc_string", &NodeProto::doc_string) | ||
.def_property_readonly("domain", &NodeProto::domain) | ||
.def_property_readonly("op_type", &NodeProto::op_type) | ||
.def_property_readonly( | ||
"inputs", | ||
[](NodeProto* np) -> std::vector<std::string> { | ||
std::vector<std::string> _stl_vec; | ||
_stl_vec.assign(np->input().begin(), np->input().end()); | ||
return _stl_vec; | ||
}) | ||
.def_property_readonly( | ||
"outputs", [](NodeProto* np) -> std::vector<std::string> { | ||
std::vector<std::string> _stl_vec; | ||
_stl_vec.assign(np->output().begin(), np->output().end()); | ||
return _stl_vec; | ||
}); | ||
|
||
defs.def( | ||
"has_schema", | ||
[](const std::string& op_type, const std::string& domain) -> bool { | ||
|
@@ -146,6 +198,32 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
return OpSchemaRegistry::get_all_schemas_with_history(); | ||
}); | ||
|
||
defs.def( | ||
"get_all_functions", | ||
[](const std::string& domain) | ||
-> std::unordered_map<std::string, std::vector<FunctionProto>> { | ||
// Eliminate unsupported stl containers and smart pointers for Pybind | ||
// and return a dict[Text, [FunctionProto]] datatype to Python API | ||
std::multimap<std::string, std::unique_ptr<FunctionProto>> temp_ptr_map; | ||
std::unordered_map<std::string, std::vector<FunctionProto>> temp_map; | ||
FunctionBuilderRegistry& function_registry = | ||
FunctionBuilderRegistry::OnnxInstance(); | ||
Common::Status status = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check status? |
||
function_registry.GetFunctions(domain, &temp_ptr_map); | ||
for (auto iter = temp_ptr_map.begin(); iter != temp_ptr_map.end(); | ||
++iter) | ||
if (!temp_map.count(iter->first)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use |
||
std::vector<FunctionProto> tmp_vec; | ||
tmp_vec.emplace_back(*iter->second); | ||
temp_map.insert( | ||
std::unordered_map<std::string, std::vector<FunctionProto>>:: | ||
value_type(iter->first, tmp_vec)); | ||
} else { | ||
temp_map.at(iter->first).emplace_back(*iter->second); | ||
} | ||
return temp_map; | ||
}); | ||
|
||
// Submodule `checker` | ||
auto checker = onnx_cpp2py_export.def_submodule("checker"); | ||
checker.doc() = "Checker submodule"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?