diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index f6404f686..650d04463 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -7,8 +7,6 @@ from __future__ import print_function from __future__ import unicode_literals -import unittest - import numpy as np from onnx import helper, TensorProto, OperatorSetIdProto from tf2onnx import utils @@ -1099,10 +1097,9 @@ def test_transpose_back_to_back_non_const(self): model_proto = self.make_model(graph, producer_name="onnx-tests") self.run_transpose_compare(["res"], {"u": np.random.randn(5, 5, 5, 5).astype(np.float32)}, - model_proto, remaining_transpose_num=2) + model_proto, remaining_transpose_num=1) - # @check_opset_min_version(9, "string type tensor") - @unittest.skip("FIXME: disabled because of crash on linux/ortnightly") + @check_opset_min_version(9, "string type tensor") def test_cast_back_to_back_non_const_mixed_types(self): node0 = helper.make_node("Cast", ["u"], ["v"], to=11, name="cast_0") # double node1 = helper.make_node("Cast", ["v"], ["w"], to=6, name="cast_1") # int32 @@ -1113,11 +1110,13 @@ def test_cast_back_to_back_non_const_mixed_types(self): node5 = helper.make_node("Cast", ["w2"], ["res2"], to=7, name="cast_5") # int64 node6 = helper.make_node("Cast", ["x"], ["x2"], to=9, name="cast_6") # bool - node7 = helper.make_node("Cast", ["x2"], ["x3"], to=8, name="cast_7") # string - node8 = helper.make_node("Cast", ["x3"], ["res3"], to=3, name="cast_8") # int8 + # TODO: uncomment below after fix + # https://github.com/microsoft/onnxruntime/issues/2338 + # node7 = helper.make_node("Cast", ["x2"], ["x3"], to=8, name="cast_7") # string + node8 = helper.make_node("Cast", ["x2"], ["res3"], to=3, name="cast_8") # int8 graph = helper.make_graph( - [node0, node1, node2, node3, node4, node5, node6, node7, node8], + [node0, node1, node2, node3, node4, node5, node6, node8], "test-cast-back-to-back-non-const", [helper.make_tensor_value_info("u", TensorProto.FLOAT, (1, 2, 3))], [helper.make_tensor_value_info("res", TensorProto.INT64, (1, 2, 3)), diff --git a/tf2onnx/optimizer/back_to_back_optimizer.py b/tf2onnx/optimizer/back_to_back_optimizer.py index 8d30fa9a7..e9db8b07e 100644 --- a/tf2onnx/optimizer/back_to_back_optimizer.py +++ b/tf2onnx/optimizer/back_to_back_optimizer.py @@ -121,9 +121,8 @@ def _optimize_cast(g, node, consumer_nodes): g.remove_node(node.name) return q2 - # TODO: reactivate after fixing interference with transpose_optimizer @staticmethod - @_register_func("_Transpose") + @_register_func("Transpose") def _optimize_transpose(g, node, consumer_nodes): t1 = list(node.get_attr('perm').ints) q2 = [] @@ -132,13 +131,16 @@ def _optimize_transpose(g, node, consumer_nodes): t2 = list(node2.get_attr('perm').ints) new_perm = [t1[i] for i in t2] # check if node2 can be removed. otherwise only update - if new_perm == list(range(len(t2))) \ - and not set(node2.output) & set(g.outputs): + if new_perm == list(range(len(t2))): # both nodes can be deleted + shape = g.get_shape(node2.output[0]) + dtype = g.get_dtype(node2.output[0]) node2_consumers = g.find_output_consumers(node2.output[0]) - for consumer in node2_consumers: - consumer.input[0] = node.input[0] + g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0]) g.remove_node(node2.name) + if set(node2.output) & set(g.outputs): + g.make_node("Identity", [node.input[0]], + outputs=node2.output, shapes=[shape], dtypes=[dtype]) else: node2.set_attr('perm', [t1[i] for i in t2]) q2.append(node2.output[0])