From 57b0325f52374b89a48a7727e4a0eed57b775c12 Mon Sep 17 00:00:00 2001 From: Mohsen Zohrevandi Date: Wed, 21 Oct 2020 15:04:10 -0700 Subject: [PATCH] Synchronize cancel queue with usercall queue --- enclave-runner/src/usercalls/mod.rs | 55 ++++++++++------------ fortanix-sgx-abi/src/lib.rs | 14 ++++-- ipc-queue/src/fifo.rs | 30 ++++++++---- ipc-queue/src/interface_async.rs | 72 ++++++++++++++++++++++++++++- ipc-queue/src/interface_sync.rs | 4 +- ipc-queue/src/lib.rs | 19 ++++++++ ipc-queue/src/position.rs | 48 +++++++++++++++++++ 7 files changed, 195 insertions(+), 47 deletions(-) create mode 100644 ipc-queue/src/position.rs diff --git a/enclave-runner/src/usercalls/mod.rs b/enclave-runner/src/usercalls/mod.rs index 477f869d..9a072f32 100644 --- a/enclave-runner/src/usercalls/mod.rs +++ b/enclave-runner/src/usercalls/mod.rs @@ -33,7 +33,7 @@ use tokio::sync::broadcast; use tokio::sync::mpsc as async_mpsc; use fortanix_sgx_abi::*; -use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent}; +use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent, WritePosition}; use sgxs::loader::Tcs as SgxsTcs; use crate::loader::{EnclavePanic, ErasedTcs}; @@ -636,26 +636,22 @@ impl Work { enum UsercallEvent { Started(u64, tokio::sync::oneshot::Sender<()>), Finished(u64), - Cancelled(u64, Instant), -} - -fn ignore_cancel_impl(usercall_nr: u64) -> bool { - usercall_nr != UsercallList::read as u64 && - usercall_nr != UsercallList::read_alloc as u64 && - usercall_nr != UsercallList::write as u64 && - usercall_nr != UsercallList::accept_stream as u64 && - usercall_nr != UsercallList::connect_stream as u64 && - usercall_nr != UsercallList::wait as u64 + Cancelled(u64, WritePosition), } trait IgnoreCancel { fn ignore_cancel(&self) -> bool; } + impl IgnoreCancel for Identified { - fn ignore_cancel(&self) -> bool { ignore_cancel_impl(self.data.0) } -} -impl IgnoreCancel for Identified { - fn ignore_cancel(&self) -> bool { ignore_cancel_impl(self.data.usercall_nr) } + fn ignore_cancel(&self) -> bool { + self.data.0 != UsercallList::read as u64 && + self.data.0 != UsercallList::read_alloc as u64 && + self.data.0 != UsercallList::write as u64 && + self.data.0 != UsercallList::accept_stream as u64 && + self.data.0 != UsercallList::connect_stream as u64 && + self.data.0 != UsercallList::wait as u64 + } } impl EnclaveState { @@ -892,6 +888,8 @@ impl EnclaveState { *enclave_clone.fifo_guards.lock().await = Some(fifo_guards); *enclave_clone.return_queue_tx.lock().await = Some(return_queue_tx); + let usercall_queue_monitor = usercall_queue_rx.position_monitor(); + tokio::task::spawn_local(async move { while let Ok(usercall) = usercall_queue_rx.recv().await { let _ = io_queue_send.send(UsercallSendData::Async(usercall)); @@ -900,37 +898,32 @@ impl EnclaveState { let (usercall_event_tx, mut usercall_event_rx) = async_mpsc::unbounded_channel(); let usercall_event_tx_clone = usercall_event_tx.clone(); + let usercall_queue_monitor_clone = usercall_queue_monitor.clone(); tokio::task::spawn_local(async move { while let Ok(c) = cancel_queue_rx.recv().await { - if !c.ignore_cancel() { - let _ = usercall_event_tx_clone.send(UsercallEvent::Cancelled(c.id, Instant::now())); - } + let write_position = usercall_queue_monitor_clone.write_position(); + let _ = usercall_event_tx_clone.send(UsercallEvent::Cancelled(c.id, write_position)); } }); tokio::task::spawn_local(async move { let mut notifiers = HashMap::new(); - let mut cancels: HashMap = HashMap::new(); - // This should be greater than the amount of time it takes for the enclave runner - // to start executing a usercall after the enclave sends it on the usercall_queue. - const CANCEL_EXPIRY: Duration = Duration::from_millis(100); + let mut cancels: HashMap = HashMap::new(); loop { match usercall_event_rx.recv().await.expect("usercall_event channel closed unexpectedly") { UsercallEvent::Started(id, notifier) => match cancels.remove(&id) { - Some(t) if t.elapsed() < CANCEL_EXPIRY => { let _ = notifier.send(()); }, + Some(_) => { let _ = notifier.send(()); }, _ => { notifiers.insert(id, notifier); }, }, UsercallEvent::Finished(id) => { notifiers.remove(&id); }, - UsercallEvent::Cancelled(id, t) => if t.elapsed() < CANCEL_EXPIRY { - match notifiers.remove(&id) { - Some(notifier) => { let _ = notifier.send(()); }, - None => { cancels.insert(id, t); }, - } + UsercallEvent::Cancelled(id, wp) => match notifiers.remove(&id) { + Some(notifier) => { let _ = notifier.send(()); }, + None => { cancels.insert(id, wp); }, }, } - // cleanup expired cancels - let now = Instant::now(); - cancels.retain(|_id, &mut t| now - t < CANCEL_EXPIRY); + // cleanup old cancels + let read_position = usercall_queue_monitor.read_position(); + cancels.retain(|_id, wp| !read_position.is_past(wp)); } }); diff --git a/fortanix-sgx-abi/src/lib.rs b/fortanix-sgx-abi/src/lib.rs index 9aff4891..1bf22e0e 100644 --- a/fortanix-sgx-abi/src/lib.rs +++ b/fortanix-sgx-abi/src/lib.rs @@ -626,9 +626,13 @@ impl Usercalls { /// Additionally, userspace may choose to ignore cancellations for non-blocking /// usercalls. Userspace should be able to cancel a usercall that has been sent /// by the enclave but not yet received by the userspace, i.e. if cancellation -/// is received before the usercall itself. However, userspace should not keep -/// cancellations forever since that would prevent the enclave from re-using -/// usercall ids. +/// is received before the usercall itself. To avoid keeping such cancellations +/// forever and preventing the enclave from re-using usercall ids, userspace +/// should synchronize cancel queue with the usercall queue such that the +/// following invariant is maintained: whenever the enclave writes an id to the +/// usercall or cancel queue, the enclave will not reuse that id until the +/// usercall queue's read pointer has advanced to the write pointer at the time +/// the id was written. /// /// *TODO*: Add diagram. /// @@ -718,8 +722,8 @@ pub mod async { #[derive(Copy, Clone, Default)] #[cfg_attr(feature = "rustc-dep-of-std", unstable(feature = "sgx_platform", issue = "56975"))] pub struct Cancel { - /// This must be the same value as `Usercall.0`. - pub usercall_nr: u64, + /// Reserved for future use. + pub reserved: u64, } /// A circular buffer used as a FIFO queue with atomic reads and writes. diff --git a/ipc-queue/src/fifo.rs b/ipc-queue/src/fifo.rs index 31b92989..0a6005a0 100644 --- a/ipc-queue/src/fifo.rs +++ b/ipc-queue/src/fifo.rs @@ -6,7 +6,7 @@ use std::cell::UnsafeCell; use std::mem; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use fortanix_sgx_abi::{FifoDescriptor, WithId}; @@ -33,7 +33,7 @@ where let arc = Arc::new(FifoBuffer::new(len)); let inner = Fifo::from_arc(arc); let tx = AsyncSender { inner: inner.clone(), synchronizer: s.clone() }; - let rx = AsyncReceiver { inner, synchronizer: s }; + let rx = AsyncReceiver { inner, synchronizer: s, read_epoch: Arc::new(AtomicU32::new(0)) }; (tx, rx) } @@ -87,6 +87,12 @@ impl Clone for Fifo { } } +impl Fifo { + pub(crate) fn current_offsets(&self, ordering: Ordering) -> Offsets { + Offsets::new(self.offsets.load(ordering), self.data.len() as u32) + } +} + impl Fifo { pub(crate) unsafe fn from_descriptor(descriptor: FifoDescriptor) -> Self { assert!( @@ -152,7 +158,7 @@ impl Fifo { pub(crate) fn try_send_impl(&self, val: Identified) -> Result { let (new, was_empty) = loop { // 1. Load the current offsets. - let current = Offsets::new(self.offsets.load(Ordering::SeqCst), self.data.len() as u32); + let current = self.current_offsets(Ordering::SeqCst); let was_empty = current.is_empty(); // 2. If the queue is full, wait, then go to step 1. @@ -179,9 +185,9 @@ impl Fifo { Ok(was_empty) } - pub(crate) fn try_recv_impl(&self) -> Result<(Identified, /*wake up writer:*/ bool), TryRecvError> { + pub(crate) fn try_recv_impl(&self) -> Result<(Identified, /*wake up writer:*/ bool, /*read offset wrapped around:*/bool), TryRecvError> { // 1. Load the current offsets. - let current = Offsets::new(self.offsets.load(Ordering::SeqCst), self.data.len() as u32); + let current = self.current_offsets(Ordering::SeqCst); // 2. If the queue is empty, wait, then go to step 1. if current.is_empty() { @@ -216,7 +222,7 @@ impl Fifo { // 8. If the queue was full before step 7, signal the writer to wake up. let was_full = Offsets::new(before, self.data.len() as u32).is_full(); - Ok((val, was_full)) + Ok((val, was_full, new.read_offset() == 0)) } } @@ -282,6 +288,14 @@ impl Offsets { ..*self } } + + pub(crate) fn read_high_bit(&self) -> bool { + self.read & self.len == self.len + } + + pub(crate) fn write_high_bit(&self) -> bool { + self.write & self.len == self.len + } } #[cfg(test)] @@ -308,7 +322,7 @@ mod tests { } for i in 1..=7 { - let (v, wake) = inner.try_recv_impl().unwrap(); + let (v, wake, _) = inner.try_recv_impl().unwrap(); assert!(!wake); assert_eq!(v.id, i); assert_eq!(v.data.0, i); @@ -327,7 +341,7 @@ mod tests { assert!(inner.try_send_impl(Identified { id: 9, data: TestValue(9) }).is_err()); for i in 1..=8 { - let (v, wake) = inner.try_recv_impl().unwrap(); + let (v, wake, _) = inner.try_recv_impl().unwrap(); assert!(if i == 1 { wake } else { !wake }); assert_eq!(v.id, i); assert_eq!(v.data.0, i); diff --git a/ipc-queue/src/interface_async.rs b/ipc-queue/src/interface_async.rs index ea4dc9a7..5571a763 100644 --- a/ipc-queue/src/interface_async.rs +++ b/ipc-queue/src/interface_async.rs @@ -5,6 +5,7 @@ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ use super::*; +use std::sync::atomic::Ordering; unsafe impl Send for AsyncSender {} unsafe impl Sync for AsyncSender {} @@ -52,10 +53,13 @@ impl AsyncReceiver { pub async fn recv(&self) -> Result, RecvError> { loop { match self.inner.try_recv_impl() { - Ok((val, wake_sender)) => { + Ok((val, wake_sender, read_wrapped_around)) => { if wake_sender { self.synchronizer.notify(QueueEvent::NotFull); } + if read_wrapped_around { + self.read_epoch.fetch_add(1, Ordering::Relaxed); + } return Ok(val); } Err(TryRecvError::QueueEmpty) => { @@ -68,6 +72,13 @@ impl AsyncReceiver { } } + pub fn position_monitor(&self) -> PositionMonitor { + PositionMonitor { + read_epoch: self.read_epoch.clone(), + fifo: self.inner.clone(), + } + } + /// Consumes `self` and returns a DescriptorGuard. /// The returned guard can be used to make `FifoDescriptor`s that remain /// valid as long as the guard is not dropped. @@ -153,6 +164,65 @@ mod tests { do_multi_sender(1024, 30, 100).await; } + #[tokio::test] + async fn positions() { + const LEN: usize = 16; + let s = TestAsyncSynchronizer::new(); + let (tx, rx) = bounded_async(LEN, s); + let monitor = rx.position_monitor(); + let mut id = 1; + + let p0 = monitor.write_position(); + tx.send(Identified { id, data: TestValue(1) }).await.unwrap(); + let p1 = monitor.write_position(); + tx.send(Identified { id: id + 1, data: TestValue(2) }).await.unwrap(); + let p2 = monitor.write_position(); + tx.send(Identified { id: id + 2, data: TestValue(3) }).await.unwrap(); + let p3 = monitor.write_position(); + id += 3; + assert!(monitor.read_position().is_past(&p0) == false); + assert!(monitor.read_position().is_past(&p1) == false); + assert!(monitor.read_position().is_past(&p2) == false); + assert!(monitor.read_position().is_past(&p3) == false); + + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == false); + assert!(monitor.read_position().is_past(&p2) == false); + assert!(monitor.read_position().is_past(&p3) == false); + + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == true); + assert!(monitor.read_position().is_past(&p2) == false); + assert!(monitor.read_position().is_past(&p3) == false); + + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == true); + assert!(monitor.read_position().is_past(&p2) == true); + assert!(monitor.read_position().is_past(&p3) == false); + + for i in 0..1000 { + let n = 1 + (i % LEN); + let p4 = monitor.write_position(); + for _ in 0..n { + tx.send(Identified { id, data: TestValue(id) }).await.unwrap(); + id += 1; + } + let p5 = monitor.write_position(); + for _ in 0..n { + rx.recv().await.unwrap(); + assert!(monitor.read_position().is_past(&p0) == true); + assert!(monitor.read_position().is_past(&p1) == true); + assert!(monitor.read_position().is_past(&p2) == true); + assert!(monitor.read_position().is_past(&p3) == true); + assert!(monitor.read_position().is_past(&p4) == true); + assert!(monitor.read_position().is_past(&p5) == false); + } + } + } + struct Subscription { tx: broadcast::Sender, rx: Mutex>, diff --git a/ipc-queue/src/interface_sync.rs b/ipc-queue/src/interface_sync.rs index 66f39fe8..dfed16d4 100644 --- a/ipc-queue/src/interface_sync.rs +++ b/ipc-queue/src/interface_sync.rs @@ -112,7 +112,7 @@ impl Receiver { } pub fn try_recv(&self) -> Result, TryRecvError> { - self.inner.try_recv_impl().map(|(val, wake_sender)| { + self.inner.try_recv_impl().map(|(val, wake_sender, _)| { if wake_sender { self.synchronizer.notify(QueueEvent::NotFull); } @@ -127,7 +127,7 @@ impl Receiver { pub fn recv(&self) -> Result, RecvError> { loop { match self.inner.try_recv_impl() { - Ok((val, wake_sender)) => { + Ok((val, wake_sender, _)) => { if wake_sender { self.synchronizer.notify(QueueEvent::NotFull); } diff --git a/ipc-queue/src/lib.rs b/ipc-queue/src/lib.rs index 68a0f016..85b2a36b 100644 --- a/ipc-queue/src/lib.rs +++ b/ipc-queue/src/lib.rs @@ -10,6 +10,7 @@ use std::future::Future; #[cfg(target_env = "sgx")] use std::os::fortanix_sgx::usercalls::alloc::UserSafeSized; use std::pin::Pin; +use std::sync::atomic::AtomicU32; use std::sync::Arc; use fortanix_sgx_abi::FifoDescriptor; @@ -19,6 +20,7 @@ use self::fifo::{Fifo, FifoBuffer}; mod fifo; mod interface_sync; mod interface_async; +mod position; #[cfg(test)] mod test_support; @@ -123,6 +125,7 @@ pub struct AsyncSender { pub struct AsyncReceiver { inner: Fifo, synchronizer: S, + read_epoch: Arc, } /// `DescriptorGuard` can produce a `FifoDescriptor` that is guaranteed @@ -137,3 +140,19 @@ impl DescriptorGuard { self.descriptor } } + +/// `PositionMonitor` can be used to record the current read/write positions +/// of a queue. Even though a queue is comprised of a limited number of slots +/// arranged as a ring buffer, we can assign a position to each value written/ +/// read to/from the queue. This is useful in case we want to know whether or +/// not a particular value written to the queue has been read. +pub struct PositionMonitor { + read_epoch: Arc, + fifo: Fifo, +} + +/// A read position in a queue. +pub struct ReadPosition(u64); + +/// A write position in a queue. +pub struct WritePosition(u64); diff --git a/ipc-queue/src/position.rs b/ipc-queue/src/position.rs new file mode 100644 index 00000000..22c30e5d --- /dev/null +++ b/ipc-queue/src/position.rs @@ -0,0 +1,48 @@ +/* Copyright (c) Fortanix, Inc. + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use super::*; +use std::sync::atomic::Ordering; + +impl PositionMonitor { + pub fn read_position(&self) -> ReadPosition { + let current = self.fifo.current_offsets(Ordering::Relaxed); + let read_epoch = self.read_epoch.load(Ordering::Relaxed); + ReadPosition(((read_epoch as u64) << 32) | (current.read_offset() as u64)) + } + + pub fn write_position(&self) -> WritePosition { + let current = self.fifo.current_offsets(Ordering::Relaxed); + let mut write_epoch = self.read_epoch.load(Ordering::Relaxed); + if current.read_high_bit() != current.write_high_bit() { + write_epoch += 1; + } + WritePosition(((write_epoch as u64) << 32) | (current.write_offset() as u64)) + } +} + +impl Clone for PositionMonitor { + fn clone(&self) -> Self { + Self { + read_epoch: self.read_epoch.clone(), + fifo: self.fifo.clone(), + } + } +} + +impl ReadPosition { + /// A `WritePosition` can be compared to a `ReadPosition` **correctly** if + /// at most 2³¹ writes have occured since the write position was recorded. + pub fn is_past(&self, write: &WritePosition) -> bool { + let (read, write) = (self.0, write.0); + let hr = read & (1 << 63); + let hw = write & (1 << 63); + if hr == hw { + return read > write; + } + true + } +}