Skip to content

Commit

Permalink
Fixed race condition causing workers to sleep prematurely
Browse files Browse the repository at this point in the history
Fixed a race condition that could cause a worker to stop spinning prematurely. The result was that some workers would do multiple tasks in serial, while other workers sleep.
  • Loading branch information
natepaynefb authored and ben-clayton committed Nov 29, 2023
1 parent 3eb171e commit 535d491
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 15 deletions.
9 changes: 5 additions & 4 deletions include/marl/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,11 @@ class Scheduler {
// spinForWork().
void waitForWork() REQUIRES(work.mutex);

// spinForWork() attempts to steal work from another Worker, and keeps
// spinForWorkAndLock() attempts to steal work from another Worker, and keeps
// the thread awake for a short duration. This reduces overheads of
// frequently putting the thread to sleep and re-waking.
void spinForWork();
// frequently putting the thread to sleep and re-waking. It locks the mutex
// before returning so that a stolen task cannot be re-stolen by other workers.
void spinForWorkAndLock() ACQUIRE(work.mutex);

// enqueueFiberTimeouts() enqueues all the fibers that have finished
// waiting.
Expand Down Expand Up @@ -498,7 +499,7 @@ class Scheduler {
// The immutable configuration used to build the scheduler.
const Config cfg;

std::array<std::atomic<int>, 8> spinningWorkers;
std::array<std::atomic<int>, MaxWorkerThreads> spinningWorkers;
std::atomic<unsigned int> nextSpinningWorkerIdx = {0x8000000};

std::atomic<unsigned int> nextEnqueueIndex = {0};
Expand Down
28 changes: 17 additions & 11 deletions src/scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,8 @@ Scheduler::Scheduler(const Config& config)
: cfg(setConfigDefaults(config)),
workerThreads{},
singleThreadedWorkers(config.allocator) {
for (size_t i = 0; i < spinningWorkers.size(); i++) {
spinningWorkers[i] = -1;
}
for (int i = 0; i < cfg.workerThread.count; i++) {
spinningWorkers[i] = -1;
workerThreads[i] =
cfg.allocator->create<Worker>(this, Worker::Mode::MultiThreaded, i);
}
Expand Down Expand Up @@ -170,7 +168,7 @@ void Scheduler::enqueue(Task&& task) {
if (cfg.workerThread.count > 0) {
while (true) {
// Prioritize workers that have recently started spinning.
auto i = --nextSpinningWorkerIdx % spinningWorkers.size();
auto i = --nextSpinningWorkerIdx % cfg.workerThread.count;
auto idx = spinningWorkers[i].exchange(-1);
if (idx < 0) {
// If a spinning worker couldn't be found, round-robin the
Expand Down Expand Up @@ -212,7 +210,7 @@ bool Scheduler::stealWork(Worker* thief, uint64_t from, Task& out) {
}

void Scheduler::onBeginSpinning(int workerId) {
auto idx = nextSpinningWorkerIdx++ % spinningWorkers.size();
auto idx = nextSpinningWorkerIdx++ % cfg.workerThread.count;
spinningWorkers[idx] = workerId;
}

Expand Down Expand Up @@ -572,7 +570,7 @@ void Scheduler::Worker::run() {
MARL_NAME_THREAD("Thread<%.2d> Fiber<%.2d>", int(id), Fiber::current()->id);
// This is the entry point for a multi-threaded worker.
// Start with a regular condition-variable wait for work. This avoids
// starting the thread with a spinForWork().
// starting the thread with a spinForWorkAndLock().
work.wait([this]() REQUIRES(work.mutex) {
return work.num > 0 || work.waiting || shutdown;
});
Expand All @@ -599,8 +597,7 @@ void Scheduler::Worker::waitForWork() {
if (mode == Mode::MultiThreaded) {
scheduler->onBeginSpinning(id);
work.mutex.unlock();
spinForWork();
work.mutex.lock();
spinForWorkAndLock();
}

work.wait([this]() REQUIRES(work.mutex) {
Expand Down Expand Up @@ -637,7 +634,7 @@ void Scheduler::Worker::setFiberState(Fiber* fiber, Fiber::State to) const {
fiber->state = to;
}

void Scheduler::Worker::spinForWork() {
void Scheduler::Worker::spinForWorkAndLock() {
TRACE("SPIN");
Task stolen;

Expand All @@ -652,20 +649,29 @@ void Scheduler::Worker::spinForWork() {
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
nop(); nop(); nop(); nop(); nop(); nop(); nop(); nop();
// clang-format on

if (work.num > 0) {
return;
work.mutex.lock();
if (work.num > 0) {
return;
}
else {
// Our new task was stolen by another worker. Keep spinning.
work.mutex.unlock();
}
}
}

if (scheduler->stealWork(this, rng(), stolen)) {
marl::lock lock(work.mutex);
work.mutex.lock();
work.tasks.emplace_back(std::move(stolen));
work.num++;
return;
}

std::this_thread::yield();
}
work.mutex.lock();
}

void Scheduler::Worker::runUntilIdle() {
Expand Down
29 changes: 29 additions & 0 deletions src/scheduler_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,35 @@ BENCHMARK_DEFINE_F(Schedule, SomeWork)
}
BENCHMARK_REGISTER_F(Schedule, SomeWork)->Apply(Schedule::args);

BENCHMARK_DEFINE_F(Schedule, MultipleForkAndJoin)(benchmark::State& state) {
run(state, [&](int numTasks) {
const int batchSize = std::max(1, Schedule::numThreads(state));
for (auto _ : state) {
marl::WaitGroup wg;
for (int i = 0; i < numTasks; i++) {
wg.add(1);
marl::schedule([=] {
// Give each task a significant amount of work so that concurrency matters.
// If any worker performs more than one task, it will affect the results.
int value = i;
for (int j = 0; j < 256; ++j) {
value = doSomeWork(value);
}
benchmark::DoNotOptimize(value);
wg.done();
});
// Wait for completion after every batch. This simulates the fork-and-join pattern.
if ((i + 1) % batchSize == 0) {
wg.wait();
}
}
wg.wait();
}
});
}

BENCHMARK_REGISTER_F(Schedule, MultipleForkAndJoin)->Apply(Schedule::args<512>);

BENCHMARK_DEFINE_F(Schedule, SomeWorkWorkerAffinityOneOf)
(benchmark::State& state) {
marl::Scheduler::Config cfg;
Expand Down

0 comments on commit 535d491

Please sign in to comment.