Skip to content
Permalink
Browse files

Merge pull request #12232 from CookiePLMonster/threadpool-improvements

Threadpool improvements
  • Loading branch information...
hrydgard committed Aug 12, 2019
2 parents 8643360 + 86a887d commit 825dac3e905cf6b576eac745c7d7a809de7c8fa0
Showing with 113 additions and 66 deletions.
  1. +1 −0 Common/Common.vcxproj
  2. +42 −0 Common/MakeUnique.h
  3. +5 −7 Common/ThreadPools.cpp
  4. +2 −2 Common/ThreadPools.h
  5. +44 −38 ext/native/thread/threadpool.cpp
  6. +19 −19 ext/native/thread/threadpool.h
@@ -431,6 +431,7 @@
<ClInclude Include="KeyMap.h" />
<ClInclude Include="Log.h" />
<ClInclude Include="LogManager.h" />
<ClInclude Include="MakeUnique.h" />
<ClInclude Include="MathUtil.h" />
<ClInclude Include="MemArena.h" />
<ClInclude Include="MemoryUtil.h" />
@@ -0,0 +1,42 @@
#pragma once

#include <memory>
#include <type_traits>

// Custom make_unique so that C++14 support will not be necessary for compilation
template<class T, class... Args,
typename std::enable_if<!std::is_array<T>::value, int>::type = 0>
std::unique_ptr<T> make_unique(Args&&... args)
{
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}

template<class T,
typename std::enable_if<std::is_array<T>::value && std::extent<T>::value == 0, int>::type = 0>
std::unique_ptr<T> make_unique(std::size_t size)
{
return std::unique_ptr<T>(new typename std::remove_extent<T>::type[size]());
}

template<class T, class... Args,
typename std::enable_if<std::extent<T>::value != 0, int>::type = 0>
void make_unique(Args&&... args) = delete;


template<class T,
typename std::enable_if<!std::is_array<T>::value, int>::type = 0>
std::unique_ptr<T> make_unique_default_init()
{
return std::unique_ptr<T>(new T);
}

template<class T,
typename std::enable_if<std::is_array<T>::value && std::extent<T>::value == 0, int>::type = 0>
std::unique_ptr<T> make_unique_default_init(std::size_t size)
{
return std::unique_ptr<T>(new typename std::remove_extent<T>::type[size]);
}

template<class T, class... Args,
typename std::enable_if<std::extent<T>::value != 0, int>::type = 0>
void make_unique_default_init(Args&&... args) = delete;
@@ -1,18 +1,16 @@
#include "ThreadPools.h"

#include "../Core/Config.h"
#include "Common/MakeUnique.h"

std::shared_ptr<ThreadPool> GlobalThreadPool::pool;
bool GlobalThreadPool::initialized = false;
std::unique_ptr<ThreadPool> GlobalThreadPool::pool;
std::once_flag GlobalThreadPool::init_flag;

void GlobalThreadPool::Loop(const std::function<void(int,int)>& loop, int lower, int upper) {
Inititialize();
std::call_once(init_flag, Inititialize);
pool->ParallelLoop(loop, lower, upper);
}

void GlobalThreadPool::Inititialize() {
if(!initialized) {
pool = std::make_shared<ThreadPool>(g_Config.iNumWorkerThreads);
initialized = true;
}
pool = make_unique<ThreadPool>(g_Config.iNumWorkerThreads);
}
@@ -9,7 +9,7 @@ class GlobalThreadPool {
static void Loop(const std::function<void(int,int)>& loop, int lower, int upper);

private:
static std::shared_ptr<ThreadPool> pool;
static bool initialized;
static std::unique_ptr<ThreadPool> pool;
static std::once_flag init_flag;
static void Inititialize();
};
@@ -1,61 +1,61 @@
#include "base/logging.h"
#include "thread/threadpool.h"
#include "thread/threadutil.h"
#include "Common/MakeUnique.h"

///////////////////////////// WorkerThread

