diff --git a/llvm/include/llvm/Support/ThreadPool.h b/llvm/include/llvm/Support/ThreadPool.h index c20efc7396b79..d3276a18dc2c6 100644 --- a/llvm/include/llvm/Support/ThreadPool.h +++ b/llvm/include/llvm/Support/ThreadPool.h @@ -14,6 +14,7 @@ #define LLVM_SUPPORT_THREADPOOL_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/FunctionExtras.h" #include "llvm/Config/llvm-config.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Jobserver.h" @@ -51,7 +52,7 @@ class ThreadPoolTaskGroup; class LLVM_ABI ThreadPoolInterface { /// The actual method to enqueue a task to be defined by the concrete /// implementation. - virtual void asyncEnqueue(std::function Task, + virtual void asyncEnqueue(llvm::unique_function Task, ThreadPoolTaskGroup *Group) = 0; public: @@ -95,22 +96,22 @@ class LLVM_ABI ThreadPoolInterface { /// used to wait for the task to finish and is *non-blocking* on destruction. template auto async(Func &&F) -> std::shared_future { - return asyncImpl(std::function(std::forward(F)), - nullptr); + return asyncImpl( + llvm::unique_function(std::forward(F)), nullptr); } template auto async(ThreadPoolTaskGroup &Group, Func &&F) -> std::shared_future { - return asyncImpl(std::function(std::forward(F)), - &Group); + return asyncImpl( + llvm::unique_function(std::forward(F)), &Group); } private: /// Asynchronous submission of a task to the pool. The returned future can be /// used to wait for the task to finish and is *non-blocking* on destruction. template - std::shared_future asyncImpl(std::function Task, + std::shared_future asyncImpl(llvm::unique_function Task, ThreadPoolTaskGroup *Group) { auto Future = std::async(std::launch::deferred, std::move(Task)).share(); asyncEnqueue([Future]() { Future.wait(); }, Group); @@ -160,7 +161,7 @@ class LLVM_ABI StdThreadPool : public ThreadPoolInterface { /// Asynchronous submission of a task to the pool. The returned future can be /// used to wait for the task to finish and is *non-blocking* on destruction. - void asyncEnqueue(std::function Task, + void asyncEnqueue(llvm::unique_function Task, ThreadPoolTaskGroup *Group) override { int requestedThreads; { @@ -189,7 +190,8 @@ class LLVM_ABI StdThreadPool : public ThreadPoolInterface { mutable llvm::sys::RWMutex ThreadsLock; /// Tasks waiting for execution in the pool. - std::deque, ThreadPoolTaskGroup *>> Tasks; + std::deque, ThreadPoolTaskGroup *>> + Tasks; /// Locking and signaling for accessing the Tasks queue. std::mutex QueueLock; @@ -239,13 +241,14 @@ class LLVM_ABI SingleThreadExecutor : public ThreadPoolInterface { private: /// Asynchronous submission of a task to the pool. The returned future can be /// used to wait for the task to finish and is *non-blocking* on destruction. - void asyncEnqueue(std::function Task, + void asyncEnqueue(llvm::unique_function Task, ThreadPoolTaskGroup *Group) override { Tasks.emplace_back(std::make_pair(std::move(Task), Group)); } /// Tasks waiting for execution in the pool. - std::deque, ThreadPoolTaskGroup *>> Tasks; + std::deque, ThreadPoolTaskGroup *>> + Tasks; }; #if LLVM_ENABLE_THREADS diff --git a/llvm/lib/Support/ThreadPool.cpp b/llvm/lib/Support/ThreadPool.cpp index 69602688cf3fd..4779e673cc055 100644 --- a/llvm/lib/Support/ThreadPool.cpp +++ b/llvm/lib/Support/ThreadPool.cpp @@ -73,7 +73,7 @@ static LLVM_THREAD_LOCAL std::vector // WaitingForGroup == nullptr means all tasks regardless of their group. void StdThreadPool::processTasks(ThreadPoolTaskGroup *WaitingForGroup) { while (true) { - std::function Task; + llvm::unique_function Task; ThreadPoolTaskGroup *GroupOfTask; { std::unique_lock LockGuard(QueueLock); @@ -189,7 +189,7 @@ void StdThreadPool::processTasksWithJobserver() { // While we hold a job slot, process tasks from the internal queue. while (true) { - std::function Task; + llvm::unique_function Task; ThreadPoolTaskGroup *GroupOfTask = nullptr; { diff --git a/llvm/unittests/Support/ThreadPool.cpp b/llvm/unittests/Support/ThreadPool.cpp index aa7f8744e1417..b5268c82e4199 100644 --- a/llvm/unittests/Support/ThreadPool.cpp +++ b/llvm/unittests/Support/ThreadPool.cpp @@ -183,6 +183,20 @@ TYPED_TEST(ThreadPoolTest, Async) { ASSERT_EQ(2, i.load()); } +TYPED_TEST(ThreadPoolTest, AsyncMoveOnly) { + CHECK_UNSUPPORTED(); + DefaultThreadPool Pool; + std::promise p; + std::future f = p.get_future(); + Pool.async([this, p = std::move(p)]() mutable { + this->waitForMainThread(); + p.set_value(42); + }); + this->setMainThreadReady(); + Pool.wait(); + ASSERT_EQ(42, f.get()); +} + TYPED_TEST(ThreadPoolTest, GetFuture) { CHECK_UNSUPPORTED(); DefaultThreadPool Pool(hardware_concurrency(2));