diff --git a/include/condy/task.hpp b/include/condy/task.hpp index 0920b4cd..9deb4612 100644 --- a/include/condy/task.hpp +++ b/include/condy/task.hpp @@ -22,13 +22,22 @@ template class TaskBase { public: using PromiseType = typename Coro::promise_type; + TaskBase() : TaskBase(nullptr) {} TaskBase(std::coroutine_handle h) : handle_(h) {} TaskBase(TaskBase &&other) noexcept : handle_(std::exchange(other.handle_, nullptr)) {} + TaskBase &operator=(TaskBase &&other) noexcept { + if (this != &other) { + if (handle_) { + panic_on("Task destroyed without being awaited"); + } + handle_ = std::exchange(other.handle_, nullptr); + } + return *this; + } TaskBase(const TaskBase &) = delete; TaskBase &operator=(const TaskBase &) = delete; - TaskBase &operator=(TaskBase &&other) = delete; ~TaskBase() { if (handle_) { @@ -49,9 +58,16 @@ template class TaskBase { handle_ = nullptr; } + /** + * @brief Check if the task is still joinable. Similar to + * `std::thread::joinable()`. + */ + bool joinable() const noexcept { return handle_ != nullptr; } + /** * @brief Await the task asynchronously. * @return T The result of the coroutine. + * @throw std::invalid_argument If the task is not joinable. * @throws Any exception thrown inside the coroutine. * @details This function allows the caller to await the completion of the * coroutine associated with the task. It suspends the caller coroutine @@ -73,6 +89,9 @@ void TaskBase::wait_inner_( if (detail::Context::current().runtime() != nullptr) [[unlikely]] { throw std::logic_error("Sync wait inside runtime"); } + if (handle == nullptr) [[unlikely]] { + throw std::invalid_argument("Task not joinable"); + } std::promise prom; auto fut = prom.get_future(); struct TaskWaiter : public InvokerAdapter { @@ -112,6 +131,7 @@ class [[nodiscard]] Task : public TaskBase { /** * @brief Wait synchronously for the task to complete and get the result. * @return T The result of the coroutine. + * @throws std::invalid_argument If the task is not joinable. * @throws Any exception thrown inside the coroutine. * @details This function blocks the current thread until the coroutine * associated with the task completes. It then retrieves the result of the @@ -140,6 +160,7 @@ class [[nodiscard]] Task : public TaskBase { /** * @brief Wait synchronously for the task to complete. + * @throws std::invalid_argument If the task is not joinable. * @throws Any exception thrown inside the coroutine. * @details This function blocks the current thread until the coroutine * associated with the task completes. If the coroutine throws an exception, @@ -164,7 +185,12 @@ struct TaskAwaiterBase : public InvokerAdapter> { Runtime *runtime) : task_handle_(task_handle), runtime_(runtime) {} - bool await_ready() const noexcept { return false; } + bool await_ready() const { + if (task_handle_ == nullptr) { + throw std::invalid_argument("Task not joinable"); + } + return false; + } template bool diff --git a/tests/test_task.cpp b/tests/test_task.cpp index 28ad3243..4e175403 100644 --- a/tests/test_task.cpp +++ b/tests/test_task.cpp @@ -3,6 +3,7 @@ #include "condy/runtime_options.hpp" #include "condy/task.hpp" #include +#include #include namespace { @@ -11,6 +12,47 @@ condy::RuntimeOptions options = condy::RuntimeOptions().sq_size(8).cq_size(16); } // namespace +TEST_CASE("test task - construct") { + condy::Runtime runtime(options); + auto func = []() -> condy::Coro { co_return; }; + + condy::Task task; + REQUIRE(!task.joinable()); + + auto task2 = condy::co_spawn(runtime, func()); + REQUIRE(task2.joinable()); + + task = std::move(task2); + REQUIRE(task.joinable()); + // NOLINTNEXTLINE(bugprone-use-after-move) + REQUIRE(!task2.joinable()); + + task.detach(); + REQUIRE(!task.joinable()); +} + +TEST_CASE("test task - joinable check") { + condy::Runtime runtime(options); + std::thread rt_thread([&]() { runtime.run(); }); + + condy::Task task; + REQUIRE(!task.joinable()); + REQUIRE_THROWS_AS(task.wait(), std::invalid_argument); + + auto func = [&]() -> condy::Coro { + REQUIRE(!task.joinable()); + REQUIRE_THROWS_AS((co_await task), std::invalid_argument); + co_return; + }; + + auto task2 = condy::co_spawn(runtime, func()); + REQUIRE(task2.joinable()); + REQUIRE_NOTHROW(task2.wait()); + + runtime.allow_exit(); + rt_thread.join(); +} + TEST_CASE("test task - local spawn and await") { condy::Runtime runtime(options); bool finished = false;