From c05aae4763e4cab24b9bcf5634a3632045c5c240 Mon Sep 17 00:00:00 2001 From: Mohsen Zohrevandi Date: Fri, 2 Oct 2020 11:33:31 -0700 Subject: [PATCH] Async usercall interface for SGX enclaves --- .travis.yml | 3 +- Cargo.lock | 18 + Cargo.toml | 1 + async-usercalls/Cargo.toml | 31 ++ async-usercalls/rustfmt.toml | 1 + async-usercalls/src/alloc/allocator.rs | 145 ++++++++ async-usercalls/src/alloc/bitmap.rs | 156 +++++++++ async-usercalls/src/alloc/io_bufs.rs | 260 ++++++++++++++ async-usercalls/src/alloc/mod.rs | 69 ++++ async-usercalls/src/alloc/slab.rs | 198 +++++++++++ async-usercalls/src/alloc/tests.rs | 323 ++++++++++++++++++ async-usercalls/src/batch_drop.rs | 127 +++++++ async-usercalls/src/callback.rs | 89 +++++ async-usercalls/src/duplicated.rs | 168 +++++++++ async-usercalls/src/hacks/async_queues.rs | 50 +++ async-usercalls/src/hacks/mod.rs | 61 ++++ async-usercalls/src/hacks/unsafe_typecasts.rs | 95 ++++++ async-usercalls/src/lib.rs | 165 +++++++++ async-usercalls/src/provider_api.rs | 274 +++++++++++++++ async-usercalls/src/provider_core.rs | 69 ++++ async-usercalls/src/queues.rs | 188 ++++++++++ async-usercalls/src/raw.rs | 155 +++++++++ async-usercalls/src/tests.rs | 251 ++++++++++++++ async-usercalls/test.sh | 14 + 24 files changed, 2910 insertions(+), 1 deletion(-) create mode 100644 async-usercalls/Cargo.toml create mode 100644 async-usercalls/rustfmt.toml create mode 100644 async-usercalls/src/alloc/allocator.rs create mode 100644 async-usercalls/src/alloc/bitmap.rs create mode 100644 async-usercalls/src/alloc/io_bufs.rs create mode 100644 async-usercalls/src/alloc/mod.rs create mode 100644 async-usercalls/src/alloc/slab.rs create mode 100644 async-usercalls/src/alloc/tests.rs create mode 100644 async-usercalls/src/batch_drop.rs create mode 100644 async-usercalls/src/callback.rs create mode 100644 async-usercalls/src/duplicated.rs create mode 100644 async-usercalls/src/hacks/async_queues.rs create mode 100644 async-usercalls/src/hacks/mod.rs create mode 100644 async-usercalls/src/hacks/unsafe_typecasts.rs create mode 100644 async-usercalls/src/lib.rs create mode 100644 async-usercalls/src/provider_api.rs create mode 100644 async-usercalls/src/provider_core.rs create mode 100644 async-usercalls/src/queues.rs create mode 100644 async-usercalls/src/raw.rs create mode 100644 async-usercalls/src/tests.rs create mode 100755 async-usercalls/test.sh diff --git a/.travis.yml b/.travis.yml index 8bb268701..174f1b28d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -29,7 +29,8 @@ matrix: before_script: - rustup target add x86_64-fortanix-unknown-sgx x86_64-unknown-linux-musl script: - - cargo test --verbose --all + - cargo test --verbose --all --exclude async-usercalls + - cargo test --verbose -p async-usercalls --target x86_64-fortanix-unknown-sgx --no-run - cargo test --verbose -p sgx-isa --features sgxstd -Z package-features --target x86_64-fortanix-unknown-sgx --no-run - cargo test --verbose -p sgxs-tools --features pe2sgxs --bin isgx-pe2sgx -Z package-features - cargo test --verbose -p dcap-ql --features link -Z package-features diff --git a/Cargo.lock b/Cargo.lock index 7dae97b4b..9febce21e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -58,6 +58,18 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d663a8e9a99154b5fb793032533f6328da35e23aac63d5c152279aa8ba356825" +[[package]] +name = "async-usercalls" +version = "0.1.0" +dependencies = [ + "crossbeam-channel", + "fnv", + "fortanix-sgx-abi", + "ipc-queue", + "lazy_static 1.4.0", + "spin", +] + [[package]] name = "atty" version = "0.2.14" @@ -2204,6 +2216,12 @@ dependencies = [ "winapi 0.3.8", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 96cf25bcb..df294fd1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "aesm-client", + "async-usercalls", "dcap-provider", "dcap-ql-sys", "dcap-ql", diff --git a/async-usercalls/Cargo.toml b/async-usercalls/Cargo.toml new file mode 100644 index 000000000..b1a1c7ee9 --- /dev/null +++ b/async-usercalls/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "async-usercalls" +version = "0.1.0" +authors = ["Fortanix, Inc."] +license = "MPL-2.0" +edition = "2018" +description = """ +An interface for asynchronous usercalls in SGX enclaves. + +This is an SGX-only crate, you should compile it with the `x86_64-fortanix-unknown-sgx` target. +""" +repository = "https://github.com/fortanix/rust-sgx" +documentation = "https://edp.fortanix.com/docs/api/async_usercalls/" +homepage = "https://edp.fortanix.com/" +keywords = ["sgx", "async", "usercall"] +categories = ["asynchronous"] + +[dependencies] +# Project dependencies +ipc-queue = { version = "0.1", path = "../ipc-queue" } +fortanix-sgx-abi = { version = "0.3", path = "../fortanix-sgx-abi" } + +# External dependencies +lazy_static = "1.4.0" # MIT/Apache-2.0 +crossbeam-channel = "0.4" # MIT/Apache-2.0 +spin = "0.5" # MIT/Apache-2.0 +fnv = "1.0" # MIT/Apache-2.0 + +# For cargo test --target x86_64-fortanix-unknown-sgx +[package.metadata.fortanix-sgx] +threads = 128 diff --git a/async-usercalls/rustfmt.toml b/async-usercalls/rustfmt.toml new file mode 100644 index 000000000..753065179 --- /dev/null +++ b/async-usercalls/rustfmt.toml @@ -0,0 +1 @@ +max_width = 120 diff --git a/async-usercalls/src/alloc/allocator.rs b/async-usercalls/src/alloc/allocator.rs new file mode 100644 index 000000000..7c6cef9f9 --- /dev/null +++ b/async-usercalls/src/alloc/allocator.rs @@ -0,0 +1,145 @@ +use super::slab::{BufSlab, Slab, SlabAllocator, User, MAX_COUNT}; +use std::cmp; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; + +pub const MIN_BUF_SIZE: usize = 1 << 5; // 32 bytes +pub const MAX_BUF_SIZE: usize = 1 << 16; // 64 KB +pub const NUM_SIZES: usize = 1 + (MAX_BUF_SIZE / MIN_BUF_SIZE).trailing_zeros() as usize; + +pub struct SharedAllocator { + by_size: Vec>, + byte_buffers: Vec>, +} + +unsafe impl Send for SharedAllocator {} +unsafe impl Sync for SharedAllocator {} + +impl SharedAllocator { + pub fn new(buf_counts: [usize; NUM_SIZES], byte_buffer_count: usize) -> Self { + let mut by_size = Vec::with_capacity(NUM_SIZES); + for i in 0..NUM_SIZES { + by_size.push(make_buf_slabs(buf_counts[i], MIN_BUF_SIZE << i)); + } + let byte_buffers = make_byte_buffers(byte_buffer_count); + Self { by_size, byte_buffers } + } + + pub fn alloc_buf(&self, size: usize) -> Option> { + assert!(size > 0); + if size > MAX_BUF_SIZE { + return None; + } + let (_, index) = size_index(size); + self.by_size[index].alloc() + } + + pub fn alloc_byte_buffer(&self) -> Option> { + self.byte_buffers.alloc() + } +} + +pub struct LocalAllocator { + initial_buf_counts: [usize; NUM_SIZES], + initial_byte_buffer_count: usize, + inner: SharedAllocator, +} + +impl LocalAllocator { + pub fn new(initial_buf_counts: [usize; NUM_SIZES], initial_byte_buffer_count: usize) -> Self { + let mut by_size = Vec::with_capacity(NUM_SIZES); + by_size.resize_with(NUM_SIZES, Default::default); + let byte_buffers = Vec::new(); + Self { + initial_buf_counts, + initial_byte_buffer_count, + inner: SharedAllocator { by_size, byte_buffers }, + } + } + + pub fn alloc_buf(&mut self, request_size: usize) -> User<[u8]> { + assert!(request_size > 0); + if request_size > MAX_BUF_SIZE { + // Always allocate very large buffers directly + return User::<[u8]>::uninitialized(request_size); + } + let (size, index) = size_index(request_size); + if let Some(buf) = self.inner.by_size[index].alloc() { + return buf; + } + let slabs = &mut self.inner.by_size[index]; + if slabs.len() >= 8 { + // Keep the number of slabs for each size small. + return User::<[u8]>::uninitialized(request_size); + } + let count = slabs.last().map_or(self.initial_buf_counts[index], |s| s.count() * 2); + // Limit each slab's count for better worst-case performance. + let count = cmp::min(count, MAX_COUNT / 8); + slabs.push(BufSlab::new(count, size)); + slabs.last().unwrap().alloc().expect("fresh slab failed to allocate") + } + + pub fn alloc_byte_buffer(&mut self) -> User { + let bbs = &mut self.inner.byte_buffers; + if let Some(byte_buffer) = bbs.alloc() { + return byte_buffer; + } + if bbs.len() >= 8 { + // Keep the number of slabs small. + return User::::uninitialized(); + } + let count = bbs.last().map_or(self.initial_byte_buffer_count, |s| s.count() * 2); + // Limit each slab's count for better worst-case performance. + let count = cmp::min(count, MAX_COUNT / 8); + bbs.push(Slab::new(count)); + bbs.last().unwrap().alloc().expect("fresh slab failed to allocate") + } +} + +fn make_buf_slabs(count: usize, size: usize) -> Vec { + match count { + 0 => Vec::new(), + n if n < 1024 => vec![BufSlab::new(n, size)], + n if n < 4 * 1024 => vec![BufSlab::new(n / 2, size), BufSlab::new(n / 2, size)], + n if n < 32 * 1024 => vec![ + BufSlab::new(n / 4, size), + BufSlab::new(n / 4, size), + BufSlab::new(n / 4, size), + BufSlab::new(n / 4, size), + ], + n => vec![ + BufSlab::new(n / 8, size), + BufSlab::new(n / 8, size), + BufSlab::new(n / 8, size), + BufSlab::new(n / 8, size), + BufSlab::new(n / 8, size), + BufSlab::new(n / 8, size), + BufSlab::new(n / 8, size), + BufSlab::new(n / 8, size), + ], + } +} + +fn make_byte_buffers(count: usize) -> Vec> { + match count { + 0 => Vec::new(), + n if n < 1024 => vec![Slab::new(n)], + n if n < 4 * 1024 => vec![Slab::new(n / 2), Slab::new(n / 2)], + n if n < 32 * 1024 => vec![Slab::new(n / 4), Slab::new(n / 4), Slab::new(n / 4), Slab::new(n / 4)], + n => vec![ + Slab::new(n / 8), + Slab::new(n / 8), + Slab::new(n / 8), + Slab::new(n / 8), + Slab::new(n / 8), + Slab::new(n / 8), + Slab::new(n / 8), + Slab::new(n / 8), + ], + } +} + +fn size_index(request_size: usize) -> (usize, usize) { + let size = cmp::max(MIN_BUF_SIZE, request_size.next_power_of_two()); + let index = (size / MIN_BUF_SIZE).trailing_zeros() as usize; + (size, index) +} diff --git a/async-usercalls/src/alloc/bitmap.rs b/async-usercalls/src/alloc/bitmap.rs new file mode 100644 index 000000000..80da1cca5 --- /dev/null +++ b/async-usercalls/src/alloc/bitmap.rs @@ -0,0 +1,156 @@ +use spin::Mutex; +use std::sync::atomic::*; + +pub struct OptionalBitmap(BitmapKind); + +struct LargeBitmap(Mutex); + +struct LargeBitmapInner { + bits: Box<[u64]>, + unset_count: usize, // optimization +} + +enum BitmapKind { + None, + V1(AtomicU8), + V2(AtomicU16), + V3(AtomicU32), + V4(AtomicU64), + V5(LargeBitmap), +} + +impl OptionalBitmap { + pub fn none() -> Self { + Self(BitmapKind::None) + } + + /// `bit_count` must be >= 8 and a power of two + pub fn new(bit_count: usize) -> Self { + Self(match bit_count { + 8 => BitmapKind::V1(AtomicU8::new(0)), + 16 => BitmapKind::V2(AtomicU16::new(0)), + 32 => BitmapKind::V3(AtomicU32::new(0)), + 64 => BitmapKind::V4(AtomicU64::new(0)), + n if n > 0 && n % 64 == 0 => { + let bits = vec![0u64; n / 64].into_boxed_slice(); + BitmapKind::V5(LargeBitmap(Mutex::new(LargeBitmapInner { + bits, + unset_count: bit_count, + }))) + } + _ => panic!("bit_count must be >= 8 and a power of two"), + }) + } + + /// set the bit at given index to 0 and panic if the old value was not 1. + pub fn unset(&self, index: usize) { + match self.0 { + BitmapKind::None => {} + BitmapKind::V1(ref a) => a.unset(index), + BitmapKind::V2(ref b) => b.unset(index), + BitmapKind::V3(ref c) => c.unset(index), + BitmapKind::V4(ref d) => d.unset(index), + BitmapKind::V5(ref e) => e.unset(index), + } + } + + /// return the index of a previously unset bit and set that bit to 1. + pub fn reserve(&self) -> Option { + match self.0 { + BitmapKind::None => None, + BitmapKind::V1(ref a) => a.reserve(), + BitmapKind::V2(ref b) => b.reserve(), + BitmapKind::V3(ref c) => c.reserve(), + BitmapKind::V4(ref d) => d.reserve(), + BitmapKind::V5(ref e) => e.reserve(), + } + } +} + +trait BitmapOps { + fn unset(&self, index: usize); + fn reserve(&self) -> Option; +} + +macro_rules! impl_bitmap_ops { + ( $( $t:ty ),* $(,)? ) => {$( + impl BitmapOps for $t { + fn unset(&self, index: usize) { + let bit = 1 << index; + let old = self.fetch_and(!bit, Ordering::Release) & bit; + assert!(old != 0); + } + + fn reserve(&self) -> Option { + let initial = self.load(Ordering::Relaxed); + let unset_count = initial.count_zeros(); + let (mut index, mut bit) = match unset_count { + 0 => return None, + _ => (0, 1), + }; + for _ in 0..unset_count { + // find the next unset bit + while bit & initial != 0 { + index += 1; + bit = bit << 1; + } + let old = self.fetch_or(bit, Ordering::Acquire) & bit; + if old == 0 { + return Some(index); + } + index += 1; + bit = bit << 1; + } + None + } + } + )*}; +} + +impl_bitmap_ops!(AtomicU8, AtomicU16, AtomicU32, AtomicU64); + +impl BitmapOps for LargeBitmap { + fn unset(&self, index: usize) { + let mut inner = self.0.lock(); + let array = &mut inner.bits; + assert!(index < array.len() * 64); + let slot = index / 64; + let offset = index % 64; + let element = &mut array[slot]; + + let bit = 1 << offset; + let old = *element & bit; + *element = *element & !bit; + inner.unset_count += 1; + assert!(old != 0); + } + + fn reserve(&self) -> Option { + let mut inner = self.0.lock(); + if inner.unset_count == 0 { + return None; + } + let array = &mut inner.bits; + for slot in 0..array.len() { + if let (Some(offset), val) = reserve_u64(array[slot]) { + array[slot] = val; + inner.unset_count -= 1; + return Some(slot * 64 + offset); + } + } + unreachable!() + } +} + +fn reserve_u64(element: u64) -> (Option, u64) { + let (mut index, mut bit) = match element.count_zeros() { + 0 => return (None, element), + _ => (0, 1), + }; + // find the first unset bit + while bit & element != 0 { + index += 1; + bit = bit << 1; + } + (Some(index), element | bit) +} diff --git a/async-usercalls/src/alloc/io_bufs.rs b/async-usercalls/src/alloc/io_bufs.rs new file mode 100644 index 000000000..3880e763e --- /dev/null +++ b/async-usercalls/src/alloc/io_bufs.rs @@ -0,0 +1,260 @@ +use super::slab::User; +use std::cell::UnsafeCell; +use std::cmp; +use std::io::IoSlice; +use std::ops::{Deref, DerefMut, Range}; +use std::os::fortanix_sgx::usercalls::alloc::UserRef; +use std::sync::Arc; + +pub struct UserBuf(UserBufKind); + +enum UserBufKind { + Owned { + user: User<[u8]>, + range: Range, + }, + Shared { + user: Arc>>, + range: Range, + }, +} + +impl UserBuf { + pub fn into_user(self) -> Result, Self> { + match self.0 { + UserBufKind::Owned { user, .. } => Ok(user), + UserBufKind::Shared { user, range } => Err(Self(UserBufKind::Shared { user, range })), + } + } + + fn into_shared(self) -> Option>>> { + match self.0 { + UserBufKind::Owned { .. } => None, + UserBufKind::Shared { user, .. } => Some(user), + } + } +} + +unsafe impl Send for UserBuf {} + +impl Deref for UserBuf { + type Target = UserRef<[u8]>; + + fn deref(&self) -> &Self::Target { + match self.0 { + UserBufKind::Owned { ref user, ref range } => &user[range.start..range.end], + UserBufKind::Shared { ref user, ref range } => { + let user = unsafe { &*user.get() }; + &user[range.start..range.end] + } + } + } +} + +impl DerefMut for UserBuf { + fn deref_mut(&mut self) -> &mut Self::Target { + match self.0 { + UserBufKind::Owned { + ref mut user, + ref range, + } => &mut user[range.start..range.end], + UserBufKind::Shared { ref user, ref range } => { + let user = unsafe { &mut *user.get() }; + &mut user[range.start..range.end] + } + } + } +} + +impl From> for UserBuf { + fn from(user: User<[u8]>) -> Self { + UserBuf(UserBufKind::Owned { + range: 0..user.len(), + user, + }) + } +} + +impl From<(User<[u8]>, Range)> for UserBuf { + fn from(pair: (User<[u8]>, Range)) -> Self { + UserBuf(UserBufKind::Owned { + user: pair.0, + range: pair.1, + }) + } +} + +/// `WriteBuffer` provides a ring buffer that can be written to by the code +/// running in the enclave while a portion of it can be passed to a `write` +/// usercall running concurrently. It ensures that enclave code does not write +/// to the portion sent to userspace. +pub struct WriteBuffer { + userbuf: Arc>>, + buf_len: usize, + read: u32, + write: u32, +} + +unsafe impl Send for WriteBuffer {} + +impl WriteBuffer { + pub fn new(userbuf: User<[u8]>) -> Self { + Self { + buf_len: userbuf.len(), + userbuf: Arc::new(UnsafeCell::new(userbuf)), + read: 0, + write: 0, + } + } + + pub fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> usize { + if self.is_full() { + return 0; + } + let mut wrote = 0; + for buf in bufs { + wrote += self.write(buf); + } + wrote + } + + pub fn write(&mut self, buf: &[u8]) -> usize { + let (_, write_offset) = self.offsets(); + let rem = self.remaining_capacity(); + let can_write = cmp::min(buf.len(), rem); + let end = cmp::min(self.buf_len, write_offset + can_write); + let n = end - write_offset; + unsafe { + let userbuf = &mut *self.userbuf.get(); + userbuf[write_offset..write_offset + n].copy_from_enclave(&buf[..n]); + } + self.advance_write(n); + n + if n < can_write { self.write(&buf[n..]) } else { 0 } + } + + /// This function returns a slice of bytes appropriate for writing to a socket. + /// Once some or all of these bytes are successfully written to the socket, + /// `self.consume()` must be called to actually consume those bytes. + /// + /// Returns None if the buffer is empty. + /// + /// Panics if called more than once in a row without either calling `consume()` + /// or dropping the previously returned buffer. + pub fn consumable_chunk(&mut self) -> Option { + assert!( + Arc::strong_count(&self.userbuf) == 1, + "called consumable_chunk() more than once in a row" + ); + let range = match self.offsets() { + (_, _) if self.read == self.write => return None, // empty + (r, w) if r < w => r..w, + (r, _) => r..self.buf_len, + }; + Some(UserBuf(UserBufKind::Shared { + user: self.userbuf.clone(), + range, + })) + } + + /// Mark `n` bytes as consumed. `buf` must have been produced by a call + /// to `self.consumable_chunk()`. + /// Panics if: + /// - `n > buf.len()` + /// - `buf` was not produced by `self.consumable_chunk()` + /// + /// This function is supposed to be used in conjunction with `consumable_chunk()`. + pub fn consume(&mut self, buf: UserBuf, n: usize) { + assert!(n <= buf.len()); + const PANIC_MESSAGE: &'static str = "`buf` not produced by self.consumable_chunk()"; + let buf = buf.into_shared().expect(PANIC_MESSAGE); + assert!(Arc::ptr_eq(&self.userbuf, &buf), PANIC_MESSAGE); + drop(buf); + assert!(Arc::strong_count(&self.userbuf) == 1, PANIC_MESSAGE); + self.advance_read(n); + } + + fn len(&self) -> usize { + match self.offsets() { + (_, _) if self.read == self.write => 0, // empty + (r, w) if r == w && self.read != self.write => self.buf_len, // full + (r, w) if r < w => w - r, + (r, w) => w + self.buf_len - r, + } + } + + fn remaining_capacity(&self) -> usize { + let len = self.len(); + debug_assert!(len <= self.buf_len); + self.buf_len - len + } + + fn offsets(&self) -> (usize, usize) { + (self.read as usize % self.buf_len, self.write as usize % self.buf_len) + } + + pub fn is_empty(&self) -> bool { + self.read == self.write + } + + fn is_full(&self) -> bool { + let (read_offset, write_offset) = self.offsets(); + read_offset == write_offset && self.read != self.write + } + + fn advance_read(&mut self, by: usize) { + debug_assert!(by <= self.len()); + self.read = ((self.read as usize + by) % (self.buf_len * 2)) as _; + } + + fn advance_write(&mut self, by: usize) { + debug_assert!(by <= self.remaining_capacity()); + self.write = ((self.write as usize + by) % (self.buf_len * 2)) as _; + } +} + +pub struct ReadBuffer { + userbuf: User<[u8]>, + position: usize, + len: usize, +} + +impl ReadBuffer { + /// Constructs a new `ReadBuffer`, assuming `len` bytes of `userbuf` have + /// meaningful data. Panics if `len > userbuf.len()`. + pub fn new(userbuf: User<[u8]>, len: usize) -> ReadBuffer { + assert!(len <= userbuf.len()); + ReadBuffer { + userbuf, + position: 0, + len, + } + } + + pub fn read(&mut self, buf: &mut [u8]) -> usize { + debug_assert!(self.position <= self.len); + if self.position == self.len { + return 0; + } + let n = cmp::min(buf.len(), self.len - self.position); + self.userbuf[self.position..self.position + n].copy_to_enclave(&mut buf[..n]); + self.position += n; + n + } + + /// Returns the number of bytes that have not been read yet. + pub fn remaining_bytes(&self) -> usize { + debug_assert!(self.position <= self.len); + self.len - self.position + } + + pub fn len(&self) -> usize { + self.len + } + + /// Consumes self and returns the internal userspace buffer. + /// It's the caller's responsibility to ensure all bytes have been read + /// before calling this function. + pub fn into_inner(self) -> User<[u8]> { + self.userbuf + } +} diff --git a/async-usercalls/src/alloc/mod.rs b/async-usercalls/src/alloc/mod.rs new file mode 100644 index 000000000..ab1085c04 --- /dev/null +++ b/async-usercalls/src/alloc/mod.rs @@ -0,0 +1,69 @@ +use std::cell::RefCell; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; + +mod allocator; +mod bitmap; +mod io_bufs; +mod slab; +#[cfg(test)] +mod tests; + +use self::allocator::{LocalAllocator, SharedAllocator}; +pub use self::io_bufs::{ReadBuffer, UserBuf, WriteBuffer}; +pub use self::slab::{User, UserSafeExt}; + +/// Allocates a slice of bytes in userspace that is at least as large as `size`. +pub fn alloc_buf(size: usize) -> User<[u8]> { + if let Some(buf) = SHARED.alloc_buf(size) { + return buf; + } + LOCAL.with(|local| local.borrow_mut().alloc_buf(size)) +} + +/// Allocates a `ByteBuffer` in userspace. +pub fn alloc_byte_buffer() -> User { + if let Some(bb) = SHARED.alloc_byte_buffer() { + return bb; + } + LOCAL.with(|local| local.borrow_mut().alloc_byte_buffer()) +} + +lazy_static::lazy_static! { + static ref SHARED: SharedAllocator = SharedAllocator::new( + [ + 8192, // x 32 bytes + 4096, // x 64 bytes + 2048, // x 128 bytes + 1024, // x 256 bytes + 512, // x 512 bytes + 256, // x 1 KB + 64, // x 2 KB + 32, // x 4 KB + 16, // x 8 KB + 1024, // x 16 KB + 32, // x 32 KB + 16, // x 64 KB + ], + 8192, // x ByteBuffer(s) + ); +} + +std::thread_local! { + static LOCAL: RefCell = RefCell::new(LocalAllocator::new( + [ + 128, // x 32 bytes + 64, // x 64 bytes + 32, // x 128 bytes + 16, // x 256 bytes + 8, // x 512 bytes + 8, // x 1 KB + 8, // x 2 KB + 8, // x 4 KB + 8, // x 8 KB + 8, // x 16 KB + 8, // x 32 KB + 8, // x 64 KB + ], + 64, // x ByteBuffer(s) + )); +} diff --git a/async-usercalls/src/alloc/slab.rs b/async-usercalls/src/alloc/slab.rs new file mode 100644 index 000000000..a9e0a0c44 --- /dev/null +++ b/async-usercalls/src/alloc/slab.rs @@ -0,0 +1,198 @@ +use super::bitmap::OptionalBitmap; +use std::cell::UnsafeCell; +use std::mem; +use std::ops::{Deref, DerefMut}; +use std::os::fortanix_sgx::usercalls::alloc::{User as StdUser, UserRef, UserSafe, UserSafeSized}; +use std::sync::Arc; + +pub const MIN_COUNT: usize = 8; +pub const MAX_COUNT: usize = 64 * 1024; +pub const MIN_UNIT_LEN: usize = 32; + +pub trait SlabAllocator { + type Output; + + fn alloc(&self) -> Option; + fn count(&self) -> usize; + fn total_size(&self) -> usize; +} + +impl SlabAllocator for Vec { + type Output = A::Output; + + fn alloc(&self) -> Option { + for a in self.iter() { + if let Some(buf) = a.alloc() { + return Some(buf); + } + } + None + } + + fn count(&self) -> usize { + self.iter().map(|a| a.count()).sum() + } + + fn total_size(&self) -> usize { + self.iter().map(|a| a.total_size()).sum() + } +} + +struct Storage { + user: UnsafeCell>, + bitmap: OptionalBitmap, +} + +pub struct BufSlab { + storage: Arc>, + unit_len: usize, +} + +impl BufSlab { + pub fn new(count: usize, unit_len: usize) -> Self { + assert!(count.is_power_of_two() && count >= MIN_COUNT && count <= MAX_COUNT); + assert!(unit_len.is_power_of_two() && unit_len >= MIN_UNIT_LEN); + BufSlab { + storage: Arc::new(Storage { + user: UnsafeCell::new(StdUser::<[u8]>::uninitialized(count * unit_len)), + bitmap: OptionalBitmap::new(count), + }), + unit_len, + } + } +} + +impl SlabAllocator for BufSlab { + type Output = User<[u8]>; + + fn alloc(&self) -> Option { + let index = self.storage.bitmap.reserve()?; + let start = index * self.unit_len; + let end = start + self.unit_len; + let user = unsafe { &mut *self.storage.user.get() }; + let user_ref = &mut user[start..end]; + Some(User { + user_ref, + storage: self.storage.clone(), + index, + }) + } + + fn count(&self) -> usize { + self.total_size() / self.unit_len + } + + fn total_size(&self) -> usize { + let user = unsafe { &*self.storage.user.get() }; + user.len() + } +} + +pub trait UserSafeExt: UserSafe { + type Element: UserSafeSized; +} + +impl UserSafeExt for [T] { + type Element = T; +} + +impl UserSafeExt for T { + type Element = T; +} + +pub struct User { + user_ref: &'static mut UserRef, + storage: Arc>, + index: usize, +} + +unsafe impl Send for User {} + +impl User { + pub fn uninitialized() -> Self { + let storage = Arc::new(Storage { + user: UnsafeCell::new(StdUser::<[T]>::uninitialized(1)), + bitmap: OptionalBitmap::none(), + }); + let user = unsafe { &mut *storage.user.get() }; + let user_ref = &mut user[0]; + Self { + user_ref, + storage, + index: 0, + } + } +} + +impl User<[T]> { + pub fn uninitialized(n: usize) -> Self { + let storage = Arc::new(Storage { + user: UnsafeCell::new(StdUser::<[T]>::uninitialized(n)), + bitmap: OptionalBitmap::none(), + }); + let user = unsafe { &mut *storage.user.get() }; + let user_ref = &mut user[..]; + Self { + user_ref, + storage, + index: 0, + } + } +} + +impl Drop for User { + fn drop(&mut self) { + self.storage.bitmap.unset(self.index); + } +} + +impl Deref for User { + type Target = UserRef; + + fn deref(&self) -> &Self::Target { + self.user_ref + } +} + +impl DerefMut for User { + fn deref_mut(&mut self) -> &mut Self::Target { + self.user_ref + } +} + +pub struct Slab(Arc>); + +impl Slab { + pub fn new(count: usize) -> Self { + assert!(count.is_power_of_two() && count >= MIN_COUNT && count <= MAX_COUNT); + Slab(Arc::new(Storage { + user: UnsafeCell::new(StdUser::<[T]>::uninitialized(count)), + bitmap: OptionalBitmap::new(count), + })) + } +} + +impl SlabAllocator for Slab { + type Output = User; + + fn alloc(&self) -> Option { + let index = self.0.bitmap.reserve()?; + let user = unsafe { &mut *self.0.user.get() }; + let user_ref = &mut user[index]; + Some(User { + user_ref, + storage: self.0.clone(), + index, + }) + } + + fn count(&self) -> usize { + let user = unsafe { &*self.0.user.get() }; + user.len() + } + + fn total_size(&self) -> usize { + let user = unsafe { &*self.0.user.get() }; + user.len() * mem::size_of::() + } +} diff --git a/async-usercalls/src/alloc/tests.rs b/async-usercalls/src/alloc/tests.rs new file mode 100644 index 000000000..da4e8b3d3 --- /dev/null +++ b/async-usercalls/src/alloc/tests.rs @@ -0,0 +1,323 @@ +use super::allocator::SharedAllocator; +use super::bitmap::*; +use super::io_bufs::{ReadBuffer, UserBuf, WriteBuffer}; +use super::slab::{BufSlab, Slab, SlabAllocator, User}; +use crossbeam_channel as mpmc; +use std::collections::HashSet; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; +use std::sync::atomic::*; +use std::sync::Arc; +use std::thread; +use std::time::Instant; + +// Copied from Rust tests (test/ui/mpsc_stress.rs) +struct Barrier { + // Not using mutex/condvar for precision + shared: Arc, + count: usize, +} + +impl Barrier { + fn new(count: usize) -> Vec { + let shared = Arc::new(AtomicUsize::new(0)); + (0..count) + .map(|_| Barrier { + shared: shared.clone(), + count: count, + }) + .collect() + } + + /// Returns when `count` threads enter `wait` + fn wait(self) { + self.shared.fetch_add(1, Ordering::SeqCst); + while self.shared.load(Ordering::SeqCst) != self.count {} + } +} + +#[test] +fn bitmap() { + const BITS: usize = 1024; + let bitmap = OptionalBitmap::new(BITS); + for _ in 0..BITS { + assert!(bitmap.reserve().is_some()); + } + let mut indices = vec![34, 7, 5, 6, 120, 121, 122, 127, 0, 9] + .into_iter() + .collect::>(); + for &i in indices.iter() { + bitmap.unset(i); + } + while let Some(index) = bitmap.reserve() { + assert!(indices.remove(&index)); + } + assert!(indices.is_empty()); +} + +#[test] +fn bitmap_concurrent_use() { + const BITS: usize = 16; + const THREADS: usize = 4; + let bitmap = Arc::new(OptionalBitmap::new(BITS)); + for _ in 0..BITS - THREADS { + bitmap.reserve().unwrap(); + } + let mut handles = Vec::with_capacity(THREADS); + let mut barriers = Barrier::new(THREADS); + let (tx, rx) = mpmc::unbounded(); + + for _ in 0..THREADS { + let bitmap = Arc::clone(&bitmap); + let barrier = barriers.pop().unwrap(); + let tx = tx.clone(); + + handles.push(thread::spawn(move || { + barrier.wait(); + let index = bitmap.reserve().unwrap(); + tx.send(index).unwrap(); + })); + } + drop(tx); + for x in rx.iter() { + bitmap.unset(x); + } + for h in handles { + h.join().unwrap(); + } +} + +#[test] +fn buf_slab() { + const COUNT: usize = 16; + const SIZE: usize = 64; + let buf_slab = BufSlab::new(COUNT, SIZE); + + let bufs = (0..COUNT) + .map(|_| { + let buf = buf_slab.alloc().unwrap(); + assert!(buf.len() == SIZE); + buf + }) + .collect::>(); + + assert!(buf_slab.alloc().is_none()); + drop(bufs); + assert!(buf_slab.alloc().is_some()); +} + +#[test] +fn byte_buffer_slab() { + const COUNT: usize = 256; + let slab = Slab::::new(COUNT); + + let bufs = (0..COUNT) + .map(|_| slab.alloc().unwrap()) + .collect::>>(); + + assert!(slab.alloc().is_none()); + drop(bufs); + assert!(slab.alloc().is_some()); +} + +#[test] +fn user_is_send() { + const COUNT: usize = 16; + const SIZE: usize = 1024; + let buf_slab = BufSlab::new(COUNT, SIZE); + + let mut user = buf_slab.alloc().unwrap(); + + let h = thread::spawn(move || { + user[0..5].copy_from_enclave(b"hello"); + }); + + h.join().unwrap(); +} + +fn slab_speed(count: usize) { + let t0 = Instant::now(); + const SIZE: usize = 32; + const N: u32 = 100_000; + let buf_slab = BufSlab::new(count, SIZE); + + let bufs = (0..count - 1).map(|_| buf_slab.alloc().unwrap()).collect::>(); + + let mut x = 0; + for _ in 0..N { + let b = buf_slab.alloc().unwrap(); + x += b.len(); + } + drop(bufs); + drop(buf_slab); + let d = t0.elapsed(); + assert!(x > 0); // prevent the compiler from removing the whole loop above in release mode + println!("count = {} took {:?}", count, d / N); +} + +#[test] +#[ignore] +fn speed_slab() { + println!("\n"); + for i in 3..=16 { + slab_speed(1 << i); + } +} + +#[test] +#[ignore] +fn speed_direct() { + use std::os::fortanix_sgx::usercalls::alloc::User; + + let t0 = Instant::now(); + const SIZE: usize = 32; + const N: u32 = 100_000; + let mut x = 0; + for _ in 0..N { + let b = User::<[u8]>::uninitialized(SIZE); + x += b.len(); + } + let d = t0.elapsed(); + assert!(x > 0); + println!("took {:?}", d / N); +} + +#[test] +fn shared_allocator() { + let a = SharedAllocator::new( + [ + /*32:*/ 2048, /*64:*/ 1024, /*128:*/ 512, /*256:*/ 256, /*512:*/ 128, + /*1K:*/ 64, /*2K:*/ 0, /*4K:*/ 0, /*8K:*/ 0, /*16K:*/ 0, /*32K:*/ 0, + /*64K:*/ 1024, + ], + 1024, + ); + for size in 1..=32 { + let b = a.alloc_buf(size).unwrap(); + assert!(b.len() == 32); + } + for size in 33..=64 { + let b = a.alloc_buf(size).unwrap(); + assert!(b.len() == 64); + } + for &size in &[65, 79, 83, 120, 127, 128] { + let b = a.alloc_buf(size).unwrap(); + assert!(b.len() == 128); + } + for &size in &[129, 199, 210, 250, 255, 256] { + let b = a.alloc_buf(size).unwrap(); + assert!(b.len() == 256); + } + for &size in &[257, 299, 365, 500, 512] { + let b = a.alloc_buf(size).unwrap(); + assert!(b.len() == 512); + } + for &size in &[513, 768, 1023, 1024] { + let b = a.alloc_buf(size).unwrap(); + assert!(b.len() == 1024); + } + for i in 2..=32 { + assert!(a.alloc_buf(i * 1024).is_none()); + } + for i in 33..=64 { + let b = a.alloc_buf(i * 1024).unwrap(); + assert!(b.len() == 64 * 1024); + } +} + +fn alloc_speed(count: usize) { + let t0 = Instant::now(); + const SIZE: usize = 32; + const N: u32 = 100_000; + + let bufs = (0..count - 1).map(|_| super::alloc_buf(SIZE)).collect::>(); + + let mut x = 0; + for _ in 0..N { + let b = super::alloc_buf(SIZE); + x += b.len(); + } + drop(bufs); + let d = t0.elapsed(); + assert!(x > 0); + println!("count = {} took {:?}", count, d / N); +} + +#[test] +#[ignore] +fn speed_overall() { + println!("\n"); + for i in 3..=14 { + alloc_speed(1 << i); + } +} + +#[test] +fn alloc_buf_size() { + let b = super::alloc_buf(32); + assert_eq!(b.len(), 32); + let b = super::alloc_buf(128); + assert_eq!(b.len(), 128); + let b = super::alloc_buf(900); + assert_eq!(b.len(), 1024); + let b = super::alloc_buf(8 * 1024); + assert_eq!(b.len(), 8 * 1024); +} + +#[test] +fn write_buffer_basic() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(super::alloc_buf(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let chunk = write_buffer.consumable_chunk().unwrap(); + write_buffer.consume(chunk, 200); + assert_eq!(write_buffer.write(&buf), 200); + assert_eq!(write_buffer.write(&buf), 0); +} + +#[test] +#[should_panic] +fn call_consumable_chunk_twice() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(super::alloc_buf(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let chunk1 = write_buffer.consumable_chunk().unwrap(); + let _ = write_buffer.consumable_chunk().unwrap(); + drop(chunk1); +} + +#[test] +#[should_panic] +fn consume_wrong_buf() { + const LENGTH: usize = 1024; + let mut write_buffer = WriteBuffer::new(super::alloc_buf(1024)); + + let buf = vec![0u8; LENGTH]; + assert_eq!(write_buffer.write(&buf), LENGTH); + assert_eq!(write_buffer.write(&buf), 0); + + let unrelated_buf: UserBuf = super::alloc_buf(512).into(); + write_buffer.consume(unrelated_buf, 100); +} + +#[test] +fn read_buffer_basic() { + let mut buf = super::alloc_buf(64); + const DATA: &'static [u8] = b"hello"; + buf[0..DATA.len()].copy_from_enclave(DATA); + + let mut read_buffer = ReadBuffer::new(buf, DATA.len()); + assert_eq!(read_buffer.len(), DATA.len()); + assert_eq!(read_buffer.remaining_bytes(), DATA.len()); + let mut buf = [0u8; 8]; + assert_eq!(read_buffer.read(&mut buf), DATA.len()); + assert_eq!(read_buffer.remaining_bytes(), 0); + assert_eq!(&buf, b"hello\0\0\0"); +} diff --git a/async-usercalls/src/batch_drop.rs b/async-usercalls/src/batch_drop.rs new file mode 100644 index 000000000..f27b05c4a --- /dev/null +++ b/async-usercalls/src/batch_drop.rs @@ -0,0 +1,127 @@ +use crate::hacks::Usercall; +use crate::provider_core::ProviderCore; +use ipc_queue::Identified; +use std::cell::RefCell; +use std::mem; +use std::os::fortanix_sgx::usercalls::alloc::{User, UserSafe}; +use std::os::fortanix_sgx::usercalls::raw::UsercallNrs; + +pub trait BatchDropable: private::BatchDropable {} +impl BatchDropable for T {} + +/// Drop the given value at some point in the future (no rush!). This is useful +/// for freeing userspace memory when we don't particularly care about when the +/// buffer is freed. Multiple `free` usercalls are batched together and sent to +/// userspace asynchronously. It is also guaranteed that the memory is freed if +/// the current thread exits before there is a large enough batch. +/// +/// This is mainly an optimization to avoid exitting the enclave for each +/// usercall. Note that even when sending usercalls asynchronously, if the +/// usercall queue is empty we still need to exit the enclave to signal the +/// userspace that the queue is not empty anymore. The batch send would send +/// multiple usercalls and notify the userspace at most once. +pub fn batch_drop(t: T) { + t.batch_drop(); +} + +mod private { + use super::*; + + const BATCH_SIZE: usize = 8; + + struct BatchDropProvider { + core: ProviderCore, + deferred: Vec>, + } + + impl BatchDropProvider { + pub fn new() -> Self { + Self { + core: ProviderCore::new(None), + deferred: Vec::with_capacity(BATCH_SIZE), + } + } + + fn make_progress(&self, deferred: &[Identified]) -> usize { + let sent = self.core.try_send_multiple_usercalls(deferred); + if sent == 0 { + self.core.send_usercall(deferred[0]); + return 1; + } + sent + } + + fn maybe_send_usercall(&mut self, u: Usercall) { + self.deferred.push(self.core.assign_id(u)); + if self.deferred.len() < BATCH_SIZE { + return; + } + let sent = self.make_progress(&self.deferred); + let mut not_sent = self.deferred.split_off(sent); + self.deferred.clear(); + self.deferred.append(&mut not_sent); + } + + pub fn free(&mut self, buf: User) { + let ptr = buf.into_raw(); + let size = unsafe { mem::size_of_val(&mut *ptr) }; + let alignment = T::align_of(); + let ptr = ptr as *mut u8; + let u = Usercall(UsercallNrs::free as _, ptr as _, size as _, alignment as _, 0); + self.maybe_send_usercall(u); + } + } + + impl Drop for BatchDropProvider { + fn drop(&mut self) { + let mut sent = 0; + while sent < self.deferred.len() { + sent += self.make_progress(&self.deferred[sent..]); + } + } + } + + std::thread_local! { + static PROVIDER: RefCell = RefCell::new(BatchDropProvider::new()); + } + + pub trait BatchDropable { + fn batch_drop(self); + } + + impl BatchDropable for User { + fn batch_drop(self) { + PROVIDER.with(|p| p.borrow_mut().free(self)); + } + } +} + +#[cfg(test)] +mod tests { + use super::batch_drop; + use std::os::fortanix_sgx::usercalls::alloc::User; + use std::thread; + + #[test] + fn basic() { + for _ in 0..100 { + batch_drop(User::<[u8]>::uninitialized(100)); + } + } + + #[test] + fn multiple_threads() { + const THREADS: usize = 16; + let mut handles = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + handles.push(thread::spawn(move || { + for _ in 0..1000 { + batch_drop(User::<[u8]>::uninitialized(100)); + } + })); + } + for h in handles { + h.join().unwrap(); + } + } +} diff --git a/async-usercalls/src/callback.rs b/async-usercalls/src/callback.rs new file mode 100644 index 000000000..369ca2b89 --- /dev/null +++ b/async-usercalls/src/callback.rs @@ -0,0 +1,89 @@ +use crate::duplicated::{FromSgxResult, ReturnValue}; +use crate::hacks::Return; +use fortanix_sgx_abi::{Fd, Result as SxgResult}; +use std::io; + +pub struct CbFn(Box); + +impl CbFn { + fn call(self, t: T) { + (self.0)(t); + } +} + +impl From for CbFn +where + F: FnOnce(T) + Send + 'static, +{ + fn from(f: F) -> Self { + Self(Box::new(f)) + } +} + +pub(crate) enum Callback { + Read(CbFn>), + Write(CbFn>), + Flush(CbFn>), + Close(CbFn<()>), + BindStream(CbFn>), + AcceptStream(CbFn>), + ConnectStream(CbFn>), + InsecureTime(CbFn), + Alloc(CbFn>), + Free(CbFn<()>), +} + +impl Callback { + pub(crate) fn call(self, ret: Return) { + use Callback::*; + match self { + Read(cb) => { + let x: (SxgResult, usize) = ReturnValue::from_registers("read", (ret.0, ret.1)); + let x = x.from_sgx_result(); + cb.call(x); + } + Write(cb) => { + let x: (SxgResult, usize) = ReturnValue::from_registers("write", (ret.0, ret.1)); + let x = x.from_sgx_result(); + cb.call(x); + } + Flush(cb) => { + let x: SxgResult = ReturnValue::from_registers("flush", (ret.0, ret.1)); + let x = x.from_sgx_result(); + cb.call(x); + } + Close(cb) => { + assert_eq!((ret.0, ret.1), (0, 0)); + cb.call(()); + } + BindStream(cb) => { + let x: (SxgResult, Fd) = ReturnValue::from_registers("bind_stream", (ret.0, ret.1)); + let x = x.from_sgx_result(); + cb.call(x); + } + AcceptStream(cb) => { + let x: (SxgResult, Fd) = ReturnValue::from_registers("accept_stream", (ret.0, ret.1)); + let x = x.from_sgx_result(); + cb.call(x); + } + ConnectStream(cb) => { + let x: (SxgResult, Fd) = ReturnValue::from_registers("connect_stream", (ret.0, ret.1)); + let x = x.from_sgx_result(); + cb.call(x); + } + InsecureTime(cb) => { + let x: u64 = ReturnValue::from_registers("insecure_time", (ret.0, ret.1)); + cb.call(x); + } + Alloc(cb) => { + let x: (SxgResult, *mut u8) = ReturnValue::from_registers("alloc", (ret.0, ret.1)); + let x = x.from_sgx_result(); + cb.call(x); + } + Free(cb) => { + assert_eq!((ret.0, ret.1), (0, 0)); + cb.call(()); + } + } + } +} diff --git a/async-usercalls/src/duplicated.rs b/async-usercalls/src/duplicated.rs new file mode 100644 index 000000000..0a39e5a1f --- /dev/null +++ b/async-usercalls/src/duplicated.rs @@ -0,0 +1,168 @@ +//! this file contains code duplicated from libstd's sys/sgx +use fortanix_sgx_abi::{Error, Result, RESULT_SUCCESS}; +use std::io; +use std::ptr::NonNull; + +fn check_os_error(err: Result) -> i32 { + // FIXME: not sure how to make sure all variants of Error are covered + if err == Error::NotFound as _ + || err == Error::PermissionDenied as _ + || err == Error::ConnectionRefused as _ + || err == Error::ConnectionReset as _ + || err == Error::ConnectionAborted as _ + || err == Error::NotConnected as _ + || err == Error::AddrInUse as _ + || err == Error::AddrNotAvailable as _ + || err == Error::BrokenPipe as _ + || err == Error::AlreadyExists as _ + || err == Error::WouldBlock as _ + || err == Error::InvalidInput as _ + || err == Error::InvalidData as _ + || err == Error::TimedOut as _ + || err == Error::WriteZero as _ + || err == Error::Interrupted as _ + || err == Error::Other as _ + || err == Error::UnexpectedEof as _ + || ((Error::UserRangeStart as _)..=(Error::UserRangeEnd as _)).contains(&err) + { + err + } else { + panic!("Usercall: returned invalid error value {}", err) + } +} + +pub trait FromSgxResult { + type Return; + + fn from_sgx_result(self) -> io::Result; +} + +impl FromSgxResult for (Result, T) { + type Return = T; + + fn from_sgx_result(self) -> io::Result { + if self.0 == RESULT_SUCCESS { + Ok(self.1) + } else { + Err(io::Error::from_raw_os_error(check_os_error(self.0))) + } + } +} + +impl FromSgxResult for Result { + type Return = (); + + fn from_sgx_result(self) -> io::Result { + if self == RESULT_SUCCESS { + Ok(()) + } else { + Err(io::Error::from_raw_os_error(check_os_error(self))) + } + } +} + +type Register = u64; + +pub trait RegisterArgument { + fn from_register(_: Register) -> Self; + fn into_register(self) -> Register; +} + +pub trait ReturnValue { + fn from_registers(call: &'static str, regs: (Register, Register)) -> Self; +} + +macro_rules! define_ra { + (< $i:ident > $t:ty) => { + impl<$i> RegisterArgument for $t { + fn from_register(a: Register) -> Self { + a as _ + } + fn into_register(self) -> Register { + self as _ + } + } + }; + ($i:ty as $t:ty) => { + impl RegisterArgument for $t { + fn from_register(a: Register) -> Self { + a as $i as _ + } + fn into_register(self) -> Register { + self as $i as _ + } + } + }; + ($t:ty) => { + impl RegisterArgument for $t { + fn from_register(a: Register) -> Self { + a as _ + } + fn into_register(self) -> Register { + self as _ + } + } + }; +} + +define_ra!(Register); +define_ra!(i64); +define_ra!(u32); +define_ra!(u32 as i32); +define_ra!(u16); +define_ra!(u16 as i16); +define_ra!(u8); +define_ra!(u8 as i8); +define_ra!(usize); +define_ra!(usize as isize); +define_ra!( *const T); +define_ra!( *mut T); + +impl RegisterArgument for bool { + fn from_register(a: Register) -> bool { + if a != 0 { + true + } else { + false + } + } + fn into_register(self) -> Register { + self as _ + } +} + +impl RegisterArgument for Option> { + fn from_register(a: Register) -> Option> { + NonNull::new(a as _) + } + fn into_register(self) -> Register { + self.map_or(0 as _, NonNull::as_ptr) as _ + } +} + +impl ReturnValue for ! { + fn from_registers(call: &'static str, _regs: (Register, Register)) -> Self { + panic!("Usercall {}: did not expect to be re-entered", call); + } +} + +impl ReturnValue for () { + fn from_registers(_call: &'static str, usercall_retval: (Register, Register)) -> Self { + assert!(usercall_retval.0 == 0); + assert!(usercall_retval.1 == 0); + () + } +} + +impl ReturnValue for T { + fn from_registers(_call: &'static str, usercall_retval: (Register, Register)) -> Self { + assert!(usercall_retval.1 == 0); + T::from_register(usercall_retval.0) + } +} + +impl ReturnValue for (T, U) { + fn from_registers(_call: &'static str, regs: (Register, Register)) -> Self { + (T::from_register(regs.0), U::from_register(regs.1)) + } +} diff --git a/async-usercalls/src/hacks/async_queues.rs b/async-usercalls/src/hacks/async_queues.rs new file mode 100644 index 000000000..a325a28fb --- /dev/null +++ b/async-usercalls/src/hacks/async_queues.rs @@ -0,0 +1,50 @@ +use super::{Cancel, Return, Usercall}; +use crate::duplicated::ReturnValue; +use fortanix_sgx_abi::FifoDescriptor; +use std::num::NonZeroU64; +use std::os::fortanix_sgx::usercalls; +use std::os::fortanix_sgx::usercalls::raw; +use std::{mem, ptr}; + +// TODO: remove these once support for cancel queue is added in `std::os::fortanix_sgx` + +pub unsafe fn async_queues( + usercall_queue: *mut FifoDescriptor, + return_queue: *mut FifoDescriptor, + cancel_queue: *mut FifoDescriptor, +) -> raw::Result { + ReturnValue::from_registers( + "async_queues", + raw::do_usercall( + NonZeroU64::new(raw::UsercallNrs::async_queues as _).unwrap(), + usercall_queue as _, + return_queue as _, + cancel_queue as _, + 0, + false, + ), + ) +} + +pub unsafe fn alloc_descriptor() -> *mut FifoDescriptor { + usercalls::alloc( + mem::size_of::>(), + mem::align_of::>(), + ) + .expect("failed to allocate userspace memory") as _ +} + +pub unsafe fn to_enclave(ptr: *mut FifoDescriptor) -> FifoDescriptor { + let mut dest: FifoDescriptor = mem::zeroed(); + ptr::copy( + ptr as *const u8, + (&mut dest) as *mut FifoDescriptor as *mut u8, + mem::size_of_val(&mut dest), + ); + usercalls::free( + ptr as _, + mem::size_of::>(), + mem::align_of::>(), + ); + dest +} diff --git a/async-usercalls/src/hacks/mod.rs b/async-usercalls/src/hacks/mod.rs new file mode 100644 index 000000000..6e7c183e1 --- /dev/null +++ b/async-usercalls/src/hacks/mod.rs @@ -0,0 +1,61 @@ +use std::ops::{Deref, DerefMut}; +use std::os::fortanix_sgx::usercalls::alloc::UserSafeSized; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; + +mod async_queues; +mod unsafe_typecasts; + +pub use self::async_queues::{alloc_descriptor, async_queues, to_enclave}; +pub use self::unsafe_typecasts::{new_std_listener, new_std_stream}; + +#[repr(C)] +#[derive(Copy, Clone, Default)] +pub struct Usercall(pub u64, pub u64, pub u64, pub u64, pub u64); + +unsafe impl UserSafeSized for Usercall {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +pub struct Return(pub u64, pub u64); + +unsafe impl UserSafeSized for Return {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +pub struct Cancel { + /// This must be the same value as `Usercall.0`. + pub usercall_nr: u64, +} + +unsafe impl UserSafeSized for Cancel {} + +// Interim solution until we mark the target types appropriately +pub(crate) struct MakeSend(T); + +impl MakeSend { + pub fn new(t: T) -> Self { + Self(t) + } + + #[allow(unused)] + pub fn into_inner(self) -> T { + self.0 + } +} + +impl Deref for MakeSend { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for MakeSend { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +unsafe impl Send for MakeSend {} +unsafe impl Send for MakeSend> {} diff --git a/async-usercalls/src/hacks/unsafe_typecasts.rs b/async-usercalls/src/hacks/unsafe_typecasts.rs new file mode 100644 index 000000000..1e3d67c5e --- /dev/null +++ b/async-usercalls/src/hacks/unsafe_typecasts.rs @@ -0,0 +1,95 @@ +//! The incredibly unsafe code in this module allows us to create +//! `std::net::TcpStream` and `std::net::TcpListener` types from their raw +//! components in SGX. +//! +//! This is obviously very unsafe and not maintainable and is only intended as +//! an iterim solution until we add similar functionality as extension traits +//! in `std::os::fortanix_sgx`. +use fortanix_sgx_abi::Fd; + +mod sgx { + use fortanix_sgx_abi::Fd; + use std::sync::Arc; + + #[derive(Debug)] + pub struct FileDesc { + fd: Fd, + } + + #[derive(Debug, Clone)] + pub struct Socket { + inner: Arc, + local_addr: Option, + } + + #[derive(Clone)] + pub struct TcpStream { + inner: Socket, + peer_addr: Option, + } + + impl TcpStream { + pub fn new(fd: Fd, local_addr: Option, peer_addr: Option) -> TcpStream { + TcpStream { + inner: Socket { + inner: Arc::new(FileDesc { fd }), + local_addr, + }, + peer_addr, + } + } + } + + #[derive(Clone)] + pub struct TcpListener { + inner: Socket, + } + + impl TcpListener { + pub fn new(fd: Fd, local_addr: Option) -> TcpListener { + TcpListener { + inner: Socket { + inner: Arc::new(FileDesc { fd }), + local_addr, + }, + } + } + } +} + +struct TcpStream(self::sgx::TcpStream); +struct TcpListener(self::sgx::TcpListener); + +pub unsafe fn new_std_stream(fd: Fd, local_addr: Option, peer_addr: Option) -> std::net::TcpStream { + let stream = TcpStream(sgx::TcpStream::new(fd, local_addr, peer_addr)); + std::mem::transmute(stream) +} + +pub unsafe fn new_std_listener(fd: Fd, local_addr: Option) -> std::net::TcpListener { + let listener = TcpListener(sgx::TcpListener::new(fd, local_addr)); + std::mem::transmute(listener) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + use std::os::fortanix_sgx::io::AsRawFd; + + #[test] + fn sanity_check() { + let fd = 42; + let local = "1.2.3.4:1234"; + let peer = "5.6.7.8:443"; + let stream = unsafe { new_std_stream(fd, Some(local.to_owned()), Some(peer.to_owned())) }; + assert_eq!(stream.as_raw_fd(), fd); + assert_eq!(stream.local_addr().unwrap().to_string(), local); + assert_eq!(stream.peer_addr().unwrap().to_string(), peer); + mem::forget(stream); // not a real stream... + + let listener = unsafe { new_std_listener(fd, Some(local.to_owned())) }; + assert_eq!(listener.as_raw_fd(), fd); + assert_eq!(listener.local_addr().unwrap().to_string(), local); + mem::forget(listener); // not a real listener... + } +} diff --git a/async-usercalls/src/lib.rs b/async-usercalls/src/lib.rs new file mode 100644 index 000000000..45c377813 --- /dev/null +++ b/async-usercalls/src/lib.rs @@ -0,0 +1,165 @@ +#![feature(sgx_platform)] +#![feature(never_type)] +#![cfg_attr(test, feature(unboxed_closures))] +#![cfg_attr(test, feature(fn_traits))] + +use crossbeam_channel as mpmc; +use ipc_queue::Identified; +use std::collections::HashMap; +use std::os::fortanix_sgx::usercalls::raw::UsercallNrs; +use std::panic; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread::{self, JoinHandle}; + +mod alloc; +mod batch_drop; +mod callback; +mod duplicated; +mod hacks; +mod provider_api; +mod provider_core; +mod queues; +mod raw; +#[cfg(test)] +mod tests; + +pub use self::alloc::{alloc_buf, alloc_byte_buffer, ReadBuffer, User, UserBuf, UserSafeExt, WriteBuffer}; +pub use self::batch_drop::batch_drop; +pub use self::callback::CbFn; +pub use self::raw::RawApi; + +use self::callback::*; +use self::hacks::{Cancel, Return, Usercall}; +use self::provider_core::ProviderCore; +use self::queues::*; + +pub struct CancelHandle<'p> { + c: Identified, + tx: &'p Sender, +} + +impl<'p> CancelHandle<'p> { + pub fn cancel(self) { + self.tx.send(self.c).expect("failed to send cancellation"); + } + + pub(crate) fn new(c: Identified, tx: &'p Sender) -> Self { + CancelHandle { c, tx } + } +} + +/// This type provides a mechanism for submitting usercalls asynchronously. +/// Usercalls are sent to the enclave runner through a queue. The results are +/// retrieved on a dedicated thread. Users are notified of the results through +/// callback functions. +/// +/// Users of this type should take care not to block execution in callbacks. +/// Ceratin usercalls can be cancelled through a handle, but note that it is +/// still possible to receive successful results for cancelled usercalls. +pub struct AsyncUsercallProvider { + core: ProviderCore, + callback_tx: mpmc::Sender<(u64, Callback)>, + shutdown: Arc, + join_handle: Option>, +} + +impl AsyncUsercallProvider { + pub fn new() -> Self { + let (return_tx, return_rx) = mpmc::unbounded(); + let core = ProviderCore::new(Some(return_tx)); + let (callback_tx, callback_rx) = mpmc::unbounded(); + let shutdown = Arc::new(AtomicBool::new(false)); + let callback_handler = CallbackHandler { + return_rx, + callback_rx, + shutdown: Arc::clone(&shutdown), + }; + let join_handle = thread::spawn(move || callback_handler.run()); + Self { + core, + callback_tx, + shutdown, + join_handle: Some(join_handle), + } + } + + #[cfg(test)] + pub(crate) fn provider_id(&self) -> u32 { + self.core.provider_id() + } + + fn send_usercall(&self, usercall: Usercall, callback: Option) -> CancelHandle { + let usercall = self.core.assign_id(usercall); + if let Some(callback) = callback { + self.callback_tx + .send((usercall.id, callback)) + .expect("failed to send callback"); + } + self.core.send_usercall(usercall) + } +} + +impl Drop for AsyncUsercallProvider { + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Release); + // send a usercall to ensure CallbackHandler wakes up and breaks its loop. + let u = Usercall(UsercallNrs::insecure_time as _, 0, 0, 0, 0); + self.send_usercall(u, None); + let join_handle = self.join_handle.take().unwrap(); + join_handle.join().unwrap(); + } +} + +struct CallbackHandler { + return_rx: mpmc::Receiver>, + callback_rx: mpmc::Receiver<(u64, Callback)>, + shutdown: Arc, +} + +impl CallbackHandler { + const BATCH: usize = 1024; + + fn recv_returns(&self) -> ([Identified; Self::BATCH], usize) { + let first = self.return_rx.recv().expect("channel closed unexpectedly"); + let mut returns = [Identified { + id: 0, + data: Return(0, 0), + }; Self::BATCH]; + let mut count = 0; + for ret in std::iter::once(first).chain(self.return_rx.try_iter().take(Self::BATCH - 1)) { + returns[count] = ret; + count += 1; + } + (returns, count) + } + + fn run(self) { + let mut callbacks = HashMap::with_capacity(256); + loop { + // block until there are some returns + let (returns, count) = self.recv_returns(); + // receive pending callbacks + for (id, callback) in self.callback_rx.try_iter() { + callbacks.insert(id, callback); + } + for ret in &returns[..count] { + if let Some(cb) = callbacks.remove(&ret.id) { + let _r = panic::catch_unwind(panic::AssertUnwindSafe(move || { + cb.call(ret.data); + })); + // if let Err(e) = _r { + // let msg = e + // .downcast_ref::() + // .map(String::as_str) + // .or_else(|| e.downcast_ref::<&str>().map(|&s| s)); + // println!("callback paniced: {:?}", msg); + // } + } + } + if self.shutdown.load(Ordering::Acquire) { + break; + } + } + } +} diff --git a/async-usercalls/src/provider_api.rs b/async-usercalls/src/provider_api.rs new file mode 100644 index 000000000..087a22ee7 --- /dev/null +++ b/async-usercalls/src/provider_api.rs @@ -0,0 +1,274 @@ +use crate::alloc::{alloc_buf, alloc_byte_buffer, User, UserBuf}; +use crate::batch_drop; +use crate::hacks::{new_std_listener, new_std_stream, MakeSend}; +use crate::raw::RawApi; +use crate::{AsyncUsercallProvider, CancelHandle}; +use fortanix_sgx_abi::Fd; +use std::io; +use std::mem::{self, ManuallyDrop}; +use std::net::{TcpListener, TcpStream}; +use std::os::fortanix_sgx::usercalls::alloc::{User as StdUser, UserRef, UserSafe}; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +impl AsyncUsercallProvider { + /// Sends an asynchronous `read` usercall. `callback` is called when a + /// return value is received from userspace. `read_buf` is returned as an + /// argument to `callback` along with the result of the `read` usercall. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn read(&self, fd: Fd, read_buf: User<[u8]>, callback: F) -> CancelHandle + where + F: FnOnce(io::Result, User<[u8]>) + Send + 'static, + { + let mut read_buf = ManuallyDrop::new(read_buf); + let ptr = read_buf.as_mut_ptr(); + let len = read_buf.len(); + let cb = move |res: io::Result| { + let read_buf = ManuallyDrop::into_inner(read_buf); + callback(res, read_buf); + }; + unsafe { self.raw_read(fd, ptr, len, Some(cb.into())) } + } + + /// Sends an asynchronous `write` usercall. `callback` is called when a + /// return value is received from userspace. `write_buf` is returned as an + /// argument to `callback` along with the result of the `write` usercall. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn write(&self, fd: Fd, write_buf: UserBuf, callback: F) -> CancelHandle + where + F: FnOnce(io::Result, UserBuf) + Send + 'static, + { + let mut write_buf = ManuallyDrop::new(write_buf); + let ptr = write_buf.as_mut_ptr(); + let len = write_buf.len(); + let cb = move |res| { + let write_buf = ManuallyDrop::into_inner(write_buf); + callback(res, write_buf); + }; + unsafe { self.raw_write(fd, ptr, len, Some(cb.into())) } + } + + /// Sends an asynchronous `flush` usercall. `callback` is called when a + /// return value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn flush(&self, fd: Fd, callback: F) + where + F: FnOnce(io::Result<()>) + Send + 'static, + { + unsafe { + self.raw_flush(fd, Some(callback.into())); + } + } + + /// Sends an asynchronous `close` usercall. If specified, `callback` is + /// called when a return is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn close(&self, fd: Fd, callback: Option) + where + F: FnOnce() + Send + 'static, + { + let cb = callback.map(|callback| move |()| callback()); + unsafe { + self.raw_close(fd, cb.map(Into::into)); + } + } + + /// Sends an asynchronous `bind_stream` usercall. `callback` is called when + /// a return value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn bind_stream(&self, addr: &str, callback: F) + where + F: FnOnce(io::Result) + Send + 'static, + { + let mut addr_buf = ManuallyDrop::new(alloc_buf(addr.len())); + let mut local_addr = ManuallyDrop::new(MakeSend::new(alloc_byte_buffer())); + + addr_buf[0..addr.len()].copy_from_enclave(addr.as_bytes()); + let addr_buf_ptr = addr_buf.as_raw_mut_ptr() as *mut u8; + let local_addr_ptr = local_addr.as_raw_mut_ptr(); + + let cb = move |res: io::Result| { + let _addr_buf = ManuallyDrop::into_inner(addr_buf); + let local_addr = ManuallyDrop::into_inner(local_addr); + + let local = string_from_bytebuffer(&local_addr, "bind_stream", "local_addr"); + let res = res.map(|fd| unsafe { new_std_listener(fd, Some(local)) }); + callback(res); + }; + unsafe { self.raw_bind_stream(addr_buf_ptr, addr.len(), local_addr_ptr, Some(cb.into())) } + } + + /// Sends an asynchronous `accept_stream` usercall. `callback` is called + /// when a return value is received from userspace. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn accept_stream(&self, fd: Fd, callback: F) -> CancelHandle + where + F: FnOnce(io::Result) + Send + 'static, + { + let mut local_addr = ManuallyDrop::new(MakeSend::new(alloc_byte_buffer())); + let mut peer_addr = ManuallyDrop::new(MakeSend::new(alloc_byte_buffer())); + + let local_addr_ptr = local_addr.as_raw_mut_ptr(); + let peer_addr_ptr = peer_addr.as_raw_mut_ptr(); + + let cb = move |res: io::Result| { + let local_addr = ManuallyDrop::into_inner(local_addr); + let peer_addr = ManuallyDrop::into_inner(peer_addr); + + let local = string_from_bytebuffer(&*local_addr, "accept_stream", "local_addr"); + let peer = string_from_bytebuffer(&*peer_addr, "accept_stream", "peer_addr"); + let res = res.map(|fd| unsafe { new_std_stream(fd, Some(local), Some(peer)) }); + callback(res); + }; + unsafe { self.raw_accept_stream(fd, local_addr_ptr, peer_addr_ptr, Some(cb.into())) } + } + + /// Sends an asynchronous `connect_stream` usercall. `callback` is called + /// when a return value is received from userspace. + /// + /// Returns a handle that can be used to cancel the usercall if desired. + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn connect_stream(&self, addr: &str, callback: F) -> CancelHandle + where + F: FnOnce(io::Result) + Send + 'static, + { + let mut addr_buf = ManuallyDrop::new(alloc_buf(addr.len())); + let mut local_addr = ManuallyDrop::new(MakeSend::new(alloc_byte_buffer())); + let mut peer_addr = ManuallyDrop::new(MakeSend::new(alloc_byte_buffer())); + + addr_buf[0..addr.len()].copy_from_enclave(addr.as_bytes()); + let addr_buf_ptr = addr_buf.as_raw_mut_ptr() as *mut u8; + let local_addr_ptr = local_addr.as_raw_mut_ptr(); + let peer_addr_ptr = peer_addr.as_raw_mut_ptr(); + + let cb = move |res: io::Result| { + let _addr_buf = ManuallyDrop::into_inner(addr_buf); + let local_addr = ManuallyDrop::into_inner(local_addr); + let peer_addr = ManuallyDrop::into_inner(peer_addr); + + let local = string_from_bytebuffer(&local_addr, "connect_stream", "local_addr"); + let peer = string_from_bytebuffer(&peer_addr, "connect_stream", "peer_addr"); + let res = res.map(|fd| unsafe { new_std_stream(fd, Some(local), Some(peer)) }); + callback(res); + }; + unsafe { self.raw_connect_stream(addr_buf_ptr, addr.len(), local_addr_ptr, peer_addr_ptr, Some(cb.into())) } + } + + /// Sends an asynchronous `alloc` usercall to allocate one instance of `T` + /// in userspace. `callback` is called when a return value is received from + /// userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn alloc(&self, callback: F) + where + T: UserSafe, + F: FnOnce(io::Result>) + Send + 'static, + { + let cb = move |res: io::Result<*mut u8>| { + let res = res.map(|ptr| unsafe { StdUser::::from_raw(ptr as _) }); + callback(res); + }; + unsafe { + self.raw_alloc(mem::size_of::(), T::align_of(), Some(cb.into())); + } + } + + /// Sends an asynchronous `alloc` usercall to allocate a slice of `T` in + /// userspace with the specified `len`. `callback` is called when a return + /// value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn alloc_slice(&self, len: usize, callback: F) + where + [T]: UserSafe, + F: FnOnce(io::Result>) + Send + 'static, + { + let cb = move |res: io::Result<*mut u8>| { + let res = res.map(|ptr| unsafe { StdUser::<[T]>::from_raw_parts(ptr as _, len) }); + callback(res); + }; + unsafe { + self.raw_alloc(len * mem::size_of::(), <[T]>::align_of(), Some(cb.into())); + } + } + + /// Sends an asynchronous `free` usercall to deallocate the userspace + /// buffer `buf`. If specified, `callback` is called when a return is + /// received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn free(&self, mut buf: StdUser, callback: Option) + where + T: ?Sized + UserSafe, + F: FnOnce() + Send + 'static, + { + let ptr = buf.as_raw_mut_ptr(); + let cb = callback.map(|callback| move |()| callback()); + unsafe { + self.raw_free( + buf.into_raw() as _, + mem::size_of_val(&mut *ptr), + T::align_of(), + cb.map(Into::into), + ); + } + } + + /// Sends an asynchronous `insecure_time` usercall. `callback` is called + /// when a return value is received from userspace. + /// + /// Please refer to the type-level documentation for general notes about + /// callbacks. + pub fn insecure_time(&self, callback: F) + where + F: FnOnce(SystemTime) + Send + 'static, + { + let cb = move |nanos_since_epoch| { + let t = UNIX_EPOCH + Duration::from_nanos(nanos_since_epoch); + callback(t); + }; + unsafe { + self.raw_insecure_time(Some(cb.into())); + } + } +} + +fn string_from_bytebuffer(buf: &UserRef, usercall: &str, arg: &str) -> String { + String::from_utf8(copy_user_buffer(buf)) + .unwrap_or_else(|_| panic!("Usercall {}: expected {} to be valid UTF-8", usercall, arg)) +} + +// adapted from libstd sys/sgx/abi/usercalls/alloc.rs +fn copy_user_buffer(buf: &UserRef) -> Vec { + unsafe { + let buf = buf.to_enclave(); + if buf.len > 0 { + let user = StdUser::from_raw_parts(buf.data as _, buf.len); + let v = user.to_enclave(); + batch_drop(user); + v + } else { + // Mustn't look at `data` or call `free` if `len` is `0`. + Vec::new() + } + } +} diff --git a/async-usercalls/src/provider_core.rs b/async-usercalls/src/provider_core.rs new file mode 100644 index 000000000..3d891d0b8 --- /dev/null +++ b/async-usercalls/src/provider_core.rs @@ -0,0 +1,69 @@ +use crate::hacks::{Cancel, Return, Usercall}; +use crate::queues::*; +use crate::CancelHandle; +use crossbeam_channel as mpmc; +use ipc_queue::Identified; +use std::sync::atomic::{AtomicU32, Ordering}; + +pub(crate) struct ProviderCore { + usercall_tx: Sender, + cancel_tx: Sender, + provider_id: u32, + next_id: AtomicU32, +} + +impl ProviderCore { + pub fn new(return_tx: Option>>) -> ProviderCore { + let (usercall_tx, cancel_tx, provider_id) = PROVIDERS.new_provider(return_tx); + ProviderCore { + usercall_tx, + cancel_tx, + provider_id, + next_id: AtomicU32::new(1), + } + } + + #[cfg(test)] + pub fn provider_id(&self) -> u32 { + self.provider_id + } + + fn next_id(&self) -> u32 { + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + match id { + 0 => self.next_id(), + _ => id, + } + } + + pub fn assign_id(&self, usercall: Usercall) -> Identified { + let id = self.next_id(); + Identified { + id: ((self.provider_id as u64) << 32) | id as u64, + data: usercall, + } + } + + pub fn send_usercall(&self, usercall: Identified) -> CancelHandle { + assert!(usercall.id != 0); + let cancel = Identified { + id: usercall.id, + data: Cancel { + usercall_nr: usercall.data.0, + }, + }; + self.usercall_tx.send(usercall).expect("failed to send async usercall"); + CancelHandle::new(cancel, &self.cancel_tx) + } + + // returns the number of usercalls successfully sent. + pub fn try_send_multiple_usercalls(&self, usercalls: &[Identified]) -> usize { + self.usercall_tx.try_send_multiple(usercalls).unwrap_or(0) + } +} + +impl Drop for ProviderCore { + fn drop(&mut self) { + PROVIDERS.remove_provider(self.provider_id); + } +} diff --git a/async-usercalls/src/queues.rs b/async-usercalls/src/queues.rs new file mode 100644 index 000000000..18c9eade5 --- /dev/null +++ b/async-usercalls/src/queues.rs @@ -0,0 +1,188 @@ +use crate::hacks::{alloc_descriptor, async_queues, to_enclave, Cancel, Return, Usercall}; +use crossbeam_channel as mpmc; +use fortanix_sgx_abi::{EV_CANCELQ_NOT_FULL, EV_RETURNQ_NOT_EMPTY, EV_USERCALLQ_NOT_FULL}; +use ipc_queue::{self, Identified, QueueEvent, RecvError, SynchronizationError, Synchronizer}; +use lazy_static::lazy_static; +use std::os::fortanix_sgx::usercalls::raw; +use std::sync::{Arc, Mutex}; +use std::{io, iter, thread}; + +pub(crate) type Sender = ipc_queue::Sender; +pub(crate) type Receiver = ipc_queue::Receiver; + +pub(crate) struct Providers { + usercall_queue_tx: Sender, + cancel_queue_tx: Sender, + provider_map: Arc>>>>>, +} + +impl Providers { + pub(crate) fn new_provider( + &self, + return_tx: Option>>, + ) -> (Sender, Sender, u32) { + let id = self.provider_map.lock().unwrap().insert(return_tx); + let usercall_queue_tx = self.usercall_queue_tx.clone(); + let cancel_queue_tx = self.cancel_queue_tx.clone(); + (usercall_queue_tx, cancel_queue_tx, id) + } + + pub(crate) fn remove_provider(&self, id: u32) { + let entry = self.provider_map.lock().unwrap().remove(id); + assert!(entry.is_some()); + } +} + +lazy_static! { + pub(crate) static ref PROVIDERS: Providers = { + let (utx, ctx, rx) = init_async_queues().expect("Failed to initialize async queues"); + let provider_map = Arc::new(Mutex::new(Map::new())); + let return_handler = ReturnHandler { + return_queue_rx: rx, + provider_map: Arc::clone(&provider_map), + }; + thread::spawn(move || return_handler.run()); + Providers { + usercall_queue_tx: utx, + cancel_queue_tx: ctx, + provider_map, + } + }; +} + +fn init_async_queues() -> io::Result<(Sender, Sender, Receiver)> { + // FIXME: this is just a hack. Replace these with `User::>::uninitialized().into_raw()` + let usercall_q = unsafe { alloc_descriptor::() }; + let cancel_q = unsafe { alloc_descriptor::() }; + let return_q = unsafe { alloc_descriptor::() }; + + let r = unsafe { async_queues(usercall_q, return_q, cancel_q) }; + if r != 0 { + return Err(io::Error::from_raw_os_error(r)); + } + + // FIXME: this is another hack, replace with `unsafe { User::>::from_raw(q) }.to_enclave()` + let usercall_queue = unsafe { to_enclave(usercall_q) }; + let cancel_queue = unsafe { to_enclave(cancel_q) }; + let return_queue = unsafe { to_enclave(return_q) }; + + let utx = unsafe { Sender::from_descriptor(usercall_queue, QueueSynchronizer { queue: Queue::Usercall }) }; + let ctx = unsafe { Sender::from_descriptor(cancel_queue, QueueSynchronizer { queue: Queue::Cancel }) }; + let rx = unsafe { Receiver::from_descriptor(return_queue, QueueSynchronizer { queue: Queue::Return }) }; + Ok((utx, ctx, rx)) +} + +struct ReturnHandler { + return_queue_rx: Receiver, + provider_map: Arc>>>>>, +} + +impl ReturnHandler { + const N: usize = 1024; + + fn send(&self, returns: &[Identified]) { + // This should hold the lock only for a short amount of time + // since mpmc::Sender::send() will not block (unbounded channel). + // Also note that the lock is uncontested most of the time, so + // taking the lock should be fast. + let provider_map = self.provider_map.lock().unwrap(); + for ret in returns { + let provider_id = (ret.id >> 32) as u32; + if let Some(sender) = provider_map.get(provider_id).and_then(|entry| entry.as_ref()) { + let _ = sender.send(*ret); + } + } + } + + fn run(self) { + const DEFAULT_RETURN: Identified = Identified { + id: 0, + data: Return(0, 0), + }; + loop { + let mut returns = [DEFAULT_RETURN; Self::N]; + let first = match self.return_queue_rx.recv() { + Ok(ret) => ret, + Err(RecvError::Closed) => break, + }; + let mut count = 0; + for ret in iter::once(first).chain(self.return_queue_rx.try_iter().take(Self::N - 1)) { + assert!(ret.id != 0); + returns[count] = ret; + count += 1; + } + self.send(&returns[..count]); + } + } +} + +#[derive(Clone, Copy, Debug)] +enum Queue { + Usercall, + Return, + Cancel, +} + +#[derive(Clone, Debug)] +pub(crate) struct QueueSynchronizer { + queue: Queue, +} + +impl Synchronizer for QueueSynchronizer { + fn wait(&self, event: QueueEvent) -> Result<(), SynchronizationError> { + let ev = match (self.queue, event) { + (Queue::Usercall, QueueEvent::NotEmpty) => panic!("enclave should not recv on usercall queue"), + (Queue::Cancel, QueueEvent::NotEmpty) => panic!("enclave should not recv on cancel queue"), + (Queue::Return, QueueEvent::NotFull) => panic!("enclave should not send on return queue"), + (Queue::Usercall, QueueEvent::NotFull) => EV_USERCALLQ_NOT_FULL, + (Queue::Cancel, QueueEvent::NotFull) => EV_CANCELQ_NOT_FULL, + (Queue::Return, QueueEvent::NotEmpty) => EV_RETURNQ_NOT_EMPTY, + }; + unsafe { + raw::wait(ev, raw::WAIT_INDEFINITE); + } + Ok(()) + } + + fn notify(&self, _event: QueueEvent) { + // any synchronous usercall would do + unsafe { + raw::wait(0, raw::WAIT_NO); + } + } +} + +use self::map::Map; +mod map { + use fnv::FnvHashMap; + + pub struct Map { + map: FnvHashMap, + next_id: u32, + } + + impl Map { + pub fn new() -> Self { + Self { + map: FnvHashMap::with_capacity_and_hasher(16, Default::default()), + next_id: 0, + } + } + + pub fn insert(&mut self, value: T) -> u32 { + let id = self.next_id; + self.next_id += 1; + let old = self.map.insert(id, value); + debug_assert!(old.is_none()); + id + } + + pub fn get(&self, id: u32) -> Option<&T> { + self.map.get(&id) + } + + pub fn remove(&mut self, id: u32) -> Option { + self.map.remove(&id) + } + } +} diff --git a/async-usercalls/src/raw.rs b/async-usercalls/src/raw.rs new file mode 100644 index 000000000..fb2d4fac0 --- /dev/null +++ b/async-usercalls/src/raw.rs @@ -0,0 +1,155 @@ +use crate::callback::*; +use crate::hacks::Usercall; +use crate::{AsyncUsercallProvider, CancelHandle}; +use fortanix_sgx_abi::Fd; +use std::io; +use std::os::fortanix_sgx::usercalls::raw::ByteBuffer; +use std::os::fortanix_sgx::usercalls::raw::UsercallNrs; + +pub trait RawApi { + unsafe fn raw_read( + &self, + fd: Fd, + buf: *mut u8, + len: usize, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_write( + &self, + fd: Fd, + buf: *const u8, + len: usize, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_flush(&self, fd: Fd, callback: Option>>); + + unsafe fn raw_close(&self, fd: Fd, callback: Option>); + + unsafe fn raw_bind_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + callback: Option>>, + ); + + unsafe fn raw_accept_stream( + &self, + fd: Fd, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_connect_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle; + + unsafe fn raw_insecure_time(&self, callback: Option>); + + unsafe fn raw_alloc(&self, size: usize, alignment: usize, callback: Option>>); + + unsafe fn raw_free(&self, ptr: *mut u8, size: usize, alignment: usize, callback: Option>); +} + +impl RawApi for AsyncUsercallProvider { + unsafe fn raw_read( + &self, + fd: Fd, + buf: *mut u8, + len: usize, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall(UsercallNrs::read as _, fd as _, buf as _, len as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::Read(cb))) + } + + unsafe fn raw_write( + &self, + fd: Fd, + buf: *const u8, + len: usize, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall(UsercallNrs::write as _, fd as _, buf as _, len as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::Write(cb))) + } + + unsafe fn raw_flush(&self, fd: Fd, callback: Option>>) { + let u = Usercall(UsercallNrs::flush as _, fd as _, 0, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::Flush(cb))); + } + + unsafe fn raw_close(&self, fd: Fd, callback: Option>) { + let u = Usercall(UsercallNrs::close as _, fd as _, 0, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::Close(cb))); + } + + unsafe fn raw_bind_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + callback: Option>>, + ) { + let u = Usercall(UsercallNrs::bind_stream as _, addr as _, len as _, local_addr as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::BindStream(cb))); + } + + unsafe fn raw_accept_stream( + &self, + fd: Fd, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall( + UsercallNrs::accept_stream as _, + fd as _, + local_addr as _, + peer_addr as _, + 0, + ); + self.send_usercall(u, callback.map(|cb| Callback::AcceptStream(cb))) + } + + unsafe fn raw_connect_stream( + &self, + addr: *const u8, + len: usize, + local_addr: *mut ByteBuffer, + peer_addr: *mut ByteBuffer, + callback: Option>>, + ) -> CancelHandle { + let u = Usercall( + UsercallNrs::connect_stream as _, + addr as _, + len as _, + local_addr as _, + peer_addr as _, + ); + self.send_usercall(u, callback.map(|cb| Callback::ConnectStream(cb))) + } + + unsafe fn raw_insecure_time(&self, callback: Option>) { + let u = Usercall(UsercallNrs::insecure_time as _, 0, 0, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::InsecureTime(cb))); + } + + unsafe fn raw_alloc(&self, size: usize, alignment: usize, callback: Option>>) { + let u = Usercall(UsercallNrs::alloc as _, size as _, alignment as _, 0, 0); + self.send_usercall(u, callback.map(|cb| Callback::Alloc(cb))); + } + + unsafe fn raw_free(&self, ptr: *mut u8, size: usize, alignment: usize, callback: Option>) { + let u = Usercall(UsercallNrs::free as _, ptr as _, size as _, alignment as _, 0); + self.send_usercall(u, callback.map(|cb| Callback::Free(cb))); + } +} diff --git a/async-usercalls/src/tests.rs b/async-usercalls/src/tests.rs new file mode 100644 index 000000000..78cb2094d --- /dev/null +++ b/async-usercalls/src/tests.rs @@ -0,0 +1,251 @@ +use super::*; +use crate::hacks::MakeSend; +use crossbeam_channel as mpmc; +use std::io; +use std::net::{TcpListener, TcpStream}; +use std::os::fortanix_sgx::io::AsRawFd; +use std::os::fortanix_sgx::usercalls::alloc::User as StdUser; +use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::{Duration, UNIX_EPOCH}; + +#[test] +fn get_time_async_raw() { + fn run(tid: u32, provider: AsyncUsercallProvider) -> (u32, u32, Duration) { + let pid = provider.provider_id(); + const N: usize = 500; + let (tx, rx) = mpmc::bounded(N); + for _ in 0..N { + let tx = tx.clone(); + let cb = move |d| { + let system_time = UNIX_EPOCH + Duration::from_nanos(d); + tx.send(system_time).unwrap(); + }; + unsafe { + provider.raw_insecure_time(Some(cb.into())); + } + } + let mut all = Vec::with_capacity(N); + for _ in 0..N { + all.push(rx.recv().unwrap()); + } + + assert_eq!(all.len(), N); + // The results are returned in arbitrary order + all.sort(); + let t0 = *all.first().unwrap(); + let tn = *all.last().unwrap(); + let total = tn.duration_since(t0).unwrap(); + (tid, pid, total / N as u32) + } + + println!(); + const THREADS: usize = 4; + let mut providers = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + providers.push(AsyncUsercallProvider::new()); + } + let mut handles = Vec::with_capacity(THREADS); + for (i, provider) in providers.into_iter().enumerate() { + handles.push(thread::spawn(move || run(i as u32, provider))); + } + for h in handles { + let res = h.join().unwrap(); + println!("[{}/{}] (Tn - T0) / N = {:?}", res.0, res.1, res.2); + } +} + +#[test] +fn raw_alloc_free() { + let provider = AsyncUsercallProvider::new(); + let ptr: Arc> = Arc::new(AtomicPtr::new(0 as _)); + let ptr2 = Arc::clone(&ptr); + const SIZE: usize = 1024; + const ALIGN: usize = 8; + + let (tx, rx) = mpmc::bounded(1); + let cb_alloc = move |p: io::Result<*mut u8>| { + let p = p.unwrap(); + ptr2.store(p, Ordering::Relaxed); + tx.send(()).unwrap(); + }; + unsafe { + provider.raw_alloc(SIZE, ALIGN, Some(cb_alloc.into())); + } + rx.recv().unwrap(); + let p = ptr.load(Ordering::Relaxed); + assert!(!p.is_null()); + + let (tx, rx) = mpmc::bounded(1); + let cb_free = move |()| { + tx.send(()).unwrap(); + }; + unsafe { + provider.raw_free(p, SIZE, ALIGN, Some(cb_free.into())); + } + rx.recv().unwrap(); +} + +#[test] +fn cancel_accept() { + let provider = Arc::new(AsyncUsercallProvider::new()); + let port = 6688; + let addr = format!("0.0.0.0:{}", port); + let (tx, rx) = mpmc::bounded(1); + provider.bind_stream(&addr, move |res| { + tx.send(res).unwrap(); + }); + let bind_res = rx.recv().unwrap(); + let listener = bind_res.unwrap(); + let fd = listener.as_raw_fd(); + let accept_count = Arc::new(AtomicUsize::new(0)); + let accept_count1 = Arc::clone(&accept_count); + let (tx, rx) = mpmc::bounded(1); + let accept = provider.accept_stream(fd, move |res| { + if let Ok(_) = res { + accept_count1.fetch_add(1, Ordering::Relaxed); + } + tx.send(()).unwrap(); + }); + accept.cancel(); + thread::sleep(Duration::from_millis(10)); + let _ = TcpStream::connect(&addr); + let _ = rx.recv(); + assert_eq!(accept_count.load(Ordering::Relaxed), 0); +} + +#[test] +fn connect() { + let listener = TcpListener::bind("0.0.0.0:0").unwrap(); + let addr = listener.local_addr().unwrap().to_string(); + let provider = AsyncUsercallProvider::new(); + let (tx, rx) = mpmc::bounded(1); + provider.connect_stream(&addr, move |res| { + tx.send(res).unwrap(); + }); + let res = rx.recv().unwrap(); + assert!(res.is_ok()); +} + +#[test] +fn safe_alloc_free() { + let provider = AsyncUsercallProvider::new(); + + const LEN: usize = 64 * 1024; + let (tx, rx) = mpmc::bounded(1); + provider.alloc_slice::(LEN, move |res| { + let buf = res.expect("failed to allocate memory"); + tx.send(MakeSend::new(buf)).unwrap(); + }); + let user_buf = rx.recv().unwrap().into_inner(); + assert_eq!(user_buf.len(), LEN); + + let (tx, rx) = mpmc::bounded(1); + let cb = move || { + tx.send(()).unwrap(); + }; + provider.free(user_buf, Some(cb)); + rx.recv().unwrap(); +} + +unsafe impl Send for MakeSend> {} + +#[test] +#[ignore] +fn echo() { + println!(); + let provider = Arc::new(AsyncUsercallProvider::new()); + const ADDR: &'static str = "0.0.0.0:7799"; + let (tx, rx) = mpmc::bounded(1); + provider.bind_stream(ADDR, move |res| { + tx.send(res).unwrap(); + }); + let bind_res = rx.recv().unwrap(); + let listener = bind_res.unwrap(); + println!("bind done: {:?}", listener); + let fd = listener.as_raw_fd(); + let cb = KeepAccepting { + listener, + provider: Arc::clone(&provider), + }; + provider.accept_stream(fd, cb); + thread::sleep(Duration::from_secs(60)); +} + +struct KeepAccepting { + listener: TcpListener, + provider: Arc, +} + +impl FnOnce<(io::Result,)> for KeepAccepting { + type Output = (); + + extern "rust-call" fn call_once(self, args: (io::Result,)) -> Self::Output { + let res = args.0; + println!("accept result: {:?}", res); + if let Ok(stream) = res { + let fd = stream.as_raw_fd(); + let cb = Echo { + stream, + read: true, + provider: self.provider.clone(), + }; + self.provider.read(fd, alloc_buf(Echo::READ_BUF_SIZE), cb); + } + let provider = Arc::clone(&self.provider); + provider.accept_stream(self.listener.as_raw_fd(), self); + } +} + +struct Echo { + stream: TcpStream, + read: bool, + provider: Arc, +} + +impl Echo { + const READ_BUF_SIZE: usize = 1024; + + fn close(self) { + let fd = self.stream.as_raw_fd(); + println!("connection closed, fd = {}", fd); + self.provider.close(fd, None::>); + } +} + +// read callback +impl FnOnce<(io::Result, User<[u8]>)> for Echo { + type Output = (); + + extern "rust-call" fn call_once(mut self, args: (io::Result, User<[u8]>)) -> Self::Output { + let (res, user) = args; + assert!(self.read); + match res { + Ok(len) if len > 0 => { + self.read = false; + let provider = Arc::clone(&self.provider); + provider.write(self.stream.as_raw_fd(), (user, 0..len).into(), self); + } + _ => self.close(), + } + } +} + +// write callback +impl FnOnce<(io::Result, UserBuf)> for Echo { + type Output = (); + + extern "rust-call" fn call_once(mut self, args: (io::Result, UserBuf)) -> Self::Output { + let (res, _) = args; + assert!(!self.read); + match res { + Ok(len) if len > 0 => { + self.read = true; + let provider = Arc::clone(&self.provider); + provider.read(self.stream.as_raw_fd(), alloc_buf(Echo::READ_BUF_SIZE), self); + } + _ => self.close(), + } + } +} diff --git a/async-usercalls/test.sh b/async-usercalls/test.sh new file mode 100755 index 000000000..cdb85673d --- /dev/null +++ b/async-usercalls/test.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# Run this in parallel with: +# $ cargo test --target x86_64-fortanix-unknown-sgx --release -- --nocapture --ignored echo + +for i in $(seq 1 100); do + echo $i + telnet localhost 7799 < /dev/zero &> /dev/null & + sleep 0.01 +done + +sleep 10s +kill $(jobs -p) +wait