Skip to content

Commit

Permalink
Improve EventCount used by the non-blocking threadpool.
Browse files Browse the repository at this point in the history
The current algorithm requires threads to commit/cancel waiting in order
they called Prewait. Spinning caused by that serialization can consume
lots of CPU time on some workloads. Restructure the algorithm to not
require that serialization and remove spin waits from Commit/CancelWait.
Note: this reduces max number of threads from 2^16 to 2^14 to leave
more space for ABA counter (which is now 22 bits).
Implementation details are explained in comments.
  • Loading branch information
rmlarsen committed Feb 22, 2019
1 parent a4cff5a commit 01da8ca
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 93 deletions.
187 changes: 102 additions & 85 deletions unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace Eigen {
// if (predicate)
// return act();
// EventCount::Waiter& w = waiters[my_index];
// ec.Prewait(&w);
// if (!ec.Prewait(&w))
// return act();
// if (predicate) {
// ec.CancelWait(&w);
// return act();
Expand Down Expand Up @@ -50,78 +51,78 @@ class EventCount {
public:
class Waiter;

EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) {
EventCount(MaxSizeVector<Waiter>& waiters)
: state_(kStackMask), waiters_(waiters) {
eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
// Initialize epoch to something close to overflow to test overflow.
state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
}

~EventCount() {
// Ensure there are no waiters.
eigen_plain_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask);
eigen_plain_assert(state_.load() == kStackMask);
}

// Prewait prepares for waiting.
// After calling this function the thread must re-check the wait predicate
// and call either CancelWait or CommitWait passing the same Waiter object.
void Prewait(Waiter* w) {
w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed);
std::atomic_thread_fence(std::memory_order_seq_cst);
// If Prewait returns true, the thread must re-check the wait predicate
// and then call either CancelWait or CommitWait.
// Otherwise, the thread should assume the predicate may be true
// and don't call CancelWait/CommitWait (there was a concurrent Notify call).
bool Prewait() {
uint64_t state = state_.load(std::memory_order_relaxed);
for (;;) {
CheckState(state);
uint64_t newstate = state + kWaiterInc;
if ((state & kSignalMask) != 0) {
// Consume the signal and cancel waiting.
newstate -= kSignalInc + kWaiterInc;
}
CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
std::memory_order_seq_cst))
return (state & kSignalMask) == 0;
}
}

// CommitWait commits waiting.
// CommitWait commits waiting after Prewait.
void CommitWait(Waiter* w) {
eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
w->state = Waiter::kNotSignaled;
// Modification epoch of this waiter.
uint64_t epoch =
(w->epoch & kEpochMask) +
(((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
const uint64_t me = (w - &waiters_[0]) | w->epoch;
uint64_t state = state_.load(std::memory_order_seq_cst);
for (;;) {
if (int64_t((state & kEpochMask) - epoch) < 0) {
// The preceding waiter has not decided on its fate. Wait until it
// calls either CancelWait or CommitWait, or is notified.
EIGEN_THREAD_YIELD();
state = state_.load(std::memory_order_seq_cst);
continue;
CheckState(state, true);
uint64_t newstate;
if ((state & kSignalMask) != 0) {
// Consume the signal and return immidiately.
newstate = state - kWaiterInc - kSignalInc;
} else {
// Remove this thread from pre-wait counter and add to the waiter stack.
newstate = ((state & kWaiterMask) - kWaiterInc) | me;
w->next.store(state & (kStackMask | kEpochMask),
std::memory_order_relaxed);
}
// We've already been notified.
if (int64_t((state & kEpochMask) - epoch) > 0) return;
// Remove this thread from prewait counter and add it to the waiter list.
eigen_plain_assert((state & kWaiterMask) != 0);
uint64_t newstate = state - kWaiterInc + kEpochInc;
newstate = (newstate & ~kStackMask) | (w - &waiters_[0]);
if ((state & kStackMask) == kStackMask)
w->next.store(nullptr, std::memory_order_relaxed);
else
w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed);
CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
std::memory_order_release))
break;
std::memory_order_acq_rel)) {
if ((state & kSignalMask) == 0) {
w->epoch += kEpochInc;
Park(w);
}
return;
}
}
Park(w);
}

