New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement function retrieval APIs; Add documentation for functions #1112
Changes from all commits
f14cabb
cdbc593
d251325
cc7bdf4
2a4fd29
65ab836
83d5328
668c3bd
fcfd814
826da02
e11baa5
03970b5
c3639d1
3bca586
fd7b65f
cb82317
7bd3ded
411a5c7
1bb892d
bbeffa0
0be2361
1ae57b1
398143b
a3271ed
06016b1
87398e9
5812ed2
b3a77df
743dc8d
e3010fa
0db95c4
95e3b84
8906f13
2d7d271
35728c5
9d37027
936a830
4731be7
5f3794e
9579398
c4a25f7
9c6a01e
f07f977
f081fd6
abad4cc
d22b043
293c4c8
af7009e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
## Functions | ||
*This file is automatically generated from the | ||
[def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py). | ||
Do not modify directly and instead edit function definitions.* |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
## Functions | ||
*This file is automatically generated from the | ||
[def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py). | ||
Do not modify directly and instead edit function definitions.* | ||
## ai.onnx (default) | ||
* <sub>experimental</sub><a href="#FuncMeanVarianceNormalization">FuncMeanVarianceNormalization</a> | ||
|
||
|
||
|
||
### <sub>experimental</sub> <a name="FuncMeanVarianceNormalization"></a><a name="funcmeanvariancenormalization">**FuncMeanVarianceNormalization**</a> | ||
|
||
A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula: <br/> ``` (X-EX)/sqrt(E(X-EX)^2) ``` <br/><br/><b>INPUT: </b>X(float/float16/double) with shape [N,C,W,H] or N-D shape <br/><br/><b>ATTRIBUTE: </b><br/> <tt>axes: </tt>will be passed to ReducedMean Ops. Use [0,2,3] (without C axis for N-D cases) for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance. Use [0,1,2,3] (with C axis) to calculate global mean and global variance with all variables sharing the same mean/variance.<br/> (The KeepDims attribute in ReducedMean is set to true for calculation)<br/><br/><b>OUTPUT: </b>X_MVN(float/float16/double) with the same shape as input X<br/> | ||
|
||
#### Version | ||
|
||
This version of the function has been available since version 8 of the default ONNX operator set. | ||
|
||
#### Inputs | ||
|
||
<dl> | ||
<dt>X; </dt> | ||
<br/></dl> | ||
|
||
#### Outputs | ||
|
||
<dl> | ||
<dt>X_MVN; </dt> | ||
<br/></dl> | ||
|
||
#### Attributes | ||
|
||
<dl> | ||
<dt>axes;<br/></dt> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs to document done its value for across_channels=True/False. Please also make sure the description is understandable for people from outside. |
||
</dl> | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
## Function Changelog | ||
*This file is automatically generated from the | ||
[def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py). | ||
Do not modify directly and instead edit function definitions.* | ||
## ai.onnx.ml |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
## Function Changelog | ||
*This file is automatically generated from the | ||
[def files](/onnx/defs) via [this script](/onnx/defs/gen_doc.py). | ||
Do not modify directly and instead edit function definitions.* | ||
# ai.onnx (default) | ||
## Version 8 of domain ai.onnx (default) | ||
### <a name="FuncMeanVarianceNormalization-8"></a>**FuncMeanVarianceNormalization-8**</a> | ||
|
||
A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula: <br/> ``` (X-EX)/sqrt(E(X-EX)^2) ``` <br/><br/><b>INPUT: </b>X(float/float16/double) with shape [N,C,W,H] or N-D shape <br/><br/><b>ATTRIBUTE: </b><br/> <tt>axes: </tt>will be passed to ReducedMean Ops. Use [0,2,3] (without C axis for N-D cases) for calculating means and variances along channels. Two variables with the same C-coordinate are associated with the same mean and variance. Use [0,1,2,3] (with C axis) to calculate global mean and global variance with all variables sharing the same mean/variance.<br/> (The KeepDims attribute in ReducedMean is set to true for calculation)<br/><br/><b>OUTPUT: </b>X_MVN(float/float16/double) with the same shape as input X<br/> | ||
|
||
#### Version | ||
|
||
This version of the function has been available since version 8 of the default ONNX operator set. | ||
|
||
#### Inputs | ||
|
||
<dl> | ||
<dt>X; </dt> | ||
<br/></dl> | ||
|
||
#### Outputs | ||
|
||
<dl> | ||
<dt>X_MVN; </dt> | ||
<br/></dl> | ||
|
||
#### Attributes | ||
|
||
<dl> | ||
<dt>axes;<br/></dt> | ||
</dl> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import numpy as np # type: ignore | ||
|
||
import onnx | ||
from ..base import Base | ||
from . import expect | ||
|
||
|
||
class MVN(Base): | ||
|
||
@staticmethod | ||
def export(): # type: () -> None | ||
node = onnx.helper.make_node( | ||
'FuncMeanVarianceNormalization', | ||
inputs=['X'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question please: no attributes needed? |
||
outputs=['Y'], | ||
axes=[0, 2, 3] | ||
) | ||
|
||
input_data = np.array([[[[0.8439683], [0.5665144], [0.05836735]], | ||
[[0.02916367], [0.12964272], [0.5060197]], | ||
[[0.79538304], [0.9411346], [0.9546573]]], | ||
[[[0.17730942], [0.46192095], [0.26480448]], | ||
[[0.6746842], [0.01665257], [0.62473077]], | ||
[[0.9240844], [0.9722341], [0.11965699]]], | ||
[[[0.41356155], [0.9129373], [0.59330076]], | ||
[[0.81929934], [0.7862604], [0.11799799]], | ||
[[0.69248444], [0.54119414], [0.07513223]]]], dtype=np.float32) | ||
|
||
# Calculate expected output data | ||
data_mean = np.mean(input_data, axis=(0, 2, 3), keepdims=1) | ||
data_mean_squared = np.power(data_mean, 2) | ||
data_squared = np.power(input_data, 2) | ||
data_squared_mean = np.mean(data_squared, axis=(0, 2, 3), keepdims=1) | ||
std = np.sqrt(data_squared_mean - data_mean_squared) | ||
expected_output = (input_data - data_mean) / (std + 1e-9) | ||
|
||
expect(node, inputs=[input_data], outputs=[expected_output], | ||
name='test_mvn') |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// Copyright (c) Facebook Inc. and Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#include "onnx/common/model_helpers.h" | ||
#include "onnx/checker.h" | ||
#include "onnx/defs/schema.h" | ||
#include "onnx/string_utils.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
using namespace Common; | ||
|
||
Common::Status BuildNode( | ||
const std::string& name, | ||
const std::string& domain, | ||
const std::string& doc_string, | ||
const std::string& op_type, | ||
std::vector<std::string> const& inputs, | ||
std::vector<std::string> const& outputs, | ||
NodeProto* node) { | ||
if (node == NULL) { | ||
return Status( | ||
Common::CHECKER, | ||
Common::INVALID_ARGUMENT, | ||
"node_proto should not be nullptr."); | ||
} | ||
node->set_name(name); | ||
node->set_domain(domain); | ||
node->set_doc_string(doc_string); | ||
node->set_op_type(op_type); | ||
for (auto& input : inputs) { | ||
node->add_input(input); | ||
} | ||
for (auto& output : outputs) { | ||
node->add_output(output); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
} // namespace ONNX_NAMESPACE |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (c) Facebook Inc. and Microsoft Corporation. | ||
// Licensed under the MIT license. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
#include "onnx/common/status.h" | ||
#include "onnx/onnx-operators_pb.h" | ||
|
||
namespace ONNX_NAMESPACE { | ||
|
||
// Helper function for register nodes in | ||
// a FunctionProto. Attributes need to be | ||
// registered separately. | ||
Common::Status BuildNode( | ||
const std::string& name, | ||
const std::string& domain, | ||
const std::string& doc_string, | ||
const std::string& op_type, | ||
std::vector<std::string> const& inputs, | ||
std::vector<std::string> const& outputs, | ||
/*OUT*/ NodeProto* node); | ||
} // namespace ONNX_NAMESPACE |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
#include <unordered_map> | ||
|
||
#include "onnx/checker.h" | ||
#include "onnx/defs/function.h" | ||
#include "onnx/defs/schema.h" | ||
#include "onnx/optimizer/optimize.h" | ||
#include "onnx/py_utils.h" | ||
|
@@ -146,6 +147,34 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { | |
return OpSchemaRegistry::get_all_schemas_with_history(); | ||
}); | ||
|
||
defs.def( | ||
"get_all_functions", | ||
[](const std::string& domain) | ||
-> std::unordered_map<std::string, std::vector<py::bytes>> { | ||
std::multimap<std::string, std::unique_ptr<FunctionProto>> temp_ptr_map; | ||
std::unordered_map<std::string, std::vector<py::bytes>> temp_map; | ||
FunctionBuilderRegistry& function_registry = | ||
FunctionBuilderRegistry::OnnxInstance(); | ||
|
||
Common::Status status = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check status? |
||
function_registry.GetFunctions(domain, &temp_ptr_map); | ||
if (!status.IsOK()) { | ||
throw std::runtime_error( | ||
"Failed to retrieve function list for domain '" + domain + "'!"); | ||
} | ||
for (auto iter = temp_ptr_map.begin(); iter != temp_ptr_map.end(); | ||
++iter) { | ||
std::string bytes; | ||
if (!iter->second->SerializeToString(&bytes)) { | ||
throw std::runtime_error( | ||
"Failed to serilize registered function for '" + iter->first + | ||
"'!"); | ||
} | ||
temp_map[iter->first].emplace_back(py::bytes(std::move(bytes))); | ||
} | ||
return temp_map; | ||
}); | ||
|
||
// Submodule `checker` | ||
auto checker = onnx_cpp2py_export.def_submodule("checker"); | ||
checker.doc() = "Checker submodule"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the type constraint info should also be inferred from its body and printed here. You may add it in separate PR and push this in firstly.