diff --git a/docs/Changelog.md b/docs/Changelog.md index e7c9a05dee4..528b2bffc17 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -24022,6 +24022,37 @@ This version of the operator has been available since version 20 of the default
Constrain grid types to float tensors.
+### **StringConcat-20** + + StringConcat concatenates string tensors elementwise (with NumPy-style broadcasting support) + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Inputs + +
+
X (non-differentiable) : T
+
Tensor to prepend in concatenation
+
Y (non-differentiable) : T
+
Tensor to append in concatenation
+
+ +#### Outputs + +
+
Z (non-differentiable) : T
+
Concatenated string tensor
+
+ +#### Type Constraints + +
+
T : tensor(string)
+
Inputs and outputs must be UTF-8 strings
+
+ # ai.onnx.preview.training ## Version 1 of the 'ai.onnx.preview.training' operator set ### **ai.onnx.preview.training.Adagrad-1** diff --git a/docs/Operators.md b/docs/Operators.md index 0efd01da237..e56e1642b3b 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -146,6 +146,7 @@ For an operator input/output's differentiability, it can be differentiable, |SplitToSequence|11| |Sqrt|13, 6, 1| |Squeeze|13, 11, 1| +|StringConcat|20| |StringNormalizer|10| |Sub|14, 13, 7, 6, 1| |Sum|13, 8, 6, 1| @@ -30080,6 +30081,103 @@ expect(node, inputs=[x, axes], outputs=[y], name="test_squeeze_negative_axes") +### **StringConcat** + + StringConcat concatenates string tensors elementwise (with NumPy-style broadcasting support) + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Inputs + +
+
X (non-differentiable) : T
+
Tensor to prepend in concatenation
+
Y (non-differentiable) : T
+
Tensor to append in concatenation
+
+ +#### Outputs + +
+
Z (non-differentiable) : T
+
Concatenated string tensor
+
+ +#### Type Constraints + +
+
T : tensor(string)
+
Inputs and outputs must be UTF-8 strings
+
+ + +#### Examples + +
+stringconcat + +```python +node = onnx.helper.make_node( + "StringConcat", + inputs=["x", "y"], + outputs=["result"], +) +x = np.array(["abc", "def"]).astype("object") +y = np.array([".com", ".net"]).astype("object") +result = np.array(["abc.com", "def.net"]).astype("object") + +expect(node, inputs=[x, y], outputs=[result], name="test_string_concat") + +x = np.array(["cat", "dog", "snake"]).astype("object") +y = np.array(["s"]).astype("object") +result = np.array(["cats", "dogs", "snakes"]).astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_broadcasting", +) + +x = np.array("cat").astype("object") +y = np.array("s").astype("object") +result = np.array("cats").astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_zero_dimensional", +) + +x = np.array(["abc", ""]).astype("object") +y = np.array(["", "abc"]).astype("object") +result = np.array(["abc", "abc"]).astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_empty_string", +) + +x = np.array(["的", "中"]).astype("object") +y = np.array(["的", "中"]).astype("object") +result = np.array(["的的", "中中"]).astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_utf8", +) +``` + +
+ + ### **StringNormalizer** StringNormalization performs string operations for basic cleaning. diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index 7b8d0cd2d9e..8d4d2d31518 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6,7 +6,7 @@ * [Overall Test Coverage](#overall-test-coverage) # Node Test Coverage ## Summary -Node tests have covered 174/187 (93.05%, 5 generators excluded) common operators. +Node tests have covered 175/188 (93.09%, 5 generators excluded) common operators. Node tests have covered 0/0 (N/A) experimental operators. @@ -20594,6 +20594,71 @@ expect(node, inputs=[x, axes], outputs=[y], name="test_squeeze_negative_axes") +### StringConcat +There are 1 test cases, listed as following: +
+stringconcat + +```python +node = onnx.helper.make_node( + "StringConcat", + inputs=["x", "y"], + outputs=["result"], +) +x = np.array(["abc", "def"]).astype("object") +y = np.array([".com", ".net"]).astype("object") +result = np.array(["abc.com", "def.net"]).astype("object") + +expect(node, inputs=[x, y], outputs=[result], name="test_string_concat") + +x = np.array(["cat", "dog", "snake"]).astype("object") +y = np.array(["s"]).astype("object") +result = np.array(["cats", "dogs", "snakes"]).astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_broadcasting", +) + +x = np.array("cat").astype("object") +y = np.array("s").astype("object") +result = np.array("cats").astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_zero_dimensional", +) + +x = np.array(["abc", ""]).astype("object") +y = np.array(["", "abc"]).astype("object") +result = np.array(["abc", "abc"]).astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_empty_string", +) + +x = np.array(["的", "中"]).astype("object") +y = np.array(["的", "中"]).astype("object") +result = np.array(["的的", "中中"]).astype("object") + +expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_utf8", +) +``` + +
+ + ### StringNormalizer There are 6 test cases, listed as following:
diff --git a/onnx/backend/test/case/node/string_concat.py b/onnx/backend/test/case/node/string_concat.py new file mode 100644 index 00000000000..51f566eecae --- /dev/null +++ b/onnx/backend/test/case/node/string_concat.py @@ -0,0 +1,68 @@ +# Copyright (c) ONNX Project Contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import onnx +from onnx.backend.test.case.base import Base +from onnx.backend.test.case.node import expect + + +class StringConcat(Base): + @staticmethod + def export() -> None: + node = onnx.helper.make_node( + "StringConcat", + inputs=["x", "y"], + outputs=["result"], + ) + x = np.array(["abc", "def"]).astype("object") + y = np.array([".com", ".net"]).astype("object") + result = np.array(["abc.com", "def.net"]).astype("object") + + expect(node, inputs=[x, y], outputs=[result], name="test_string_concat") + + x = np.array(["cat", "dog", "snake"]).astype("object") + y = np.array(["s"]).astype("object") + result = np.array(["cats", "dogs", "snakes"]).astype("object") + + expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_broadcasting", + ) + + x = np.array("cat").astype("object") + y = np.array("s").astype("object") + result = np.array("cats").astype("object") + + expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_zero_dimensional", + ) + + x = np.array(["abc", ""]).astype("object") + y = np.array(["", "abc"]).astype("object") + result = np.array(["abc", "abc"]).astype("object") + + expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_empty_string", + ) + + x = np.array(["的", "中"]).astype("object") + y = np.array(["的", "中"]).astype("object") + result = np.array(["的的", "中中"]).astype("object") + + expect( + node, + inputs=[x, y], + outputs=[result], + name="test_string_concat_utf8", + ) diff --git a/onnx/backend/test/data/node/test_string_concat/model.onnx b/onnx/backend/test/data/node/test_string_concat/model.onnx new file mode 100644 index 00000000000..258e878a556 Binary files /dev/null and b/onnx/backend/test/data/node/test_string_concat/model.onnx differ diff --git a/onnx/backend/test/data/node/test_string_concat/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_string_concat/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..69329cd7227 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +2abc2defBx \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_string_concat/test_data_set_0/input_1.pb new file mode 100644 index 00000000000..6cc32e25e88 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat/test_data_set_0/input_1.pb @@ -0,0 +1 @@ +2.com2.netBy \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_string_concat/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..60d415e647c --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat/test_data_set_0/output_0.pb @@ -0,0 +1 @@ +2abc.com2def.netBresult \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_broadcasting/model.onnx b/onnx/backend/test/data/node/test_string_concat_broadcasting/model.onnx new file mode 100644 index 00000000000..cde5cc9dd91 Binary files /dev/null and b/onnx/backend/test/data/node/test_string_concat_broadcasting/model.onnx differ diff --git a/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..8dac795560a --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +2cat2dog2snakeBx \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/input_1.pb new file mode 100644 index 00000000000..9b18dcdac00 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/input_1.pb @@ -0,0 +1 @@ +2sBy \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..2c1a6b57919 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_broadcasting/test_data_set_0/output_0.pb @@ -0,0 +1 @@ +2cats2dogs2snakesBresult \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_empty_string/model.onnx b/onnx/backend/test/data/node/test_string_concat_empty_string/model.onnx new file mode 100644 index 00000000000..93a0177f30f Binary files /dev/null and b/onnx/backend/test/data/node/test_string_concat_empty_string/model.onnx differ diff --git a/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..c15c43b6695 Binary files /dev/null and b/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/input_0.pb differ diff --git a/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/input_1.pb new file mode 100644 index 00000000000..1c56d03d297 Binary files /dev/null and b/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/input_1.pb differ diff --git a/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..dac25ab6f68 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_empty_string/test_data_set_0/output_0.pb @@ -0,0 +1 @@ +2abc2abcBresult \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_utf8/model.onnx b/onnx/backend/test/data/node/test_string_concat_utf8/model.onnx new file mode 100644 index 00000000000..7ddbbf85aa9 Binary files /dev/null and b/onnx/backend/test/data/node/test_string_concat_utf8/model.onnx differ diff --git a/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..1a8f4f43dbf --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +2的2中Bx \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/input_1.pb new file mode 100644 index 00000000000..8ff819ba45b --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/input_1.pb @@ -0,0 +1 @@ +2的2中By \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..8975d577ee8 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_utf8/test_data_set_0/output_0.pb @@ -0,0 +1 @@ +2的的2中中Bresult \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_zero_dimensional/model.onnx b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/model.onnx new file mode 100644 index 00000000000..610ba8e1357 Binary files /dev/null and b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/model.onnx differ diff --git a/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..7d66adf5979 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +2catBx \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/input_1.pb new file mode 100644 index 00000000000..6d237bd00f1 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/input_1.pb @@ -0,0 +1 @@ +2sBy \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..c4842eb6b11 --- /dev/null +++ b/onnx/backend/test/data/node/test_string_concat_zero_dimensional/test_data_set_0/output_0.pb @@ -0,0 +1 @@ +2catsBresult \ No newline at end of file diff --git a/onnx/defs/nn/defs.cc b/onnx/defs/nn/defs.cc index 410ec4da050..5db37f02340 100644 --- a/onnx/defs/nn/defs.cc +++ b/onnx/defs/nn/defs.cc @@ -2229,75 +2229,6 @@ ONNX_OPERATOR_SET_SCHEMA( }) .SetDoc(TfIdfVectorizer_ver9_doc)); -static const char* StringNormalizer_ver10_doc = R"DOC( -StringNormalization performs string operations for basic cleaning. -This operator has only one input (denoted by X) and only one output -(denoted by Y). This operator first examines the elements in the X, -and removes elements specified in "stopwords" attribute. -After removing stop words, the intermediate result can be further lowercased, -uppercased, or just returned depending the "case_change_action" attribute. -This operator only accepts [C]- and [1, C]-tensor. -If all elements in X are dropped, the output will be the empty value of string tensor with shape [1] -if input shape is [C] and shape [1, 1] if input shape is [1, C]. -)DOC"; - -ONNX_OPERATOR_SET_SCHEMA( - StringNormalizer, - 10, - OpSchema() - .Input(0, "X", "UTF-8 strings to normalize", "tensor(string)") - .Output(0, "Y", "UTF-8 Normalized strings", "tensor(string)") - .Attr( - std::string("case_change_action"), - std::string("string enum that cases output to be lowercased/uppercases/unchanged." - " Valid values are \"LOWER\", \"UPPER\", \"NONE\". Default is \"NONE\""), - AttributeProto::STRING, - std::string("NONE")) - .Attr( - std::string("is_case_sensitive"), - std::string("Boolean. Whether the identification of stop words in X is case-sensitive. Default is false"), - AttributeProto::INT, - static_cast(0)) - .Attr( - "stopwords", - "List of stop words. If not set, no word would be removed from X.", - AttributeProto::STRINGS, - OPTIONAL_VALUE) - .Attr( - "locale", - "Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased." - "Default en_US or platform specific equivalent as decided by the implementation.", - AttributeProto::STRING, - OPTIONAL_VALUE) - .SetDoc(StringNormalizer_ver10_doc) - .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { - auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); - output_elem_type->set_elem_type(TensorProto::STRING); - if (!hasInputShape(ctx, 0)) { - return; - } - TensorShapeProto output_shape; - auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); - auto dim_size = input_shape.dim_size(); - // Last axis dimension is unknown if we have stop-words since we do - // not know how many stop-words are dropped - if (dim_size == 1) { - // Unknown output dimension - output_shape.add_dim(); - } else if (dim_size == 2) { - // Copy B-dim - auto& b_dim = input_shape.dim(0); - if (!b_dim.has_dim_value() || b_dim.dim_value() != 1) { - fail_shape_inference("Input shape must have either [C] or [1,C] dimensions where C > 0"); - } - *output_shape.add_dim() = b_dim; - output_shape.add_dim(); - } else { - fail_shape_inference("Input shape must have either [C] or [1,C] dimensions where C > 0"); - } - updateOutputShape(ctx, 0, output_shape); - })); - static const char* mvn_ver13_doc = R"DOC( A MeanVarianceNormalization Function: Perform mean variance normalization on the input tensor X using formula: `(X-EX)/sqrt(E(X-EX)^2)` diff --git a/onnx/defs/operator_sets.h b/onnx/defs/operator_sets.h index a83adfd194f..0857ac37601 100644 --- a/onnx/defs/operator_sets.h +++ b/onnx/defs/operator_sets.h @@ -1105,6 +1105,7 @@ class OpSet_Onnx_ver19 { class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ConstantOfShape); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, StringConcat); // Iterate over schema from ai.onnx version 20 class OpSet_Onnx_ver20 { @@ -1113,6 +1114,7 @@ class OpSet_Onnx_ver20 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); } }; diff --git a/onnx/defs/text/defs.cc b/onnx/defs/text/defs.cc new file mode 100644 index 00000000000..935133dadd3 --- /dev/null +++ b/onnx/defs/text/defs.cc @@ -0,0 +1,104 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "onnx/defs/schema.h" + +namespace ONNX_NAMESPACE { +static const char* StringConcat_doc = + R"DOC(StringConcat concatenates string tensors elementwise (with NumPy-style broadcasting support))DOC"; +ONNX_OPERATOR_SET_SCHEMA( + StringConcat, + 20, + OpSchema() + .Input( + 0, + "X", + "Tensor to prepend in concatenation", + "T", + OpSchema::Single, + true, + 1, + OpSchema::NonDifferentiable) + .Input(1, "Y", "Tensor to append in concatenation", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) + .Output(0, "Z", "Concatenated string tensor", "T", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) + .TypeConstraint("T", {"tensor(string)"}, "Inputs and outputs must be UTF-8 strings") + .SetDoc(StringConcat_doc) + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (hasNInputShapes(ctx, 2)) + bidirectionalBroadcastShapeInference( + ctx.getInputType(0)->tensor_type().shape(), + ctx.getInputType(1)->tensor_type().shape(), + *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()); + })); + +static const char* StringNormalizer_ver10_doc = R"DOC( +StringNormalization performs string operations for basic cleaning. +This operator has only one input (denoted by X) and only one output +(denoted by Y). This operator first examines the elements in the X, +and removes elements specified in "stopwords" attribute. +After removing stop words, the intermediate result can be further lowercased, +uppercased, or just returned depending the "case_change_action" attribute. +This operator only accepts [C]- and [1, C]-tensor. +If all elements in X are dropped, the output will be the empty value of string tensor with shape [1] +if input shape is [C] and shape [1, 1] if input shape is [1, C]. +)DOC"; + +ONNX_OPERATOR_SET_SCHEMA( + StringNormalizer, + 10, + OpSchema() + .Input(0, "X", "UTF-8 strings to normalize", "tensor(string)") + .Output(0, "Y", "UTF-8 Normalized strings", "tensor(string)") + .Attr( + std::string("case_change_action"), + std::string("string enum that cases output to be lowercased/uppercases/unchanged." + " Valid values are \"LOWER\", \"UPPER\", \"NONE\". Default is \"NONE\""), + AttributeProto::STRING, + std::string("NONE")) + .Attr( + std::string("is_case_sensitive"), + std::string("Boolean. Whether the identification of stop words in X is case-sensitive. Default is false"), + AttributeProto::INT, + static_cast(0)) + .Attr( + "stopwords", + "List of stop words. If not set, no word would be removed from X.", + AttributeProto::STRINGS, + OPTIONAL_VALUE) + .Attr( + "locale", + "Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased." + "Default en_US or platform specific equivalent as decided by the implementation.", + AttributeProto::STRING, + OPTIONAL_VALUE) + .SetDoc(StringNormalizer_ver10_doc) + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_elem_type->set_elem_type(TensorProto::STRING); + if (!hasInputShape(ctx, 0)) { + return; + } + TensorShapeProto output_shape; + auto& input_shape = ctx.getInputType(0)->tensor_type().shape(); + auto dim_size = input_shape.dim_size(); + // Last axis dimension is unknown if we have stop-words since we do + // not know how many stop-words are dropped + if (dim_size == 1) { + // Unknown output dimension + output_shape.add_dim(); + } else if (dim_size == 2) { + // Copy B-dim + auto& b_dim = input_shape.dim(0); + if (!b_dim.has_dim_value() || b_dim.dim_value() != 1) { + fail_shape_inference("Input shape must have either [C] or [1,C] dimensions where C > 0"); + } + *output_shape.add_dim() = b_dim; + output_shape.add_dim(); + } else { + fail_shape_inference("Input shape must have either [C] or [1,C] dimensions where C > 0"); + } + updateOutputShape(ctx, 0, output_shape); + })); +} // namespace ONNX_NAMESPACE diff --git a/onnx/reference/ops/_op_list.py b/onnx/reference/ops/_op_list.py index 9d8b8f88536..2ae00f81a23 100644 --- a/onnx/reference/ops/_op_list.py +++ b/onnx/reference/ops/_op_list.py @@ -211,6 +211,7 @@ from onnx.reference.ops.op_sqrt import Sqrt from onnx.reference.ops.op_squeeze import Squeeze_1, Squeeze_11, Squeeze_13 from onnx.reference.ops.op_stft import STFT +from onnx.reference.ops.op_string_concat import StringConcat from onnx.reference.ops.op_string_normalizer import StringNormalizer from onnx.reference.ops.op_sub import Sub from onnx.reference.ops.op_sum import Sum diff --git a/onnx/reference/ops/op_string_concat.py b/onnx/reference/ops/op_string_concat.py new file mode 100644 index 00000000000..3ef0d08210f --- /dev/null +++ b/onnx/reference/ops/op_string_concat.py @@ -0,0 +1,23 @@ +# Copyright (c) ONNX Project Contributors + +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=R0912,R0913,W0221 + +import numpy as np + +from onnx.reference.op_run import OpRun + +_acceptable_str_dtypes = ("U", "O") + + +class StringConcat(OpRun): + def _run(self, x, y): + if ( + x.dtype.kind not in _acceptable_str_dtypes + or y.dtype.kind not in _acceptable_str_dtypes + ): + raise TypeError( + f"Inputs must be string tensors, received dtype {x.dtype} and {y.dtype}" + ) + # As per onnx/mapping.py, object numpy dtype corresponds to TensorProto.STRING + return (np.char.add(x.astype(np.str_), y.astype(np.str_)).astype(object),) diff --git a/onnx/reference/reference_evaluator.py b/onnx/reference/reference_evaluator.py index 3a4ddfe01c4..bc989406aa7 100644 --- a/onnx/reference/reference_evaluator.py +++ b/onnx/reference/reference_evaluator.py @@ -149,7 +149,7 @@ def _run(self, x, alpha=None): # type: ignore The class name must be the same. The domain does not have to be specified for the default domain. However, by default, class `OpRun` will load the most recent for this operator. - It can be explicirely specified by adding static attribute + It can be explicitly specified by adding static attribute `op_schema` of type :class:`OpSchema `. diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py index 0277e068a79..acd55749bff 100644 --- a/onnx/test/automatic_upgrade_test.py +++ b/onnx/test/automatic_upgrade_test.py @@ -1789,6 +1789,14 @@ def test_GroupNormalization(self) -> None: attrs={"epsilon": 1e-5, "num_groups": 2}, ) + def test_StringConcat(self) -> None: + self._test_op_upgrade( + "StringConcat", + 20, + [[2, 3], [2, 3]], + [[2, 3]], + ) + def test_ops_tested(self) -> None: all_schemas = onnx.defs.get_all_schemas() all_op_names = [schema.name for schema in all_schemas if schema.domain == ""] diff --git a/onnx/test/reference_evaluator_test.py b/onnx/test/reference_evaluator_test.py index c58ea49d620..fc9534cd150 100644 --- a/onnx/test/reference_evaluator_test.py +++ b/onnx/test/reference_evaluator_test.py @@ -3763,6 +3763,26 @@ def test_constant_of_shape_castlike(self): self.assertEqual(got.dtype, np.uint16) assert_allclose(np.array(1, dtype=np.uint16), got) + @parameterized.parameterized.expand( + [ + (["abc", "def"], [".com", ".net"], ["abc.com", "def.net"], (2,)), + (["cat", "dog", "snake"], ["s"], ["cats", "dogs", "snakes"], (3,)), + ("cat", "s", "cats", ()), + (["a", "ß", "y"], ["a", "ß", "y"], ["aa", "ßß", "yy"], (3,)), + ] + ) + def test_string_concat(self, a, b, expected, expected_shape): + A = make_tensor_value_info("A", TensorProto.STRING, None) + B = make_tensor_value_info("B", TensorProto.STRING, None) + Y = make_tensor_value_info("Y", TensorProto.STRING, None) + node = make_node("StringConcat", inputs=["A", "B"], outputs=["Y"]) + model = make_model(make_graph([node], "g", [A, B], [Y])) + ref = ReferenceEvaluator(model) + result, *_ = ref.run(None, {"A": np.array(a), "B": np.array(b)}) + np.testing.assert_array_equal(result, expected) + self.assertEqual(result.dtype.kind, "O") + self.assertEqual(result.shape, expected_shape) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index 2615bb56186..b44345258cd 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -1590,6 +1590,38 @@ def test_squeeze(self, _, version) -> None: opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)], ) + @parameterized.expand(all_versions_for("StringConcat")) + def test_stringconcat(self, _, version) -> None: + graph = self._make_graph( + [ + ("x", TensorProto.STRING, (2, 3, 4)), + ("y", TensorProto.STRING, (2, 3, 4)), + ], + [make_node("StringConcat", ["x", "y"], "z")], + [], + ) + self._assert_inferred( + graph, + [make_tensor_value_info("z", TensorProto.STRING, (2, 3, 4))], + opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)], + ) + + @parameterized.expand(all_versions_for("StringConcat")) + def test_stringconcat_broadcasting(self, _, version) -> None: + graph = self._make_graph( + [ + ("x", TensorProto.STRING, (2, 3, 4)), + ("y", TensorProto.STRING, (1, 3, 1)), + ], + [make_node("StringConcat", ["x", "y"], "z")], + [], + ) + self._assert_inferred( + graph, + [make_tensor_value_info("z", TensorProto.STRING, (2, 3, 4))], + opset_imports=[helper.make_opsetid(ONNX_DOMAIN, version)], + ) + def test_unsqueeze_regular(self) -> None: graph = self._make_graph( [("x", TensorProto.FLOAT, (3, 2)), ("axes", TensorProto.INT64, (4,))], diff --git a/onnx/test/test_backend_onnxruntime.py b/onnx/test/test_backend_onnxruntime.py index 9a87309c15a..ec094ea853a 100644 --- a/onnx/test/test_backend_onnxruntime.py +++ b/onnx/test/test_backend_onnxruntime.py @@ -249,6 +249,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): "|equal" "|identity" "|reshape" + "|string_concat" "|gelu" ")" )