Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Where op to permit bfloat16 types #3738

Merged
merged 5 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -20381,6 +20381,47 @@ This version of the operator has been available since version 16 of the default
<dd>Constrain input and output types to any tensor type.</dd>
</dl>

### <a name="Where-16"></a>**Where-16**</a>

Return elements, either from X or Y, depending on condition
(with Numpy-style broadcasting support).
Where behaves like numpy.where with three parameters:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

**History**
- Version 16 adds bfloat16 to the types allowed (for the second and third parameter).

#### Version

This version of the operator has been available since version 16 of the default ONNX operator set.

#### Inputs

<dl>
<dt><tt>condition</tt> (non-differentiable) : B</dt>
<dd>When True (nonzero), yield X, otherwise yield Y</dd>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>values selected at indices where condition is True</dd>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>values selected at indices where condition is False</dd>
</dl>

#### Outputs

<dl>
<dt><tt>output</tt> (differentiable) : T</dt>
<dd>Tensor of shape equal to the broadcasted shape of condition, X, and Y.</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain to boolean tensors.</dd>
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
<dd>Constrain input and output types to all tensor types (including bfloat).</dd>
</dl>

# ai.onnx.preview.training
## Version 1 of the 'ai.onnx.preview.training' operator set
### <a name="ai.onnx.preview.training.Adagrad-1"></a>**ai.onnx.preview.training.Adagrad-1**</a>
Expand Down
13 changes: 9 additions & 4 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Unique">Unique</a>|<a href="Changelog.md#Unique-11">11</a>|
|<a href="#Unsqueeze">Unsqueeze</a>|<a href="Changelog.md#Unsqueeze-13">13</a>, <a href="Changelog.md#Unsqueeze-11">11</a>, <a href="Changelog.md#Unsqueeze-1">1</a>|
|<a href="#Upsample">Upsample</a> (deprecated)|<a href="Changelog.md#Upsample-10">10</a>, <a href="Changelog.md#Upsample-9">9</a>, <a href="Changelog.md#Upsample-7">7</a>|
|<a href="#Where">Where</a>|<a href="Changelog.md#Where-9">9</a>|
|<a href="#Where">Where</a>|<a href="Changelog.md#Where-16">16</a>, <a href="Changelog.md#Where-9">9</a>|
|<a href="#Xor">Xor</a>|<a href="Changelog.md#Xor-7">7</a>, <a href="Changelog.md#Xor-1">1</a>|
|**Function**|**Since version**|
|<a href="#Bernoulli">Bernoulli</a>|<a href="Changelog.md#Bernoulli-15">15</a>|
Expand Down Expand Up @@ -24019,10 +24019,15 @@ expect(node, inputs=[data, scales], outputs=[output],
(with Numpy-style broadcasting support).
Where behaves like numpy.where with three parameters:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

**History**
- Version 16 adds bfloat16 to the types allowed (for the second and third parameter).

#### Version

This version of the operator has been available since version 9 of the default ONNX operator set.
This version of the operator has been available since version 16 of the default ONNX operator set.

Other versions of this operator: <a href="Changelog.md#Where-9">9</a>

#### Inputs

Expand All @@ -24047,8 +24052,8 @@ This version of the operator has been available since version 9 of the default O
<dl>
<dt><tt>B</tt> : tensor(bool)</dt>
<dd>Constrain to boolean tensors.</dd>
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
<dd>Constrain input and output types to all tensor types.</dd>
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(bfloat16), tensor(float16), tensor(float), tensor(double), tensor(string), tensor(bool), tensor(complex64), tensor(complex128)</dt>
<dd>Constrain input and output types to all tensor types (including bfloat).</dd>
</dl>


Expand Down
2 changes: 2 additions & 0 deletions onnx/defs/operator_sets.h
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, ScatterElements);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, If);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, Loop);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, Identity);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, Where);

// Iterate over schema from ai.onnx version 16
class OpSet_Onnx_ver16 {
Expand All @@ -1009,6 +1010,7 @@ class OpSet_Onnx_ver16 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, If)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, Loop)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, Identity)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 16, Where)>());
}
};
inline void RegisterOnnxOperatorSetSchema() {
Expand Down
13 changes: 8 additions & 5 deletions onnx/defs/tensor/defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2898,18 +2898,21 @@ ONNX_OPERATOR_SET_SCHEMA(
}
}));

