Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compio-runtime/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "compio-runtime"
version = "0.9.4"
version = "0.9.5"
description = "High-level runtime for compio"
categories = ["asynchronous"]
keywords = ["async", "runtime"]
Expand Down
87 changes: 57 additions & 30 deletions compio-runtime/src/runtime/scheduler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
use std::{cell::RefCell, future::Future, marker::PhantomData, rc::Rc, sync::Arc, task::Waker};
use std::{
cell::RefCell,
future::Future,
marker::PhantomData,
rc::Rc,
sync::{Arc, Weak},
task::Waker,
};

use async_task::{Runnable, Task};
use compio_driver::NotifyHandle;
Expand All @@ -15,31 +22,16 @@ mod send_wrapper;

/// A task queue consisting of a local queue and a synchronized queue.
struct TaskQueue {
local_queue: SendWrapper<LocalQueue<Runnable>>,
sync_queue: SegQueue<Runnable>,
local_queue: Arc<SendWrapper<LocalQueue<Runnable>>>,
sync_queue: Arc<SegQueue<Runnable>>,
}

impl TaskQueue {
/// Creates a new `TaskQueue`.
fn new() -> Self {
Self {
local_queue: SendWrapper::new(LocalQueue::new()),
sync_queue: SegQueue::new(),
}
}

/// Pushes a `Runnable` task to the appropriate queue.
///
/// If the current thread is the same as the creator thread, push to the
/// local queue. Otherwise, push to the sync queue.
fn push(&self, runnable: Runnable, notify: &NotifyHandle) {
if let Some(local_queue) = self.local_queue.get() {
local_queue.push(runnable);
#[cfg(feature = "notify-always")]
notify.notify().ok();
} else {
self.sync_queue.push(runnable);
notify.notify().ok();
local_queue: Arc::new(SendWrapper::new(LocalQueue::new())),
sync_queue: Arc::new(SegQueue::new()),
}
}

Expand Down Expand Up @@ -94,12 +86,52 @@ impl TaskQueue {
drop(item);
}
}

/// Downgrades the `TaskQueue` into a `WeakTaskQueue`.
fn downgrade(&self) -> WeakTaskQueue {
WeakTaskQueue {
local_queue: Arc::downgrade(&self.local_queue),
sync_queue: Arc::downgrade(&self.sync_queue),
local_thread: self.local_queue.tracker(),
}
}
}

/// A weak reference to a `TaskQueue`.
struct WeakTaskQueue {
local_queue: Weak<SendWrapper<LocalQueue<Runnable>>>,
sync_queue: Weak<SegQueue<Runnable>>,
// `()` is a trivial type, so it won't panic on drop even if moved to another thread.
local_thread: SendWrapper<()>,
}

impl WeakTaskQueue {
/// Upgrades the `WeakTaskQueue` and pushes the `runnable` into the
/// appropriate queue.
fn upgrade_and_push(&self, runnable: Runnable, notify: &NotifyHandle) {
if self.local_thread.valid() {
// It's ok to drop the runnable on the same thread.
if let Some(local_queue) = self.local_queue.upgrade() {
// SAFETY: already checked
unsafe { local_queue.get_unchecked() }.push(runnable);
#[cfg(feature = "notify-always")]
notify.notify().ok();
}
} else if let Some(sync_queue) = self.sync_queue.upgrade() {
sync_queue.push(runnable);
notify.notify().ok();
} else {
// We have to leak the runnable since it's not safe to drop it on another
// thread.
std::mem::forget(runnable);
}
}
}

/// A scheduler for managing and executing tasks.
pub(crate) struct Scheduler {
/// Queue for scheduled tasks.
task_queue: Arc<TaskQueue>,
task_queue: TaskQueue,

/// `Waker` of active tasks.
active_tasks: Rc<RefCell<Slab<Waker>>>,
Expand All @@ -115,7 +147,7 @@ impl Scheduler {
/// Creates a new `Scheduler`.
pub(crate) fn new(event_interval: usize) -> Self {
Self {
task_queue: Arc::new(TaskQueue::new()),
task_queue: TaskQueue::new(),
active_tasks: Rc::new(RefCell::new(Slab::new())),
event_interval,
_local_marker: PhantomData,
Expand Down Expand Up @@ -150,16 +182,11 @@ impl Scheduler {

let schedule = {
// The schedule closure is managed by the `Waker` and may be dropped on another
// thread, so use `Weak` to ensure the `TaskQueue` is always dropped
// thread, so use `WeakTaskQueue` to ensure the `TaskQueue` is always dropped
// on the creator thread.
let task_queue = Arc::downgrade(&self.task_queue);
let task_queue = self.task_queue.downgrade();

move |runnable| {
// The `upgrade()` never fails because all tasks are dropped when the
// `Scheduler` is dropped, if a `Waker` is used after that, the
// schedule closure will never be called.
task_queue.upgrade().unwrap().push(runnable, &notify);
}
move |runnable| task_queue.upgrade_and_push(runnable, &notify)
};

let (runnable, task) = async_task::spawn_unchecked(future, schedule);
Expand Down
11 changes: 11 additions & 0 deletions compio-runtime/src/runtime/scheduler/send_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,20 @@ impl<T> SendWrapper<T> {

/// Returns a reference to the contained value, if valid.
#[inline]
#[allow(dead_code)]
pub fn get(&self) -> Option<&T> {
if self.valid() { Some(&self.data) } else { None }
}

/// Returns a tracker that can be used to check if the current thread is
/// the same as the creator thread.
#[inline]
pub fn tracker(&self) -> SendWrapper<()> {
SendWrapper {
data: ManuallyDrop::new(()),
thread_id: self.thread_id,
}
}
}

unsafe impl<T> Send for SendWrapper<T> {}
Expand Down
Loading