Skip to content

Commit

Permalink
to extend OptionalHasElement and OptionalGetElement to accept tensor …
Browse files Browse the repository at this point in the history
…and sequence types (onnx#4421)
  • Loading branch information
liqunfu authored and Bjarke Roune committed May 6, 2023
1 parent 1d4b36e commit 68e30e4
Show file tree
Hide file tree
Showing 40 changed files with 429 additions and 114 deletions.
66 changes: 66 additions & 0 deletions docs/Changelog.md
Expand Up @@ -21320,6 +21320,72 @@ This version of the operator has been available since version 18 of the default
<dd>Constrain input X and output types to float tensors.</dd>
</dl>

### <a name="OptionalGetElement-18"></a>**OptionalGetElement-18**</a>

If the input is a tensor or sequence type, it returns the input.
If the input is an optional type, it outputs the element in the input.
It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case.

#### Version

This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>input</tt> : O</dt>
<dd>The optional input.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : V</dt>
<dd>Output element in the optional input.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain output type to all tensor or sequence types.</dd>
</dl>

### <a name="OptionalHasElement-18"></a>**OptionalHasElement-18**</a>

Returns true if (1) the input is an optional-type and contains an element,
or, (2) the input is a tensor or sequence type.
If the input is not provided or is an empty optional-type, this op returns false.

#### Version

This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs (0 - 1)

<dl>
<dt><tt>input</tt> (optional) : O</dt>
<dd>The optional input.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : B</dt>
<dd>A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain output to a boolean tensor.</dd>
</dl>

### <a name="Pad-18"></a>**Pad-18**</a>

Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`,
Expand Down
133 changes: 93 additions & 40 deletions docs/Operators.md
Expand Up @@ -97,8 +97,8 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Not">Not</a>|<a href="Changelog.md#Not-1">1</a>|
|<a href="#OneHot">OneHot</a>|<a href="Changelog.md#OneHot-11">11</a>, <a href="Changelog.md#OneHot-9">9</a>|
|<a href="#Optional">Optional</a>|<a href="Changelog.md#Optional-15">15</a>|
|<a href="#OptionalGetElement">OptionalGetElement</a>|<a href="Changelog.md#OptionalGetElement-15">15</a>|
|<a href="#OptionalHasElement">OptionalHasElement</a>|<a href="Changelog.md#OptionalHasElement-15">15</a>|
|<a href="#OptionalGetElement">OptionalGetElement</a>|<a href="Changelog.md#OptionalGetElement-18">18</a>, <a href="Changelog.md#OptionalGetElement-15">15</a>|
|<a href="#OptionalHasElement">OptionalHasElement</a>|<a href="Changelog.md#OptionalHasElement-18">18</a>, <a href="Changelog.md#OptionalHasElement-15">15</a>|
|<a href="#Or">Or</a>|<a href="Changelog.md#Or-7">7</a>, <a href="Changelog.md#Or-1">1</a>|
|<a href="#PRelu">PRelu</a>|<a href="Changelog.md#PRelu-16">16</a>, <a href="Changelog.md#PRelu-9">9</a>, <a href="Changelog.md#PRelu-7">7</a>, <a href="Changelog.md#PRelu-6">6</a>, <a href="Changelog.md#PRelu-1">1</a>|
|<a href="#Pad">Pad</a>|<a href="Changelog.md#Pad-18">18</a>, <a href="Changelog.md#Pad-13">13</a>, <a href="Changelog.md#Pad-11">11</a>, <a href="Changelog.md#Pad-2">2</a>, <a href="Changelog.md#Pad-1">1</a>|
Expand Down Expand Up @@ -15736,12 +15736,15 @@ This version of the operator has been available since version 15 of the default

### <a name="OptionalGetElement"></a><a name="optionalgetelement">**OptionalGetElement**</a>

Outputs the element in the optional-type input. It is an error if the input value does not have an element
and the behavior is undefined in this case.
If the input is a tensor or sequence type, it returns the input.
If the input is an optional type, it outputs the element in the input.
It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case.

#### Version

This version of the operator has been available since version 15 of the default ONNX operator set.
This version of the operator has been available since version 18 of the default ONNX operator set.

Other versions of this operator: <a href="Changelog.md#OptionalGetElement-15">15</a>

#### Inputs

Expand All @@ -15760,7 +15763,7 @@ This version of the operator has been available since version 15 of the default
#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128))</dt>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain output type to all tensor or sequence types.</dd>
Expand All @@ -15769,16 +15772,20 @@ This version of the operator has been available since version 15 of the default

### <a name="OptionalHasElement"></a><a name="optionalhaselement">**OptionalHasElement**</a>

Returns true if the optional-type input contains an element. If it is an empty optional-type, this op returns false.
Returns true if (1) the input is an optional-type and contains an element,
or, (2) the input is a tensor or sequence type.
If the input is not provided or is an empty optional-type, this op returns false.

#### Version

This version of the operator has been available since version 15 of the default ONNX operator set.
This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs
Other versions of this operator: <a href="Changelog.md#OptionalHasElement-15">15</a>

#### Inputs (0 - 1)

<dl>
<dt><tt>input</tt> : O</dt>
<dt><tt>input</tt> (optional) : O</dt>
<dd>The optional input.</dd>
</dl>

Expand All @@ -15792,7 +15799,7 @@ This version of the operator has been available since version 15 of the default
#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128))</dt>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain output to a boolean tensor.</dd>
Expand All @@ -15806,21 +15813,45 @@ This version of the operator has been available since version 15 of the default

```python
optional = None

tensor_type_proto = onnx.helper.make_tensor_type_proto(
elem_type=onnx.TensorProto.INT32, shape=[]
)
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
node = onnx.helper.make_node(
"OptionalHasElement", inputs=["optional_input"], outputs=["output"]
)
output = optional_has_element_reference_implementation(optional)
expect(
node,
inputs=[optional],
outputs=[output],
input_type_protos=[input_type_proto],
name="test_optional_has_element_empty",
)
optional_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)

