Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to extend OptionalHasElement and OptionalGetElement to accept tensor and sequence types #4421

Merged
merged 11 commits into from Aug 16, 2022
66 changes: 66 additions & 0 deletions docs/Changelog.md
Expand Up @@ -21320,6 +21320,72 @@ This version of the operator has been available since version 18 of the default
<dd>Constrain input X and output types to float tensors.</dd>
</dl>

### <a name="OptionalGetElement-18"></a>**OptionalGetElement-18**</a>

If the input is a tensor or sequence type, it returns the input.
If the input is an optional type, it outputs the element in the input.
It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case.

#### Version

This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>input</tt> : O</dt>
<dd>The optional input.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : V</dt>
<dd>Output element in the optional input.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain output type to all tensor or sequence types.</dd>
</dl>

### <a name="OptionalHasElement-18"></a>**OptionalHasElement-18**</a>

Returns true if (1) the input is an optional-type and contains an element,
or, (2) the input is a tensor or sequence type.
If the input is an empty optional-type, this op returns false.

#### Version

This version of the operator has been available since version 18 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>input</tt> : O</dt>
<dd>The optional input.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> : B</dt>
<dd>A scalar boolean tensor. If true, it indicates that optional-type input contains an element. Otherwise, it is empty.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain output to a boolean tensor.</dd>
</dl>

### <a name="Pad-18"></a>**Pad-18**</a>

Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`,
Expand Down
46 changes: 37 additions & 9 deletions docs/Operators.md
Expand Up @@ -97,8 +97,8 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Not">Not</a>|<a href="Changelog.md#Not-1">1</a>|
|<a href="#OneHot">OneHot</a>|<a href="Changelog.md#OneHot-11">11</a>, <a href="Changelog.md#OneHot-9">9</a>|
|<a href="#Optional">Optional</a>|<a href="Changelog.md#Optional-15">15</a>|
|<a href="#OptionalGetElement">OptionalGetElement</a>|<a href="Changelog.md#OptionalGetElement-15">15</a>|
|<a href="#OptionalHasElement">OptionalHasElement</a>|<a href="Changelog.md#OptionalHasElement-15">15</a>|
|<a href="#OptionalGetElement">OptionalGetElement</a>|<a href="Changelog.md#OptionalGetElement-18">18</a>, <a href="Changelog.md#OptionalGetElement-15">15</a>|
|<a href="#OptionalHasElement">OptionalHasElement</a>|<a href="Changelog.md#OptionalHasElement-18">18</a>, <a href="Changelog.md#OptionalHasElement-15">15</a>|
|<a href="#Or">Or</a>|<a href="Changelog.md#Or-7">7</a>, <a href="Changelog.md#Or-1">1</a>|
|<a href="#PRelu">PRelu</a>|<a href="Changelog.md#PRelu-16">16</a>, <a href="Changelog.md#PRelu-9">9</a>, <a href="Changelog.md#PRelu-7">7</a>, <a href="Changelog.md#PRelu-6">6</a>, <a href="Changelog.md#PRelu-1">1</a>|
|<a href="#Pad">Pad</a>|<a href="Changelog.md#Pad-18">18</a>, <a href="Changelog.md#Pad-13">13</a>, <a href="Changelog.md#Pad-11">11</a>, <a href="Changelog.md#Pad-2">2</a>, <a href="Changelog.md#Pad-1">1</a>|
Expand Down Expand Up @@ -14457,12 +14457,15 @@ This version of the operator has been available since version 15 of the default

### <a name="OptionalGetElement"></a><a name="optionalgetelement">**OptionalGetElement**</a>

Outputs the element in the optional-type input. It is an error if the input value does not have an element
and the behavior is undefined in this case.
If the input is a tensor or sequence type, it returns the input.
If the input is an optional type, it outputs the element in the input.
It is an error if the input is an empty optional-type (i.e. does not have an element) and the behavior is undefined in this case.

#### Version

This version of the operator has been available since version 15 of the default ONNX operator set.
This version of the operator has been available since version 18 of the default ONNX operator set.

Other versions of this operator: <a href="Changelog.md#OptionalGetElement-15">15</a>

#### Inputs

Expand All @@ -14481,7 +14484,7 @@ This version of the operator has been available since version 15 of the default
#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128))</dt>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>V</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain output type to all tensor or sequence types.</dd>
Expand All @@ -14490,11 +14493,15 @@ This version of the operator has been available since version 15 of the default

### <a name="OptionalHasElement"></a><a name="optionalhaselement">**OptionalHasElement**</a>

Returns true if the optional-type input contains an element. If it is an empty optional-type, this op returns false.
Returns true if (1) the input is an optional-type and contains an element,
or, (2) the input is a tensor or sequence type.
If the input is an empty optional-type, this op returns false.

#### Version

This version of the operator has been available since version 15 of the default ONNX operator set.
This version of the operator has been available since version 18 of the default ONNX operator set.

Other versions of this operator: <a href="Changelog.md#OptionalHasElement-15">15</a>

#### Inputs

Expand All @@ -14513,7 +14520,7 @@ This version of the operator has been available since version 15 of the default
#### Type Constraints

<dl>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128))</dt>
<dt><tt>O</tt> : optional(seq(tensor(uint8))), optional(seq(tensor(uint16))), optional(seq(tensor(uint32))), optional(seq(tensor(uint64))), optional(seq(tensor(int8))), optional(seq(tensor(int16))), optional(seq(tensor(int32))), optional(seq(tensor(int64))), optional(seq(tensor(float16))), optional(seq(tensor(float))), optional(seq(tensor(double))), optional(seq(tensor(string))), optional(seq(tensor(bool))), optional(seq(tensor(complex64))), optional(seq(tensor(complex128))), optional(tensor(uint8)), optional(tensor(uint16)), optional(tensor(uint32)), optional(tensor(uint64)), optional(tensor(int8)), optional(tensor(int16)), optional(tensor(int32)), optional(tensor(int64)), optional(tensor(float16)), optional(tensor(float)), optional(tensor(double)), optional(tensor(string)), optional(tensor(bool)), optional(tensor(complex64)), optional(tensor(complex128)), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128), seq(tensor(uint8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(int8)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(float16)), seq(tensor(float)), seq(tensor(double)), seq(tensor(string)), seq(tensor(bool)), seq(tensor(complex64)), seq(tensor(complex128))</dt>
<dd>Constrain input type to optional tensor and optional sequence types.</dd>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain output to a boolean tensor.</dd>
Expand Down Expand Up @@ -14609,6 +14616,27 @@ expect(node, inputs=[optional], outputs=[output],
</details>


<details>
<summary>tensor</summary>

```python
tensor = np.array([1, 2, 3, 4]).astype(np.float32)
tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=onnx.TensorProto.FLOAT, shape=[4, ])
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
node = onnx.helper.make_node(
'OptionalHasElement',
inputs=['optional_input'],
outputs=['output']
)
output = optional_has_element_reference_implementation(tensor)
expect(node, inputs=[tensor], outputs=[output],
input_type_protos=[input_type_proto],
name='test_optional_has_element_tensor')
```

</details>


### <a name="Or"></a><a name="or">**Or**</a>

Returns the tensor resulted from performing the `or` logical operation
Expand Down
21 changes: 20 additions & 1 deletion docs/TestCoverage.md
Expand Up @@ -9099,7 +9099,7 @@ expect(node, inputs=[indices, depth, values], outputs=[y], name='test_onehot_wit


### OptionalHasElement
There are 4 test cases, listed as following:
There are 5 test cases, listed as following:
<details>
<summary>empty</summary>

Expand Down Expand Up @@ -9178,6 +9178,25 @@ expect(node, inputs=[optional], outputs=[output],
name='test_optional_has_element')
```

