Skip to content

Commit

Permalink
Synchronize cancel queue with usercall queue
Browse files Browse the repository at this point in the history
  • Loading branch information
mzohreva committed Oct 21, 2020
1 parent 556765d commit 95b558e
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 44 deletions.
55 changes: 24 additions & 31 deletions enclave-runner/src/usercalls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use tokio::sync::broadcast;
use tokio::sync::mpsc as async_mpsc;

use fortanix_sgx_abi::*;
use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent};
use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent, WritePosition};
use sgxs::loader::Tcs as SgxsTcs;

use crate::loader::{EnclavePanic, ErasedTcs};
Expand Down Expand Up @@ -636,26 +636,22 @@ impl Work {
enum UsercallEvent {
Started(u64, tokio::sync::oneshot::Sender<()>),
Finished(u64),
Cancelled(u64, Instant),
}

fn ignore_cancel_impl(usercall_nr: u64) -> bool {
usercall_nr != UsercallList::read as u64 &&
usercall_nr != UsercallList::read_alloc as u64 &&
usercall_nr != UsercallList::write as u64 &&
usercall_nr != UsercallList::accept_stream as u64 &&
usercall_nr != UsercallList::connect_stream as u64 &&
usercall_nr != UsercallList::wait as u64
Cancelled(u64, WritePosition),
}

trait IgnoreCancel {
fn ignore_cancel(&self) -> bool;
}

impl IgnoreCancel for Identified<Usercall> {
fn ignore_cancel(&self) -> bool { ignore_cancel_impl(self.data.0) }
}
impl IgnoreCancel for Identified<Cancel> {
fn ignore_cancel(&self) -> bool { ignore_cancel_impl(self.data.usercall_nr) }
fn ignore_cancel(&self) -> bool {
self.data.0 != UsercallList::read as u64 &&
self.data.0 != UsercallList::read_alloc as u64 &&
self.data.0 != UsercallList::write as u64 &&
self.data.0 != UsercallList::accept_stream as u64 &&
self.data.0 != UsercallList::connect_stream as u64 &&
self.data.0 != UsercallList::wait as u64
}
}

impl EnclaveState {
Expand Down Expand Up @@ -892,6 +888,8 @@ impl EnclaveState {
*enclave_clone.fifo_guards.lock().await = Some(fifo_guards);
*enclave_clone.return_queue_tx.lock().await = Some(return_queue_tx);

let usercall_queue_monitor = usercall_queue_rx.position_monitor();

tokio::task::spawn_local(async move {
while let Ok(usercall) = usercall_queue_rx.recv().await {
let _ = io_queue_send.send(UsercallSendData::Async(usercall));
Expand All @@ -900,37 +898,32 @@ impl EnclaveState {

let (usercall_event_tx, mut usercall_event_rx) = async_mpsc::unbounded_channel();
let usercall_event_tx_clone = usercall_event_tx.clone();
let usercall_queue_monitor_clone = usercall_queue_monitor.clone();
tokio::task::spawn_local(async move {
while let Ok(c) = cancel_queue_rx.recv().await {
if !c.ignore_cancel() {
let _ = usercall_event_tx_clone.send(UsercallEvent::Cancelled(c.id, Instant::now()));
}
let write_position = usercall_queue_monitor_clone.write_position();
let _ = usercall_event_tx_clone.send(UsercallEvent::Cancelled(c.id, write_position));
}
});

tokio::task::spawn_local(async move {
let mut notifiers = HashMap::new();
let mut cancels: HashMap<u64, Instant> = HashMap::new();
// This should be greater than the amount of time it takes for the enclave runner
// to start executing a usercall after the enclave sends it on the usercall_queue.
const CANCEL_EXPIRY: Duration = Duration::from_millis(100);
let mut cancels: HashMap<u64, WritePosition> = HashMap::new();
loop {
match usercall_event_rx.recv().await.expect("usercall_event channel closed unexpectedly") {
UsercallEvent::Started(id, notifier) => match cancels.remove(&id) {
Some(t) if t.elapsed() < CANCEL_EXPIRY => { let _ = notifier.send(()); },
Some(_) => { let _ = notifier.send(()); },
_ => { notifiers.insert(id, notifier); },
},
UsercallEvent::Finished(id) => { notifiers.remove(&id); },
UsercallEvent::Cancelled(id, t) => if t.elapsed() < CANCEL_EXPIRY {
match notifiers.remove(&id) {
Some(notifier) => { let _ = notifier.send(()); },
None => { cancels.insert(id, t); },
}
UsercallEvent::Cancelled(id, wp) => match notifiers.remove(&id) {
Some(notifier) => { let _ = notifier.send(()); },
None => { cancels.insert(id, wp); },
},
}
// cleanup expired cancels
let now = Instant::now();
cancels.retain(|_id, &mut t| now - t < CANCEL_EXPIRY);
// cleanup old cancels
let read_position = usercall_queue_monitor.read_position();
cancels.retain(|_id, wp| !read_position.is_past(wp));
}
});

Expand Down
4 changes: 2 additions & 2 deletions fortanix-sgx-abi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,8 +718,8 @@ pub mod async {
#[derive(Copy, Clone, Default)]
#[cfg_attr(feature = "rustc-dep-of-std", unstable(feature = "sgx_platform", issue = "56975"))]
pub struct Cancel {
/// This must be the same value as `Usercall.0`.
pub usercall_nr: u64,
/// Reserved for future use.
pub reserved: u64,
}

/// A circular buffer used as a FIFO queue with atomic reads and writes.
Expand Down
30 changes: 22 additions & 8 deletions ipc-queue/src/fifo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

use std::cell::UnsafeCell;
use std::mem;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;

use fortanix_sgx_abi::{FifoDescriptor, WithId};
Expand All @@ -33,7 +33,7 @@ where
let arc = Arc::new(FifoBuffer::new(len));
let inner = Fifo::from_arc(arc);
let tx = AsyncSender { inner: inner.clone(), synchronizer: s.clone() };
let rx = AsyncReceiver { inner, synchronizer: s };
let rx = AsyncReceiver { inner, synchronizer: s, read_epoch: Arc::new(AtomicU32::new(0)) };
(tx, rx)
}

Expand Down Expand Up @@ -87,6 +87,12 @@ impl<T> Clone for Fifo<T> {
}
}

impl<T> Fifo<T> {
pub(crate) fn current_offsets(&self, ordering: Ordering) -> Offsets {
Offsets::new(self.offsets.load(ordering), self.data.len() as u32)
}
}

impl<T: Transmittable> Fifo<T> {
pub(crate) unsafe fn from_descriptor(descriptor: FifoDescriptor<T>) -> Self {
assert!(
Expand Down Expand Up @@ -152,7 +158,7 @@ impl<T: Transmittable> Fifo<T> {
pub(crate) fn try_send_impl(&self, val: Identified<T>) -> Result</*wake up reader:*/ bool, TrySendError> {
let (new, was_empty) = loop {
// 1. Load the current offsets.
let current = Offsets::new(self.offsets.load(Ordering::SeqCst), self.data.len() as u32);
let current = self.current_offsets(Ordering::SeqCst);
let was_empty = current.is_empty();

// 2. If the queue is full, wait, then go to step 1.
Expand All @@ -179,9 +185,9 @@ impl<T: Transmittable> Fifo<T> {
Ok(was_empty)
}

pub(crate) fn try_recv_impl(&self) -> Result<(Identified<T>, /*wake up writer:*/ bool), TryRecvError> {
pub(crate) fn try_recv_impl(&self) -> Result<(Identified<T>, /*wake up writer:*/ bool, /*read offset wrapped around:*/bool), TryRecvError> {
// 1. Load the current offsets.
let current = Offsets::new(self.offsets.load(Ordering::SeqCst), self.data.len() as u32);
let current = self.current_offsets(Ordering::SeqCst);

// 2. If the queue is empty, wait, then go to step 1.
if current.is_empty() {
Expand Down Expand Up @@ -216,7 +222,7 @@ impl<T: Transmittable> Fifo<T> {

// 8. If the queue was full before step 7, signal the writer to wake up.
let was_full = Offsets::new(before, self.data.len() as u32).is_full();
Ok((val, was_full))
Ok((val, was_full, new.read_offset() == 0))
}
}

Expand Down Expand Up @@ -282,6 +288,14 @@ impl Offsets {
..*self
}
}

pub(crate) fn read_high_bit(&self) -> bool {
self.read & self.len == self.len
}

pub(crate) fn write_high_bit(&self) -> bool {
self.write & self.len == self.len
}
}

#[cfg(test)]
Expand All @@ -308,7 +322,7 @@ mod tests {
}

for i in 1..=7 {
let (v, wake) = inner.try_recv_impl().unwrap();
let (v, wake, _) = inner.try_recv_impl().unwrap();
assert!(!wake);
assert_eq!(v.id, i);
assert_eq!(v.data.0, i);
Expand All @@ -327,7 +341,7 @@ mod tests {
assert!(inner.try_send_impl(Identified { id: 9, data: TestValue(9) }).is_err());

for i in 1..=8 {
let (v, wake) = inner.try_recv_impl().unwrap();
let (v, wake, _) = inner.try_recv_impl().unwrap();
assert!(if i == 1 { wake } else { !wake });
assert_eq!(v.id, i);
assert_eq!(v.data.0, i);
Expand Down
72 changes: 71 additions & 1 deletion ipc-queue/src/interface_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */

use super::*;
use std::sync::atomic::Ordering;

unsafe impl<T: Send, S: Send> Send for AsyncSender<T, S> {}
unsafe impl<T: Send, S: Sync> Sync for AsyncSender<T, S> {}
Expand Down Expand Up @@ -52,10 +53,13 @@ impl<T: Transmittable, S: AsyncSynchronizer> AsyncReceiver<T, S> {
pub async fn recv(&self) -> Result<Identified<T>, RecvError> {
loop {
match self.inner.try_recv_impl() {
Ok((val, wake_sender)) => {
Ok((val, wake_sender, read_wrapped_around)) => {
if wake_sender {
self.synchronizer.notify(QueueEvent::NotFull);
}
if read_wrapped_around {
self.read_epoch.fetch_add(1, Ordering::Relaxed);
}
return Ok(val);
}
Err(TryRecvError::QueueEmpty) => {
Expand All @@ -68,6 +72,13 @@ impl<T: Transmittable, S: AsyncSynchronizer> AsyncReceiver<T, S> {
}
}

pub fn position_monitor(&self) -> PositionMonitor<T> {
PositionMonitor {
read_epoch: self.read_epoch.clone(),
fifo: self.inner.clone(),
}
}

/// Consumes `self` and returns a DescriptorGuard.
/// The returned guard can be used to make `FifoDescriptor`s that remain
/// valid as long as the guard is not dropped.
Expand Down Expand Up @@ -153,6 +164,65 @@ mod tests {
do_multi_sender(1024, 30, 100).await;
}

#[tokio::test]
async fn positions() {
const LEN: usize = 16;
let s = TestAsyncSynchronizer::new();
let (tx, rx) = bounded_async(LEN, s);
let monitor = rx.position_monitor();
let mut id = 1;

let p0 = monitor.write_position();
tx.send(Identified { id, data: TestValue(1) }).await.unwrap();
let p1 = monitor.write_position();
tx.send(Identified { id: id + 1, data: TestValue(2) }).await.unwrap();
let p2 = monitor.write_position();
tx.send(Identified { id: id + 2, data: TestValue(3) }).await.unwrap();
let p3 = monitor.write_position();
id += 3;
assert!(monitor.read_position().is_past(&p0) == false);
assert!(monitor.read_position().is_past(&p1) == false);
assert!(monitor.read_position().is_past(&p2) == false);
assert!(monitor.read_position().is_past(&p3) == false);

rx.recv().await.unwrap();
assert!(monitor.read_position().is_past(&p0) == true);
assert!(monitor.read_position().is_past(&p1) == false);
assert!(monitor.read_position().is_past(&p2) == false);
assert!(monitor.read_position().is_past(&p3) == false);

rx.recv().await.unwrap();
assert!(monitor.read_position().is_past(&p0) == true);
assert!(monitor.read_position().is_past(&p1) == true);
assert!(monitor.read_position().is_past(&p2) == false);
assert!(monitor.read_position().is_past(&p3) == false);

rx.recv().await.unwrap();
assert!(monitor.read_position().is_past(&p0) == true);
assert!(monitor.read_position().is_past(&p1) == true);
assert!(monitor.read_position().is_past(&p2) == true);
assert!(monitor.read_position().is_past(&p3) == false);

for i in 0..1000 {
let n = 1 + (i % LEN);
let p4 = monitor.write_position();
for _ in 0..n {
tx.send(Identified { id, data: TestValue(id) }).await.unwrap();
id += 1;
}
let p5 = monitor.write_position();
for _ in 0..n {
rx.recv().await.unwrap();
assert!(monitor.read_position().is_past(&p0) == true);
assert!(monitor.read_position().is_past(&p1) == true);
assert!(monitor.read_position().is_past(&p2) == true);
assert!(monitor.read_position().is_past(&p3) == true);
assert!(monitor.read_position().is_past(&p4) == true);
assert!(monitor.read_position().is_past(&p5) == false);
}
}
}

struct Subscription<T> {
tx: broadcast::Sender<T>,
rx: Mutex<broadcast::Receiver<T>>,
Expand Down
4 changes: 2 additions & 2 deletions ipc-queue/src/interface_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl<T: Transmittable, S: Synchronizer> Receiver<T, S> {
}

pub fn try_recv(&self) -> Result<Identified<T>, TryRecvError> {
self.inner.try_recv_impl().map(|(val, wake_sender)| {
self.inner.try_recv_impl().map(|(val, wake_sender, _)| {
if wake_sender {
self.synchronizer.notify(QueueEvent::NotFull);
}
Expand All @@ -127,7 +127,7 @@ impl<T: Transmittable, S: Synchronizer> Receiver<T, S> {
pub fn recv(&self) -> Result<Identified<T>, RecvError> {
loop {
match self.inner.try_recv_impl() {
Ok((val, wake_sender)) => {
Ok((val, wake_sender, _)) => {
if wake_sender {
self.synchronizer.notify(QueueEvent::NotFull);
}
Expand Down
19 changes: 19 additions & 0 deletions ipc-queue/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::future::Future;
#[cfg(target_env = "sgx")]
use std::os::fortanix_sgx::usercalls::alloc::UserSafeSized;
use std::pin::Pin;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;

use fortanix_sgx_abi::FifoDescriptor;
Expand All @@ -19,6 +20,7 @@ use self::fifo::{Fifo, FifoBuffer};
mod fifo;
mod interface_sync;
mod interface_async;
mod position;
#[cfg(test)]
mod test_support;

Expand Down Expand Up @@ -123,6 +125,7 @@ pub struct AsyncSender<T: 'static, S> {
pub struct AsyncReceiver<T: 'static, S> {
inner: Fifo<T>,
synchronizer: S,
read_epoch: Arc<AtomicU32>,
}

/// `DescriptorGuard<T>` can produce a `FifoDescriptor<T>` that is guaranteed
Expand All @@ -137,3 +140,19 @@ impl<T> DescriptorGuard<T> {
self.descriptor
}
}

/// `PositionMonitor<T>` can be used to record the current read/write positions
/// of a queue. Even though a queue is comprised of a limited number of slots
/// arranged as a ring buffer, we can assign a position to each value written/
/// read to/from the queue. This is useful in case we want to know whether or
/// not a particular value written to the queue has been read.
pub struct PositionMonitor<T: 'static> {
read_epoch: Arc<AtomicU32>,
fifo: Fifo<T>,
}

/// A read position in a queue.
pub struct ReadPosition(u64);

/// A write position in a queue.
pub struct WritePosition(u64);
Loading

0 comments on commit 95b558e

Please sign in to comment.