From 44699b1f102ee67f1e7c64fd6fa6414aa1740181 Mon Sep 17 00:00:00 2001 From: Niket Naidu Date: Thu, 27 Jun 2024 18:38:34 -0700 Subject: [PATCH 1/2] Use TaskMetadata instead of Droppable Future --- Cargo.toml | 1 - src/droppable_future.rs | 51 -------------------- src/lib.rs | 3 -- src/ticked_async_executor.rs | 92 ++++++++++++++++++++++-------------- 4 files changed, 56 insertions(+), 91 deletions(-) delete mode 100644 src/droppable_future.rs diff --git a/Cargo.toml b/Cargo.toml index f3fb95e..1bf97de 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,6 @@ edition = "2021" [dependencies] async-task = "4.7" -pin-project = "1" [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src/droppable_future.rs b/src/droppable_future.rs deleted file mode 100644 index 0ab6d57..0000000 --- a/src/droppable_future.rs +++ /dev/null @@ -1,51 +0,0 @@ -use std::{future::Future, pin::Pin}; - -use pin_project::{pin_project, pinned_drop}; - -#[pin_project(PinnedDrop)] -pub struct DroppableFuture -where - F: Future, - D: Fn(), -{ - #[pin] - future: F, - on_drop: D, -} - -impl DroppableFuture -where - F: Future, - D: Fn(), -{ - pub fn new(future: F, on_drop: D) -> Self { - Self { future, on_drop } - } -} - -impl Future for DroppableFuture -where - F: Future, - D: Fn(), -{ - type Output = F::Output; - - fn poll( - self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - let this = self.project(); - this.future.poll(cx) - } -} - -#[pinned_drop] -impl PinnedDrop for DroppableFuture -where - F: Future, - D: Fn(), -{ - fn drop(self: Pin<&mut Self>) { - (self.on_drop)(); - } -} diff --git a/src/lib.rs b/src/lib.rs index 1e5a011..86ac703 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,3 @@ -mod droppable_future; -use droppable_future::*; - mod task_identifier; pub use task_identifier::*; diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index 6c06a78..c5c22da 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -6,7 +6,7 @@ use std::{ }, }; -use crate::{DroppableFuture, TaskIdentifier}; +use crate::TaskIdentifier; #[derive(Debug)] pub enum TaskState { @@ -16,11 +16,37 @@ pub enum TaskState { Drop(TaskIdentifier), } -pub type Task = async_task::Task; -type Payload = (TaskIdentifier, async_task::Runnable); +pub type Task = async_task::Task>; +type TaskRunnable = async_task::Runnable>; +type Payload = (TaskIdentifier, TaskRunnable); -pub struct TickedAsyncExecutor { - channel: (mpsc::Sender, mpsc::Receiver), +/// Task Metadata associated with TickedAsyncExecutor +/// +/// Primarily used to track when the Task is completed/cancelled +pub struct TaskMetadata +where + O: Fn(TaskState) + Send + Sync + 'static, +{ + num_spawned_tasks: Arc, + identifier: TaskIdentifier, + observer: O, +} + +impl Drop for TaskMetadata +where + O: Fn(TaskState) + Send + Sync + 'static, +{ + fn drop(&mut self) { + self.num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); + (self.observer)(TaskState::Drop(self.identifier.clone())); + } +} + +pub struct TickedAsyncExecutor +where + O: Fn(TaskState) + Send + Sync + 'static, +{ + channel: (mpsc::Sender>, mpsc::Receiver>), num_woken_tasks: Arc, num_spawned_tasks: Arc, @@ -53,14 +79,22 @@ where &self, identifier: impl Into, future: impl Future + Send + 'static, - ) -> Task + ) -> Task where T: Send + 'static, { let identifier = identifier.into(); - let future = self.droppable_future(identifier.clone(), future); - let schedule = self.runnable_schedule_cb(identifier); - let (runnable, task) = async_task::spawn(future, schedule); + self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); + (self.observer)(TaskState::Spawn(identifier.clone())); + + let schedule = self.runnable_schedule_cb(identifier.clone()); + let (runnable, task) = async_task::Builder::new() + .metadata(TaskMetadata { + num_spawned_tasks: self.num_spawned_tasks.clone(), + identifier, + observer: self.observer.clone(), + }) + .spawn(|_m| future, schedule); runnable.schedule(); task } @@ -69,14 +103,22 @@ where &self, identifier: impl Into, future: impl Future + 'static, - ) -> Task + ) -> Task where T: 'static, { let identifier = identifier.into(); - let future = self.droppable_future(identifier.clone(), future); - let schedule = self.runnable_schedule_cb(identifier); - let (runnable, task) = async_task::spawn_local(future, schedule); + self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); + (self.observer)(TaskState::Spawn(identifier.clone())); + + let schedule = self.runnable_schedule_cb(identifier.clone()); + let (runnable, task) = async_task::Builder::new() + .metadata(TaskMetadata { + num_spawned_tasks: self.num_spawned_tasks.clone(), + identifier, + observer: self.observer.clone(), + }) + .spawn_local(move |_m| future, schedule); runnable.schedule(); task } @@ -104,29 +146,7 @@ where .fetch_sub(num_woken_tasks, Ordering::Relaxed); } - fn droppable_future( - &self, - identifier: TaskIdentifier, - future: F, - ) -> DroppableFuture - where - F: Future, - { - let observer = self.observer.clone(); - - // Spawn Task - self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); - observer(TaskState::Spawn(identifier.clone())); - - // Droppable Future registering on_drop callback - let num_spawned_tasks = self.num_spawned_tasks.clone(); - DroppableFuture::new(future, move || { - num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); - observer(TaskState::Drop(identifier.clone())); - }) - } - - fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) { + fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(TaskRunnable) { let sender = self.channel.0.clone(); let num_woken_tasks = self.num_woken_tasks.clone(); let observer = self.observer.clone(); From cdb45629f57105757f368b229eba21b8aa5661ab Mon Sep 17 00:00:00 2001 From: Niket Naidu Date: Thu, 27 Jun 2024 19:41:17 -0700 Subject: [PATCH 2/2] Add Send task unit-tests --- src/ticked_async_executor.rs | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index c5c22da..1ade368 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -165,7 +165,7 @@ mod tests { use super::*; #[test] - fn test_multiple_tasks() { + fn test_multiple_local_tasks() { let executor = TickedAsyncExecutor::default(); executor .spawn_local("A", async move { @@ -187,7 +187,7 @@ mod tests { } #[test] - fn test_task_cancellation() { + fn test_local_tasks_cancellation() { let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}")); let task1 = executor.spawn_local("A", async move { loop { @@ -217,4 +217,36 @@ mod tests { executor.tick(); } } + + #[test] + fn test_tasks_cancellation() { + let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}")); + let task1 = executor.spawn("A", async move { + loop { + tokio::task::yield_now().await; + } + }); + + let task2 = executor.spawn(format!("B"), async move { + loop { + tokio::task::yield_now().await; + } + }); + assert_eq!(executor.num_tasks(), 2); + executor.tick(); + + executor + .spawn_local("CancelTasks", async move { + let (t1, t2) = join!(task1.cancel(), task2.cancel()); + assert_eq!(t1, None); + assert_eq!(t2, None); + }) + .detach(); + assert_eq!(executor.num_tasks(), 3); + + // Since we have cancelled the tasks above, the loops should eventually end + while executor.num_tasks() != 0 { + executor.tick(); + } + } }