Skip to content

Commit

Permalink
Call each NetworkStateObserver separately.
Browse files Browse the repository at this point in the history
NetworkStateNotifier used to have NetworkStateObserver list as
HashMap<SingleThreadTaskRunner*, Vector<NetworkStateObserver*>> and
call each observer on a taskrunner sequentially.
That caused race condition and use-after-free: what if an observer
calls wait and other observer removes all?
We also should not guarantee the order of registering observers is
kept as notification order: each observer should not depend on others.

To fix that, this patch reconstructs the structure to
HashMap<NetworkStateObserver*, SingleThreadTaskRunner*> and call each
observer on each taskrunner separately.
This implementation follows base/observer_list_threadsafe.h except
the taskrunner is given by the caller.

Fixed: 1278708
Change-Id: Iff5d0008d5b0d98caa5931e2806db3ffc52be6fa
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4280021
Reviewed-by: Kent Tamura <tkent@chromium.org>
Commit-Queue: Yoichi Osato <yoichio@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1111448}
  • Loading branch information
Yoichi Osato authored and Chromium LUCI CQ committed Mar 1, 2023
1 parent dd0a52a commit 9c3ac3d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 196 deletions.
154 changes: 45 additions & 109 deletions third_party/blink/renderer/platform/network/network_state_notifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,54 +279,55 @@ void NetworkStateNotifier::NotifyObservers(ObserverListMap& map,
DCHECK(IsMainThread());
base::AutoLock locker(lock_);
for (const auto& entry : map) {
scoped_refptr<base::SingleThreadTaskRunner> task_runner = entry.key;
PostCrossThreadTask(
*task_runner, FROM_HERE,
CrossThreadBindOnce(&NetworkStateNotifier::NotifyObserversOnTaskRunner,
CrossThreadUnretained(this),
CrossThreadUnretained(&map), type, task_runner,
state));
entry.value->PostTask(
FROM_HERE,
base::BindOnce(&NetworkStateNotifier::NotifyObserverOnTaskRunner,
base::Unretained(this), base::UnsafeDangling(entry.key),
type, state));
}
}

