Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
from onnx import helper, numpy_helper, TensorProto, OperatorSetIdProto
from parameterized import parameterized

from backend_test_base import Tf2OnnxBackendTestBase
from common import unittest_main, group_nodes_by_type, check_opset_min_version, check_opset_max_version, get_test_config
from tf2onnx import utils, constants
Expand Down Expand Up @@ -309,21 +310,31 @@ def test_transpose_dequantize_with_axis(self, shape, perm_input, perm_output):
model_proto, remaining_transpose_num=0)

@parameterized.expand([
((2, 3, 4, 5), [1, 2, 1, 2], (1, 2, 2, 1), [0, 2, 3, 1], [0, 3, 1, 2]),
((2, 3, 4, 5, 6), [1, 2, 1, 2, 1], (1, 1, 2, 1, 2), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
])
@check_opset_min_version(10, "Slice in opset 10 can accept dymaic 'start' and 'ends'")
def test_transpose_slice(self, input_shape, slice_size, output_shape, perm_input, perm_output):
starts = np.array([0] * len(input_shape), dtype=np.int64)
ends = np.array(slice_size, dtype=np.int64)
axes = np.array(list(range(len(input_shape))), dtype=np.int64)
@check_opset_max_version(9, "Slice in opset 9 and takes 'axes, 'start' and 'ends' as attributes")
def test_transpose_slice(self, input_shape, slice_size, axes, perm_input, perm_output):
axes = np.array(axes, dtype=np.int64)
starts = np.array([0] * axes.size, dtype=np.int64)
ends = []
for i in range(axes.size):
ends.append(slice_size[axes[i]])
ends = np.array(ends, dtype=np.int64)
output_shape = input_shape.copy()
for axis in axes:
output_shape[perm_input[axis]] = slice_size[axis]
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="relu")
node2 = helper.make_node("Slice", ["Y"], ["Z"], starts=starts, ends=ends, axes=axes, name="slice")
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")

graph = helper.make_graph(
[node1, node2, node3],
"relu-test",
"slice-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
[
Expand All @@ -337,6 +348,45 @@ def test_transpose_slice(self, input_shape, slice_size, output_shape, perm_input
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@parameterized.expand([
([2, 3, 4, 5], [1, 2, 1, 2], [1], [0, 2, 3, 1], [0, 3, 1, 2]),
([2, 3, 4, 5], [1, 2, 1, 2], [1, 2], [0, 2, 3, 1], [0, 3, 1, 2]),
([2, 3, 4, 5], [1, 2, 1, 2], [0, 1, 2, 3], [0, 2, 3, 1], [0, 3, 1, 2]),
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [2, 3], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
([2, 3, 4, 5, 6], [1, 2, 1, 2, 1], [0, 1, 2, 3, 4], [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
])
@check_opset_min_version(10, "Slice in opset 10 can accept dynamic 'start' and 'ends'")
def test_transpose_slice_opset_10(self, input_shape, slice_size, axes, perm_input, perm_output):
axes = np.array(axes, dtype=np.int32)
starts = np.array([0] * axes.size, dtype=np.int32)
ends = []
for i in range(axes.size):
ends.append(slice_size[axes[i]])
ends = np.array(ends, dtype=np.int32)
output_shape = input_shape.copy()
for axis in axes:
output_shape[perm_input[axis]] = slice_size[axis]
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm_input, name="trans_1")
node2 = helper.make_node("Slice", ["Y", "starts", "ends", "axes"], ["Z"], name="slice")
node3 = helper.make_node("Transpose", ["Z"], ["Z1"], perm=perm_output, name="trans_2")

graph = helper.make_graph(
[node1, node2, node3],
"slice-test",
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
[helper.make_tensor_value_info("Z1", TensorProto.FLOAT, output_shape)],
[
helper.make_tensor("starts", TensorProto.INT32, starts.shape, starts),
helper.make_tensor("ends", TensorProto.INT32, ends.shape, ends),
helper.make_tensor("axes", TensorProto.INT32, axes.shape, axes)
]
)

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_transpose_compare(["Z1"], {"X": np.random.randn(*input_shape).astype(np.float32)},
model_proto, remaining_transpose_num=0)

@parameterized.expand([
((2, 3, 4, 5), (2, 4, 5, 3), [0, 2, 3, 1], [0, 3, 1, 2]),
((2, 3, 4, 5, 6), (2, 4, 5, 6, 3), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]),
Expand Down
40 changes: 21 additions & 19 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,25 +712,27 @@ def _slice_handler(self, trans, node):
if not axes_values:
return False
axes = axes_values.ints
if axes == list(range(trans_rank)):
new_axes = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
node.set_attr("axes", new_axes)
return self._switch_transpose_and_node(node, trans)
else: # in opset 10, axes is input instead of an attribute.
if len(node.inputs) >= 4 and node.inputs[3].is_const():
axes = node.inputs[3].get_tensor_value(as_list=True)
if axes == list(range(trans_rank)):
axes = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
# axes node might be shared
new_axes = np.array(axes, dtype=np.int64)
if self._nodes_has_single_consumer_node([node.inputs[3]]):
node.inputs[3].set_tensor_value(new_axes)
else:
new_axes_const = self._g.make_const(
utils.make_name(node.inputs[3].name), new_axes
)
self._g.replace_input(node, node.input[3], new_axes_const.output[0], 3)
return self._switch_transpose_and_node(node, trans)
perm = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
new_axes = [perm[axes[i]] for i in range(len(axes))]
node.set_attr("axes", new_axes)
return self._switch_transpose_and_node(node, trans)
# in opset 10, axes is input instead of an attribute.
if len(node.inputs) >= 4 and node.inputs[3].is_const():
axes = node.inputs[3].get_tensor_value(as_list=False)
dtype = axes.dtype
axes = axes.tolist()
perm = NCHW_TO_NHWC if trans_rank == 4 else NCDHW_TO_NDHWC
axes = [perm[axes[i]] for i in range(len(axes))]
# axes node might be shared
new_axes = np.array(axes, dtype=dtype)
if self._nodes_has_single_consumer_node([node.inputs[3]]):
node.inputs[3].set_tensor_value(new_axes)
else:
new_axes_const = self._g.make_const(
utils.make_name(node.inputs[3].name), new_axes
)
self._g.replace_input(node, node.input[3], new_axes_const.output[0], 3)
return self._switch_transpose_and_node(node, trans)
return False

def _quantize_handler(self, trans, node):
Expand Down