Skip to content

Commit

Permalink
Add test for notify observers race
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaushik Iska committed Aug 22, 2019
1 parent 67606b3 commit 97d16ca
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 24 deletions.
49 changes: 31 additions & 18 deletions fml/message_loop_task_queues.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ MessageLoopTaskQueues::~MessageLoopTaskQueues() = default;

void MessageLoopTaskQueues::Dispose(TaskQueueId queue_id) {
fml::UniqueLock queue_wirter(*queue_meta_mutex_);
FML_DCHECK(queue_entries.at(queue_id).subsumed_by == _kUnmerged);
TaskQueueId subsumed = queue_entries.at(queue_id).owner_of;
queue_entries.erase(queue_id);
if (subsumed != _kUnmerged) {
Expand All @@ -62,8 +63,7 @@ void MessageLoopTaskQueues::Dispose(TaskQueueId queue_id) {
void MessageLoopTaskQueues::RegisterTask(TaskQueueId queue_id,
fml::closure task,
fml::TimePoint target_time) {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock tasks_lock(*queue_entries.at(queue_id).tasks_mutex);
std::scoped_lock tasks_lock(GetMutex(queue_id, MutexType::kTasks));

size_t order = order_++;
TaskQueueEntry& queue_entry = queue_entries[queue_id];
Expand All @@ -76,8 +76,7 @@ void MessageLoopTaskQueues::RegisterTask(TaskQueueId queue_id,
}

bool MessageLoopTaskQueues::HasPendingTasks(TaskQueueId queue_id) const {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock tasks_lock(*queue_entries.at(queue_id).tasks_mutex);
std::scoped_lock tasks_lock(GetMutex(queue_id, MutexType::kTasks));

return HasPendingTasksUnlocked(queue_id);
}
Expand All @@ -86,8 +85,7 @@ void MessageLoopTaskQueues::GetTasksToRunNow(
TaskQueueId queue_id,
FlushType type,
std::vector<fml::closure>& invocations) {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock tasks_lock(*queue_entries.at(queue_id).tasks_mutex);
std::scoped_lock tasks_lock(GetMutex(queue_id, MutexType::kTasks));

if (!HasPendingTasksUnlocked(queue_id)) {
return;
Expand Down Expand Up @@ -117,17 +115,19 @@ void MessageLoopTaskQueues::GetTasksToRunNow(

void MessageLoopTaskQueues::WakeUp(TaskQueueId queue_id,
fml::TimePoint time) const {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock wakeable_lock(*queue_entries.at(queue_id).wakeable_mutex);
std::scoped_lock wakeable_lock(GetMutex(queue_id, MutexType::kWakeable));

if (queue_entries.at(queue_id).wakeable) {
queue_entries.at(queue_id).wakeable->WakeUp(time);
}
}

size_t MessageLoopTaskQueues::GetNumPendingTasks(TaskQueueId queue_id) const {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock tasks_lock(*queue_entries.at(queue_id).tasks_mutex);
std::scoped_lock tasks_lock(GetMutex(queue_id, MutexType::kTasks));

if (queue_entries.at(queue_id).subsumed_by != _kUnmerged) {
return 0;
}

size_t total_tasks = 0;
total_tasks += queue_entries.at(queue_id).delayed_tasks.size();
Expand All @@ -143,24 +143,25 @@ size_t MessageLoopTaskQueues::GetNumPendingTasks(TaskQueueId queue_id) const {
void MessageLoopTaskQueues::AddTaskObserver(TaskQueueId queue_id,
intptr_t key,
fml::closure callback) {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock observers_lock(*queue_entries.at(queue_id).observers_mutex);
std::scoped_lock observers_lock(GetMutex(queue_id, MutexType::kObservers));

FML_DCHECK(callback != nullptr) << "Observer callback must be non-null.";
queue_entries[queue_id].task_observers[key] = std::move(callback);
}

void MessageLoopTaskQueues::RemoveTaskObserver(TaskQueueId queue_id,
intptr_t key) {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock observers_lock(*queue_entries.at(queue_id).observers_mutex);
std::scoped_lock observers_lock(GetMutex(queue_id, MutexType::kObservers));

queue_entries[queue_id].task_observers.erase(key);
}

void MessageLoopTaskQueues::NotifyObservers(TaskQueueId queue_id) const {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock observers_lock(*queue_entries.at(queue_id).observers_mutex);
std::scoped_lock observers_lock(GetMutex(queue_id, MutexType::kObservers));

if (queue_entries.at(queue_id).subsumed_by != _kUnmerged) {
return;
}

for (const auto& observer : queue_entries.at(queue_id).task_observers) {
observer.second();
Expand All @@ -176,8 +177,7 @@ void MessageLoopTaskQueues::NotifyObservers(TaskQueueId queue_id) const {

void MessageLoopTaskQueues::SetWakeable(TaskQueueId queue_id,
fml::Wakeable* wakeable) {
fml::SharedLock queue_reader(*queue_meta_mutex_);
std::scoped_lock wakeable_lock(*queue_entries.at(queue_id).wakeable_mutex);
std::scoped_lock wakeable_lock(GetMutex(queue_id, MutexType::kWakeable));

FML_CHECK(!queue_entries[queue_id].wakeable)
<< "Wakeable can only be set once.";
Expand Down Expand Up @@ -243,6 +243,19 @@ bool MessageLoopTaskQueues::Owns(TaskQueueId owner,
return subsumed == queue_entries.at(owner).owner_of || owner == subsumed;
}

std::mutex& MessageLoopTaskQueues::GetMutex(TaskQueueId queue_id,
MutexType type) const {
fml::SharedLock queue_reader(*queue_meta_mutex_);
const auto& entry = queue_entries.at(queue_id);
if (type == MutexType::kObservers) {
return *entry.observers_mutex;
} else if (type == MutexType::kTasks) {
return *entry.tasks_mutex;
} else {
return *entry.wakeable_mutex;
}
}

// Subsumed queues will never have pending tasks.
// Owning queues will consider both their and their subsumed tasks.
bool MessageLoopTaskQueues::HasPendingTasksUnlocked(
Expand Down
10 changes: 4 additions & 6 deletions fml/message_loop_task_queues.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,6 @@ class MessageLoopTaskQueues
private:
class MergedQueuesRunner;

enum class MutexType {
kTasks,
kObservers,
kWakeables,
};

using Mutexes = std::vector<std::unique_ptr<std::mutex>>;

MessageLoopTaskQueues();
Expand All @@ -135,6 +129,10 @@ class MessageLoopTaskQueues

void WakeUp(TaskQueueId queue_id, fml::TimePoint time) const;

enum class MutexType { kObservers, kTasks, kWakeable };

std::mutex& GetMutex(TaskQueueId queue_id, MutexType type) const;

bool HasPendingTasksUnlocked(TaskQueueId queue_id) const;

const DelayedTask& PeekNextTaskUnlocked(TaskQueueId queue_id,
Expand Down
29 changes: 29 additions & 0 deletions fml/message_loop_task_queues_unittests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#define FML_USED_ON_EMBEDDER

#include <thread>

#include "flutter/fml/message_loop_task_queues.h"
#include "flutter/fml/synchronization/count_down_latch.h"
#include "flutter/fml/synchronization/waitable_event.h"
Expand Down Expand Up @@ -136,3 +138,30 @@ TEST(MessageLoopTaskQueue, WokenUpWithNewerTime) {

latch.Wait();
}

TEST(MessageLoopTaskQueue, NotifyObserversWhileCreatingQueues) {
auto task_queues = fml::MessageLoopTaskQueues::GetInstance();
fml::TaskQueueId queue_id = task_queues->CreateTaskQueue();
fml::AutoResetWaitableEvent first_observer_executing, before_second_observer;

task_queues->AddTaskObserver(queue_id, queue_id + 1, [&]() {
first_observer_executing.Signal();
before_second_observer.Wait();
});

for (int i = 0; i < 100; i++) {
task_queues->AddTaskObserver(queue_id, queue_id + i + 2, [] {});
}

std::thread notify_observers(
[&]() { task_queues->NotifyObservers(queue_id); });

first_observer_executing.Wait();

for (int i = 0; i < 100; i++) {
task_queues->CreateTaskQueue();
}

before_second_observer.Signal();
notify_observers.join();
}

0 comments on commit 97d16ca

Please sign in to comment.