From 405d30fcef13d5954fdc63fb3354c4ee8825abe3 Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Mon, 13 Mar 2023 23:36:26 -0700 Subject: [PATCH 1/2] Fix CPU memory leak due to external weights not getting unmapped when using non-CPU EP. --- onnxruntime/core/framework/callback.h | 3 ++- onnxruntime/core/framework/session_state_utils.cc | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/callback.h b/onnxruntime/core/framework/callback.h index 88f14d7a5a53..11ec92d18949 100644 --- a/onnxruntime/core/framework/callback.h +++ b/onnxruntime/core/framework/callback.h @@ -40,6 +40,7 @@ struct OrtCallbackInvoker { */ class ScopedOrtCallbackInvoker { public: + ScopedOrtCallbackInvoker() {} explicit ScopedOrtCallbackInvoker(OrtCallback callback) noexcept : callback_(callback) {} @@ -69,6 +70,6 @@ class ScopedOrtCallbackInvoker { private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScopedOrtCallbackInvoker); - OrtCallback callback_; + OrtCallback callback_{}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 4f4b71a8acdd..624f9162258c 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -161,9 +161,11 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } OrtCallback ext_data_deleter; + ScopedOrtCallbackInvoker scoped_ort_callback_invoker; if (utils::HasExternalData(tensor_proto)) { ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, ext_data_deleter)); + scoped_ort_callback_invoker = ScopedOrtCallbackInvoker(ext_data_deleter); } else { ORT_RETURN_IF_ERROR(utils::TensorProtoToTensor(env, proto_path.c_str(), tensor_proto, *p_deserialize_tensor)); } From 13d782879844a5908c919c2dc14aad45c9090a1b Mon Sep 17 00:00:00 2001 From: Pranav Sharma Date: Wed, 15 Mar 2023 16:47:28 -0700 Subject: [PATCH 2/2] Use optional --- onnxruntime/core/framework/callback.h | 3 +-- onnxruntime/core/framework/session_state_utils.cc | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/framework/callback.h b/onnxruntime/core/framework/callback.h index 11ec92d18949..88f14d7a5a53 100644 --- a/onnxruntime/core/framework/callback.h +++ b/onnxruntime/core/framework/callback.h @@ -40,7 +40,6 @@ struct OrtCallbackInvoker { */ class ScopedOrtCallbackInvoker { public: - ScopedOrtCallbackInvoker() {} explicit ScopedOrtCallbackInvoker(OrtCallback callback) noexcept : callback_(callback) {} @@ -70,6 +69,6 @@ class ScopedOrtCallbackInvoker { private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScopedOrtCallbackInvoker); - OrtCallback callback_{}; + OrtCallback callback_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 624f9162258c..0a1720b5b7aa 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -161,7 +161,7 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } OrtCallback ext_data_deleter; - ScopedOrtCallbackInvoker scoped_ort_callback_invoker; + std::optional scoped_ort_callback_invoker; if (utils::HasExternalData(tensor_proto)) { ORT_RETURN_IF_ERROR(ExtDataTensorProtoToTensor(env, proto_path, tensor_proto, *p_deserialize_tensor, ext_data_deleter));