From d2cbbfd9e5370f2e299a04f8b32bad9023cd0fc3 Mon Sep 17 00:00:00 2001 From: Tatsuya Kawano Date: Tue, 28 Dec 2021 20:27:03 +0800 Subject: [PATCH] Prevent `future::get_or_*_insert_with` to panic after an inserting task was aborted Add `WaiterGuard` to `future::value_initializer` which will ensure that the waiter will be removed when the enclosing future has been aborted. Fixes #59. --- src/future/value_initializer.rs | 155 ++++++++++++++++++++++++++------ src/sync/value_initializer.rs | 4 +- 2 files changed, 132 insertions(+), 27 deletions(-) diff --git a/src/future/value_initializer.rs b/src/future/value_initializer.rs index 99e04092..1ba69209 100644 --- a/src/future/value_initializer.rs +++ b/src/future/value_initializer.rs @@ -7,8 +7,6 @@ use std::{ }; type ErrorObject = Arc; -type WaiterValue = Option>; -type Waiter = Arc>>; pub(crate) enum InitResult { Initialized(V), @@ -16,6 +14,77 @@ pub(crate) enum InitResult { InitErr(Arc), } +enum WaiterValue { + Computing, + Ready(Result), + // https://github.com/moka-rs/moka/issues/43 + InitFuturePanicked, + // https://github.com/moka-rs/moka/issues/59 + EnclosingFutureAborted, +} + +type Waiter = Arc>>; + +struct WaiterGuard<'a, K, V, S> +// NOTE: We usually do not attach trait bounds to hera at the struct definition, but +// the Drop trait requires these bounds here. +where + Arc: Eq + Hash, + V: Clone, + S: BuildHasher, +{ + is_waiter_value_set: bool, + key: &'a Arc, + type_id: TypeId, + value_initializer: &'a ValueInitializer, + write_lock: &'a mut WaiterValue, +} + +impl<'a, K, V, S> WaiterGuard<'a, K, V, S> +where + Arc: Eq + Hash, + V: Clone, + S: BuildHasher, +{ + fn new( + key: &'a Arc, + type_id: TypeId, + value_initializer: &'a ValueInitializer, + write_lock: &'a mut WaiterValue, + ) -> Self { + Self { + is_waiter_value_set: false, + key, + type_id, + value_initializer, + write_lock, + } + } + + fn set_waiter_value(&mut self, v: WaiterValue) { + *self.write_lock = v; + self.is_waiter_value_set = true; + } +} + +impl<'a, K, V, S> Drop for WaiterGuard<'a, K, V, S> +where + Arc: Eq + Hash, + V: Clone, + S: BuildHasher, +{ + fn drop(&mut self) { + if !self.is_waiter_value_set { + // Value is not set. This means the future containing + // `get_or_*_insert_with` has been aborted: + // https://github.com/moka-rs/moka/issues/59 + *self.write_lock = WaiterValue::EnclosingFutureAborted; + self.value_initializer.remove_waiter(self.key, self.type_id); + self.is_waiter_value_set = true; + } + } +} + pub(crate) struct ValueInitializer { // TypeId is the type ID of the concrete error type of generic type E in // try_init_or_read(). We use the type ID as a part of the key to ensure that @@ -44,8 +113,8 @@ where { // This closure will be called after the init closure has returned a value. // It will convert the returned value (from init) into an InitResult. - let post_init = |_key, value: V, lock: &mut WaiterValue| { - *lock = Some(Ok(value.clone())); + let post_init = |_key, value: V, mut guard: WaiterGuard<'_, K, V, S>| { + guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone()))); InitResult::Initialized(value) }; @@ -64,14 +133,15 @@ where // This closure will be called after the init closure has returned a value. // It will convert the returned value (from init) into an InitResult. - let post_init = |key, value: Result, lock: &mut WaiterValue| match value { + let post_init = |key, value: Result, mut guard: WaiterGuard<'_, K, V, S>| match value + { Ok(value) => { - *lock = Some(Ok(value.clone())); + guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone()))); InitResult::Initialized(value) } Err(e) => { let err: ErrorObject = Arc::new(e); - *lock = Some(Err(Arc::clone(&err))); + guard.set_waiter_value(WaiterValue::Ready(Err(Arc::clone(&err)))); self.remove_waiter(key, type_id); InitResult::InitErr(err.downcast().unwrap()) } @@ -91,7 +161,7 @@ where ) -> InitResult where F: Future, - C: FnMut(&'a Arc, O, &mut WaiterValue) -> InitResult, + C: FnMut(&'a Arc, O, WaiterGuard<'_, K, V, S>) -> InitResult, E: Send + Sync + 'static, { use futures_util::FutureExt; @@ -102,19 +172,25 @@ where let mut retries = 0; loop { - let waiter = Arc::new(RwLock::new(None)); + let waiter = Arc::new(RwLock::new(WaiterValue::Computing)); let mut lock = waiter.write().await; match self.try_insert_waiter(key, type_id, &waiter) { None => { // Our waiter was inserted. Let's resolve the init future. + + // Create a guard. This will ensure to remove our waiter when the + // enclosing future has been aborted: + // https://github.com/moka-rs/moka/issues/59 + let mut waiter_guard = WaiterGuard::new(key, type_id, self, &mut lock); + // Catching panic is safe here as we do not try to resolve the future again. match AssertUnwindSafe(init).catch_unwind().await { // Resolved. - Ok(value) => return post_init(key, value, &mut lock), + Ok(value) => return post_init(key, value, waiter_guard), // Panicked. Err(payload) => { - *lock = None; + waiter_guard.set_waiter_value(WaiterValue::InitFuturePanicked); // Remove the waiter so that others can retry. self.remove_waiter(key, type_id); resume_unwind(payload); @@ -126,22 +202,30 @@ where // for a read lock to become available. std::mem::drop(lock); match &*res.read().await { - Some(Ok(value)) => return ReadExisting(value.clone()), - Some(Err(e)) => return InitErr(Arc::clone(e).downcast().unwrap()), - // None means somebody else's init future has been panicked. - None => { + WaiterValue::Ready(Ok(value)) => return ReadExisting(value.clone()), + WaiterValue::Ready(Err(e)) => { + return InitErr(Arc::clone(e).downcast().unwrap()) + } + // Somebody else's init future has been panicked. + WaiterValue::InitFuturePanicked => { retries += 1; - if retries < MAX_RETRIES { - // Retry from the beginning. - continue; - } else { - panic!( - r#"Too many retries. Tried to read the return value from the `init` \ - future but failed {} times. Maybe the `init` kept panicking?"#, - retries - ); - } + panic_if_retry_exhausted_for_panicking(retries, MAX_RETRIES); + // Retry from the beginning. + continue; } + // Somebody else (a future containing `get_or_insert_with`/ + // `get_or_try_insert_with`) has been aborted. + WaiterValue::EnclosingFutureAborted => { + retries += 1; + panic_if_retry_exhausted_for_aborting(retries, MAX_RETRIES); + // Retry from the beginning. + continue; + } + // Unexpected state. + WaiterValue::Computing => panic!( + "Got unexpected state `Computing` after resolving `init` future. \ + This might be a bug in Moka" + ), } } } @@ -168,3 +252,24 @@ where .insert_with_or_modify((key, type_id), || waiter, |_, w| Arc::clone(w)) } } + +fn panic_if_retry_exhausted_for_panicking(retries: usize, max: usize) { + if retries >= max { + panic!( + "Too many retries. Tried to read the return value from the `init` future \ + but failed {} times. Maybe the `init` kept panicking?", + retries + ); + } +} + +fn panic_if_retry_exhausted_for_aborting(retries: usize, max: usize) { + if retries >= max { + panic!( + "Too many retries. Tried to read the return value from the `init` future \ + but failed {} times. Maybe the future containing `get_or_insert_with`/\ + `get_or_try_insert_with` kept being aborted?", + retries + ); + } +} diff --git a/src/sync/value_initializer.rs b/src/sync/value_initializer.rs index da4ff56c..6812d188 100644 --- a/src/sync/value_initializer.rs +++ b/src/sync/value_initializer.rs @@ -131,8 +131,8 @@ where continue; } else { panic!( - r#"Too many retries. Tried to read the return value from the `init` \ - closure but failed {} times. Maybe the `init` kept panicking?"#, + "Too many retries. Tried to read the return value from the `init` \ + closure but failed {} times. Maybe the `init` kept panicking?", retries ); }