From cee79526fd68b6fb6b09d47dae0d48b47d2f7021 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 25 Aug 2021 12:04:20 -0700 Subject: [PATCH] Add opset 15 kernels for Pow, BatchNorm, and Shape (#8442) --- docs/OperatorKernels.md | 18 +++-- .../onnxruntime/core/framework/tensor_shape.h | 9 +++ .../core/optimizer/constant_folding.cc | 38 ++++++++- .../core/optimizer/graph_transformer_utils.cc | 2 - .../core/optimizer/shape_to_initializer.cc | 80 ------------------- .../core/optimizer/shape_to_initializer.h | 32 -------- .../providers/cpu/cpu_execution_provider.cc | 38 ++++++--- .../providers/cpu/math/element_wise_ops.cc | 26 ++++-- .../core/providers/cpu/nn/batch_norm.cc | 50 ++++++++++-- .../core/providers/cpu/tensor/shape_op.cc | 9 ++- .../core/providers/cpu/tensor/shape_op.h | 45 ++++++++++- .../providers/cuda/cuda_execution_provider.cc | 35 +++++--- .../cuda/math/binary_elementwise_ops.cc | 27 +++++-- .../core/providers/cuda/nn/batch_norm.cc | 66 ++++++++------- .../core/providers/cuda/nn/batch_norm.h | 6 +- .../core/providers/cuda/tensor/shape_op.cc | 14 +++- .../providers/rocm/rocm_execution_provider.cc | 16 +++- onnxruntime/test/onnx/main.cc | 16 +--- .../test/optimizer/graph_transform_test.cc | 39 ++++----- .../cpu/math/element_wise_ops_test.cc | 47 +++++++---- .../providers/cpu/nn/batch_norm_op_test.cc | 25 ++++++ .../providers/cpu/tensor/shape_op_test.cc | 79 ++++++++++++++---- .../testdata/kernel_def_hashes/onnx.cpu.json | 16 ++++ .../onnx_backend_test_series_filters.jsonc | 1 - .../core/optimizer/graph_transformer_utils.cc | 11 ++- 25 files changed, 470 insertions(+), 275 deletions(-) delete mode 100644 onnxruntime/core/optimizer/shape_to_initializer.cc delete mode 100644 onnxruntime/core/optimizer/shape_to_initializer.h diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 10227e2d05f2..f4f395a2f84a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -38,7 +38,8 @@ Do not modify directly.* |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float)| |||10|**T** = tensor(float)| |||[7, 9]|**T** = tensor(float)| -|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|14+|**T** = tensor(double), tensor(float)| +|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(double), tensor(float)
**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| +|||14|**T** = tensor(double), tensor(float)
**U** = tensor(double), tensor(float)| |||[9, 13]|**T** = tensor(double), tensor(float)| |||[7, 8]|**T** = tensor(double), tensor(float)| |BitShift|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)| @@ -202,7 +203,8 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(int8), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(double), tensor(float)| |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|Pow|*in* X:**T**
*in* Y:**T**
*out* Z:**T**

or

*in* X:**T**
*in* Y:**T1**
*out* Z:**T**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|Pow|*in* X:**T**
*in* Y:**T**
*out* Z:**T**

or

*in* X:**T**
*in* Y:**T1**
*out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|||[13, 14]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 11]|**T** = tensor(double), tensor(float)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(uint8)
**T4** = tensor(int32)| @@ -280,7 +282,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| @@ -446,7 +449,8 @@ Do not modify directly.* |AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(float16)| -|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|14+|**T** = tensor(double), tensor(float), tensor(float16)| +|BatchNormalization|*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* input_mean:**U**
*in* input_var:**U**
*out* Y:**T**
*out* running_mean:**U**
*out* running_var:**U**

or

*in* X:**T**
*in* scale:**T**
*in* B:**T**
*in* mean:**T**
*in* var:**T**
*out* Y:**T**
*out* mean:**T**
*out* var:**T**
*out* saved_mean:**T**
*out* saved_var:**T**

or

*in* X:**T**
*in* scale:**T1**
*in* B:**T1**
*in* input_mean:**T2**
*in* input_var:**T2**
*out* Y:**T**
*out* running_mean:**T2**
*out* running_var:**T2**|15+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(double), tensor(float), tensor(float16)| +|||14|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)| |||[9, 13]|**T** = tensor(double), tensor(float), tensor(float16)| |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -582,7 +586,8 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|Pow|*in* X:**T**
*in* Y:**T**
*out* Z:**T**

or

*in* X:**T**
*in* Y:**T1**
*out* Z:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|Pow|*in* X:**T**
*in* Y:**T**
*out* Z:**T**

or

