diff --git a/src/mavsdk/core/callback_list.h b/src/mavsdk/core/callback_list.h index 9f8f80b74..1ab95fae9 100644 --- a/src/mavsdk/core/callback_list.h +++ b/src/mavsdk/core/callback_list.h @@ -17,6 +17,7 @@ template class CallbackList { Handle subscribe(const std::function& callback); void unsubscribe(Handle handle); + void subscribe_conditional(const std::function& callback); void operator()(Args... args); [[nodiscard]] bool empty(); void clear(); diff --git a/src/mavsdk/core/callback_list.tpp b/src/mavsdk/core/callback_list.tpp index 51bc4abf0..4ee56b85e 100644 --- a/src/mavsdk/core/callback_list.tpp +++ b/src/mavsdk/core/callback_list.tpp @@ -25,6 +25,12 @@ template void CallbackList::unsubscribe(Handleunsubscribe(handle); } +template +void CallbackList::subscribe_conditional(const std::function& callback) +{ + _impl->subscribe_conditional(callback); +} + template void CallbackList::operator()(Args... args) { _impl->exec(args...); diff --git a/src/mavsdk/core/callback_list_impl.h b/src/mavsdk/core/callback_list_impl.h index ec62ece00..1bfe04ba4 100644 --- a/src/mavsdk/core/callback_list_impl.h +++ b/src/mavsdk/core/callback_list_impl.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -16,15 +17,25 @@ template class CallbackListImpl { Handle subscribe(const std::function& callback) { check_removals(); + process_subscriptions(); // We need to return a handle, even if the callback is nullptr to // unsubscribe. That's fine, the handle just won't remove anything // when/if used later. - auto handle = Handle(_last_id++); + auto handleId = _last_id.fetch_add(1, std::memory_order_relaxed); + auto handle = Handle(handleId); if (callback != nullptr) { - std::lock_guard lock(_mutex); - _list.emplace_back(handle, callback); + if (_mutex.try_lock()) { + // We've acquired the lock without blocking, so we can modify the list. + _list.emplace_back(handle, callback); + _mutex.unlock(); + } else { + // We couldn't acquire the lock because we're likely in a callback. + // Defer the subscription. + std::lock_guard lock(_subscribe_later_mutex); + _subscribe_later.emplace_back(handle, callback); + } } else { LogErr() << "Use new unsubscribe methods instead of subscribe(nullptr)\n" << "See: https://mavsdk.mavlink.io/main/en/cpp/api_changes.html#unsubscribe"; @@ -58,21 +69,70 @@ template class CallbackListImpl { std::lock_guard remove_later_lock(_remove_later_mutex); _remove_later.push_back(handle._id); } + + // check and remove from deferred lock list if present + std::lock_guard later_lock(_subscribe_later_mutex); + _subscribe_later.erase( + std::remove_if( + _subscribe_later.begin(), + _subscribe_later.end(), + [&](const auto& pair) { return pair.first._id == handle._id; }), + _subscribe_later.end()); + } + + /** + * @brief Subscribe a new conditional callback to the list. Conditional callbacks + * automatically unsubscribe if the callback evaluates to true, so the user does not + * have to manage a handle for 'one-shot' callbacks. + * + * @param callback The callback function to subscribe. + * @return void Since removal is handled internally, we dont need to expose the Handle. + */ + void subscribe_conditional(const std::function& callback) + { + check_removals(); + process_subscriptions(); + + if (callback != nullptr) { + if (_mutex.try_lock()) { + _cond_cb_list.emplace_back(callback); + _mutex.unlock(); + } else { + // We couldn't acquire the lock because we're likely in a callback. + // Defer the subscription. + std::lock_guard lock(_subscribe_later_mutex); + _subscribe_later_cond.emplace_back(callback); + } + } else { + try_clear(); + } } void exec(Args... args) { check_removals(); + process_subscriptions(); std::lock_guard lock(_mutex); for (const auto& pair : _list) { pair.second(args...); } + + for (auto it = _cond_cb_list.begin(); it != _cond_cb_list.end();) { + if ((*it)(args...)) { + // If the callback returns true, remove it based on the iterator + it = _cond_cb_list.erase(it); + } else { + // Otherwise, move to the next element + ++it; + } + } } void queue(Args... args, const std::function&)>& queue_func) { check_removals(); + process_subscriptions(); std::lock_guard lock(_mutex); @@ -84,15 +144,17 @@ template class CallbackListImpl { bool empty() { check_removals(); + process_subscriptions(); std::lock_guard lock(_mutex); - return _list.empty(); + return _list.empty() && _cond_cb_list.empty(); } void clear() { std::lock_guard lock(_mutex); - return _list.clear(); + _list.clear(); + _cond_cb_list.clear(); } private: @@ -106,6 +168,7 @@ template class CallbackListImpl { if (_remove_all_later) { _remove_all_later = false; _list.clear(); + _cond_cb_list.clear(); _remove_later.clear(); } @@ -121,6 +184,25 @@ template class CallbackListImpl { } } + void process_subscriptions() + { + std::lock_guard subscribe_later_lock(_subscribe_later_mutex); + + if (_mutex.try_lock()) { + for (const auto& sub : _subscribe_later) { + _list.emplace_back(sub); + } + _subscribe_later.clear(); + + // add conditional callbacks + for (const auto& sub : _subscribe_later_cond) { + _cond_cb_list.emplace_back(sub); + } + _subscribe_later_cond.clear(); + _mutex.unlock(); + } + } + void try_clear() { std::unique_lock lock(_mutex, std::try_to_lock); @@ -133,11 +215,17 @@ template class CallbackListImpl { } mutable std::mutex _mutex{}; - uint64_t _last_id{1}; // Start at 1 because 0 is the "null handle" + std::atomic _last_id{1}; // Start at 1 because 0 is the "null handle" std::vector, std::function>> _list{}; + std::vector> _cond_cb_list{}; mutable std::mutex _remove_later_mutex{}; std::vector _remove_later{}; + + std::mutex _subscribe_later_mutex; + std::vector, std::function>> _subscribe_later; + std::vector> _subscribe_later_cond; + bool _remove_all_later{false}; }; diff --git a/src/mavsdk/core/callback_list_test.cpp b/src/mavsdk/core/callback_list_test.cpp index ce199ccd4..208423466 100644 --- a/src/mavsdk/core/callback_list_test.cpp +++ b/src/mavsdk/core/callback_list_test.cpp @@ -1,3 +1,11 @@ +#include +#include +#include +#include +#include +#include +#include + #include "callback_list.h" #include "callback_list.tpp" #include "log.h" @@ -15,6 +23,7 @@ TEST(CallbackList, SubscribeCallUnsubscribe) { unsigned first_called = 0; unsigned second_called = 0; + unsigned conditional_called = 0; CallbackList cl; auto first_handle = cl.subscribe([&](int i, double d) { @@ -33,15 +42,27 @@ TEST(CallbackList, SubscribeCallUnsubscribe) EXPECT_LT(d, 112.0); }); + cl.subscribe_conditional([&](int i, double d) -> bool { + ++conditional_called; + EXPECT_GE(i, 42); + EXPECT_LE(i, 44); + EXPECT_GT(d, 77.0); + EXPECT_LT(d, 112.0); + // Return true if the callback should be removed after being called. + return i == 43 && d == 99.0; + }); + // Call both a first time. cl(42, 77.7); EXPECT_EQ(first_called, 1); EXPECT_EQ(second_called, 1); + EXPECT_EQ(conditional_called, 1); // Call both a second time. cl(43, 88.8); EXPECT_EQ(first_called, 2); EXPECT_EQ(second_called, 2); + EXPECT_EQ(conditional_called, 2); // Now we unsubscribe the first one. cl.unsubscribe(first_handle); @@ -49,16 +70,33 @@ TEST(CallbackList, SubscribeCallUnsubscribe) cl(43, 99.9); EXPECT_EQ(first_called, 2); EXPECT_EQ(second_called, 3); + EXPECT_EQ(conditional_called, 3); // Unsubscribing the first once should be ignored. cl.unsubscribe(first_handle); + // This should call and remove the conditional callback + cl(43, 99.0); + EXPECT_EQ(first_called, 2); + EXPECT_EQ(second_called, 4); + EXPECT_EQ(conditional_called, 4); + + // This should not call the conditional callback anymore + cl(43, 99.0); + EXPECT_EQ(first_called, 2); + EXPECT_EQ(second_called, 5); + EXPECT_EQ(conditional_called, 4); + // Now we unsubscribe the second one as well, no more calling. cl.unsubscribe(second_handle); cl(44, 111.1); EXPECT_EQ(first_called, 2); - EXPECT_EQ(second_called, 3); + EXPECT_EQ(second_called, 5); + EXPECT_EQ(conditional_called, 4); + + // Both handles are manually removed and conditional callback is autoremoved + EXPECT_TRUE(cl.empty()); } TEST(CallbackList, UnsubscribeFromCallback) @@ -98,3 +136,147 @@ TEST(CallbackList, UnsubscribeAllWithNullptr) // It should only be called once. EXPECT_EQ(num_called, 1); } + +TEST(CallbackList, UnsubscribeAllWithClear) +{ + // This is to deal with the previous API where nullptr would + // unsubscribe the callback. + unsigned num_called = 0; + unsigned num_called_other = 0; + + CallbackList<> cl; + cl.subscribe([&]() { ++num_called; }); + cl.subscribe([&]() { ++num_called_other; }); + + // Call once. + cl(); + + // Unsubscribe using clear. + cl.clear(); + + // Call again. + cl(); + + // It should only be called once. + EXPECT_EQ(num_called, 1); + EXPECT_EQ(num_called_other, 1); +} + +TEST(CallbackList, SubscribeAndUnsubscribeWithinCallbacks) +{ + const int test_value1 = 42; + const double test_value2 = 3.14; + const double unsub_value = 4.669; + std::atomic callback_count{0}; + std::atomic nested_callback_count{0}; + + CallbackList cl; + + // Lambda function for subscribing within a callback + auto subscribeCallback = [test_value1, test_value2, unsub_value, &nested_callback_count, &cl]( + std::shared_ptr>> handle_ptr) { + *handle_ptr = cl.subscribe( + [&nested_callback_count, &cl, test_value1, test_value2, unsub_value, handle_ptr]( + int i, double d) { + nested_callback_count++; + // Unsubscribe after execution if unsub_value provided + if (d == unsub_value) { + if (handle_ptr->has_value()) { + cl.unsubscribe(handle_ptr->value()); // this will get deferred unsubscription + } + } + }); + }; + + // Main thread subscription + auto first_handle = cl.subscribe( + [&subscribeCallback, &callback_count, &cl, test_value1, test_value2](int i, double d) { + callback_count++; + + // Define a placeholder for the subscription handle + auto handle_ptr = std::make_shared>>(); + subscribeCallback(handle_ptr); // this will get deferred sub + }); + + EXPECT_EQ(callback_count, 0); + EXPECT_EQ(nested_callback_count, 0); + + cl(test_value1, test_value2); // Calls 1st callback that adds 2nd callback + EXPECT_EQ(callback_count, 1); + EXPECT_EQ(nested_callback_count, 0); + + cl(test_value1, test_value2); // Calls 1st callback that adds 3rd callback, calls 2nd callback + EXPECT_EQ(callback_count, 2); + EXPECT_EQ(nested_callback_count, 1); + + // Calls 1st callback that adds 4th callback, calls 2nd callback which increments and unsubscribes, + // calls the 3rd callback which increments and unsubscribes + cl(test_value1, unsub_value); + EXPECT_EQ(callback_count, 3); + EXPECT_EQ(nested_callback_count, 3); + + // Remove the 1st callback + cl.unsubscribe(first_handle); + + // Calls the only remaining 4th callback that increments and unsubscribes + cl(test_value1, unsub_value); + EXPECT_EQ(callback_count, 3); + EXPECT_EQ(nested_callback_count, 4); + + // List is now empty + EXPECT_TRUE(cl.empty()); +} + +TEST(CallbackList, SubscribeAndUnsubscribeWithinCallbacks2) +{ + const int test_value1 = 42; + const double test_value2 = 3.14; + const int thread_count = 50000; + std::atomic callback_count{0}; + std::atomic nested_callback_count{0}; + CallbackList cl; + + // Define the vector type + using HandleType = Handle; + + // Create a vector of std::shared_ptr>> + std::set handle_set; + std::mutex mutex_; + + // Lambda function to subscribe a callback as assign the handle to the arg + auto addSubCallback = [test_value1, test_value2, &nested_callback_count, &cl]( + std::shared_ptr>> handle_ptr) { + *handle_ptr = cl.subscribe( + [&nested_callback_count, &cl, test_value1, test_value2, handle_ptr]( + int i, double d) { + EXPECT_EQ(i, test_value1); + EXPECT_DOUBLE_EQ(d, test_value2); + nested_callback_count.fetch_add(1, std::memory_order_relaxed); + }); + }; + + // Create multiple threads to simulate concurrent notifications + std::vector threads; + for (int i = 0; i < thread_count; ++i) { + threads.emplace_back([&]() { + // Add a burst of subscriptions in a multithreaded env all at once + auto handle_ptr = std::make_shared>>(); + addSubCallback(handle_ptr); // this will get deferred sub + + // safely put them in a vec for later comparison + std::lock_guard guard(mutex_); + if (handle_ptr && handle_ptr->has_value()) { + if (!handle_set.insert(handle_ptr->value()).second) { + EXPECT_TRUE(false); // handles are not unique + } + } else { + EXPECT_TRUE(false); // handles not filled in + } + }); + } + + // Join all the threads + for (auto& thread : threads) { + thread.join(); + } +} \ No newline at end of file diff --git a/src/mavsdk/core/include/mavsdk/handle.h b/src/mavsdk/core/include/mavsdk/handle.h index 6c67c32f5..c77396478 100644 --- a/src/mavsdk/core/include/mavsdk/handle.h +++ b/src/mavsdk/core/include/mavsdk/handle.h @@ -17,6 +17,8 @@ template class Handle { Handle() = default; ~Handle() = default; + bool operator<(const Handle& other) const { return _id < other._id; } + private: bool operator==(const Handle& other) const { return _id == other._id; }