Skip to content

Commit

Permalink
Merge pull request #1238 from onnx/tom/multiscanout
Browse files Browse the repository at this point in the history
Add support for multiple scan outputs
  • Loading branch information
TomWildenhain-Microsoft committed Dec 17, 2020
2 parents 06c1a32 + 0013b3a commit ac2f675
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
34 changes: 34 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,40 @@ def b(i, out_ta):
output_names_with_port = ["i:0", "output_ta:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)

def test_while_loop_with_multi_scan_outputs(self):
def func(i, inputs1, inputs2):
inputs1_ = tf.identity(inputs1)
inputs2_ = tf.identity(inputs2)
input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs1_)
input_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs2_)
output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
output_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)

def b(i, out_ta, out_ta2):
new_i = tf.add(i, 1)
x = input_ta.read(i)
y = input_ta2.read(i)
z = x + 3 + y
p = x * y * 2
out_ta_new = out_ta.write(i, z)
out_ta_new2 = out_ta2.write(i, p)
return new_i, out_ta_new, out_ta_new2

i_final, out_final, out_final2 = tf.while_loop(c, b, [i, output_ta, output_ta2])
i_final_ = tf.identity(i_final, name="i")
out_final_ = tf.identity(out_final.stack(), name="output_ta")
out_final2_ = tf.identity(out_final2.stack(), name="output_ta2")
return i_final_, out_final_, out_final2_

input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
feed_dict = {"input_1:0": np.array(0, dtype=np.int32),
"input_2:0": np.array([2.0, 16.0, 5.0, 1.6, 5.0, 6.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32),
"input_3:0": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 16.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32)}
output_names_with_port = ["i:0", "output_ta:0", "output_ta2:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)

@check_onnxruntime_min_version(
"0.5.0",
"disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"
Expand Down
23 changes: 11 additions & 12 deletions tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def version_7(cls, ctx, node, **kwargs):
del output_names[idx]
del body.outputs[idx]

removed_scan_outputs = {}
scan_output_names = []
# remove tensor array that are passed in to the loop
for idx, n in reversed(to_remove):
ctx.remove_node(n.name)
Expand All @@ -430,19 +430,15 @@ def version_7(cls, ctx, node, **kwargs):
del body.func_inputs[idx]
del cond_graph.func_inputs[idx]
del tf_while_inputs[idx]
# save the index of the scan output
removed_scan_outputs[body.outputs[idx]] = idx
scan_output_names.append(body.outputs[idx])
del body.outputs[idx]
# FIXME: Output shapes may be in wrong order if there are multiple scan outputs
output_shapes.append(output_shapes[idx])
output_dtypes.append(output_dtypes[idx])
output_names.append(output_names[idx])
del output_shapes[idx]
del output_dtypes[idx]
del output_names[idx]

utils.make_sure(len(removed_scan_outputs) <= 1, "converter only supports while loops with a single scan output")

ctx.remove_node(node.name)

# In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph
Expand All @@ -467,7 +463,7 @@ def version_7(cls, ctx, node, **kwargs):
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()

wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names)

# if there was a tensorflow variant type, bind in a real type here
# FIXME: I don't think this is needed anymore
Expand All @@ -477,7 +473,7 @@ def version_7(cls, ctx, node, **kwargs):


def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
output_dtypes, scope, parent, cond_graph, tf_while_inputs, removed_scan_outputs):
output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names):
"""Wire subgraph graph into main."""
remove_parents = []
to_remove = []
Expand Down Expand Up @@ -521,9 +517,10 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
g.replace_inputs(node, [node.input[2]])
scan_outputs.append(node.output[0])

if len(scan_outputs) != len(removed_scan_outputs):
if len(scan_outputs) != len(scan_output_names):
raise ValueError("While loop couldn't find scan output index for nodes")

names_to_scan_outputs = {}
for output in scan_outputs:
last_output = output
consumers = g.find_output_consumers(last_output)
Expand All @@ -533,10 +530,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
raise ValueError("While loop couldn't find scan output index for node " + node.name)
last_output = node.output[0]
consumers = g.find_output_consumers(last_output)
if last_output not in removed_scan_outputs:
if last_output not in scan_output_names:
raise ValueError("While loop couldn't find scan output index for node " + node.name)
# TODO: store index to ensure scan outputs are in correct order for multiple outputs
# initial_output_index = removed_scan_outputs[last_output]
names_to_scan_outputs[last_output] = output

# Reorder scan outputs
scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names]

# remove all nodes feeding to TensorListSetItem's reserved tensor
while remove_parents:
Expand Down

0 comments on commit ac2f675

Please sign in to comment.