diff --git a/include/condy/channel.hpp b/include/condy/channel.hpp index eadc51aa..9d441f79 100644 --- a/include/condy/channel.hpp +++ b/include/condy/channel.hpp @@ -368,7 +368,7 @@ class Channel::PushFinishHandle // Successfully canceled assert(result_ == -ENOTRECOVERABLE); result_ = -ECANCELED; - runtime_->resume_work(); + need_resume_ = true; runtime_->schedule(this); } } @@ -427,7 +427,7 @@ class Channel::PopFinishHandle // Successfully canceled assert(result_.first == -ENOTRECOVERABLE); result_.first = -ECANCELED; - runtime_->resume_work(); + need_resume_ = true; runtime_->schedule(this); } } diff --git a/include/condy/runtime.hpp b/include/condy/runtime.hpp index db09531f..940b6cb2 100644 --- a/include/condy/runtime.hpp +++ b/include/condy/runtime.hpp @@ -195,7 +195,8 @@ class Runtime { // Fast path: if the ring is enabled, we can directly schedule the // work tsan_release(work); - schedule_msg_ring_(curr_runtime, work, WorkType::Schedule); + schedule_msg_ring_(curr_runtime, + encode_work(work, WorkType::Schedule)); } else { // Slow path: if the ring is not enabled, we need to acquire the // mutex to ensure the work is scheduled before the ring is enabled @@ -204,13 +205,33 @@ class Runtime { if (state == State::Enabled) { lock.unlock(); tsan_release(work); - schedule_msg_ring_(curr_runtime, work, WorkType::Schedule); + schedule_msg_ring_(curr_runtime, + encode_work(work, WorkType::Schedule)); } else { global_queue_.push_back(work); } } } + // Internal use only. Schedule a cancel request for the given data. + void cancel(void *data) noexcept { + // Ensure align of 8 for encoding + assert(reinterpret_cast(data) % 8 == 0); + auto *curr_runtime = detail::Context::current().runtime(); + if (curr_runtime == this) { + io_uring_sqe *sqe = ring_.get_sqe(); + prep_cancel_(sqe, data); + return; + } + + auto state = state_.load(); + if (state != State::Enabled) { + return; + } + + schedule_msg_ring_(curr_runtime, encode_work(data, WorkType::Cancel)); + } + void pend_work() noexcept { pending_works_++; } void resume_work() noexcept { pending_works_--; } @@ -291,15 +312,15 @@ class Runtime { auto &settings() noexcept { return ring_.settings(); } private: - void schedule_msg_ring_(Runtime *curr_runtime, WorkInvoker *work, - WorkType type) noexcept { + void schedule_msg_ring_(Runtime *curr_runtime, void *data) noexcept { + int ring_fd = this->ring_.ring()->ring_fd; if (curr_runtime != nullptr) { io_uring_sqe *sqe = curr_runtime->ring_.get_sqe(); - prep_msg_ring_(sqe, work, type); + prep_msg_ring_(ring_fd, sqe, data); curr_runtime->pend_work(); } else { io_uring_sqe sqe = {}; - prep_msg_ring_(&sqe, work, type); + prep_msg_ring_(ring_fd, &sqe, data); int r = detail::sync_msg_ring(&sqe); if (r < 0) { panic_on(std::format("sync_msg_ring: {}", std::strerror(-r))); @@ -319,21 +340,27 @@ class Runtime { return; } - schedule_msg_ring_(curr_runtime, nullptr, WorkType::Ignore); + schedule_msg_ring_(curr_runtime, + encode_work(nullptr, WorkType::Ignore)); } void flush_global_queue_() noexcept { local_queue_.push_back(std::move(global_queue_)); } - void prep_msg_ring_(io_uring_sqe *sqe, WorkInvoker *work, - WorkType type) noexcept { - auto data = encode_work(work, type); - io_uring_prep_msg_ring(sqe, this->ring_.ring()->ring_fd, 0, + static void prep_msg_ring_(int ring_fd, io_uring_sqe *sqe, + void *data) noexcept { + io_uring_prep_msg_ring(sqe, ring_fd, 0, reinterpret_cast(data), 0); io_uring_sqe_set_data(sqe, encode_work(nullptr, WorkType::Schedule)); } + static void prep_cancel_(io_uring_sqe *sqe, void *data) noexcept { + io_uring_prep_cancel(sqe, data, 0); + io_uring_sqe_set_data(sqe, encode_work(nullptr, WorkType::Ignore)); + io_uring_sqe_set_flags(sqe, IOSQE_CQE_SKIP_SUCCESS); + } + void flush_ring_() noexcept { auto r = ring_.reap_completions( [this](io_uring_cqe *cqe) { process_cqe_(cqe); }); @@ -371,6 +398,9 @@ class Runtime { tsan_acquire(data); (*work)(); } + } else if (type == WorkType::Cancel) { + io_uring_sqe *sqe = ring_.get_sqe(); + prep_cancel_(sqe, data); } else if (type == WorkType::Common) { auto *handle = static_cast(data); auto op_finish = handle->handle(cqe); diff --git a/include/condy/work_type.hpp b/include/condy/work_type.hpp index 4b5de756..2f582935 100644 --- a/include/condy/work_type.hpp +++ b/include/condy/work_type.hpp @@ -14,7 +14,13 @@ enum class WorkType : uint8_t { Common, Ignore, Schedule, + Cancel, + + // Add new work types above this line + WorkTypeMax, }; +static_assert(static_cast(WorkType::WorkTypeMax) <= 8, + "WorkType must fit in 3 bits"); inline std::pair decode_work(void *ptr) noexcept { intptr_t mask = (1 << 3) - 1; diff --git a/tests/test_runtime.cpp b/tests/test_runtime.cpp index ca983974..a4af6743 100644 --- a/tests/test_runtime.cpp +++ b/tests/test_runtime.cpp @@ -236,5 +236,70 @@ TEST_CASE("test runtime - allow_exit from other thread") { runtime.allow_exit(); + t1.join(); +} + +TEST_CASE("test runtime - cancel from other task") { + condy::Runtime runtime(options); + + auto cancel_task = [&](void *ptr) -> condy::Coro { + runtime.cancel(ptr); + co_return; + }; + auto func = [&]() -> condy::Coro { + __kernel_timespec ts{ + .tv_sec = 60ll * 60ll, + .tv_nsec = 0, + }; + auto aw = + condy::detail::make_op_awaiter(io_uring_prep_timeout, &ts, 0, 0); + void *ptr = aw.get_handle(); + auto t = condy::co_spawn(cancel_task(ptr)); + co_await aw; + co_await t; + }; + + condy::co_spawn(runtime, func()).detach(); + + runtime.allow_exit(); + runtime.run(); +} + +TEST_CASE("test runtime - cancel from other thread") { + condy::Runtime runtime(options); + + std::atomic_bool r1_started = false; + void *ptr = nullptr; + + auto notify_task = [&]() -> condy::Coro { + r1_started = true; + r1_started.notify_one(); + co_return; + }; + + auto func = [&]() -> condy::Coro { + __kernel_timespec ts{ + .tv_sec = 60ll * 60ll, + .tv_nsec = 0, + }; + auto aw = + condy::detail::make_op_awaiter(io_uring_prep_timeout, &ts, 0, 0); + ptr = aw.get_handle(); + auto t = condy::co_spawn(runtime, notify_task()); + co_await aw; + co_await t; + }; + + condy::co_spawn(runtime, func()).detach(); + + std::thread t1([&]() { + runtime.allow_exit(); + runtime.run(); + }); + + r1_started.wait(false); + + runtime.cancel(ptr); + t1.join(); } \ No newline at end of file diff --git a/tests/test_work_type.cpp b/tests/test_work_type.cpp index 8765f620..b57e6baf 100644 --- a/tests/test_work_type.cpp +++ b/tests/test_work_type.cpp @@ -26,4 +26,5 @@ TEST_CASE("test work_type - encode and decode") { test_type(condy::WorkType::Common); test_type(condy::WorkType::Ignore); test_type(condy::WorkType::Schedule); + test_type(condy::WorkType::Cancel); } \ No newline at end of file