diff --git a/compio-runtime/Cargo.toml b/compio-runtime/Cargo.toml index a1f780e7..545b0326 100644 --- a/compio-runtime/Cargo.toml +++ b/compio-runtime/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "compio-runtime" -version = "0.9.4" +version = "0.9.5" description = "High-level runtime for compio" categories = ["asynchronous"] keywords = ["async", "runtime"] diff --git a/compio-runtime/src/runtime/scheduler/mod.rs b/compio-runtime/src/runtime/scheduler/mod.rs index 8f4ab45a..3afd9e3c 100644 --- a/compio-runtime/src/runtime/scheduler/mod.rs +++ b/compio-runtime/src/runtime/scheduler/mod.rs @@ -1,4 +1,11 @@ -use std::{cell::RefCell, future::Future, marker::PhantomData, rc::Rc, sync::Arc, task::Waker}; +use std::{ + cell::RefCell, + future::Future, + marker::PhantomData, + rc::Rc, + sync::{Arc, Weak}, + task::Waker, +}; use async_task::{Runnable, Task}; use compio_driver::NotifyHandle; @@ -15,31 +22,16 @@ mod send_wrapper; /// A task queue consisting of a local queue and a synchronized queue. struct TaskQueue { - local_queue: SendWrapper>, - sync_queue: SegQueue, + local_queue: Arc>>, + sync_queue: Arc>, } impl TaskQueue { /// Creates a new `TaskQueue`. fn new() -> Self { Self { - local_queue: SendWrapper::new(LocalQueue::new()), - sync_queue: SegQueue::new(), - } - } - - /// Pushes a `Runnable` task to the appropriate queue. - /// - /// If the current thread is the same as the creator thread, push to the - /// local queue. Otherwise, push to the sync queue. - fn push(&self, runnable: Runnable, notify: &NotifyHandle) { - if let Some(local_queue) = self.local_queue.get() { - local_queue.push(runnable); - #[cfg(feature = "notify-always")] - notify.notify().ok(); - } else { - self.sync_queue.push(runnable); - notify.notify().ok(); + local_queue: Arc::new(SendWrapper::new(LocalQueue::new())), + sync_queue: Arc::new(SegQueue::new()), } } @@ -94,12 +86,52 @@ impl TaskQueue { drop(item); } } + + /// Downgrades the `TaskQueue` into a `WeakTaskQueue`. + fn downgrade(&self) -> WeakTaskQueue { + WeakTaskQueue { + local_queue: Arc::downgrade(&self.local_queue), + sync_queue: Arc::downgrade(&self.sync_queue), + local_thread: self.local_queue.tracker(), + } + } +} + +/// A weak reference to a `TaskQueue`. +struct WeakTaskQueue { + local_queue: Weak>>, + sync_queue: Weak>, + // `()` is a trivial type, so it won't panic on drop even if moved to another thread. + local_thread: SendWrapper<()>, +} + +impl WeakTaskQueue { + /// Upgrades the `WeakTaskQueue` and pushes the `runnable` into the + /// appropriate queue. + fn upgrade_and_push(&self, runnable: Runnable, notify: &NotifyHandle) { + if self.local_thread.valid() { + // It's ok to drop the runnable on the same thread. + if let Some(local_queue) = self.local_queue.upgrade() { + // SAFETY: already checked + unsafe { local_queue.get_unchecked() }.push(runnable); + #[cfg(feature = "notify-always")] + notify.notify().ok(); + } + } else if let Some(sync_queue) = self.sync_queue.upgrade() { + sync_queue.push(runnable); + notify.notify().ok(); + } else { + // We have to leak the runnable since it's not safe to drop it on another + // thread. + std::mem::forget(runnable); + } + } } /// A scheduler for managing and executing tasks. pub(crate) struct Scheduler { /// Queue for scheduled tasks. - task_queue: Arc, + task_queue: TaskQueue, /// `Waker` of active tasks. active_tasks: Rc>>, @@ -115,7 +147,7 @@ impl Scheduler { /// Creates a new `Scheduler`. pub(crate) fn new(event_interval: usize) -> Self { Self { - task_queue: Arc::new(TaskQueue::new()), + task_queue: TaskQueue::new(), active_tasks: Rc::new(RefCell::new(Slab::new())), event_interval, _local_marker: PhantomData, @@ -150,16 +182,11 @@ impl Scheduler { let schedule = { // The schedule closure is managed by the `Waker` and may be dropped on another - // thread, so use `Weak` to ensure the `TaskQueue` is always dropped + // thread, so use `WeakTaskQueue` to ensure the `TaskQueue` is always dropped // on the creator thread. - let task_queue = Arc::downgrade(&self.task_queue); + let task_queue = self.task_queue.downgrade(); - move |runnable| { - // The `upgrade()` never fails because all tasks are dropped when the - // `Scheduler` is dropped, if a `Waker` is used after that, the - // schedule closure will never be called. - task_queue.upgrade().unwrap().push(runnable, ¬ify); - } + move |runnable| task_queue.upgrade_and_push(runnable, ¬ify) }; let (runnable, task) = async_task::spawn_unchecked(future, schedule); diff --git a/compio-runtime/src/runtime/scheduler/send_wrapper.rs b/compio-runtime/src/runtime/scheduler/send_wrapper.rs index 65a50652..8ed9df01 100644 --- a/compio-runtime/src/runtime/scheduler/send_wrapper.rs +++ b/compio-runtime/src/runtime/scheduler/send_wrapper.rs @@ -53,9 +53,20 @@ impl SendWrapper { /// Returns a reference to the contained value, if valid. #[inline] + #[allow(dead_code)] pub fn get(&self) -> Option<&T> { if self.valid() { Some(&self.data) } else { None } } + + /// Returns a tracker that can be used to check if the current thread is + /// the same as the creator thread. + #[inline] + pub fn tracker(&self) -> SendWrapper<()> { + SendWrapper { + data: ManuallyDrop::new(()), + thread_id: self.thread_id, + } + } } unsafe impl Send for SendWrapper {}