# OptionalHasElement takes a tensor or optional as input
for input_type_proto in [tensor_type_proto, optional_type_proto]:
input_name_options = {
"empty": "optional_input",
"empty_no_input_name": "",
"empty_no_input": None,
}
for test_name_surfix, input_name in input_name_options.items():
if input_type_proto == tensor_type_proto and input_name:
# the input tensor cannot be empty if input name is provided.
continue
node = onnx.helper.make_node(
"OptionalHasElement",
inputs=[] if input_name is None else [input_name],
outputs=["output"],
)
output = optional_has_element_reference_implementation(optional)
test_name = (
"test_optional_has_element_"
+ test_name_surfix
+ (
"_optional_input"
if input_type_proto == optional_type_proto
else "_tensor_input"
)
)
expect(
node,
inputs=[optional] if input_name else [],
outputs=[output],
input_type_protos=[input_type_proto] if input_name else [],
name=test_name,
)
```

</details>
Expand All @@ -15838,7 +15869,7 @@ tensor_type_proto = onnx.helper.make_tensor_type_proto(
],
)
seq_type_proto = onnx.helper.make_sequence_type_proto(tensor_type_proto)
input_type_proto = onnx.helper.make_optional_type_proto(seq_type_proto)
optional_type_proto = onnx.helper.make_optional_type_proto(seq_type_proto)

node = onnx.helper.make_node(
"OptionalGetElement", inputs=["optional_input"], outputs=["output"]
Expand All @@ -15848,7 +15879,14 @@ expect(
node,
inputs=[optional],
outputs=[output],
input_type_protos=[input_type_proto],
input_type_protos=[optional_type_proto],
name="test_optional_get_element_optional_sequence",
)
expect(
node,
inputs=[optional],
outputs=[output],
input_type_protos=[seq_type_proto],
name="test_optional_get_element_sequence",
)
```
Expand All @@ -15867,7 +15905,7 @@ tensor_type_proto = onnx.helper.make_tensor_type_proto(
4,
],
)
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
optional_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)

node = onnx.helper.make_node(
"OptionalGetElement", inputs=["optional_input"], outputs=["output"]
Expand All @@ -15877,8 +15915,15 @@ expect(
node,
inputs=[optional],
outputs=[output],
input_type_protos=[input_type_proto],
name="test_optional_get_element",
input_type_protos=[optional_type_proto],
name="test_optional_get_element_optional_tensor",
)
expect(
node,
inputs=[optional],
outputs=[output],
input_type_protos=[tensor_type_proto],
name="test_optional_get_element_tensor",
)
```

Expand All @@ -15896,18 +15941,26 @@ tensor_type_proto = onnx.helper.make_tensor_type_proto(
4,
],
)
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
node = onnx.helper.make_node(
"OptionalHasElement", inputs=["optional_input"], outputs=["output"]
)
output = optional_has_element_reference_implementation(optional)
expect(
node,
inputs=[optional],
outputs=[output],
input_type_protos=[input_type_proto],
name="test_optional_has_element",
)
optional_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)

# OptionalHasElement takes a tensor or optional as input
for input_type_protos in [tensor_type_proto, optional_type_proto]:
node = onnx.helper.make_node(
"OptionalHasElement", inputs=["optional_input"], outputs=["output"]
)
output = optional_has_element_reference_implementation(optional)
test_name = "test_optional_has_element_" + (
"optional_input"
if input_type_protos == optional_type_proto
else "tensor_input"
)
expect(
node,
inputs=[optional],
outputs=[output],
input_type_protos=[optional_type_proto],
name=test_name,
)
```

</details>
Expand Down

0 comments on commit 68e30e4

Please sign in to comment.