Skip to content

Commit

Permalink
Resolve PR comments; Refine Function type check implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Raymond Yang committed Jun 22, 2018
1 parent cb82317 commit 1274128
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 44 deletions.
94 changes: 74 additions & 20 deletions onnx/defs/experiments/functions.cc
Expand Up @@ -4,7 +4,7 @@
#include "onnx/defs/function.h"
using namespace ONNX_NAMESPACE;

static Common::Status BuildFc(std::unique_ptr<FunctionProto>* func_proto) {
static Common::Status BuildMVN(std::unique_ptr<FunctionProto>* func_proto) {
if (nullptr == func_proto) {
return Status(
Common::CHECKER,
Expand All @@ -14,31 +14,85 @@ static Common::Status BuildFc(std::unique_ptr<FunctionProto>* func_proto) {

func_proto->reset(new FunctionProto);
auto& func = **func_proto;
func.set_name("FC");
func.set_doc_string("this is a full connection function.");
func.set_name("MeanVarianceNormalization");
func.set_doc_string(
"A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X");
func.set_since_version(8);
func.add_input("w");
func.add_input("x");
func.add_input("b");
func.add_output("y");
func.add_input("X");
func.add_input("Pow_exponent");
func.add_output("X_MVN");
func.add_attribute("axes");
NodeProto* node0 = func.add_node();
node0->set_name("node0");
node0->set_name("Reduced_Mean_0");
node0->set_domain("");
node0->set_doc_string("This is a matmul testing node ");
node0->set_op_type("MatMul");
node0->add_input("w");
node0->add_input("x");
node0->add_output("y_1");
node0->set_doc_string("Caculating Reduced Mean on input tensor X");
node0->set_op_type("ReduceMean");
node0->add_input("X");
node0->add_output("X_RM");
AttributeProto* attr0 = node0->add_attribute();
attr0->set_ref_attr_name("axes");
attr0->set_name("axes");
attr0->set_type(AttributeProto_AttributeType_INTS);
NodeProto* node1 = func.add_node();
node1->set_name("node1");
node1->set_name("Pow_0");
node1->set_domain("");
node1->set_doc_string("This is a add testing node ");
node1->set_op_type("Add");
node1->add_input("y_1");
node1->add_input("b");
node1->add_output("y");
node1->set_doc_string("Caculating (EX)^2");
node1->set_op_type("Pow");
node1->add_input("X_RM");
node1->add_input("Pow_exponent");
node1->add_output("EX_POW");
NodeProto* node2 = func.add_node();
node2->set_name("Pow_1");
node2->set_domain("");
node2->set_doc_string("Caculating X^2");
node2->set_op_type("Pow");
node2->add_input("X");
node2->add_input("Pow_exponent");
node2->add_output("X_POW");
NodeProto* node3 = func.add_node();
node3->set_name("Reduced_Mean_1");
node3->set_domain("");
node3->set_doc_string("Caculating E(X^2)");
node3->set_op_type("ReduceMean");
node3->add_input("X_POW");
node3->add_output("E_XPOW");
AttributeProto* attr1 = node3->add_attribute();
attr1->set_ref_attr_name("axes");
attr1->set_name("axes");
attr1->set_type(AttributeProto_AttributeType_INTS);
NodeProto* node4 = func.add_node();
node4->set_name("SUB_0");
node4->set_domain("");
node4->set_doc_string("Caculating variance (E(X^2)-(EX)^2)");
node4->set_op_type("Sub");
node4->add_input("EX_POW");
node4->add_input("E_XPOW");
node4->add_output("VAR");
NodeProto* node5 = func.add_node();
node5->set_name("SQRT_0");
node5->set_domain("");
node5->set_doc_string("Caculating standard variance from variance");
node5->set_op_type("Sqrt");
node5->add_input("VAR");
node5->add_output("STD_VAR");
NodeProto* node6 = func.add_node();
node6->set_name("SUB_1");
node6->set_domain("");
node6->set_doc_string("Caculating X-EX");
node6->set_op_type("Sub");
node6->add_input("X");
node6->add_input("X_RM");
node6->add_output("X_VAR");
NodeProto* node7 = func.add_node();
node7->set_name("DIV_0");
node7->set_domain("");
node7->set_doc_string("Caculating X-EX");
node7->set_op_type("Div");
node7->add_input("X_VAR");
node7->add_input("STD_VAR");
node7->add_output("X_MVN");

return Status::OK();
}

ONNX_FUNCTION(FunctionBuilder().SetDomain("").SetBuildFunction(BuildFc));
ONNX_FUNCTION(FunctionBuilder().SetDomain("").SetBuildFunction(BuildMVN));
1 change: 1 addition & 0 deletions onnx/defs/function.cc
Expand Up @@ -66,6 +66,7 @@ Status FunctionBuilderRegistry::GetFunctions(
? function_proto->since_version()
: version_range.second});
ctx.set_opset_imports(op_set);
ctx.set_is_main_graph(false);
LexicalScopeContext lex_ctx;
try {
check_function(*function_proto, ctx, lex_ctx);
Expand Down
23 changes: 23 additions & 0 deletions onnx/defs/function.h
Expand Up @@ -57,4 +57,27 @@ class FunctionBuilderRegistry {
static Common::Status function_builder_##counter##_status = \
FunctionBuilderRegistry::OnnxInstance().Register(function_builder);

// Example to register a function.
// Common::Status BuildFc(std::unique_ptr<FunctionProto>* func_proto) {
// if (nullptr == func_proto) {
// return Status(
// Common::CHECKER,
// Common::INVALID_ARGUMENT,
// "func_proto should not be nullptr.");
// }
//
// func_proto->reset(new FunctionProto);
// auto& func = **func_proto;
// func.set_name("FC");
// set function inputs.
// set function outputs.
// set function attributes.
// set function description.
// set function body (nodes).
//
// return Status::OK();
//}
//
// ONNX_FUNCTION(FunctionBuilder().SetDomain("").SetBuildFunction(BuildFc));

} // namespace ONNX_NAMESPACE
60 changes: 39 additions & 21 deletions onnx/defs/schema.cc
Expand Up @@ -140,20 +140,19 @@ void OpSchema::Verify(const NodeProto& node) const {

for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
if (out_idx >= static_cast<int>(outputs_.size())) {
if (outputs_.size() > 0 && Variadic == outputs_.back().GetOption()) {
// The last output formal parameter should be variadic.
break;
}
else {
fail_check(
"Node (",
node.name(),
") has more outputs (",
node.output_size(),
") than declared (",
outputs_.size(),
") in op definition.");
}
if (outputs_.size() > 0 && Variadic == outputs_.back().GetOption()) {
// The last output formal parameter should be variadic.
break;
} else {
fail_check(
"Node (",
node.name(),
") has more outputs (",
node.output_size(),
") than declared (",
outputs_.size(),
") in op definition.");
}
}

if (node.output(out_idx).empty() &&
Expand Down Expand Up @@ -191,6 +190,14 @@ void OpSchema::Verify(const NodeProto& node) const {
fail_check("Unrecognized attribute: ", name);
}

if (attr_proto.has_ref_attr_name()) {
if (!attr_proto.has_type() || attr_proto.type() != expected_type) {
fail_check(
"Unmatch attribute type in '", node.name() + " : " + name, "'");
}
continue;
}

switch (expected_type) {
case AttributeProto::FLOAT:
if (!attr_proto.has_f()) {
Expand Down Expand Up @@ -269,20 +276,23 @@ OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) {
}

OpSchema& OpSchema::NumInputs(std::set<int> allowed_input_nums) {
num_inputs_allowed_ = [MOVE_CAPTURE_IF_CPP14(allowed_input_nums)](int n) -> bool {
num_inputs_allowed_ =
[MOVE_CAPTURE_IF_CPP14(allowed_input_nums)](int n) -> bool {
return allowed_input_nums.count(n);
};
return *this;
}

OpSchema& OpSchema::NumOutputs(std::set<int> allowed_output_nums) {
num_outputs_allowed_ = [MOVE_CAPTURE_IF_CPP14(allowed_output_nums)](int n) -> bool {
num_outputs_allowed_ =
[MOVE_CAPTURE_IF_CPP14(allowed_output_nums)](int n) -> bool {
return allowed_output_nums.count(n);
};
return *this;
}

OpSchema& OpSchema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction) {
OpSchema& OpSchema::TypeAndShapeInferenceFunction(
InferenceFunction inferenceFunction) {
tensor_inference_function_ = inferenceFunction;
return *this;
}
Expand Down Expand Up @@ -464,7 +474,11 @@ OpSchema& OpSchema::Input(
if (int(inputs_.size()) <= n) {
inputs_.resize(n + 1);
}
inputs_[n] = FormalParameter(std::move(name), std::move(description), std::move(type_str), param_option);
inputs_[n] = FormalParameter(
std::move(name),
std::move(description),
std::move(type_str),
param_option);
return *this;
}

Expand All @@ -491,7 +505,11 @@ OpSchema& OpSchema::Output(
if (int(outputs_.size()) <= n) {
outputs_.resize(n + 1);
}
outputs_[n] = FormalParameter(std::move(name), std::move(description), std::move(type_str), param_option);
outputs_[n] = FormalParameter(
std::move(name),
std::move(description),
std::move(type_str),
param_option);
return *this;
}

Expand Down Expand Up @@ -520,8 +538,8 @@ OpSchema& OpSchema::TypeConstraint(
}
type_constraints_.insert(
std::make_pair(type_str, std::make_pair(d, description)));
type_constraint_params_.push_back(
TypeConstraintParam(std::move(type_str), std::move(constraints), std::move(description)));
type_constraint_params_.push_back(TypeConstraintParam(
std::move(type_str), std::move(constraints), std::move(description)));
return *this;
}

Expand Down
7 changes: 4 additions & 3 deletions onnx/test/c++/function_get_test.cc
Expand Up @@ -11,10 +11,11 @@ TEST(FunctionAPITest, Get_All_Functions) {
Common::Status status = function_registry.GetFunctions("", &temp_map);
size_t input_size = temp_map.size();
EXPECT_EQ(input_size, 1);
EXPECT_EQ(temp_map.count("FC"), 1);
auto temp_iter = temp_map.find("FC");
EXPECT_EQ(temp_map.count("MeanVarianceNormalization"), 1);
auto temp_iter = temp_map.find("MeanVarianceNormalization");
EXPECT_EQ(
temp_iter->second->doc_string(), "this is a full connection function.");
temp_iter->second->doc_string(),
"A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X");
}
} // namespace Test
} // namespace ONNX_NAMESPACE

0 comments on commit 1274128

Please sign in to comment.