Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to extend OptionalHasElement and OptionalGetElement to accept tensor and sequence types #4421

Merged
merged 11 commits into from Aug 16, 2022
66 changes: 66 additions & 0 deletions docs/Changelog.md
Expand Up @@ -21320,6 +21320,72 @@ This version of the operator has been available since version 18 of the default
<dd>Constrain input X and output types to float tensors.</dd>
</dl>

### <a name="OptionalGetElement-18"></a>**OptionalGetElement-18**</a>

If the input is a tensor or sequence type, it returns the input.
If the input is an optional type, it outputs the element in the input.
It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case.

#### Version

This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>input</tt> : O</dt>
<dd>The optional input.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : V</dt>
<dd>Output element in the optional input.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain output type to all tensor or sequence types.</dd>
</dl>

### <a name="OptionalHasElement-18"></a>**OptionalHasElement-18**</a>

Returns true if (1) the input is an optional-type and contains an element,
or, (2) the input is a tensor or sequence type.
If the input is not provided or is an empty optional-type, this op returns false.

#### Version

This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs (0 - 1)

<dl>
<dt><tt>input</tt> (optional) : O</dt>
<dd>The optional input.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : B</dt>
<dd>A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain output to a boolean tensor.</dd>
</dl>

### <a name="Pad-18"></a>**Pad-18**</a>

Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`,
Expand Down
102 changes: 66 additions & 36 deletions docs/Operators.md
Expand Up @@ -97,8 +97,8 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Not">Not</a>|<a href="Changelog.md#Not-1">1</a>|
|<a href="#OneHot">OneHot</a>|<a href="Changelog.md#OneHot-11">11</a>, <a href="Changelog.md#OneHot-9">9</a>|
|<a href="#Optional">Optional</a>|<a href="Changelog.md#Optional-15">15</a>|
|<a href="#OptionalGetElement">OptionalGetElement</a>|<a href="Changelog.md#OptionalGetElement-15">15</a>|
|<a href="#OptionalHasElement">OptionalHasElement</a>|<a href="Changelog.md#OptionalHasElement-15">15</a>|
|<a href="#OptionalGetElement">OptionalGetElement</a>|<a href="Changelog.md#OptionalGetElement-18">18</a>, <a href="Changelog.md#OptionalGetElement-15">15</a>|
|<a href="#OptionalHasElement">OptionalHasElement</a>|<a href="Changelog.md#OptionalHasElement-18">18</a>, <a href="Changelog.md#OptionalHasElement-15">15</a>|
|<a href="#Or">Or</a>|<a href="Changelog.md#Or-7">7</a>, <a href="Changelog.md#Or-1">1</a>|
|<a href="#PRelu">PRelu</a>|<a href="Changelog.md#PRelu-16">16</a>, <a href="Changelog.md#PRelu-9">9</a>, <a href="Changelog.md#PRelu-7">7</a>, <a href="Changelog.md#PRelu-6">6</a>, <a href="Changelog.md#PRelu-1">1</a>|
|<a href="#Pad">Pad</a>|<a href="Changelog.md#Pad-18">18</a>, <a href="Changelog.md#Pad-13">13</a>, <a href="Changelog.md#Pad-11">11</a>, <a href="Changelog.md#Pad-2">2</a>, <a href="Changelog.md#Pad-1">1</a>|
Expand Down Expand Up @@ -14457,12 +14457,15 @@ This version of the operator has been available since version 15 of the default

### <a name="OptionalGetElement"></a><a name="optionalgetelement">**OptionalGetElement**</a>

Outputs the element in the optional-type input. It is an error if the input value does not have an element
and the behavior is undefined in this case.
If the input is a tensor or sequence type, it returns the input.
If the input is an optional type, it outputs the element in the input.
It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case.

#### Version

This version of the operator has been available since version 15 of the default ONNX operator set.
This version of the operator has been available since version 18 of the default ONNX operator set.

Other versions of this operator: <a href="Changelog.md#OptionalGetElement-15">15</a>

#### Inputs

Expand All @@ -14481,7 +14484,7 @@ This version of the operator has been available since version 15 of the default
#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128))</dt>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain output type to all tensor or sequence types.</dd>
Expand All @@ -14490,16 +14493,20 @@ This version of the operator has been available since version 15 of the default

### <a name="OptionalHasElement"></a><a name="optionalhaselement">**OptionalHasElement**</a>

Returns true if the optional-type input contains an element. If it is an empty optional-type, this op returns false.
Returns true if (1) the input is an optional-type and contains an element,
or, (2) the input is a tensor or sequence type.
If the input is not provided or is an empty optional-type, this op returns false.

#### Version

This version of the operator has been available since version 15 of the default ONNX operator set.
This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs
Other versions of this operator: <a href="Changelog.md#OptionalHasElement-15">15</a>

#### Inputs (0 - 1)

<dl>
<dt><tt>input</tt> : O</dt>
<dt><tt>input</tt> (optional) : O</dt>
<dd>The optional input.</dd>
</dl>

Expand All @@ -14513,7 +14520,7 @@ This version of the operator has been available since version 15 of the default
#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128))</dt>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain output to a boolean tensor.</dd>
Expand All @@ -14527,17 +14534,29 @@ This version of the operator has been available since version 15 of the default

```python
optional = None

tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=onnx.TensorProto.INT32, shape=[])
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
node = onnx.helper.make_node(
'OptionalHasElement',
inputs=['optional_input'],
outputs=['output']
)
output = optional_has_element_reference_implementation(optional)
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[input_type_proto],
name='test_optional_has_element_empty')
optional_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)

# OptionalHasElement takes a tensor or optional as input
for input_type_proto in [tensor_type_proto, optional_type_proto]:
input_name_options = {
'empty': 'optional_input',
'empty_no_input_name': '',
'empty_no_input': None,
}
for test_name_surfix, input_name in input_name_options.items():
node = onnx.helper.make_node(
'OptionalHasElement',
inputs=[] if input_name is None else [input_name],
outputs=['output']
)
output = optional_has_element_reference_implementation(optional)
test_name = 'test_optional_has_element_' + test_name_surfix \
+ ('_optional_input' if input_type_proto == optional_type_proto else '_tensor_input')
expect(node, inputs=[optional] if input_name else [], outputs=[output],
input_type_protos=[input_type_proto] if input_name else [],
name=test_name)
```

</details>
Expand All @@ -14550,7 +14569,7 @@ expect(node, inputs=[optional], outputs=[output],
optional = [np.array([1, 2, 3, 4]).astype(np.int32)]
tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=onnx.TensorProto.INT32, shape=[4, ])
seq_type_proto = onnx.helper.make_sequence_type_proto(tensor_type_proto)
input_type_proto = onnx.helper.make_optional_type_proto(seq_type_proto)
optional_type_proto = onnx.helper.make_optional_type_proto(seq_type_proto)

node = onnx.helper.make_node(
'OptionalGetElement',
Expand All @@ -14559,7 +14578,10 @@ node = onnx.helper.make_node(
)
output = optional_get_element_reference_implementation(optional)
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[input_type_proto],
input_type_protos=[optional_type_proto],
name='test_optional_get_element_optional_sequence')
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[seq_type_proto],
name='test_optional_get_element_sequence')
```

Expand All @@ -14572,7 +14594,7 @@ expect(node, inputs=[optional], outputs=[output],
```python
optional = np.array([1, 2, 3, 4]).astype(np.float32)
tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=onnx.TensorProto.FLOAT, shape=[4, ])
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
optional_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)

node = onnx.helper.make_node(
'OptionalGetElement',
Expand All @@ -14581,8 +14603,11 @@ node = onnx.helper.make_node(
)
output = optional_get_element_reference_implementation(optional)
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[input_type_proto],
name='test_optional_get_element')
input_type_protos=[optional_type_proto],
name='test_optional_get_element_optional_tensor')
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[tensor_type_proto],
name='test_optional_get_element_tensor')
```

</details>
Expand All @@ -14594,16 +14619,21 @@ expect(node, inputs=[optional], outputs=[output],
```python
optional = np.array([1, 2, 3, 4]).astype(np.float32)
tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=onnx.TensorProto.FLOAT, shape=[4, ])
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
node = onnx.helper.make_node(
'OptionalHasElement',
inputs=['optional_input'],
outputs=['output']
)
output = optional_has_element_reference_implementation(optional)
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[input_type_proto],
name='test_optional_has_element')
optional_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)

# OptionalHasElement takes a tensor or optional as input
for input_type_protos in [tensor_type_proto, optional_type_proto]:
node = onnx.helper.make_node(
'OptionalHasElement',
inputs=['optional_input'],
outputs=['output']
)
output = optional_has_element_reference_implementation(optional)
test_name = 'test_optional_has_element_' +\
('optional_input' if input_type_protos == optional_type_proto else 'tensor_input')
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[optional_type_proto],
name=test_name)
```

</details>
Expand Down