Skip to content

Commit

Permalink
Prevent future::get_or_*_insert_with to panic after an inserting ta…
Browse files Browse the repository at this point in the history
…sk was aborted

Add `WaiterGuard` to `future::value_initializer` which will ensure that the
waiter will be removed when the enclosing future has been aborted.

Fixes #59.
  • Loading branch information
tatsuya6502 committed Dec 28, 2021
1 parent a14cb0d commit d2cbbfd
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 27 deletions.
155 changes: 130 additions & 25 deletions src/future/value_initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,84 @@ use std::{
};

type ErrorObject = Arc<dyn Any + Send + Sync + 'static>;
type WaiterValue<V> = Option<Result<V, ErrorObject>>;
type Waiter<V> = Arc<RwLock<WaiterValue<V>>>;

pub(crate) enum InitResult<V, E> {
Initialized(V),
ReadExisting(V),
InitErr(Arc<E>),
}

enum WaiterValue<V> {
Computing,
Ready(Result<V, ErrorObject>),
// https://github.com/moka-rs/moka/issues/43
InitFuturePanicked,
// https://github.com/moka-rs/moka/issues/59
EnclosingFutureAborted,
}

type Waiter<V> = Arc<RwLock<WaiterValue<V>>>;

struct WaiterGuard<'a, K, V, S>
// NOTE: We usually do not attach trait bounds to hera at the struct definition, but
// the Drop trait requires these bounds here.
where
Arc<K>: Eq + Hash,
V: Clone,
S: BuildHasher,
{
is_waiter_value_set: bool,
key: &'a Arc<K>,
type_id: TypeId,
value_initializer: &'a ValueInitializer<K, V, S>,
write_lock: &'a mut WaiterValue<V>,
}

impl<'a, K, V, S> WaiterGuard<'a, K, V, S>
where
Arc<K>: Eq + Hash,
V: Clone,
S: BuildHasher,
{
fn new(
key: &'a Arc<K>,
type_id: TypeId,
value_initializer: &'a ValueInitializer<K, V, S>,
write_lock: &'a mut WaiterValue<V>,
) -> Self {
Self {
is_waiter_value_set: false,
key,
type_id,
value_initializer,
write_lock,
}
}

fn set_waiter_value(&mut self, v: WaiterValue<V>) {
*self.write_lock = v;
self.is_waiter_value_set = true;
}
}

impl<'a, K, V, S> Drop for WaiterGuard<'a, K, V, S>
where
Arc<K>: Eq + Hash,
V: Clone,
S: BuildHasher,
{
fn drop(&mut self) {
if !self.is_waiter_value_set {
// Value is not set. This means the future containing
// `get_or_*_insert_with` has been aborted:
// https://github.com/moka-rs/moka/issues/59
*self.write_lock = WaiterValue::EnclosingFutureAborted;
self.value_initializer.remove_waiter(self.key, self.type_id);
self.is_waiter_value_set = true;
}
}
}

pub(crate) struct ValueInitializer<K, V, S> {
// TypeId is the type ID of the concrete error type of generic type E in
// try_init_or_read(). We use the type ID as a part of the key to ensure that
Expand Down Expand Up @@ -44,8 +113,8 @@ where
{
// This closure will be called after the init closure has returned a value.
// It will convert the returned value (from init) into an InitResult.
let post_init = |_key, value: V, lock: &mut WaiterValue<V>| {
*lock = Some(Ok(value.clone()));
let post_init = |_key, value: V, mut guard: WaiterGuard<'_, K, V, S>| {
guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone())));
InitResult::Initialized(value)
};

Expand All @@ -64,14 +133,15 @@ where

