diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 4a20847c0890c..abb5b31b76e44 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -37,19 +37,35 @@ ov::CompiledModel BackendManager::GetOVCompiledModel() { return ov::CompiledModel(); } +static bool ShouldExportEpContext(const SessionContext& session_context, const SubGraphContext& subgraph_context) { + return session_context.so_context_enable && (subgraph_context.is_ep_ctx_ovir_encapsulated || !subgraph_context.is_ep_ctx_graph); +} + BackendManager::BackendManager(SessionContext& session_context, - SharedContext& shared_context, + SharedContextManager& shared_context_manager, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger, EPCtxHandler& ep_ctx_handle) : ep_ctx_handle_(ep_ctx_handle), session_context_(session_context), - shared_context_{shared_context} { + shared_context_manager_(shared_context_manager) { subgraph_context_.is_ep_ctx_graph = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(subgraph); // If the graph contains a OVIR wrapped node, we check if it has matching xml file name attribute subgraph_context_.is_ep_ctx_ovir_encapsulated = ep_ctx_handle_.CheckEPCacheContextAttribute(subgraph, session_context_.onnx_model_path_name.filename().replace_extension("xml").string()); + if (subgraph_context_.is_ep_ctx_graph && !subgraph_context_.is_ep_ctx_ovir_encapsulated) { + shared_context_ = ep_ctx_handle.GetSharedContextForEpContextSubgraph(subgraph, + session_context_.GetModelPath()); + } else if (session_context_.so_context_enable && session_context_.so_share_ep_contexts) { + shared_context_ = shared_context_manager_.GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); + } else { + // Creating a shared context to satisfy backend. It won't be used for weight sharing. + // Don't make it the active share context since we don't actually want to share it. + shared_context_ = shared_context_manager_.GetOrCreateSharedContext(session_context_.GetOutputBinPath()); + } + ORT_ENFORCE(shared_context_, "Could not create a shared context."); + subgraph_context_.model_precision = [&](const GraphViewer& graph_viewer) { // return empty if graph has no inputs or if types are not one of FP32/FP16 // else assume the type of the first input @@ -107,23 +123,6 @@ BackendManager::BackendManager(SessionContext& session_context, } std::string device_type = session_context_.device_type; - auto& sw = shared_context_.shared_weights; - if (session_context_.so_share_ep_contexts && !sw.metadata.empty()) { - std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path(); - if (sw.external_weight_filename.empty()) { - // Reasonable assumption that all metadata entries have the same external file location - sw.external_weight_filename = sw.metadata.begin()->second.location; - } - weight_filename /= sw.external_weight_filename; - std::ifstream weight_file(weight_filename); - - ORT_ENFORCE(weight_file, "Initializer file not found: ", weight_filename.string()); - if (!sw.mapped_weights) { - sw.mapped_weights = std::make_unique(weight_filename); - } - backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights); - } - if (subgraph_context_.has_dynamic_input_shape) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Model has symbolic input dims"; if ((!session_context_.disable_dynamic_shapes && @@ -138,7 +137,7 @@ BackendManager::BackendManager(SessionContext& session_context, concrete_backend_ = BackendFactory::MakeBackend(model_proto, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (std::string const& msg) { ORT_THROW(msg); @@ -162,7 +161,7 @@ BackendManager::BackendManager(SessionContext& session_context, concrete_backend_ = BackendFactory::MakeBackend(model_proto, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (const OnnxRuntimeException& ex) { std::string exception_str = ex.what(); @@ -193,15 +192,15 @@ BackendManager::BackendManager(SessionContext& session_context, } } } - if (session_context_.so_context_enable && - (subgraph_context_.is_ep_ctx_ovir_encapsulated || !subgraph_context_.is_ep_ctx_graph)) { + + if (ShouldExportEpContext(session_context_, subgraph_context_)) { if (concrete_backend_) { - auto status = onnxruntime::openvino_ep::BackendManager::ExportCompiledBlobAsEPCtxNode(subgraph); - if (!status.IsOK()) { - ORT_THROW(status); - } + shared_context_->AddNativeBlob(subgraph_context_.subgraph_name, concrete_backend_->GetOVCompiledModel()); } else { - ORT_THROW("[OpenVINO-EP] Cannot export compiled blob as EPCtx Node: Backend not initialized."); + ORT_THROW( + "Exporting dynamically compiled models at runtime is not supported. " + "Cannot export blobs of dynamic models that request static shape inference. " + "To export this model, set disable_dynamic_shapes to False"); } } } @@ -210,13 +209,9 @@ BackendManager::BackendManager(SessionContext& session_context, // precompiled blob is set. If that's the case: // By default, create model in embed mode where the blob stream is exported as data within // the EPContext node. -Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& graph_body_viewer) { - if (session_context_.disable_dynamic_shapes && subgraph_context_.has_dynamic_input_shape) { - std::string exception_str = - "Exporting dynamically compiled models at runtime is not supported. " - "Cannot export blobs of dynamic models that request static shape inference. " - "To export this model, set disable_dynamic_shapes to False"; - ORT_THROW(exception_str); +void BackendManager::TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& graph_body_viewer, bool include_embed_data) { + if (!ShouldExportEpContext(session_context_, subgraph_context_) || !concrete_backend_) { + return; } // If embed_mode, then pass on the serialized blob @@ -224,11 +219,10 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie std::string model_blob_str; auto compiled_model = concrete_backend_->GetOVCompiledModel(); if (session_context_.so_context_embed_mode) { // Internal blob - std::ostringstream model_blob_stream; - compiled_model.export_model(model_blob_stream); - model_blob_str = std::move(model_blob_stream).str(); - if (model_blob_str.empty()) { - ORT_THROW("Model blob stream is empty after exporting the compiled model."); + if (include_embed_data) { + std::stringstream ss; + shared_context_->Serialize(ss); + model_blob_str = std::move(ss).str(); } } else { // External blob // Build name by combining EpCtx model name (if available) and subgraph name. Model @@ -238,30 +232,17 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie name = graph_body_viewer.ModelPath().stem().string(); } ORT_ENFORCE(!name.empty()); - name += "_" + subgraph_context_.subgraph_name; - std::filesystem::path blob_filename = session_context_.so_context_file_path; - if (blob_filename.empty()) { - blob_filename = session_context_.onnx_model_path_name; - } - blob_filename = blob_filename.parent_path() / (name + ".blob"); - std::ofstream blob_file(blob_filename, - std::ios::out | std::ios::trunc | std::ios::binary); - if (!blob_file) { - std::ostringstream err_msg; - err_msg << "Unable to open file for epctx model dump: " << blob_filename; - ORT_THROW(err_msg.str()); - } - compiled_model.export_model(blob_file); - model_blob_str = blob_filename.filename().string(); + model_blob_str = shared_context_->GetBinPath().filename().string(); } - ORT_RETURN_IF_ERROR(ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, - subgraph_context_.subgraph_name, - session_context_.so_context_embed_mode, - std::move(model_blob_str))); - - return Status::OK(); + auto status = ep_ctx_handle_.AddOVEPCtxNodeToGraph(graph_body_viewer, + subgraph_context_.subgraph_name, + session_context_.so_context_embed_mode, + std::move(model_blob_str)); + if (!status.IsOK()) { + ORT_THROW("[OpenVINO-EP] Failed to add OVEP EPContext node to the graph: " + status.ErrorMessage()); + } } bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const { @@ -568,7 +549,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, if ((session_context_.device_type.find("NPU") != std::string::npos) && (enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) { std::unique_ptr model; - Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights); + Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, *shared_context_); auto model_proto = model->ToProto(); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); print_model_proto_duration(); @@ -835,7 +816,7 @@ void BackendManager::Compute(OrtKernelContext* context) { dynamic_backend = BackendFactory::MakeBackend(modelproto_with_concrete_shapes, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (const OnnxRuntimeException& ex) { // Build option disables fallback to CPU on compilation failures with NPU. @@ -855,7 +836,7 @@ void BackendManager::Compute(OrtKernelContext* context) { dynamic_backend = BackendFactory::MakeBackend(modelproto_with_concrete_shapes, session_context_, subgraph_context_, - shared_context_, + *shared_context_, model_stream); } catch (std::string const& msg) { ORT_THROW(msg); diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index f091f95fe1c16..64dadb6c2151b 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -20,7 +20,7 @@ namespace openvino_ep { class BackendManager { public: BackendManager(SessionContext& session_context, - SharedContext& shared_context, + SharedContextManager& shared_context_manager, const onnxruntime::Node& fused_node, const onnxruntime::GraphViewer& subgraph, const logging::Logger& logger, @@ -28,7 +28,7 @@ class BackendManager { void Compute(OrtKernelContext* context); void ShutdownBackendManager(); SessionContext& GetSessionContext(); - Status ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph); + void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data); ov::CompiledModel GetOVCompiledModel(); void RewindKVCache(size_t index); @@ -59,7 +59,8 @@ class BackendManager { SubGraphContext subgraph_context_; EPCtxHandler& ep_ctx_handle_; SessionContext& session_context_; - SharedContext& shared_context_; + SharedContextManager& shared_context_manager_; + std::shared_ptr shared_context_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 7201c47a805e3..45e518d16686e 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -20,130 +20,6 @@ using Exception = ov::Exception; namespace onnxruntime { namespace openvino_ep { -SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary), file_path_(filename) { - try { - file_.exceptions(std::ifstream::failbit | std::ifstream::badbit); - weights_size_ = std::filesystem::file_size(filename); - } catch (const std::exception& e) { - ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what()); - } -} - -void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) { - ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds."); - file_.seekg(file_offset); - file_.read(reinterpret_cast(data), size); -} - -void* SharedContext::SharedWeights::WeightsFile::TryGetOrCreateDeviceMapping(std::optional& remote_context) { - std::string dev_name{}; - if (remote_context) { - dev_name = remote_context->get_device_name(); - } - - auto [it, inserted] = imported_device_tensors_.emplace(dev_name, MappingContainer{}); - if (inserted) { - if (dev_name == "NPU") { -#if OPENVINO_VERSION_AT_LEAST(2025, 3) - // try to import the memory mapped file to remote tensor - ORT_ENFORCE(remote_context, "Error: Remote context is required for NPU device."); - auto npu_context = remote_context->as(); - auto&& l0_tensor = npu_context.create_tensor(ov::element::Type_t::u8, {weights_size_}, ov::intel_npu::FileDescriptor(file_path_)); - it->second = MappingContainer{.ptr_ = l0_tensor.get(), .tensor_ = l0_tensor}; -#endif - } else if (dev_name.empty()) { - // CPU/virtual device case, create a CPU tensor memory mapped from file - auto&& mmaped_tensor = ov::read_tensor_data(file_path_); - it->second = MappingContainer{.ptr_ = mmaped_tensor.data(), .tensor_ = mmaped_tensor}; - } - } - - return it->second.ptr_; -} - -std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) { - try { - stream << metadata.size(); - - // Write each key-value pair - // Put elements in separate lines to facilitate reading - for (const auto& [key, value] : metadata) { - stream << std::endl - << key.name; - stream << std::endl - << value.location; - stream << std::endl - << value.data_offset; - stream << std::endl - << value.size; - stream << std::endl - << value.dimensions.size(); - for (const auto& dim : value.dimensions) { - stream << std::endl - << dim; - } - stream << std::endl - << value.element_type; - } - } catch (const Exception& e) { - ORT_THROW("Error: Failed to write map data.", e.what()); - } catch (...) { - ORT_THROW("Error: Failed to write map data."); - } - - ORT_ENFORCE(stream.good(), "Error: Failed to write map data."); - return stream; -} - -std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) { - size_t map_size{0}; - try { - stream >> map_size; - - while (!stream.eof()) { - SharedContext::SharedWeights::Metadata::Key key; - SharedContext::SharedWeights::Metadata::Value value; - stream >> key.name; - stream >> value.location; - stream >> value.data_offset; - stream >> value.size; - size_t num_dimensions; - stream >> num_dimensions; - - if (stream.fail()) { - ORT_THROW("Error: Failed to read num_dimensions from stream."); - } - - constexpr size_t MAX_SAFE_DIMENSIONS = 1024; - - size_t safe_num_dimensions = num_dimensions; - - if (num_dimensions == 0 || safe_num_dimensions > MAX_SAFE_DIMENSIONS) { - ORT_THROW("Invalid number of dimensions provided."); - } - try { - value.dimensions.resize(safe_num_dimensions); - } catch (const std::bad_alloc&) { - ORT_THROW("Error: Memory allocation failed while resizing dimensions."); - } - - for (auto& dim : value.dimensions) { - stream >> dim; - } - stream >> value.element_type; - metadata.emplace(key, value); - } - } catch (const Exception& e) { - ORT_THROW("Error: Failed to read map data.", e.what()); - } catch (...) { - ORT_THROW("Error: Failed to read map data."); - } - - ORT_ENFORCE(metadata.size() == map_size, "Error: Inconsistent map data."); - - return stream; -} - namespace backend_utils { bool IsDebugEnabled() { @@ -390,96 +266,10 @@ void printPerformanceCounts(const std::vector& performanceMap, } void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName) { - auto performanceMap = request->GetNewObj().get_profiling_info(); + auto performanceMap = request->GetInfReq().get_profiling_info(); printPerformanceCounts(performanceMap, stream, std::move(deviceName)); } -ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt) { - static std::unordered_map map{ - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, ov::element::f32}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT8, ov::element::u8}, - {ONNX_NAMESPACE::TensorProto_DataType_INT8, ov::element::i8}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT16, ov::element::u16}, - {ONNX_NAMESPACE::TensorProto_DataType_INT16, ov::element::i16}, - {ONNX_NAMESPACE::TensorProto_DataType_INT32, ov::element::i32}, - {ONNX_NAMESPACE::TensorProto_DataType_INT64, ov::element::i64}, - {ONNX_NAMESPACE::TensorProto_DataType_STRING, ov::element::string}, - {ONNX_NAMESPACE::TensorProto_DataType_BOOL, ov::element::boolean}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, ov::element::f16}, - {ONNX_NAMESPACE::TensorProto_DataType_DOUBLE, ov::element::f64}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT32, ov::element::u32}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT64, ov::element::u64}, - //{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64, ov::element::undefined}, - //{ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128, ov::element::undefined}, - {ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16, ov::element::bf16}, - //{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN, ov::element::undefined}, - //{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ, ov::element::undefined}, - {ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2, ov::element::f8e5m2}, - //{ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ, ov::element::undefined}, - {ONNX_NAMESPACE::TensorProto_DataType_UINT4, ov::element::u4}, - {ONNX_NAMESPACE::TensorProto_DataType_INT4, ov::element::i4}, - }; - - if (auto result = map.find(dt); result != map.end()) { - return result->second; - } else { - throw std::runtime_error("Unsupported ONNX data type: " + std::to_string(dt)); - } -} - -// Function to handle tensor creation from external data -void CreateOVTensors(const std::string& device_name, - SharedContext::SharedWeights::Metadata::Map& metadata_map, - SharedContext::SharedWeights::WeightsFile& weights) { - // Get remote context if available - std::optional opt_remote_ctx; - try { - opt_remote_ctx = OVCore::Get()->core.get_default_context(device_name); - } catch (const std::exception&) { - // Remote context not available - } - - for (auto& [key, value] : metadata_map) { - if (value.tensor) continue; - - // Get element data type - auto onnx_element_type = (ONNX_NAMESPACE::TensorProto_DataType)value.element_type; - ov::element::Type ov_elementType = GetOpenVINOElementType(onnx_element_type); - - // Try to get memory-mapped weights - ov::Tensor tensor; - uint8_t* mmaped_weights = static_cast(weights.TryGetOrCreateDeviceMapping(opt_remote_ctx)); - - if (mmaped_weights) { - // We have memory mapped weights. Create a Tensor view into it for this value. - ORT_ENFORCE(value.data_offset < weights.Size() && - value.size <= weights.Size() && - (value.data_offset <= weights.Size() - value.size), - "File offset + size outside of external initializer file"); - void* mmapped_offset = static_cast(mmaped_weights + value.data_offset); - tensor = ov::Tensor(ov_elementType, value.dimensions, mmapped_offset); - } else { - ORT_ENFORCE(opt_remote_ctx, "Expected either memory-mapped weights or a valid remote context, but neither is available for device: ", device_name); - // Can't mmap the file to device tensor, create a host tensor and copy the data - tensor = opt_remote_ctx->create_host_tensor(ov_elementType, value.dimensions); - ORT_ENFORCE(tensor.get_byte_size() == value.size, "Remote tensor size mismatch"); - weights.load_weights(value.data_offset, tensor.data(), value.size); - } - - ORT_ENFORCE(tensor.get_byte_size() == value.size, "Unexpected tensor size mismatch"); - value.tensor = std::make_shared(std::move(tensor)); - } -} - -void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) { - for (auto& [key, value] : metadata_map) { - if (value.tensor) { - value.tensor.reset(); - } - } - metadata_map.clear(); -} - bool IsModelStreamXML(std::istream& model_stream) { std::streampos originalPos = model_stream.tellg(); diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index 27f791c7a5bd1..8ba35e0abd1bc 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -99,11 +99,6 @@ CreateOVModel(std::string&& model, const SessionContext& session_context, std::map>& const_outputs_map); -void CreateOVTensors(const std::string& device_name, - SharedContext::SharedWeights::Metadata::Map& metadata_map, - SharedContext::SharedWeights::WeightsFile& weights); -void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map); - void printPerformanceCounts(const std::vector& performanceMap, std::ostream& stream, std::string deviceName); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index a950538c7c5fd..d7fc0553fb1d4 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -138,20 +138,13 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr } int num_infer_req = (session_context_.num_of_threads > 0) ? session_context_.num_of_threads : 1; std::function initializer = [](OVInferRequestPtr) {}; - auto metadata = shared_context_.shared_weights.metadata; if (session_context_.so_share_ep_contexts) { - initializer = [&metadata](OVInferRequestPtr ir_ptr) { - const auto input_count = ir_ptr->GetNumInputs(); - for (auto i = 0u; i < input_count; i++) { - using Key = SharedContext::SharedWeights::Metadata::Key; - const auto tensor_key = Key{ir_ptr->GetInputTensorName(i)}; - if (metadata.contains(tensor_key)) { - auto& value = metadata.at(tensor_key); - ir_ptr->SetTensor(tensor_key.name, value.tensor); - } - } + auto model_dir = session_context_.GetModelPath().parent_path(); + initializer = [this, model_dir = std::move(model_dir)](OVInferRequestPtr ir_ptr) { + shared_context_.SetSharedWeightsOnInferRequest(ir_ptr->GetInfReq(), model_dir); }; } + infer_req_pool_ = std::make_unique(exe_network_, num_infer_req, std::move(initializer)); bindings_ = std::make_unique(exe_network_, subgraph_context_, session_context_); } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index edd9f176658f8..b14e05191dfaa 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -13,80 +13,14 @@ #include "core/common/common.h" #include "core/providers/openvino/ov_interface.h" #include "core/providers/shared_library/provider_api.h" +#include "ov_bin_manager.h" +#include "ov_shared_context.h" namespace onnxruntime { namespace openvino_ep { namespace fs = std::filesystem; -class SharedContext : public WeakSingleton { - // Keep the core alive as long as the shared SharedContext are alive. - std::shared_ptr OVCore_; - - public: - SharedContext() : OVCore_(OVCore::Get()) {} - struct SharedWeights { - struct Metadata { - struct Key { - std::string name; - bool operator==(const Key&) const = default; - }; - struct Hash { - std::size_t operator()(const Key& key) const noexcept { - return std::hash()(key.name); - } - }; - struct Value { - std::string location; - unsigned int data_offset; - unsigned int size; - std::vector dimensions; - std::int32_t element_type; - std::shared_ptr tensor; - }; - using Map = std::unordered_map; - friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata); - friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata); - }; - - struct WeightsFile { - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightsFile); - WeightsFile() = delete; - explicit WeightsFile(std::filesystem::path filename); - - void load_weights(size_t file_offset, void* data, size_t size); - void* TryGetOrCreateDeviceMapping(std::optional& remote_context); - size_t Size() const { return weights_size_; } - - private: - std::ifstream file_; - std::filesystem::path file_path_; - size_t weights_size_; - struct MappingContainer { - void* ptr_{nullptr}; - ov::Tensor tensor_; - }; - std::map imported_device_tensors_; - }; - - void clear() { - metadata.clear(); - metadata_filepath.clear(); - external_weight_filename.clear(); - mapped_weights.reset(); - } - - fs::path external_weight_filename; - std::unique_ptr mapped_weights; - Metadata::Map metadata; - fs::path metadata_filepath; - } shared_weights; - - void clear() { - shared_weights.clear(); - } -}; - using config_t = std::map; using reshape_t = std::map; using layout_t = std::map; @@ -127,8 +61,8 @@ struct ProviderInfo { bool so_disable_cpu_ep_fallback{false}; // ORT session option bool so_context_embed_mode{false}; // ORT session option bool so_share_ep_contexts{false}; // ORT session option - fs::path so_context_file_path{}; // ORT session option bool so_stop_share_ep_contexts{false}; // ORT session option + fs::path so_context_file_path{}; // ORT session option const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", @@ -156,8 +90,24 @@ struct SessionContext : ProviderInfo { mutable bool has_external_weights = false; // Value is set to mutable to modify from capability const std::vector OpenVINO_Version = {OPENVINO_VERSION_MAJOR, OPENVINO_VERSION_MINOR}; const std::string openvino_sdk_version = std::to_string(OPENVINO_VERSION_MAJOR) + "." + std::to_string(OPENVINO_VERSION_MINOR); + RuntimeConfig runtime_config; + const std::filesystem::path& GetModelPath() const { + return onnx_model_path_name.empty() ? so_context_file_path : onnx_model_path_name; + } + + const std::filesystem::path GetOutputBinPath() const { + std::filesystem::path bin_file_name = so_context_file_path; + if (bin_file_name.empty()) { + bin_file_name = onnx_model_path_name; + } + if (bin_file_name.empty()) { + return {}; + } + return BinManager::GetBinPathForModel(bin_file_name); + } + private: void InitRuntimeConfig() { if (config_options) { diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 051a39bd4f205..3260d18e9f43c 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -12,32 +12,11 @@ namespace onnxruntime { namespace openvino_ep { -EPCtxHandler::EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger) : openvino_sdk_version_(std::move(ov_sdk_version)), logger_(logger) { - epctx_model_ = Model::Create("ovep_context_model", false, logger_); -} - -/* Export the serialized blob string embedded onto an EPContext Node - * along with other metadata necessary to validate the graph on import - */ - -Status EPCtxHandler::ExportEPCtxModel(const std::string& model_name) { - // Serialize modelproto to string - auto model_proto = epctx_model_->ToProto(); - model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - // Finally, dump the model - std::ofstream epctx_onnx_model(model_name, - std::ios::out | std::ios::trunc | std::ios::binary); - if (!epctx_onnx_model) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to create epctx onnx model file"); - } +EPCtxHandler::EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger, std::shared_ptr shared_context_manager) + : openvino_sdk_version_(std::move(ov_sdk_version)), logger_(logger), shared_context_manager_(std::move(shared_context_manager)) { + ORT_ENFORCE(shared_context_manager_ != nullptr, "SharedContextManager pointer is null in EPCtxHandler constructor."); - if (!model_proto->SerializeToOstream(epctx_onnx_model)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to serialize model to file"); - } - LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Export blob as EPContext Node"; - - return Status::OK(); + epctx_model_ = Model::Create("ovep_context_model", false, logger_); } Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, @@ -59,7 +38,7 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, // Create EP context node attributes auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create(); - node_attributes->reserve(4); + node_attributes->reserve(6); { // Create EP context node attributes @@ -70,6 +49,13 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, embed_mode_attr->set_i(embed_mode); node_attributes->emplace(EMBED_MODE, std::move(*embed_mode_attr)); + // main context + auto main_graph_attr = ONNX_NAMESPACE::AttributeProto::Create(); + main_graph_attr->set_name(MAIN_CONTEXT); + main_graph_attr->set_type(onnx::AttributeProto_AttributeType_INT); + main_graph_attr->set_i(model_blob_str.empty() ? 0 : 1); + node_attributes->emplace(MAIN_CONTEXT, std::move(*main_graph_attr)); + // ep context auto ep_cache_context_attr = ONNX_NAMESPACE::AttributeProto::Create(); ep_cache_context_attr->set_name(EP_CACHE_CONTEXT); @@ -90,6 +76,13 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, source_attr->set_type(onnx::AttributeProto_AttributeType_STRING); source_attr->set_s(kOpenVINOExecutionProvider); node_attributes->emplace(SOURCE, std::move(*source_attr)); + + // partition name + auto partition_name_attr = ONNX_NAMESPACE::AttributeProto::Create(); + partition_name_attr->set_name(PARTITION_NAME); + partition_name_attr->set_type(onnx::AttributeProto_AttributeType_STRING); + partition_name_attr->set_s(graph_name); + node_attributes->emplace(PARTITION_NAME, std::move(*partition_name_attr)); } // Create EP context node @@ -100,8 +93,30 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, return Status::OK(); } -std::unique_ptr -EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { +std::shared_ptr EPCtxHandler::GetSharedContextForEpContextSubgraph(const GraphViewer& subgraph_view, const std::filesystem::path& ep_context_path) const { + if (!CheckForOVEPCtxNodeInGraph(subgraph_view)) { + return nullptr; + } + + auto first_index = *subgraph_view.GetNodesInTopologicalOrder().begin(); + auto node = subgraph_view.GetNode(first_index); + ORT_ENFORCE(node != nullptr); + auto& attrs = node->GetAttributes(); + ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) == 1); + const auto& ep_cache_context = attrs.at(EP_CACHE_CONTEXT).s(); + + ORT_ENFORCE(attrs.count(EMBED_MODE) == 1); + bool embed_mode = static_cast(attrs.at(EMBED_MODE).i()); + + std::filesystem::path bin_path{}; + if (!embed_mode) { + bin_path = ep_context_path.parent_path() / ep_cache_context; + } + + return shared_context_manager_->GetOrCreateSharedContext(bin_path); +} + +std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); auto node = graph_viewer.GetNode(first_index); ORT_ENFORCE(node != nullptr); @@ -130,16 +145,23 @@ EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_pa bool isXML = backend_utils::IsModelStreamXML(*result); std::filesystem::path native_blob_path{}; if (!isXML) { + ORT_ENFORCE(attrs.count(PARTITION_NAME) == 1, "Expected partition name for native ep context node"); + const auto& partition_name = attrs.at(PARTITION_NAME).s(); + // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was // exported with must match the version that is currently running. native_blob_path = std::move(blob_filepath); ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); + + result.reset(); // Release the stream as we will get the native blob from SharedContext + auto shared_context = shared_context_manager_->GetOrCreateSharedContext(native_blob_path); + return std::make_unique(shared_context->GetNativeBlobAsStream(partition_name), shared_context->GetNativeBlob(partition_name)); } LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; - return std::make_unique(std::move(result), native_blob_path); + return std::make_unique(std::move(result), ov::Tensor()); } bool EPCtxHandler::CheckForOVEPCtxNodeInGraph(const GraphViewer& graph_viewer) const { @@ -196,5 +218,61 @@ bool EPCtxHandler::CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, return false; } +void EPCtxHandler::Initialize(const std::vector& fused_nodes, const std::filesystem::path& ep_context_dir) { + bool has_embed_nodes = false; + bool has_non_embed_nodes = false; + bool has_main_context = false; + for (const auto& fused_node_graph : fused_nodes) { + const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; + + // Only process graphs that contain ep context nodes. + if (!CheckForOVEPCtxNodeInGraph(graph_viewer)) { + continue; + } + + auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); + const Node* node = graph_viewer.GetNode(first_index); + ORT_ENFORCE(node != nullptr, "Node pointer is null despite CheckForOVEPCtxNodeInGraph returning true"); + + auto& attrs = node->GetAttributes(); + ORT_ENFORCE(attrs.count(EP_CACHE_CONTEXT) == 1, "EP_CACHE_CONTEXT attribute missing"); + + bool embed_mode = false; + if (attrs.count(EMBED_MODE) == 1) { + embed_mode = static_cast(attrs.at(EMBED_MODE).i()); + } + has_embed_nodes |= embed_mode; + has_non_embed_nodes |= !embed_mode; + + bool main_context = true; + if (attrs.count(MAIN_CONTEXT) == 1) { + main_context = static_cast(attrs.at(MAIN_CONTEXT).i()); + } + has_main_context |= main_context; + + const std::string& ep_cache_context = attrs.at(EP_CACHE_CONTEXT).s(); + if (embed_mode) { + std::filesystem::path dummy_path{}; + auto shared_context = shared_context_manager_->GetOrCreateSharedContext(dummy_path); + if (main_context) { + ORT_ENFORCE(!ep_cache_context.empty(), "Embedded EP context is indicated but EP_CACHE_CONTEXT attribute is empty."); + std::istringstream ss(ep_cache_context); + shared_context->Deserialize(ss); + } + } else { + std::filesystem::path ep_context_path = ep_context_dir / ep_cache_context; + if (ep_context_path.extension() != ".xml") { + auto shared_context = shared_context_manager_->GetOrCreateSharedContext(ep_context_path); + shared_context->Deserialize(); + } + } + } + + ORT_ENFORCE(!(has_embed_nodes && has_non_embed_nodes), + "Mixed embed and non-embed EP context nodes are not supported in a single model."); + ORT_ENFORCE(!(has_embed_nodes && !has_main_context), + "Expected at least one main context node when embedded EP context nodes are present."); +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index f207f5014ca1f..fc2a56c1d0671 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -8,43 +8,52 @@ #include #include "core/providers/shared_library/provider_api.h" +#include "core/framework/execution_provider.h" +#include "ov_bin_manager.h" +#include "ov_shared_context.h" namespace onnxruntime { namespace openvino_ep { +class SharedBinManager; + struct ModelBlobWrapper { - ModelBlobWrapper(std::unique_ptr stream, const std::filesystem::path& native_blob_path) : stream_(std::move(stream)), maybe_native_blob_path_(native_blob_path) {} + ModelBlobWrapper(std::unique_ptr stream, const ov::Tensor& tensor) : stream_(std::move(stream)), tensor_(tensor) {} std::unique_ptr stream_; - std::filesystem::path maybe_native_blob_path_; + ov::Tensor tensor_; // May be empty if model blob is provided as stream only. }; // Utilities to handle EPContext node export and parsing of an EPContext node // to create the compiled_model object to infer on static const char EPCONTEXT_OP[] = "EPContext"; static const char EMBED_MODE[] = "embed_mode"; +static const char MAIN_CONTEXT[] = "main_context"; +static const char PARTITION_NAME[] = "partition_name"; static const char EP_CACHE_CONTEXT[] = "ep_cache_context"; static const char EP_SDK_VER[] = "ep_sdk_version"; static const char SOURCE[] = "source"; class EPCtxHandler { public: - EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger); + EPCtxHandler(std::string ov_sdk_version, const logging::Logger& logger, std::shared_ptr shared_context_manager); EPCtxHandler(const EPCtxHandler&) = delete; // No copy constructor - Status ExportEPCtxModel(const std::string& model_name); - bool CheckForOVEPCtxNodeInGraph(const GraphViewer& graph_viewer) const; + bool CheckForOVEPCtxNodeInGraph(const GraphViewer& subgraph_view) const; + std::shared_ptr GetSharedContextForEpContextSubgraph(const GraphViewer& subgraph_view, const std::filesystem::path& ep_context_path) const; bool CheckForOVEPCtxNode(const Node& node) const; - Status AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, + Status AddOVEPCtxNodeToGraph(const GraphViewer& subgraph_view, const std::string& graph_name, const bool embed_mode, std::string&& model_blob_str) const; - std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; + std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& subgraph_view) const; InlinedVector GetEPCtxNodes() const; - bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const; + bool CheckEPCacheContextAttribute(const GraphViewer& subgraph_view, const std::string& target_attr_extn) const; + void Initialize(const std::vector& fused_nodes, const std::filesystem::path& ep_context_path); private: const std::string openvino_sdk_version_; std::unique_ptr epctx_model_; const logging::Logger& logger_; + std::shared_ptr shared_context_manager_; }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 049af81c9ffb2..f9c9fa2ea6f48 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -17,6 +17,7 @@ #ifdef USE_OVEP_NPU_MEMORY #include "core/providers/openvino/ov_allocator.h" #endif +#include "ov_interface.h" namespace onnxruntime { namespace openvino_ep { @@ -54,11 +55,12 @@ static std::vector parseDevices(const std::string& device_string, } #endif -OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, std::shared_ptr shared_context) +OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info) : IExecutionProvider{onnxruntime::kOpenVINOExecutionProvider}, session_context_(info), - shared_context_{std::move(shared_context)}, - ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger()} { + ov_core_(OVCore::Get()), + shared_context_manager_(SharedContextManager::Get()), + ep_ctx_handle_{session_context_.openvino_sdk_version, *GetLogger(), shared_context_manager_} { InitProviderOrtApi(); #ifdef _WIN32 session_id_ = global_session_counter_.fetch_add(1) + 1; @@ -72,7 +74,6 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() { backend_manager.ShutdownBackendManager(); } backend_managers_.clear(); - shared_context_.reset(); } std::vector> @@ -102,6 +103,11 @@ common::Status OpenVINOExecutionProvider::Compile( auto& logger = *GetLogger(); Status status = Status::OK(); + if (session_context_.so_context_enable && session_context_.so_context_embed_mode && session_context_.so_share_ep_contexts) { + return Status(common::StatusCategory::ONNXRUNTIME, common::EP_FAIL, + std::string("Invalid EP context configuration: ") + kOrtSessionOptionEpContextEmbedMode + " must be 0 if " + kOrtSessionOptionShareEpContexts + " is 1."); + } + bool is_epctx_model = false; if (!fused_nodes.empty()) { // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext @@ -115,24 +121,8 @@ common::Status OpenVINOExecutionProvider::Compile( is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); } - // The block below is executed during EP context model inference - auto& metadata = shared_context_->shared_weights.metadata; // Metadata object in memory - if (session_context_.so_share_ep_contexts && - is_epctx_model && - metadata.empty()) { - fs::path context_model_file_path = session_context_.so_context_file_path; - if (context_model_file_path.empty()) { - // If ep.context_file_path is not set the input model path is used - context_model_file_path = session_context_.onnx_model_path_name; - } - - // Metadata is always read from model location, this could be a source or epctx model - fs::path metadata_filename = context_model_file_path.stem().string() + "_metadata.bin"; - fs::path metadata_file_path = context_model_file_path.parent_path() / metadata_filename; - std::ifstream file(metadata_file_path, std::ios::binary); - ORT_RETURN_IF_NOT(file, "Metadata file was not found: " + metadata_file_path.string()); - shared_context_->shared_weights.metadata_filepath = std::move(metadata_file_path); - file >> metadata; + if (is_epctx_model) { + ep_ctx_handle_.Initialize(fused_nodes, session_context_.GetOutputBinPath().parent_path()); } struct OpenVINOEPFunctionState { @@ -153,12 +143,11 @@ common::Status OpenVINOExecutionProvider::Compile( // For original model, check if the user wants to export a model with pre-compiled blob auto& backend_manager = backend_managers_.emplace_back(session_context_, - *shared_context_, + *shared_context_manager_, fused_node, graph_body_viewer, logger, ep_ctx_handle_); - compute_info.create_state_func = [&backend_manager](ComputeContext* context, FunctionState* state) { OpenVINOEPFunctionState* p = new OpenVINOEPFunctionState{ @@ -189,42 +178,31 @@ common::Status OpenVINOExecutionProvider::Compile( }; node_compute_funcs.push_back(std::move(compute_info)); - - if (!status.IsOK()) { - break; - } } - // The block below is executed during EP context model generation - if (session_context_.so_context_enable && - session_context_.so_share_ep_contexts && - !metadata.empty()) { - // For models after the first the metadata name comes from the shared context - fs::path metadata_file_path = shared_context_->shared_weights.metadata_filepath; - if (metadata_file_path.empty()) { - metadata_file_path = session_context_.so_context_file_path; - std::string name_append{"_metadata.bin"}; - if (metadata_file_path.empty()) { - metadata_file_path = session_context_.onnx_model_path_name; - name_append = "_ctx" + name_append; - } - auto metadata_filename = metadata_file_path.stem().string() + name_append; - metadata_file_path.replace_filename(metadata_filename); - shared_context_->shared_weights.metadata_filepath = metadata_file_path; - } + // Export compiled blobs as EPContext nodes if context enable is set + if (session_context_.so_context_enable) { + auto backend_it = backend_managers_.begin(); + bool is_first = true; - // Metadata is generated only for shared contexts - // If saving metadata then save it to the provided path or use the original model path - // Multiple calls to Compile() will update the metadata and for the last call - // the resulting file will contain the aggregated content - std::ofstream file{metadata_file_path, std::ios::binary}; - ORT_RETURN_IF_NOT(file, "Metadata file could not be written: ", metadata_file_path); - file << metadata; - } + for (const auto& fused_node_graph : fused_nodes) { + const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph; + + // Set include_embed_data to true only for the first backend manager + backend_it->TryExportCompiledBlobAsEPCtxNode(graph_body_viewer, is_first); - if (session_context_.so_stop_share_ep_contexts) { - if (shared_context_) { - shared_context_->clear(); + is_first = false; + ++backend_it; + } + + // bit clunky ideally we should try to fold this into ep context handler + if (!session_context_.so_context_embed_mode) { + auto shared_context = shared_context_manager_->GetOrCreateActiveSharedContext(session_context_.GetOutputBinPath()); + shared_context->Serialize(); + if (session_context_.so_stop_share_ep_contexts) { + shared_context_manager_->ClearActiveSharedContext(); + shared_context->Clear(); + } } } diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index a375a9ee788bd..326f6de30498f 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -15,6 +15,9 @@ #include "core/providers/openvino/backend_manager.h" #include "core/providers/openvino/contexts.h" +#include "ov_shared_context.h" +#include "ov_bin_manager.h" +#include "ov_interface.h" #ifdef _WIN32 #include "core/providers/openvino/ov_tracing.h" @@ -50,7 +53,7 @@ static std::vector split(const std::string& s, char delim) { // Logical device representation. class OpenVINOExecutionProvider : public IExecutionProvider { public: - explicit OpenVINOExecutionProvider(const ProviderInfo& info, std::shared_ptr shared_context); + explicit OpenVINOExecutionProvider(const ProviderInfo& info); ~OpenVINOExecutionProvider(); std::vector> @@ -76,7 +79,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider { #endif private: SessionContext session_context_; - std::shared_ptr shared_context_; + std::shared_ptr ov_core_; + std::shared_ptr shared_context_manager_; + std::list backend_managers_; // EP session owns the backend objects EPCtxHandler ep_ctx_handle_; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 298eb25713bec..cb94fb3793024 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -16,6 +16,7 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "nlohmann/json.hpp" #include "core/providers/openvino/openvino_parser_utils.h" +#include "ov_interface.h" namespace onnxruntime { namespace openvino_ep { @@ -381,14 +382,14 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, } struct OpenVINOProviderFactory : IExecutionProviderFactory { - OpenVINOProviderFactory(ProviderInfo provider_info, std::shared_ptr shared_context) - : provider_info_(std::move(provider_info)), shared_context_(std::move(shared_context)) {} + OpenVINOProviderFactory(ProviderInfo provider_info, std::shared_ptr ov_core) + : provider_info_(std::move(provider_info)), ov_core_(ov_core) {} ~OpenVINOProviderFactory() override {} std::unique_ptr CreateProvider() override { ParseConfigOptions(provider_info_); - return std::make_unique(provider_info_, shared_context_); + return std::make_unique(provider_info_); } // Called by InferenceSession when registering EPs. Allows creation of an EP instance that is initialized with @@ -421,7 +422,7 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { ParseProviderInfo(provider_options, &config_options, provider_info); ParseConfigOptions(provider_info); - auto ov_ep = std::make_unique(provider_info, shared_context_); + auto ov_ep = std::make_unique(provider_info); ov_ep->SetLogger(reinterpret_cast(&session_logger)); return ov_ep; } @@ -432,14 +433,14 @@ struct OpenVINOProviderFactory : IExecutionProviderFactory { std::unique_ptr CreateProvider_V2(const OrtSessionOptions& /*session_options*/, const OrtLogger& session_logger) { ProviderInfo provider_info = provider_info_; - auto ov_ep = std::make_unique(provider_info, shared_context_); + auto ov_ep = std::make_unique(provider_info); ov_ep->SetLogger(reinterpret_cast(&session_logger)); return ov_ep; } private: ProviderInfo provider_info_; - std::shared_ptr shared_context_; + std::shared_ptr ov_core_; }; struct ProviderInfo_OpenVINO_Impl : ProviderInfo_OpenVINO { @@ -464,7 +465,7 @@ struct OpenVINO_Provider : Provider { ProviderInfo pi; ParseProviderInfo(provider_options, config_options, pi); - return std::make_shared(pi, SharedContext::Get()); + return std::make_shared(pi, OVCore::Get()); } Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, @@ -571,7 +572,7 @@ struct OpenVINO_Provider : Provider { ParseConfigOptions(pi); // Create and return the execution provider - auto factory = std::make_unique(pi, SharedContext::Get()); + auto factory = std::make_unique(pi, OVCore::Get()); ep = factory->CreateProvider_V2(session_options, logger); return Status::OK(); } diff --git a/onnxruntime/core/providers/openvino/ov_bin_manager.cc b/onnxruntime/core/providers/openvino/ov_bin_manager.cc new file mode 100644 index 0000000000000..bdab631bb478b --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_bin_manager.cc @@ -0,0 +1,440 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "ov_bin_manager.h" +#include "ov_shared_context.h" +#include +#include "core/providers/shared_library/provider_api.h" // for ORT_VERSION and kOpenVINOExecutionProvider + +namespace onnxruntime { +namespace openvino_ep { + +static inline uint64_t AlignUp(uint64_t value, uint64_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +// Custom streambuf that wraps an ov::Tensor's memory +// Provides us a std::istream interface over the tensor data without copying. +// Only supports input operations. +class TensorStreamBuf : public std::streambuf { + public: + explicit TensorStreamBuf(ov::Tensor& tensor) { + char* data = const_cast(tensor.data()); + size_t size = tensor.get_byte_size(); + setg(data, data, data + size); + } + + protected: + // Override seekoff for proper seeking support + std::streampos seekoff(std::streamoff off, std::ios_base::seekdir dir, std::ios_base::openmode which) override { + if (which & std::ios_base::in) { + char* new_pos = nullptr; + switch (dir) { + case std::ios_base::beg: + new_pos = eback() + off; + break; + case std::ios_base::cur: + new_pos = gptr() + off; + break; + case std::ios_base::end: + new_pos = egptr() + off; + break; + default: + return std::streampos(std::streamoff(-1)); + } + + if (new_pos >= eback() && new_pos <= egptr()) { + setg(eback(), new_pos, egptr()); + return std::streampos(new_pos - eback()); + } + } + return std::streampos(std::streamoff(-1)); + } + + // Override seekpos for proper seeking support + std::streampos seekpos(std::streampos pos, std::ios_base::openmode which) override { + return seekoff(std::streamoff(pos), std::ios_base::beg, which); + } +}; + +// Custom istream that owns the tensor to ensure proper lifetime management +class TensorStream : public std::istream { + public: + explicit TensorStream(ov::Tensor tensor) + : std::istream(&buf_), + tensor_(std::move(tensor)), + buf_(tensor_) {} + + private: + ov::Tensor tensor_; // Keep tensor alive + TensorStreamBuf buf_; // Buffer wrapping tensor data +}; + +/* + Logical layout of the single binary file: + [Header] + [BSON Metadata] ← Contains blob_metadata_map with data_offset and size for each blob + [Padding to 64K alignment] ← Blob section starts here (64K aligned) + [Blob 1] ← BSON blob_metadata_map["blob_name"].data_offset points here + [Padding to 64K alignment] ← Each blob end is 64K aligned + [Blob 2] ← BSON blob_metadata_map["blob_name2"].data_offset points here + [Padding to 64K alignment] + [Blob 3] ← BSON blob_metadata_map["blob_name3"].data_offset points here + ... + + BSON Schema: + { + "version": , // BSON schema version (semver format) + "producer": , // Producer identifier (e.g., "onnxruntime-openvino-ep-plugin") + "weights_metadata_map": { // Map of ONNX tensor names to external weight file metadata + "": { + "location": , // Relative path to external weights file + "data_offset": , // Offset within external weights file + "size": // Size of weight data in bytes + }, + ... + }, + "blob_metadata_map": { // Map of blob names to compiled model blob metadata + "": { + "data_offset": , // Absolute file offset to blob data (64K aligned) + "size": // Actual blob data size (excluding padding) + }, + ... + } + } + + Note: data_offset values in blob_metadata_map are absolute file positions. + size values exclude alignment padding bytes. +*/ + +// "OVEP_BIN" in little-endian (memory will read as 'O','V','E','P','_','B','I','N') +constexpr uint64_t kMagicNumber = 0x4E49425F5045564FULL; + +enum class BinVersion : uint64_t { + v1 = 1, + current = v1 +}; + +struct header_t { + uint64_t magic; + uint64_t version; + uint64_t header_size; + uint64_t bson_start_offset; + uint64_t bson_size; +}; + +constexpr uint64_t kBlobAlignment = 64 * 1024; + +// BSON field names +namespace BSONFields { +constexpr const char* kVersion = "version"; +constexpr const char* kProducer = "producer"; +constexpr const char* kWeightsMetadata = "weights_metadata_map"; +constexpr const char* kBlobMetadata = "blob_metadata_map"; +constexpr const char* kLocation = "location"; +constexpr const char* kDataOffset = "data_offset"; +constexpr const char* kSize = "size"; +constexpr const char* kCurrentBsonVersion = "1.0.0"; +constexpr const char* kProducerName = "onnxruntime-openvino-ep-" ORT_VERSION; +} // namespace BSONFields + +template +constexpr std::underlying_type_t to_underlying(E e) noexcept { + static_assert(std::is_enum_v, "to_underlying requires an enum type"); + return static_cast>(e); +} + +void BinManager::AddNativeBlob(const std::string& name, const ov::CompiledModel& compiled_model) { + std::unique_lock lock(mutex_); + native_blobs_[name] = BlobContainer{.compiled_model = compiled_model, .tensor = {}, .data = {}, .serialized_info = {0, 0}}; +} + +ov::Tensor BinManager::GetNativeBlob(const std::string& blob_name) { + std::unique_lock lock(mutex_); + + auto it = native_blobs_.find(blob_name); + ORT_ENFORCE(it != native_blobs_.end(), "Blob not found for ", blob_name); + + auto& blob_container = it->second; + if (blob_container.tensor) { + return blob_container.tensor; + } + + ORT_ENFORCE(blob_container.serialized_info.size > 0 || !blob_container.data.empty(), + "Blob has no serialization info or embedded data for ", blob_name); + + if (!external_bin_path_.value_or("").empty() && !mapped_bin_) { + // Use ov::read_tensor_data to create a memory-mapped tensor from external file + mapped_bin_ = ov::read_tensor_data(external_bin_path_.value()); + } + + if (mapped_bin_) { + // Create a tensor from memory-mapped external file + blob_container.tensor = ov::Tensor( + ov::element::u8, + ov::Shape{blob_container.serialized_info.size}, + mapped_bin_.data() + blob_container.serialized_info.file_offset); + } else { + // Create a tensor from embedded data vector + blob_container.tensor = ov::Tensor( + ov::element::u8, + ov::Shape{blob_container.data.size()}, + blob_container.data.data()); + } + + return blob_container.tensor; +} + +std::unique_ptr BinManager::GetNativeBlobAsStream(const std::string& blob_name) { + return std::make_unique(GetNativeBlob(blob_name)); +} + +void BinManager::Clear() { + std::unique_lock lock(mutex_); + native_blobs_.clear(); + mapped_bin_ = {}; + external_bin_path_.reset(); +} + +std::filesystem::path BinManager::GetBinPathForModel(const std::filesystem::path& model_path) { + ORT_ENFORCE(!model_path.empty()); + return model_path.parent_path() / (model_path.stem().string() + "_" + kOpenVINOExecutionProvider + ".bin"); +} + +void BinManager::Serialize(std::shared_ptr shared_context) { + auto path = GetExternalBinPath(); + std::ofstream stream(path, std::ios::out | std::ios::binary); + ORT_ENFORCE(stream.is_open(), "Failed to open file for serialization: " + path.string()); + Serialize(stream, shared_context); +} + +void BinManager::Deserialize(std::shared_ptr shared_context) { + auto path = GetExternalBinPath(); + std::ifstream stream(path, std::ios::in | std::ios::binary); + ORT_ENFORCE(stream.is_open(), "Failed to open file for deserialization: " + path.string()); + Deserialize(stream, shared_context); +} + +bool BinManager::ShouldSerialize(const std::shared_ptr& shared_context) const { + if (shared_context) { + auto metadata = shared_context->GetMetadataCopy(); + if (!metadata.empty()) { + return true; + } + } + return !native_blobs_.empty(); +} + +void BinManager::Serialize(std::ostream& stream, std::shared_ptr shared_context) { + std::shared_lock ul(mutex_); + + if (!ShouldSerialize(shared_context)) { + // nothing to serialize + return; + } + + const auto stream_start = stream.tellp(); + + auto write_alignment_padding = [&stream](uint64_t current_pos, uint64_t alignment) { + uint64_t aligned_position = AlignUp(current_pos, alignment); + uint64_t padding_size = aligned_position - current_pos; + if (padding_size > 0) { + std::vector padding(padding_size, 0); + stream.write(padding.data(), padding.size()); + ORT_ENFORCE(stream.good(), "Error: Failed to write alignment padding."); + } + }; + + // Reserve space for header (will be updated later) + header_t header{}; + header.magic = kMagicNumber; + header.version = to_underlying(BinVersion::current); + header.header_size = sizeof(header_t); + stream.write(reinterpret_cast(&header), sizeof(header)); + ORT_ENFORCE(stream.good(), "Error: Failed to write header."); + + // Build JSON metadata + nlohmann::json j; + j[BSONFields::kVersion] = BSONFields::kCurrentBsonVersion; + j[BSONFields::kProducer] = BSONFields::kProducerName; + + // Add weights metadata as a map (from SharedContext if available) + if (shared_context) { + auto metadata = shared_context->GetMetadataCopy(); + if (!metadata.empty()) { + nlohmann::json weights_map = nlohmann::json::object(); + for (const auto& [key, value] : metadata) { + nlohmann::json weight_entry; + weight_entry[BSONFields::kLocation] = value.serialized.location.string(); + weight_entry[BSONFields::kDataOffset] = value.serialized.data_offset; + weight_entry[BSONFields::kSize] = value.serialized.size; + weights_map[key] = weight_entry; + } + j[BSONFields::kWeightsMetadata] = weights_map; + } + } + + // Add blob metadata with placeholder values as a map (will be updated after writing blobs) + nlohmann::json blob_map = nlohmann::json::object(); + for (const auto& [key, value] : native_blobs_) { + nlohmann::json blob_entry; + auto max_val = std::numeric_limits::max(); + // Placehold max size since we don't know actual offsets/sizes yet, and if they aren't max they might serialize smaller. + blob_entry[BSONFields::kDataOffset] = max_val; + blob_entry[BSONFields::kSize] = max_val; + blob_map[key] = blob_entry; + } + j[BSONFields::kBlobMetadata] = blob_map; + + // Write BSON metadata (will be rewritten later with correct blob info) + header.bson_start_offset = stream.tellp(); + + size_t orig_bson_size; + { + std::vector bson_data = nlohmann::json::to_bson(j); + orig_bson_size = bson_data.size(); + stream.write(reinterpret_cast(bson_data.data()), bson_data.size()); + ORT_ENFORCE(stream.good(), "Error: Failed to write BSON data."); + } + uint64_t bson_end = stream.tellp(); + + write_alignment_padding(bson_end, kBlobAlignment); + + // Write blob data and capture actual offsets/sizes + for (auto& [blob_name, value] : native_blobs_) { + uint64_t blob_start = stream.tellp(); + value.compiled_model.export_model(stream); + ORT_ENFORCE(stream.good(), "Error: Failed to write blob data for ", blob_name); + // Seek to end of stream after writing in case export model didn't leave us there + stream.seekp(0, std::ios::end); + uint64_t blob_end = stream.tellp(); + uint64_t blob_size = blob_end - blob_start; + + // Update the BlobContainer + BSON with serialization info + value.serialized_info.file_offset = blob_start; + value.serialized_info.size = blob_size; + j[BSONFields::kBlobMetadata][blob_name][BSONFields::kDataOffset] = blob_start; + j[BSONFields::kBlobMetadata][blob_name][BSONFields::kSize] = blob_size; + + write_alignment_padding(blob_end, kBlobAlignment); + } + + // Rewrite BSON metadata with correct blob info + std::vector updated_bson_data = nlohmann::json::to_bson(j); + ORT_ENFORCE(updated_bson_data.size() <= orig_bson_size, + "Error: BSON size larger after updating blob info. Original: ", orig_bson_size, + " Updated: ", updated_bson_data.size()); + + stream.seekp(header.bson_start_offset); + stream.write(reinterpret_cast(updated_bson_data.data()), updated_bson_data.size()); + ORT_ENFORCE(stream.good(), "Error: Failed to rewrite BSON data."); + bson_end = stream.tellp(); + header.bson_size = bson_end - header.bson_start_offset; + + // Update header with BSON offsets + stream.seekp(stream_start); + stream.write(reinterpret_cast(&header), sizeof(header)); + ORT_ENFORCE(stream.good(), "Error: Failed to update header."); + + stream.seekp(0, std::ios::end); // Move to end after writing. +} + +void BinManager::Deserialize(std::istream& stream, std::shared_ptr shared_context) { + // Read and validate header + header_t header{}; + + stream.read(reinterpret_cast(&header), sizeof(header)); + ORT_ENFORCE(stream.good(), "Error: Failed to read header."); + ORT_ENFORCE(header.magic == kMagicNumber, "Error: Invalid magic number. Expected: 0x", std::hex, kMagicNumber, " Got: 0x", header.magic); + ORT_ENFORCE(header.version == to_underlying(BinVersion::current), "Error: Unsupported file version: ", header.version); + ORT_ENFORCE(header.header_size == sizeof(header_t), "Error: Header size mismatch."); + + // Seek to BSON metadata and read it + stream.seekg(header.bson_start_offset); + ORT_ENFORCE(stream.good(), "Error: Failed to seek to BSON metadata."); + + // Parse BSON + nlohmann::json j; + { + std::vector bson_data(header.bson_size); + stream.read(reinterpret_cast(bson_data.data()), header.bson_size); + j = nlohmann::json::from_bson(bson_data); + } + + // Validate BSON version (check major version compatibility) + ORT_ENFORCE(j.contains(BSONFields::kVersion), "Error: Missing version in BSON metadata."); + auto bson_version = j[BSONFields::kVersion].get(); + + // Extract major version from semver strings (format: "major.minor.patch") + auto get_major_version = [](const std::string& version) -> int { + size_t dot_pos = version.find('.'); + if (dot_pos == std::string::npos) return -1; + try { + return std::stoi(version.substr(0, dot_pos)); + } catch (...) { + return -1; + } + }; + + int file_major = get_major_version(bson_version); + int current_major = get_major_version(BSONFields::kCurrentBsonVersion); + + ORT_ENFORCE(file_major >= 0 && current_major >= 0, + "Error: Invalid BSON version format. Expected: ", BSONFields::kCurrentBsonVersion, + " Got: ", bson_version); + ORT_ENFORCE(file_major == current_major, + "Error: Incompatible BSON schema major version. Expected: ", current_major, + " Got: ", file_major, " (full version: ", bson_version, ")"); + + // Parse weights metadata and populate SharedContext if available + if (j.contains(BSONFields::kWeightsMetadata)) { + ORT_ENFORCE(shared_context, "Error: Bin contains shared weights metadata but no SharedContext was provided during deserialization."); + const auto& weights_map = j[BSONFields::kWeightsMetadata]; + if (weights_map.is_object()) { + for (const auto& [weight_name, weight_entry] : weights_map.items()) { + auto location = weight_entry[BSONFields::kLocation].get(); + auto data_offset = weight_entry[BSONFields::kDataOffset].get(); + auto size = weight_entry[BSONFields::kSize].get(); + shared_context->AddExternalWeight(weight_name, data_offset, size, location); + } + } + } + + // Parse blob metadata + ORT_ENFORCE(j.contains(BSONFields::kBlobMetadata), "Error: Missing blob metadata in BSON."); + const auto& blob_map = j[BSONFields::kBlobMetadata]; + ORT_ENFORCE(blob_map.is_object(), "Error: Blob metadata must be an object."); + + // Determine if we're deserializing from an external file or embedded stream + const bool has_external_file = !external_bin_path_.value_or("").empty(); + + std::unique_lock lock(mutex_); + for (const auto& [blob_name, blob_entry] : blob_map.items()) { + uint64_t blob_offset = blob_entry[BSONFields::kDataOffset].get(); + uint64_t blob_size = blob_entry[BSONFields::kSize].get(); + + BlobContainer container; + container.serialized_info.file_offset = blob_offset; + container.serialized_info.size = blob_size; + + // If no external file, extract blob data into vector + if (!has_external_file) { + // Seek to blob offset and read data into vector + auto current_pos = stream.tellg(); + stream.seekg(blob_offset); + ORT_ENFORCE(stream.good(), "Error: Failed to seek to blob data for ", blob_name); + + container.data.resize(blob_size); + stream.read(reinterpret_cast(container.data.data()), blob_size); + ORT_ENFORCE(stream.good(), "Error: Failed to read blob data for ", blob_name); + + // Restore stream position + stream.seekg(current_pos); + } + + native_blobs_[blob_name] = std::move(container); + } +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_bin_manager.h b/onnxruntime/core/providers/openvino/ov_bin_manager.h new file mode 100644 index 0000000000000..d6d6ada2d252a --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_bin_manager.h @@ -0,0 +1,77 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/runtime/core.hpp" +#include "weak_singleton.h" + +namespace onnxruntime { +namespace openvino_ep { + +// Forward declaration +class SharedContext; + +// Manages native compiled model blobs and binary file serialization/deserialization +class BinManager { + public: + BinManager() = default; + BinManager(const std::filesystem::path& external_bin_path) : external_bin_path_(external_bin_path) {} + ~BinManager() = default; + + // Blob management + void AddNativeBlob(const std::string& name, const ov::CompiledModel& compiled_model); + ov::Tensor GetNativeBlob(const std::string& blob_name); + std::unique_ptr GetNativeBlobAsStream(const std::string& blob_name); + void Clear(); + + // Serialization/Deserialization + void Serialize(std::ostream& stream, std::shared_ptr shared_context = nullptr); + void Deserialize(std::istream& stream, std::shared_ptr shared_context = nullptr); + + void Serialize(std::shared_ptr shared_context = nullptr); + void Deserialize(std::shared_ptr shared_context = nullptr); + + // Path management + void TrySetExternalBinPath(const std::filesystem::path& bin_path) { + std::unique_lock lock(mutex_); + if (!external_bin_path_) { + external_bin_path_ = bin_path; + } + } + std::filesystem::path GetExternalBinPath() const { + std::shared_lock lock(mutex_); + return external_bin_path_.value_or(""); + } + + static std::filesystem::path GetBinPathForModel(const std::filesystem::path& model_path); + + private: + struct BlobContainer { + ov::CompiledModel compiled_model; + ov::Tensor tensor; + std::vector data; // For embedded blobs when no external file exists + struct { + uint64_t file_offset{0}; + uint64_t size{0}; + } serialized_info; + }; + + bool ShouldSerialize(const std::shared_ptr& shared_context) const; + + mutable std::shared_mutex mutex_; + std::optional external_bin_path_; + ov::Tensor mapped_bin_; + std::unordered_map native_blobs_; +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc index 2853cc17726ab..5119c611d3f3d 100644 --- a/onnxruntime/core/providers/openvino/ov_factory.cc +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -16,7 +16,7 @@ #include "onnxruntime_c_api.h" #include "ov_factory.h" #include "openvino/openvino.hpp" -#include "ov_interface.h" +#include "weak_singleton.h" using namespace onnxruntime::openvino_ep; using ov_core_singleton = onnxruntime::openvino_ep::WeakSingleton; diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index e97bbaceee4e2..85fc4d93d6243 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -199,8 +199,8 @@ OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, return OvExceptionBoundary([&]() { ov::CompiledModel obj; #if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) - if (!model_blob.maybe_native_blob_path_.empty()) { - obj = core.import_model(ov::read_tensor_data(model_blob.maybe_native_blob_path_), hw_target, device_config); + if (model_blob.tensor_) { + obj = core.import_model(model_blob.tensor_, hw_target, device_config); } else { obj = core.import_model(*model_blob.stream_, hw_target, device_config); } diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index d5d4bd1af0c6a..5df5420a427f2 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -18,6 +18,7 @@ #include "openvino/frontend/manager.hpp" #include "openvino/core/dimension.hpp" #include "openvino/core/partial_shape.hpp" +#include "weak_singleton.h" #include @@ -47,31 +48,6 @@ typedef std::shared_ptr OVTensorPtr; std::optional queryOVProperty(const std::string& property, const std::string& device_type); -template -class WeakSingleton { - public: - static std::shared_ptr Get() { - static std::weak_ptr instance; - static std::mutex mutex; - - auto ptr = instance.lock(); - if (!ptr) { - std::lock_guard lock(mutex); - // ensure another thread didn't create an instance while this thread was waiting - ptr = instance.lock(); - if (!ptr) { - ptr = std::make_shared(); - instance = ptr; - } - } - return ptr; - } - - protected: - WeakSingleton() = default; - virtual ~WeakSingleton() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeakSingleton); -}; struct OVCore : WeakSingleton { ov::Core core; @@ -153,7 +129,7 @@ class OVInferRequest { virtual void Infer(); explicit OVInferRequest(ov::InferRequest obj) : ovInfReq(std::move(obj)) {} OVInferRequest() : ovInfReq(ov::InferRequest()) {} - ov::InferRequest& GetNewObj() { + ov::InferRequest& GetInfReq() { return ovInfReq; } virtual void RewindKVCache([[maybe_unused]] size_t index) {} diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.cc b/onnxruntime/core/providers/openvino/ov_shared_context.cc new file mode 100644 index 0000000000000..84cce6e7e16d4 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_shared_context.cc @@ -0,0 +1,145 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#include "ov_shared_context.h" +#include "ov_interface.h" + +#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp" +#include "openvino/core/type/element_type.hpp" + +namespace onnxruntime { +namespace openvino_ep { + +SharedContext::SharedContext(std::filesystem::path bin_path) + : bin_path_(std::move(bin_path)), + bin_manager_(bin_path_) { +} + +static bool InRange(size_t offset, size_t size, size_t total_size) { + return (offset < total_size) && (size <= total_size) && (offset <= total_size - size); +} + +// Weights file handling +SharedContext::WeightsFile::WeightsFile(const std::filesystem::path& filename) : file_(filename, std::ios::in | std::ios::binary), file_path_(filename) { + try { + file_.exceptions(std::ifstream::failbit | std::ifstream::badbit); + weights_size_ = std::filesystem::file_size(filename); + } catch (std::exception& e) { + ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what()); + } +} + +void SharedContext::WeightsFile::LoadWeights(size_t file_offset, void* data, size_t size) { + ORT_ENFORCE(InRange(file_offset, size, weights_size_), "Error: File offset is out of bounds."); + file_.seekg(file_offset); + file_.read(static_cast(data), size); +} + +void* SharedContext::WeightsFile::TryGetOrCreateDeviceMapping(std::optional& remote_context) { + std::string dev_name{}; + if (remote_context) { + dev_name = remote_context->get_device_name(); + } + + auto [it, inserted] = imported_device_tensors_.emplace(dev_name, MappingContainer{}); + if (inserted) { + if (dev_name == "NPU") { + // try to import the memory mapped file to remote tensor +#if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) + ORT_ENFORCE(remote_context, "Error: Remote context is required for NPU device."); + auto npu_context = remote_context->as(); + auto&& l0_tensor = npu_context.create_tensor(ov::element::Type_t::u8, {weights_size_}, ov::intel_npu::FileDescriptor(file_path_)); + it->second = MappingContainer{.ptr_ = l0_tensor.get(), .tensor_ = l0_tensor}; +#endif + } else if (dev_name.empty()) { + // CPU/virtual device case, create a CPU tensor memory mapped from file + auto&& mmaped_tensor = ov::read_tensor_data(file_path_); + it->second = MappingContainer{.ptr_ = mmaped_tensor.data(), .tensor_ = mmaped_tensor}; + } + } + + return it->second.ptr_; +} + +void SharedContext::LoadTensorFromFile( + Metadata::Value& value, + const std::filesystem::path& model_dir, + std::optional& remote_context, + const ov::element::Type& element_type, + const ov::Shape& dimensions) { + const auto weights_location = model_dir / value.serialized.location; + auto& weights_file = weight_files_[weights_location]; + if (!weights_file) { + weights_file = std::make_unique(weights_location); + } + + ov::Tensor tensor; + uint8_t* mmaped_weights = static_cast(weights_file->TryGetOrCreateDeviceMapping(remote_context)); + if (mmaped_weights) { + // We have memory mapped weights. Create a Tensor view into it for this value. + ORT_ENFORCE(InRange(value.serialized.data_offset, value.serialized.size, weights_file->Size()), "File offset + size outside of external initializer file"); + void* mmapped_offset = static_cast(mmaped_weights + value.serialized.data_offset); + tensor = ov::Tensor(element_type, dimensions, mmapped_offset); + } else { + ORT_ENFORCE(remote_context, "Unexpected: Don't have remote context and memory mapped weights is null!"); + // Can't mmap the file to device tensor, create a host tensor and copy the data + tensor = remote_context->create_host_tensor(element_type, dimensions); + ORT_ENFORCE(tensor.get_byte_size() == value.serialized.size, "Remote tensor size mismatch"); + weights_file->LoadWeights(value.serialized.data_offset, tensor.data(), value.serialized.size); + } + + ORT_ENFORCE(tensor.get_byte_size() == value.serialized.size, "Tensor size mismatch"); + value.tensor = std::make_shared(std::move(tensor)); +} + +void SharedContext::SetSharedWeightsOnInferRequest(ov::InferRequest& ir, const std::filesystem::path& model_dir) { + auto&& compiled_model = ir.get_compiled_model(); + std::optional opt_remote_ctx; + try { + opt_remote_ctx = compiled_model.get_context(); + } catch (ov::Exception&) { + // CPU may not have a remote context. + } + + std::unique_lock ul(mutex_); + for (const auto& input : compiled_model.inputs()) { + const std::string tensor_name = *input.get_names().begin(); + + auto it = metadata_.find(tensor_name); + if (it == metadata_.end()) continue; // No shared weight for this tensor + auto& value = it->second; + + if (!value.tensor) { + LoadTensorFromFile(value, model_dir, opt_remote_ctx, input.get_element_type(), input.get_shape()); + } + ir.set_tensor(tensor_name, *value.tensor); + } +} + +void SharedContext::Serialize(std::ostream& stream) { + bin_manager_.Serialize(stream, shared_from_this()); +} + +void SharedContext::Deserialize(std::istream& stream) { + bin_manager_.Deserialize(stream, shared_from_this()); +} + +void SharedContext::Serialize() { + bin_manager_.Serialize(shared_from_this()); +} + +void SharedContext::Deserialize() { + bin_manager_.Deserialize(shared_from_this()); +} + +void SharedContext::Clear() { + // Outside the mutex since bin_manager has it's own lock, and we want to keep lock ordering consistent + // It's ok for clear to not be fully atomic we're primarily interested in internal consistency. + bin_manager_.Clear(); + std::unique_lock lock(mutex_); + weight_files_.clear(); + metadata_.clear(); +} + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.h b/onnxruntime/core/providers/openvino/ov_shared_context.h new file mode 100644 index 0000000000000..c893b64442fa4 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_shared_context.h @@ -0,0 +1,159 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/runtime/core.hpp" +#include "ov_bin_manager.h" +#include "weak_singleton.h" + +namespace onnxruntime { +namespace openvino_ep { + +class SharedContext : public std::enable_shared_from_this { + public: + explicit SharedContext(std::filesystem::path bin_path); + SharedContext() : SharedContext("") {} + + struct Metadata { + struct Value { + struct { + std::filesystem::path location{}; + size_t data_offset{0}; + size_t size{0}; + } serialized; + + std::shared_ptr tensor; + }; + using Map = std::unordered_map; + }; + + bool IsSharedWeight(const std::string& name) const { + std::shared_lock lock(mutex_); + return metadata_.contains(name); + } + + void AddExternalWeight(const std::string& name, size_t offset, size_t size, const std::filesystem::path& location) { + Metadata::Value value; + value.serialized.data_offset = offset; + value.serialized.size = size; + value.serialized.location = location; + std::unique_lock lock(mutex_); + metadata_[name] = std::move(value); + } + + Metadata::Map GetMetadataCopy() const { + std::shared_lock lock(mutex_); + return metadata_; + } + + void SetSharedWeightsOnInferRequest(ov::InferRequest& ir, const std::filesystem::path& model_dir); + + void AddNativeBlob(const std::string& name, const ov::CompiledModel& compiled_model) { + bin_manager_.AddNativeBlob(name, compiled_model); + } + + ov::Tensor GetNativeBlob(const std::string& blob_name) { + return bin_manager_.GetNativeBlob(blob_name); + } + + std::unique_ptr GetNativeBlobAsStream(const std::string& blob_name) { + return bin_manager_.GetNativeBlobAsStream(blob_name); + } + + void Serialize(std::ostream& stream); + void Deserialize(std::istream& stream); + void Serialize(); + void Deserialize(); + + void Clear(); + + std::filesystem::path GetBinPath() const { + return bin_manager_.GetExternalBinPath(); + } + + static std::filesystem::path GetBinPathForModel(const std::filesystem::path& model_path) { + return BinManager::GetBinPathForModel(model_path); + } + + private: + struct WeightsFile { + ORT_DISALLOW_COPY_AND_ASSIGNMENT(WeightsFile); + WeightsFile() = delete; + virtual ~WeightsFile() = default; + explicit WeightsFile(const std::filesystem::path& filename); + void LoadWeights(size_t file_offset, void* data, size_t size); + void* TryGetOrCreateDeviceMapping(std::optional& remote_context); + size_t Size() const { return weights_size_; } + + private: + std::ifstream file_; + std::filesystem::path file_path_; + size_t weights_size_; + struct MappingContainer { + void* ptr_{nullptr}; + ov::Tensor tensor_; + }; + std::map imported_device_tensors_; + }; + + void LoadTensorFromFile( + Metadata::Value& value, + const std::filesystem::path& model_dir, + std::optional& remote_context, + const ov::element::Type& element_type, + const ov::Shape& dimensions); + + mutable std::shared_mutex mutex_; + std::filesystem::path bin_path_; + BinManager bin_manager_; + std::unordered_map> weight_files_; + Metadata::Map metadata_; +}; + +class SharedContextManager : public WeakSingleton { + public: + std::shared_ptr GetOrCreateActiveSharedContext(const std::filesystem::path& model_path) { + std::lock_guard lock(mutex_); + if (active_context_) { + return active_context_; + } + auto [it, inserted] = contexts_.try_emplace(model_path, nullptr); + if (inserted) { + it->second = std::make_shared(model_path); + } + active_context_ = it->second; + return it->second; + } + + std::shared_ptr GetOrCreateSharedContext(const std::filesystem::path& model_path) { + std::lock_guard lock(mutex_); + auto [it, inserted] = contexts_.try_emplace(model_path, nullptr); + if (inserted) { + it->second = std::make_shared(model_path); + } + return it->second; + } + + void ClearActiveSharedContext() { + std::lock_guard lock(mutex_); + active_context_ = nullptr; + } + + private: + mutable std::mutex mutex_; + std::unordered_map> contexts_; + std::shared_ptr active_context_; +}; + +} // namespace openvino_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index e010851f22e50..2e5bb7b8c86be 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -704,7 +704,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, bool enable_ovep_weight_sharing, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights) { + /*out*/ SharedContext& shared_context) { // NOTE: This function is a re-implementation of GraphViewerToProto() in core/graph/graph_proto_serializer.cc // with the following differences: // - Uses onnxruntime::Graph APIs instead of onnx::GraphProto APIs. @@ -824,34 +824,28 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, }); // initialize map for creating metadata for initilizers with external weights - auto& metadata = shared_weights.metadata; - - const auto& insert_metadata = [&metadata](const ONNX_NAMESPACE::TensorProto& proto) { - sw::Metadata::Map::key_type key{proto.name()}; - sw::Metadata::Map::mapped_type value{}; + const auto& add_shared_weight = [&shared_context](const ONNX_NAMESPACE::TensorProto& proto) { using mutable_proto_t = ONNX_NAMESPACE::TensorProto*; auto& mutable_proto = *const_cast(&proto); auto* entry_protos = mutable_proto.mutable_external_data(); + + std::string location = ""; + size_t data_offset = 0, size = 0; for (int i = 0; i < entry_protos->size(); i++) { auto& string_entry_proto{entry_protos->at(i)}; const auto& pb_key{*(string_entry_proto.mutable_key())}; const auto& pb_value{*(string_entry_proto.mutable_value())}; if (pb_key == "location") { - value.location = pb_value; + location = pb_value; } else if (pb_key == "offset") { - value.data_offset = std::stoul(pb_value); + data_offset = std::stoul(pb_value); } else if (pb_key == "length") { - value.size = std::stoul(pb_value); + size = std::stoul(pb_value); } } - value.element_type = proto.data_type(); - value.dimensions.resize(proto.dims_size()); - for (uint32_t index = 0; auto& dim : value.dimensions) { - dim = proto.dims()[index++]; - } - metadata.emplace(key, std::move(value)); + shared_context.AddExternalWeight(proto.name(), data_offset, size, location); }; // Handle initializers @@ -871,7 +865,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!is_quant_param) { // This is actual weight data - so to convert to input for weight sharing - insert_metadata(initializer_tensor); + add_shared_weight(initializer_tensor); AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); } else { // This is a quantization parameter - keep as initializer even if external @@ -912,7 +906,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!init_with_data && utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { - insert_metadata(initializer_tensor); + add_shared_weight(initializer_tensor); // Add initializer as input if it has external data AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, input->Name()); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h index 53de0fd019311..e649b3ec71943 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace openvino_ep { -using sw = SharedContext::SharedWeights; +class SharedContext; // Creates a new model without the DQ/Q operators in the src graph as per pre-defined rulesets Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, @@ -18,8 +18,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, bool enable_ovep_weight_sharing, bool enable_ovep_qdq_optimizer, /*out*/ std::unique_ptr& model, - /*out*/ sw& shared_weights); + /*out*/ SharedContext& shared_context); -bool dumpMetaDataMapToBinary(const sw::Metadata::Map& shared_weights, const std::string& filename); } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/weak_singleton.h b/onnxruntime/core/providers/openvino/weak_singleton.h new file mode 100644 index 0000000000000..949ed1b527c60 --- /dev/null +++ b/onnxruntime/core/providers/openvino/weak_singleton.h @@ -0,0 +1,40 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License + +#pragma once + +#include +#include +#include "core/common/common.h" + +namespace onnxruntime { +namespace openvino_ep { + +template +class WeakSingleton { + public: + static std::shared_ptr Get() { + static std::weak_ptr instance; + static std::mutex mutex; + + auto ptr = instance.lock(); + if (!ptr) { + std::lock_guard lock(mutex); + // ensure another thread didn't create an instance while this thread was waiting + ptr = instance.lock(); + if (!ptr) { + ptr = std::make_shared(); + instance = ptr; + } + } + return ptr; + } + + protected: + WeakSingleton() = default; + virtual ~WeakSingleton() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeakSingleton); +}; + +} // namespace openvino_ep +} // namespace onnxruntime