Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from __future__ import print_function
from __future__ import unicode_literals

import unittest

import numpy as np
from onnx import helper, TensorProto, OperatorSetIdProto
from tf2onnx import utils
Expand Down Expand Up @@ -1099,10 +1097,9 @@ def test_transpose_back_to_back_non_const(self):

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["res"], {"u": np.random.randn(5, 5, 5, 5).astype(np.float32)},
model_proto, remaining_transpose_num=2)
model_proto, remaining_transpose_num=1)

# @check_opset_min_version(9, "string type tensor")
@unittest.skip("FIXME: disabled because of crash on linux/ortnightly")
@check_opset_min_version(9, "string type tensor")
def test_cast_back_to_back_non_const_mixed_types(self):
node0 = helper.make_node("Cast", ["u"], ["v"], to=11, name="cast_0") # double
node1 = helper.make_node("Cast", ["v"], ["w"], to=6, name="cast_1") # int32
Expand All @@ -1113,11 +1110,13 @@ def test_cast_back_to_back_non_const_mixed_types(self):
node5 = helper.make_node("Cast", ["w2"], ["res2"], to=7, name="cast_5") # int64

node6 = helper.make_node("Cast", ["x"], ["x2"], to=9, name="cast_6") # bool
node7 = helper.make_node("Cast", ["x2"], ["x3"], to=8, name="cast_7") # string
node8 = helper.make_node("Cast", ["x3"], ["res3"], to=3, name="cast_8") # int8
# TODO: uncomment below after fix
# https://github.com/microsoft/onnxruntime/issues/2338
# node7 = helper.make_node("Cast", ["x2"], ["x3"], to=8, name="cast_7") # string
node8 = helper.make_node("Cast", ["x2"], ["res3"], to=3, name="cast_8") # int8

graph = helper.make_graph(
[node0, node1, node2, node3, node4, node5, node6, node7, node8],
[node0, node1, node2, node3, node4, node5, node6, node8],
"test-cast-back-to-back-non-const",
[helper.make_tensor_value_info("u", TensorProto.FLOAT, (1, 2, 3))],
[helper.make_tensor_value_info("res", TensorProto.INT64, (1, 2, 3)),
Expand Down
14 changes: 8 additions & 6 deletions tf2onnx/optimizer/back_to_back_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ def _optimize_cast(g, node, consumer_nodes):
g.remove_node(node.name)
return q2

# TODO: reactivate after fixing interference with transpose_optimizer
@staticmethod
@_register_func("_Transpose")
@_register_func("Transpose")
def _optimize_transpose(g, node, consumer_nodes):
t1 = list(node.get_attr('perm').ints)
q2 = []
Expand All @@ -132,13 +131,16 @@ def _optimize_transpose(g, node, consumer_nodes):
t2 = list(node2.get_attr('perm').ints)
new_perm = [t1[i] for i in t2]
# check if node2 can be removed. otherwise only update
if new_perm == list(range(len(t2))) \
and not set(node2.output) & set(g.outputs):
if new_perm == list(range(len(t2))):
# both nodes can be deleted
shape = g.get_shape(node2.output[0])
dtype = g.get_dtype(node2.output[0])
node2_consumers = g.find_output_consumers(node2.output[0])
for consumer in node2_consumers:
consumer.input[0] = node.input[0]
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0])
g.remove_node(node2.name)
if set(node2.output) & set(g.outputs):
g.make_node("Identity", [node.input[0]],
outputs=node2.output, shapes=[shape], dtypes=[dtype])
else:
node2.set_attr('perm', [t1[i] for i in t2])
q2.append(node2.output[0])
Expand Down