From 951bec15eec1b941081aacae6c4a30a931e01506 Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Mon, 17 Nov 2025 19:28:49 +1100 Subject: [PATCH] [orc-rt] Introduce Task and TaskDispatcher APIs and implementations. Introduces the Task and TaskDispatcher interfaces (TaskDispatcher.h), ThreadPoolTaskDispatcher implementation (ThreadPoolTaskDispatch.h), and updates Session to include a TaskDispatcher instance that can be used to run tasks. TaskDispatcher's introduction is motivated by the need to handle calls to JIT'd code initiated from the controller process: Incoming calls will be wrapped in Tasks and dispatched. Session shutdown will wait on TaskDispatcher shutdown, ensuring that all Tasks are run or destroyed prior to the Session being destroyed. --- orc-rt/include/CMakeLists.txt | 2 + orc-rt/include/orc-rt/Session.h | 23 +++- orc-rt/include/orc-rt/TaskDispatcher.h | 64 ++++++++++ .../include/orc-rt/ThreadPoolTaskDispatcher.h | 48 ++++++++ orc-rt/lib/executor/CMakeLists.txt | 2 + orc-rt/lib/executor/Session.cpp | 58 ++++++--- orc-rt/lib/executor/TaskDispatcher.cpp | 20 ++++ .../lib/executor/ThreadPoolTaskDispatcher.cpp | 70 +++++++++++ orc-rt/unittests/CMakeLists.txt | 1 + orc-rt/unittests/SessionTest.cpp | 94 ++++++++++++++- .../ThreadPoolTaskDispatcherTest.cpp | 110 ++++++++++++++++++ 11 files changed, 467 insertions(+), 25 deletions(-) create mode 100644 orc-rt/include/orc-rt/TaskDispatcher.h create mode 100644 orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h create mode 100644 orc-rt/lib/executor/TaskDispatcher.cpp create mode 100644 orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp create mode 100644 orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp diff --git a/orc-rt/include/CMakeLists.txt b/orc-rt/include/CMakeLists.txt index 8ac8a126dd012..35c45e236c023 100644 --- a/orc-rt/include/CMakeLists.txt +++ b/orc-rt/include/CMakeLists.txt @@ -22,6 +22,8 @@ set(ORC_RT_HEADERS orc-rt/SPSMemoryFlags.h orc-rt/SPSWrapperFunction.h orc-rt/SPSWrapperFunctionBuffer.h + orc-rt/TaskDispatcher.h + orc-rt/ThreadPoolTaskDispatcher.h orc-rt/WrapperFunction.h orc-rt/bind.h orc-rt/bit.h diff --git a/orc-rt/include/orc-rt/Session.h b/orc-rt/include/orc-rt/Session.h index 78bd92bb0d0c8..367cdb9a97b62 100644 --- a/orc-rt/include/orc-rt/Session.h +++ b/orc-rt/include/orc-rt/Session.h @@ -15,10 +15,12 @@ #include "orc-rt/Error.h" #include "orc-rt/ResourceManager.h" +#include "orc-rt/TaskDispatcher.h" #include "orc-rt/move_only_function.h" #include "orc-rt-c/CoreTypes.h" +#include #include #include #include @@ -39,7 +41,10 @@ class Session { /// /// Note that entry into the reporter is not synchronized: it may be /// called from multiple threads concurrently. - Session(ErrorReporterFn ReportError) : ReportError(std::move(ReportError)) {} + Session(std::unique_ptr Dispatcher, + ErrorReporterFn ReportError) + : Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) { + } // Sessions are not copyable or moveable. Session(const Session &) = delete; @@ -49,6 +54,9 @@ class Session { ~Session(); + /// Dispatch a task using the Session's TaskDispatcher. + void dispatch(std::unique_ptr T) { Dispatcher->dispatch(std::move(T)); } + /// Report an error via the ErrorReporter function. void reportError(Error Err) { ReportError(std::move(Err)); } @@ -67,12 +75,21 @@ class Session { } private: - void shutdownNext(OnShutdownCompleteFn OnShutdownComplete, Error Err, + void shutdownNext(Error Err, std::vector> RemainingRMs); - std::mutex M; + void shutdownComplete(); + + std::unique_ptr Dispatcher; ErrorReporterFn ReportError; + + enum class SessionState { Running, ShuttingDown, Shutdown }; + + std::mutex M; + SessionState State = SessionState::Running; + std::condition_variable StateCV; std::vector> ResourceMgrs; + std::vector ShutdownCallbacks; }; inline orc_rt_SessionRef wrap(Session *S) noexcept { diff --git a/orc-rt/include/orc-rt/TaskDispatcher.h b/orc-rt/include/orc-rt/TaskDispatcher.h new file mode 100644 index 0000000000000..f49d537ef25f7 --- /dev/null +++ b/orc-rt/include/orc-rt/TaskDispatcher.h @@ -0,0 +1,64 @@ +//===----------- TaskDispatcher.h - Task dispatch utils ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Task and TaskDispatcher classes. +// +//===----------------------------------------------------------------------===// + +#ifndef ORC_RT_TASKDISPATCHER_H +#define ORC_RT_TASKDISPATCHER_H + +#include "orc-rt/RTTI.h" + +#include +#include + +namespace orc_rt { + +/// Represents an abstract task to be run. +class Task : public RTTIExtends { +public: + virtual ~Task(); + virtual void run() = 0; +}; + +/// Base class for generic tasks. +class GenericTask : public RTTIExtends {}; + +/// Generic task implementation. +template class GenericTaskImpl : public GenericTask { +public: + GenericTaskImpl(FnT &&Fn) : Fn(std::forward(Fn)) {} + void run() override { Fn(); } + +private: + FnT Fn; +}; + +/// Create a generic task from a function object. +template std::unique_ptr makeGenericTask(FnT &&Fn) { + return std::make_unique>>( + std::forward(Fn)); +} + +/// Abstract base for classes that dispatch Tasks. +class TaskDispatcher { +public: + virtual ~TaskDispatcher(); + + /// Run the given task. + virtual void dispatch(std::unique_ptr T) = 0; + + /// Called by Session. Should cause further dispatches to be rejected, and + /// wait until all previously dispatched tasks have completed. + virtual void shutdown() = 0; +}; + +} // End namespace orc_rt + +#endif // ORC_RT_TASKDISPATCHER_H diff --git a/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h b/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h new file mode 100644 index 0000000000000..227c3500a1321 --- /dev/null +++ b/orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h @@ -0,0 +1,48 @@ +//===--- ThreadPoolTaskDispatcher.h - Run tasks in thread pool --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// ThreadPoolTaskDispatcher implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef ORC_RT_THREADPOOLTASKDISPATCHER_H +#define ORC_RT_THREADPOOLTASKDISPATCHER_H + +#include "orc-rt/TaskDispatcher.h" + +#include +#include +#include +#include + +namespace orc_rt { + +/// Thread-pool based TaskDispatcher. +/// +/// Will spawn NumThreads threads to run dispatched Tasks. +class ThreadPoolTaskDispatcher : public TaskDispatcher { +public: + ThreadPoolTaskDispatcher(size_t NumThreads); + ~ThreadPoolTaskDispatcher() override; + void dispatch(std::unique_ptr T) override; + void shutdown() override; + +private: + void taskLoop(); + + std::vector Threads; + + std::mutex M; + bool AcceptingTasks = true; + std::condition_variable CV; + std::vector> PendingTasks; +}; + +} // End namespace orc_rt + +#endif // ORC_RT_THREADPOOLTASKDISPATCHER_H diff --git a/orc-rt/lib/executor/CMakeLists.txt b/orc-rt/lib/executor/CMakeLists.txt index 9750d8e048f74..58b5ec2189d43 100644 --- a/orc-rt/lib/executor/CMakeLists.txt +++ b/orc-rt/lib/executor/CMakeLists.txt @@ -4,6 +4,8 @@ set(files RTTI.cpp Session.cpp SimpleNativeMemoryMap.cpp + TaskDispatcher.cpp + ThreadPoolTaskDispatcher.cpp ) add_library(orc-rt-executor STATIC ${files}) diff --git a/orc-rt/lib/executor/Session.cpp b/orc-rt/lib/executor/Session.cpp index 599bc8705f397..fafa13b1cbb08 100644 --- a/orc-rt/lib/executor/Session.cpp +++ b/orc-rt/lib/executor/Session.cpp @@ -12,8 +12,6 @@ #include "orc-rt/Session.h" -#include - namespace orc_rt { Session::~Session() { waitForShutdown(); } @@ -23,38 +21,62 @@ void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) { { std::scoped_lock Lock(M); + ShutdownCallbacks.push_back(std::move(OnShutdownComplete)); + + // If somebody else has already called shutdown then there's nothing further + // for us to do here. + if (State >= SessionState::ShuttingDown) + return; + + State = SessionState::ShuttingDown; std::swap(ResourceMgrs, ToShutdown); } - shutdownNext(std::move(OnShutdownComplete), Error::success(), - std::move(ToShutdown)); + shutdownNext(Error::success(), std::move(ToShutdown)); } void Session::waitForShutdown() { - std::promise P; - auto F = P.get_future(); - - shutdown([P = std::move(P)]() mutable { P.set_value(); }); - - F.wait(); + shutdown([]() {}); + std::unique_lock Lock(M); + StateCV.wait(Lock, [&]() { return State == SessionState::Shutdown; }); } void Session::shutdownNext( - OnShutdownCompleteFn OnComplete, Error Err, - std::vector> RemainingRMs) { + Error Err, std::vector> RemainingRMs) { if (Err) reportError(std::move(Err)); if (RemainingRMs.empty()) - return OnComplete(); + return shutdownComplete(); auto NextRM = std::move(RemainingRMs.back()); RemainingRMs.pop_back(); - NextRM->shutdown([this, RemainingRMs = std::move(RemainingRMs), - OnComplete = std::move(OnComplete)](Error Err) mutable { - shutdownNext(std::move(OnComplete), std::move(Err), - std::move(RemainingRMs)); - }); + NextRM->shutdown( + [this, RemainingRMs = std::move(RemainingRMs)](Error Err) mutable { + shutdownNext(std::move(Err), std::move(RemainingRMs)); + }); +} + +void Session::shutdownComplete() { + + std::unique_ptr TmpDispatcher; + std::vector TmpShutdownCallbacks; + { + std::lock_guard Lock(M); + TmpDispatcher = std::move(Dispatcher); + TmpShutdownCallbacks = std::move(ShutdownCallbacks); + } + + TmpDispatcher->shutdown(); + + for (auto &OnShutdownComplete : TmpShutdownCallbacks) + OnShutdownComplete(); + + { + std::lock_guard Lock(M); + State = SessionState::Shutdown; + } + StateCV.notify_all(); } } // namespace orc_rt diff --git a/orc-rt/lib/executor/TaskDispatcher.cpp b/orc-rt/lib/executor/TaskDispatcher.cpp new file mode 100644 index 0000000000000..5f34627fb5150 --- /dev/null +++ b/orc-rt/lib/executor/TaskDispatcher.cpp @@ -0,0 +1,20 @@ +//===- TaskDispatch.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the implementation of APIs in the orc-rt/TaskDispatch.h header. +// +//===----------------------------------------------------------------------===// + +#include "orc-rt/TaskDispatcher.h" + +namespace orc_rt { + +Task::~Task() = default; +TaskDispatcher::~TaskDispatcher() = default; + +} // namespace orc_rt diff --git a/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp b/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp new file mode 100644 index 0000000000000..d6d301302220d --- /dev/null +++ b/orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp @@ -0,0 +1,70 @@ +//===- ThreadPoolTaskDispatch.cpp -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains the implementation of APIs in the orc-rt/ThreadPoolTaskDispatch.h +// header. +// +//===----------------------------------------------------------------------===// + +#include "orc-rt/ThreadPoolTaskDispatcher.h" + +#include + +namespace orc_rt { + +ThreadPoolTaskDispatcher::~ThreadPoolTaskDispatcher() { + assert(!AcceptingTasks && "shutdown was not run"); +} + +ThreadPoolTaskDispatcher::ThreadPoolTaskDispatcher(size_t NumThreads) { + Threads.reserve(NumThreads); + for (size_t I = 0; I < NumThreads; ++I) + Threads.emplace_back([this]() { taskLoop(); }); +} + +void ThreadPoolTaskDispatcher::dispatch(std::unique_ptr T) { + { + std::scoped_lock Lock(M); + if (!AcceptingTasks) + return; + PendingTasks.push_back(std::move(T)); + } + CV.notify_one(); +} + +void ThreadPoolTaskDispatcher::shutdown() { + { + std::scoped_lock Lock(M); + assert(AcceptingTasks && "ThreadPoolTaskDispatcher already shut down?"); + AcceptingTasks = false; + } + CV.notify_all(); + for (auto &Thread : Threads) + Thread.join(); +} + +void ThreadPoolTaskDispatcher::taskLoop() { + while (true) { + std::unique_ptr T; + { + std::unique_lock Lock(M); + CV.wait(Lock, + [this]() { return !PendingTasks.empty() || !AcceptingTasks; }); + + if (!AcceptingTasks && PendingTasks.empty()) + return; + + T = std::move(PendingTasks.back()); + PendingTasks.pop_back(); + } + + T->run(); + } +} + +} // namespace orc_rt diff --git a/orc-rt/unittests/CMakeLists.txt b/orc-rt/unittests/CMakeLists.txt index 7b943e8039449..c43ec17b54de3 100644 --- a/orc-rt/unittests/CMakeLists.txt +++ b/orc-rt/unittests/CMakeLists.txt @@ -31,6 +31,7 @@ add_orc_rt_unittest(CoreTests SPSMemoryFlagsTest.cpp SPSWrapperFunctionTest.cpp SPSWrapperFunctionBufferTest.cpp + ThreadPoolTaskDispatcherTest.cpp WrapperFunctionBufferTest.cpp bind-test.cpp bit-test.cpp diff --git a/orc-rt/unittests/SessionTest.cpp b/orc-rt/unittests/SessionTest.cpp index 7e6084484e227..85b82e65744b0 100644 --- a/orc-rt/unittests/SessionTest.cpp +++ b/orc-rt/unittests/SessionTest.cpp @@ -11,11 +11,17 @@ //===----------------------------------------------------------------------===// #include "orc-rt/Session.h" +#include "orc-rt/ThreadPoolTaskDispatcher.h" + #include "gmock/gmock.h" #include "gtest/gtest.h" +#include +#include #include +#include + using namespace orc_rt; using ::testing::Eq; using ::testing::Optional; @@ -49,17 +55,47 @@ class MockResourceManager : public ResourceManager { move_only_function GenResult; }; +class NoDispatcher : public TaskDispatcher { +public: + void dispatch(std::unique_ptr T) override { + assert(false && "strictly no dispatching!"); + } + void shutdown() override {} +}; + +class EnqueueingDispatcher : public TaskDispatcher { +public: + using OnShutdownRunFn = move_only_function; + EnqueueingDispatcher(std::deque> &Tasks, + OnShutdownRunFn OnShutdownRun = {}) + : Tasks(Tasks), OnShutdownRun(std::move(OnShutdownRun)) {} + void dispatch(std::unique_ptr T) override { + Tasks.push_back(std::move(T)); + } + void shutdown() override { + if (OnShutdownRun) + OnShutdownRun(); + } + +private: + std::deque> &Tasks; + OnShutdownRunFn OnShutdownRun; +}; + // Non-overloaded version of cantFail: allows easy construction of // move_only_functionss. static void noErrors(Error Err) { cantFail(std::move(Err)); } -TEST(SessionTest, TrivialConstructionAndDestruction) { Session S(noErrors); } +TEST(SessionTest, TrivialConstructionAndDestruction) { + Session S(std::make_unique(), noErrors); +} TEST(SessionTest, ReportError) { Error E = Error::success(); cantFail(std::move(E)); // Force error into checked state. - Session S([&](Error Err) { E = std::move(Err); }); + Session S(std::make_unique(), + [&](Error Err) { E = std::move(Err); }); S.reportError(make_error("foo")); if (E) @@ -68,13 +104,27 @@ TEST(SessionTest, ReportError) { ADD_FAILURE() << "Missing error value"; } +TEST(SessionTest, DispatchTask) { + int X = 0; + std::deque> Tasks; + Session S(std::make_unique(Tasks), noErrors); + + EXPECT_EQ(Tasks.size(), 0U); + S.dispatch(makeGenericTask([&]() { ++X; })); + EXPECT_EQ(Tasks.size(), 1U); + auto T = std::move(Tasks.front()); + Tasks.pop_front(); + T->run(); + EXPECT_EQ(X, 1); +} + TEST(SessionTest, SingleResourceManager) { size_t OpIdx = 0; std::optional DetachOpIdx; std::optional ShutdownOpIdx; { - Session S(noErrors); + Session S(std::make_unique(), noErrors); S.addResourceManager(std::make_unique( DetachOpIdx, ShutdownOpIdx, OpIdx)); } @@ -90,7 +140,7 @@ TEST(SessionTest, MultipleResourceManagers) { std::optional ShutdownOpIdx[3]; { - Session S(noErrors); + Session S(std::make_unique(), noErrors); for (size_t I = 0; I != 3; ++I) S.addResourceManager(std::make_unique( DetachOpIdx[I], ShutdownOpIdx[I], OpIdx)); @@ -103,3 +153,39 @@ TEST(SessionTest, MultipleResourceManagers) { EXPECT_THAT(ShutdownOpIdx[I], Optional(Eq(2 - I))); } } + +TEST(SessionTest, ExpectedShutdownSequence) { + // Check that Session shutdown results in... + // 1. ResourceManagers being shut down. + // 2. The TaskDispatcher being shut down. + // 3. A call to OnShutdownComplete. + + size_t OpIdx = 0; + std::optional DetachOpIdx; + std::optional ShutdownOpIdx; + + bool DispatcherShutDown = false; + bool SessionShutdownComplete = false; + std::deque> Tasks; + Session S(std::make_unique( + Tasks, + [&]() { + std::cerr << "Running dispatcher shutdown.\n"; + EXPECT_TRUE(ShutdownOpIdx); + EXPECT_EQ(*ShutdownOpIdx, 0); + EXPECT_FALSE(SessionShutdownComplete); + DispatcherShutDown = true; + }), + noErrors); + S.addResourceManager( + std::make_unique(DetachOpIdx, ShutdownOpIdx, OpIdx)); + + S.shutdown([&]() { + EXPECT_TRUE(DispatcherShutDown); + std::cerr << "Running shutdown callback.\n"; + SessionShutdownComplete = true; + }); + S.waitForShutdown(); + + EXPECT_TRUE(SessionShutdownComplete); +} diff --git a/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp b/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp new file mode 100644 index 0000000000000..02cca94a494ff --- /dev/null +++ b/orc-rt/unittests/ThreadPoolTaskDispatcherTest.cpp @@ -0,0 +1,110 @@ +//===-- ThreadPoolTaskDispatcherTest.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "orc-rt/ThreadPoolTaskDispatcher.h" +#include "gtest/gtest.h" + +#include +#include +#include +#include + +using namespace orc_rt; + +namespace { + +TEST(ThreadPoolTaskDispatcherTest, NoTasks) { + // Check that immediate shutdown works as expected. + ThreadPoolTaskDispatcher Dispatcher(1); + Dispatcher.shutdown(); +} + +TEST(ThreadPoolTaskDispatcherTest, BasicTaskExecution) { + // Smoke test: Check that we can run a single task on a single-threaded pool. + ThreadPoolTaskDispatcher Dispatcher(1); + std::atomic TaskRan = false; + + Dispatcher.dispatch(makeGenericTask([&]() { TaskRan = true; })); + + Dispatcher.shutdown(); + + EXPECT_TRUE(TaskRan); +} + +TEST(ThreadPoolTaskDispatcherTest, SingleThreadMultipleTasks) { + // Check that multiple tasks in a single threaded pool run as expected. + ThreadPoolTaskDispatcher Dispatcher(1); + size_t NumTasksToRun = 10; + std::atomic TasksRun = 0; + + for (size_t I = 0; I != NumTasksToRun; ++I) + Dispatcher.dispatch(makeGenericTask([&]() { ++TasksRun; })); + + Dispatcher.shutdown(); + + EXPECT_EQ(TasksRun, NumTasksToRun); +} + +TEST(ThreadPoolTaskDispatcherTest, ConcurrentTasks) { + // Check that tasks are run concurrently when multiple workers are available. + // Adds two tasks that communicate a value back and forth using futures. + // Neither task should be able to complete without the other having started. + ThreadPoolTaskDispatcher Dispatcher(2); + + std::promise PInit; + std::future FInit = PInit.get_future(); + std::promise P1; + std::future F1 = P1.get_future(); + std::promise P2; + std::future F2 = P2.get_future(); + std::promise PResult; + std::future FResult = PResult.get_future(); + + // Task A gets the initial value, sends it via P1, waits for response on F2. + Dispatcher.dispatch(makeGenericTask([&]() { + P1.set_value(FInit.get()); + PResult.set_value(F2.get()); + })); + + // Task B gets value from F1, sends it back on P2. + Dispatcher.dispatch(makeGenericTask([&]() { P2.set_value(F1.get()); })); + + int ExpectedValue = 42; + PInit.set_value(ExpectedValue); + + Dispatcher.shutdown(); + + EXPECT_EQ(FResult.get(), ExpectedValue); +} + +TEST(ThreadPoolTaskDispatcherTest, TasksRejectedAfterShutdown) { + class TaskToReject : public Task { + public: + TaskToReject(bool &BodyRun, bool &DestructorRun) + : BodyRun(BodyRun), DestructorRun(DestructorRun) {} + ~TaskToReject() { DestructorRun = true; } + void run() override { BodyRun = true; } + + private: + bool &BodyRun; + bool &DestructorRun; + }; + + ThreadPoolTaskDispatcher Dispatcher(1); + Dispatcher.shutdown(); + + bool BodyRun = false; + bool DestructorRun = false; + + Dispatcher.dispatch(std::make_unique(BodyRun, DestructorRun)); + + EXPECT_FALSE(BodyRun); + EXPECT_TRUE(DestructorRun); +} + +} // end anonymous namespace