diff --git a/async-usercalls/src/lib.rs b/async-usercalls/src/lib.rs index 0c51d29a3..21f42ba27 100644 --- a/async-usercalls/src/lib.rs +++ b/async-usercalls/src/lib.rs @@ -67,10 +67,13 @@ impl AsyncUsercallProvider { let callbacks = Mutex::new(HashMap::new()); let (callback_tx, callback_rx) = mpmc::unbounded(); let provider = Self { core, callback_tx }; + let (waker_tx, waker_rx) = mpmc::bounded(1); let handler = CallbackHandler { return_rx, callbacks, callback_rx, + waker_tx, + waker_rx, }; (provider, handler) } @@ -91,22 +94,41 @@ impl AsyncUsercallProvider { } } +pub struct CallbackHandlerWaker(mpmc::Sender<()>); + +impl CallbackHandlerWaker { + /// Interrupts the related callback handler's `poll()` if blocked. + pub fn wake(&self) { + let _ = self.0.try_send(()); + } +} + pub struct CallbackHandler { return_rx: mpmc::Receiver>, callbacks: Mutex>, // This is used so that threads sending usercalls don't have to take the lock. callback_rx: mpmc::Receiver<(u64, Callback)>, + waker_rx: mpmc::Receiver<()>, + waker_tx: mpmc::Sender<()>, } impl CallbackHandler { + // Returns an object that can be used to interrupt a blocked `self.poll()`. + pub fn waker(&self) -> CallbackHandlerWaker { + CallbackHandlerWaker(self.waker_tx.clone()) + } + #[inline] fn recv_returns(&self, timeout: Option, returns: &mut [Identified]) -> usize { let first = match timeout { - None => self.return_rx.recv().ok(), - Some(timeout) => match self.return_rx.recv_timeout(timeout) { - Ok(val) => Some(val), - Err(mpmc::RecvTimeoutError::Disconnected) => None, - Err(mpmc::RecvTimeoutError::Timeout) => return 0, + None => mpmc::select! { + recv(self.return_rx) -> res => res.ok(), + recv(self.waker_rx) -> _res => return 0, + }, + Some(timeout) => mpmc::select! { + recv(self.return_rx) -> res => res.ok(), + recv(self.waker_rx) -> _res => return 0, + default(timeout) => return 0, }, } .expect("return channel closed unexpectedly"); @@ -122,6 +144,7 @@ impl CallbackHandler { /// functions. If `timeout` is `None`, it will block execution until at /// least one return is received, otherwise it will block until there is a /// return or timeout is elapsed. Returns the number of executed callbacks. + /// This can be interrupted using `CallbackHandlerWaker::wake()`. pub fn poll(&self, timeout: Option) -> usize { // 1. wait for returns let mut returns = [Identified { diff --git a/async-usercalls/src/tests.rs b/async-usercalls/src/tests.rs index 2bdc473b7..ff838c48c 100644 --- a/async-usercalls/src/tests.rs +++ b/async-usercalls/src/tests.rs @@ -251,6 +251,25 @@ fn read_buffer_basic() { assert_eq!(&buf, b"hello\0\0\0"); } +#[test] +fn callback_handler_waker() { + let (_provider, handler) = AsyncUsercallProvider::new(); + let waker = handler.waker(); + let (tx, rx) = mpmc::bounded(1); + let h = thread::spawn(move || { + let n1 = handler.poll(None); + tx.send(()).unwrap(); + let n2 = handler.poll(Some(Duration::from_secs(3))); + tx.send(()).unwrap(); + n1 + n2 + }); + for _ in 0..2 { + waker.wake(); + rx.recv().unwrap(); + } + assert_eq!(h.join().unwrap(), 0); +} + #[test] #[ignore] fn echo() {