Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions orc-rt/include/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ set(ORC_RT_HEADERS
orc-rt/SPSMemoryFlags.h
orc-rt/SPSWrapperFunction.h
orc-rt/SPSWrapperFunctionBuffer.h
orc-rt/TaskDispatcher.h
orc-rt/ThreadPoolTaskDispatcher.h
orc-rt/WrapperFunction.h
orc-rt/bind.h
orc-rt/bit.h
Expand Down
23 changes: 20 additions & 3 deletions orc-rt/include/orc-rt/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

#include "orc-rt/Error.h"
#include "orc-rt/ResourceManager.h"
#include "orc-rt/TaskDispatcher.h"
#include "orc-rt/move_only_function.h"

#include "orc-rt-c/CoreTypes.h"

#include <condition_variable>
#include <memory>
#include <mutex>
#include <vector>
Expand All @@ -39,7 +41,10 @@ class Session {
///
/// Note that entry into the reporter is not synchronized: it may be
/// called from multiple threads concurrently.
Session(ErrorReporterFn ReportError) : ReportError(std::move(ReportError)) {}
Session(std::unique_ptr<TaskDispatcher> Dispatcher,
ErrorReporterFn ReportError)
: Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) {
}

// Sessions are not copyable or moveable.
Session(const Session &) = delete;
Expand All @@ -49,6 +54,9 @@ class Session {

~Session();

/// Dispatch a task using the Session's TaskDispatcher.
void dispatch(std::unique_ptr<Task> T) { Dispatcher->dispatch(std::move(T)); }

/// Report an error via the ErrorReporter function.
void reportError(Error Err) { ReportError(std::move(Err)); }

Expand All @@ -67,12 +75,21 @@ class Session {
}

private:
void shutdownNext(OnShutdownCompleteFn OnShutdownComplete, Error Err,
void shutdownNext(Error Err,
std::vector<std::unique_ptr<ResourceManager>> RemainingRMs);

std::mutex M;
void shutdownComplete();

std::unique_ptr<TaskDispatcher> Dispatcher;
ErrorReporterFn ReportError;

enum class SessionState { Running, ShuttingDown, Shutdown };

std::mutex M;
SessionState State = SessionState::Running;
std::condition_variable StateCV;
std::vector<std::unique_ptr<ResourceManager>> ResourceMgrs;
std::vector<OnShutdownCompleteFn> ShutdownCallbacks;
};

inline orc_rt_SessionRef wrap(Session *S) noexcept {
Expand Down
64 changes: 64 additions & 0 deletions orc-rt/include/orc-rt/TaskDispatcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//===----------- TaskDispatcher.h - Task dispatch utils ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Task and TaskDispatcher classes.
//
//===----------------------------------------------------------------------===//

#ifndef ORC_RT_TASKDISPATCHER_H
#define ORC_RT_TASKDISPATCHER_H

#include "orc-rt/RTTI.h"

#include <memory>
#include <utility>

namespace orc_rt {

/// Represents an abstract task to be run.
class Task : public RTTIExtends<Task, RTTIRoot> {
public:
virtual ~Task();
virtual void run() = 0;
};

/// Base class for generic tasks.
class GenericTask : public RTTIExtends<GenericTask, Task> {};

/// Generic task implementation.
template <typename FnT> class GenericTaskImpl : public GenericTask {
public:
GenericTaskImpl(FnT &&Fn) : Fn(std::forward<FnT>(Fn)) {}
void run() override { Fn(); }

private:
FnT Fn;
};

/// Create a generic task from a function object.
template <typename FnT> std::unique_ptr<GenericTask> makeGenericTask(FnT &&Fn) {
return std::make_unique<GenericTaskImpl<std::decay_t<FnT>>>(
std::forward<FnT>(Fn));
}

/// Abstract base for classes that dispatch Tasks.
class TaskDispatcher {
public:
virtual ~TaskDispatcher();

/// Run the given task.
virtual void dispatch(std::unique_ptr<Task> T) = 0;

/// Called by Session. Should cause further dispatches to be rejected, and
/// wait until all previously dispatched tasks have completed.
virtual void shutdown() = 0;
};

} // End namespace orc_rt

#endif // ORC_RT_TASKDISPATCHER_H
48 changes: 48 additions & 0 deletions orc-rt/include/orc-rt/ThreadPoolTaskDispatcher.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//===--- ThreadPoolTaskDispatcher.h - Run tasks in thread pool --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// ThreadPoolTaskDispatcher implementation.
//
//===----------------------------------------------------------------------===//

#ifndef ORC_RT_THREADPOOLTASKDISPATCHER_H
#define ORC_RT_THREADPOOLTASKDISPATCHER_H

#include "orc-rt/TaskDispatcher.h"

#include <condition_variable>
#include <mutex>
#include <thread>
#include <vector>

namespace orc_rt {

/// Thread-pool based TaskDispatcher.
///
/// Will spawn NumThreads threads to run dispatched Tasks.
class ThreadPoolTaskDispatcher : public TaskDispatcher {
public:
ThreadPoolTaskDispatcher(size_t NumThreads);
~ThreadPoolTaskDispatcher() override;
void dispatch(std::unique_ptr<Task> T) override;
void shutdown() override;

private:
void taskLoop();

std::vector<std::thread> Threads;

std::mutex M;
bool AcceptingTasks = true;
std::condition_variable CV;
std::vector<std::unique_ptr<Task>> PendingTasks;
};

} // End namespace orc_rt

#endif // ORC_RT_THREADPOOLTASKDISPATCHER_H
2 changes: 2 additions & 0 deletions orc-rt/lib/executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ set(files
RTTI.cpp
Session.cpp
SimpleNativeMemoryMap.cpp
TaskDispatcher.cpp
ThreadPoolTaskDispatcher.cpp
)

add_library(orc-rt-executor STATIC ${files})
Expand Down
58 changes: 40 additions & 18 deletions orc-rt/lib/executor/Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

#include "orc-rt/Session.h"

#include <future>

namespace orc_rt {

Session::~Session() { waitForShutdown(); }
Expand All @@ -23,38 +21,62 @@ void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) {

{
std::scoped_lock<std::mutex> Lock(M);
ShutdownCallbacks.push_back(std::move(OnShutdownComplete));

// If somebody else has already called shutdown then there's nothing further
// for us to do here.
if (State >= SessionState::ShuttingDown)
return;

State = SessionState::ShuttingDown;
std::swap(ResourceMgrs, ToShutdown);
}

shutdownNext(std::move(OnShutdownComplete), Error::success(),
std::move(ToShutdown));
shutdownNext(Error::success(), std::move(ToShutdown));
}

void Session::waitForShutdown() {
std::promise<void> P;
auto F = P.get_future();

shutdown([P = std::move(P)]() mutable { P.set_value(); });

F.wait();
shutdown([]() {});
std::unique_lock<std::mutex> Lock(M);
StateCV.wait(Lock, [&]() { return State == SessionState::Shutdown; });
}

void Session::shutdownNext(
OnShutdownCompleteFn OnComplete, Error Err,
std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
Error Err, std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
if (Err)
reportError(std::move(Err));

if (RemainingRMs.empty())
return OnComplete();
return shutdownComplete();

auto NextRM = std::move(RemainingRMs.back());
RemainingRMs.pop_back();
NextRM->shutdown([this, RemainingRMs = std::move(RemainingRMs),
OnComplete = std::move(OnComplete)](Error Err) mutable {
shutdownNext(std::move(OnComplete), std::move(Err),
std::move(RemainingRMs));
});
NextRM->shutdown(
[this, RemainingRMs = std::move(RemainingRMs)](Error Err) mutable {
shutdownNext(std::move(Err), std::move(RemainingRMs));
});
}

void Session::shutdownComplete() {

std::unique_ptr<TaskDispatcher> TmpDispatcher;
std::vector<OnShutdownCompleteFn> TmpShutdownCallbacks;
{
std::lock_guard<std::mutex> Lock(M);
TmpDispatcher = std::move(Dispatcher);
TmpShutdownCallbacks = std::move(ShutdownCallbacks);
}

TmpDispatcher->shutdown();

for (auto &OnShutdownComplete : TmpShutdownCallbacks)
OnShutdownComplete();

{
std::lock_guard<std::mutex> Lock(M);
State = SessionState::Shutdown;
}
StateCV.notify_all();
}

} // namespace orc_rt
20 changes: 20 additions & 0 deletions orc-rt/lib/executor/TaskDispatcher.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- TaskDispatch.cpp ---------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains the implementation of APIs in the orc-rt/TaskDispatch.h header.
//
//===----------------------------------------------------------------------===//

#include "orc-rt/TaskDispatcher.h"

namespace orc_rt {

Task::~Task() = default;
TaskDispatcher::~TaskDispatcher() = default;

} // namespace orc_rt
70 changes: 70 additions & 0 deletions orc-rt/lib/executor/ThreadPoolTaskDispatcher.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//===- ThreadPoolTaskDispatch.cpp -----------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Contains the implementation of APIs in the orc-rt/ThreadPoolTaskDispatch.h
// header.
//
//===----------------------------------------------------------------------===//

#include "orc-rt/ThreadPoolTaskDispatcher.h"

#include <cassert>

namespace orc_rt {

ThreadPoolTaskDispatcher::~ThreadPoolTaskDispatcher() {
assert(!AcceptingTasks && "shutdown was not run");
}

ThreadPoolTaskDispatcher::ThreadPoolTaskDispatcher(size_t NumThreads) {
Threads.reserve(NumThreads);
for (size_t I = 0; I < NumThreads; ++I)
Threads.emplace_back([this]() { taskLoop(); });
}

void ThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
{
std::scoped_lock<std::mutex> Lock(M);
if (!AcceptingTasks)
return;
PendingTasks.push_back(std::move(T));
}
CV.notify_one();
}

void ThreadPoolTaskDispatcher::shutdown() {
{
std::scoped_lock<std::mutex> Lock(M);
assert(AcceptingTasks && "ThreadPoolTaskDispatcher already shut down?");
AcceptingTasks = false;
}
CV.notify_all();
for (auto &Thread : Threads)
Thread.join();
}

void ThreadPoolTaskDispatcher::taskLoop() {
while (true) {
std::unique_ptr<Task> T;
{
std::unique_lock<std::mutex> Lock(M);
CV.wait(Lock,
[this]() { return !PendingTasks.empty() || !AcceptingTasks; });

if (!AcceptingTasks && PendingTasks.empty())
return;

T = std::move(PendingTasks.back());
PendingTasks.pop_back();
}

T->run();
}
}

} // namespace orc_rt
1 change: 1 addition & 0 deletions orc-rt/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_orc_rt_unittest(CoreTests
SPSMemoryFlagsTest.cpp
SPSWrapperFunctionTest.cpp
SPSWrapperFunctionBufferTest.cpp
ThreadPoolTaskDispatcherTest.cpp
WrapperFunctionBufferTest.cpp
bind-test.cpp
bit-test.cpp
Expand Down
Loading
Loading