diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h index dd75aeea48424c..82bd79390e6108 100644 --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -417,7 +417,7 @@ class DiagnosticEngine { /// The handler type for MLIR diagnostics. This function takes a diagnostic as /// input, and returns success if the handler has fully processed this /// diagnostic. Returns failure otherwise. - using HandlerTy = std::function; + using HandlerTy = llvm::unique_function; /// A handle to a specific registered handler object. using HandlerID = uint64_t; @@ -427,7 +427,7 @@ class DiagnosticEngine { /// handlers will process diagnostics first. This function returns a unique /// identifier for the registered handler, which can be used to unregister /// this handler at a later time. - HandlerID registerHandler(const HandlerTy &handler); + HandlerID registerHandler(HandlerTy handler); /// Set the diagnostic handler with a function that returns void. This is a /// convenient wrapper for handlers that always completely process the given diff --git a/mlir/lib/CAPI/IR/Diagnostics.cpp b/mlir/lib/CAPI/IR/Diagnostics.cpp index 40639c7ba31b71..4a13ae57611d8e 100644 --- a/mlir/lib/CAPI/IR/Diagnostics.cpp +++ b/mlir/lib/CAPI/IR/Diagnostics.cpp @@ -59,11 +59,12 @@ MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( assert(handler && "unexpected null diagnostic handler"); if (deleteUserData == nullptr) deleteUserData = deleteUserDataNoop; - std::shared_ptr sharedUserData(userData, deleteUserData); DiagnosticEngine::HandlerID id = unwrap(context)->getDiagEngine().registerHandler( - [handler, sharedUserData](Diagnostic &diagnostic) { - return unwrap(handler(wrap(diagnostic), sharedUserData.get())); + [handler, + ownedUserData = std::unique_ptr( + userData, deleteUserData)](Diagnostic &diagnostic) { + return unwrap(handler(wrap(diagnostic), ownedUserData.get())); }); return static_cast(id); } diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 975f6943fc673d..98e4d40461c36b 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -272,10 +272,10 @@ DiagnosticEngine::~DiagnosticEngine() = default; /// Register a new handler for diagnostics to the engine. This function returns /// a unique identifier for the registered handler, which can be used to /// unregister this handler at a later time. -auto DiagnosticEngine::registerHandler(const HandlerTy &handler) -> HandlerID { +auto DiagnosticEngine::registerHandler(HandlerTy handler) -> HandlerID { llvm::sys::SmartScopedLock lock(impl->mutex); auto uniqueID = impl->uniqueHandlerId++; - impl->handlers.insert({uniqueID, handler}); + impl->handlers.insert({uniqueID, std::move(handler)}); return uniqueID; }