Skip to content

Commit

Permalink
add distributed, non-blocking thread pool implementation
Browse files Browse the repository at this point in the history
Current thread pool implementation is centralized and non-scalable.
This change adds distributed, non-blocking thread pool implementation.
Both implementations co-exist and can be chosen using
TF_THREAD_POOL env var.

Fixes tensorflow#551
Fixes tensorflow#583
Update tensorflow#932
Update tensorflow#933
  • Loading branch information
dvyukov committed Jan 29, 2016
1 parent 2cb25ab commit e3af358
Show file tree
Hide file tree
Showing 7 changed files with 560 additions and 146 deletions.
96 changes: 96 additions & 0 deletions tensorflow/core/lib/core/eventcount.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright 2016 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//==============================================================================

#ifndef TENSORFLOW_LIB_CORE_EVENTCOUNT_H_
#define TENSORFLOW_LIB_CORE_EVENTCOUNT_H_

#include <atomic>

#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"

namespace tensorflow {
namespace thread {

class EventCount {
public:
EventCount() : state_(), waiters_() {}

unsigned Prewait() {
unsigned state = state_.fetch_or(kWaiter, std::memory_order_relaxed);
std::atomic_thread_fence(std::memory_order_seq_cst);
return state & ~kWaiter;
}

void Wait(unsigned epoch) {
mutex_lock lock(mutex_);
if (epoch != (state_.load(std::memory_order_seq_cst) & ~kWaiter)) return;
waiters_++;
cv_.wait(lock);
}

void NotifyOne() {
std::atomic_thread_fence(std::memory_order_seq_cst);
unsigned state = state_.load(std::memory_order_relaxed);
if (!(state & kWaiter)) return;
unsigned waiters;
{
mutex_lock lock(mutex_);
waiters = waiters_;
if (waiters < 2) {
waiters_ = 0;
while (!state_.compare_exchange_weak(state, (state & ~kWaiter) + kEpoch,
std::memory_order_relaxed)) {
}
} else {
waiters_--;
state_.fetch_add(kEpoch, std::memory_order_relaxed);
}
}
if (waiters) cv_.notify_one();
}

void NotifyAll() {
std::atomic_thread_fence(std::memory_order_seq_cst);
unsigned state = state_.load(std::memory_order_relaxed);
if (!(state & kWaiter)) return;
unsigned waiters;
{
mutex_lock lock(mutex_);
while (!state_.compare_exchange_weak(state, (state & ~kWaiter) + kEpoch,
std::memory_order_relaxed)) {
}
waiters = waiters_;
waiters_ = 0;
}
if (waiters) cv_.notify_all();
}

private:
enum {
kWaiter = 1,
kEpoch = 2,
};
mutex mutex_;
condition_variable cv_;
std::atomic<unsigned> state_;
unsigned waiters_;
TF_DISALLOW_COPY_AND_ASSIGN(EventCount);
};

} // namespace thread
} // namespace tensorflow

#endif // TENSORFLOW_LIB_CORE_EVENTCOUNT_H_
113 changes: 113 additions & 0 deletions tensorflow/core/lib/core/runqueue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2016 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//==============================================================================

#ifndef TENSORFLOW_LIB_CORE_RUNQUEUE_H_
#define TENSORFLOW_LIB_CORE_RUNQUEUE_H_

#include <atomic>

#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"

namespace tensorflow {
namespace thread {

template <typename Work, unsigned kSize>
class RunQueueT {
public:
RunQueueT() : front_(), back_() {
for (unsigned i = 0; i < kSize; i++)
array_[i].state.store(kEmpty, std::memory_order_relaxed);
}

Work PushFront(Work w) {
Elem *e = &array_[front_ % kSize];
uint8_t s = e->state.load(std::memory_order_relaxed);
if (s != kEmpty ||
!e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire))
return w;
e->w = std::move(w);
e->state.store(kReady, std::memory_order_release);
front_++;
return Work();
}

Work PopFront() {
Elem *e = &array_[(front_ - 1) % kSize];
uint8_t s = e->state.load(std::memory_order_relaxed);
if (s != kReady ||
!e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire))
return Work();
Work w = std::move(e->w);
e->state.store(kEmpty, std::memory_order_release);
front_--;
return w;
}

Work PushBack(Work w) {
mutex_lock lock(mutex_);
Elem *e = &array_[(back_ - 1) % kSize];
uint8_t s = e->state.load(std::memory_order_relaxed);
if (s != kEmpty ||
!e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire))
return w;
e->w = std::move(w);
e->state.store(kReady, std::memory_order_release);
back_--;
return Work();
}

Work PopBack() {
mutex_lock lock(mutex_, std::try_to_lock);
if (!lock) return Work();
Elem *e = &array_[back_ % kSize];
uint8_t s = e->state.load(std::memory_order_relaxed);
if (s != kReady ||
!e->state.compare_exchange_strong(s, kBusy, std::memory_order_acquire))
return Work();
Work w = std::move(e->w);
e->state.store(kEmpty, std::memory_order_release);
back_++;
return w;
}

bool Empty() {
mutex_lock lock(mutex_);
Elem *e = &array_[back_ % kSize];
uint8_t s = e->state.load(std::memory_order_relaxed);
return s == kEmpty;
}

private:
struct Elem {
std::atomic<uint8_t> state;
Work w;
};
enum {
kEmpty,
kBusy,
kReady,
};
mutex mutex_;
unsigned front_;
unsigned back_;
Elem array_[kSize];
TF_DISALLOW_COPY_AND_ASSIGN(RunQueueT);
};

} // namespace thread
} // namespace tensorflow

#endif // TENSORFLOW_LIB_CORE_RUNQUEUE_H_
124 changes: 0 additions & 124 deletions tensorflow/core/lib/core/threadpool.cc

This file was deleted.

30 changes: 8 additions & 22 deletions tensorflow/core/lib/core/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
#define TENSORFLOW_LIB_CORE_THREADPOOL_H_

#include <deque>
#include <functional>
#include <thread>
#include <vector>
#include <memory>
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
namespace thread {
Expand All @@ -45,28 +41,18 @@ class ThreadPool {

// Wait until all scheduled work has finished and then destroy the
// set of threads.
virtual ~ThreadPool();
~ThreadPool();

// Schedule fn() for execution in the pool of threads.
virtual void Schedule(std::function<void()> fn);
// Note that not all implementations guarantee the order in which
// jobs will be run on a single worker thread, and so jobs scheduled
// simultaneously should not have ordering dependencies between them.
void Schedule(std::function<void()> fn);

virtual bool HasPendingClosures() const;
struct Impl;

private:
struct Waiter;
struct Item {
std::function<void()> fn;
uint64 id;
};

void WorkerLoop();

const string name_;
mutable mutex mu_;
std::vector<Thread*> threads_; // All threads
std::vector<Waiter*> waiters_; // Stack of waiting threads.
std::deque<Item> pending_; // Queue of pending work

std::unique_ptr<Impl> impl_;
TF_DISALLOW_COPY_AND_ASSIGN(ThreadPool);
};

Expand Down

0 comments on commit e3af358

Please sign in to comment.