Skip to content

Commit

Permalink
Fix a soundness issue with the component model and async (#6509)
Browse files Browse the repository at this point in the history
* Force `execute_across_threads` to use multiple threads

Currently this uses tokio's `spawn_blocking` but that will reuse threads
in its thread pool. Instead spawn a thread and perform a single poll on
that thread to force lots of fresh threads to be used and ideally stress
TLS management further.

* Add a guard against using stale stacks

This commit adds a guard to Wasmtime's async support to double-check
that when a call to `poll` is finished that the currently active TLS
activation pointer does not point to the stack that is being switched
off of. This is attempting to be a bit of a defense-in-depth measure to
prevent stale pointers from sticking around in TLS. This is currently
happening and causing #6493 which can result in unsoundness but
currently is manifesting as a crash.

* Fix a soundness issue with the component model and async

This commit addresses #6493 by fixing a soundness issue with the async
implementation of the component model. This issue has been presence
since the inception of the addition of async support to the component
model and doesn't represent a recent regression. The underlying problem
is that one of the base assumptions of the trap handling code is that
there's only one single activation in TLS that needs to be pushed/popped
when a stack is switched (e.g. a fiber is switched to or from). In the
case of the component model there might be two activations: one for an
invocation of a component function and then a second for an invocation
of a `realloc` function to return results back to wasm (e.g. in the case
an imported function returns a list).

This problem is fixed by changing how TLS is managed in the presence of
fibers. Previously when a fiber was suspended it would pop a single
activation from the top of the stack and save that to get pushed when
the fiber was resumed. This has the benefit of maintaining an entire
linked list of activations for the current thread but has the problem
above where it doesn't handle a fiber with multiple activations on it.
Instead now TLS management is done when a fiber is resumed instead of
suspended. Instead of pushing/popping a single activation the entire
linked list of activations is tracked for a particular fiber and stored
within the fiber itself. In this manner resuming a fiber will push
all activations onto the current thread and suspending a fiber will pop
all activations for the fiber (and store them as a new linked list in
the fiber's state itself).

This end result is that all activations on a fiber should now be managed
correctly, regardless of how many there are. The main downside of this
commit is that fiber suspension and resumption is more complicated, but
the hope there is that fiber suspension typically corresponds with I/O
not being ready or similar so the order of magnitude of TLS operations
isn't too significant compared to the I/O overhead.

Closes #6493

* Review comments

* Fix restoration during panic
  • Loading branch information
alexcrichton committed Jun 2, 2023
1 parent 1d4686d commit 550a16f
Show file tree
Hide file tree
Showing 10 changed files with 474 additions and 74 deletions.
22 changes: 17 additions & 5 deletions crates/fiber/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::any::Any;
use std::cell::Cell;
use std::io;
use std::marker::PhantomData;
use std::ops::Range;
use std::panic::{self, AssertUnwindSafe};

cfg_if::cfg_if! {
Expand All @@ -26,24 +27,35 @@ impl FiberStack {
Ok(Self(imp::FiberStack::new(size)?))
}

/// Creates a new fiber stack with the given pointer to the top of the stack.
/// Creates a new fiber stack with the given pointer to the bottom of the
/// stack plus the byte length of the stack.
///
/// The `bottom` pointer should be addressable for `len` bytes. The page
/// beneath `bottom` should be unmapped as a guard page.
///
/// # Safety
///
/// This is unsafe because there is no validation of the given stack pointer.
/// This is unsafe because there is no validation of the given pointer.
///
/// The caller must properly allocate the stack space with a guard page and
/// make the pages accessible for correct behavior.
pub unsafe fn from_top_ptr(top: *mut u8) -> io::Result<Self> {
Ok(Self(imp::FiberStack::from_top_ptr(top)?))
pub unsafe fn from_raw_parts(bottom: *mut u8, len: usize) -> io::Result<Self> {
Ok(Self(imp::FiberStack::from_raw_parts(bottom, len)?))
}

/// Gets the top of the stack.
///
/// Returns `None` if the platform does not support getting the top of the stack.
/// Returns `None` if the platform does not support getting the top of the
/// stack.
pub fn top(&self) -> Option<*mut u8> {
self.0.top()
}

/// Returns the range of where this stack resides in memory if the platform
/// supports it.
pub fn range(&self) -> Option<Range<usize>> {
self.0.range()
}
}

pub struct Fiber<'a, Resume, Yield, Return> {
Expand Down
27 changes: 20 additions & 7 deletions crates/fiber/src/unix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@
use crate::RunResult;
use std::cell::Cell;
use std::io;
use std::ops::Range;
use std::ptr;

#[derive(Debug)]
pub struct FiberStack {
// The top of the stack; for stacks allocated by the fiber implementation itself,
// the base address of the allocation will be `top.sub(len.unwrap())`
top: *mut u8,
// The length of the stack; `None` when the stack was not created by this implementation.
len: Option<usize>,
// The length of the stack
len: usize,
// whether or not this stack was mmap'd
mmap: bool,
}

impl FiberStack {
Expand Down Expand Up @@ -74,25 +77,35 @@ impl FiberStack {

Ok(Self {
top: mmap.cast::<u8>().add(mmap_len),
len: Some(mmap_len),
len: mmap_len,
mmap: true,
})
}
}

pub unsafe fn from_top_ptr(top: *mut u8) -> io::Result<Self> {
Ok(Self { top, len: None })
pub unsafe fn from_raw_parts(base: *mut u8, len: usize) -> io::Result<Self> {
Ok(Self {
top: base.add(len),
len,
mmap: false,
})
}

pub fn top(&self) -> Option<*mut u8> {
Some(self.top)
}

pub fn range(&self) -> Option<Range<usize>> {
let base = unsafe { self.top.sub(self.len) as usize };
Some(base..base + self.len)
}
}

impl Drop for FiberStack {
fn drop(&mut self) {
unsafe {
if let Some(len) = self.len {
let ret = rustix::mm::munmap(self.top.sub(len) as _, len);
if self.mmap {
let ret = rustix::mm::munmap(self.top.sub(self.len) as _, self.len);
debug_assert!(ret.is_ok());
}
}
Expand Down
7 changes: 6 additions & 1 deletion crates/fiber/src/windows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::RunResult;
use std::cell::Cell;
use std::ffi::c_void;
use std::io;
use std::ops::Range;
use std::ptr;
use windows_sys::Win32::Foundation::*;
use windows_sys::Win32::System::Threading::*;
Expand All @@ -14,13 +15,17 @@ impl FiberStack {
Ok(Self(size))
}

pub unsafe fn from_top_ptr(_top: *mut u8) -> io::Result<Self> {
pub unsafe fn from_raw_parts(_base: *mut u8, _len: usize) -> io::Result<Self> {
Err(io::Error::from_raw_os_error(ERROR_NOT_SUPPORTED as i32))
}

pub fn top(&self) -> Option<*mut u8> {
None
}

pub fn range(&self) -> Option<Range<usize>> {
None
}
}

pub struct Fiber {
Expand Down
2 changes: 1 addition & 1 deletion crates/runtime/src/instance/allocator/pooling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ impl StackPool {
commit_stack_pages(bottom_of_stack, size_without_guard)?;

let stack =
wasmtime_fiber::FiberStack::from_top_ptr(bottom_of_stack.add(size_without_guard))?;
wasmtime_fiber::FiberStack::from_raw_parts(bottom_of_stack, size_without_guard)?;
Ok(stack)
}
}
Expand Down
5 changes: 1 addition & 4 deletions crates/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ pub use crate::mmap::Mmap;
pub use crate::mmap_vec::MmapVec;
pub use crate::store_box::*;
pub use crate::table::{Table, TableElement};
pub use crate::traphandlers::{
catch_traps, init_traps, raise_lib_trap, raise_user_trap, resume_panic, tls_eager_initialize,
Backtrace, Frame, SignalHandler, TlsRestore, Trap, TrapReason,
};
pub use crate::traphandlers::*;
pub use crate::vmcontext::{
VMArrayCallFunction, VMArrayCallHostFuncContext, VMContext, VMFuncRef, VMFunctionBody,
VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMInvokeArgument, VMMemoryDefinition,
Expand Down
149 changes: 117 additions & 32 deletions crates/runtime/src/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::ptr;
use std::sync::Once;

pub use self::backtrace::{Backtrace, Frame};
pub use self::tls::{tls_eager_initialize, TlsRestore};
pub use self::tls::{tls_eager_initialize, AsyncWasmCallState, PreviousAsyncWasmCallState};

cfg_if::cfg_if! {
if #[cfg(miri)] {
Expand Down Expand Up @@ -297,7 +297,7 @@ mod call_thread_state {

pub(crate) limits: *const VMRuntimeLimits,

prev: Cell<tls::Ptr>,
pub(super) prev: Cell<tls::Ptr>,

// The values of `VMRuntimeLimits::last_wasm_{exit_{pc,fp},entry_sp}`
// for the *previous* `CallThreadState` for this same store/limits. Our
Expand Down Expand Up @@ -517,6 +517,8 @@ impl<T: Copy> Drop for ResetCell<'_, T> {
// the caller to the trap site.
mod tls {
use super::CallThreadState;
use std::mem;
use std::ops::Range;

pub use raw::Ptr;

Expand Down Expand Up @@ -588,49 +590,132 @@ mod tls {

pub use raw::initialize as tls_eager_initialize;

/// Opaque state used to help control TLS state across stack switches for
/// async support.
pub struct TlsRestore {
/// Opaque state used to persist the state of the `CallThreadState`
/// activations associated with a fiber stack that's used as part of an
/// async wasm call.
pub struct AsyncWasmCallState {
// The head of a linked list of activations that are currently present
// on an async call's fiber stack. This pointer points to the oldest
// activation frame where the `prev` links internally link to younger
// activation frames.
//
// When pushed onto a thread this linked list is traversed to get pushed
// onto the current thread at the time.
state: raw::Ptr,
}

impl TlsRestore {
/// Takes the TLS state that is currently configured and returns a
/// token that is used to replace it later.
impl AsyncWasmCallState {
/// Creates new state that initially starts as null.
pub fn new() -> AsyncWasmCallState {
AsyncWasmCallState {
state: std::ptr::null_mut(),
}
}

/// Pushes the saved state of this wasm's call onto the current thread's
/// state.
///
/// This will iterate over the linked list of states stored within
/// `self` and push them sequentially onto the current thread's
/// activation list.
///
/// The returned `PreviousAsyncWasmCallState` captures the state of this
/// thread just before this operation, and it must have its `restore`
/// method called to restore the state when the async wasm is suspended
/// from.
///
/// This is not a safe operation since it's intended to only be used
/// with stack switching found with fibers and async wasmtime.
pub unsafe fn take() -> TlsRestore {
// Our tls pointer must be set at this time, and it must not be
// null. We need to restore the previous pointer since we're
// removing ourselves from the call-stack, and in the process we
// null out our own previous field for safety in case it's
// accidentally used later.
let state = raw::get();
if let Some(state) = state.as_ref() {
state.pop();
} else {
// Null case: we aren't in a wasm context, so theres no tls to
// save for restoration.
/// # Unsafety
///
/// Must be carefully coordinated with
/// `PreviousAsyncWasmCallState::restore` and fiber switches to ensure
/// that this doesn't push stale data and the data is popped
/// appropriately.
pub unsafe fn push(self) -> PreviousAsyncWasmCallState {
// Our `state` pointer is a linked list of oldest-to-youngest so by
// pushing in order of the list we restore the youngest-to-oldest
// list as stored in the state of this current thread.
let ret = PreviousAsyncWasmCallState { state: raw::get() };
let mut ptr = self.state;
while let Some(state) = ptr.as_ref() {
ptr = state.prev.replace(std::ptr::null_mut());
state.push();
}
ret
}

TlsRestore { state }
/// Performs a runtime check that this state is indeed null.
pub fn assert_null(&self) {
assert!(self.state.is_null());
}

/// Restores a previous tls state back into this thread's TLS.
/// Asserts that the current CallThreadState pointer, if present, is not
/// in the `range` specified.
///
/// This is unsafe because it's intended to only be used within the
/// context of stack switching within wasmtime.
pub unsafe fn replace(self) {
if let Some(state) = self.state.as_ref() {
state.push();
} else {
// Null case: we aren't in a wasm context, so theres no tls
// to restore.
/// This is used when exiting a future in Wasmtime to assert that the
/// current CallThreadState pointer does not point within the stack
/// we're leaving (e.g. allocated for a fiber).
pub fn assert_current_state_not_in_range(range: Range<usize>) {
let p = raw::get() as usize;
assert!(p < range.start || range.end < p);
}
}

/// Opaque state used to help control TLS state across stack switches for
/// async support.
pub struct PreviousAsyncWasmCallState {
// The head of a linked list, similar to the TLS state. Note though that
// this list is stored in reverse order to assist with `push` and `pop`
// below.
//
// After a `push` call this stores the previous head for the current
// thread so we know when to stop popping during a `pop`.
state: raw::Ptr,
}

impl PreviousAsyncWasmCallState {
/// Pops a fiber's linked list of activations and stores them in
/// `AsyncWasmCallState`.
///
/// This will pop the top activation of this current thread continuously
/// until it reaches whatever the current activation was when `push` was
/// originally called.
///
/// # Unsafety
///
/// Must be paired with a `push` and only performed at a time when a
/// fiber is being suspended.
pub unsafe fn restore(self) -> AsyncWasmCallState {
let thread_head = self.state;
mem::forget(self);
let mut ret = AsyncWasmCallState::new();
loop {
// If the current TLS state is as we originally found it, then
// this loop is finished.
let ptr = raw::get();
if ptr == thread_head {
break ret;
}

// Pop this activation from the current thread's TLS state, and
// then afterwards push it onto our own linked list within this
// `AsyncWasmCallState`. Note that the linked list in `AsyncWasmCallState` is stored
// in reverse order so a subsequent `push` later on pushes
// everything in the right order.
(*ptr).pop();
if let Some(state) = ret.state.as_ref() {
(*ptr).prev.set(state);
}
ret.state = ptr;
}
}
}

impl Drop for PreviousAsyncWasmCallState {
fn drop(&mut self) {
panic!("must be consumed with `restore`");
}
}

/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `state`, unless
/// this is recursively called again.
Expand Down
Loading

0 comments on commit 550a16f

Please sign in to comment.