Skip to content

Commit

Permalink
add a check for loop state variables
Browse files Browse the repository at this point in the history
Signed-off-by: Chun-Wei Chen <jacky82226@gmail.com>
  • Loading branch information
jcwchen committed Feb 9, 2022
1 parent 803b8f3 commit 6aff395
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions onnx/defs/controlflow/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,18 @@ void ClearShape(TypeProto& input_type) {

void LoopInferenceFunction(InferenceContext& ctx) {
auto num_inputs = ctx.getNumInputs();
auto num_outputs = ctx.getNumOutputs();
assert(num_inputs >= 2);
auto num_loop_state_vars = num_inputs - 2; // skip 'M' and 'cond'

if (num_loop_state_vars > num_outputs) {
fail_type_inference(
"The number of loop state variables for input is larger than the number of outputs. ",
num_loop_state_vars,
" > ",
num_outputs);
}

std::vector<const TypeProto*> subgraph_input_types;
subgraph_input_types.reserve(num_inputs);

Expand Down Expand Up @@ -328,8 +337,6 @@ void LoopInferenceFunction(InferenceContext& ctx) {

// if empty(), assume inferencing was skipped
if (!subgraph_output_types.empty()) {
auto num_outputs = ctx.getNumOutputs();

// subgraph outputs the condition value first but that is only used
// internally and not returned by Loop.
if (subgraph_output_types.size() != num_outputs + 1) {
Expand Down

0 comments on commit 6aff395

Please sign in to comment.