Skip to content

Commit

Permalink
added move support for channels
Browse files Browse the repository at this point in the history
Closes #183
  • Loading branch information
klemens-morgenstern committed May 30, 2024
1 parent 4009635 commit 9502d09
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 12 deletions.
8 changes: 7 additions & 1 deletion include/boost/cobalt/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ struct channel
struct write_op : intrusive::list_base_hook<intrusive::link_mode<intrusive::auto_unlink> >
{
channel * chn;
variant2::variant<T*, const T*> ref;
using ref_t = std::conditional_t<
std::is_copy_constructible_v<T>,
variant2::variant<T*, const T*>,
T*>;
ref_t ref;
boost::source_location loc;
bool cancelled = false, direct = false;
asio::cancellation_slot cancel_slot{};
Expand Down Expand Up @@ -131,10 +135,12 @@ struct channel
public:
read_op read(const boost::source_location & loc = BOOST_CURRENT_LOCATION) {return read_op{{}, this, loc}; }
write_op write(const T && value, const boost::source_location & loc = BOOST_CURRENT_LOCATION)
requires std::is_copy_constructible_v<T>
{
return write_op{{}, this, &value, loc};
}
write_op write(const T & value, const boost::source_location & loc = BOOST_CURRENT_LOCATION)
requires std::is_copy_constructible_v<T>
{
return write_op{{}, this, &value, loc};
}
Expand Down
35 changes: 25 additions & 10 deletions include/boost/cobalt/impl/channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,15 @@ std::coroutine_handle<void> channel<T>::read_op::await_suspend(std::coroutine_ha
auto & op = chn->write_queue_.front();
op.transactional_unlink();
op.direct = true;
if (op.ref.index() == 0)
direct = std::move(*variant2::get<0>(op.ref));
if constexpr (std::is_copy_constructible_v<T>)
{
if (op.ref.index() == 0)
direct = std::move(*variant2::get<0>(op.ref));
else
direct = *variant2::get<1>(op.ref);
}
else
direct = *variant2::get<1>(op.ref);
direct = std::move(*op.ref);
BOOST_ASSERT(op.awaited_from);
asio::post(chn->executor_, std::move(awaited_from));
return op.awaited_from.release();
Expand Down Expand Up @@ -171,7 +176,7 @@ system::result<T> channel<T>::read_op::await_resume(const struct as_result_tag &
asio::post(chn->executor_, std::move(op.awaited_from));
}
}
return {system::in_place_value, value};
return {system::in_place_value, std::move(value)};
}

template<typename T>
Expand Down Expand Up @@ -213,10 +218,15 @@ std::coroutine_handle<void> channel<T>::write_op::await_suspend(std::coroutine_h
cancel_slot.clear();
auto & op = chn->read_queue_.front();
op.transactional_unlink();
if (ref.index() == 0)
op.direct = std::move(*variant2::get<0>(ref));
if constexpr (std::is_copy_constructible_v<T>)
{
if (ref.index() == 0)
op.direct.emplace(std::move(*variant2::get<0>(ref)));
else
op.direct.emplace(*variant2::get<1>(ref));
}
else
op.direct = *variant2::get<1>(ref);
op.direct.emplace(std::move(*ref));

BOOST_ASSERT(op.awaited_from);
direct = true;
Expand Down Expand Up @@ -250,10 +260,15 @@ system::result<void> channel<T>::write_op::await_resume(const struct as_result_
if (!direct)
{
BOOST_ASSERT(!chn->buffer_.full());
if (ref.index() == 0)
chn->buffer_.push_back(std::move(*variant2::get<0>(ref)));
if constexpr (std::is_copy_constructible_v<T>)
{
if (ref.index() == 0)
chn->buffer_.push_back(std::move(*variant2::get<0>(ref)));
else
chn->buffer_.push_back(*variant2::get<1>(ref));
}
else
chn->buffer_.push_back(*variant2::get<1>(ref));
chn->buffer_.push_back(std::move(*ref));
}

if (!chn->read_queue_.empty())
Expand Down
29 changes: 28 additions & 1 deletion test/channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,34 @@ CO_TEST_CASE(reader)

}

}


BOOST_AUTO_TEST_SUITE_END();

namespace boost::cobalt
{

struct move_only
{
move_only() {}
move_only(move_only &&) {}
move_only& operator=(move_only &&) {return * this;}
};

template struct channel<move_only>;

CO_TEST_CASE(unique)
{
std::unique_ptr<int> p{new int(42)};
auto pi = p.get();
cobalt::channel<std::unique_ptr<int>> c{1u};

co_await c.write(std::move(p));
auto p2 = co_await c.read();

BOOST_CHECK(p == nullptr);
BOOST_CHECK(p2.get() == pi);
}

BOOST_AUTO_TEST_SUITE_END();
}

0 comments on commit 9502d09

Please sign in to comment.