From 527f4ddbfd059d2dd1b8c61ea1b7a53b3d3b1952 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Tue, 23 Apr 2024 21:42:48 +0000 Subject: [PATCH] Add Gradient-2 with bfloat16 Signed-off-by: Thiago Crepaldi --- onnx/defs/operator_sets_preview.h | 15 ++- onnx/defs/schema.h | 4 +- onnx/defs/training/defs.cc | 10 +- onnx/defs/training/old.cc | 196 ++++++++++++++++++++++++++++++ onnx/helper.py | 1 + 5 files changed, 218 insertions(+), 8 deletions(-) create mode 100644 onnx/defs/training/old.cc diff --git a/onnx/defs/operator_sets_preview.h b/onnx/defs/operator_sets_preview.h index 55ca07bdfb9..7060db8e004 100644 --- a/onnx/defs/operator_sets_preview.h +++ b/onnx/defs/operator_sets_preview.h @@ -8,7 +8,7 @@ namespace ONNX_NAMESPACE { -// Declare training operators. +// Declare training operators version 1 class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Gradient); class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(1, Momentum); @@ -26,12 +26,25 @@ class OpSet_OnnxPreview_ver1 { } }; +// Declare training operators version 2 + +class ONNX_PREVIEW_OPERATOR_SET_SCHEMA_CLASS_NAME(2, Gradient); + +// Iterate over schema from ai.onnx.training version 2 +class OpSet_OnnxPreview_ver2 { + public: + static void ForEachSchema(std::function fn) { + fn(GetOpSchema()); + } +}; + // Register preview operators. inline void RegisterOnnxPreviewOperatorSetSchema() { // Preview operators should have only one version. // If changes are needed for a specific preview operator, // its spec should be modified without increasing its version. RegisterOpSetSchema(); + RegisterOpSetSchema(); } } // namespace ONNX_NAMESPACE diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index e786a96b299..c012e9e8f13 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -1181,14 +1181,14 @@ class OpSchemaRegistry final : public ISchemaRegistry { // ONNX's preview domain contains operators subject to change, so // versining is not meaningful and that domain should have only one // version. - map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 1); + map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = std::make_pair(1, 2); // Version corresponding last release of ONNX. Update this to match with // the max version above in a *release* version of ONNX. But in other // versions, the max version may be ahead of the last-release-version. last_release_version_map_[ONNX_DOMAIN] = 21; last_release_version_map_[AI_ONNX_ML_DOMAIN] = 5; last_release_version_map_[AI_ONNX_TRAINING_DOMAIN] = 1; - last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 1; + last_release_version_map_[AI_ONNX_PREVIEW_TRAINING_DOMAIN] = 2; } const std::unordered_map>& Map() const { diff --git a/onnx/defs/training/defs.cc b/onnx/defs/training/defs.cc index 0c3ed7a6311..ecbc9d4d156 100644 --- a/onnx/defs/training/defs.cc +++ b/onnx/defs/training/defs.cc @@ -10,7 +10,7 @@ namespace ONNX_NAMESPACE { -static const char* Gradient_ver1_doc = R"DOC( +static const char* Gradient_ver2_doc = R"DOC( Gradient operator computes the partial derivatives of a specific tensor w.r.t. some other tensors. This operator is widely used in gradient-based training algorithms. To illustrate its use, let's consider a computation graph, @@ -138,9 +138,9 @@ auto-differentiation. ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA( Gradient, - 1, + 2, OpSchema() - .SetDoc(Gradient_ver1_doc) + .SetDoc(Gradient_ver2_doc) .Input( 0, "Inputs", @@ -187,10 +187,10 @@ ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA( "\"zs\" are the minimal independent variable set that determines " "the value of \"y\".", AttributeProto::STRING) - .TypeConstraint("T1", OpSchema::all_tensor_types(), "Allow outputs to be any kind of tensor.") + .TypeConstraint("T1", OpSchema::all_tensor_types_ir4(), "Allow outputs to be any kind of tensor.") .TypeConstraint( "T2", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, + OpSchema::all_float_types_ir4(), "Allow inputs to be any kind of floating-point tensor.")); static const char* Adagrad_ver1_doc = R"DOC( diff --git a/onnx/defs/training/old.cc b/onnx/defs/training/old.cc new file mode 100644 index 00000000000..f1e7599d5e3 --- /dev/null +++ b/onnx/defs/training/old.cc @@ -0,0 +1,196 @@ +// /* +// * SPDX-License-Identifier: Apache-2.0 +// */ + +#include +#include + +// #include "onnx/defs/function.h" +#include "onnx/defs/schema.h" + +namespace ONNX_NAMESPACE { + +static const char* Gradient_ver1_doc = R"DOC( +Gradient operator computes the partial derivatives of a specific tensor w.r.t. +some other tensors. This operator is widely used in gradient-based training +algorithms. To illustrate its use, let's consider a computation graph, + +``` +X -----. + | + v +W --> Conv --> H --> Gemm --> Y + ^ + | + Z +``` + +, where W and Z are trainable tensors. Note that operators' attributes are +omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of +Y with respect to W (Z). The user can compute gradient by inserting Gradient +operator to form another graph shown below. + +``` +W --> Conv --> H --> Gemm --> Y +| ^ ^ +| | | +| X Z +| | | +| | .----------' +| | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in +| | | "xs" followed by "zs") +| v v +'---> Gradient(xs=["W", "Z"], zs=["X"], y="Y") + | | + | '-----------------------------------> dY/dW (1st output of Gradient) + | + '---------------------------------------> dY/dZ (2nd output of Gradient) +``` + +By definition, the tensor "y" is a function of independent variables in "xs" +and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable +variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H" +cannot appear in "xs" and "zs". The reason is that "H" can be determined by +tensors "W" and "X" and therefore "H" is not an independent variable. + +All outputs are optional. If needed, for example, user can assign an empty +string to the 1st output name of that Gradient to skip the generation of dY/dW. +Note that the concept of optional outputs can also be found in ONNX's RNN, GRU, +and LSTM. + +Gradient operator can compute derivative against intermediate tensors. For +example, the gradient of Y with respect to H can be done via + +``` +W --> Conv --> H --> Gemm --> Y + ^ | ^ + | | | + X | Z + .-------' | + | .----------' + | | (H/Z is the 1st/2nd input of Gradient as shown in "xs") + v v + Gradient(xs=["H", "Z"], y="Y") + | | + | '-----------------------------------> dY/dH (1st output of Gradient) + | + '---------------------------------------> dY/dZ (2nd output of Gradient) +``` + +It is possible to represent high-order differentiation using Gradient operators. +For example, given the following linear model: + +``` +W --> Gemm --> Y --> Loss --> O + ^ ^ + | | + X L +``` + +To compute the 2nd order derivative of O with respect to W (denoted by +d^2O/dW^2), one can do + +``` +W --> Gemm --> Y --> Loss --> O +| ^ ^ +| | | +| X .------------L +| | | | +| | | v ++------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient) +| | | | +| | | '---> dO/dW (2nd output of Gradient) +| v v +'---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of + | Gradient) + | + | + '---> d^2O/dW^2 (2nd output of Gradient) +``` + +The tensors named in attributes "xs", "zs", and "y" define the differentiated +computation graph, and the inputs to Gradient node define the values at +which the gradient is computed. We can feed different tensors to the identified +graph. For example, one can compute the gradient of Y with respect to H at +a specific value of H, H_1, by providing that value as an input to the Gradient +node. + +``` +W --> Conv --> H --> Gemm --> Y + ^ ^ + | | + X Z + + Z_1 (2nd input of Gradient) + | + v +H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1. + | + '------------------------------> dY/dZ (2nd output of Gradient) +``` + +When the inputs of Gradient are the tensors named in "xs" and "zs", the +computation can be optimized. More specifically, intermediate variables in +forward pass can be reused if the gradient is computed via reverse-mode +auto-differentiation. + +)DOC"; + +ONNX_PREVIEW_TRAINING_OPERATOR_SET_SCHEMA( + Gradient, + 1, + OpSchema() + .SetDoc(Gradient_ver1_doc) + .Input( + 0, + "Inputs", + "The values fed into graph identified by the attributes. " + "The i-th input is the value of the i-th tensor specified in the " + "concatenated list of the attribute \"xs\" and the attribute " + " \"zs\". For example, if xs=[\"A\", \"B\"] and zs=[\"C\"], the " + "first input is used as the value of symbol \"A\" and the 3rd " + "input is substituted for all the occurrences of \"C\".", + "T1", + OpSchema::Variadic, + false) + .Output( + 0, + "Outputs", + "The gradient of the tensor specified by the attribute \"y\" " + "with respect to each of tensors specified in the " + "attribute \"xs\". The i-th output is the gradient of \"y\" with " + "respect to the i-th tensor specified in the attribute \"xs\".", + "T2", + OpSchema::Variadic, + false) + .Attr( + "xs", + "Input tensor names of the differentiated sub-graph. It " + "contains only the necessary differentiated " + "inputs of a (sub-)graph. Variables (usually called " + "intermediate variables) that can be generated from inputs " + "cannot be included in this attribute.", + AttributeProto::STRINGS) + .Attr( + "zs", + "Input tensor names of the differentiated sub-graph. It " + "contains only the necessary non-differentiated " + "inputs of a (sub-)graph. Variables (usually called " + "intermediate variables) that can be generated from inputs " + "cannot be included in this attribute.", + AttributeProto::STRINGS, + OPTIONAL_VALUE) + .Attr( + "y", + "The targeted tensor. It can be viewed as the output of the " + "differentiated function. The attribute \"xs\" and attribute " + "\"zs\" are the minimal independent variable set that determines " + "the value of \"y\".", + AttributeProto::STRING) + .TypeConstraint("T1", OpSchema::all_tensor_types(), "Allow outputs to be any kind of tensor.") + .TypeConstraint( + "T2", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Allow inputs to be any kind of floating-point tensor.")); + +} // namespace ONNX_NAMESPACE diff --git a/onnx/helper.py b/onnx/helper.py index d28df25ca3c..0b6512c769b 100644 --- a/onnx/helper.py +++ b/onnx/helper.py @@ -76,6 +76,7 @@ ("1.14.1", 9, 19, 3, 1), ("1.15.0", 9, 20, 4, 1), ("1.16.0", 10, 21, 5, 1), + ("1.17.0", 10, 21, 5, 1), ] VersionMapType = Dict[Tuple[str, int], int]