Skip to content

Commit

Permalink
Merge pull request #1184 from onnx/tom/ImproveSelectOp
Browse files Browse the repository at this point in the history
Allow conversion of Select op when condition has rank 1 and input has…
  • Loading branch information
TomWildenhain-Microsoft committed Nov 17, 2020
2 parents 2f0db61 + dc7ba67 commit 20ae214
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,12 @@ def version_9(cls, ctx, node, **kwargs):
input_shape = ctx.get_shape(node.input[1])
if input_shape is None:
input_shape = ctx.get_shape(node.input[2])
if cond_shape is None or input_shape is None:
# Fallback if shape is unknown
cls.version_7(ctx, node, **kwargs)
return
node.type = "Where"
input_rank = len(input_shape)
input_rank = len(input_shape) if input_shape is not None else None
cond_rank = len(cond_shape) if cond_shape is not None else None
# if cond shape is 1-dimensional while input has higher rank, need to be reshaped to broadcast
if len(cond_shape) == 1 and input_rank > 1:
if cond_rank == 1 and input_rank != 1:
utils.make_sure(input_rank is not None, "input_rank unknown and cond_rank == 1")
broadcast_shape = [cond_shape[0]] + [1] * (input_rank - 1)
shape_const = ctx.make_const(utils.make_name(node.name), np.array(broadcast_shape, dtype=np.int64))
reshape = ctx.make_node("Reshape", [node.input[0], shape_const.output[0]])
Expand Down

0 comments on commit 20ae214

Please sign in to comment.