diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index be59b1ae07020..cadeab4cbd4cc 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -375,6 +375,18 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { return false; } +static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) { + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (std::size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + for (auto& output : node->OutputDefs()) { + if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + return true; + } + } + return false; +} + static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { @@ -456,6 +468,16 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; + } else if (IsModelBF16(subgraph)) { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled"; + std::unique_ptr model; + Status status = bfloat16_fix::Transform(subgraph, logger, model); + auto model_proto = model->ToProto(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + print_model_proto_duration(); + DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; auto model = subgraph.CreateModel(logger); diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 17e69ad080b90..f991e85ebe518 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -555,8 +555,11 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { return false; } + auto dtype = type_proto->tensor_type().elem_type(); + // Enable bfloat16 -> float16 on-the-fly conversion + if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16) + return true; if (is_initializer) { - auto dtype = type_proto->tensor_type().elem_type(); for (auto const& var : supported_types_initializer_) { if ((var.first <= version_id_) && (var.second == dtype)) { @@ -571,8 +574,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { #endif return false; } else { - auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { for (auto const& var : supported_types_npu_) { diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index d159930d52845..f1ce230387565 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -3,6 +3,7 @@ #include "qdq_scales_fix.h" #include "core/providers/openvino/ov_protobuf_utils.h" +#include "core/framework/float16.h" #include #include @@ -940,5 +941,54 @@ Status Transform(const GraphViewer& src_graph_viewer, return status; } } // namespace qdq_scales_fix + +namespace bfloat16_fix { +void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { + for (auto& const_node : gen_graph.original_graph.Nodes()) { + auto node = const_cast(const_node); + if (node->OpType() == "Cast") { + for (auto& [name, const_attribute] : node->GetAttributes()) { + auto& attribute = const_cast(const_attribute); + if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT) + if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + for (auto& output : node->OutputDefs()) { + auto& output_proto = const_cast(output->ToProto().type()); + if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + + const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors(); + for (auto& [key, const_tensor_proto] : init_set) { + auto tensor_proto = const_cast(const_tensor_proto); + auto dt = tensor_proto->data_type(); + if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast(tensor_proto->mutable_raw_data()->data()) : nullptr; + if (raw_data) { + tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + std::int64_t size = 1; + for (int i = 0; i < tensor_proto->dims_size(); ++i) + size *= tensor_proto->dims()[i]; + for (std::int64_t i = 0; i < size; ++i) { + raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val; + } + } + } + } +} + +Status Transform(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model) { + auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model); + auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph()); + + replace_bf16_with_fp16(g); + return status; +} +} // namespace bfloat16_fix } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h index c54c531e1bd40..2182850d96c43 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h @@ -15,5 +15,10 @@ Status Transform(const GraphViewer& src_graph, const logging::Logger& logger, /*out*/ std::unique_ptr& model); } +namespace bfloat16_fix { +Status Transform(const GraphViewer& src_graph, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model); +} } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc new file mode 100644 index 0000000000000..fc90563a61bb1 --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/float16.h" + +#include "test/util/include/test/test_environment.h" +#include "test/optimizer/qdq_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +extern std::unique_ptr ort_env; + +class OVEP_BF16_Tests : public ::testing::TestWithParam {}; + +namespace detail { +auto ConstructModel() { + using namespace onnxruntime; + using namespace test; + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 19; + Model model("Bfloat16Tester", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + auto dim = 4; + std::vector input_data(dim, 1.0f); + auto* input = builder.MakeInput({dim}, input_data); + builder.graph_.SetInputs({input}); + + auto* cast_to_bf16 = builder.MakeIntermediate(); + Node& cast_node = builder.AddNode("Cast", {input}, {cast_to_bf16}, ""); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); + + std::vector weight_data(dim * dim); + for (std::size_t i = 0; i < weight_data.size(); ++i) + weight_data[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights = builder.MakeInitializer({dim, dim}, weight_data); + + auto* matmul_out = builder.MakeIntermediate(); + builder.AddNode("MatMul", {cast_to_bf16, weights}, {matmul_out}); + + std::vector weight_data_2(dim * dim); + for (std::size_t i = 0; i < weight_data_2.size(); ++i) + weight_data_2[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights_2 = builder.MakeInitializer({dim, dim}, weight_data_2); + + auto* matmul_out_2 = builder.MakeIntermediate(); + builder.AddNode("MatMul", {matmul_out, weights_2}, {matmul_out_2}); + + auto* output = builder.MakeOutput(); + Node& cast2_node = builder.AddNode("Cast", {matmul_out_2}, {output}); + cast2_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + builder.SetGraphOutputs(); + auto st = model.MainGraph().Resolve(); + if (st != Status::OK()) + throw std::runtime_error(st.ErrorMessage()); + return model; +} + +auto ProbeDevice(const std::string& device) { + static std::map is_present; + if (is_present.find(device) == is_present.end()) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + ov_options["device_type"] = device; + try { + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + is_present[device] = true; + } catch (...) { + is_present[device] = false; + } + } + return is_present[device]; +} +} // namespace detail + +namespace onnxruntime { +namespace test { + +TEST_P(OVEP_BF16_Tests, TestModelConversion) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + const auto& device = GetParam(); + if (!::detail::ProbeDevice(device)) + GTEST_SKIP() << device + " is not available on this machine"; + + ov_options["device_type"] = device; + auto model = ::detail::ConstructModel(); + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); + } catch (...) { + FAIL(); + } +} +INSTANTIATE_TEST_SUITE_P(OVEP_Tests, + OVEP_BF16_Tests, + ::testing::Values("CPU", "GPU", "NPU")); +} // namespace test +} // namespace onnxruntime