diff --git a/bastion/src/children.rs b/bastion/src/children.rs index 111a6d26..05d84e43 100644 --- a/bastion/src/children.rs +++ b/bastion/src/children.rs @@ -10,7 +10,7 @@ use runtime::task::JoinHandle; use std::any::Any; use std::fmt::Debug; use std::future::Future; -use std::panic::UnwindSafe; +use std::panic::AssertUnwindSafe; use std::pin::Pin; use std::task::Poll; @@ -29,12 +29,22 @@ where } } -pub trait Closure: Fn(BastionContext, Box) -> Pin> + Shell {} -impl Closure for T where T: Fn(BastionContext, Box) -> Pin> + Shell {} +pub trait Closure: Fn(BastionContext, Box) -> Fut + Shell {} +impl Closure for T where T: Fn(BastionContext, Box) -> Fut + Shell {} // TODO: Ok(T) & Err(E) -pub trait Fut: Future> + Send + UnwindSafe {} -impl Fut for T where T: Future> + Send + UnwindSafe {} +type FutInner = dyn Future> + Send; + +pub struct Fut(Pin>); + +impl From for Fut +where + T: Future> + Send + 'static, +{ + fn from(fut: T) -> Fut { + Fut(Box::pin(fut)) + } +} pub(super) struct Children { thunk: Box, @@ -44,7 +54,7 @@ pub(super) struct Children { } pub(super) struct Child { - exec: CatchUnwind>>, + exec: CatchUnwind>>>, bcast: Broadcast, state: Qutex, } @@ -126,7 +136,7 @@ impl Children { let parent = self.bcast.sender().clone(); let ctx = BastionContext::new(id, parent, state.clone()); - let exec = thunk(ctx, msg) + let exec = AssertUnwindSafe(thunk(ctx, msg).0) .catch_unwind(); let child = Child { exec, bcast, state };