Skip to content

Commit

Permalink
Add conditional callback, use atomic handleid and add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
devbharat committed May 20, 2024
1 parent 10876aa commit d0be5a8
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/mavsdk/core/callback_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ template<typename... Args> class CallbackList {

Handle<Args...> subscribe(const std::function<void(Args...)>& callback);
void unsubscribe(Handle<Args...> handle);
void subscribe_conditional(const std::function<bool(Args...)>& callback);
void operator()(Args... args);
[[nodiscard]] bool empty();
void clear();
Expand Down
6 changes: 6 additions & 0 deletions src/mavsdk/core/callback_list.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ template<typename... Args> void CallbackList<Args...>::unsubscribe(Handle<Args..
_impl->unsubscribe(handle);
}

template<typename... Args>
void CallbackList<Args...>::subscribe_conditional(const std::function<bool(Args...)>& callback)
{
_impl->subscribe_conditional(callback);
}

template<typename... Args> void CallbackList<Args...>::operator()(Args... args)
{
_impl->exec(args...);
Expand Down
100 changes: 94 additions & 6 deletions src/mavsdk/core/callback_list_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>
#include <functional>
#include <atomic>
#include <memory>
#include <mutex>
#include <utility>
Expand All @@ -16,15 +17,25 @@ template<typename... Args> class CallbackListImpl {
Handle<Args...> subscribe(const std::function<void(Args...)>& callback)
{
check_removals();
process_subscriptions();

// We need to return a handle, even if the callback is nullptr to
// unsubscribe. That's fine, the handle just won't remove anything
// when/if used later.
auto handle = Handle<Args...>(_last_id++);
auto handleId = _last_id.fetch_add(1, std::memory_order_relaxed);
auto handle = Handle<Args...>(handleId);

if (callback != nullptr) {
std::lock_guard<std::mutex> lock(_mutex);
_list.emplace_back(handle, callback);
if (_mutex.try_lock()) {
// We've acquired the lock without blocking, so we can modify the list.
_list.emplace_back(handle, callback);
_mutex.unlock();
} else {
// We couldn't acquire the lock because we're likely in a callback.
// Defer the subscription.
std::lock_guard<std::mutex> lock(_subscribe_later_mutex);
_subscribe_later.emplace_back(handle, callback);
}
} else {
LogErr() << "Use new unsubscribe methods instead of subscribe(nullptr)\n"
<< "See: https://mavsdk.mavlink.io/main/en/cpp/api_changes.html#unsubscribe";
Expand Down Expand Up @@ -58,21 +69,70 @@ template<typename... Args> class CallbackListImpl {
std::lock_guard<std::mutex> remove_later_lock(_remove_later_mutex);
_remove_later.push_back(handle._id);
}

// check and remove from deferred lock list if present
std::lock_guard<std::mutex> later_lock(_subscribe_later_mutex);
_subscribe_later.erase(
std::remove_if(
_subscribe_later.begin(),
_subscribe_later.end(),
[&](const auto& pair) { return pair.first._id == handle._id; }),
_subscribe_later.end());
}

/**
* @brief Subscribe a new conditional callback to the list. Conditional callbacks
* automatically unsubscribe if the callback evaluates to true, so the user does not
* have to manage a handle for 'one-shot' callbacks.
*
* @param callback The callback function to subscribe.
* @return void Since removal is handled internally, we dont need to expose the Handle.
*/
void subscribe_conditional(const std::function<bool(Args...)>& callback)
{
check_removals();
process_subscriptions();

if (callback != nullptr) {
if (_mutex.try_lock()) {
_cond_cb_list.emplace_back(callback);
_mutex.unlock();
} else {
// We couldn't acquire the lock because we're likely in a callback.
// Defer the subscription.
std::lock_guard<std::mutex> lock(_subscribe_later_mutex);
_subscribe_later_cond.emplace_back(callback);
}
} else {
try_clear();
}
}

void exec(Args... args)
{
check_removals();
process_subscriptions();

std::lock_guard<std::mutex> lock(_mutex);
for (const auto& pair : _list) {
pair.second(args...);
}

for (auto it = _cond_cb_list.begin(); it != _cond_cb_list.end();) {
if ((*it)(args...)) {
// If the callback returns true, remove it based on the iterator
it = _cond_cb_list.erase(it);
} else {
// Otherwise, move to the next element
++it;
}
}
}

void queue(Args... args, const std::function<void(const std::function<void()>&)>& queue_func)
{
check_removals();
process_subscriptions();

std::lock_guard<std::mutex> lock(_mutex);

Expand All @@ -84,15 +144,17 @@ template<typename... Args> class CallbackListImpl {
bool empty()
{
check_removals();
process_subscriptions();

std::lock_guard<std::mutex> lock(_mutex);
return _list.empty();
return _list.empty() && _cond_cb_list.empty();
}

void clear()
{
std::lock_guard<std::mutex> lock(_mutex);
return _list.clear();
_list.clear();
_cond_cb_list.clear();
}

private:
Expand All @@ -106,6 +168,7 @@ template<typename... Args> class CallbackListImpl {
if (_remove_all_later) {
_remove_all_later = false;
_list.clear();
_cond_cb_list.clear();
_remove_later.clear();
}

Expand All @@ -121,6 +184,25 @@ template<typename... Args> class CallbackListImpl {
}
}

void process_subscriptions()
{
std::lock_guard<std::mutex> subscribe_later_lock(_subscribe_later_mutex);

if (_mutex.try_lock()) {
for (const auto& sub : _subscribe_later) {
_list.emplace_back(sub);
}
_subscribe_later.clear();

// add conditional callbacks
for (const auto& sub : _subscribe_later_cond) {
_cond_cb_list.emplace_back(sub);
}
_subscribe_later_cond.clear();
_mutex.unlock();
}
}

void try_clear()
{
std::unique_lock<std::mutex> lock(_mutex, std::try_to_lock);
Expand All @@ -133,11 +215,17 @@ template<typename... Args> class CallbackListImpl {
}

mutable std::mutex _mutex{};
uint64_t _last_id{1}; // Start at 1 because 0 is the "null handle"
std::atomic<uint64_t> _last_id{1}; // Start at 1 because 0 is the "null handle"
std::vector<std::pair<Handle<Args...>, std::function<void(Args...)>>> _list{};
std::vector<std::function<bool(Args...)>> _cond_cb_list{};

mutable std::mutex _remove_later_mutex{};
std::vector<uint64_t> _remove_later{};

std::mutex _subscribe_later_mutex;
std::vector<std::pair<Handle<Args...>, std::function<void(Args...)>>> _subscribe_later;
std::vector<std::function<bool(Args...)>> _subscribe_later_cond;

bool _remove_all_later{false};
};

Expand Down
Loading

0 comments on commit d0be5a8

Please sign in to comment.