Skip to content

Commit

Permalink
Solve a deadlock problem
Browse files Browse the repository at this point in the history
Summary: The HHWheelTimer implementation mistakenly called the callback's timeoutExpired() function while holding an internal lock (contrary to the comment).  This caused a deadlock problem under production loads.  Modify the code to avoid this.  Also added a test that fails before this change and succeeds after it.

Differential Revision: D7754090

fbshipit-source-id: 1356b98
  • Loading branch information
jkedgar authored and facebook-github-bot committed Apr 27, 2018
1 parent f5c9af8 commit fdc0383
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 26 deletions.
33 changes: 18 additions & 15 deletions sql/hh_wheel_timer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,27 +133,30 @@ void HHWheelTimer::readdToWheel(const CallbackPtr& cb,
// The background thread has woken at the appropriate time - cause all expired
// timers to have their callback functions called
// Inherited from Timer
void HHWheelTimer::expired() noexcept {
void HHWheelTimer::expired(
std::unique_lock<std::recursive_mutex>&& lock) noexcept {
CallbackVct needNotification;

{
// Grab the mutex to protect the data structures
std::lock_guard<std::recursive_mutex> guard(m_mutex);
// Validate that the mutex is locked
DBUG_ASSERT(lock.owns_lock());

m_curr_tick = currentTick();

// In case we have moved past the expected expire tick, expire all items
// from then to the current tick.
auto tick = m_expire_tick;
DBUG_ASSERT(tick <= m_curr_tick);
while (tick <= m_curr_tick) {
expireItems(tick++, needNotification);
}
m_curr_tick = currentTick();

// Determine when next we need to wake
scheduleNextTimeout();
// In case we have moved past the expected expire tick, expire all items
// from then to the current tick.
auto tick = m_expire_tick;
DBUG_ASSERT(tick <= m_curr_tick);
while (tick <= m_curr_tick) {
expireItems(tick++, needNotification);
}

// Determine when next we need to wake
scheduleNextTimeout();

// Release the mutex - the caller assumes we will do this
lock.unlock();
DBUG_ASSERT(!lock.owns_lock());

// Notify all the callbacks that expired (outside the mutex)
for (const auto& itr : needNotification) {
itr.first->timeoutExpired(itr.second);
Expand Down
6 changes: 4 additions & 2 deletions sql/hh_wheel_timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,11 @@ class HHWheelTimer : private Timer {
HHWheelTimer& operator=(HHWheelTimer const&) = delete;

// The background thread has woken at the appropriate time - cause all expired
// timers to have their callback functions called
// timers to have their callback functions called.
// Receive ownership of the lock from the caller so that we can release it
// when we are ready.
// Inherited from Timer
void expired() noexcept override;
void expired(std::unique_lock<std::recursive_mutex>&& lock) noexcept override;

// Remove the callback from the timer
ID removeCallback(const CallbackPtr& cb);
Expand Down
20 changes: 15 additions & 5 deletions sql/timer.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,28 @@ class Timer : public my_timer_t {

void trigger() noexcept {
m_triggers_in_process++;
std::lock_guard<std::recursive_mutex> guard(m_mutex);

if (m_scheduled) {
m_scheduled = false;
expired();
{
// Instead of a standard lock guard, create a unique lock so that we
// can pass it to the callback which can then unlock it early.
std::unique_lock<std::recursive_mutex> lock(m_mutex);
DBUG_ASSERT(lock.owns_lock());

if (m_scheduled) {
m_scheduled = false;
// Pass the guard to the callback so the mutex can be released early
expired(std::move(lock));
}
}

m_triggers_in_process--;
}

// Function to override to handle the timer expiring
virtual void expired() noexcept = 0;
// We will pass ownership of the mutex into the function so the callback
// can release it when it desires.
virtual void expired(
std::unique_lock<std::recursive_mutex>&& lock) noexcept = 0;

protected:
std::recursive_mutex m_mutex;
Expand Down
97 changes: 93 additions & 4 deletions unittest/gunit2/hh_wheel_timer-t.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ namespace hh_wheel_timer_unittests {
using seconds = std::chrono::seconds;
using msecs = std::chrono::milliseconds;

constexpr std::chrono::milliseconds k60Seconds(60000);

class CallbackTest : public HHWheelTimer::Callback {
public:
CallbackTest() :
Expand Down Expand Up @@ -72,6 +74,9 @@ class CallbackTest : public HHWheelTimer::Callback {
return duration();
}

// Do nothing in the base version
virtual void setup() {}

void reset() {
setID_ = 0;
expiredID_ = 0;
Expand All @@ -92,6 +97,43 @@ class CallbackTest : public HHWheelTimer::Callback {
std::chrono::steady_clock::time_point end_;
};

class CallbackTestWithMutex :
public CallbackTest,
public std::enable_shared_from_this<CallbackTestWithMutex> {
public:
CallbackTestWithMutex(HHWheelTimer* timer) : timer_(timer) {}

void timeoutExpired(HHWheelTimer::ID id) noexcept override {
CallbackTest::timeoutExpired(id);
checkDeadlock();
}

void timeoutCancelled(HHWheelTimer::ID id) noexcept override {
CallbackTest::timeoutCancelled(id);
checkDeadlock();
}

void setup() override {
std::lock_guard<std::timed_mutex> guard(mutex_);
timer_->cancelTimeout(shared_from_this());
}

private:
static std::timed_mutex mutex_;
HHWheelTimer* timer_;

void checkDeadlock() {
std::unique_lock<std::timed_mutex> lock(mutex_, std::chrono::seconds(2));
if (!lock) {
throw std::runtime_error("Deadlock!!!");
}

lock.unlock();
}
};

std::timed_mutex CallbackTestWithMutex::mutex_;

static void waitForUseCount(std::shared_ptr<CallbackTest>& ptr, long count)
{
// Wait up to 2 seconds for the use_count to drop to the expected count
Expand Down Expand Up @@ -160,10 +202,12 @@ class ThreadTestThread : public thread::Thread
timer_{nullptr},
state_{nullptr},
failures_{0},
ready_{false} {
ready_{false},
useMutex_{false} {
}
virtual ~ThreadTestThread() {}

void requireMutex() { useMutex_ = true; }
void setTimer(HHWheelTimer* timer) { timer_ = timer; }
void setState(enum ControlState* state) { state_ = state; }
void run() {
Expand All @@ -189,7 +233,11 @@ class ThreadTestThread : public thread::Thread
void try_test() {
// Get a timeout from 10 to 1000 millseconds (in multiples of 10ms)
auto to = msecs(((rng_() % 100) + 1) * 10);
auto cb = std::make_shared<CallbackTest>();
auto cb = useMutex_ ?
std::dynamic_pointer_cast<CallbackTest>(
std::make_shared<CallbackTestWithMutex>(timer_)) :
std::make_shared<CallbackTest>();
cb->setup();
cb->setID_ = timer_->scheduleTimeout(cb, to);

if (rng_() % 5 == 0) {
Expand Down Expand Up @@ -225,6 +273,7 @@ class ThreadTestThread : public thread::Thread
std::mt19937_64 rng_;
uint32_t failures_;
bool ready_;
bool useMutex_;
};

// Make sure the template parameters are set correctly. The template
Expand Down Expand Up @@ -535,8 +584,48 @@ TEST_F(HHWheelTimerTest, Threads)

// Wait for 60 seconds
auto cb = std::make_shared<CallbackTest>();
timer.scheduleTimeout(cb, msecs(60000));
cb->wait(msecs(60000));
timer.scheduleTimeout(cb, k60Seconds);
cb->wait(k60Seconds);

// Stop the threads
state = STOPPING;

for (uint32_t ii = 0; ii < NUM_THREADS; ii++) {
threads[ii].join();
EXPECT_EQ(threads[ii].num_failures(), 0U);
}
}

// Run many threads all doing timeouts and using a mutex
TEST_F(HHWheelTimerTest, ThreadsWithAMutex)
{
HHWheelTimer timer;
ThreadTestThread threads[NUM_THREADS];
enum ControlState state = INITIALIZING;

// Get each thread up and running and pass in a pointer to the timer
for (uint32_t ii = 0; ii < NUM_THREADS; ii++) {
threads[ii].requireMutex();
threads[ii].setTimer(&timer);
threads[ii].setState(&state);
threads[ii].start();
}

// Make sure each thread is ready to run
for (uint32_t ii = 0; ii < NUM_THREADS; ii++) {
while (!threads[ii].ready()) {
std::this_thread::sleep_for(msecs(1));
}
}

// Begin timers
state = RUNNING;

// Wait for 60 seconds
auto cb = std::dynamic_pointer_cast<CallbackTest>(
std::make_shared<CallbackTestWithMutex>(&timer));
timer.scheduleTimeout(cb, k60Seconds);
cb->wait(k60Seconds);

// Stop the threads
state = STOPPING;
Expand Down

0 comments on commit fdc0383

Please sign in to comment.