Skip to content

Commit

Permalink
feat: Add cleanup code in case of await trap (#232)
Browse files Browse the repository at this point in the history
* Add cleanup code in case of await trap

* fmt

Co-authored-by: Linwei Shang <linwei.shang@dfinity.org>
  • Loading branch information
adamspofford-dfinity and lwshang committed Mar 28, 2022
1 parent 4147564 commit 095be3b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
24 changes: 24 additions & 0 deletions src/ic-cdk/src/api/call.rs
Expand Up @@ -7,6 +7,7 @@ use serde::ser::Error;
use std::future::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::Ordering;
use std::task::{Context, Poll, Waker};

#[cfg(target_arch = "wasm32-unknown-unknown")]
Expand Down Expand Up @@ -216,6 +217,28 @@ fn callback(state_ptr: *const InnerCell<CallFutureState<Vec<u8>>>) {
}
}

/// This function is called when [callback] was just called with the same parameter, and trapped.
/// We can't guarantee internal consistency at this point, but we can at least e.g. drop mutex guards.
/// Waker is a very opaque API, so the best we can do is set a global flag and proceed normally.
fn cleanup(state_ptr: *const InnerCell<CallFutureState<Vec<u8>>>) {
let state = unsafe { WasmCell::from_raw(state_ptr) };
// We set the call result, even though it won't be read on the default executor, because we can't guarantee it was called on our executor.
// None of these calls trap - the rollback from the previous trap ensures that the Mutex is not in a poisoned state.
{
state.borrow_mut().result = Some(match reject_code() {
RejectionCode::NoError => unsafe { Ok(arg_data_raw()) },
n => Err((n, reject_message())),
});
}
let w = state.borrow_mut().waker.take();
if let Some(waker) = w {
// Flag that we do not want to actually wake the task - we want to drop it *without* executing it.
crate::futures::CLEANUP.store(true, Ordering::Relaxed);
waker.wake();
crate::futures::CLEANUP.store(false, Ordering::Relaxed);
}
}

/// Similar to `call`, but without serialization.
pub fn call_raw(
id: Principal,
Expand Down Expand Up @@ -276,6 +299,7 @@ fn call_raw_internal(

ic0::call_data_append(args_raw.as_ptr() as i32, args_raw.len() as i32);
payment_func();
ic0::call_on_cleanup(cleanup as usize as i32, state_ptr as i32);
ic0::call_perform()
};

Expand Down
20 changes: 14 additions & 6 deletions src/ic-cdk/src/futures.rs
@@ -1,5 +1,6 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::task::Context;

/// Must be called on every top-level future corresponding to a method call of a
Expand Down Expand Up @@ -36,6 +37,8 @@ pub fn spawn<F: 'static + Future<Output = ()>>(future: F) {
}
}

pub(crate) static CLEANUP: AtomicBool = AtomicBool::new(false);

// This module contains the implementation of a waker we're using for waking
// top-level futures (the ones returned by canister methods). The waker polls
// the future once and re-pins it on the heap, if it's pending. If the future is
Expand All @@ -44,7 +47,10 @@ pub fn spawn<F: 'static + Future<Output = ()>>(future: F) {
// waker was used as intended.
mod waker {
use super::*;
use std::task::{RawWaker, RawWakerVTable, Waker};
use std::{
sync::atomic::Ordering,
task::{RawWaker, RawWakerVTable, Waker},
};
type FuturePtr = *mut dyn Future<Output = ()>;

static MY_VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
Expand All @@ -61,16 +67,18 @@ mod waker {
// Then, the waker will restore the future from the pointer we passed into the
// waker inside the `kickstart` method and poll the future again. If the future
// is pending, we leave it on the heap. If it's ready, we deallocate the
// pointer.
// pointer. If CLEANUP is set, then we're recovering from a callback trap, and
// want to drop the future without executing any more of it.
unsafe fn wake(ptr: *const ()) {
let boxed_future_ptr_ptr = Box::from_raw(ptr as *mut FuturePtr);
let future_ptr: FuturePtr = *boxed_future_ptr_ptr;
let boxed_future = Box::from_raw(future_ptr);
let mut pinned_future = Pin::new_unchecked(&mut *future_ptr);
if pinned_future
.as_mut()
.poll(&mut Context::from_waker(&waker::waker(ptr)))
.is_pending()
if !super::CLEANUP.load(Ordering::Relaxed)
&& pinned_future
.as_mut()
.poll(&mut Context::from_waker(&waker::waker(ptr)))
.is_pending()
{
Box::into_raw(boxed_future_ptr_ptr);
Box::into_raw(boxed_future);
Expand Down

0 comments on commit 095be3b

Please sign in to comment.