From 12bbb48e5707e21f8094742ad689250d6a34e950 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Thu, 11 May 2023 02:25:22 -0400 Subject: [PATCH 01/23] Added implementation for GELU Signed-off-by: pranshupant --- onnx/defs/math/defs.cc | 119 +++++++++++++++++++++++++++++++++----- onnx/defs/operator_sets.h | 2 + 2 files changed, 107 insertions(+), 14 deletions(-) diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index bc3d95ac07f..9fe0e3285ef 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -470,6 +470,26 @@ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) ``` )DOC"; +ONNX_OPERATOR_SET_SCHEMA( + Mish, + 18, + OpSchema() + .SetDoc(mish_ver18_doc) + .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input X and output types to float tensors.") + .FunctionBody(R"ONNX( + { + Softplus_X = Softplus (X) + TanHSoftplusX = Tanh (Softplus_X) + Y = Mul (X, TanHSoftplusX) + } + )ONNX") + .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); + static const char* celu_ver12_doc = R"DOC( Continuously Differentiable Exponential Linear Units: Perform the linear unit element-wise on the input tensor X @@ -538,24 +558,95 @@ ONNX_OPERATOR_SET_SCHEMA( .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyCelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); +static const char* gelu_ver20_doc = R"DOC( +Gaussian Error Linear Units: +Gelu takes input data (Tensor) and an argument approximate, and produces one +output data (Tensor) where the function `f(x) = alpha * x for x < 0`, +`f(x) = x for x >= 0`, is applied to the data tensor elementwise. +Perform the linear unit element-wise on the input tensor X +using formula: + +``` +0.5*x*(1+erf(x/sqrt(2))) +``` + +When approximate is set to tanh + +``` +0.5*x*(1+Tanh(sqrt(2/π)*(x+0.044715*x^3))) +``` + +)DOC"; + +static std::string gelu_default_approx = "none"; + +bool BuildContextDependentFunctionBodyGelu( + const FunctionBodyBuildContext& ctx, + const OpSchema& schema, + FunctionProto& functionProto) { + auto approx_attr_proto = ctx.getAttribute("approximate"); + std::string approx = approx_attr_proto != nullptr && approx_attr_proto->has_s() + ? approx_attr_proto->s() + : gelu_default_alpha; + FunctionBuilder builder(functionProto); + + if (approx == "tanh") { + builder.Add(R"( + Half = Constant () + HalfCast = CastLike (Half, X) + One = Constant () + OneCast = CastLike (One, X) + TwoOverPi = Constant () + TwoOverPiCast = CastLike (TwoOverPi, X) + C0 = Constant () + C0Cast = CastLike (C0, X) + SqrtTwoOverPi = Sqrt (TwoOverPiCast) + Three = Constant () + ThreeCast = CastLike (Three, X) + CubeX = Pow ( X, ThreeCast) + XCubeC0 = Mul (C0Cast, CubeX) + XC0XCube = Sum (X, XCubeC0) + ErfApprox = Tanh (XC0XCube) + PhiApprox = Sum (OneCast, ErfApprox) + MultX = Mul (Half, X) + Y = Mul (MultX, PhiApprox) + )"); + } else { + builder.Add(R"( + Half = Constant () + HalfCast = CastLike (Half, X) + One = Constant () + OneCast = CastLike (One, X) + Two = Constant () + TwoCast = CastLike (Two, X) + SqrtTwo = Sqrt (TwoCast) + XSqrt = Div (X, SqrtTwo) + ErfXSqrt = Erf(XSqrt) + Phi = Sum (OneCast, ErfXSqrt) + MultX = Mul (Half, X) + Y = Mul (MultX, Phi) + )"); + } + schema.BuildFunction(functionProto); + return true; +} + ONNX_OPERATOR_SET_SCHEMA( - Mish, - 18, + Gelu, + 20, OpSchema() - .SetDoc(mish_ver18_doc) + .SetDoc(gelu_ver20_doc) .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input X and output types to float tensors.") - .FunctionBody(R"ONNX( - { - Softplus_X = Softplus (X) - TanHSoftplusX = Tanh (Softplus_X) - Y = Mul (X, TanHSoftplusX) - } - )ONNX") + .Attr( + "approximate", + "Type of gelu approximation algorithm to use: tanh, none(default)." + "'none': do not use approximation." + "'tanh': use tanh approximation.", + AttributeProto::STRING, + gelu_default_appox) + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float32 tensors.") + .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); static const char* Exp_ver13_doc = R"DOC( diff --git a/onnx/defs/operator_sets.h b/onnx/defs/operator_sets.h index 76473d2f9e0..a83adfd194f 100644 --- a/onnx/defs/operator_sets.h +++ b/onnx/defs/operator_sets.h @@ -1103,6 +1103,7 @@ class OpSet_Onnx_ver19 { // Forward declarations for ai.onnx version 20 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ConstantOfShape); // Iterate over schema from ai.onnx version 20 @@ -1110,6 +1111,7 @@ class OpSet_Onnx_ver20 { public: static void ForEachSchema(std::function fn) { fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); } }; From 0cf8a6cfe9a4b709c46547df8abb1e60389ea915 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Mon, 29 May 2023 23:09:28 -0400 Subject: [PATCH 02/23] adding unit test for gelu Signed-off-by: pranshupant --- onnx/backend/test/case/node/gelu.py | 44 +++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 onnx/backend/test/case/node/gelu.py diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py new file mode 100644 index 00000000000..7c4fe59e319 --- /dev/null +++ b/onnx/backend/test/case/node/gelu.py @@ -0,0 +1,44 @@ +# Copyright (c) ONNX Project Contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import math + +import numpy as np + +import onnx +from onnx.backend.test.case.base import Base +from onnx.backend.test.case.node import expect + + +class Gelu(Base): + @staticmethod + def export() -> None: + node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approx="tanh") + + x = np.array([-1, 0, 1]).astype(np.float32) + # expected output [-0.158808, 0., 0.841192] + cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = x * cdf + expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + + x = np.random.randn(3, 4, 5).astype(np.float32) + # expected output [2.9963627, 3.99993, 4.9999995] + cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = x * cdf + expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") + + @staticmethod + def export_gelu_default() -> None: + node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + + x = np.random.randn(-1, 0, 1).astype(np.float32) + # expected output [-0.15865526, 0., 0.84134474] + y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + + x = np.random.randn(3, 4, 5).astype(np.float32) + # expected output [2.99595031, 3.99987331, 4.99999857] + y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") + From 273366960013898fee162521b17547fb711a938f Mon Sep 17 00:00:00 2001 From: pranshupant Date: Mon, 29 May 2023 23:47:50 -0400 Subject: [PATCH 03/23] updated attribute name from appox to approximate + added trivial automatic_upgrade_tests Signed-off-by: pranshupant --- onnx/backend/test/case/node/gelu.py | 4 ++-- onnx/defs/math/defs.cc | 19 ++++--------------- onnx/test/automatic_upgrade_test.py | 6 ++++++ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index 7c4fe59e319..d7a6e2de270 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -14,7 +14,7 @@ class Gelu(Base): @staticmethod def export() -> None: - node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approx="tanh") + node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] @@ -31,7 +31,7 @@ def export() -> None: @staticmethod def export_gelu_default() -> None: node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) - + x = np.random.randn(-1, 0, 1).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index 9fe0e3285ef..ceba1fd317c 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -559,22 +559,11 @@ ONNX_OPERATOR_SET_SCHEMA( .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); static const char* gelu_ver20_doc = R"DOC( -Gaussian Error Linear Units: Gelu takes input data (Tensor) and an argument approximate, and produces one -output data (Tensor) where the function `f(x) = alpha * x for x < 0`, -`f(x) = x for x >= 0`, is applied to the data tensor elementwise. -Perform the linear unit element-wise on the input tensor X -using formula: - -``` -0.5*x*(1+erf(x/sqrt(2))) -``` - -When approximate is set to tanh - -``` -0.5*x*(1+Tanh(sqrt(2/π)*(x+0.044715*x^3))) -``` +output data (Tensor) where the function `y = 0.5 * x * (1 + erf(x/sqrt(2)))` +is applied to the tensor elementwise. When the attribute "approximate" is set +to "tanh", the function `y = 0.5 * x * (1 + Tanh(sqrt(2/π)*(x+0.044715*x^3)))` +is applied to the tensor elementwise to estimate. )DOC"; diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py index 248f338ce82..12714fea997 100644 --- a/onnx/test/automatic_upgrade_test.py +++ b/onnx/test/automatic_upgrade_test.py @@ -464,6 +464,12 @@ def test_GatherElements(self) -> None: def test_GatherND(self) -> None: self._test_op_upgrade("GatherND", 11, [[1, 2, 3], [1, 2, 3]], [[1, 2]]) + def test_Gelu_1(self) -> None: + self._test_op_upgrade("Gelu", 20, attrs={"approximate": "tanh"}) + + def test_Gelu_2(self) -> None: + self._test_op_upgrade("Gelu", 20) + def test_Gemm(self) -> None: self._test_op_upgrade("Gemm", 1, [[5, 4], [4, 3], [3]], [[5, 3]]) From 9ccb7ffa5a0a040dd2d8a743425d736dbbd0c364 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 30 May 2023 23:34:06 -0400 Subject: [PATCH 04/23] updated doc and unit test Signed-off-by: pranshupant --- onnx/backend/test/case/node/gelu.py | 8 +++----- onnx/defs/math/defs.cc | 22 +++++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index d7a6e2de270..ee337db6487 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -18,21 +18,19 @@ def export() -> None: x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] - cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) - y = x * cdf + y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] - cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) - y = x * cdf + y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") @staticmethod def export_gelu_default() -> None: node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) - x = np.random.randn(-1, 0, 1).astype(np.float32) + x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index ceba1fd317c..78f90805f07 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -559,11 +559,12 @@ ONNX_OPERATOR_SET_SCHEMA( .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); static const char* gelu_ver20_doc = R"DOC( -Gelu takes input data (Tensor) and an argument approximate, and produces one -output data (Tensor) where the function `y = 0.5 * x * (1 + erf(x/sqrt(2)))` -is applied to the tensor elementwise. When the attribute "approximate" is set -to "tanh", the function `y = 0.5 * x * (1 + Tanh(sqrt(2/π)*(x+0.044715*x^3)))` -is applied to the tensor elementwise to estimate. +Gelu takes one input data (Tensor) and produces one +output data (Tensor) where the gaussian error linear units function, +`y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. +If the attribute "approximate" is set to "tanh", the function estimation, +`y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied +to the tensor elementwise. )DOC"; @@ -576,7 +577,7 @@ bool BuildContextDependentFunctionBodyGelu( auto approx_attr_proto = ctx.getAttribute("approximate"); std::string approx = approx_attr_proto != nullptr && approx_attr_proto->has_s() ? approx_attr_proto->s() - : gelu_default_alpha; + : gelu_default_approx; FunctionBuilder builder(functionProto); if (approx == "tanh") { @@ -629,12 +630,15 @@ ONNX_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Attr( "approximate", - "Type of gelu approximation algorithm to use: tanh, none(default)." + "Gelu approximation algorithm: tanh, none(default)." "'none': do not use approximation." "'tanh': use tanh approximation.", AttributeProto::STRING, - gelu_default_appox) - .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float32 tensors.") + gelu_default_approx) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); From c4503e9e8df2ff1133c09dec0daf48091b533d6d Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 30 May 2023 23:36:12 -0400 Subject: [PATCH 05/23] Adding generated doc files Signed-off-by: pranshupant --- docs/Changelog.md | 42 ++++++++++++++++++++++ docs/Operators.md | 86 ++++++++++++++++++++++++++++++++++++++++++++ docs/TestCoverage.md | 42 +++++++++++++++++++++- 3 files changed, 169 insertions(+), 1 deletion(-) diff --git a/docs/Changelog.md b/docs/Changelog.md index c82dbbbc083..5dc7251c2d7 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -23881,6 +23881,48 @@ This version of the operator has been available since version 19 of the default ## Version 20 of the default ONNX operator set +### **Gelu-20** + + Gelu takes one input data (Tensor) and produces one + output data (Tensor) where the gaussian error linear units function, + `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + If the attribute "approximate" is set to "tanh", the function estimation, + `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + to the tensor elementwise. + + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +
+
approximate : string (default is none)
+
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
+ +#### Inputs + +
+
X (differentiable) : T
+
Input tensor
+
+ +#### Outputs + +
+
Y (differentiable) : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types to float tensors.
+
+ ### **ConstantOfShape-20** Generate a tensor with given value and shape. diff --git a/docs/Operators.md b/docs/Operators.md index e9fa1c5f6eb..a7f91ccb914 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -170,6 +170,7 @@ For an operator input/output's differentiability, it can be differentiable, |Clip|13, 12, 11, 6, 1|13| |DynamicQuantizeLinear|11|11| |Elu|6, 1|18| +|Gelu|20|20| |GreaterOrEqual|16, 12|16| |GroupNormalization|18|18| |HammingWindow|17|17| @@ -9410,6 +9411,91 @@ expect( +### **Gelu** + + Gelu takes one input data (Tensor) and produces one + output data (Tensor) where the gaussian error linear units function, + `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + If the attribute "approximate" is set to "tanh", the function estimation, + `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + to the tensor elementwise. + + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +
+
approximate : string (default is none)
+
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
+ +#### Inputs + +
+
X (differentiable) : T
+
Input tensor
+
+ +#### Outputs + +
+
Y (differentiable) : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types to float tensors.
+
+ + +#### Examples + +
+gelu + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +``` + +
+ + +
+gelu_default + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +``` + +
+ + ### **Gemm** General Matrix multiplication: diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index f59159e2a4c..01e9ca420a2 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6,7 +6,7 @@ * [Overall Test Coverage](#overall-test-coverage) # Node Test Coverage ## Summary -Node tests have covered 173/186 (93.01%, 5 generators excluded) common operators. +Node tests have covered 174/187 (93.05%, 5 generators excluded) common operators. Node tests have covered 0/0 (N/A) experimental operators. @@ -6241,6 +6241,46 @@ expect( +### Gelu +There are 2 test cases, listed as following: +
+gelu + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +``` + +
+
+gelu_default + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +``` + +
+ + ### Gemm There are 11 test cases, listed as following:
From f23be94fce437d5b8ae9e887b2938f7204e1d3f5 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 30 May 2023 23:46:55 -0400 Subject: [PATCH 06/23] adding test data files Signed-off-by: pranshupant --- .../data/node/test_gelu_default_1/model.onnx | Bin 0 -> 93 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 33 bytes .../node/test_gelu_default_1_expanded/model.onnx | Bin 0 -> 1425 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 33 bytes .../data/node/test_gelu_default_2/model.onnx | Bin 0 -> 109 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 3 +++ .../node/test_gelu_default_2_expanded/model.onnx | Bin 0 -> 1441 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 3 +++ .../test/data/node/test_gelu_tanh_1/model.onnx | Bin 0 -> 114 bytes .../test_gelu_tanh_1/test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_gelu_tanh_1/test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../node/test_gelu_tanh_1_expanded/model.onnx | Bin 0 -> 2062 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../test/data/node/test_gelu_tanh_2/model.onnx | Bin 0 -> 130 bytes .../test_gelu_tanh_2/test_data_set_0/input_0.pb | 1 + .../test_gelu_tanh_2/test_data_set_0/output_0.pb | Bin 0 -> 254 bytes .../node/test_gelu_tanh_2_expanded/model.onnx | Bin 0 -> 2078 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | Bin 0 -> 254 bytes 24 files changed, 10 insertions(+) create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb diff --git a/onnx/backend/test/data/node/test_gelu_default_1/model.onnx b/onnx/backend/test/data/node/test_gelu_default_1/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9de98825326d939f2b7294a5497098a61bcb58a3 GIT binary patch literal 93 zcmdKLDtPoGnKczp@qq8;j zA_PV+$?U+qH*db8EwM9D!;v=0q1B~reF8j*;~Skj>EH6jTZvAU%dI?>dOlGm(TPl* z8QQEc@>b<(9CZp)T4ii=A2w$yce)7;9*2yvi#A*Gu*JXu6&CnhoD);&sM2CO_`S literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..8a9445744b63f66e76c3ef4fce746606ffc6f47e GIT binary patch literal 21 Ycmd;J7GQK@tnlJtU})IS00s^A03TBWEC2ui literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..e3bff4eb0ce00f047fa9733e094caeb277b188a4 GIT binary patch literal 33 jcmd;J7T|Vbtn`v73rsw6!SBd^1}Ipx`uREUH?Qmgmud?% literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_2/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4fd1b6ba5a58b03e9164f8bd4c4825ce69e5173e GIT binary patch literal 109 zcmdKLDtPoz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..00ae3481971 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ + ByJ\ ?!p?ٴS0?m@=?M:WĿb93?= +]|nn@nobS?|6?{Hv@?3h*?R΋*??@Ь?D0J?X+?Lų,|ſ닿#E$?$R'M?"ſhn{@eۉ{&}H?#afԔ???r3F?aU?BHIſN $ Ә45$?+?.&?&mӾPoj]ÿM#\Hy5hIj?2v ÿ}fE¿\B+Cq?qqc\7'aG<ſe?cÿو`Pml5/2i@?0?v?q3/?c;Yſ@{T + \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..11296adc1e5278193f9e90d2b5c8857f46cffca1 GIT binary patch literal 1441 zcmbtUO-sW-5Z#y*;qTiTdQcQvp?C?~W}~6$CN{gZy?OL!^rt#o zQbB~k=wXvthIwz^%+C1xyfqQinN;zKmZk1~HsOh%-N?*D_8z@hE0&2cnT|STgdSt;vd1=#*JWUU3JcDPUl3F2i&A4c z_`OSK5@|>LlmPLdWQ~}QZ!skhQLUB^Es=jp8NKH2VXJm`I`U$ literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..bae0ffd6324 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..00ae3481971 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ + ByJ\ ?!p?ٴS0?m@=?M:WĿb93?= +]|nn@nobS?|6?{Hv@?3h*?R΋*??@Ь?D0J?X+?Lų,|ſ닿#E$?$R'M?"ſhn{@eۉ{&}H?#afԔ???r3F?aU?BHIſN $ Ә45$?+?.&?&mӾPoj]ÿM#\Hy5hIj?2v ÿ}fE¿\B+Cq?qqc\7'aG<ſe?cÿو`Pml5/2i@?0?v?q3/?c;Yſ@{T + \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cbb06f2b052747f6aa0a281ec0e25ef28a8fb20c GIT binary patch literal 114 zcmdl1k8X$;I^x^uMQ<+cJ7tF~ z#xCuUj?)ckn+<&G+`3-ZG%TM^nM+w=4{gp|FC+z&S4~|}l*_8}q3Y)f@{trKQWv-+ zL&9>Wj1%F{yTOp~h7b4qb8y>*8FgB+gBc9H=l49vnd<u z-tJw6YvX6HNC`X?X*)PWyFZ8-#(wnVFC*|);R9wofj@Kh;DL(j;5MY1u)2)9;r*2t c_H#~bUW5!o#X8zBv=(1j{#~f*NXhSh10;XhRsaA1 literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..8a9445744b63f66e76c3ef4fce746606ffc6f47e GIT binary patch literal 21 Ycmd;J7GQK@tnlJtU})IS00s^A03TBWEC2ui literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..0f554cc42e247392ae38456c23643f18032a9088 GIT binary patch literal 21 bcmd;J7GQK@tn}iUFi&Y80}#YSgxdoED_sPK literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..887e5c52023cb40139da847694442f1a46320594 GIT binary patch literal 130 zcmdz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..d9ba5e0c6fbe5ba8123405c2bcd15ff1fa98a415 GIT binary patch literal 254 zcmVKZr_(KK^!5KOdwGKme%cKl^(mzPy%6KMPoh zy?+Bjz5JPqKIdelJ(G&8Kb%Y#KTN5JJ^YlIKCyjQK3Yn$KOXO|y|J!GJ_-CQzUk*z zyj{EaKGkS4KP4G1z7v@bK(wmay>~>yJYYAYz1iQlKX|yVKN&Q#J+P~KJ`-vxzF^Wo zz0Rruz6JiXJ=#2sKMg&IKemSozSjioyCN4k#6xVE^^lNdGc)hKc{{6Z;(W_#4+-rQ12PTDZ#sI!hd0CxZN9d9(U(K}_Nh%4 zV~2J~$L@x-9Sl5b-#TvB&`gg`1BV8M9h93nZb%A9mkm{tq|2)GDXV1(c}OTHk{7rn z0>X5r0RzI`cY`5e1t0JMG`>d38V)ZhhY=D<$tfHQK*%t7!Y9!1#@wsTh;I#6@_79+ zQXWero|v;a6NZ!FWZ?odiCJ1nW^4caA~Dtw31s9fj^F??@Mq|M8o;HiPyFDwn<`&2 zGuLquin`JF$vQh%Dv5FvB$kwm2R9CfkoH~Ig}(aL@3GGrsJ19&*Z!SI3DG)qy8q13}sR1Q2#lTdi4ul?@TcDy+c)bKp;(UAUp5IJgWcCoC`HYItwu hh5fwC4=+Z#CgVKX&=mHl?8&v*#`3R?s){6i`y0Vs*{c8m literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..bae0ffd6324 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..d9ba5e0c6fbe5ba8123405c2bcd15ff1fa98a415 GIT binary patch literal 254 zcmVKZr_(KK^!5KOdwGKme%cKl^(mzPy%6KMPoh zy?+Bjz5JPqKIdelJ(G&8Kb%Y#KTN5JJ^YlIKCyjQK3Yn$KOXO|y|J!GJ_-CQzUk*z zyj{EaKGkS4KP4G1z7v@bK(wmay>~>yJYYAYz1iQlKX|yVKN&Q#J+P~KJ`-vxzF^Wo zz0Rruz6JiXJ=#2sKMg&IKemSozSjioy Date: Tue, 20 Jun 2023 00:29:49 -0400 Subject: [PATCH 07/23] update to test name and added reference op implementation Signed-off-by: pranshupant --- docs/Operators.md | 32 ++++++++++++++--------------- docs/TestCoverage.md | 32 ++++++++++++++--------------- onnx/backend/test/case/node/gelu.py | 2 +- onnx/reference/ops/op_gelu.py | 18 ++++++++++++++++ 4 files changed, 51 insertions(+), 33 deletions(-) create mode 100644 onnx/reference/ops/op_gelu.py diff --git a/docs/Operators.md b/docs/Operators.md index a7f91ccb914..70ed68867c5 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -9457,40 +9457,40 @@ This version of the operator has been available since version 20 of the default #### Examples
-gelu +gelu_default ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ```
-gelu_default +gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ```
diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index 01e9ca420a2..82df746dcdd 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6244,38 +6244,38 @@ expect( ### Gelu There are 2 test cases, listed as following:
-gelu +gelu_default ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ```
-gelu_default +gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ```
diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index ee337db6487..d57b3dd6940 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -13,7 +13,7 @@ class Gelu(Base): @staticmethod - def export() -> None: + def export_gelu_tanh() -> None: node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) diff --git a/onnx/reference/ops/op_gelu.py b/onnx/reference/ops/op_gelu.py new file mode 100644 index 00000000000..14715467647 --- /dev/null +++ b/onnx/reference/ops/op_gelu.py @@ -0,0 +1,18 @@ +# Copyright (c) ONNX Project Contributors + +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=W0221 + +import math + +import numpy as np + +from onnx.reference.ops._op import OpRunUnaryNum + + +class Gelu(OpRunUnaryNum): + def _run(self, x, approximate="none"): # type: ignore + if approximate == "tanh": + return (x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))),) + return (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))),) + From 0f62240981d4496b29c23da4baea4d636179bc17 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Wed, 21 Jun 2023 02:35:00 -0400 Subject: [PATCH 08/23] updates based on PR feedback Signed-off-by: pranshupant --- docs/Changelog.md | 8 ++++---- docs/Operators.md | 8 ++++---- onnx/defs/math/defs.cc | 29 +++++++++++++++-------------- onnx/reference/ops/op_gelu.py | 18 ------------------ onnx/test/automatic_upgrade_test.py | 4 ++-- 5 files changed, 25 insertions(+), 42 deletions(-) delete mode 100644 onnx/reference/ops/op_gelu.py diff --git a/docs/Changelog.md b/docs/Changelog.md index 5dc7251c2d7..e180096d521 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -23885,9 +23885,9 @@ This version of the operator has been available since version 19 of the default Gelu takes one input data (Tensor) and produces one output data (Tensor) where the gaussian error linear units function, - `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. If the attribute "approximate" is set to "tanh", the function estimation, - `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. @@ -23899,7 +23899,7 @@ This version of the operator has been available since version 20 of the default
approximate : string (default is none)
-
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
#### Inputs @@ -23919,7 +23919,7 @@ This version of the operator has been available since version 20 of the default #### Type Constraints
-
T : tensor(float16), tensor(float), tensor(double)
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
Constrain input and output types to float tensors.
diff --git a/docs/Operators.md b/docs/Operators.md index 70ed68867c5..aa2f64387ad 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -9415,9 +9415,9 @@ expect( Gelu takes one input data (Tensor) and produces one output data (Tensor) where the gaussian error linear units function, - `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. If the attribute "approximate" is set to "tanh", the function estimation, - `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. @@ -9429,7 +9429,7 @@ This version of the operator has been available since version 20 of the default
approximate : string (default is none)
-
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
#### Inputs @@ -9449,7 +9449,7 @@ This version of the operator has been available since version 20 of the default #### Type Constraints
-
T : tensor(float16), tensor(float), tensor(double)
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
Constrain input and output types to float tensors.
diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index 78f90805f07..ce4e5659b70 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -561,9 +561,9 @@ ONNX_OPERATOR_SET_SCHEMA( static const char* gelu_ver20_doc = R"DOC( Gelu takes one input data (Tensor) and produces one output data (Tensor) where the gaussian error linear units function, -`y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. +$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. If the attribute "approximate" is set to "tanh", the function estimation, -`y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied +$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. )DOC"; @@ -575,12 +575,12 @@ bool BuildContextDependentFunctionBodyGelu( const OpSchema& schema, FunctionProto& functionProto) { auto approx_attr_proto = ctx.getAttribute("approximate"); - std::string approx = approx_attr_proto != nullptr && approx_attr_proto->has_s() + std::string approximate = approx_attr_proto != nullptr && approx_attr_proto->has_s() ? approx_attr_proto->s() : gelu_default_approx; FunctionBuilder builder(functionProto); - if (approx == "tanh") { + if (approximate == "tanh") { builder.Add(R"( Half = Constant () HalfCast = CastLike (Half, X) @@ -593,12 +593,13 @@ bool BuildContextDependentFunctionBodyGelu( SqrtTwoOverPi = Sqrt (TwoOverPiCast) Three = Constant () ThreeCast = CastLike (Three, X) - CubeX = Pow ( X, ThreeCast) - XCubeC0 = Mul (C0Cast, CubeX) - XC0XCube = Sum (X, XCubeC0) - ErfApprox = Tanh (XC0XCube) + XCubed = Pow (X, ThreeCast) + XCubedC0 = Mul (C0Cast, XCubed) + XC0XCubed = Sum (X, XCubedC0) + TanhInput = Mul (SqrtTwoOverPi, XC0XCubed) + ErfApprox = Tanh (TanhInput) PhiApprox = Sum (OneCast, ErfApprox) - MultX = Mul (Half, X) + MultX = Mul (HalfCast, X) Y = Mul (MultX, PhiApprox) )"); } else { @@ -613,7 +614,7 @@ bool BuildContextDependentFunctionBodyGelu( XSqrt = Div (X, SqrtTwo) ErfXSqrt = Erf(XSqrt) Phi = Sum (OneCast, ErfXSqrt) - MultX = Mul (Half, X) + MultX = Mul (HalfCast, X) Y = Mul (MultX, Phi) )"); } @@ -630,14 +631,14 @@ ONNX_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Attr( "approximate", - "Gelu approximation algorithm: tanh, none(default)." - "'none': do not use approximation." - "'tanh': use tanh approximation.", + "Gelu approximation algorithm: `\"tanh\"`, `\"none\"`(default)." + "`\"none\"`: do not use approximation." + "`\"tanh\"`: use tanh approximation.", AttributeProto::STRING, gelu_default_approx) .TypeConstraint( "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); diff --git a/onnx/reference/ops/op_gelu.py b/onnx/reference/ops/op_gelu.py deleted file mode 100644 index 14715467647..00000000000 --- a/onnx/reference/ops/op_gelu.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) ONNX Project Contributors - -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=W0221 - -import math - -import numpy as np - -from onnx.reference.ops._op import OpRunUnaryNum - - -class Gelu(OpRunUnaryNum): - def _run(self, x, approximate="none"): # type: ignore - if approximate == "tanh": - return (x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))),) - return (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))),) - diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py index 12714fea997..0277e068a79 100644 --- a/onnx/test/automatic_upgrade_test.py +++ b/onnx/test/automatic_upgrade_test.py @@ -464,10 +464,10 @@ def test_GatherElements(self) -> None: def test_GatherND(self) -> None: self._test_op_upgrade("GatherND", 11, [[1, 2, 3], [1, 2, 3]], [[1, 2]]) - def test_Gelu_1(self) -> None: + def test_Gelu_approximate_tanh(self) -> None: self._test_op_upgrade("Gelu", 20, attrs={"approximate": "tanh"}) - def test_Gelu_2(self) -> None: + def test_Gelu(self) -> None: self._test_op_upgrade("Gelu", 20) def test_Gemm(self) -> None: From ab54319b3fc3d00406665722098b14d7bef9eec6 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 4 Jul 2023 20:33:10 -0400 Subject: [PATCH 09/23] fixed linting issues and test failures Signed-off-by: pranshupant --- docs/Operators.md | 20 ++++++++++++----- docs/TestCoverage.md | 20 ++++++++++++----- onnx/backend/test/case/node/gelu.py | 21 +++++++++++++----- .../data/node/test_gelu_default_1/model.onnx | Bin 93 -> 93 bytes .../test_data_set_0/output_0.pb | Bin 33 -> 21 bytes .../test_gelu_default_1_expanded/model.onnx | Bin 1425 -> 1429 bytes .../test_data_set_0/output_0.pb | Bin 33 -> 21 bytes .../data/node/test_gelu_default_2/model.onnx | Bin 109 -> 109 bytes .../test_data_set_0/output_0.pb | 6 ++--- .../test_gelu_default_2_expanded/model.onnx | Bin 1441 -> 1445 bytes .../test_data_set_0/output_0.pb | 6 ++--- .../node/test_gelu_tanh_1_expanded/model.onnx | Bin 2062 -> 2239 bytes .../node/test_gelu_tanh_2_expanded/model.onnx | Bin 2078 -> 2255 bytes 13 files changed, 51 insertions(+), 22 deletions(-) diff --git a/docs/Operators.md b/docs/Operators.md index aa2f64387ad..59db43da7c1 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -9464,12 +9464,12 @@ node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ``` @@ -9480,16 +9480,26 @@ expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" +) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ``` diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index 82df746dcdd..7b8d0cd2d9e 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6251,12 +6251,12 @@ node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ``` @@ -6265,16 +6265,26 @@ expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" +) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ``` diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index d57b3dd6940..cc93a4f5471 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -14,16 +14,26 @@ class Gelu(Base): @staticmethod def export_gelu_tanh() -> None: - node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") + node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" + ) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] - y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + ).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] - y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + ).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") @staticmethod @@ -32,11 +42,10 @@ def export_gelu_default() -> None: x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] - y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.99595031, 3.99987331, 4.99999857] - y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") - diff --git a/onnx/backend/test/data/node/test_gelu_default_1/model.onnx b/onnx/backend/test/data/node/test_gelu_default_1/model.onnx index 9de98825326d939f2b7294a5497098a61bcb58a3..ada8f652bed5fdba31049c4a5363d16c902a76fe 100644 GIT binary patch delta 18 Zcma!zoe;pwD8$0W#KG*u!o?sU0stQe0!jb? delta 18 Zcma!zoe;pwEyTjb#KG*u!o?sU0stS00#pD1 diff --git a/onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb index e3bff4eb0ce00f047fa9733e094caeb277b188a4..b12e822f15a5648d6c9c8f16d2ac4470c3534a8f 100644 GIT binary patch literal 21 acmd;J7GQK@tn}h(D^uFX00ePK;r0M2Qv_iE literal 33 jcmd;J7T|Vbtn`v73rsw6!SBd^1}Ipx`uREUH?Qmgmud?% diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_default_1_expanded/model.onnx index c5e924d4a058a08bc62c1090910e1917403fd982..ffee90beaf2288ed808fd440b22f0db19dbe931c 100644 GIT binary patch delta 59 zcmbQpJ(XLGgHwnnDKR-aH7`ZCB(=E2>JJxJsL

7K4p$nk-C(ToV_XOkT+Hn3qwA Pg^P)U*@=aVK|llm_;3#K delta 80 zcmbQrJ&{|SgHwnnDKR-aH7`ZCB(=E2>IWBBsF42TiOkv?opf24bGg85PLIT#G%m&p fAp?-iWC2#z$+uY6@p22Xa4~T(JF##v2#5dxeh(9+ diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb index e3bff4eb0ce00f047fa9733e094caeb277b188a4..b12e822f15a5648d6c9c8f16d2ac4470c3534a8f 100644 GIT binary patch literal 21 acmd;J7GQK@tn}h(D^uFX00ePK;r0M2Qv_iE literal 33 jcmd;J7T|Vbtn`v73rsw6!SBd^1}Ipx`uREUH?Qmgmud?% diff --git a/onnx/backend/test/data/node/test_gelu_default_2/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2/model.onnx index 4fd1b6ba5a58b03e9164f8bd4c4825ce69e5173e..c03f4701e47a4232ff057b67b2fee68c6ba294ff 100644 GIT binary patch delta 26 ecmd1Joe&|)D8$3X#K8>2EI`ca#KOfOAOZk1@&dB} delta 26 ecmd1Joe&|)EyTmc#K8>2EI`ca#KOfOAOZk2!~(ql diff --git a/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb index 00ae3481971..c55aea167f7 100644 --- a/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb +++ b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb @@ -1,3 +1,3 @@ - ByJ\ ?!p?ٴS0?m@=?M:WĿb93?= -]|nn@nobS?|6?{Hv@?3h*?R΋*??@Ь?D0J?X+?Lų,|ſ닿#E$?$R'M?"ſhn{@eۉ{&}H?#afԔ???r3F?aU?BHIſN $ Ә45$?+?.&?&mӾPoj]ÿM#\Hy5hIj?2v ÿ}fE¿\B+Cq?qqc\7'aG<ſe?cÿو`Pml5/2i@?0?v?q3/?c;Yſ@{T - \ No newline at end of file +ByJ?K>Q?A @t? V$I?WuB>d=?VQ?V=q>~W>Q?xG>d+8^_$>o2?.ޓ@3ٽxD<+7?п?24=Iz>jL*'A=D?D? +| +BKR?h@:UZ?T;%;)>faJ>s=?>*S  \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx index 11296adc1e5278193f9e90d2b5c8857f46cffca1..1988c1b6297e78111e1b62265dd3b2f27301d617 100644 GIT binary patch delta 67 zcmZ3;y_8#wgHwnnDKR-aH7`ZCB(=E2s)w5^RA_P`i@`=WO%|p?u89jxCNE@pEy^gw U!^OnG48$xz%<9C##ULO80A(%?>i_@% delta 72 zcmZ3=y^vd+gHwnnDKR-aH7`ZCB(=E2s)L&=R7ijFL}u-cPP#12xm;j2r$=H=8W&@Q XkO4?$vH&aVQ?A @t? V$I?WuB>d=?VQ?V=q>~W>Q?xG>d+8^_$>o2?.ޓ@3ٽxD<+7?п?24=Iz>jL*'A=D?D? +| +BKR?h@:UZ?T;%;)>faJ>s=?>*S  \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/model.onnx index 65f21c7627a04507ee554c3cbf4214e783d7f056..254f702635caeb04cb5f0e8932f0bf4e1329481b 100644 GIT binary patch delta 256 zcmeAZ*e@u~!70R(l$e~InwO$ml3HA1wM>vJNJwY0A+yG01$LH=&e|-Dxs&r*%qM?h z(iDntE=@{JQDP3rFSlagn#ste2U5u4Y~Y+&Tmt1l)Cd`XWW*pGX9Fc>-_jhQmV7S8 z3NEk|7fedX5Nw$M8wWSg5ll7UkCQo4IW#WQMGpB%+Y6>UjWq9Tll$O8@ m%>%2^|=0%$yCe4CECKup1dGgeE7j>rD=0-vIzjFhQ~a delta 163 zcmdll*e4*)!70R(l$e~InwO$ml3HA1^_8D1NJwY0A+yFtdwmwhoXPnt=92|EH2IuM zlTss;m;>_5tr)nbPrk=u1QKQk3k&H_7G%{Ga5hk4_ASi;%I9)1R&ap@S)2`o3?^@6 tHQ-ia4ld0F%Ype!LWUq)CckIf%$&{zwvodlF((bEQwV6X-eh);9RSO+DZKyy diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/model.onnx index 5b277e8c732ee9532eae615987ae1afaea4a04d1..6456042ec3620ff3548aef75a16c2f9386c30af9 100644 GIT binary patch delta 256 zcmbOya9&WHgHwnnDKR-aH7`ZCB(=E2YMUTekdV$~LuQT13hXQ!owZpQb0_Dsm{0!1 zq$w2PT$+@cqQo4KUv9;~HItD`52TR8*}yrmxCF|9s1Y&%$%sKX&IU@%zNI-pE%{uG z6%A=okjHV$r}$-$+$KxIo9C+jn7O`gEa%ftnjW=;Vs)f7(5%ka!AC@p~* mng>>+$psUKn&`x~nK>I`8OSRfU^g;W2u)64*P9&1egXh`z(Mi= delta 163 zcmX>vI8Q*FgHwnnDKR-aH7`ZCB(=E2s#Sn1NJwY0A+yFtdwmwhoXPnt=92|EH2IuM zlTss;m;>_5tr)nbPrk=u1QKQk3k&H_7G%{Ga5hk4_ASi;%I9)1R&ap@S)2`o3?^@6 tHQ-ia4ld0F%Ype!LWUq)CckIf%$&{zwvodlF((bEQwV6X-eh);69A&#DSiL| From ad409113d93a34a00b5d991c024572b488d2ada7 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 4 Jul 2023 23:13:33 -0400 Subject: [PATCH 10/23] Disabled GELU ORT tests for gelu (opset 20) Signed-off-by: pranshupant --- onnx/test/test_backend_onnxruntime.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnx/test/test_backend_onnxruntime.py b/onnx/test/test_backend_onnxruntime.py index 06811d7c1d7..9a87309c15a 100644 --- a/onnx/test/test_backend_onnxruntime.py +++ b/onnx/test/test_backend_onnxruntime.py @@ -249,6 +249,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): "|equal" "|identity" "|reshape" + "|gelu" ")" ) From fc730d14ed87180cfa5f56201bb57699e4ba1644 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Wed, 5 Jul 2023 11:20:48 -0400 Subject: [PATCH 11/23] Fixed C++ linting issues Signed-off-by: pranshupant --- onnx/defs/math/defs.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index ce4e5659b70..9128a957d58 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -560,10 +560,10 @@ ONNX_OPERATOR_SET_SCHEMA( static const char* gelu_ver20_doc = R"DOC( Gelu takes one input data (Tensor) and produces one -output data (Tensor) where the gaussian error linear units function, -$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. -If the attribute "approximate" is set to "tanh", the function estimation, -$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied +output data (Tensor) where the gaussian error linear units function, +$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. +If the attribute "approximate" is set to "tanh", the function estimation, +$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. )DOC"; @@ -575,11 +575,10 @@ bool BuildContextDependentFunctionBodyGelu( const OpSchema& schema, FunctionProto& functionProto) { auto approx_attr_proto = ctx.getAttribute("approximate"); - std::string approximate = approx_attr_proto != nullptr && approx_attr_proto->has_s() - ? approx_attr_proto->s() - : gelu_default_approx; + std::string approximate = + approx_attr_proto != nullptr && approx_attr_proto->has_s() ? approx_attr_proto->s() : gelu_default_approx; FunctionBuilder builder(functionProto); - + if (approximate == "tanh") { builder.Add(R"( Half = Constant () @@ -601,7 +600,7 @@ bool BuildContextDependentFunctionBodyGelu( PhiApprox = Sum (OneCast, ErfApprox) MultX = Mul (HalfCast, X) Y = Mul (MultX, PhiApprox) - )"); + )"); } else { builder.Add(R"( Half = Constant () From 8f43c38e53d85dddb16428ae96cfd14a93d5c765 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Thu, 11 May 2023 02:25:22 -0400 Subject: [PATCH 12/23] Added implementation for GELU Signed-off-by: pranshupant --- onnx/defs/math/defs.cc | 119 +++++++++++++++++++++++++++++++++----- onnx/defs/operator_sets.h | 2 + 2 files changed, 107 insertions(+), 14 deletions(-) diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index bc3d95ac07f..9fe0e3285ef 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -470,6 +470,26 @@ mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x})) ``` )DOC"; +ONNX_OPERATOR_SET_SCHEMA( + Mish, + 18, + OpSchema() + .SetDoc(mish_ver18_doc) + .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input X and output types to float tensors.") + .FunctionBody(R"ONNX( + { + Softplus_X = Softplus (X) + TanHSoftplusX = Tanh (Softplus_X) + Y = Mul (X, TanHSoftplusX) + } + )ONNX") + .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); + static const char* celu_ver12_doc = R"DOC( Continuously Differentiable Exponential Linear Units: Perform the linear unit element-wise on the input tensor X @@ -538,24 +558,95 @@ ONNX_OPERATOR_SET_SCHEMA( .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyCelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); +static const char* gelu_ver20_doc = R"DOC( +Gaussian Error Linear Units: +Gelu takes input data (Tensor) and an argument approximate, and produces one +output data (Tensor) where the function `f(x) = alpha * x for x < 0`, +`f(x) = x for x >= 0`, is applied to the data tensor elementwise. +Perform the linear unit element-wise on the input tensor X +using formula: + +``` +0.5*x*(1+erf(x/sqrt(2))) +``` + +When approximate is set to tanh + +``` +0.5*x*(1+Tanh(sqrt(2/π)*(x+0.044715*x^3))) +``` + +)DOC"; + +static std::string gelu_default_approx = "none"; + +bool BuildContextDependentFunctionBodyGelu( + const FunctionBodyBuildContext& ctx, + const OpSchema& schema, + FunctionProto& functionProto) { + auto approx_attr_proto = ctx.getAttribute("approximate"); + std::string approx = approx_attr_proto != nullptr && approx_attr_proto->has_s() + ? approx_attr_proto->s() + : gelu_default_alpha; + FunctionBuilder builder(functionProto); + + if (approx == "tanh") { + builder.Add(R"( + Half = Constant () + HalfCast = CastLike (Half, X) + One = Constant () + OneCast = CastLike (One, X) + TwoOverPi = Constant () + TwoOverPiCast = CastLike (TwoOverPi, X) + C0 = Constant () + C0Cast = CastLike (C0, X) + SqrtTwoOverPi = Sqrt (TwoOverPiCast) + Three = Constant () + ThreeCast = CastLike (Three, X) + CubeX = Pow ( X, ThreeCast) + XCubeC0 = Mul (C0Cast, CubeX) + XC0XCube = Sum (X, XCubeC0) + ErfApprox = Tanh (XC0XCube) + PhiApprox = Sum (OneCast, ErfApprox) + MultX = Mul (Half, X) + Y = Mul (MultX, PhiApprox) + )"); + } else { + builder.Add(R"( + Half = Constant () + HalfCast = CastLike (Half, X) + One = Constant () + OneCast = CastLike (One, X) + Two = Constant () + TwoCast = CastLike (Two, X) + SqrtTwo = Sqrt (TwoCast) + XSqrt = Div (X, SqrtTwo) + ErfXSqrt = Erf(XSqrt) + Phi = Sum (OneCast, ErfXSqrt) + MultX = Mul (Half, X) + Y = Mul (MultX, Phi) + )"); + } + schema.BuildFunction(functionProto); + return true; +} + ONNX_OPERATOR_SET_SCHEMA( - Mish, - 18, + Gelu, + 20, OpSchema() - .SetDoc(mish_ver18_doc) + .SetDoc(gelu_ver20_doc) .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input X and output types to float tensors.") - .FunctionBody(R"ONNX( - { - Softplus_X = Softplus (X) - TanHSoftplusX = Tanh (Softplus_X) - Y = Mul (X, TanHSoftplusX) - } - )ONNX") + .Attr( + "approximate", + "Type of gelu approximation algorithm to use: tanh, none(default)." + "'none': do not use approximation." + "'tanh': use tanh approximation.", + AttributeProto::STRING, + gelu_default_appox) + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float32 tensors.") + .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); static const char* Exp_ver13_doc = R"DOC( diff --git a/onnx/defs/operator_sets.h b/onnx/defs/operator_sets.h index 76473d2f9e0..a83adfd194f 100644 --- a/onnx/defs/operator_sets.h +++ b/onnx/defs/operator_sets.h @@ -1103,6 +1103,7 @@ class OpSet_Onnx_ver19 { // Forward declarations for ai.onnx version 20 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ConstantOfShape); // Iterate over schema from ai.onnx version 20 @@ -1110,6 +1111,7 @@ class OpSet_Onnx_ver20 { public: static void ForEachSchema(std::function fn) { fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); } }; From 78519256aa53bcf29f8853b6512e51ea2246413a Mon Sep 17 00:00:00 2001 From: pranshupant Date: Mon, 29 May 2023 23:09:28 -0400 Subject: [PATCH 13/23] adding unit test for gelu Signed-off-by: pranshupant --- onnx/backend/test/case/node/gelu.py | 44 +++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 onnx/backend/test/case/node/gelu.py diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py new file mode 100644 index 00000000000..7c4fe59e319 --- /dev/null +++ b/onnx/backend/test/case/node/gelu.py @@ -0,0 +1,44 @@ +# Copyright (c) ONNX Project Contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import math + +import numpy as np + +import onnx +from onnx.backend.test.case.base import Base +from onnx.backend.test.case.node import expect + + +class Gelu(Base): + @staticmethod + def export() -> None: + node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approx="tanh") + + x = np.array([-1, 0, 1]).astype(np.float32) + # expected output [-0.158808, 0., 0.841192] + cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = x * cdf + expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + + x = np.random.randn(3, 4, 5).astype(np.float32) + # expected output [2.9963627, 3.99993, 4.9999995] + cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = x * cdf + expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") + + @staticmethod + def export_gelu_default() -> None: + node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + + x = np.random.randn(-1, 0, 1).astype(np.float32) + # expected output [-0.15865526, 0., 0.84134474] + y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + + x = np.random.randn(3, 4, 5).astype(np.float32) + # expected output [2.99595031, 3.99987331, 4.99999857] + y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") + From 93160015fc297a7a40508ae2a7724e5525a0f698 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Mon, 29 May 2023 23:47:50 -0400 Subject: [PATCH 14/23] updated attribute name from appox to approximate + added trivial automatic_upgrade_tests Signed-off-by: pranshupant --- onnx/backend/test/case/node/gelu.py | 4 ++-- onnx/defs/math/defs.cc | 19 ++++--------------- onnx/test/automatic_upgrade_test.py | 6 ++++++ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index 7c4fe59e319..d7a6e2de270 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -14,7 +14,7 @@ class Gelu(Base): @staticmethod def export() -> None: - node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approx="tanh") + node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] @@ -31,7 +31,7 @@ def export() -> None: @staticmethod def export_gelu_default() -> None: node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) - + x = np.random.randn(-1, 0, 1).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index 9fe0e3285ef..ceba1fd317c 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -559,22 +559,11 @@ ONNX_OPERATOR_SET_SCHEMA( .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); static const char* gelu_ver20_doc = R"DOC( -Gaussian Error Linear Units: Gelu takes input data (Tensor) and an argument approximate, and produces one -output data (Tensor) where the function `f(x) = alpha * x for x < 0`, -`f(x) = x for x >= 0`, is applied to the data tensor elementwise. -Perform the linear unit element-wise on the input tensor X -using formula: - -``` -0.5*x*(1+erf(x/sqrt(2))) -``` - -When approximate is set to tanh - -``` -0.5*x*(1+Tanh(sqrt(2/π)*(x+0.044715*x^3))) -``` +output data (Tensor) where the function `y = 0.5 * x * (1 + erf(x/sqrt(2)))` +is applied to the tensor elementwise. When the attribute "approximate" is set +to "tanh", the function `y = 0.5 * x * (1 + Tanh(sqrt(2/π)*(x+0.044715*x^3)))` +is applied to the tensor elementwise to estimate. )DOC"; diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py index 248f338ce82..12714fea997 100644 --- a/onnx/test/automatic_upgrade_test.py +++ b/onnx/test/automatic_upgrade_test.py @@ -464,6 +464,12 @@ def test_GatherElements(self) -> None: def test_GatherND(self) -> None: self._test_op_upgrade("GatherND", 11, [[1, 2, 3], [1, 2, 3]], [[1, 2]]) + def test_Gelu_1(self) -> None: + self._test_op_upgrade("Gelu", 20, attrs={"approximate": "tanh"}) + + def test_Gelu_2(self) -> None: + self._test_op_upgrade("Gelu", 20) + def test_Gemm(self) -> None: self._test_op_upgrade("Gemm", 1, [[5, 4], [4, 3], [3]], [[5, 3]]) From 283a4fd38bfab585646c18204a213f6806e869d4 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 30 May 2023 23:34:06 -0400 Subject: [PATCH 15/23] updated doc and unit test Signed-off-by: pranshupant --- onnx/backend/test/case/node/gelu.py | 8 +++----- onnx/defs/math/defs.cc | 22 +++++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index d7a6e2de270..ee337db6487 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -18,21 +18,19 @@ def export() -> None: x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] - cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) - y = x * cdf + y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] - cdf = 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) - y = x * cdf + y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") @staticmethod def export_gelu_default() -> None: node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) - x = np.random.randn(-1, 0, 1).astype(np.float32) + x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index ceba1fd317c..78f90805f07 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -559,11 +559,12 @@ ONNX_OPERATOR_SET_SCHEMA( .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); static const char* gelu_ver20_doc = R"DOC( -Gelu takes input data (Tensor) and an argument approximate, and produces one -output data (Tensor) where the function `y = 0.5 * x * (1 + erf(x/sqrt(2)))` -is applied to the tensor elementwise. When the attribute "approximate" is set -to "tanh", the function `y = 0.5 * x * (1 + Tanh(sqrt(2/π)*(x+0.044715*x^3)))` -is applied to the tensor elementwise to estimate. +Gelu takes one input data (Tensor) and produces one +output data (Tensor) where the gaussian error linear units function, +`y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. +If the attribute "approximate" is set to "tanh", the function estimation, +`y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied +to the tensor elementwise. )DOC"; @@ -576,7 +577,7 @@ bool BuildContextDependentFunctionBodyGelu( auto approx_attr_proto = ctx.getAttribute("approximate"); std::string approx = approx_attr_proto != nullptr && approx_attr_proto->has_s() ? approx_attr_proto->s() - : gelu_default_alpha; + : gelu_default_approx; FunctionBuilder builder(functionProto); if (approx == "tanh") { @@ -629,12 +630,15 @@ ONNX_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Attr( "approximate", - "Type of gelu approximation algorithm to use: tanh, none(default)." + "Gelu approximation algorithm: tanh, none(default)." "'none': do not use approximation." "'tanh': use tanh approximation.", AttributeProto::STRING, - gelu_default_appox) - .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float32 tensors.") + gelu_default_approx) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); From 58f8280030ff3351c2848e19354211ecd2a0f865 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 30 May 2023 23:36:12 -0400 Subject: [PATCH 16/23] Adding generated doc files Signed-off-by: pranshupant --- docs/Changelog.md | 42 ++++++++++++++++++++++ docs/Operators.md | 86 ++++++++++++++++++++++++++++++++++++++++++++ docs/TestCoverage.md | 42 +++++++++++++++++++++- 3 files changed, 169 insertions(+), 1 deletion(-) diff --git a/docs/Changelog.md b/docs/Changelog.md index c82dbbbc083..5dc7251c2d7 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -23881,6 +23881,48 @@ This version of the operator has been available since version 19 of the default ## Version 20 of the default ONNX operator set +### **Gelu-20** + + Gelu takes one input data (Tensor) and produces one + output data (Tensor) where the gaussian error linear units function, + `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + If the attribute "approximate" is set to "tanh", the function estimation, + `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + to the tensor elementwise. + + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +

+
approximate : string (default is none)
+
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
+ +#### Inputs + +
+
X (differentiable) : T
+
Input tensor
+
+ +#### Outputs + +
+
Y (differentiable) : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types to float tensors.
+
+ ### **ConstantOfShape-20** Generate a tensor with given value and shape. diff --git a/docs/Operators.md b/docs/Operators.md index e9fa1c5f6eb..a7f91ccb914 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -170,6 +170,7 @@ For an operator input/output's differentiability, it can be differentiable, |Clip|13, 12, 11, 6, 1|13| |DynamicQuantizeLinear|11|11| |Elu|6, 1|18| +|Gelu|20|20| |GreaterOrEqual|16, 12|16| |GroupNormalization|18|18| |HammingWindow|17|17| @@ -9410,6 +9411,91 @@ expect(
+### **Gelu** + + Gelu takes one input data (Tensor) and produces one + output data (Tensor) where the gaussian error linear units function, + `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + If the attribute "approximate" is set to "tanh", the function estimation, + `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + to the tensor elementwise. + + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +
+
approximate : string (default is none)
+
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
+ +#### Inputs + +
+
X (differentiable) : T
+
Input tensor
+
+ +#### Outputs + +
+
Y (differentiable) : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float), tensor(double)
+
Constrain input and output types to float tensors.
+
+ + +#### Examples + +
+gelu + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +``` + +
+ + +
+gelu_default + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +``` + +
+ + ### **Gemm** General Matrix multiplication: diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index f59159e2a4c..01e9ca420a2 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6,7 +6,7 @@ * [Overall Test Coverage](#overall-test-coverage) # Node Test Coverage ## Summary -Node tests have covered 173/186 (93.01%, 5 generators excluded) common operators. +Node tests have covered 174/187 (93.05%, 5 generators excluded) common operators. Node tests have covered 0/0 (N/A) experimental operators. @@ -6241,6 +6241,46 @@ expect( +### Gelu +There are 2 test cases, listed as following: +
+gelu + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +``` + +
+
+gelu_default + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +``` + +
+ + ### Gemm There are 11 test cases, listed as following:
From 2bb32363d6e766d93dcca00fe3196e64c6f3602f Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 30 May 2023 23:46:55 -0400 Subject: [PATCH 17/23] adding test data files Signed-off-by: pranshupant --- .../data/node/test_gelu_default_1/model.onnx | Bin 0 -> 93 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 33 bytes .../node/test_gelu_default_1_expanded/model.onnx | Bin 0 -> 1425 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 33 bytes .../data/node/test_gelu_default_2/model.onnx | Bin 0 -> 109 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 3 +++ .../node/test_gelu_default_2_expanded/model.onnx | Bin 0 -> 1441 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 3 +++ .../test/data/node/test_gelu_tanh_1/model.onnx | Bin 0 -> 114 bytes .../test_gelu_tanh_1/test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_gelu_tanh_1/test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../node/test_gelu_tanh_1_expanded/model.onnx | Bin 0 -> 2062 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../test/data/node/test_gelu_tanh_2/model.onnx | Bin 0 -> 130 bytes .../test_gelu_tanh_2/test_data_set_0/input_0.pb | 1 + .../test_gelu_tanh_2/test_data_set_0/output_0.pb | Bin 0 -> 254 bytes .../node/test_gelu_tanh_2_expanded/model.onnx | Bin 0 -> 2078 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | Bin 0 -> 254 bytes 24 files changed, 10 insertions(+) create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb diff --git a/onnx/backend/test/data/node/test_gelu_default_1/model.onnx b/onnx/backend/test/data/node/test_gelu_default_1/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9de98825326d939f2b7294a5497098a61bcb58a3 GIT binary patch literal 93 zcmdKLDtPoGnKczp@qq8;j zA_PV+$?U+qH*db8EwM9D!;v=0q1B~reF8j*;~Skj>EH6jTZvAU%dI?>dOlGm(TPl* z8QQEc@>b<(9CZp)T4ii=A2w$yce)7;9*2yvi#A*Gu*JXu6&CnhoD);&sM2CO_`S literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..8a9445744b63f66e76c3ef4fce746606ffc6f47e GIT binary patch literal 21 Ycmd;J7GQK@tnlJtU})IS00s^A03TBWEC2ui literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..e3bff4eb0ce00f047fa9733e094caeb277b188a4 GIT binary patch literal 33 jcmd;J7T|Vbtn`v73rsw6!SBd^1}Ipx`uREUH?Qmgmud?% literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_2/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..4fd1b6ba5a58b03e9164f8bd4c4825ce69e5173e GIT binary patch literal 109 zcmdKLDtPoz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..00ae3481971 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ + ByJ\ ?!p?ٴS0?m@=?M:WĿb93?= +]|nn@nobS?|6?{Hv@?3h*?R΋*??@Ь?D0J?X+?Lų,|ſ닿#E$?$R'M?"ſhn{@eۉ{&}H?#afԔ???r3F?aU?BHIſN $ Ә45$?+?.&?&mӾPoj]ÿM#\Hy5hIj?2v ÿ}fE¿\B+Cq?qqc\7'aG<ſe?cÿو`Pml5/2i@?0?v?q3/?c;Yſ@{T + \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..11296adc1e5278193f9e90d2b5c8857f46cffca1 GIT binary patch literal 1441 zcmbtUO-sW-5Z#y*;qTiTdQcQvp?C?~W}~6$CN{gZy?OL!^rt#o zQbB~k=wXvthIwz^%+C1xyfqQinN;zKmZk1~HsOh%-N?*D_8z@hE0&2cnT|STgdSt;vd1=#*JWUU3JcDPUl3F2i&A4c z_`OSK5@|>LlmPLdWQ~}QZ!skhQLUB^Es=jp8NKH2VXJm`I`U$ literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..bae0ffd6324 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..00ae3481971 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ + ByJ\ ?!p?ٴS0?m@=?M:WĿb93?= +]|nn@nobS?|6?{Hv@?3h*?R΋*??@Ь?D0J?X+?Lų,|ſ닿#E$?$R'M?"ſhn{@eۉ{&}H?#afԔ???r3F?aU?BHIſN $ Ә45$?+?.&?&mӾPoj]ÿM#\Hy5hIj?2v ÿ}fE¿\B+Cq?qqc\7'aG<ſe?cÿو`Pml5/2i@?0?v?q3/?c;Yſ@{T + \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cbb06f2b052747f6aa0a281ec0e25ef28a8fb20c GIT binary patch literal 114 zcmdl1k8X$;I^x^uMQ<+cJ7tF~ z#xCuUj?)ckn+<&G+`3-ZG%TM^nM+w=4{gp|FC+z&S4~|}l*_8}q3Y)f@{trKQWv-+ zL&9>Wj1%F{yTOp~h7b4qb8y>*8FgB+gBc9H=l49vnd<u z-tJw6YvX6HNC`X?X*)PWyFZ8-#(wnVFC*|);R9wofj@Kh;DL(j;5MY1u)2)9;r*2t c_H#~bUW5!o#X8zBv=(1j{#~f*NXhSh10;XhRsaA1 literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..8a9445744b63f66e76c3ef4fce746606ffc6f47e GIT binary patch literal 21 Ycmd;J7GQK@tnlJtU})IS00s^A03TBWEC2ui literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..0f554cc42e247392ae38456c23643f18032a9088 GIT binary patch literal 21 bcmd;J7GQK@tn}iUFi&Y80}#YSgxdoED_sPK literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..887e5c52023cb40139da847694442f1a46320594 GIT binary patch literal 130 zcmdz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..d9ba5e0c6fbe5ba8123405c2bcd15ff1fa98a415 GIT binary patch literal 254 zcmVKZr_(KK^!5KOdwGKme%cKl^(mzPy%6KMPoh zy?+Bjz5JPqKIdelJ(G&8Kb%Y#KTN5JJ^YlIKCyjQK3Yn$KOXO|y|J!GJ_-CQzUk*z zyj{EaKGkS4KP4G1z7v@bK(wmay>~>yJYYAYz1iQlKX|yVKN&Q#J+P~KJ`-vxzF^Wo zz0Rruz6JiXJ=#2sKMg&IKemSozSjioyCN4k#6xVE^^lNdGc)hKc{{6Z;(W_#4+-rQ12PTDZ#sI!hd0CxZN9d9(U(K}_Nh%4 zV~2J~$L@x-9Sl5b-#TvB&`gg`1BV8M9h93nZb%A9mkm{tq|2)GDXV1(c}OTHk{7rn z0>X5r0RzI`cY`5e1t0JMG`>d38V)ZhhY=D<$tfHQK*%t7!Y9!1#@wsTh;I#6@_79+ zQXWero|v;a6NZ!FWZ?odiCJ1nW^4caA~Dtw31s9fj^F??@Mq|M8o;HiPyFDwn<`&2 zGuLquin`JF$vQh%Dv5FvB$kwm2R9CfkoH~Ig}(aL@3GGrsJ19&*Z!SI3DG)qy8q13}sR1Q2#lTdi4ul?@TcDy+c)bKp;(UAUp5IJgWcCoC`HYItwu hh5fwC4=+Z#CgVKX&=mHl?8&v*#`3R?s){6i`y0Vs*{c8m literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..bae0ffd6324 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..d9ba5e0c6fbe5ba8123405c2bcd15ff1fa98a415 GIT binary patch literal 254 zcmVKZr_(KK^!5KOdwGKme%cKl^(mzPy%6KMPoh zy?+Bjz5JPqKIdelJ(G&8Kb%Y#KTN5JJ^YlIKCyjQK3Yn$KOXO|y|J!GJ_-CQzUk*z zyj{EaKGkS4KP4G1z7v@bK(wmay>~>yJYYAYz1iQlKX|yVKN&Q#J+P~KJ`-vxzF^Wo zz0Rruz6JiXJ=#2sKMg&IKemSozSjioy Date: Tue, 20 Jun 2023 00:29:49 -0400 Subject: [PATCH 18/23] update to test name and added reference op implementation Signed-off-by: pranshupant --- docs/Operators.md | 32 ++++++++++++++--------------- docs/TestCoverage.md | 32 ++++++++++++++--------------- onnx/backend/test/case/node/gelu.py | 2 +- onnx/reference/ops/op_gelu.py | 18 ++++++++++++++++ 4 files changed, 51 insertions(+), 33 deletions(-) create mode 100644 onnx/reference/ops/op_gelu.py diff --git a/docs/Operators.md b/docs/Operators.md index a7f91ccb914..70ed68867c5 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -9457,40 +9457,40 @@ This version of the operator has been available since version 20 of the default #### Examples
-gelu +gelu_default ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ```
-gelu_default +gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ```
diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index 01e9ca420a2..82df746dcdd 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6244,38 +6244,38 @@ expect( ### Gelu There are 2 test cases, listed as following:
-gelu +gelu_default ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") +# expected output [-0.15865526, 0., 0.84134474] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +# expected output [2.99595031, 3.99987331, 4.99999857] +y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ```
-gelu_default +gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) -# expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") +# expected output [-0.158808, 0., 0.841192] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) -# expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) -expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +# expected output [2.9963627, 3.99993, 4.9999995] +y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ```
diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index ee337db6487..d57b3dd6940 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -13,7 +13,7 @@ class Gelu(Base): @staticmethod - def export() -> None: + def export_gelu_tanh() -> None: node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") x = np.array([-1, 0, 1]).astype(np.float32) diff --git a/onnx/reference/ops/op_gelu.py b/onnx/reference/ops/op_gelu.py new file mode 100644 index 00000000000..14715467647 --- /dev/null +++ b/onnx/reference/ops/op_gelu.py @@ -0,0 +1,18 @@ +# Copyright (c) ONNX Project Contributors + +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=W0221 + +import math + +import numpy as np + +from onnx.reference.ops._op import OpRunUnaryNum + + +class Gelu(OpRunUnaryNum): + def _run(self, x, approximate="none"): # type: ignore + if approximate == "tanh": + return (x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))),) + return (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))),) + From b2f86d51d2a23e2539866383fd9e39bddf378bd3 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Wed, 21 Jun 2023 02:35:00 -0400 Subject: [PATCH 19/23] updates based on PR feedback Signed-off-by: pranshupant --- docs/Changelog.md | 8 ++++---- docs/Operators.md | 8 ++++---- onnx/defs/math/defs.cc | 29 +++++++++++++++-------------- onnx/reference/ops/op_gelu.py | 18 ------------------ onnx/test/automatic_upgrade_test.py | 4 ++-- 5 files changed, 25 insertions(+), 42 deletions(-) delete mode 100644 onnx/reference/ops/op_gelu.py diff --git a/docs/Changelog.md b/docs/Changelog.md index 5dc7251c2d7..e180096d521 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -23885,9 +23885,9 @@ This version of the operator has been available since version 19 of the default Gelu takes one input data (Tensor) and produces one output data (Tensor) where the gaussian error linear units function, - `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. If the attribute "approximate" is set to "tanh", the function estimation, - `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. @@ -23899,7 +23899,7 @@ This version of the operator has been available since version 20 of the default
approximate : string (default is none)
-
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
#### Inputs @@ -23919,7 +23919,7 @@ This version of the operator has been available since version 20 of the default #### Type Constraints
-
T : tensor(float16), tensor(float), tensor(double)
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
Constrain input and output types to float tensors.
diff --git a/docs/Operators.md b/docs/Operators.md index 70ed68867c5..aa2f64387ad 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -9415,9 +9415,9 @@ expect( Gelu takes one input data (Tensor) and produces one output data (Tensor) where the gaussian error linear units function, - `y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. + $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. If the attribute "approximate" is set to "tanh", the function estimation, - `y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied + $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. @@ -9429,7 +9429,7 @@ This version of the operator has been available since version 20 of the default
approximate : string (default is none)
-
Gelu approximation algorithm: tanh, none(default).'none': do not use approximation.'tanh': use tanh approximation.
+
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
#### Inputs @@ -9449,7 +9449,7 @@ This version of the operator has been available since version 20 of the default #### Type Constraints
-
T : tensor(float16), tensor(float), tensor(double)
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
Constrain input and output types to float tensors.
diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index 78f90805f07..ce4e5659b70 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -561,9 +561,9 @@ ONNX_OPERATOR_SET_SCHEMA( static const char* gelu_ver20_doc = R"DOC( Gelu takes one input data (Tensor) and produces one output data (Tensor) where the gaussian error linear units function, -`y = 0.5 * x * (1 + erf(x/sqrt(2)))` is applied to the tensor elementwise. +$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. If the attribute "approximate" is set to "tanh", the function estimation, -`y = 0.5 * x * (1 + Tanh(sqrt(2/π) * (x + 0.044715 * x^3)))` is used and applied +$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. )DOC"; @@ -575,12 +575,12 @@ bool BuildContextDependentFunctionBodyGelu( const OpSchema& schema, FunctionProto& functionProto) { auto approx_attr_proto = ctx.getAttribute("approximate"); - std::string approx = approx_attr_proto != nullptr && approx_attr_proto->has_s() + std::string approximate = approx_attr_proto != nullptr && approx_attr_proto->has_s() ? approx_attr_proto->s() : gelu_default_approx; FunctionBuilder builder(functionProto); - if (approx == "tanh") { + if (approximate == "tanh") { builder.Add(R"( Half = Constant () HalfCast = CastLike (Half, X) @@ -593,12 +593,13 @@ bool BuildContextDependentFunctionBodyGelu( SqrtTwoOverPi = Sqrt (TwoOverPiCast) Three = Constant () ThreeCast = CastLike (Three, X) - CubeX = Pow ( X, ThreeCast) - XCubeC0 = Mul (C0Cast, CubeX) - XC0XCube = Sum (X, XCubeC0) - ErfApprox = Tanh (XC0XCube) + XCubed = Pow (X, ThreeCast) + XCubedC0 = Mul (C0Cast, XCubed) + XC0XCubed = Sum (X, XCubedC0) + TanhInput = Mul (SqrtTwoOverPi, XC0XCubed) + ErfApprox = Tanh (TanhInput) PhiApprox = Sum (OneCast, ErfApprox) - MultX = Mul (Half, X) + MultX = Mul (HalfCast, X) Y = Mul (MultX, PhiApprox) )"); } else { @@ -613,7 +614,7 @@ bool BuildContextDependentFunctionBodyGelu( XSqrt = Div (X, SqrtTwo) ErfXSqrt = Erf(XSqrt) Phi = Sum (OneCast, ErfXSqrt) - MultX = Mul (Half, X) + MultX = Mul (HalfCast, X) Y = Mul (MultX, Phi) )"); } @@ -630,14 +631,14 @@ ONNX_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Attr( "approximate", - "Gelu approximation algorithm: tanh, none(default)." - "'none': do not use approximation." - "'tanh': use tanh approximation.", + "Gelu approximation algorithm: `\"tanh\"`, `\"none\"`(default)." + "`\"none\"`: do not use approximation." + "`\"tanh\"`: use tanh approximation.", AttributeProto::STRING, gelu_default_approx) .TypeConstraint( "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); diff --git a/onnx/reference/ops/op_gelu.py b/onnx/reference/ops/op_gelu.py deleted file mode 100644 index 14715467647..00000000000 --- a/onnx/reference/ops/op_gelu.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) ONNX Project Contributors - -# SPDX-License-Identifier: Apache-2.0 -# pylint: disable=W0221 - -import math - -import numpy as np - -from onnx.reference.ops._op import OpRunUnaryNum - - -class Gelu(OpRunUnaryNum): - def _run(self, x, approximate="none"): # type: ignore - if approximate == "tanh": - return (x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))),) - return (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))),) - diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py index 12714fea997..0277e068a79 100644 --- a/onnx/test/automatic_upgrade_test.py +++ b/onnx/test/automatic_upgrade_test.py @@ -464,10 +464,10 @@ def test_GatherElements(self) -> None: def test_GatherND(self) -> None: self._test_op_upgrade("GatherND", 11, [[1, 2, 3], [1, 2, 3]], [[1, 2]]) - def test_Gelu_1(self) -> None: + def test_Gelu_approximate_tanh(self) -> None: self._test_op_upgrade("Gelu", 20, attrs={"approximate": "tanh"}) - def test_Gelu_2(self) -> None: + def test_Gelu(self) -> None: self._test_op_upgrade("Gelu", 20) def test_Gemm(self) -> None: From a47c3c1afbc8b32186d5c85beac2052b21bd490e Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 4 Jul 2023 20:33:10 -0400 Subject: [PATCH 20/23] fixed linting issues and test failures Signed-off-by: pranshupant --- docs/Operators.md | 20 ++++++++++++----- docs/TestCoverage.md | 20 ++++++++++++----- onnx/backend/test/case/node/gelu.py | 21 +++++++++++++----- .../data/node/test_gelu_default_1/model.onnx | Bin 93 -> 93 bytes .../test_data_set_0/output_0.pb | Bin 33 -> 21 bytes .../test_gelu_default_1_expanded/model.onnx | Bin 1425 -> 1429 bytes .../test_data_set_0/output_0.pb | Bin 33 -> 21 bytes .../data/node/test_gelu_default_2/model.onnx | Bin 109 -> 109 bytes .../test_data_set_0/output_0.pb | 6 ++--- .../test_gelu_default_2_expanded/model.onnx | Bin 1441 -> 1445 bytes .../test_data_set_0/output_0.pb | 6 ++--- .../node/test_gelu_tanh_1_expanded/model.onnx | Bin 2062 -> 2239 bytes .../node/test_gelu_tanh_2_expanded/model.onnx | Bin 2078 -> 2255 bytes 13 files changed, 51 insertions(+), 22 deletions(-) diff --git a/docs/Operators.md b/docs/Operators.md index aa2f64387ad..59db43da7c1 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -9464,12 +9464,12 @@ node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ``` @@ -9480,16 +9480,26 @@ expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" +) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ``` diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index 82df746dcdd..7b8d0cd2d9e 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6251,12 +6251,12 @@ node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.99595031, 3.99987331, 4.99999857] -y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") ``` @@ -6265,16 +6265,26 @@ expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") gelu_tanh ```python -node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") +node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" +) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] -y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") ``` diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py index d57b3dd6940..cc93a4f5471 100644 --- a/onnx/backend/test/case/node/gelu.py +++ b/onnx/backend/test/case/node/gelu.py @@ -14,16 +14,26 @@ class Gelu(Base): @staticmethod def export_gelu_tanh() -> None: - node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"], approximate="tanh") + node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" + ) x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.158808, 0., 0.841192] - y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + ).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.9963627, 3.99993, 4.9999995] - y = x * 0.5 * (1 + np.tanh((np.sqrt(2/np.pi) * (x + 0.044715 * np.power(x, 3))))) + y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + ).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") @staticmethod @@ -32,11 +42,10 @@ def export_gelu_default() -> None: x = np.array([-1, 0, 1]).astype(np.float32) # expected output [-0.15865526, 0., 0.84134474] - y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") x = np.random.randn(3, 4, 5).astype(np.float32) # expected output [2.99595031, 3.99987331, 4.99999857] - y = 0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2))) + y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") - diff --git a/onnx/backend/test/data/node/test_gelu_default_1/model.onnx b/onnx/backend/test/data/node/test_gelu_default_1/model.onnx index 9de98825326d939f2b7294a5497098a61bcb58a3..ada8f652bed5fdba31049c4a5363d16c902a76fe 100644 GIT binary patch delta 18 Zcma!zoe;pwD8$0W#KG*u!o?sU0stQe0!jb? delta 18 Zcma!zoe;pwEyTjb#KG*u!o?sU0stS00#pD1 diff --git a/onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb index e3bff4eb0ce00f047fa9733e094caeb277b188a4..b12e822f15a5648d6c9c8f16d2ac4470c3534a8f 100644 GIT binary patch literal 21 acmd;J7GQK@tn}h(D^uFX00ePK;r0M2Qv_iE literal 33 jcmd;J7T|Vbtn`v73rsw6!SBd^1}Ipx`uREUH?Qmgmud?% diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_default_1_expanded/model.onnx index c5e924d4a058a08bc62c1090910e1917403fd982..ffee90beaf2288ed808fd440b22f0db19dbe931c 100644 GIT binary patch delta 59 zcmbQpJ(XLGgHwnnDKR-aH7`ZCB(=E2>JJxJsL

7K4p$nk-C(ToV_XOkT+Hn3qwA Pg^P)U*@=aVK|llm_;3#K delta 80 zcmbQrJ&{|SgHwnnDKR-aH7`ZCB(=E2>IWBBsF42TiOkv?opf24bGg85PLIT#G%m&p fAp?-iWC2#z$+uY6@p22Xa4~T(JF##v2#5dxeh(9+ diff --git a/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb index e3bff4eb0ce00f047fa9733e094caeb277b188a4..b12e822f15a5648d6c9c8f16d2ac4470c3534a8f 100644 GIT binary patch literal 21 acmd;J7GQK@tn}h(D^uFX00ePK;r0M2Qv_iE literal 33 jcmd;J7T|Vbtn`v73rsw6!SBd^1}Ipx`uREUH?Qmgmud?% diff --git a/onnx/backend/test/data/node/test_gelu_default_2/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2/model.onnx index 4fd1b6ba5a58b03e9164f8bd4c4825ce69e5173e..c03f4701e47a4232ff057b67b2fee68c6ba294ff 100644 GIT binary patch delta 26 ecmd1Joe&|)D8$3X#K8>2EI`ca#KOfOAOZk1@&dB} delta 26 ecmd1Joe&|)EyTmc#K8>2EI`ca#KOfOAOZk2!~(ql diff --git a/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb index 00ae3481971..c55aea167f7 100644 --- a/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb +++ b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb @@ -1,3 +1,3 @@ - ByJ\ ?!p?ٴS0?m@=?M:WĿb93?= -]|nn@nobS?|6?{Hv@?3h*?R΋*??@Ь?D0J?X+?Lų,|ſ닿#E$?$R'M?"ſhn{@eۉ{&}H?#afԔ???r3F?aU?BHIſN $ Ә45$?+?.&?&mӾPoj]ÿM#\Hy5hIj?2v ÿ}fE¿\B+Cq?qqc\7'aG<ſe?cÿو`Pml5/2i@?0?v?q3/?c;Yſ@{T - \ No newline at end of file +ByJ?K>Q?A @t? V$I?WuB>d=?VQ?V=q>~W>Q?xG>d+8^_$>o2?.ޓ@3ٽxD<+7?п?24=Iz>jL*'A=D?D? +| +BKR?h@:UZ?T;%;)>faJ>s=?>*S  \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx index 11296adc1e5278193f9e90d2b5c8857f46cffca1..1988c1b6297e78111e1b62265dd3b2f27301d617 100644 GIT binary patch delta 67 zcmZ3;y_8#wgHwnnDKR-aH7`ZCB(=E2s)w5^RA_P`i@`=WO%|p?u89jxCNE@pEy^gw U!^OnG48$xz%<9C##ULO80A(%?>i_@% delta 72 zcmZ3=y^vd+gHwnnDKR-aH7`ZCB(=E2s)L&=R7ijFL}u-cPP#12xm;j2r$=H=8W&@Q XkO4?$vH&aVQ?A @t? V$I?WuB>d=?VQ?V=q>~W>Q?xG>d+8^_$>o2?.ޓ@3ٽxD<+7?п?24=Iz>jL*'A=D?D? +| +BKR?h@:UZ?T;%;)>faJ>s=?>*S  \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/model.onnx index 65f21c7627a04507ee554c3cbf4214e783d7f056..254f702635caeb04cb5f0e8932f0bf4e1329481b 100644 GIT binary patch delta 256 zcmeAZ*e@u~!70R(l$e~InwO$ml3HA1wM>vJNJwY0A+yG01$LH=&e|-Dxs&r*%qM?h z(iDntE=@{JQDP3rFSlagn#ste2U5u4Y~Y+&Tmt1l)Cd`XWW*pGX9Fc>-_jhQmV7S8 z3NEk|7fedX5Nw$M8wWSg5ll7UkCQo4IW#WQMGpB%+Y6>UjWq9Tll$O8@ m%>%2^|=0%$yCe4CECKup1dGgeE7j>rD=0-vIzjFhQ~a delta 163 zcmdll*e4*)!70R(l$e~InwO$ml3HA1^_8D1NJwY0A+yFtdwmwhoXPnt=92|EH2IuM zlTss;m;>_5tr)nbPrk=u1QKQk3k&H_7G%{Ga5hk4_ASi;%I9)1R&ap@S)2`o3?^@6 tHQ-ia4ld0F%Ype!LWUq)CckIf%$&{zwvodlF((bEQwV6X-eh);9RSO+DZKyy diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/model.onnx index 5b277e8c732ee9532eae615987ae1afaea4a04d1..6456042ec3620ff3548aef75a16c2f9386c30af9 100644 GIT binary patch delta 256 zcmbOya9&WHgHwnnDKR-aH7`ZCB(=E2YMUTekdV$~LuQT13hXQ!owZpQb0_Dsm{0!1 zq$w2PT$+@cqQo4KUv9;~HItD`52TR8*}yrmxCF|9s1Y&%$%sKX&IU@%zNI-pE%{uG z6%A=okjHV$r}$-$+$KxIo9C+jn7O`gEa%ftnjW=;Vs)f7(5%ka!AC@p~* mng>>+$psUKn&`x~nK>I`8OSRfU^g;W2u)64*P9&1egXh`z(Mi= delta 163 zcmX>vI8Q*FgHwnnDKR-aH7`ZCB(=E2s#Sn1NJwY0A+yFtdwmwhoXPnt=92|EH2IuM zlTss;m;>_5tr)nbPrk=u1QKQk3k&H_7G%{Ga5hk4_ASi;%I9)1R&ap@S)2`o3?^@6 tHQ-ia4ld0F%Ype!LWUq)CckIf%$&{zwvodlF((bEQwV6X-eh);69A&#DSiL| From e9eeb1d5ca40a6ad80ff38cf5f74b2e55b0e39b7 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Tue, 4 Jul 2023 23:13:33 -0400 Subject: [PATCH 21/23] Disabled GELU ORT tests for gelu (opset 20) Signed-off-by: pranshupant --- onnx/test/test_backend_onnxruntime.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnx/test/test_backend_onnxruntime.py b/onnx/test/test_backend_onnxruntime.py index 06811d7c1d7..9a87309c15a 100644 --- a/onnx/test/test_backend_onnxruntime.py +++ b/onnx/test/test_backend_onnxruntime.py @@ -249,6 +249,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): "|equal" "|identity" "|reshape" + "|gelu" ")" ) From ba95b2f5bc95c33e7d6760fcb0c6aa39d168bf70 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Wed, 5 Jul 2023 11:20:48 -0400 Subject: [PATCH 22/23] Fixed C++ linting issues Signed-off-by: pranshupant --- onnx/defs/math/defs.cc | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/onnx/defs/math/defs.cc b/onnx/defs/math/defs.cc index ce4e5659b70..9128a957d58 100644 --- a/onnx/defs/math/defs.cc +++ b/onnx/defs/math/defs.cc @@ -560,10 +560,10 @@ ONNX_OPERATOR_SET_SCHEMA( static const char* gelu_ver20_doc = R"DOC( Gelu takes one input data (Tensor) and produces one -output data (Tensor) where the gaussian error linear units function, -$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. -If the attribute "approximate" is set to "tanh", the function estimation, -$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied +output data (Tensor) where the gaussian error linear units function, +$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. +If the attribute "approximate" is set to "tanh", the function estimation, +$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied to the tensor elementwise. )DOC"; @@ -575,11 +575,10 @@ bool BuildContextDependentFunctionBodyGelu( const OpSchema& schema, FunctionProto& functionProto) { auto approx_attr_proto = ctx.getAttribute("approximate"); - std::string approximate = approx_attr_proto != nullptr && approx_attr_proto->has_s() - ? approx_attr_proto->s() - : gelu_default_approx; + std::string approximate = + approx_attr_proto != nullptr && approx_attr_proto->has_s() ? approx_attr_proto->s() : gelu_default_approx; FunctionBuilder builder(functionProto); - + if (approximate == "tanh") { builder.Add(R"( Half = Constant () @@ -601,7 +600,7 @@ bool BuildContextDependentFunctionBodyGelu( PhiApprox = Sum (OneCast, ErfApprox) MultX = Mul (HalfCast, X) Y = Mul (MultX, PhiApprox) - )"); + )"); } else { builder.Add(R"( Half = Constant () From 00b977f014c3a88fb90465f91dfdf872f0e2a581 Mon Sep 17 00:00:00 2001 From: pranshupant Date: Sat, 8 Jul 2023 01:19:55 -0400 Subject: [PATCH 23/23] Updated Changelog to account for #5390 Signed-off-by: pranshupant --- docs/Changelog.md | 56 +++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/docs/Changelog.md b/docs/Changelog.md index e180096d521..0a69e103102 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -23881,15 +23881,9 @@ This version of the operator has been available since version 19 of the default ## Version 20 of the default ONNX operator set -### **Gelu-20** - - Gelu takes one input data (Tensor) and produces one - output data (Tensor) where the gaussian error linear units function, - $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. - If the attribute "approximate" is set to "tanh", the function estimation, - $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied - to the tensor elementwise. +### **ConstantOfShape-20** + Generate a tensor with given value and shape. #### Version @@ -23898,34 +23892,42 @@ This version of the operator has been available since version 20 of the default #### Attributes

-
approximate : string (default is none)
-
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
+
value : tensor
+
(Optional) The value of the output elements.Should be a one-element tensor. If not specified, it defaults to a tensor of value 0 and datatype float32
#### Inputs
-
X (differentiable) : T
-
Input tensor
+
input : T1
+
1D tensor. The shape of the expected output tensor. If empty tensor is given, the output would be a scalar. All values must be >= 0.
#### Outputs
-
Y (differentiable) : T
-
Output tensor
+
output : T2
+
Output tensor of shape specified by 'input'.If attribute 'value' is specified, the value and datatype of the output tensor is taken from 'value'.If attribute 'value' is not specified, the value in the output defaults to 0, and the datatype defaults to float32.
#### Type Constraints
-
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
-
Constrain input and output types to float tensors.
+
T1 : tensor(int64)
+
Constrain input types.
+
T2 : tensor(float16), tensor(float), tensor(double), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(bool), tensor(bfloat16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
+
Constrain output types to be numerics.
-### **ConstantOfShape-20** +### **Gelu-20** + + Gelu takes one input data (Tensor) and produces one + output data (Tensor) where the gaussian error linear units function, + $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. + If the attribute "approximate" is set to "tanh", the function estimation, + $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied + to the tensor elementwise. - Generate a tensor with given value and shape. #### Version @@ -23934,31 +23936,29 @@ This version of the operator has been available since version 20 of the default #### Attributes
-
value : tensor
-
(Optional) The value of the output elements.Should be a one-element tensor. If not specified, it defaults to a tensor of value 0 and datatype float32
+
approximate : string (default is none)
+
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
#### Inputs
-
input : T1
-
1D tensor. The shape of the expected output tensor. If empty tensor is given, the output would be a scalar. All values must be >= 0.
+
X (differentiable) : T
+
Input tensor
#### Outputs
-
output : T2
-
Output tensor of shape specified by 'input'.If attribute 'value' is specified, the value and datatype of the output tensor is taken from 'value'.If attribute 'value' is not specified, the value in the output defaults to 0, and the datatype defaults to float32.
+
Y (differentiable) : T
+
Output tensor
#### Type Constraints
-
T1 : tensor(int64)
-
Constrain input types.
-
T2 : tensor(float16), tensor(float), tensor(double), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(bool), tensor(bfloat16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
-
Constrain output types to be numerics.
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
+
Constrain input and output types to float tensors.
### **GridSample-20**