From fd3ab73a66a9c4de4095dcedc7a43965ec2d21a7 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Thu, 23 Jan 2020 14:10:23 -0800 Subject: [PATCH 1/2] Clarify split supports zero length splits (#2544) * Fixed specs for split * split zero size * fix for test model * Trigger build * clarify spec for splittosequence * Trigger build * fix for feedback * Trigger build Co-authored-by: G. Ramalingam --- docs/Changelog.md | 4 +-- docs/Operators.md | 25 ++++++++++++-- docs/TestCoverage.md | 21 +++++++++++- onnx/backend/test/case/model/sequence.py | 31 ++++++++++++++++++ onnx/backend/test/case/node/split.py | 15 +++++++++ .../test_split_zero_size_splits/model.onnx | Bin 0 -> 209 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 13 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 16 bytes .../test_data_set_0/output_1.pb | Bin 0 -> 16 bytes .../test_data_set_0/output_2.pb | Bin 0 -> 16 bytes .../simple/test_sequence_model8/model.onnx | Bin 0 -> 136 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 9 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 17 bytes onnx/defs/sequence/defs.cc | 3 +- onnx/defs/tensor/defs.cc | 2 +- 15 files changed, 93 insertions(+), 8 deletions(-) create mode 100644 onnx/backend/test/data/node/test_split_zero_size_splits/model.onnx create mode 100644 onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_1.pb create mode 100644 onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_2.pb create mode 100644 onnx/backend/test/data/simple/test_sequence_model8/model.onnx create mode 100644 onnx/backend/test/data/simple/test_sequence_model8/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/simple/test_sequence_model8/test_data_set_0/output_0.pb diff --git a/docs/Changelog.md b/docs/Changelog.md index f7f96988874..b66b91392da 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -13532,7 +13532,7 @@ This version of the operator has been available since version 11 of the default
axis : int (default is 0)
Which axis to split on. A negative value means counting dimensions from the back. Accepted range is [-rank, rank-1] where r = rank(input).
split : list of ints
-
length of each output
+
length of each output. Values should be >= 0.
#### Inputs @@ -13588,7 +13588,7 @@ This version of the operator has been available since version 11 of the default
input : T
The tensor to split
split (optional) : I
-
Length of each output. It can be either a scalar(tensor of empty shape), or a 1-D tensor. All values must be positive.
+
Length of each output. It can be either a scalar(tensor of empty shape), or a 1-D tensor. All values must be >= 0.
#### Outputs diff --git a/docs/Operators.md b/docs/Operators.md index f32460d67cf..30504ea86e2 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -17033,7 +17033,7 @@ Other versions of this operator: Split-1, axis : int (default is 0)
Which axis to split on. A negative value means counting dimensions from the back. Accepted range is [-rank, rank-1] where r = rank(input).
split : list of ints
-
length of each output
+
length of each output. Values should be >= 0.
#### Inputs @@ -17157,6 +17157,27 @@ expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_s +
+zero_size_splits + +```python +input = np.array([]).astype(np.float32) + +# Split emtpy tensor to tensors of size zero +node = onnx.helper.make_node( + 'Split', + inputs=['input'], + outputs=['output_1', 'output_2', 'output_3'], + split=[0, 0, 0] +) + +expected_outputs = [np.array([]).astype(np.float32), np.array([]).astype(np.float32), np.array([]).astype(np.float32)] +expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_zero_size_splits') +``` + +
+ + ###
**SplitToSequence** Split a tensor into a sequence of tensors, along the specified @@ -17189,7 +17210,7 @@ This version of the operator has been available since version 11 of the default
input : T
The tensor to split
split (optional) : I
-
Length of each output. It can be either a scalar(tensor of empty shape), or a 1-D tensor. All values must be positive.
+
Length of each output. It can be either a scalar(tensor of empty shape), or a 1-D tensor. All values must be >= 0.
#### Outputs diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index b3e1d2a9028..63af64c57c9 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -9571,7 +9571,7 @@ expect(node, inputs=[x], outputs=[y], ### Split -There are 3 test cases, listed as following: +There are 4 test cases, listed as following:
1d @@ -9662,6 +9662,25 @@ expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4., 5., expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_variable_parts_default_axis') ``` +
+
+zero_size_splits + +```python +input = np.array([]).astype(np.float32) + +# Split emtpy tensor to tensors of size zero +node = onnx.helper.make_node( + 'Split', + inputs=['input'], + outputs=['output_1', 'output_2', 'output_3'], + split=[0, 0, 0] +) + +expected_outputs = [np.array([]).astype(np.float32), np.array([]).astype(np.float32), np.array([]).astype(np.float32)] +expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_zero_size_splits') +``` +
diff --git a/onnx/backend/test/case/model/sequence.py b/onnx/backend/test/case/model/sequence.py index 1f20a026c5a..c34c3c440a0 100644 --- a/onnx/backend/test/case/model/sequence.py +++ b/onnx/backend/test/case/model/sequence.py @@ -238,3 +238,34 @@ def make_graph( [pos_at]) model = onnx.helper.make_model(graph, producer_name='backend-test') expect(model, inputs=[x], outputs=[out], name="test_sequence_model7") + + #8th testcase - split zero length + seq_split_node = onnx.helper.make_node('SplitToSequence', ['X'], ['seq_1']) + seq_len_node = onnx.helper.make_node('SequenceLength', ['seq_1'], ['len']) + + tensor_shape = [] # type: ignore + len_shape = [] # type: ignore + + x = np.array([]).astype(np.float32) + out_len = np.int64(0) + + graph = onnx.helper.make_graph( + nodes=[seq_split_node, seq_len_node], + name='Sequence', + inputs=[ + onnx.helper.make_tensor_value_info( + 'X', + onnx.TensorProto.FLOAT, + tensor_shape), # type: ignore + onnx.helper.make_tensor_value_info( + 'Split', + onnx.TensorProto.INT64, + len_shape)], # type: ignore + outputs=[ + onnx.helper.make_tensor_value_info( + 'len', + onnx.TensorProto.INT64, + len_shape)]) # type: ignore + + model = onnx.helper.make_model(graph, producer_name='backend-test') + expect(model, inputs=[x], outputs=[out_len], name="test_sequence_model8") diff --git a/onnx/backend/test/case/node/split.py b/onnx/backend/test/case/node/split.py index fac19ee9f79..028b834764c 100644 --- a/onnx/backend/test/case/node/split.py +++ b/onnx/backend/test/case/node/split.py @@ -90,3 +90,18 @@ def export_default_values(): # type: () -> None expected_outputs = [np.array([1., 2.]).astype(np.float32), np.array([3., 4., 5., 6.]).astype(np.float32)] expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_variable_parts_default_axis') + + @staticmethod + def export_zero_size_splits(): # type: () -> None + input = np.array([]).astype(np.float32) + + # Split emtpy tensor to tensors of size zero + node = onnx.helper.make_node( + 'Split', + inputs=['input'], + outputs=['output_1', 'output_2', 'output_3'], + split=[0, 0, 0] + ) + + expected_outputs = [np.array([]).astype(np.float32), np.array([]).astype(np.float32), np.array([]).astype(np.float32)] + expect(node, inputs=[input], outputs=[y for y in expected_outputs], name='test_split_zero_size_splits') diff --git a/onnx/backend/test/data/node/test_split_zero_size_splits/model.onnx b/onnx/backend/test/data/node/test_split_zero_size_splits/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..163f43d9409fb80b373f84251bcaab393c8967d4 GIT binary patch literal 209 zcmd;J6XHoqOwLZtOVKS!EiSRz#mHsH#hRH{P+B6ykzZN@q~i_YEF(C}Scx^bASbg# zOMr{D7{qX30Ko-}>_XBYgW|!0@l~ls`SHb>RjCkuag;F3a4s$m4n`psE+!6!Brz_8 TRoKOigs_VnBa1sR3Ge^_kOwp+ literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..16d4ac7a398ab96299391c7d2fa5fefc962aeae7 GIT binary patch literal 13 Ucmd;J5MXp-&CDw(E%9Oi01prX+W-In literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..14aa56f56b63a9f129eb665f27be7c5e25d52b1a GIT binary patch literal 16 Xcmd;J5MXrT$S*A^C@qOM^kM)28?^)P literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_1.pb b/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_1.pb new file mode 100644 index 0000000000000000000000000000000000000000..2f700936b76f6cde4cd35d91c538ad0008444a54 GIT binary patch literal 16 Xcmd;J5MXrT$S*A^C@qOM@?ro08@L1T literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_2.pb b/onnx/backend/test/data/node/test_split_zero_size_splits/test_data_set_0/output_2.pb new file mode 100644 index 0000000000000000000000000000000000000000..f45e84e61dcfeb8d00bdd20a5b3d89d0c7f2a72a GIT binary patch literal 16 Xcmd;J5MXrT$S*A^C@qOM_F@148@mJX literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/simple/test_sequence_model8/model.onnx b/onnx/backend/test/data/simple/test_sequence_model8/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b6bf227d498d5d3b669d883aff9e7731c2a3d78c GIT binary patch literal 136 zcmd;J6XHoqOwLZtOVKS!EiSPt;*#cKj1XciPA!Z#RN@aV$jK}T$q!B~EKSWzPUVu} zg2)Il=cMK-@xjD>QuESFGK4r_{3vdaA#7YM9E?H?QT$x2U~|9%>_QAlyj&pdFs>7m G01p5Xfgtq& literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/simple/test_sequence_model8/test_data_set_0/input_0.pb b/onnx/backend/test/data/simple/test_sequence_model8/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..65193ec95dc61b964dff17356b743dc850beaa25 GIT binary patch literal 9 Qcmd;J5MXp-jPPOr00XfA{{R30 literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/simple/test_sequence_model8/test_data_set_0/output_0.pb b/onnx/backend/test/data/simple/test_sequence_model8/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..8a36bae46aec6aa4fa64b05be9646174a2a0fead GIT binary patch literal 17 ScmWe&cVf;-&GX`5fC2y)w*l?| literal 0 HcmV?d00001 diff --git a/onnx/defs/sequence/defs.cc b/onnx/defs/sequence/defs.cc index ed12f763ccc..571729cfb7b 100644 --- a/onnx/defs/sequence/defs.cc +++ b/onnx/defs/sequence/defs.cc @@ -346,8 +346,7 @@ ONNX_OPERATOR_SET_SCHEMA( 1, "split", "Length of each output. " - "It can be either a scalar(tensor of empty shape), or a 1-D tensor. " - "All values must be positive. ", + "It can be either a scalar(tensor of empty shape), or a 1-D tensor. All values must be >= 0. ", "I", OpSchema::Optional) .Output( diff --git a/onnx/defs/tensor/defs.cc b/onnx/defs/tensor/defs.cc index 880125cdf90..fa21f7fc9f4 100644 --- a/onnx/defs/tensor/defs.cc +++ b/onnx/defs/tensor/defs.cc @@ -399,7 +399,7 @@ ONNX_OPERATOR_SET_SCHEMA( "where r = rank(input).", AttributeProto::INT, static_cast(0)) - .Attr("split", "length of each output", AttributeProto::INTS, OPTIONAL) + .Attr("split", "length of each output. Values should be >= 0.", AttributeProto::INTS, OPTIONAL) .SetDoc(Split_ver11_doc) .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { for (int i = 0; i < static_cast(ctx.getNumOutputs()); ++i) { From bbd604ef01472ba8c4e88b2da9bd3dbf68c45487 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Fri, 24 Jan 2020 13:05:38 -0800 Subject: [PATCH 2/2] Add Einsum op (#2504) * added einsum * defenition updates * CI fix * fix for CI * added comments * CI fix * CI fix * CI fix * added comments * fix for type annotation * fixed ref implementation * CI fix * Update einsum.py * Update defs.cc * Update einsum.py * Update einsum.py * Update einsum.py * Update einsum.py * fix for feedback * feedback updates * fix CI * Update defs.cc * feedback for comments * feedback updates * updated doc formatting * updated doc * formatting * test fix * fix for feedback * Trigger build * Update onnx/defs/math/defs.cc Co-Authored-By: Jonny Shipton * Update Changelog.md * Update Operators.md * trigger build * merge * merge * fix for feedback * fix for feedback * fix for feedback * added shape_infer test * Trigger build Co-authored-by: Jonny Shipton Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> --- docs/Changelog.md | 57 ++++++ docs/Operators.md | 168 ++++++++++++++++++ docs/TestCoverage.md | 103 ++++++++++- onnx/backend/test/case/node/einsum.py | 97 ++++++++++ .../test_einsum_batch_diagonal/model.onnx | 13 ++ .../test_data_set_0/input_0.pb | Bin 0 -> 614 bytes .../test_data_set_0/output_0.pb | 1 + .../node/test_einsum_batch_matmul/model.onnx | 20 +++ .../test_data_set_0/input_0.pb | Bin 0 -> 254 bytes .../test_data_set_0/input_1.pb | 3 + .../test_data_set_0/output_0.pb | Bin 0 -> 334 bytes .../node/test_einsum_inner_prod/model.onnx | Bin 0 -> 132 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/input_1.pb | 1 + .../test_data_set_0/output_0.pb | 1 + .../test/data/node/test_einsum_sum/model.onnx | 12 ++ .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 1 + .../node/test_einsum_transpose/model.onnx | 12 ++ .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 1 + onnx/defs/math/defs.cc | 141 +++++++++++++++ onnx/defs/operator_sets.h | 2 + onnx/test/shape_inference_test.py | 46 +++++ 24 files changed, 681 insertions(+), 1 deletion(-) create mode 100644 onnx/backend/test/case/node/einsum.py create mode 100644 onnx/backend/test/data/node/test_einsum_batch_diagonal/model.onnx create mode 100644 onnx/backend/test/data/node/test_einsum_batch_diagonal/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_batch_diagonal/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_batch_matmul/model.onnx create mode 100644 onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/input_1.pb create mode 100644 onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_inner_prod/model.onnx create mode 100644 onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_1.pb create mode 100644 onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_sum/model.onnx create mode 100644 onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_transpose/model.onnx create mode 100644 onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/output_0.pb diff --git a/docs/Changelog.md b/docs/Changelog.md index b66b91392da..c91475736da 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -13965,6 +13965,63 @@ This version of the operator has been available since version 12 of the default
Constrain input and output types to all numeric tensors.
+### **Einsum-12** + + An einsum of the form ```term1, term2 -> output-term``` produces an output tensor using the following equation + + ```output[output-term] = reduce-sum( input1[term1] * input2[term] )``` + + where the reduce-sum performs a summation over all the indices occurring in in the input terms (term1, term2) + that do not occur in the output-term. + + The Einsum operator evaluates algebraic tensor operations on a sequence of tensors, using the Einstein summation + convention. The equation string contains a comma-separated sequence of lower case letters. Each term corresponds to + an operand tensor, and the characters within the terms correspond to operands dimensions. + + This sequence may be followed by "->" to separate the left and right hand side of the equation. + If the equation contains "->" followed by the right-hand side, the explicit (not classical) form of the Einstein + summation is performed, and the right-hand side indices indicate output tensor dimensions. In other cases, + output indices are (implicitly) set to the alphabetically sorted sequence of indices appearing exactly once in the + equation. + + When a dimension character is repeated in the left-hand side, it represents summation along the dimension. + + The equation may contain ellipsis ("...") to enable broadcasting. Ellipsis must indicate a fixed number of dimensions. + The right-hand side may contain exactly one ellipsis. In implicit mode, the ellipsis dimensions are set to the + beginning of the output. The equation string may contain space (U+0020) character. + +#### Version + +This version of the operator has been available since version 12 of the default ONNX operator set. + +#### Attributes + +
+
equation : string (required)
+
Einsum expression string.
+
+ +#### Inputs (1 - ∞) + +
+
Inputs (variadic) : T
+
Operands
+
+ +#### Outputs + +
+
Output : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types to all numerical tensor types.
+
+ ### **MaxPool-12** MaxPool consumes an input tensor X and applies max pooling across diff --git a/docs/Operators.md b/docs/Operators.md index 30504ea86e2..b50b35aac31 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -37,6 +37,7 @@ * Det * Div * Dropout + * Einsum * Elu * Equal * Erf @@ -4397,6 +4398,173 @@ expect(node, inputs=[X], outputs=[Y, Y_Scale, Y_ZeroPoint], +### **Einsum** + + An einsum of the form ```term1, term2 -> output-term``` produces an output tensor using the following equation + + ```output[output-term] = reduce-sum( input1[term1] * input2[term] )``` + + where the reduce-sum performs a summation over all the indices occurring in in the input terms (term1, term2) + that do not occur in the output-term. + + The Einsum operator evaluates algebraic tensor operations on a sequence of tensors, using the Einstein summation + convention. The equation string contains a comma-separated sequence of lower case letters. Each term corresponds to + an operand tensor, and the characters within the terms correspond to operands dimensions. + + This sequence may be followed by "->" to separate the left and right hand side of the equation. + If the equation contains "->" followed by the right-hand side, the explicit (not classical) form of the Einstein + summation is performed, and the right-hand side indices indicate output tensor dimensions. In other cases, + output indices are (implicitly) set to the alphabetically sorted sequence of indices appearing exactly once in the + equation. + + When a dimension character is repeated in the left-hand side, it represents summation along the dimension. + + The equation may contain ellipsis ("...") to enable broadcasting. Ellipsis must indicate a fixed number of dimensions. + The right-hand side may contain exactly one ellipsis. In implicit mode, the ellipsis dimensions are set to the + beginning of the output. The equation string may contain space (U+0020) character. + +#### Version + +This version of the operator has been available since version 12 of the default ONNX operator set. + +#### Attributes + +
+
equation : string (required)
+
Einsum expression string.
+
+ +#### Inputs (1 - ∞) + +
+
Inputs (variadic) : T
+
Operands
+
+ +#### Outputs + +
+
Output : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types to all numerical tensor types.
+
+ + +#### Examples + +
+einsum_batch_diagonal + +```python +Eqn = '...ii ->...i' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn +) + +X = np.random.randn(3, 5, 5) +Z = einsum_reference_implementation(Eqn, (X,)) + +expect(node, inputs=[X], outputs=[Z], name='test_einsum_batch_diagonal') +``` + +
+ + +
+einsum_batch_matmul + +```python +Eqn = 'bij, bjk -> bik' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x', 'y'], + outputs=['z'], + equation=Eqn +) + +X = np.random.randn(5, 2, 3) +Y = np.random.randn(5, 3, 4) +Z = einsum_reference_implementation(Eqn, (X, Y)) + +expect(node, inputs=[X, Y], outputs=[Z], name='test_einsum_batch_matmul') +``` + +
+ + +
+einsum_inner_prod + +```python +Eqn = 'i,i' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x', 'y'], + outputs=['z'], + equation=Eqn +) + +X = np.random.randn(5) +Y = np.random.randn(5) +Z = einsum_reference_implementation(Eqn, (X, Y)) + +expect(node, inputs=[X, Y], outputs=[Z], name='test_einsum_inner_prod') +``` + +
+ + +
+einsum_sum + +```python +Eqn = 'ij->i' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn +) + +X = np.random.randn(3, 4) +Z = einsum_reference_implementation(Eqn, (X,)) + +expect(node, inputs=[X], outputs=[Z], name='test_einsum_sum') +``` + +
+ + +
+einsum_transpose + +```python +Eqn = 'ij->ji' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn +) + +X = np.random.randn(3, 4) +Y = einsum_reference_implementation(Eqn, (X,)) + +expect(node, inputs=[X], outputs=[Y], name='test_einsum_transpose') +``` + +
+ + ### **Elu** Elu takes one input data (Tensor) and produces one output data diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index 63af64c57c9..a8c22967d85 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -5,7 +5,7 @@ * [Overall Test Coverage](#overall-test-coverage) # Node Test Coverage ## Summary -Node tests have covered 136/151 (90.07%, 5 generators excluded) common operators. +Node tests have covered 137/152 (90.13%, 5 generators excluded) common operators. Node tests have covered 0/0 (N/A) experimental operators. @@ -2610,6 +2610,107 @@ expect(node, inputs=[X], outputs=[Y, Y_Scale, Y_ZeroPoint], +### Einsum +There are 5 test cases, listed as following: +
+einsum_batch_diagonal + +```python +Eqn = '...ii ->...i' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn +) + +X = np.random.randn(3, 5, 5) +Z = einsum_reference_implementation(Eqn, (X,)) + +expect(node, inputs=[X], outputs=[Z], name='test_einsum_batch_diagonal') +``` + +
+
+einsum_batch_matmul + +```python +Eqn = 'bij, bjk -> bik' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x', 'y'], + outputs=['z'], + equation=Eqn +) + +X = np.random.randn(5, 2, 3) +Y = np.random.randn(5, 3, 4) +Z = einsum_reference_implementation(Eqn, (X, Y)) + +expect(node, inputs=[X, Y], outputs=[Z], name='test_einsum_batch_matmul') +``` + +
+
+einsum_inner_prod + +```python +Eqn = 'i,i' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x', 'y'], + outputs=['z'], + equation=Eqn +) + +X = np.random.randn(5) +Y = np.random.randn(5) +Z = einsum_reference_implementation(Eqn, (X, Y)) + +expect(node, inputs=[X, Y], outputs=[Z], name='test_einsum_inner_prod') +``` + +
+
+einsum_sum + +```python +Eqn = 'ij->i' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn +) + +X = np.random.randn(3, 4) +Z = einsum_reference_implementation(Eqn, (X,)) + +expect(node, inputs=[X], outputs=[Z], name='test_einsum_sum') +``` + +
+
+einsum_transpose + +```python +Eqn = 'ij->ji' +node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn +) + +X = np.random.randn(3, 4) +Y = einsum_reference_implementation(Eqn, (X,)) + +expect(node, inputs=[X], outputs=[Y], name='test_einsum_transpose') +``` + +
+ + ### Elu There are 2 test cases, listed as following:
diff --git a/onnx/backend/test/case/node/einsum.py b/onnx/backend/test/case/node/einsum.py new file mode 100644 index 00000000000..b8637ce01a6 --- /dev/null +++ b/onnx/backend/test/case/node/einsum.py @@ -0,0 +1,97 @@ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import numpy as np # type: ignore + +import onnx +from ..base import Base +from . import expect +from typing import Tuple, Text + + +def einsum_reference_implementation(Eqn, Operands): # type: (Text, Tuple[np.ndarray, ...]) -> np.ndarray + Z = np.einsum(Eqn, *Operands) + return Z + + +class Einsum(Base): + + @staticmethod + def export_einsum_transpose(): # type: () -> None + Eqn = 'ij->ji' + node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn + ) + + X = np.random.randn(3, 4) + Y = einsum_reference_implementation(Eqn, (X,)) + + expect(node, inputs=[X], outputs=[Y], name='test_einsum_transpose') + + @staticmethod + def export_einsum_sum(): # type: () -> None + Eqn = 'ij->i' + node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn + ) + + X = np.random.randn(3, 4) + Z = einsum_reference_implementation(Eqn, (X,)) + + expect(node, inputs=[X], outputs=[Z], name='test_einsum_sum') + + @staticmethod + def export_einsum_batch_diagonal(): # type: () -> None + Eqn = '...ii ->...i' + node = onnx.helper.make_node( + 'Einsum', + inputs=['x'], + outputs=['y'], + equation=Eqn + ) + + X = np.random.randn(3, 5, 5) + Z = einsum_reference_implementation(Eqn, (X,)) + + expect(node, inputs=[X], outputs=[Z], name='test_einsum_batch_diagonal') + + @staticmethod + def export_einsum_inner_prod(): # type: () -> None + Eqn = 'i,i' + node = onnx.helper.make_node( + 'Einsum', + inputs=['x', 'y'], + outputs=['z'], + equation=Eqn + ) + + X = np.random.randn(5) + Y = np.random.randn(5) + Z = einsum_reference_implementation(Eqn, (X, Y)) + + expect(node, inputs=[X, Y], outputs=[Z], name='test_einsum_inner_prod') + + @staticmethod + def export_einsum_batch_matmul(): # type: () -> None + Eqn = 'bij, bjk -> bik' + node = onnx.helper.make_node( + 'Einsum', + inputs=['x', 'y'], + outputs=['z'], + equation=Eqn + ) + + X = np.random.randn(5, 2, 3) + Y = np.random.randn(5, 3, 4) + Z = einsum_reference_implementation(Eqn, (X, Y)) + + expect(node, inputs=[X, Y], outputs=[Z], name='test_einsum_batch_matmul') diff --git a/onnx/backend/test/data/node/test_einsum_batch_diagonal/model.onnx b/onnx/backend/test/data/node/test_einsum_batch_diagonal/model.onnx new file mode 100644 index 00000000000..843c949d206 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_batch_diagonal/model.onnx @@ -0,0 +1,13 @@ + backend-test:w ++ +xy"Einsum* +equation" ...ii ->...itest_einsum_batch_diagonalZ +x +  + + +b +y +   + +B \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_batch_diagonal/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_einsum_batch_diagonal/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..5e8787cd07a4646d48814d8779d038d3096d4d68 GIT binary patch literal 614 zcmV-s0-5~?0|*5O1rQ5D0eDK-1f|dA^Nu642zh)_C@IFV{KkebF z4Fzz*Klg?GOSDAyKNISZ$!l8ZKajJN2pA^6Kg)F@RY7RnKjf9=kn392KRcOCb-(BL zKelSHK%hd(zl4zi9Zd<;Kk$WqmcUf&zdAaw?1gLuz-B3Li*f4YKWvb48g{AcKiMx_ z+t9z~zrDr39CauHKn*$#@f1b(zg=~u4ykRYKRBcr@uT&}zsHAKe?x}&KMnwGSAK%` zKhN9yHc{2XKQro(k!-`KR=ca1cX@Nzpj$DlHTL;zZ3NTf!dz?zfZU~TjsUKKZWf|W5>4HzZGVq+)5StzlpyocR8NjKdi7 biktest_einsum_batch_matmulZ +x +  + + +Z +y +  + + +b +z +  + + +B \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..5b93576069cd70b02a2dd4461a356f80d927e7a8 GIT binary patch literal 254 zcmV642zh)_C@IFV{KkebF z4Fzz*Klg?GOSDAyKNISZ$!l8ZKajJN2pA^6Kg)F@RY7RnKjf9=kn392KRcOCb-(BL zKelSHK%hd(zl4zi9Zd<;Kk$WqmcUf&zdAaw?1gLuz-B3Li*f4YKWvb48g{AcKiMx_ z+t9z~zrDr39CauHKn*$#@f1b(zg=~u4ykRYKRBcr@uT&}zsHAKe?x}&KMnwGSAK%` EKk*}f!vFvP literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/input_1.pb new file mode 100644 index 00000000000..e5944834113 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/input_1.pb @@ -0,0 +1,3 @@ + ByJ6Q?3ꐑ3?ʣhx0WE)Q1DֿN6>?EGE?$`@?) T?e?tR6:?*nL'?zn⿈Fzӿ>?(?}(t|ך?gi?^sBtrH?D~U?|-G? +ǿ7!K&? \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_einsum_batch_matmul/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..8d382df70a04b4a9784122ba5aa3bda51caad70d GIT binary patch literal 334 zcmV-U0kQrF1qcEN1P}{C0eVWn0`Imexm;V%zdZ2Xd@RG;KRa1-yXMRQz}Q|lMwOfq zz+?kF^TTwrKhU2&c_H=oKSSOlTe4I2KPVZp?Tp6hzuuHS2+3^qzY>ycq4^>EKdGdt z?5Fm}zu>b(HUP)azldIFKWBFVz+6<{t~i|cKSot-$PV=KzcS{M)brNO;9zoTzSIC)L*zeT{p z$YrJ0zurqs|1d}az*3EGCrpR@KlnA%t?cOTKjygNF!eFzbTDO8i+~tzm$)jNc__RKv17z9;j{;z-+dDs@g*3zZi;O g=W+-HKzv90*l-C2!0_*1<4d9KKO|_O)0-s`K(iFDfdBvi literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_einsum_inner_prod/model.onnx b/onnx/backend/test/data/node/test_einsum_inner_prod/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..71095465c6c2d9d2ff7c6a602ac687cca30eaa00 GIT binary patch literal 132 zcmd;J6XHoqOwLZtOVKS!EiSRj<5J~ftl(m-6k@DWVsp*RD=y8|65`@WEi6qe$;{7F zV$RgbT)@aIBnDC*p9)qGpP84JS`=SUl%EpC57aKi#l^wFEyTjb#K8&@s6-Y>;^tzk Q5@O?G0jgkdViMp10Q+nm9smFU literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..60ec8162b62 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_0.pb @@ -0,0 +1 @@ + BxJ(9?S,?"RQ?N1iY@=|? \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_1.pb b/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_1.pb new file mode 100644 index 00000000000..9869f3dd629 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/input_1.pb @@ -0,0 +1 @@ + ByJ(BE2g?6I_ÿYlf)g>G? \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..a9dda2a61dc --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_inner_prod/test_data_set_0/output_0.pb @@ -0,0 +1 @@ + BzJPR< \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_sum/model.onnx b/onnx/backend/test/data/node/test_einsum_sum/model.onnx new file mode 100644 index 00000000000..fd3cd270bf6 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_sum/model.onnx @@ -0,0 +1,12 @@ + backend-test:] +$ +xy"Einsum* +equation"ij->itest_einsum_sumZ +x +   + +b +y + +  +B \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..86fd72d2a8c --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/input_0.pb @@ -0,0 +1 @@ + BxJ`9?S,?"RQ?N1iY@=|?BE2g?6I_ÿYlf)g>G? p?KD? \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..a69ca927987 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_sum/test_data_set_0/output_0.pb @@ -0,0 +1 @@ + ByJ[.܋ @R0?H!@ڻ}? \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_transpose/model.onnx b/onnx/backend/test/data/node/test_einsum_transpose/model.onnx new file mode 100644 index 00000000000..051b7ec0fa7 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_transpose/model.onnx @@ -0,0 +1,12 @@ + backend-test:h +% +xy"Einsum* +equation"ij->jitest_einsum_transposeZ +x +   + +b +y +   + +B \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..86fd72d2a8c --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/input_0.pb @@ -0,0 +1 @@ + BxJ`9?S,?"RQ?N1iY@=|?BE2g?6I_ÿYlf)g>G? p?KD? \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..22d7af00ba9 --- /dev/null +++ b/onnx/backend/test/data/node/test_einsum_transpose/test_data_set_0/output_0.pb @@ -0,0 +1 @@ + ByJ`9?=|?YlS,?BEf)g>G?"RQ?2g? p?N1iY@6I_ÿKD? \ No newline at end of file diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index d5213be0582..d0908f602c8 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include "onnx/defs/schema.h" #include "onnx/defs/tensor_proto_util.h" @@ -1665,4 +1666,144 @@ ONNX_OPERATOR_SET_SCHEMA( } })); +void einsumRankInference( + ONNX_NAMESPACE::InferenceContext& ctx, std::string equation) { + + const size_t numInputs = ctx.getNumInputs(); + if (numInputs < 1 || !hasNInputShapes(ctx, static_cast(numInputs))) { + return; + } + + auto* output_shape = getOutputShape(ctx, 0); + std::string left_equation; + + equation.erase(std::remove(equation.begin(), equation.end(), ' '), equation.end()); // Remove space char + auto mid_index = equation.find("->"); + if (mid_index != std::string::npos) { + // Separate right and left hand sides of the equation + left_equation = equation.substr(0, mid_index); + } else { + // No right hand side + left_equation = equation; + } + + std::string term; + size_t num_operands = 0; + size_t num_ellipsis = 0; + size_t num_ellipsis_indices = 0; + + // Parse the left-hand side + std::stringstream str(left_equation); + while(std::getline(str, term, ',')) { + auto ellipsis_index = term.find("..."); + if (ellipsis_index != std::string::npos) { + if (numInputs <= num_operands) { + fail_shape_inference("Number of input tensors does not match the operands in the equation."); + } + // If there is an ellipsis, the number of dimensions it represents must be total dim - letter dimensions + size_t rank = ctx.getInputType(num_operands)->tensor_type().shape().dim_size(); + if (num_ellipsis == 0) { + num_ellipsis_indices = rank - term.size() + 3; + } else { // ellipsis has been seen before. Check that if dimensions are compatible + if (num_ellipsis_indices != rank - term.size() + 3) { + fail_shape_inference("Ellipsis represents incompatible dimensions."); + } + } + num_ellipsis++; + } + num_operands++; + } + + if (numInputs != num_operands) { + fail_shape_inference("Number of input tensors does not match the operands in the equation."); + } + + const size_t number_of_letters = 26; + size_t num_letter_occurrences[number_of_letters] = {0}; + // Parse the provided right-hand side + if (mid_index != std::string::npos) { + std::string right_equation = equation.substr(mid_index + 2); + auto right_ellipsis_index = right_equation.find("..."); + if (right_ellipsis_index != std::string::npos) { // Right-hand side contains ellipsis + for (size_t i = 0; i < num_ellipsis; ++i) { + output_shape->add_dim(); + } + } + for (char c: right_equation) { // Add a dimension per each character in right hand equation + if (c != '.') { + output_shape->add_dim(); + } + } + } else { // Infer the dimension for right-hand side + // If there's an ellipsis, add it's corresponding dimensions + for (size_t i = 0; i < num_ellipsis_indices; i++) { + output_shape->add_dim(); + } + for (size_t i = 0; i < left_equation.size(); i++) { // Count chars that appear exactly once on left hand side + if ((left_equation.at(i) != ',') && (left_equation.at(i) != '.')) { + num_letter_occurrences[left_equation.at(i) - 'a']++; + } + } + for (size_t index = 0; index < number_of_letters; index++) { + if (num_letter_occurrences[index] == 1) { + output_shape->add_dim(); + } + } + } +} + +static const char* Einsum_ver12_doc = R"DOC( +An einsum of the form ```term1, term2 -> output-term``` produces an output tensor using the following equation + +```output[output-term] = reduce-sum( input1[term1] * input2[term] )``` + +where the reduce-sum performs a summation over all the indices occurring in in the input terms (term1, term2) +that do not occur in the output-term. + +The Einsum operator evaluates algebraic tensor operations on a sequence of tensors, using the Einstein summation +convention. The equation string contains a comma-separated sequence of lower case letters. Each term corresponds to +an operand tensor, and the characters within the terms correspond to operands dimensions. + +This sequence may be followed by "->" to separate the left and right hand side of the equation. +If the equation contains "->" followed by the right-hand side, the explicit (not classical) form of the Einstein +summation is performed, and the right-hand side indices indicate output tensor dimensions. In other cases, +output indices are (implicitly) set to the alphabetically sorted sequence of indices appearing exactly once in the +equation. + +When a dimension character is repeated in the left-hand side, it represents summation along the dimension. + +The equation may contain ellipsis ("...") to enable broadcasting. Ellipsis must indicate a fixed number of dimensions. +The right-hand side may contain exactly one ellipsis. In implicit mode, the ellipsis dimensions are set to the +beginning of the output. The equation string may contain space (U+0020) character. +)DOC"; + +ONNX_OPERATOR_SET_SCHEMA( + Einsum, + 12, + OpSchema() + .SetDoc(Einsum_ver12_doc) + .Attr( + "equation", + "Einsum expression string.", + AttributeProto::STRING) + .Input(0, + "Inputs", + "Operands", + "T", + OpSchema::Variadic) + .Output(0, "Output", "Output tensor", "T") + .TypeConstraint( + "T", + OpSchema::all_numeric_types(), + "Constrain input and output types to all numerical tensor types.") + .TypeAndShapeInferenceFunction([](InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + std::string equation = getAttribute(ctx, "equation", ""); + if (equation.compare("") == 0) { + return; + } + einsumRankInference(ctx, equation); + })); + } // namespace ONNX_NAMESPACE diff --git a/onnx/defs/operator_sets.h b/onnx/defs/operator_sets.h index 3ff8b712431..a0f21a1935a 100644 --- a/onnx/defs/operator_sets.h +++ b/onnx/defs/operator_sets.h @@ -712,6 +712,7 @@ class OpSet_Onnx_ver11 { // Forward declarations for ai.onnx version 12 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 12, ArgMax); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 12, ArgMin); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 12, Einsum); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 12, MaxPool); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 12, ReduceMax); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 12, ReduceMin); @@ -722,6 +723,7 @@ class OpSet_Onnx_ver12 { static void ForEachSchema(std::function fn) { fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnx/test/shape_inference_test.py b/onnx/test/shape_inference_test.py index 012f3f750df..8e4ff8c35ca 100644 --- a/onnx/test/shape_inference_test.py +++ b/onnx/test/shape_inference_test.py @@ -2746,6 +2746,52 @@ def test_gatherelements_indices_missing_shape(self): # type: () -> None []) self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, None)]) # type: ignore + def test_einsum_transpose(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (3, 4))], + [make_node('Einsum', ['x'], ['y'], equation='ij->ji')], + [],) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (None, None))]) # type: ignore + + def test_einsum_sum_along_dim(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (3, 4))], + [make_node('Einsum', ['x'], ['y'], equation='i j->i ')], + [],) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (None, ))]) # type: ignore + + def test_einsum_ellipsis(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (3, 4))], + [make_node('Einsum', ['x'], ['y'], equation='... ii ->... i')], + [],) + self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (None, None))]) # type: ignore + + def test_einsum_batch_matmul(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (5, 2, 3)), + ('y', TensorProto.FLOAT, (5, 3, 4))], + [make_node('Einsum', ['x', 'y'], ['z'], equation='bij , b jk-> bik')], + [],) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.FLOAT, (None, None, None))]) # type: ignore + + def test_einsum_left_hand_eqn(self): # type: () -> None + graph = self._make_graph( + [('x', TensorProto.FLOAT, (2, 3)), + ('y', TensorProto.FLOAT, (3, 4))], + [make_node('Einsum', ['x', 'y'], ['z'], equation='ij,kl')], + [],) + self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.FLOAT, (None, None, None, None))]) # type: ignore + + def test_einsum_incorrect_num_inputs(self): # type: () -> None + graph = self._make_graph( + [("x", TensorProto.FLOAT, (2, 3)), + ("y", TensorProto.FLOAT, (2, 3)), + ("z", TensorProto.FLOAT, (2, 3))], + [make_node('Einsum', ['x', 'y'], ['z'], equation='i,...j, k, l-> i')], + []) + self.assertRaises(checker.ValidationError, self._inferred, graph) + if __name__ == '__main__': unittest.main()