Skip to content

Commit

Permalink
Merge pull request #8493 from msakai/onnx-chainer-transpose-sequence-2
Browse files Browse the repository at this point in the history
Extend ONNX-Chainer's TransposeSequence converter to support more cases
  • Loading branch information
take-cheeze committed Dec 17, 2019
2 parents 8b0c4cc + b1a9f55 commit 9c83c8d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
31 changes: 22 additions & 9 deletions onnx_chainer/functions/array.py
Expand Up @@ -650,13 +650,26 @@ def convert_Rollaxis(func, opset_version, input_names, output_names, context):

def convert_TransposeSequence(
func, opset_version, input_names, output_names, context):
if len(input_names) == 1:
return onnx_helper.make_node(
'Split', input_names, output_names, axis=0),
elif len(output_names) == 1:
return onnx_helper.make_node(
'Concat', input_names, output_names, axis=0),
else:
if any(x.shape != func.inputs[0].shape for x in func.inputs):
raise ValueError(
'ONNX-Chainer can convert TransposeSequence only when input '
'or output length is 1')
'ONNX-Chainer can convert TransposeSequence only when all '
'inputs have same shape')
gb = onnx_helper.GraphBuilder()
n = func.inputs[0].shape[0]

concat_out = gb.op(
'Concat',
[gb.op('Unsqueeze', [name], axes=[0]) for name in input_names],
axis=0)

perm = list(range(len(func.inputs[0].shape) + 1))
perm[0], perm[1] = perm[1], perm[0]
transpose_out = gb.op('Transpose', [concat_out], perm=perm)

split_outs = gb.op('Split', [transpose_out], axis=0, num_outputs=n)
if n == 1:
split_outs = [split_outs]
for i, name in enumerate(split_outs):
gb.op_output_named('Squeeze', [name], [output_names[i]], axes=[0])

return gb.nodes()
2 changes: 2 additions & 0 deletions tests/onnx_chainer_tests/functions_tests/test_arrays.py
Expand Up @@ -528,6 +528,8 @@ def forward(self, x, indices):
{'in_shapes': [(3, 4)], 'name': 'transpose_sequence_single_input'},
{'in_shapes': [(1, 3), (1, 3)],
'name': 'transpose_sequence_single_output'},
{'in_shapes': [(2, 3), (2, 3), (2, 3), (2, 3)],
'name': 'transpose_sequence_same_shape'},
)
class TestTransposeSequence(ONNXModelTest):

Expand Down

0 comments on commit 9c83c8d

Please sign in to comment.