Skip to content
22 changes: 22 additions & 0 deletions onnxruntime/core/providers/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,18 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
return false;
}

static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) {
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();
for (std::size_t i = 0; i < node_indices.size(); i++) {
gsl::not_null<const onnxruntime::Node*> node(graph_viewer.GetNode(node_indices[i]));
for (auto& output : node->OutputDefs()) {
if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
return true;
}
}
return false;
}

static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name,
[[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto,
[[maybe_unused]] const onnxruntime::Node& fused_node) {
Expand Down Expand Up @@ -456,6 +468,16 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return model_proto;
} else if (IsModelBF16(subgraph)) {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled";
std::unique_ptr<onnxruntime::Model> model;
Status status = bfloat16_fix::Transform(subgraph, logger, model);
auto model_proto = model->ToProto();
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
print_model_proto_duration();
DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
return model_proto;
} else {
LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled";
auto model = subgraph.CreateModel(logger);
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/openvino/ov_versions/data_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,11 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
return false;
}

auto dtype = type_proto->tensor_type().elem_type();
// Enable bfloat16 -> float16 on-the-fly conversion
if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16)
return true;
if (is_initializer) {
auto dtype = type_proto->tensor_type().elem_type();
for (auto const& var : supported_types_initializer_) {
if ((var.first <= version_id_) &&
(var.second == dtype)) {
Expand All @@ -571,8 +574,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) {
#endif
return false;
} else {
auto dtype = type_proto->tensor_type().elem_type();

if (device_id_.find("HETERO") != std::string::npos ||
device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) {
for (auto const& var : supported_types_npu_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "qdq_scales_fix.h"
#include "core/providers/openvino/ov_protobuf_utils.h"
#include "core/framework/float16.h"

#include <fstream>
#include <list>
Expand Down Expand Up @@ -940,5 +941,54 @@
return status;
}
} // namespace qdq_scales_fix

namespace bfloat16_fix {
void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) {
for (auto& const_node : gen_graph.original_graph.Nodes()) {
auto node = const_cast<ONNX_NAMESPACE::Node*>(const_node);
if (node->OpType() == "Cast") {
for (auto& [name, const_attribute] : node->GetAttributes()) {
auto& attribute = const_cast<ONNX_NAMESPACE::AttributeProto&>(const_attribute);
if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT)
if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
}
}
for (auto& output : node->OutputDefs()) {
auto& output_proto = const_cast<ONNX_NAMESPACE::TypeProto&>(output->ToProto().type());
if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
}
}

const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors();
for (auto& [key, const_tensor_proto] : init_set) {
auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(const_tensor_proto);
auto dt = tensor_proto->data_type();
if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast<std::uint16_t*>(tensor_proto->mutable_raw_data()->data()) : nullptr;
if (raw_data) {
tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
std::int64_t size = 1;
for (int i = 0; i < tensor_proto->dims_size(); ++i)
size *= tensor_proto->dims()[i];
for (std::int64_t i = 0; i < size; ++i) {
raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val;
}
}
}
}
}

Status Transform(const GraphViewer& src_graph_viewer,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model) {

Check warning on line 985 in onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp:985: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model);
auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph());

replace_bf16_with_fp16(g);
return status;
}
} // namespace bfloat16_fix
} // namespace openvino_ep
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,10 @@
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model);
}
namespace bfloat16_fix {
Status Transform(const GraphViewer& src_graph,
const logging::Logger& logger,
/*out*/ std::unique_ptr<onnxruntime::Model>& model);

Check warning on line 21 in onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h:21: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
}
} // namespace openvino_ep
} // namespace onnxruntime
116 changes: 116 additions & 0 deletions onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <filesystem>
#include <map>
#include <string>

#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/float16.h"

#include "test/util/include/test/test_environment.h"
#include "test/optimizer/qdq_test_utils.h"

#include "gtest/gtest.h"
#include "gmock/gmock.h"

using namespace ONNX_NAMESPACE;
using namespace onnxruntime::logging;

extern std::unique_ptr<Ort::Env> ort_env;

class OVEP_BF16_Tests : public ::testing::TestWithParam<std::string> {};

namespace detail {
auto ConstructModel() {
using namespace onnxruntime;
using namespace test;

std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = 19;
Model model("Bfloat16Tester", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, DefaultLoggingManager().DefaultLogger());

Graph& graph = model.MainGraph();
ModelTestBuilder builder(graph);
auto dim = 4;
std::vector<float> input_data(dim, 1.0f);
auto* input = builder.MakeInput<float>({dim}, input_data);
builder.graph_.SetInputs({input});

auto* cast_to_bf16 = builder.MakeIntermediate();
Node& cast_node = builder.AddNode("Cast", {input}, {cast_to_bf16}, "");
cast_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16));

std::vector<onnxruntime::BFloat16> weight_data(dim * dim);
for (std::size_t i = 0; i < weight_data.size(); ++i)
weight_data[i] = onnxruntime::BFloat16(static_cast<float>(i % dim) / dim);
auto* weights = builder.MakeInitializer<onnxruntime::BFloat16>({dim, dim}, weight_data);

auto* matmul_out = builder.MakeIntermediate();
builder.AddNode("MatMul", {cast_to_bf16, weights}, {matmul_out});

std::vector<onnxruntime::BFloat16> weight_data_2(dim * dim);
for (std::size_t i = 0; i < weight_data_2.size(); ++i)
weight_data_2[i] = onnxruntime::BFloat16(static_cast<float>(i % dim) / dim);
auto* weights_2 = builder.MakeInitializer<onnxruntime::BFloat16>({dim, dim}, weight_data_2);

auto* matmul_out_2 = builder.MakeIntermediate();
builder.AddNode("MatMul", {matmul_out, weights_2}, {matmul_out_2});

auto* output = builder.MakeOutput();
Node& cast2_node = builder.AddNode("Cast", {matmul_out_2}, {output});
cast2_node.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));

builder.SetGraphOutputs();
auto st = model.MainGraph().Resolve();
if (st != Status::OK())
throw std::runtime_error(st.ErrorMessage());
return model;
}

auto ProbeDevice(const std::string& device) {
static std::map<std::string, bool> is_present;
if (is_present.find(device) == is_present.end()) {
Ort::SessionOptions sessionOptions;
std::unordered_map<std::string, std::string> ov_options;
ov_options["device_type"] = device;
try {
sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options);
is_present[device] = true;
} catch (...) {
is_present[device] = false;
}
}
return is_present[device];
}
} // namespace detail

namespace onnxruntime {
namespace test {

TEST_P(OVEP_BF16_Tests, TestModelConversion) {
Ort::SessionOptions sessionOptions;
std::unordered_map<std::string, std::string> ov_options;
const auto& device = GetParam();
if (!::detail::ProbeDevice(device))
GTEST_SKIP() << device + " is not available on this machine";

ov_options["device_type"] = device;
auto model = ::detail::ConstructModel();
sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options);

std::string model_data;
model.ToProto().SerializeToString(&model_data);
auto model_data_span = AsByteSpan(model_data.data(), model_data.size());
try {
Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions);
} catch (...) {
FAIL();
}
}
INSTANTIATE_TEST_SUITE_P(OVEP_Tests,
OVEP_BF16_Tests,
::testing::Values("CPU", "GPU", "NPU"));
} // namespace test
} // namespace onnxruntime
Loading