</details>
<details>
<summary>tensor</summary>

```python
tensor = np.array([1, 2, 3, 4]).astype(np.float32)
tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=onnx.TensorProto.FLOAT, shape=[4, ])
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
node = onnx.helper.make_node(
'OptionalHasElement',
inputs=['optional_input'],
outputs=['output']
)
output = optional_has_element_reference_implementation(tensor)
expect(node, inputs=[tensor], outputs=[output],
input_type_protos=[input_type_proto],
name='test_optional_has_element_tensor')
```

</details>


Expand Down
15 changes: 15 additions & 0 deletions onnx/backend/test/case/node/optionalhaselement.py
Expand Up @@ -46,3 +46,18 @@ def export_empty() -> None:
expect(node, inputs=[optional], outputs=[output],
input_type_protos=[input_type_proto],
name='test_optional_has_element_empty')

@staticmethod
def export_tensor() -> None:
tensor = np.array([1, 2, 3, 4]).astype(np.float32)
tensor_type_proto = onnx.helper.make_tensor_type_proto(elem_type=onnx.TensorProto.FLOAT, shape=[4, ])
input_type_proto = onnx.helper.make_optional_type_proto(tensor_type_proto)
node = onnx.helper.make_node(
'OptionalHasElement',
inputs=['optional_input'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another test-case where inputs=[] or input=[''] would be useful.

How about test-cases for optional_get_element?

outputs=['output']
)
output = optional_has_element_reference_implementation(tensor)
expect(node, inputs=[tensor], outputs=[output],
input_type_protos=[input_type_proto],
name='test_optional_has_element_tensor')
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1 @@
 BoutputJ
4 changes: 4 additions & 0 deletions onnx/defs/operator_sets.h
Expand Up @@ -1004,6 +1004,8 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, Pad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, CenterCropPad);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, Resize);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, Mish);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, OptionalGetElement);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, OptionalHasElement);

// Iterate over schema from ai.onnx version 18
class OpSet_Onnx_ver18 {
Expand All @@ -1013,6 +1015,8 @@ class OpSet_Onnx_ver18 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, CenterCropPad)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, Resize)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, Mish)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, OptionalGetElement)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 18, OptionalHasElement)>());
}
};

Expand Down