void NetworkStateNotifier::NotifyObserversOnTaskRunner(
ObserverListMap* map,
void NetworkStateNotifier::NotifyObserverOnTaskRunner(
MayBeDangling<NetworkStateObserver> observer,
ObserverType type,
scoped_refptr<base::SingleThreadTaskRunner> task_runner,
const NetworkState& state) {
ObserverList* observer_list = LockAndFindObserverList(*map, task_runner);

// The context could have been removed before the notification task got to
// run.
if (!observer_list)
return;

DCHECK(task_runner->RunsTasksInCurrentSequence());

observer_list->iterating = true;

for (wtf_size_t i = 0; i < observer_list->observers.size(); ++i) {
// Observers removed during iteration are zeroed out, skip them.
if (!observer_list->observers[i])
continue;
switch (type) {
case ObserverType::kOnLineState:
observer_list->observers[i]->OnLineStateChange(state.on_line);
continue;
case ObserverType::kConnectionType:
observer_list->observers[i]->ConnectionChange(
state.type, state.max_bandwidth_mbps, state.effective_type,
state.http_rtt, state.transport_rtt, state.downlink_throughput_mbps,
state.save_data);
continue;
{
base::AutoLock locker(lock_);
ObserverListMap& map = GetObserverMapFor(type);
// It's safe to pass a MayBeDangling pointer to find().
ObserverListMap::iterator it = map.find(observer);
if (map.end() == it) {
return;
}
NOTREACHED();
DCHECK(it->value->RunsTasksInCurrentSequence());
}

observer_list->iterating = false;
switch (type) {
case ObserverType::kOnLineState:
observer->OnLineStateChange(state.on_line);
return;
case ObserverType::kConnectionType:
observer->ConnectionChange(
state.type, state.max_bandwidth_mbps, state.effective_type,
state.http_rtt, state.transport_rtt, state.downlink_throughput_mbps,
state.save_data);
return;
default:
NOTREACHED();
}
}

if (!observer_list->zeroed_observers.empty())
CollectZeroedObservers(*map, observer_list, std::move(task_runner));
NetworkStateNotifier::ObserverListMap& NetworkStateNotifier::GetObserverMapFor(
ObserverType type) {
switch (type) {
case ObserverType::kConnectionType:
return connection_observers_;
case ObserverType::kOnLineState:
return on_line_state_observers_;
default:
NOTREACHED();
return connection_observers_;
}
}

void NetworkStateNotifier::AddObserverToMap(
Expand All @@ -338,85 +339,20 @@ void NetworkStateNotifier::AddObserverToMap(

base::AutoLock locker(lock_);
ObserverListMap::AddResult result =
map.insert(std::move(task_runner), nullptr);
if (result.is_new_entry)
result.stored_value->value = std::make_unique<ObserverList>();

DCHECK(result.stored_value->value->observers.Find(observer) == kNotFound);
result.stored_value->value->observers.push_back(observer);
map.insert(observer, std::move(task_runner));
DCHECK(result.is_new_entry);
}

void NetworkStateNotifier::RemoveObserver(
ObserverType type,
NetworkStateObserver* observer,
scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
switch (type) {
case ObserverType::kConnectionType:
RemoveObserverFromMap(connection_observers_, observer,
std::move(task_runner));
break;
case ObserverType::kOnLineState:
RemoveObserverFromMap(on_line_state_observers_, observer,
std::move(task_runner));
break;
}
}

void NetworkStateNotifier::RemoveObserverFromMap(
ObserverListMap& map,
NetworkStateObserver* observer,
scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
DCHECK(task_runner->RunsTasksInCurrentSequence());
DCHECK(observer);

ObserverList* observer_list = LockAndFindObserverList(map, task_runner);
if (!observer_list)
return;

Vector<NetworkStateObserver*>& observers = observer_list->observers;
wtf_size_t index = observers.Find(observer);
if (index != kNotFound) {
observers[index] = 0;
observer_list->zeroed_observers.push_back(index);
}

if (!observer_list->iterating && !observer_list->zeroed_observers.empty())
CollectZeroedObservers(map, observer_list, std::move(task_runner));
}

NetworkStateNotifier::ObserverList*
NetworkStateNotifier::LockAndFindObserverList(
ObserverListMap& map,
scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
base::AutoLock locker(lock_);
ObserverListMap::iterator it = map.find(task_runner);
return it == map.end() ? nullptr : it->value.get();
}

void NetworkStateNotifier::CollectZeroedObservers(
ObserverListMap& map,
ObserverList* list,
scoped_refptr<base::SingleThreadTaskRunner> task_runner) {
DCHECK(task_runner->RunsTasksInCurrentSequence());
DCHECK(!list->iterating);

// If any observers were removed during the iteration they will have
// 0 values, clean them up.
std::sort(list->zeroed_observers.begin(), list->zeroed_observers.end());
int removed = 0;
for (wtf_size_t i = 0; i < list->zeroed_observers.size(); ++i) {
int index_to_remove = list->zeroed_observers[i] - removed;
DCHECK_EQ(nullptr, list->observers[index_to_remove]);
list->observers.EraseAt(index_to_remove);
removed += 1;
}

list->zeroed_observers.clear();

if (list->observers.empty()) {
base::AutoLock locker(lock_);
map.erase(task_runner); // deletes list
}
ObserverListMap& map = GetObserverMapFor(type);
DCHECK_NE(map.end(), map.find(observer));
map.erase(observer);
}

// static
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,6 @@ class PLATFORM_EXPORT NetworkStateNotifier {
private:
friend class NetworkStateObserverHandle;

struct ObserverList {
ObserverList() : iterating(false) {}
bool iterating;
Vector<NetworkStateObserver*> observers;
Vector<wtf_size_t> zeroed_observers; // Indices in observers that are 0.
};

// This helper scope issues required notifications when mutating the state if
// something has changed. It's only possible to mutate the state on the main
// thread. Note that ScopedNotifier must be destroyed when not holding a lock
Expand All @@ -343,35 +336,20 @@ class PLATFORM_EXPORT NetworkStateNotifier {

// The ObserverListMap is cross-thread accessed, adding/removing Observers
// running on a task runner.
using ObserverListMap = HashMap<scoped_refptr<base::SingleThreadTaskRunner>,
std::unique_ptr<ObserverList>>;
using ObserverListMap = HashMap<NetworkStateObserver*,
scoped_refptr<base::SingleThreadTaskRunner>>;

void NotifyObservers(ObserverListMap&, ObserverType, const NetworkState&);
void NotifyObserversOnTaskRunner(ObserverListMap*,
ObserverType,
scoped_refptr<base::SingleThreadTaskRunner>,
const NetworkState&);

void NotifyObserverOnTaskRunner(MayBeDangling<NetworkStateObserver>,
ObserverType,
const NetworkState&);
ObserverListMap& GetObserverMapFor(ObserverType);
void AddObserverToMap(ObserverListMap&,
NetworkStateObserver*,
scoped_refptr<base::SingleThreadTaskRunner>);
void RemoveObserver(ObserverType,
NetworkStateObserver*,
scoped_refptr<base::SingleThreadTaskRunner>);
void RemoveObserverFromMap(ObserverListMap&,
NetworkStateObserver*,
scoped_refptr<base::SingleThreadTaskRunner>);

ObserverList* LockAndFindObserverList(
ObserverListMap&,
scoped_refptr<base::SingleThreadTaskRunner>);

// Removed observers are nulled out in the list in case the list is being
// iterated over. Once done iterating, call this to clean up nulled
// observers.
void CollectZeroedObservers(ObserverListMap&,
ObserverList*,
scoped_refptr<base::SingleThreadTaskRunner>);

// A random number by which the RTT and downlink estimates are multiplied
// with. The returned random multiplier is a function of the hostname.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include "base/functional/bind.h"
#include "base/task/single_thread_task_runner.h"
#include "base/test/task_environment.h"
#include "base/time/time.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
Expand Down Expand Up @@ -65,7 +66,7 @@ enum class SaveData {

} // namespace

class StateObserver : public NetworkStateNotifier::NetworkStateObserver {
class StateObserver final : public NetworkStateNotifier::NetworkStateObserver {
public:
StateObserver()
: observed_type_(kWebConnectionTypeNone),
Expand All @@ -77,6 +78,7 @@ class StateObserver : public NetworkStateNotifier::NetworkStateObserver {
observed_on_line_state_(false),
observed_save_data_(SaveData::kOff),
callback_count_(0) {}
~StateObserver() = default;

void ConnectionChange(WebConnectionType type,
double max_bandwidth_mbps,
Expand Down Expand Up @@ -411,10 +413,8 @@ TEST_F(NetworkStateNotifierTest, AddObserverWhileNotifying) {
observer1, kWebConnectionTypeBluetooth, kBluetoothMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt, kUnknownRtt,
kUnknownThroughputMbps, SaveData::kOff));
EXPECT_TRUE(VerifyObservations(
observer2, kWebConnectionTypeBluetooth, kBluetoothMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt, kUnknownRtt,
kUnknownThroughputMbps, SaveData::kOff));
RunPendingTasks();
EXPECT_EQ(0, observer2.CallbackCount());
}

TEST_F(NetworkStateNotifierTest, RemoveSoleObserverWhileNotifying) {
Expand Down Expand Up @@ -473,60 +473,6 @@ TEST_F(NetworkStateNotifierTest, RemoveCurrentObserverWhileNotifying) {
kUnknownThroughputMbps, SaveData::kOff));
}

TEST_F(NetworkStateNotifierTest, RemovePastObserverWhileNotifying) {
StateObserver observer1, observer2;
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle1 =
notifier_.AddConnectionObserver(&observer1, GetTaskRunner());
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle2 =
notifier_.AddConnectionObserver(&observer2, GetTaskRunner());
observer2.RemoveObserverOnNotification(std::move(handle1));

SetConnection(kWebConnectionTypeBluetooth, kBluetoothMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt,
kUnknownRtt, kUnknownThroughputMbps, SaveData::kOff);
EXPECT_EQ(observer1.ObservedType(), kWebConnectionTypeBluetooth);
EXPECT_EQ(observer2.ObservedType(), kWebConnectionTypeBluetooth);

SetConnection(kWebConnectionTypeEthernet, kEthernetMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt,
kUnknownRtt, kUnknownThroughputMbps, SaveData::kOff);
EXPECT_TRUE(VerifyObservations(
observer1, kWebConnectionTypeBluetooth, kBluetoothMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt, kUnknownRtt,
kUnknownThroughputMbps, SaveData::kOff));
EXPECT_TRUE(VerifyObservations(
observer2, kWebConnectionTypeEthernet, kEthernetMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt, kUnknownRtt,
kUnknownThroughputMbps, SaveData::kOff));
}

TEST_F(NetworkStateNotifierTest, RemoveFutureObserverWhileNotifying) {
StateObserver observer1, observer2, observer3;
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle1 =
notifier_.AddConnectionObserver(&observer1, GetTaskRunner());
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle2 =
notifier_.AddConnectionObserver(&observer2, GetTaskRunner());
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle3 =
notifier_.AddConnectionObserver(&observer3, GetTaskRunner());
observer1.RemoveObserverOnNotification(std::move(handle2));

SetConnection(kWebConnectionTypeBluetooth, kBluetoothMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt,
kUnknownRtt, kUnknownThroughputMbps, SaveData::kOff);
EXPECT_TRUE(VerifyObservations(
observer1, kWebConnectionTypeBluetooth, kBluetoothMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt, kUnknownRtt,
kUnknownThroughputMbps, SaveData::kOff));
EXPECT_TRUE(VerifyObservations(
observer2, kWebConnectionTypeNone, kNoneMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt, kUnknownRtt,
kUnknownThroughputMbps, SaveData::kOff));
EXPECT_TRUE(VerifyObservations(
observer3, kWebConnectionTypeBluetooth, kBluetoothMaxBandwidthMbps,
WebEffectiveConnectionType::kTypeUnknown, kUnknownRtt, kUnknownRtt,
kUnknownThroughputMbps, SaveData::kOff));
}

// It should be safe to remove multiple observers in one iteration.
TEST_F(NetworkStateNotifierTest, RemoveMultipleObserversWhileNotifying) {
StateObserver observer1, observer2, observer3;
Expand Down Expand Up @@ -1096,4 +1042,49 @@ TEST_F(NetworkStateNotifierTest, SetNetInfoHoldback) {
EXPECT_EQ(0.075, notifier_.GetWebHoldbackDownlinkThroughputMbps().value());
}

// Verify dangling pointer conditions: http://crbug.com/1278708
TEST_F(NetworkStateNotifierTest, RemoveObserverBeforeNotifying) {
base::test::SingleThreadTaskEnvironment task_environment;
scoped_refptr<FakeTaskRunner> task_runner =
base::MakeRefCounted<FakeTaskRunner>();

std::unique_ptr<StateObserver> observer = std::make_unique<StateObserver>();
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle =
notifier_.AddOnLineObserver(observer.get(), task_runner);

SetOnLine(true);
handle.reset();
observer.reset();
task_runner->RunUntilIdle();
}

class OnlineStateObserver : public NetworkStateNotifier::NetworkStateObserver {
public:
void OnLineStateChange(bool on_line) override {
count++;
handle_.reset();
task_runner_->RunUntilIdle();
}
FakeTaskRunner* task_runner_;
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle_;
int count = 0;
};

TEST_F(NetworkStateNotifierTest, RemoveObserverWhileNotifying) {
base::test::SingleThreadTaskEnvironment task_environment;
scoped_refptr<FakeTaskRunner> task_runner =
base::MakeRefCounted<FakeTaskRunner>();

OnlineStateObserver observer;
observer.task_runner_ = task_runner.get();
std::unique_ptr<NetworkStateNotifier::NetworkStateObserverHandle> handle =
notifier_.AddOnLineObserver(&observer, task_runner);
observer.handle_ = std::move(handle);

SetOnLine(true);
SetOnLine(false);
task_runner->RunUntilIdle();
EXPECT_EQ(1, observer.count);
}

} // namespace blink

0 comments on commit 9c3ac3d

Please sign in to comment.