Skip to content

Commit

Permalink
Add CallbackHandlerWaker
Browse files Browse the repository at this point in the history
  • Loading branch information
mzohreva committed Nov 21, 2020
1 parent 39f1de5 commit 6d52998
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
33 changes: 28 additions & 5 deletions async-usercalls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ impl AsyncUsercallProvider {
let callbacks = Mutex::new(HashMap::new());
let (callback_tx, callback_rx) = mpmc::unbounded();
let provider = Self { core, callback_tx };
let (waker_tx, waker_rx) = mpmc::bounded(1);
let handler = CallbackHandler {
return_rx,
callbacks,
callback_rx,
waker_tx,
waker_rx,
};
(provider, handler)
}
Expand All @@ -91,22 +94,41 @@ impl AsyncUsercallProvider {
}
}

pub struct CallbackHandlerWaker(mpmc::Sender<()>);

impl CallbackHandlerWaker {
/// Interrupts the related callback handler's `poll()` if blocked.
pub fn wake(&self) {
let _ = self.0.try_send(());
}
}

pub struct CallbackHandler {
return_rx: mpmc::Receiver<Identified<Return>>,
callbacks: Mutex<HashMap<u64, Callback>>,
// This is used so that threads sending usercalls don't have to take the lock.
callback_rx: mpmc::Receiver<(u64, Callback)>,
waker_rx: mpmc::Receiver<()>,
waker_tx: mpmc::Sender<()>,
}

impl CallbackHandler {
// Returns an object that can be used to interrupt a blocked `self.poll()`.
pub fn waker(&self) -> CallbackHandlerWaker {
CallbackHandlerWaker(self.waker_tx.clone())
}

#[inline]
fn recv_returns(&self, timeout: Option<Duration>, returns: &mut [Identified<Return>]) -> usize {
let first = match timeout {
None => self.return_rx.recv().ok(),
Some(timeout) => match self.return_rx.recv_timeout(timeout) {
Ok(val) => Some(val),
Err(mpmc::RecvTimeoutError::Disconnected) => None,
Err(mpmc::RecvTimeoutError::Timeout) => return 0,
None => mpmc::select! {
recv(self.return_rx) -> res => res.ok(),
recv(self.waker_rx) -> _res => return 0,
},
Some(timeout) => mpmc::select! {
recv(self.return_rx) -> res => res.ok(),
recv(self.waker_rx) -> _res => return 0,
default(timeout) => return 0,
},
}
.expect("return channel closed unexpectedly");
Expand All @@ -122,6 +144,7 @@ impl CallbackHandler {
/// functions. If `timeout` is `None`, it will block execution until at
/// least one return is received, otherwise it will block until there is a
/// return or timeout is elapsed. Returns the number of executed callbacks.
/// This can be interrupted using `CallbackHandlerWaker::wake()`.
pub fn poll(&self, timeout: Option<Duration>) -> usize {
// 1. wait for returns
let mut returns = [Identified {
Expand Down
19 changes: 19 additions & 0 deletions async-usercalls/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,25 @@ fn read_buffer_basic() {
assert_eq!(&buf, b"hello\0\0\0");
}

#[test]
fn callback_handler_waker() {
let (_provider, handler) = AsyncUsercallProvider::new();
let waker = handler.waker();
let (tx, rx) = mpmc::bounded(1);
let h = thread::spawn(move || {
let n1 = handler.poll(None);
tx.send(()).unwrap();
let n2 = handler.poll(Some(Duration::from_secs(3)));
tx.send(()).unwrap();
n1 + n2
});
for _ in 0..2 {
waker.wake();
rx.recv().unwrap();
}
assert_eq!(h.join().unwrap(), 0);
}

#[test]
#[ignore]
fn echo() {
Expand Down

0 comments on commit 6d52998

Please sign in to comment.