// This closure will be called after the init closure has returned a value.
// It will convert the returned value (from init) into an InitResult.
let post_init = |key, value: Result<V, E>, lock: &mut WaiterValue<V>| match value {
let post_init = |key, value: Result<V, E>, mut guard: WaiterGuard<'_, K, V, S>| match value
{
Ok(value) => {
*lock = Some(Ok(value.clone()));
guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone())));
InitResult::Initialized(value)
}
Err(e) => {
let err: ErrorObject = Arc::new(e);
*lock = Some(Err(Arc::clone(&err)));
guard.set_waiter_value(WaiterValue::Ready(Err(Arc::clone(&err))));
self.remove_waiter(key, type_id);
InitResult::InitErr(err.downcast().unwrap())
}
Expand All @@ -91,7 +161,7 @@ where
) -> InitResult<V, E>
where
F: Future<Output = O>,
C: FnMut(&'a Arc<K>, O, &mut WaiterValue<V>) -> InitResult<V, E>,
C: FnMut(&'a Arc<K>, O, WaiterGuard<'_, K, V, S>) -> InitResult<V, E>,
E: Send + Sync + 'static,
{
use futures_util::FutureExt;
Expand All @@ -102,19 +172,25 @@ where
let mut retries = 0;

loop {
let waiter = Arc::new(RwLock::new(None));
let waiter = Arc::new(RwLock::new(WaiterValue::Computing));
let mut lock = waiter.write().await;

match self.try_insert_waiter(key, type_id, &waiter) {
None => {
// Our waiter was inserted. Let's resolve the init future.

// Create a guard. This will ensure to remove our waiter when the
// enclosing future has been aborted:
// https://github.com/moka-rs/moka/issues/59
let mut waiter_guard = WaiterGuard::new(key, type_id, self, &mut lock);

// Catching panic is safe here as we do not try to resolve the future again.
match AssertUnwindSafe(init).catch_unwind().await {
// Resolved.
Ok(value) => return post_init(key, value, &mut lock),
Ok(value) => return post_init(key, value, waiter_guard),
// Panicked.
Err(payload) => {
*lock = None;
waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked);
// Remove the waiter so that others can retry.
self.remove_waiter(key, type_id);
resume_unwind(payload);
Expand All @@ -126,22 +202,30 @@ where
// for a read lock to become available.
std::mem::drop(lock);
match &*res.read().await {
Some(Ok(value)) => return ReadExisting(value.clone()),
Some(Err(e)) => return InitErr(Arc::clone(e).downcast().unwrap()),
// None means somebody else's init future has been panicked.
None => {
WaiterValue::Ready(Ok(value)) => return ReadExisting(value.clone()),
WaiterValue::Ready(Err(e)) => {
return InitErr(Arc::clone(e).downcast().unwrap())
}
// Somebody else's init future has been panicked.
WaiterValue::InitFuturePanicked => {
retries += 1;
if retries < MAX_RETRIES {
// Retry from the beginning.
continue;
} else {
panic!(
r#"Too many retries. Tried to read the return value from the `init` \
future but failed {} times. Maybe the `init` kept panicking?"#,
retries
);
}
panic_if_retry_exhausted_for_panicking(retries, MAX_RETRIES);
// Retry from the beginning.
continue;
}
// Somebody else (a future containing `get_or_insert_with`/
// `get_or_try_insert_with`) has been aborted.
WaiterValue::EnclosingFutureAborted => {
retries += 1;
panic_if_retry_exhausted_for_aborting(retries, MAX_RETRIES);
// Retry from the beginning.
continue;
}
// Unexpected state.
WaiterValue::Computing => panic!(
"Got unexpected state `Computing` after resolving `init` future. \
This might be a bug in Moka"
),
}
}
}
Expand All @@ -168,3 +252,24 @@ where
.insert_with_or_modify((key, type_id), || waiter, |_, w| Arc::clone(w))
}
}

fn panic_if_retry_exhausted_for_panicking(retries: usize, max: usize) {
if retries >= max {
panic!(
"Too many retries. Tried to read the return value from the `init` future \
but failed {} times. Maybe the `init` kept panicking?",
retries
);
}
}

fn panic_if_retry_exhausted_for_aborting(retries: usize, max: usize) {
if retries >= max {
panic!(
"Too many retries. Tried to read the return value from the `init` future \
but failed {} times. Maybe the future containing `get_or_insert_with`/\
`get_or_try_insert_with` kept being aborted?",
retries
);
}
}
4 changes: 2 additions & 2 deletions src/sync/value_initializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ where
continue;
} else {
panic!(
r#"Too many retries. Tried to read the return value from the `init` \
closure but failed {} times. Maybe the `init` kept panicking?"#,
"Too many retries. Tried to read the return value from the `init` \
closure but failed {} times. Maybe the `init` kept panicking?",
retries
);
}
Expand Down

0 comments on commit d2cbbfd

Please sign in to comment.