Skip to content

Commit

Permalink
revert onnx#3979 (onnx#4283)
Browse files Browse the repository at this point in the history
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
  • Loading branch information
jcwchen authored and Bjarke Roune committed May 6, 2023
1 parent 9f83f43 commit d80284f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 143 deletions.
35 changes: 11 additions & 24 deletions onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ void OpSchema::Verify(const NodeProto& node) const {
// Check the values of inputs / outputs
for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
if (in_idx >= static_cast<int>(inputs_.size())) {
if (inputs_.empty() || Variadic != inputs_.back().GetOption()) {
if (!inputs_.empty() && Variadic == inputs_.back().GetOption()) {
// The last input formal parameter should be variadic.
break;
} else {
fail_check(
"Node (",
node.name(),
Expand All @@ -197,22 +200,17 @@ void OpSchema::Verify(const NodeProto& node) const {
") in op definition.");
}
}
if ((in_idx >= static_cast<int>(inputs_.size()) && Variadic == inputs_.back().GetOption()) ||
Variadic == inputs_[in_idx].GetOption()) {
do {
if (node.input(in_idx).empty()) {
fail_check(
"Node (", node.name(), ")'s input ", in_idx, " is marked Variadic but has an empty string in the graph");
}
} while (++in_idx < node.input_size());
} else if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
if (node.input(in_idx).empty() && (Single == inputs_[in_idx].GetOption())) {
fail_check("Node (", node.name(), ")'s input ", in_idx, " is marked single but has an empty string in the graph");
}
}

for (int out_idx = 0; out_idx < node.output_size(); ++out_idx) {
if (out_idx >= static_cast<int>(outputs_.size())) {
if (outputs_.empty() || Variadic != outputs_.back().GetOption()) {
if (!outputs_.empty() && Variadic == outputs_.back().GetOption()) {
// The last output formal parameter should be variadic.
break;
} else {
fail_check(
"Node (",
node.name(),
Expand All @@ -223,19 +221,8 @@ void OpSchema::Verify(const NodeProto& node) const {
") in op definition.");
}
}
if ((out_idx >= static_cast<int>(outputs_.size()) && Variadic == outputs_.back().GetOption()) ||
Variadic == outputs_[out_idx].GetOption()) {
do {
if (node.output(out_idx).empty()) {
fail_check(
"Node (",
node.name(),
")'s output ",
out_idx,
" is marked Variadic but has an empty string in the graph");
}
} while (++out_idx < node.output_size());
} else if (node.output(out_idx).empty() && (Single == outputs_[out_idx].GetOption())) {

if (node.output(out_idx).empty() && (Single == outputs_[out_idx].GetOption())) {
fail_check(
"Node (", node.name(), ")'s output ", out_idx, " is marked single but has an empty string in the graph");
}
Expand Down
120 changes: 1 addition & 119 deletions onnx/test/checker_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import unittest

from typing import List, Sequence
from typing import Sequence
import numpy as np # type: ignore

from onnx import checker, helper, numpy_helper, shape_inference
Expand Down Expand Up @@ -660,124 +660,6 @@ def test_loop_with_same_initializer_input_above_ir4(self) -> None:
)
self.assertRaises(shape_inference.InferenceError, checker.check_model, model, True)

def _contruct_loop_model(self, inputs_list: List[str], outputs_list: List[str]) -> onnx.ModelProto:
y_in = onnx.helper.make_tensor_value_info('y_in', onnx.TensorProto.FLOAT, [1])
y_out = onnx.helper.make_tensor_value_info('y_out', onnx.TensorProto.FLOAT, [1])
scan_out = onnx.helper.make_tensor_value_info('scan_out', onnx.TensorProto.FLOAT, [1])
cond_in = onnx.helper.make_tensor_value_info('cond_in', onnx.TensorProto.BOOL, [])
cond_out = onnx.helper.make_tensor_value_info('cond_out', onnx.TensorProto.BOOL, [])
iter_count = onnx.helper.make_tensor_value_info('iter_count', onnx.TensorProto.INT64, [])

x_const_node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['x'],
value=onnx.helper.make_tensor(
name='const_tensor_x',
data_type=onnx.TensorProto.FLOAT,
dims=[5],
vals=[1., 2., 3., 4., 5.],
)
)

one_const_node = onnx.helper.make_node(
'Constant',
inputs=[],
outputs=['one'],
value=onnx.helper.make_tensor(
name='const_tensor_one',
data_type=onnx.TensorProto.INT64,
dims=(),
vals=[1]
)
)

i_add_node = onnx.helper.make_node(
'Add',
inputs=['iter_count', 'one'],
outputs=['end']
)

start_unsqueeze_node = onnx.helper.make_node(
'Unsqueeze',
inputs=['iter_count'],
outputs=['slice_start'],
axes=[0]
)

end_unsqueeze_node = onnx.helper.make_node(
'Unsqueeze',
inputs=['end'],
outputs=['slice_end'],
axes=[0]
)

slice_node = onnx.helper.make_node(
'Slice',
inputs=['x', 'slice_start', 'slice_end'],
outputs=['slice_out']
)

y_add_node = onnx.helper.make_node(
'Add',
inputs=['y_in', 'slice_out'],
outputs=['y_out']
)

identity_node = onnx.helper.make_node(
'Identity',
inputs=['cond_in'],
outputs=['cond_out']
)

scan_identity_node = onnx.helper.make_node(
'Identity',
inputs=['y_out'],
outputs=['scan_out']
)

loop_body = onnx.helper.make_graph(
[identity_node, x_const_node, one_const_node, i_add_node,
start_unsqueeze_node, end_unsqueeze_node, slice_node, y_add_node,
scan_identity_node],
'loop_body',
[iter_count, cond_in, y_in],
[cond_out, y_out, scan_out]
)

node = onnx.helper.make_node(
'Loop',
inputs=inputs_list,
outputs=outputs_list,
body=loop_body
)

model = helper.make_model(
opset_imports=[onnx.helper.make_opsetid("", 11)],
graph=helper.make_graph(
name='test-loop',
inputs=[
helper.make_tensor_value_info('trip_count', TensorProto.INT64, shape=[5]),
helper.make_tensor_value_info('cond', TensorProto.BOOL, shape=[1]),
helper.make_tensor_value_info('y', TensorProto.FLOAT, shape=[1]),
],
outputs=[
helper.make_tensor_value_info('cond', TensorProto.FLOAT, shape=[13]),
helper.make_tensor_value_info('res_scan', TensorProto.FLOAT, shape=[5, 1])
],
nodes=[node],
),
)
return model

def test_loop_with_empty_input(self) -> None:
model = self._contruct_loop_model(['trip_count', 'cond', ''], ['res_y', 'res_scan'])
self.assertRaises(checker.ValidationError, checker.check_model, model)

def test_loop_with_empty_output(self) -> None:
model = self._contruct_loop_model(['trip_count', 'cond', 'y'], ['', 'res_scan'])
self.assertRaises(checker.ValidationError, checker.check_model, model)


if __name__ == '__main__':
unittest.main()

0 comments on commit d80284f

Please sign in to comment.