Skip to content

Commit

Permalink
Checker should validate the node's inputs/outputs have names when its…
Browse files Browse the repository at this point in the history
… formal parameter is Variadic (onnx#3979)

* check whether there is any empty string when Variadic

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* add 2 tests

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* typo

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* fix typecheck

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* add a check for loop state variables

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>

* Revert "add a check for loop state variables"

This reverts commit 6aff395.

Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
  • Loading branch information
jcwchen authored and Bjarke Roune committed May 6, 2023
1 parent 65b71f1 commit a7d1b11
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 12 deletions.
31 changes: 20 additions & 11 deletions onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,7 @@ void OpSchema::Verify(const NodeProto& node) const {
// Check the values of inputs / outputs
for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
if (in_idx >= static_cast<int>(inputs_.size())) {
if (!inputs_.empty() && Variadic == inputs_.back().GetOption()) {
// The last input formal parameter should be variadic.
break;
} else {
if (inputs_.empty() || Variadic != inputs_.back().GetOption()) {
fail_check(
"Node (",
node.name(),
Expand All @@ -200,17 +197,22 @@ void OpSchema::Verify(const NodeProto& node) const {
") in op definition.");
}
}
if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
if ((in_idx >= static_cast<int>(inputs_.size()) && Variadic == inputs_.back().GetOption()) ||
Variadic == inputs_[in_idx].GetOption()) {
do {
if (node.input(in_idx).empty()) {
fail_check(
"Node (", node.name(), ")'s input ", in_idx, " is marked Variadic but has an empty string in the graph");
}
} while (++in_idx < node.input_size());
} else if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
fail_check("Node (", node.name(), ")'s input ", in_idx, " is marked single but has an empty string in the graph");
}
}

for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
if (out_idx >= static_cast<int>(outputs_.size())) {
if (!outputs_.empty() && Variadic == outputs_.back().GetOption()) {
// The last output formal parameter should be variadic.
break;
} else {
if (outputs_.empty() || Variadic != outputs_.back().GetOption()) {
fail_check(
"Node (",
node.name(),
Expand All @@ -221,8 +223,15 @@ void OpSchema::Verify(const NodeProto& node) const {
") in op definition.");
}
}

if (node.output(out_idx).empty() && (Single == outputs_[out_idx].GetOption())) {
if ((out_idx >= static_cast<int>(outputs_.size()) && Variadic == outputs_.back().GetOption()) ||
Variadic == outputs_[out_idx].GetOption()) {
do {
if (node.output(out_idx).empty()) {
fail_check(
"Node (", node.name(), ")'s output ", out_idx, " is marked Variadic but has an empty string in the graph");
}
} while (++out_idx < node.output_size());
} else if (node.output(out_idx).empty() && (Single == outputs_[out_idx].GetOption())) {
fail_check(
"Node (", node.name(), ")'s output ", out_idx, " is marked single but has an empty string in the graph");
}
Expand Down
120 changes: 119 additions & 1 deletion onnx/test/checker_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import unittest

from typing import Sequence, Text
from typing import List, Sequence, Text
import numpy as np # type: ignore

from onnx import checker, helper, numpy_helper, shape_inference
Expand Down Expand Up @@ -660,6 +660,124 @@ def test_loop_with_same_initializer_input_above_ir4(self) -> None:
)
self.assertRaises(shape_inference.InferenceError, checker.check_model, model, True)

def _contruct_loop_model(self, inputs_list: List[Text], outputs_list: List[Text]) -> onnx.ModelProto:
y_in = onnx.helper.make_tensor_value_info('y_in', onnx.TensorProto.FLOAT, [1])
y_out = onnx.helper.make_tensor_value_info('y_out', onnx.TensorProto.FLOAT, [1])
scan_out = onnx.helper.make_tensor_value_info('scan_out', onnx.TensorProto.FLOAT, [1])
cond_in = onnx.helper.make_tensor_value_info('cond_in', onnx.TensorProto.BOOL, [])
cond_out = onnx.helper.make_tensor_value_info('cond_out', onnx.TensorProto.BOOL, [])
iter_count = onnx.helper.make_tensor_value_info('iter_count', onnx.TensorProto.INT64, [])

x_const_node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['x'],
value=onnx.helper.make_tensor(
name='const_tensor_x',
data_type=onnx.TensorProto.FLOAT,
dims=[5],
vals=[1., 2., 3., 4., 5.],
)
)

one_const_node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['one'],
value=onnx.helper.make_tensor(
name='const_tensor_one',
data_type=onnx.TensorProto.INT64,
dims=(),
vals=[1]
)
)

i_add_node = onnx.helper.make_node(
'Add',
inputs=['iter_count', 'one'],
outputs=['end']
)

start_unsqueeze_node = onnx.helper.make_node(
'Unsqueeze',
inputs=['iter_count'],
outputs=['slice_start'],
axes=[0]
)

end_unsqueeze_node = onnx.helper.make_node(
'Unsqueeze',
inputs=['end'],
outputs=['slice_end'],
axes=[0]
)

slice_node = onnx.helper.make_node(
'Slice',
inputs=['x', 'slice_start', 'slice_end'],
outputs=['slice_out']
)

y_add_node = onnx.helper.make_node(
'Add',
inputs=['y_in', 'slice_out'],
outputs=['y_out']
)

identity_node = onnx.helper.make_node(
'Identity',
inputs=['cond_in'],
outputs=['cond_out']
)

scan_identity_node = onnx.helper.make_node(
'Identity',
inputs=['y_out'],
outputs=['scan_out']
)

loop_body = onnx.helper.make_graph(
[identity_node, x_const_node, one_const_node, i_add_node,
start_unsqueeze_node, end_unsqueeze_node, slice_node, y_add_node,
scan_identity_node],
'loop_body',
[iter_count, cond_in, y_in],
[cond_out, y_out, scan_out]
)

node = onnx.helper.make_node(
'Loop',
inputs=inputs_list,
outputs=outputs_list,
body=loop_body
)

model = helper.make_model(
opset_imports=[onnx.helper.make_opsetid("", 11)],
graph=helper.make_graph(
name='test-loop',
inputs=[
helper.make_tensor_value_info('trip_count', TensorProto.INT64, shape=[5]),
helper.make_tensor_value_info('cond', TensorProto.BOOL, shape=[1]),
helper.make_tensor_value_info('y', TensorProto.FLOAT, shape=[1]),
],
outputs=[
helper.make_tensor_value_info('cond', TensorProto.FLOAT, shape=[13]),
helper.make_tensor_value_info('res_scan', TensorProto.FLOAT, shape=[5, 1])
],
nodes=[node],
),
)
return model

def test_loop_with_empty_input(self) -> None:
model = self._contruct_loop_model(['trip_count', 'cond', ''], ['res_y', 'res_scan'])
self.assertRaises(checker.ValidationError, checker.check_model, model)

def test_loop_with_empty_output(self) -> None:
model = self._contruct_loop_model(['trip_count', 'cond', 'y'], ['', 'res_scan'])
self.assertRaises(checker.ValidationError, checker.check_model, model)


if __name__ == '__main__':
unittest.main()

0 comments on commit a7d1b11

Please sign in to comment.