Skip to content

Commit

Permalink
channel: unify the return values of try_* and *
Browse files Browse the repository at this point in the history
  • Loading branch information
sebsura authored and BareosBot committed Sep 22, 2023
1 parent 244c621 commit 706c47a
Showing 1 changed file with 127 additions and 107 deletions.
234 changes: 127 additions & 107 deletions core/src/lib/channel.h
Expand Up @@ -39,7 +39,6 @@ namespace channel {
// This ensures that there is only one producer (who writes to the input)
// and one consumer (who reads from the output).

struct failed_to_acquire_lock {};
struct channel_closed {};

template <typename T> class queue {
Expand Down Expand Up @@ -74,12 +73,22 @@ template <typename T> class queue {
}
handle(const handle&) = delete;
handle& operator=(const handle&) = delete;
handle(handle&&) = delete;
handle& operator=(handle&&) = delete;
handle(handle&& that) : locked{std::move(that.locked)}, update{that.update}
{
that.update = nullptr;
}
handle& operator=(handle&& that)
{
locked = std::move(that.locked);
update = std::exchange(that.update, nullptr);
}

std::vector<T>& data() { return locked->data; }

~handle() { update->notify_one(); }
~handle()
{
if (update) { update->notify_one(); }
}
};

/* the *_lock functions return std::nullopt only if the channel is closed
Expand All @@ -88,54 +97,35 @@ template <typename T> class queue {
* they succeeded, failed to acquire the lock or if the channel was closed
* from the other side. */

std::optional<handle> output_lock()
using result_type = std::variant<handle, channel_closed>;

result_type output_lock()
{
auto locked = shared.lock();
if (locked->out_dead) {
// note(ssura): This happening is programmer error.
// Maybe we should assert this instead ?
Dmsg0(50,
"Tried to read from channel that was closed from the read side.\n");
return std::nullopt;
return channel_closed{};
}

locked.wait(in_update, [](const auto& intern) {
return intern.data.size() > 0 || intern.in_dead;
locked.wait(in_update, [](const auto& queue) {
return queue.data.size() > 0 || queue.in_dead;
});
if (locked->data.size() == 0) {
return std::nullopt;
} else {
return std::make_optional<handle>(std::move(locked), &out_update);
}
}

std::optional<handle> input_lock()
{
auto locked = shared.lock();
locked.wait(out_update, [max_size = max_size](const auto& intern) {
return intern.data.size() < max_size || intern.out_dead;
});
if (locked->in_dead) {
// note(ssura): This happening is programmer error.
// Maybe we should assert this instead ?
Dmsg0(50,
"Tried to write to channel that was closed from the write side.\n");
return std::nullopt;
}
if (locked->out_dead) {
return std::nullopt;
if (locked->data.size() == 0) {
return channel_closed{};
} else {
return std::make_optional<handle>(std::move(locked), &in_update);
return result_type(std::in_place_type<handle>, std::move(locked),
&out_update);
}
}

using try_result
= std::variant<handle, failed_to_acquire_lock, channel_closed>;

try_result try_output_lock()
std::optional<result_type> try_output_lock()
{
auto locked = shared.try_lock();
if (!locked) { return failed_to_acquire_lock{}; }
if (!locked) { return std::nullopt; }
if (locked.value()->out_dead) {
// note(ssura): This happening is programmer error.
// Maybe we should assert this instead ?
Expand All @@ -147,18 +137,39 @@ template <typename T> class queue {
if (locked.value()->in_dead) {
return channel_closed{};
} else {
return failed_to_acquire_lock{};
return std::nullopt;
}
}

return try_result(std::in_place_type<handle>, std::move(locked).value(),
&out_update);
return result_type(std::in_place_type<handle>, std::move(locked).value(),
&out_update);
}

result_type input_lock()
{
auto locked = shared.lock();
locked.wait(out_update, [max_size = max_size](const auto& queue) {
return queue.data.size() < max_size || queue.out_dead;
});
if (locked->in_dead) {
// note(ssura): This happening is programmer error.
// Maybe we should assert this instead ?
Dmsg0(50,
"Tried to write to channel that was closed from the write side.\n");
return channel_closed{};
}
if (locked->out_dead) {
return channel_closed{};
} else {
return result_type(std::in_place_type<handle>, std::move(locked),
&in_update);
}
}

try_result try_input_lock()
std::optional<result_type> try_input_lock()
{
auto locked = shared.try_lock();
if (!locked) { return failed_to_acquire_lock{}; }
if (!locked) { return std::nullopt; }
if (locked.value()->in_dead) {
// note(ssura): This happening is programmer error.
// Maybe we should assert this instead ?
Expand All @@ -167,12 +178,10 @@ template <typename T> class queue {
return channel_closed{};
}
if (locked.value()->out_dead) { return channel_closed{}; }
if (locked.value()->data.size() >= max_size) {
return failed_to_acquire_lock{};
}
if (locked.value()->data.size() >= max_size) { return std::nullopt; }

return try_result(std::in_place_type<handle>, std::move(locked).value(),
&in_update);
return result_type(std::in_place_type<handle>, std::move(locked).value(),
&in_update);
}

void close_in()
Expand All @@ -191,6 +200,8 @@ template <typename T> class queue {
template <typename T> class input {
std::shared_ptr<queue<T>> shared;
bool did_close{false};
using handle_type = typename queue<T>::handle;
using result_type = typename queue<T>::result_type;

public:
explicit input(std::shared_ptr<queue<T>> shared) : shared{std::move(shared)}
Expand All @@ -205,29 +216,19 @@ template <typename T> class input {
template <typename... Args> bool emplace(Args... args)
{
if (did_close) { return false; }
if (auto handle = shared->input_lock()) {
handle->data().emplace_back(std::forward<Args>(args)...);
return true;
} else {
close();
return false;
}
auto result = shared->input_lock();
return do_emplace(result, std::forward<Args>(args)...);
}

template <typename... Args> bool try_emplace(Args... args)
{
if (did_close) { return false; }
auto result = shared->try_input_lock();
if (std::holds_alternative<failed_to_acquire_lock>(result)) {
return false;
} else if (std::holds_alternative<channel_closed>(result)) {
close();
return false;
} else {
std::get<typename queue<T>::handle>(result).data().emplace_back(
std::forward<Args>(args)...);
return true;

if (auto result = shared->try_input_lock()) {
return do_emplace(result.value(), std::forward<Args>(args)...);
}

return false;
}

void close()
Expand All @@ -244,13 +245,41 @@ template <typename T> class input {
{
if (shared) { close(); }
}

private:
template <typename... Args>
inline bool do_emplace(result_type& result, Args... args)
{
return std::visit(
[this, &args...](auto&& val) {
using val_type = std::decay_t<decltype(val)>;
if constexpr (std::is_same_v<val_type, channel_closed>) {
close();
return false;
} else if constexpr (std::is_same_v<val_type, handle_type>) {
val.data().emplace_back(std::forward<Args>(args)...);
return true;
} else {
static_assert("Type not handled");
}
},
result);
}
};

template <typename T> class output {
std::shared_ptr<queue<T>> shared;
std::vector<T> cache{};
typename decltype(cache)::iterator cache_iter = cache.begin();
bool did_close{false};
using handle_type = typename queue<T>::handle;
using result_type = typename queue<T>::result_type;

enum class with_lock
{
No,
Yes,
};

public:
explicit output(std::shared_ptr<queue<T>> shared) : shared{std::move(shared)}
Expand All @@ -262,31 +291,9 @@ template <typename T> class output {
output(const output&) = delete;
output& operator=(const output&) = delete;

std::optional<T> get()
{
if (did_close) { return std::nullopt; }
update_cache();
std::optional<T> get() { return get_internal(with_lock::Yes); }

if (cache_iter != cache.end()) {
std::optional result = std::make_optional<T>(std::move(*cache_iter++));
return result;
} else {
return std::nullopt;
}
}

std::optional<T> try_get()
{
if (did_close) { return std::nullopt; }
try_update_cache();

if (cache_iter != cache.end()) {
std::optional result = std::make_optional<T>(std::move(*cache_iter++));
return result;
} else {
return std::nullopt;
}
}
std::optional<T> try_get() { return get_internal(with_lock::No); }

void close()
{
Expand All @@ -306,36 +313,49 @@ template <typename T> class output {
}

private:
void do_update_cache(std::vector<T>& data)
std::optional<T> get_internal(with_lock lock)
{
cache.clear();
std::swap(data, cache);
cache_iter = cache.begin();
if (did_close) { return std::nullopt; }
update_cache(lock);

if (cache_iter != cache.end()) {
std::optional result = std::make_optional<T>(std::move(*cache_iter++));
return result;
} else {
return std::nullopt;
}
}

void update_cache()
inline void do_update_cache(result_type& result)
{
if (cache_iter == cache.end()) {
if (auto handle = shared->output_lock()) {
do_update_cache(handle->data());
} else {
// this can only happen if the channel was closed.
close();
}
}
std::visit(
[this](auto&& val) {
using val_type = std::decay_t<decltype(val)>;
if constexpr (std::is_same_v<val_type, channel_closed>) {
close();
} else if constexpr (std::is_same_v<val_type, handle_type>) {
cache.clear();
std::swap(cache, val.data());
cache_iter = cache.begin();
} else {
static_assert("Type not handled");
}
},
result);
}

void try_update_cache()
void update_cache(with_lock lock)
{
if (cache_iter == cache.end()) {
auto result = shared->try_output_lock();
if (std::holds_alternative<failed_to_acquire_lock>(result)) {
// intentionally left empty
} else if (std::holds_alternative<channel_closed>(result)) {
close();
using handle_t = typename queue<T>::handle;
std::optional<handle_t> handle;
if (lock == with_lock::No) {
if (auto result = shared->try_output_lock()) {
do_update_cache(result.value());
}
} else {
auto& handle = std::get<typename queue<T>::handle>(result);
do_update_cache(handle.data());
auto result = shared->output_lock();
do_update_cache(result);
}
}
}
Expand Down

0 comments on commit 706c47a

Please sign in to comment.