WorkerThread::WorkerThread() : active(true), started(false) {
thread.reset(new std::thread(std::bind(&WorkerThread::WorkFunc, this)));
while(!started) { };
WorkerThread::~WorkerThread() {
{
std::lock_guard<std::mutex> guard(mutex);
active = false;
signal.notify_one();
}
if (thread.joinable()) {
thread.join();
}
}

WorkerThread::~WorkerThread() {
mutex.lock();
active = false;
signal.notify_one();
mutex.unlock();
thread->join();
void WorkerThread::StartUp() {
thread = std::thread(std::bind(&WorkerThread::WorkFunc, this));
}

void WorkerThread::Process(const std::function<void()>& work) {
mutex.lock();
work_ = work;
void WorkerThread::Process(std::function<void()> work) {
std::lock_guard<std::mutex> guard(mutex);
work_ = std::move(work);
jobsTarget = jobsDone + 1;
signal.notify_one();
mutex.unlock();
}

void WorkerThread::WaitForCompletion() {
std::unique_lock<std::mutex> guard(doneMutex);
if (jobsDone < jobsTarget) {
while (jobsDone < jobsTarget) {
done.wait(guard);
}
}

void WorkerThread::WorkFunc() {
setCurrentThreadName("Worker");
std::unique_lock<std::mutex> guard(mutex);
started = true;
while (active) {
signal.wait(guard);
// 'active == false' is one of the conditions for signaling,
// do not "optimize" it
while (active && jobsTarget <= jobsDone) {
signal.wait(guard);
}
if (active) {
work_();
doneMutex.lock();
done.notify_one();

std::lock_guard<std::mutex> doneGuard(doneMutex);
jobsDone++;
doneMutex.unlock();
done.notify_one();
}
}
}

LoopWorkerThread::LoopWorkerThread() : WorkerThread(true) {
thread.reset(new std::thread(std::bind(&LoopWorkerThread::WorkFunc, this)));
while (!started) { };
}

void LoopWorkerThread::Process(const std::function<void(int, int)> &work, int start, int end) {
void LoopWorkerThread::Process(std::function<void(int, int)> work, int start, int end) {
std::lock_guard<std::mutex> guard(mutex);
work_ = work;
work_ = std::move(work);
start_ = start;
end_ = end;
jobsTarget = jobsDone + 1;
@@ -65,22 +65,25 @@ void LoopWorkerThread::Process(const std::function<void(int, int)> &work, int st
void LoopWorkerThread::WorkFunc() {
setCurrentThreadName("LoopWorker");
std::unique_lock<std::mutex> guard(mutex);
started = true;
while (active) {
signal.wait(guard);
// 'active == false' is one of the conditions for signaling,
// do not "optimize" it
while (active && jobsTarget <= jobsDone) {
signal.wait(guard);
}
if (active) {
work_(start_, end_);
doneMutex.lock();
done.notify_one();

std::lock_guard<std::mutex> doneGuard(doneMutex);
jobsDone++;
doneMutex.unlock();
done.notify_one();
}
}
}

///////////////////////////// ThreadPool

ThreadPool::ThreadPool(int numThreads) : workersStarted(false) {
ThreadPool::ThreadPool(int numThreads) {
if (numThreads <= 0) {
numThreads_ = 1;
ILOG("ThreadPool: Bad number of threads %i", numThreads);
@@ -94,8 +97,11 @@ ThreadPool::ThreadPool(int numThreads) : workersStarted(false) {

void ThreadPool::StartWorkers() {
if (!workersStarted) {
for(int i = 0; i < numThreads_; ++i) {
workers.push_back(std::make_shared<LoopWorkerThread>());
workers.reserve(numThreads_ - 1);
for(int i = 0; i < numThreads_ - 1; ++i) { // create one less worker thread as the thread calling ParallelLoop will also do work
auto workerPtr = make_unique<LoopWorkerThread>();
workerPtr->StartUp();
workers.push_back(std::move(workerPtr));
}
workersStarted = true;
}
@@ -111,14 +117,14 @@ void ThreadPool::ParallelLoop(const std::function<void(int,int)> &loop, int lowe
// but doesn't matter since all our loops are power of 2
int chunk = range / numThreads_;
int s = lower;
for (int i = 0; i < numThreads_ - 1; ++i) {
workers[i]->Process(loop, s, s+chunk);
for (auto& worker : workers) {
worker->Process(loop, s, s+chunk);
s+=chunk;
}
// This is the final chunk.
loop(s, upper);
for (int i = 0; i < numThreads_ - 1; ++i) {
workers[i]->WaitForCompletion();
for (auto& worker : workers) {
worker->WaitForCompletion();
}
} else {
loop(lower, upper);
@@ -12,41 +12,41 @@
// Only handles a single item of work at a time.
class WorkerThread {
public:
WorkerThread();
WorkerThread() = default;
virtual ~WorkerThread();

void StartUp();

// submit a new work item
void Process(const std::function<void()>& work);
void Process(std::function<void()> work);
// wait for a submitted work item to be completed
void WaitForCompletion();

protected:
WorkerThread(bool ignored) : active(true), started(false) {}
virtual void WorkFunc();

std::unique_ptr<std::thread> thread; // the worker thread
std::thread thread; // the worker thread
std::condition_variable signal; // used to signal new work
std::condition_variable done; // used to signal work completion
std::mutex mutex, doneMutex; // associated with each respective condition variable
volatile bool active, started;
bool active = true;
int jobsDone = 0;
int jobsTarget = 0;
private:
virtual void WorkFunc();

std::function<void()> work_; // the work to be done by this thread

WorkerThread(const WorkerThread& other); // prevent copies
void operator =(const WorkerThread &other);
WorkerThread(const WorkerThread& other) = delete; // prevent copies
void operator =(const WorkerThread &other) = delete;
};

class LoopWorkerThread : public WorkerThread {
class LoopWorkerThread final : public WorkerThread {
public:
LoopWorkerThread();
void Process(const std::function<void(int, int)> &work, int start, int end);

protected:
virtual void WorkFunc();
LoopWorkerThread() = default;
void Process(std::function<void(int, int)> work, int start, int end);

private:
virtual void WorkFunc() override;

int start_;
int end_;
std::function<void(int, int)> work_; // the work to be done by this thread
@@ -65,13 +65,13 @@ class ThreadPool {

private:
int numThreads_;
std::vector<std::shared_ptr<LoopWorkerThread>> workers;
std::vector<std::unique_ptr<LoopWorkerThread>> workers;
std::mutex mutex; // used to sequentialize loop execution

bool workersStarted;
bool workersStarted = false;
void StartWorkers();

ThreadPool(const ThreadPool& other); // prevent copies
void operator =(const ThreadPool &other);
ThreadPool(const ThreadPool& other) = delete; // prevent copies
void operator =(const ThreadPool &other) = delete;
};

0 comments on commit 825dac3

Please sign in to comment.
You can’t perform that action at this time.