diff --git a/Cargo.lock b/Cargo.lock index c81b1eb..ceba34b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1627,6 +1627,7 @@ dependencies = [ "base64", "bincode", "color-eyre", + "futures-channel", "futures-util", "hex", "http", diff --git a/Cargo.toml b/Cargo.toml index 3898dc3..fac64fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ repository = "https://github.com/hyperware-ai/process_lib" license = "Apache-2.0" [features] -hyperapp = ["dep:futures-util", "dep:uuid", "logging"] +hyperapp = ["dep:futures-util", "dep:futures-channel", "dep:uuid", "logging"] logging = ["dep:color-eyre", "dep:tracing", "dep:tracing-error", "dep:tracing-subscriber"] hyperwallet = ["dep:hex", "dep:sha3"] simulation-mode = [] @@ -42,6 +42,7 @@ url = "2.4.1" wit-bindgen = "0.42.1" futures-util = { version = "0.3", optional = true } +futures-channel = { version = "0.3", optional = true } uuid = { version = "1.0", features = ["v4"], optional = true } color-eyre = { version = "0.6", features = ["capture-spantrace"], optional = true } diff --git a/src/hyperapp.rs b/src/hyperapp.rs index d03299b..c2059bd 100644 --- a/src/hyperapp.rs +++ b/src/hyperapp.rs @@ -1,7 +1,11 @@ use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::future::Future; use std::pin::Pin; +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; use std::task::{Context, Poll}; use crate::{ @@ -10,7 +14,8 @@ use crate::{ logging::{error, info}, set_state, timer, Address, BuildError, LazyLoadBlob, Message, Request, SendError, }; -use futures_util::task::noop_waker_ref; +use futures_channel::{mpsc, oneshot}; +use futures_util::task::{waker_ref, ArcWake}; use serde::{Deserialize, Serialize}; use thiserror::Error; use uuid::Uuid; @@ -18,13 +23,13 @@ use uuid::Uuid; thread_local! { static SPAWN_QUEUE: RefCell>>>> = RefCell::new(Vec::new()); - pub static APP_CONTEXT: RefCell = RefCell::new(AppContext { hidden_state: None, executor: Executor::new(), }); pub static RESPONSE_REGISTRY: RefCell>> = RefCell::new(HashMap::new()); + pub static CANCELLED_RESPONSES: RefCell> = RefCell::new(HashSet::new()); pub static APP_HELPERS: RefCell = RefCell::new(AppHelpers { current_server: None, @@ -146,10 +151,53 @@ pub struct Executor { tasks: Vec>>>, } -pub fn spawn(fut: impl Future + 'static) { +struct ExecutorWakeFlag { + triggered: AtomicBool, +} + +impl ExecutorWakeFlag { + fn new() -> Self { + Self { + triggered: AtomicBool::new(false), + } + } + + fn take(&self) -> bool { + self.triggered.swap(false, Ordering::SeqCst) + } +} + +impl ArcWake for ExecutorWakeFlag { + fn wake_by_ref(arc_self: &Arc) { + arc_self.triggered.store(true, Ordering::SeqCst); + } +} + +pub struct JoinHandle { + receiver: oneshot::Receiver, +} + +impl Future for JoinHandle { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let receiver = &mut self.get_mut().receiver; + Pin::new(receiver).poll(cx) + } +} + +pub fn spawn(fut: impl Future + 'static) -> JoinHandle +where + T: 'static, +{ + let (sender, receiver) = oneshot::channel(); SPAWN_QUEUE.with(|queue| { - queue.borrow_mut().push(Box::pin(fut)); - }) + queue.borrow_mut().push(Box::pin(async move { + let result = fut.await; + let _ = sender.send(result); + })); + }); + JoinHandle { receiver } } impl Executor { @@ -158,19 +206,24 @@ impl Executor { } pub fn poll_all_tasks(&mut self) { + let wake_flag = Arc::new(ExecutorWakeFlag::new()); loop { // Drain any newly spawned tasks into our task list SPAWN_QUEUE.with(|queue| { self.tasks.append(&mut queue.borrow_mut()); }); - // Poll all tasks, collecting completed ones + // Poll all tasks, collecting completed ones. + // Put waker into context so tasks can wake the executor if needed. let mut completed = Vec::new(); - let mut ctx = Context::from_waker(noop_waker_ref()); + { + let waker = waker_ref(&wake_flag); + let mut ctx = Context::from_waker(&waker); - for i in 0..self.tasks.len() { - if let Poll::Ready(()) = self.tasks[i].as_mut().poll(&mut ctx) { - completed.push(i); + for i in 0..self.tasks.len() { + if let Poll::Ready(()) = self.tasks[i].as_mut().poll(&mut ctx) { + completed.push(i); + } } } @@ -181,9 +234,10 @@ impl Executor { // Check if there are new tasks spawned during polling let has_new_tasks = SPAWN_QUEUE.with(|queue| !queue.borrow().is_empty()); + // Check if any task woke the executor that needs to be re-polled + let was_woken = wake_flag.take(); - // Continue if new tasks were spawned, otherwise we're done - if !has_new_tasks { + if !has_new_tasks && !was_woken { break; } } @@ -193,6 +247,7 @@ struct ResponseFuture { correlation_id: String, // Capture HTTP context at creation time http_context: Option, + resolved: bool, } impl ResponseFuture { @@ -204,6 +259,7 @@ impl ResponseFuture { Self { correlation_id, http_context, + resolved: false, } } } @@ -212,16 +268,18 @@ impl Future for ResponseFuture { type Output = Vec; fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { - let correlation_id = &self.correlation_id; + let this = self.get_mut(); let maybe_bytes = RESPONSE_REGISTRY.with(|registry| { let mut registry_mut = registry.borrow_mut(); - registry_mut.remove(correlation_id) + registry_mut.remove(&this.correlation_id) }); if let Some(bytes) = maybe_bytes { + this.resolved = true; + // Restore this future's captured context - if let Some(ref context) = self.http_context { + if let Some(ref context) = this.http_context { APP_HELPERS.with(|helpers| { helpers.borrow_mut().current_http_context = Some(context.clone()); }); @@ -234,6 +292,23 @@ impl Future for ResponseFuture { } } +impl Drop for ResponseFuture { + fn drop(&mut self) { + // We want to avoid cleaning up after successful responses + if self.resolved { + return; + } + + RESPONSE_REGISTRY.with(|registry| { + registry.borrow_mut().remove(&self.correlation_id); + }); + + CANCELLED_RESPONSES.with(|set| { + set.borrow_mut().insert(self.correlation_id.clone()); + }); + } +} + #[derive(Debug, Clone, Serialize, Deserialize, Error)] pub enum AppSendError { #[error("SendError: {0}")]