// CancelWait cancels effects of the previous Prewait call.
void CancelWait(Waiter* w) {
uint64_t epoch =
(w->epoch & kEpochMask) +
(((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
void CancelWait() {
uint64_t state = state_.load(std::memory_order_relaxed);
for (;;) {
if (int64_t((state & kEpochMask) - epoch) < 0) {
// The preceding waiter has not decided on its fate. Wait until it
// calls either CancelWait or CommitWait, or is notified.
EIGEN_THREAD_YIELD();
state = state_.load(std::memory_order_relaxed);
continue;
}
// We've already been notified.
if (int64_t((state & kEpochMask) - epoch) > 0) return;
// Remove this thread from prewait counter.
eigen_plain_assert((state & kWaiterMask) != 0);
if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
std::memory_order_relaxed))
CheckState(state, true);
uint64_t newstate = state - kWaiterInc;
// Also take away a signal if any.
if ((state & kSignalMask) != 0) newstate -= kSignalInc;
CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
std::memory_order_acq_rel))
return;
}
}
Expand All @@ -132,35 +133,33 @@ class EventCount {
std::atomic_thread_fence(std::memory_order_seq_cst);
uint64_t state = state_.load(std::memory_order_acquire);
for (;;) {
CheckState(state);
const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
const uint64_t signals = (state & kSignalMask) >> kSignalShift;
// Easy case: no waiters.
if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0)
return;
uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
if ((state & kStackMask) == kStackMask && waiters == signals) return;
uint64_t newstate;
if (notifyAll) {
// Reset prewait counter and empty wait list.
newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask;
} else if (waiters) {
// Empty wait stack and set signal to number of pre-wait threads.
newstate =
(state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
} else if (signals < waiters) {
// There is a thread in pre-wait state, unblock it.
newstate = state + kEpochInc - kWaiterInc;
newstate = state + kSignalInc;
} else {
// Pop a waiter from list and unpark it.
Waiter* w = &waiters_[state & kStackMask];
Waiter* wnext = w->next.load(std::memory_order_relaxed);
uint64_t next = kStackMask;
if (wnext != nullptr) next = wnext - &waiters_[0];
// Note: we don't add kEpochInc here. ABA problem on the lock-free stack
// can't happen because a waiter is re-pushed onto the stack only after
// it was in the pre-wait state which inevitably leads to epoch
// increment.
newstate = (state & kEpochMask) + next;
uint64_t next = w->next.load(std::memory_order_relaxed);
newstate = (state & (kWaiterMask | kSignalMask)) | next;
}
CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
std::memory_order_acquire)) {
if (!notifyAll && waiters) return; // unblocked pre-wait thread
std::memory_order_acq_rel)) {
if (!notifyAll && (signals < waiters))
return; // unblocked pre-wait thread
if ((state & kStackMask) == kStackMask) return;
Waiter* w = &waiters_[state & kStackMask];
if (!notifyAll) w->next.store(nullptr, std::memory_order_relaxed);
if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
Unpark(w);
return;
}
Expand All @@ -171,11 +170,11 @@ class EventCount {
friend class EventCount;
// Align to 128 byte boundary to prevent false sharing with other Waiter
// objects in the same vector.
EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next;
EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
std::mutex mu;
std::condition_variable cv;
uint64_t epoch;
unsigned state;
uint64_t epoch = 0;
unsigned state = kNotSignaled;
enum {
kNotSignaled,
kWaiting,
Expand All @@ -185,23 +184,41 @@ class EventCount {

private:
// State_ layout:
// - low kStackBits is a stack of waiters committed wait.
// - low kWaiterBits is a stack of waiters committed wait
// (indexes in waiters_ array are used as stack elements,
// kStackMask means empty stack).
// - next kWaiterBits is count of waiters in prewait state.
// - next kEpochBits is modification counter.
static const uint64_t kStackBits = 16;
static const uint64_t kStackMask = (1ull << kStackBits) - 1;
static const uint64_t kWaiterBits = 16;
static const uint64_t kWaiterShift = 16;
// - next kWaiterBits is count of pending signals.
// - remaining bits are ABA counter for the stack.
// (stored in Waiter node and incremented on push).
static const uint64_t kWaiterBits = 14;
static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
static const uint64_t kWaiterShift = kWaiterBits;
static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
<< kWaiterShift;
static const uint64_t kWaiterInc = 1ull << kWaiterBits;
static const uint64_t kEpochBits = 32;
static const uint64_t kEpochShift = 32;
static const uint64_t kWaiterInc = 1ull << kWaiterShift;
static const uint64_t kSignalShift = 2 * kWaiterBits;
static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
<< kSignalShift;
static const uint64_t kSignalInc = 1ull << kSignalShift;
static const uint64_t kEpochShift = 3 * kWaiterBits;
static const uint64_t kEpochBits = 64 - kEpochShift;
static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
static const uint64_t kEpochInc = 1ull << kEpochShift;
std::atomic<uint64_t> state_;
MaxSizeVector<Waiter>& waiters_;

static void CheckState(uint64_t state, bool waiter = false) {
static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
const uint64_t signals = (state & kSignalMask) >> kSignalShift;
eigen_plain_assert(waiters >= signals);
eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
eigen_plain_assert(!waiter || waiters > 0);
(void)waiters;
(void)signals;
}

void Park(Waiter* w) {
std::unique_lock<std::mutex> lock(w->mu);
while (w->state != Waiter::kSignaled) {
Expand All @@ -210,10 +227,10 @@ class EventCount {
}
}

void Unpark(Waiter* waiters) {
Waiter* next = nullptr;
for (Waiter* w = waiters; w; w = next) {
next = w->next.load(std::memory_order_relaxed);
void Unpark(Waiter* w) {
for (Waiter* next; w; w = next) {
uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
next = wnext == kStackMask ? nullptr : &waiters_[wnext];
unsigned state;
{
std::unique_lock<std::mutex> lock(w->mu);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,11 +374,11 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
eigen_plain_assert(!t->f);
// We already did best-effort emptiness check in Steal, so prepare for
// blocking.
ec_.Prewait(waiter);
if (!ec_.Prewait()) return true;
// Now do a reliable emptiness check.
int victim = NonEmptyQueueIndex();
if (victim != -1) {
ec_.CancelWait(waiter);
ec_.CancelWait();
if (cancelled_) {
return false;
} else {
Expand All @@ -392,7 +392,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
blocked_++;
// TODO is blocked_ required to be unsigned?
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
ec_.CancelWait(waiter);
ec_.CancelWait();
// Almost done, but need to re-check queues.
// Consider that all queues are empty and all worker threads are preempted
// right after incrementing blocked_ above. Now a free-standing thread
Expand Down
10 changes: 5 additions & 5 deletions unsupported/test/cxx11_eventcount.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ static void test_basic_eventcount()
EventCount ec(waiters);
EventCount::Waiter& w = waiters[0];
ec.Notify(false);
ec.Prewait(&w);
VERIFY(ec.Prewait());
ec.Notify(true);
ec.CommitWait(&w);
ec.Prewait(&w);
ec.CancelWait(&w);
VERIFY(ec.Prewait());
ec.CancelWait();
}

// Fake bounded counter-based queue.
Expand Down Expand Up @@ -112,7 +112,7 @@ static void test_stress_eventcount()
unsigned idx = rand_reentrant(&rnd) % kQueues;
if (queues[idx].Pop()) continue;
j--;
ec.Prewait(&w);
if (!ec.Prewait()) continue;
bool empty = true;
for (int q = 0; q < kQueues; q++) {
if (!queues[q].Empty()) {
Expand All @@ -121,7 +121,7 @@ static void test_stress_eventcount()
}
}
if (!empty) {
ec.CancelWait(&w);
ec.CancelWait();
continue;
}
ec.CommitWait(&w);
Expand Down

0 comments on commit 01da8ca

Please sign in to comment.