static const char* Where_ver9_doc = R"DOC(
static const char* Where_ver16_doc = R"DOC(
Return elements, either from X or Y, depending on condition
(with Numpy-style broadcasting support).
Where behaves like numpy.where with three parameters:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html

**History**
- Version 16 adds bfloat16 to the types allowed (for the second and third parameter).
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Where,
9,
16,
OpSchema()
.SetDoc(Where_ver9_doc)
.SetDoc(Where_ver16_doc)
.Input(
0,
"condition",
Expand Down Expand Up @@ -2949,8 +2952,8 @@ ONNX_OPERATOR_SET_SCHEMA(
.TypeConstraint("B", {"tensor(bool)"}, "Constrain to boolean tensors.")
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
"Constrain input and output types to all tensor types.")
OpSchema::all_tensor_types_with_bfloat(),
"Constrain input and output types to all tensor types (including bfloat).")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
if (hasNInputShapes(ctx, 3)) {
Expand Down
66 changes: 66 additions & 0 deletions onnx/defs/tensor/old.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4417,4 +4417,70 @@ ONNX_OPERATOR_SET_SCHEMA(
"Constrain input and output types to all tensor and sequence types.")
.TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput));

static const char* Where_ver9_doc = R"DOC(
Return elements, either from X or Y, depending on condition
(with Numpy-style broadcasting support).
Where behaves like numpy.where with three parameters:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html
)DOC";

ONNX_OPERATOR_SET_SCHEMA(
Where,
9,
OpSchema()
.SetDoc(Where_ver9_doc)
.Input(
0,
"condition",
"When True (nonzero), yield X, otherwise yield Y",
"B",
OpSchema::Single,
true,
1,
OpSchema::NonDifferentiable)
.Input(
1,
"X",
"values selected at indices where condition is True",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Input(
2,
"Y",
"values selected at indices where condition is False",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.Output(
0,
"output",
"Tensor of shape equal to the broadcasted shape of condition, X, and Y.",
"T",
OpSchema::Single,
true,
1,
OpSchema::Differentiable)
.TypeConstraint("B", {"tensor(bool)"}, "Constrain to boolean tensors.")
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
"Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 1, 0);
if (hasNInputShapes(ctx, 3)) {
std::vector<const TensorShapeProto*> shapes;
shapes.push_back(&ctx.getInputType(0)->tensor_type().shape());
shapes.push_back(&ctx.getInputType(1)->tensor_type().shape());
shapes.push_back(&ctx.getInputType(2)->tensor_type().shape());
multidirectionalBroadcastShapeInference(
shapes,
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape());
}
}));

} // namespace ONNX_NAMESPACE
7 changes: 7 additions & 0 deletions onnx/test/shape_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3873,6 +3873,13 @@ def test_optional_sequence_get_element(self): # type: () -> None
[])
self._assert_inferred(graph, [optional_val_info, sequence_val_into, output_val_into]) # type: ignore

def test_where_bfloat(self): # type: () -> None
graph = self._make_graph(
[('cond', TensorProto.BOOL, (10,)), ('x', TensorProto.BFLOAT16, (10,)), ('y', TensorProto.BFLOAT16, (10,))],
[make_node('Where', ['cond', 'x', 'y'], ['out'])],
[])
self._assert_inferred(graph, [make_tensor_value_info('out', TensorProto.BFLOAT16, (10,))]) # type: ignore


if __name__ == '__main__':
unittest.main()
2 changes: 2 additions & 0 deletions onnx/version_converter/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,8 @@ class DefaultVersionConverter : public BaseVersionConverter {
OpSetID(15), OpSetID(16)));
registerAdapter(make_unique<CompatibleAdapter>("If",
OpSetID(15), OpSetID(16)));
registerAdapter(make_unique<CompatibleAdapter>("Where",
OpSetID(15), OpSetID(16)));
}

ModelProto convert_version(
Expand Down