*in* X:**T**
*in* Y:**T1**
*out* Z:**T**|15+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| +|||[13, 14]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64)| |||[7, 11]|**T** = tensor(double), tensor(float), tensor(float16)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|10+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| @@ -653,7 +658,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||[13, 14]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/framework/tensor_shape.h b/include/onnxruntime/core/framework/tensor_shape.h index 9a2609bc1fc9..89d9b105946e 100644 --- a/include/onnxruntime/core/framework/tensor_shape.h +++ b/include/onnxruntime/core/framework/tensor_shape.h @@ -72,6 +72,15 @@ class TensorShape : private std::vector { memcpy(dims, data(), sizeof(value_type) * std::min(num_dims, NumDimensions())); } + /** + Copy dims from a specific start dim into an array with given size + `start_dim` is expected to be in the inclusive range [0, NumDimensions() - 1] + and this function does no checks to ensure that + */ + void CopyDims(int64_t* dims, size_t start_dim, size_t num_dims) const { + memcpy(dims, data() + start_dim, sizeof(value_type) * std::min(num_dims, NumDimensions() - start_dim)); + } + /** Return underlying vector representation. */ diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e78c957bca11..af585db08dcc 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/optimizer/constant_folding.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" @@ -25,6 +27,20 @@ ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider, // We need to handle a Shape node separately as the input doesn't need to be a constant initializer for // Shape to be able to be constant folded. static bool ConstantFoldShapeNode(Graph& graph, Node& node) { + // Opset-15 Shape supports slicing using a 'start' and 'end' attribute + const auto& shape_attributes = node.GetAttributes(); + + int64_t start = 0; + int64_t end = std::numeric_limits::max(); + + for (const auto& attr : shape_attributes) { + if (attr.first == "start") { + start = attr.second.i(); + } else if (attr.first == "end") { + end = attr.second.i(); + } + } + auto shape = node.MutableInputDefs()[0]->Shape(); bool is_concrete_shape = true; std::vector dim_values; @@ -42,14 +58,30 @@ static bool ConstantFoldShapeNode(Graph& graph, Node& node) { } if (is_concrete_shape) { + int64_t rank = static_cast(dim_values.size()); + + // We ascertain the "true" starts/ends (if they were provided) + // Opset-15 Shape op supports slicing shape values + + // Deal with negatives and clamp + start = start < 0 ? start + rank : start; + start = start < 0 ? 0 : ((start > rank) ? rank : start); + + end = end < 0 ? end + rank : end; + end = end < 0 ? 0 : ((end > rank) ? rank : end); + + int64_t slice_length = end - start; + size_t clamped_slice_length = slice_length < 0 ? 0 : static_cast(slice_length); + ONNX_NAMESPACE::TensorProto shape_constant; auto* constant_arg_out = node.MutableOutputDefs()[0]; shape_constant.set_name(constant_arg_out->Name()); shape_constant.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - shape_constant.add_dims(dim_values.size()); - shape_constant.set_raw_data(dim_values.data(), dim_values.size() * sizeof(int64_t)); + shape_constant.add_dims(clamped_slice_length); + shape_constant.set_raw_data(dim_values.data() + start, + clamped_slice_length * sizeof(int64_t)); ONNX_NAMESPACE::TensorShapeProto result_shape; - result_shape.add_dim()->set_dim_value(dim_values.size()); + result_shape.add_dim()->set_dim_value(clamped_slice_length); constant_arg_out->SetShape(result_shape); graph.AddInitializedTensor(shape_constant); } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index fead80b22cfb..6a0bc646bf03 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,7 +37,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_to_initializer.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" @@ -75,7 +74,6 @@ std::vector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); - rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); diff --git a/onnxruntime/core/optimizer/shape_to_initializer.cc b/onnxruntime/core/optimizer/shape_to_initializer.cc deleted file mode 100644 index 02aa17d7724b..000000000000 --- a/onnxruntime/core/optimizer/shape_to_initializer.cc +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/shape_to_initializer.h" -#include "core/graph/graph.h" -#include "core/graph/graph_utils.h" -#include "core/graph/op.h" -#include "core/optimizer/initializer.h" -#include "core/optimizer/optimizer_execution_frame.h" -#include "core/framework/op_kernel.h" -#include "core/framework/tensorprotoutils.h" - -namespace onnxruntime { - -Status ShapeToInitializer::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { - // Store the statically inferred shape of the input to the Shape operator. - const ONNX_NAMESPACE::TensorShapeProto* input_shape_proto = node.InputDefs()[0]->Shape(); - std::vector input_dims; - int num_dimensions = input_shape_proto->dim_size(); - for (int i = 0; i < num_dimensions; i++) { - input_dims.push_back(gsl::narrow_cast(input_shape_proto->dim(i).dim_value())); - } - - // Create the TensorProto that will be used as initializer in place of the Shape operator. - const auto* shape_out_def = node.OutputDefs()[0]; - - ONNX_NAMESPACE::TensorProto shape_initializer_proto; - - shape_initializer_proto.set_name(shape_out_def->Name()); - - TensorShape tensor_shape({gsl::narrow_cast(num_dimensions)}); - for (auto& dim : tensor_shape.GetDims()) { - shape_initializer_proto.add_dims(dim); - } - - auto tensor_proto_data_type = shape_out_def->TypeAsProto()->tensor_type().elem_type(); - shape_initializer_proto.set_data_type(tensor_proto_data_type); - - // Here we expect little-endian format to set raw data of the TensorProto. - shape_initializer_proto.set_raw_data(input_dims.data(), - input_dims.size() * sizeof(decltype(input_dims)::value_type)); - - auto& new_node_arg = graph_utils::AddInitializer(graph, shape_initializer_proto); - - if (graph_utils::ReplaceNodeWithInitializer(graph, node, new_node_arg)) { - rule_effect = RewriteRuleEffect::kRemovedCurrentNode; - } - - return Status::OK(); -} - -bool ShapeToInitializer::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Shape", {1, 13})) { - return false; - } - - // The shape of the input has to be statically known. Moreover, each dimension should have a meaningful value - // (the rule cannot be applied if one of the dimensions has a negative value or if it is a symbolic variable). - const auto* input_shape = node.InputDefs()[0]->Shape(); - if (!input_shape) { - return false; - } - - for (int i = 0, num_dims = input_shape->dim_size(); i < num_dims; i++) { - const auto& input_dim = input_shape->dim(i); - if (!utils::HasDimValue(input_dim) || input_dim.dim_value() < 0) { - return false; - } - } - - // we're going to create an initializer with the same name as the node output - const auto& new_initializer_name = node.OutputDefs()[0]->Name(); - if (!graph_utils::CanReplaceNodeWithInitializer(graph, node, new_initializer_name, logger)) { - return false; - } - - return true; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/shape_to_initializer.h b/onnxruntime/core/optimizer/shape_to_initializer.h deleted file mode 100644 index fd749821d517..000000000000 --- a/onnxruntime/core/optimizer/shape_to_initializer.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/rewrite_rule.h" - -namespace onnxruntime { - -/** -@Class ShapeToInitializer - -When the input to a Shape operator is statically known (through shape inference), this rule replaces the Shape node -with an initializer to the downstream nodes. - -It is attempted to be triggered only on nodes with op type "Shape". -*/ -class ShapeToInitializer : public RewriteRule { - public: - ShapeToInitializer() noexcept : RewriteRule("ShapeToInitializer") {} - - std::vector TargetOpTypes() const noexcept override { - return {"Shape"}; - } - - private: - bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; - - Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 2186d7c2f23e..620f1632bf02 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -527,7 +527,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, ArgMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, ArgMin); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less); @@ -590,7 +590,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Log); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 14, Pow); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Slice); @@ -686,12 +686,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Identity); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN); +// Opset 15 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, double, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 15, Shape); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -1151,9 +1156,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BatchNormalization)>, BuildKernelCreateInfo, + BatchNormalization)>, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1547,7 +1552,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int32_t, ArgMin)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1650,7 +1655,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1810,13 +1815,22 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Div)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 15 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 52b80a092313..b00de30dcadb 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -259,13 +259,27 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sqrt, 6, 12, double, Sqrt); REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, float, Sqrt); REG_ELEMENTWISE_TYPED_KERNEL(Sqrt, 13, double, Sqrt); -REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, BuildKernelDefConstraintsFromTypeList(), BuildKernelDefConstraintsFromTypeList()); +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT(Pow, 7, 11, Pow, + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList()); + REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 12, 12, Pow, - BuildKernelDefConstraintsFromTypeList(), BuildKernelDefConstraintsFromTypeList(), - BuildKernelDefConstraintsFromTypeList(), BuildKernelDefConstraintsFromTypeList()); -REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 13, Pow, - BuildKernelDefConstraintsFromTypeList(), BuildKernelDefConstraintsFromTypeList(), - BuildKernelDefConstraintsFromTypeList(), BuildKernelDefConstraintsFromTypeList()); + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList()); + +REG_ELEMENTWISE_VERSIONED_KERNEL_NONT_2(Pow, 13, 14, Pow, + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList()); + +REG_ELEMENTWISE_KERNEL_NONT_2(Pow, 15, Pow, + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList(), + BuildKernelDefConstraintsFromTypeList()); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, float, Exp); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Exp, 6, 12, double, Exp); diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.cc b/onnxruntime/core/providers/cpu/nn/batch_norm.cc index 92fff6974bff..ebd1af6c0e9c 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.cc @@ -31,18 +31,52 @@ ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 7, 8, double, // We alias the running mean to the mean so it stays preserved across multiple batches ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, float, - KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), - BatchNorm); + KernelDefBuilder().Alias(3, 1).Alias(4, 2).TypeConstraint("T", DataTypeImpl::GetTensorType()), + BatchNorm); ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 9, 13, double, - KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), - BatchNorm); + KernelDefBuilder().Alias(3, 1).Alias(4, 2).TypeConstraint("T", DataTypeImpl::GetTensorType()), + BatchNorm); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, float, + KernelDefBuilder() + .Alias(3, 1) + .Alias(4, 2) + // ORT 1.8 was shipped with just the "T" type constraint and + // we want to maintain backwards compatibility for + // the hash and hence just use "T" for the hash generation + .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("U", DataTypeImpl::GetTensorType()), + BatchNorm); + +ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(BatchNormalization, 14, 14, double, + KernelDefBuilder() + .Alias(3, 1) + .Alias(4, 2) + // ORT 1.8 was shipped with just the "T" type constraint and + // we want to maintain backwards compatibility for + // the hash and hence just use "T" for the hash generation + .FixedTypeConstraintForHash("T", {DataTypeImpl::GetTensorType()}) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("U", DataTypeImpl::GetTensorType()), + BatchNorm); -ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, float, - KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), +ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 15, float, + KernelDefBuilder() + .Alias(3, 1) + .Alias(4, 2) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), BatchNorm); -ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 14, double, - KernelDefBuilder().Alias(3,1).Alias(4,2).TypeConstraint("T", DataTypeImpl::GetTensorType()), +ONNX_CPU_OPERATOR_TYPED_KERNEL(BatchNormalization, 15, double, + KernelDefBuilder() + .Alias(3, 1) + .Alias(4, 2) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), BatchNorm); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/shape_op.cc b/onnxruntime/core/providers/cpu/tensor/shape_op.cc index 4020d5368511..780cf4b73399 100644 --- a/onnxruntime/core/providers/cpu/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/shape_op.cc @@ -12,9 +12,16 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + Shape, + 13, 14, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + ONNX_CPU_OPERATOR_KERNEL( Shape, - 13, + 15, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorTypes()).TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/shape_op.h b/onnxruntime/core/providers/cpu/tensor/shape_op.h index 65881092617f..1bdfb8e0c964 100644 --- a/onnxruntime/core/providers/cpu/tensor/shape_op.h +++ b/onnxruntime/core/providers/cpu/tensor/shape_op.h @@ -9,25 +9,62 @@ #endif #include "gsl/gsl" +#include namespace onnxruntime { class Shape final : public OpKernel { public: Shape(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("start", &start_index_, 0); + + if (start_index_ != 0) { + // "start" is provided and is non-default (default is 0) + needs_slicing_ = true; + } + + if (info.GetAttr("end", &end_index_).IsOK()) { + needs_slicing_ = true; + } } // Takes a tensor as input and outputs an 1D int64 tensor // containing the shape of the input tensor. Status Compute(OpKernelContext* context) const override { const auto* input = context->Input(0); - const TensorShape& inputShape = input->Shape(); + const TensorShape& input_shape = input->Shape(); + + int64_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + if (!needs_slicing_) { // vanilla use of Shape (no slicing) + Tensor* output = context->Output(0, {rank}); + input_shape.CopyDims(output->template MutableData(), static_cast(rank)); + } else { // slicing is needed + int64_t true_start = start_index_; + int64_t true_end = end_index_; + + // Deal with negative(s) and clamp + true_start = true_start < 0 ? true_start + rank : true_start; + true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start); - size_t nDims = inputShape.NumDimensions(); - Tensor* output = context->Output(0, {gsl::narrow_cast(nDims)}); + true_end = true_end < 0 ? true_end + rank : true_end; + true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end); + + auto slice_length = true_end - true_start; + Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length}); + + if (slice_length > 0) { + input_shape.CopyDims(output->template MutableData(), true_start, slice_length); + } + } - inputShape.CopyDims(output->template MutableData(), nDims); return Status::OK(); } + + private: + bool needs_slicing_ = false; + int64_t start_index_ = 0; + int64_t end_index_ = std::numeric_limits::max(); }; + } //namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 83d9f6a3fd94..726cf1d13328 100755 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -892,7 +892,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum); //OpSet 13 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add); @@ -1005,7 +1005,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint64_t, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, bool, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterElements); @@ -1163,9 +1163,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin); @@ -1182,6 +1182,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, BFloat16, Relu); #endif +//OpSet 15 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, double, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, Shape); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -1728,7 +1735,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // OpSet 13 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1841,7 +1848,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1998,9 +2005,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2015,6 +2022,14 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + + // OpSet 15 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc index a20e8e15f34c..f5b8c9aa6b96 100644 --- a/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc @@ -205,9 +205,9 @@ Status BinaryElementwise::Prepare(OpKernelContext* context, Bin #define BINARY_OP_TYPED_VERSIONED_V_BF16(name, class_name, startver, endver) #endif -#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \ - BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \ - BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \ +#define BINARY_OP_VERSIONED_HFD(name, startver, endver) \ + BINARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \ + BINARY_OP_VERSIONED_TYPED(name, startver, endver, float) \ BINARY_OP_VERSIONED_TYPED(name, startver, endver, double) #define BINARY_OP_VERSIONED_UZILHFD(name, startver, endver) \ @@ -318,15 +318,29 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( kOnnxDomain, 12, 12, kCudaExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()).TypeConstraint("T1", BuildKernelDefConstraints()), + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T1", BuildKernelDefConstraints()), + Pow); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Pow, + kOnnxDomain, + 13, 14, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T1", BuildKernelDefConstraints()), Pow); ONNX_OPERATOR_KERNEL_EX( Pow, kOnnxDomain, - 13, + 15, kCudaExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()).TypeConstraint("T1", BuildKernelDefConstraints()), + (*KernelDefBuilder::Create()) + .TypeConstraint("T", BuildKernelDefConstraints()) + .TypeConstraint("T1", BuildKernelDefConstraints()), Pow); namespace pow12_internal { @@ -524,6 +538,5 @@ BINARY_OP_REGISTER_VERSIONED_HFD(Less, 7, 8) BINARY_LOGICALOP_REGISTER_UZILHFD(GreaterOrEqual, 12) BINARY_LOGICALOP_REGISTER_UZILHFD(LessOrEqual, 12) - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.cc b/onnxruntime/core/providers/cuda/nn/batch_norm.cc index 1a09ec8f458a..675f002cb87a 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.cc +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.cc @@ -11,33 +11,45 @@ using namespace std; namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - BatchNormalization, \ - kOnnxDomain, \ - 7, 8, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - BatchNormalization, \ - kOnnxDomain, \ - 9, 13, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - BatchNorm); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - BatchNormalization, \ - kOnnxDomain, \ - 14, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + BatchNormalization, \ + kOnnxDomain, \ + 7, 8, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BatchNorm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + BatchNormalization, \ + kOnnxDomain, \ + 9, 13, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + BatchNorm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + BatchNormalization, \ + kOnnxDomain, \ + 14, 14, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("U", DataTypeImpl::GetTensorType()), \ + BatchNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + BatchNormalization, \ + kOnnxDomain, \ + 15, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ BatchNorm); template diff --git a/onnxruntime/core/providers/cuda/nn/batch_norm.h b/onnxruntime/core/providers/cuda/nn/batch_norm.h index e8ded6a571f3..792e7da66417 100644 --- a/onnxruntime/core/providers/cuda/nn/batch_norm.h +++ b/onnxruntime/core/providers/cuda/nn/batch_norm.h @@ -39,8 +39,8 @@ class BatchNorm final : public CudaKernel { const auto& node = op_kernel_info.node(); auto opset = node.SinceVersion(); - // batch norm opset 14 is not implemented for training mode - ORT_ENFORCE(!(is_training_mode_ && opset==14), "Training mode does not support BN opset 14 yet."); + // batch norm opset 14 (or higher) is not implemented for training mode + ORT_ENFORCE(!(is_training_mode_ && opset >= 14), "Training mode does not support BN opset 14 (or higher) yet."); } Status ComputeInternal(OpKernelContext* context) const override; @@ -50,7 +50,7 @@ class BatchNorm final : public CudaKernel { int64_t spatial_ = 1; // default as per spec cudnnBatchNormMode_t cudnn_batch_norm_mode_; double momentum_; - bool is_training_mode_ = 0; //default as per spec + bool is_training_mode_ = 0; //default as per spec }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/shape_op.cc b/onnxruntime/core/providers/cuda/tensor/shape_op.cc index 5680831f5d12..1007fcc50b38 100644 --- a/onnxruntime/core/providers/cuda/tensor/shape_op.cc +++ b/onnxruntime/core/providers/cuda/tensor/shape_op.cc @@ -20,10 +20,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( .TypeConstraint("T1", DataTypeImpl::GetTensorType()), Shape); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + 13, 14, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + // properly force CPU/GPU synch inside the kernel + .OutputMemoryType(OrtMemTypeCPUInput, 0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), + Shape); + ONNX_OPERATOR_KERNEL_EX( Shape, kOnnxDomain, - 13, + 15, kCudaExecutionProvider, (*KernelDefBuilder::Create()) // properly force CPU/GPU synch inside the kernel diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index ec5bf563562e..1c2574cdc7dc 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -827,7 +827,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, Einsum); //OpSet 13 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint32_t, Add); @@ -934,7 +934,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Cast); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements); @@ -1029,6 +1029,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); +// opset 15 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, Pow); + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -1555,7 +1559,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // OpSet 13 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1662,7 +1666,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1756,6 +1760,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // opset 15 + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index f8dc8ee9e412..9085613600ca 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -291,7 +291,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { double per_sample_tolerance = 1e-3; // when cuda is enabled, set it to a larger value for resolving random MNIST test failure // when openvino is enabled, set it to a larger value for resolving MNIST accuracy mismatch - double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009 : 1e-3; + double relative_per_sample_tolerance = enable_cuda ? 0.017 : enable_openvino ? 0.009 + : 1e-3; Ort::SessionOptions sf; @@ -480,8 +481,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { ORT_TSTR("operator_pow"), ORT_TSTR("bernoulli"), ORT_TSTR("bernoulli_double"), - ORT_TSTR("bernoulli_seed") - }; + ORT_TSTR("bernoulli_seed")}; static const ORTCHAR_T* cuda_flaky_tests[] = { ORT_TSTR("fp16_inception_v1"), @@ -600,16 +600,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { {"bernoulli_seed", "By design. Test data is for informational purpose because the generator is non deterministic."}, {"bernoulli_seed_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."}, {"bernoulli_expanded", "By design. Test data is for informational purpose because the generator is non deterministic."}, - {"shape", "opset15 updates not supported yet."}, - {"shape_clip_end", "opset15 updates not supported yet."}, - {"shape_clip_start", "opset15 updates not supported yet."}, - {"shape_end_1", "opset15 updates not supported yet."}, - {"shape_end_negative_1", "opset15 updates not supported yet."}, - {"shape_example", "opset15 updates not supported yet."}, - {"shape_start_1", "opset15 updates not supported yet."}, - {"shape_start_1_end_2", "opset15 updates not supported yet."}, - {"shape_start_1_end_negative_1", "opset15 updates not supported yet."}, - {"shape_start_negative_1", "opset15 updates not supported yet."}, {"test_optional_get_element", "opset15 updates not supported yet."}, {"test_optional_get_element_sequence", "opset15 updates not supported yet."}, {"test_optional_has_element", "opset15 updates not supported yet."}, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index d981343f658d..ec36c2988e0b 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -53,7 +53,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_to_initializer.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" @@ -480,7 +479,7 @@ TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConst ASSERT_TRUE(op_to_count["RandomUniform"] == 0); } -TEST_F(GraphTransformationTests, ShapeToInitializer) { +TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) { auto model_uri = MODEL_FOLDER "shape-add.onnx"; std::shared_ptr model; ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); @@ -489,17 +488,21 @@ TEST_F(GraphTransformationTests, ShapeToInitializer) { ASSERT_TRUE(op_to_count["Shape"] == 4); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); - rule_transformer_L1->Register(std::make_unique()); - graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); - + std::unique_ptr e = + std::make_unique(CPUExecutionProviderInfo()); + graph_transformation_mgr.Register(std::make_unique(*e.get(), + false /*skip_dequantize_linear*/), + TransformerLevel::Level1); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); - // Two of the Shapes are not eliminated because: - // One includes a symbolic dimension. - // Another one includes a negative dimension - ASSERT_TRUE(op_to_count["Shape"] == 2); + + // A Shape node very deep in the graph (feeding into an Identity + // node that produces the graph output) gets constant folded which + // removes all its ancestors and the Identity node consuming this Shape's + // output is subsequently constant folded to leave the graph with no + // nodes. + ASSERT_TRUE(op_to_count.size() == 0); } // Check transformations in the case of a subgraph with constant inputs. @@ -674,8 +677,8 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 0); //Add removed from graph - ASSERT_TRUE(op_to_count["Relu"] == 0); //Relu removed from graph + ASSERT_TRUE(op_to_count["Add"] == 0); //Add removed from graph + ASSERT_TRUE(op_to_count["Relu"] == 0); //Relu removed from graph } //Conv->Add->Relu will be left intact since there is Identity depend on Add @@ -695,9 +698,9 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) { graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains - ASSERT_TRUE(op_to_count["Relu"] == 1); //Relu remains - ASSERT_TRUE(op_to_count["Identity"] == 1); //Identity remains + ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains + ASSERT_TRUE(op_to_count["Relu"] == 1); //Relu remains + ASSERT_TRUE(op_to_count["Identity"] == 1); //Identity remains } //Conv->Add will be left intact since there is no Relu follows @@ -715,7 +718,7 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) { graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains, no transform applied to the graph + ASSERT_TRUE(op_to_count["Add"] == 1); //Add remains, no transform applied to the graph } #endif @@ -4131,13 +4134,13 @@ TEST_F(GraphTransformationTests, FilterEnabledOptimizers) { const auto& graph = session_object.GetGraph(); - // check the ops that should go away if the constant folding transformer or ShapeToInitializer rewrite rule run + // check the ops that should go away if the constant folding transformer runs std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Shape"] == 1); ASSERT_TRUE(op_to_count["ConstantOfShape"] == 1); ASSERT_TRUE(op_to_count["Add"] == 1); - ASSERT_STATUS_OK(session_object.FilterEnabledOptimizers({"ConstantFolding", "ShapeToInitializer"})); + ASSERT_STATUS_OK(session_object.FilterEnabledOptimizers({"ConstantFolding"})); ASSERT_STATUS_OK(session_object.Initialize()); // Initialize runs the transformers op_to_count = CountOpsInGraph(graph); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 8de061e42d70..f4047b2c8892 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -688,6 +688,21 @@ TEST(MathOpTest, Pow_Float_12) { test.Run(); } +TEST(MathOpTest, Pow_Float_15) { + OpTester test("Pow", 15); + std::vector dims{2, 2}; + test.AddInput("X", dims, + {2.0f, 2.0f, + std::sqrt(2.0f), 1.0f}); + test.AddInput("Y", dims, + {0.0f, 8.0f, + 2.0f, 9.0f}); + test.AddOutput("Z", dims, + {1.0f, 256.0f, + 2.0f, 1.0f}); + test.Run(); +} + TEST(MathOpTest, Pow_Double_12) { OpTester test("Pow", 12); std::vector dims{2, 2}; @@ -1635,7 +1650,7 @@ TEST(MathOpTest, LessOrEqual) { test.AddInput("B", dims, {1.0f, 1.0f, 2.0f, -1.0f}); test.AddOutput("C", dims, {true, true, true, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, LessOrEqual_Scalar0) { @@ -1644,7 +1659,7 @@ TEST(MathOpTest, LessOrEqual_Scalar0) { test.AddInput("B", {4}, {1.0f, 1.5f, 2.0f, -1.0f}); test.AddOutput("C", {4}, {true, true, true, false}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, LessOrEqual_Scalar1) { @@ -1653,7 +1668,7 @@ TEST(MathOpTest, LessOrEqual_Scalar1) { test.AddInput("B", {1}, {1.0f}); test.AddOutput("C", {4}, {true, true, false, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, LessOrEqual_int64_Scalar1) { @@ -1662,7 +1677,7 @@ TEST(MathOpTest, LessOrEqual_int64_Scalar1) { test.AddInput("B", {1}, {1}); test.AddOutput("C", {4}, {true, true, false, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, LessOrEqual_broadcastAB) { OpTester test("LessOrEqual", 12); @@ -1670,7 +1685,7 @@ TEST(MathOpTest, LessOrEqual_broadcastAB) { test.AddInput("B", {2}, {15, 7}); test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, false, false}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, LessOrEqual_broadcastBA) { @@ -1679,7 +1694,7 @@ TEST(MathOpTest, LessOrEqual_broadcastBA) { test.AddInput("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, true, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, LessOrEqual_multidiretional_broadcastAB) { @@ -1688,7 +1703,7 @@ TEST(MathOpTest, LessOrEqual_multidiretional_broadcastAB) { test.AddInput("B", {2}, {15, 7}); test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, true, false}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, LessOrEqual_multidiretional_broadcastBA) { @@ -1697,7 +1712,7 @@ TEST(MathOpTest, LessOrEqual_multidiretional_broadcastBA) { test.AddInput("B", {4, 1}, {10, 11, 12, 13}); test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, false, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, Greater_7) { @@ -1784,7 +1799,7 @@ TEST(MathOpTest, GreaterOrEqual_12_float) { test.AddInput("B", dims, {1.0f, 1.0f, 2.0f, -1.0f}); test.AddOutput("C", dims, {true, false, false, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, GreaterOrEqual_12_double) { @@ -1794,7 +1809,7 @@ TEST(MathOpTest, GreaterOrEqual_12_double) { test.AddInput("B", dims, {1.0, 1.0, 2.0, -1.0}); test.AddOutput("C", dims, {true, false, true, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, GreaterOrEqual_12_int32) { @@ -1804,7 +1819,7 @@ TEST(MathOpTest, GreaterOrEqual_12_int32) { test.AddInput("B", dims, {15, 7, 12, 9}); test.AddOutput("C", dims, {false, true, true, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, GreaterOrEqual_12_int64) { @@ -1814,7 +1829,7 @@ TEST(MathOpTest, GreaterOrEqual_12_int64) { test.AddInput("B", dims, {15, 7, 12, 9}); test.AddOutput("C", dims, {false, true, true, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, GreaterOrEqual_broadcastAB) { @@ -1823,7 +1838,7 @@ TEST(MathOpTest, GreaterOrEqual_broadcastAB) { test.AddInput("B", {2}, {15, 7}); test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, true, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, GreaterOrEqual_broadcastBA) { @@ -1832,7 +1847,7 @@ TEST(MathOpTest, GreaterOrEqual_broadcastBA) { test.AddInput("B", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, false, false}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastAB) { @@ -1841,7 +1856,7 @@ TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastAB) { test.AddInput("B", {2}, {15, 7}); test.AddOutput("C", {4, 2}, {false, true, false, true, false, true, false, true}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastBA) { @@ -1850,7 +1865,7 @@ TEST(MathOpTest, GreaterOrEqual_multidiretional_broadcastBA) { test.AddInput("B", {4, 1}, {10, 11, 12, 13}); test.AddOutput("C", {4, 2}, {true, false, true, false, true, false, true, false}); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); + {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } TEST(MathOpTest, Equal_bool) { diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 1b08b25512c5..730631cbd6ba 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -790,5 +790,30 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } +TEST(BatchNormTest, ForwardTrainingTestOpset15) { + OpTester test("BatchNormalization", 15); + float epsilon = 1e-05f; + float momentum = 0.1f; + int64_t training_mode = 1; + test.AddAttribute("epsilon", epsilon); + test.AddAttribute("momentum", momentum); + test.AddAttribute("training_mode", training_mode); + std::vector input_output_dims{2, 2, 2, 2}; + std::vector channel_dims{2}; + test.AddInput("X", input_output_dims, {-0.2953f, 0.1180f, 1.0973f, -0.1931f, -0.1999f, -0.0237f, 1.5181f, 0.0076f, -1.0830f, -1.5433f, 0.4327f, -0.9813f, 0.7875f, -0.4080f, -2.3144f, 1.5493f}); + test.AddInput("scale", channel_dims, {1.0f, 1.0f}); + test.AddInput("B", channel_dims, {0.0f, 0.0f}); + test.AddInput("mean", channel_dims, {1.0f, 2.0f}); + test.AddInput("var", channel_dims, {1.0f, 2.0f}); + + test.AddOutput("Y", input_output_dims, {0.0131f, 0.5210f, 1.7244f, 0.1387f, -0.2708f, -0.1191f, 1.2089f, -0.0922f, -0.9548f, -1.5203f, 0.9077f, -0.8298f, 0.5796f, -0.4501f, -2.0921f, 1.2358f}); + + test.AddOutput("running_mean", channel_dims, {-0.1754f, 0.303106f}); + test.AddOutput("running_var", channel_dims, {0.696052f, 1.41316f}); + + // Same exclusions as the opset 14 test + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc b/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc index 91859d15222c..6128556923b8 100644 --- a/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/shape_op_test.cc @@ -4,26 +4,75 @@ namespace onnxruntime { namespace test { -template -void TestShape(const std::initializer_list& data, const std::vector& shape) -{ +template +void TestShape(const std::initializer_list& data, const std::vector& shape) { OpTester test("Shape"); test.AddInput("data", shape, data); test.AddOutput("output", {static_cast(shape.size())}, shape); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});//TensorRT parser: unsupported data types + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types } -TEST(ShapeOpTest, ShapeTestBool) { TestShape ({true, true, false, false, true, false}, {2, 3}); } -TEST(ShapeOpTest, ShapeTestFloat) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); } -TEST(ShapeOpTest, ShapeTestDouble) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); } -TEST(ShapeOpTest, ShapeTestInt8) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } -TEST(ShapeOpTest, ShapeTestInt16) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } -TEST(ShapeOpTest, ShapeTestInt32) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); } -TEST(ShapeOpTest, ShapeTestInt64) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } -TEST(ShapeOpTest, ShapeTestUint8) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } -TEST(ShapeOpTest, ShapeTestUint16) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } -TEST(ShapeOpTest, ShapeTestUint32) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } -TEST(ShapeOpTest, ShapeTestUint64) { TestShape ({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } +TEST(ShapeOpTest, ShapeTestBool) { TestShape({true, true, false, false, true, false}, {2, 3}); } +TEST(ShapeOpTest, ShapeTestFloat) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 6}); } +TEST(ShapeOpTest, ShapeTestDouble) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {6, 2}); } +TEST(ShapeOpTest, ShapeTestInt8) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } +TEST(ShapeOpTest, ShapeTestInt16) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {3, 4}); } +TEST(ShapeOpTest, ShapeTestInt32) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {4, 3}); } +TEST(ShapeOpTest, ShapeTestInt64) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } +TEST(ShapeOpTest, ShapeTestUint8) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } +TEST(ShapeOpTest, ShapeTestUint16) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } +TEST(ShapeOpTest, ShapeTestUint32) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {12, 1}); } +TEST(ShapeOpTest, ShapeTestUint64) { TestShape({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {1, 12}); } TEST(ShapeOpTest, ShapeTestString) { TestShape({"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12"}, {1, 12}); } + +TEST(ShapeOpTest, ShapeOpset15_Default) { + OpTester test("Shape", 15); + test.AddInput("data", {1, 2, 2}, {1, 2, 3, 4}); + test.AddOutput("output", {3}, {1, 2, 2}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types +} + +TEST(ShapeOpTest, ShapeOpset15_StartOnly) { + OpTester test("Shape", 15); + test.AddAttribute("start", 1); + test.AddInput("data", {1, 2, 2}, {1, 2, 3, 4}); + test.AddOutput("output", {2}, {2, 2}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types +} + +TEST(ShapeOpTest, ShapeOpset15_EndOnly) { + OpTester test("Shape", 15); + test.AddAttribute("end", 2); + test.AddInput("data", {1, 2, 2}, {1, 2, 3, 4}); + test.AddOutput("output", {2}, {1, 2}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types } + +TEST(ShapeOpTest, ShapeOpset15_StartAndEnd) { + OpTester test("Shape", 15); + test.AddAttribute("start", 1); + test.AddAttribute("end", 2); + test.AddInput("data", {1, 2, 2}, {1, 2, 3, 4}); + test.AddOutput("output", {1}, {2}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types +} + +TEST(ShapeOpTest, ShapeOpset15_StartAndEndNegative) { + OpTester test("Shape", 15); + test.AddAttribute("start", -2); + test.AddAttribute("end", -1); + test.AddInput("data", {1, 2, 2}, {1, 2, 3, 4}); + test.AddOutput("output", {1}, {2}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types } +TEST(ShapeOpTest, ShapeOpset15_StartAndEndProducingEmptySlice) { + OpTester test("Shape", 15); + test.AddAttribute("start", 2); + test.AddAttribute("end", 2); + test.AddInput("data", {1, 2, 2}, {1, 2, 3, 4}); + test.AddOutput("output", {0}, {}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT parser: unsupported data types +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json index 00b4e414a281..1f2dd81d7d08 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json @@ -255,6 +255,14 @@ "BatchNormalization ai.onnx CPUExecutionProvider", 17832136363477464736 ], + [ + "BatchNormalization ai.onnx CPUExecutionProvider", + 3016597991190826984 + ], + [ + "BatchNormalization ai.onnx CPUExecutionProvider", + 9270095107043637928 + ], [ "BitShift ai.onnx CPUExecutionProvider", 4758677670685660688 @@ -1483,6 +1491,10 @@ "Pow ai.onnx CPUExecutionProvider", 12963226513247425672 ], + [ + "Pow ai.onnx CPUExecutionProvider", + 16138602580714332296 + ], [ "PRelu ai.onnx CPUExecutionProvider", 3282999003886175808 @@ -2159,6 +2171,10 @@ "Shape ai.onnx CPUExecutionProvider", 14989007508280400584 ], + [ + "Shape ai.onnx CPUExecutionProvider", + 9917761852037658112 + ], [ "Shrink ai.onnx CPUExecutionProvider", 4706529740707835200 diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index e1218394dddd..fab3f0ce42fb 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -67,7 +67,6 @@ "^test_add_uint8_cpu", "^test_div_uint8_cpu", // Following tests are for opset 15 ops and are not yet implemented in ORT - "^test_shape_*", "^test_optional_*", //GPU failures "^test_batchnorm_epsilon_training_mode_cuda", diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index e14208382fcd..7096f2fc3f86 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -38,7 +38,6 @@ #include "core/optimizer/relu_clip_fusion.h" #include "core/optimizer/reshape_fusion.h" #include "core/optimizer/rule_based_graph_transformer.h" -#include "core/optimizer/shape_to_initializer.h" #include "core/optimizer/skip_layer_norm_fusion.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/unsqueeze_elimination.h" @@ -75,7 +74,7 @@ std::vector> GeneratePreTrainingTransformers( case TransformerLevel::Level1: { rule_transformer = std::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), - compatible_eps); + compatible_eps); rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); rule_transformer->Register(std::make_unique()); @@ -127,16 +126,16 @@ std::vector> GeneratePreTrainingTransformers( if (config.propagate_cast_ops_config.level >= 0) { std::unordered_set cuda_execution_provider = {onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider}; transformers.emplace_back(std::make_unique(config.propagate_cast_ops_config.strategy, - static_cast(config.propagate_cast_ops_config.level), - config.propagate_cast_ops_config.allow, - cuda_execution_provider)); + static_cast(config.propagate_cast_ops_config.level), + config.propagate_cast_ops_config.allow, + cuda_execution_provider)); } } break; case TransformerLevel::Level2: { rule_transformer = std::make_unique(optimizer_utils::GenerateRuleBasedTransformerName(level), - compatible_eps); + compatible_eps); rule_transformer->Register(std::make_unique()); } break;