diff --git a/src/future.rs b/src/future.rs index f8159314..f0625c2e 100644 --- a/src/future.rs +++ b/src/future.rs @@ -25,7 +25,7 @@ pub use { pub type PredicateId = String; // Empty struct to be used in InitResult::InitErr to represent the Option None. -struct OptionallyNone; +pub(crate) struct OptionallyNone; pub struct Iter<'i, K, V>(crate::sync_base::iter::Iter<'i, K, V>); diff --git a/src/future/cache.rs b/src/future/cache.rs index 2ca51ed4..acf8ada4 100644 --- a/src/future/cache.rs +++ b/src/future/cache.rs @@ -27,6 +27,7 @@ use std::{ fmt, future::Future, hash::{BuildHasher, Hash}, + pin::Pin, sync::Arc, time::Duration, }; @@ -915,6 +916,7 @@ where /// `init` futures. /// pub async fn get_with(&self, key: K, init: impl Future) -> V { + futures_util::pin_mut!(init); let hash = self.base.hash(&key); let key = Arc::new(key); let replace_if = None as Option bool>; @@ -931,9 +933,9 @@ where K: Borrow, Q: ToOwned + Hash + Eq + ?Sized, { + futures_util::pin_mut!(init); let hash = self.base.hash(key); let replace_if = None as Option bool>; - self.get_or_insert_with_hash_by_ref_and_fun(key, hash, init, replace_if, false) .await .into_value() @@ -948,6 +950,7 @@ where init: impl Future, replace_if: impl FnMut(&V) -> bool, ) -> V { + futures_util::pin_mut!(init); let hash = self.base.hash(&key); let key = Arc::new(key); self.get_or_insert_with_hash_and_fun(key, hash, init, Some(replace_if), false) @@ -1051,6 +1054,7 @@ where where F: Future>, { + futures_util::pin_mut!(init); let hash = self.base.hash(&key); let key = Arc::new(key); self.get_or_optionally_insert_with_hash_and_fun(key, hash, init, false) @@ -1068,6 +1072,7 @@ where K: Borrow, Q: ToOwned + Hash + Eq + ?Sized, { + futures_util::pin_mut!(init); let hash = self.base.hash(key); self.get_or_optionally_insert_with_hash_by_ref_and_fun(key, hash, init, false) .await @@ -1174,6 +1179,7 @@ where F: Future>, E: Send + Sync + 'static, { + futures_util::pin_mut!(init); let hash = self.base.hash(&key); let key = Arc::new(key); self.get_or_try_insert_with_hash_and_fun(key, hash, init, false) @@ -1191,6 +1197,7 @@ where K: Borrow, Q: ToOwned + Hash + Eq + ?Sized, { + futures_util::pin_mut!(init); let hash = self.base.hash(key); self.get_or_try_insert_with_hash_by_ref_and_fun(key, hash, init, false) .await @@ -1428,7 +1435,7 @@ where &self, key: Arc, hash: u64, - init: impl Future, + init: Pin<&mut impl Future>, mut replace_if: Option bool>, need_key: bool, ) -> Entry { @@ -1447,7 +1454,7 @@ where &self, key: &Q, hash: u64, - init: impl Future, + init: Pin<&mut impl Future>, mut replace_if: Option bool>, need_key: bool, ) -> Entry @@ -1471,7 +1478,7 @@ where &self, key: Arc, hash: u64, - init: impl Future, + init: Pin<&mut impl Future>, mut replace_if: Option bool>, need_key: bool, ) -> Entry { @@ -1489,9 +1496,12 @@ where None }; + let type_id = ValueInitializer::::type_id_for_get_with(); + let post_init = ValueInitializer::::post_init_for_get_with; + match self .value_initializer - .init_or_read(Arc::clone(&key), get, init, insert) + .try_init_or_read(&Arc::clone(&key), type_id, get, init, insert, post_init) .await { InitResult::Initialized(v) => { @@ -1546,7 +1556,7 @@ where &self, key: Arc, hash: u64, - init: F, + init: Pin<&mut F>, need_key: bool, ) -> Option> where @@ -1565,7 +1575,7 @@ where &self, key: &Q, hash: u64, - init: F, + init: Pin<&mut F>, need_key: bool, ) -> Option> where @@ -1587,7 +1597,7 @@ where &self, key: Arc, hash: u64, - init: F, + init: Pin<&mut F>, need_key: bool, ) -> Option> where @@ -1608,9 +1618,12 @@ where None }; + let type_id = ValueInitializer::::type_id_for_optionally_get_with(); + let post_init = ValueInitializer::::post_init_for_optionally_get_with; + match self .value_initializer - .optionally_init_or_read(Arc::clone(&key), get, init, insert) + .try_init_or_read(&Arc::clone(&key), type_id, get, init, insert, post_init) .await { InitResult::Initialized(v) => { @@ -1626,7 +1639,7 @@ where &self, key: Arc, hash: u64, - init: F, + init: Pin<&mut F>, need_key: bool, ) -> Result, Arc> where @@ -1645,7 +1658,7 @@ where &self, key: &Q, hash: u64, - init: F, + init: Pin<&mut F>, need_key: bool, ) -> Result, Arc> where @@ -1666,7 +1679,7 @@ where &self, key: Arc, hash: u64, - init: F, + init: Pin<&mut F>, need_key: bool, ) -> Result, Arc> where @@ -1688,9 +1701,12 @@ where None }; + let type_id = ValueInitializer::::type_id_for_try_get_with::(); + let post_init = ValueInitializer::::post_init_for_try_get_with; + match self .value_initializer - .try_init_or_read(Arc::clone(&key), get, init, insert) + .try_init_or_read(&Arc::clone(&key), type_id, get, init, insert, post_init) .await { InitResult::Initialized(v) => { diff --git a/src/future/entry_selector.rs b/src/future/entry_selector.rs index 9485919b..1257c153 100644 --- a/src/future/entry_selector.rs +++ b/src/future/entry_selector.rs @@ -176,6 +176,7 @@ where /// /// [get-with-method]: ./struct.Cache.html#method.get_with pub async fn or_insert_with(self, init: impl Future) -> Entry { + futures_util::pin_mut!(init); let key = Arc::new(self.owned_key); let replace_if = None as Option bool>; self.cache @@ -196,6 +197,7 @@ where init: impl Future, replace_if: impl FnMut(&V) -> bool, ) -> Entry { + futures_util::pin_mut!(init); let key = Arc::new(self.owned_key); self.cache .get_or_insert_with_hash_and_fun(key, self.hash, init, Some(replace_if), true) @@ -269,6 +271,7 @@ where self, init: impl Future>, ) -> Option> { + futures_util::pin_mut!(init); let key = Arc::new(self.owned_key); self.cache .get_or_optionally_insert_with_hash_and_fun(key, self.hash, init, true) @@ -344,6 +347,7 @@ where F: Future>, E: Send + Sync + 'static, { + futures_util::pin_mut!(init); let key = Arc::new(self.owned_key); self.cache .get_or_try_insert_with_hash_and_fun(key, self.hash, init, true) @@ -521,11 +525,10 @@ where /// /// [get-with-method]: ./struct.Cache.html#method.get_with pub async fn or_insert_with(self, init: impl Future) -> Entry { - let owned_key: K = self.ref_key.to_owned(); - let key = Arc::new(owned_key); + futures_util::pin_mut!(init); let replace_if = None as Option bool>; self.cache - .get_or_insert_with_hash_and_fun(key, self.hash, init, replace_if, true) + .get_or_insert_with_hash_by_ref_and_fun(self.ref_key, self.hash, init, replace_if, true) .await } @@ -542,10 +545,15 @@ where init: impl Future, replace_if: impl FnMut(&V) -> bool, ) -> Entry { - let owned_key: K = self.ref_key.to_owned(); - let key = Arc::new(owned_key); + futures_util::pin_mut!(init); self.cache - .get_or_insert_with_hash_and_fun(key, self.hash, init, Some(replace_if), true) + .get_or_insert_with_hash_by_ref_and_fun( + self.ref_key, + self.hash, + init, + Some(replace_if), + true, + ) .await } @@ -615,6 +623,7 @@ where self, init: impl Future>, ) -> Option> { + futures_util::pin_mut!(init); self.cache .get_or_optionally_insert_with_hash_by_ref_and_fun(self.ref_key, self.hash, init, true) .await @@ -690,6 +699,7 @@ where F: Future>, E: Send + Sync + 'static, { + futures_util::pin_mut!(init); self.cache .get_or_try_insert_with_hash_by_ref_and_fun(self.ref_key, self.hash, init, true) .await diff --git a/src/future/value_initializer.rs b/src/future/value_initializer.rs index 540f8f7b..b7f176b2 100644 --- a/src/future/value_initializer.rs +++ b/src/future/value_initializer.rs @@ -4,6 +4,7 @@ use std::{ any::{Any, TypeId}, future::Future, hash::{BuildHasher, Hash}, + pin::Pin, sync::Arc, }; use triomphe::Arc as TrioArc; @@ -115,143 +116,31 @@ where } } - /// # Panics - /// Panics if the `init` future has been panicked. - pub(crate) async fn init_or_read<'a>( - &'a self, - key: Arc, - // Closure to get an existing value from cache. - get: impl FnMut() -> Option, - init: impl Future, - // Closure to insert a new value into cache. - mut insert: impl FnMut(V) -> BoxFuture<'a, ()> + Send + 'a, - ) -> InitResult { - // This closure will be called before the init future is resolved, in order - // to check if the value has already been inserted by other async task. - let pre_init = make_pre_init(get); - - // This closure will be called after the init future has returned a value. It - // will insert the returned value (from init) to the cache, and convert the - // value into a pair of a WaiterValue and an InitResult. - let post_init = |value: V| { - async move { - insert(value.clone()).await; - ( - WaiterValue::Ready(Ok(value.clone())), - InitResult::Initialized(value), - ) - } - .boxed() - }; - - let type_id = TypeId::of::<()>(); - self.do_try_init(&key, type_id, pre_init, init, post_init) - .await - } - - /// # Panics - /// Panics if the `init` future has been panicked. - pub(crate) async fn try_init_or_read<'a, E>( - &'a self, - key: Arc, - get: impl FnMut() -> Option, - init: impl Future>, - mut insert: impl FnMut(V) -> BoxFuture<'a, ()> + Send + 'a, - ) -> InitResult - where - E: Send + Sync + 'static, - { - // This closure will be called before the init future is resolved, in order - // to check if the value has already been inserted by other async task. - let pre_init = make_pre_init(get); - - // This closure will be called after the init future has returned a value. It - // will insert the returned value (from init) to the cache, and convert the - // value into a pair of a WaiterValue and an InitResult. - let post_init = move |value: Result| { - async move { - match value { - Ok(value) => { - insert(value.clone()).await; - ( - WaiterValue::Ready(Ok(value.clone())), - InitResult::Initialized(value), - ) - } - Err(e) => { - let err: ErrorObject = Arc::new(e); - ( - WaiterValue::Ready(Err(Arc::clone(&err))), - InitResult::InitErr(err.downcast().unwrap()), - ) - } - } - } - .boxed() - }; - - let type_id = TypeId::of::(); - self.do_try_init(&key, type_id, pre_init, init, post_init) - .await - } + // + // NOTES: We use `Pin<&mut impl Future>` instead of `impl Future` here for the + // `init` argument. This is because we want to avoid the future size inflation + // caused by calling nested async functions. See the following links for more + // details: + // + // - https://github.com/moka-rs/moka/issues/212 + // - https://swatinem.de/blog/future-size/ + // /// # Panics /// Panics if the `init` future has been panicked. - pub(super) async fn optionally_init_or_read<'a>( - &'a self, - key: Arc, - get: impl FnMut() -> Option, - init: impl Future>, - mut insert: impl FnMut(V) -> BoxFuture<'a, ()> + Send + 'a, - ) -> InitResult { - // This closure will be called before the init future is resolved, in order - // to check if the value has already been inserted by other async task. - let pre_init = make_pre_init(get); - - // This closure will be called after the init future has returned a value. It - // will insert the returned value (from init) to the cache, and convert the - // value into a pair of a WaiterValue and an InitResult. - let post_init = |value: Option| { - async move { - match value { - Some(value) => { - insert(value.clone()).await; - ( - WaiterValue::Ready(Ok(value.clone())), - InitResult::Initialized(value), - ) - } - None => { - // `value` can be either `Some` or `None`. For `None` case, - // without change the existing API too much, we will need to - // convert `None` to Arc here. `Infallible` could not be - // instantiated. So it might be good to use an empty struct - // to indicate the error type. - let err: ErrorObject = Arc::new(OptionallyNone); - ( - WaiterValue::Ready(Err(Arc::clone(&err))), - InitResult::InitErr(err.downcast().unwrap()), - ) - } - } - } - .boxed() - }; - - let type_id = TypeId::of::(); - self.do_try_init(&key, type_id, pre_init, init, post_init) - .await - } - - /// # Panics - /// Panics if the `init` future has been panicked. - async fn do_try_init<'a, O, E>( + pub(crate) async fn try_init_or_read<'a, O, E>( &'a self, key: &Arc, type_id: TypeId, - mut pre_init: impl FnMut() -> Option<(WaiterValue, InitResult)>, - init: impl Future, - post_init: impl FnOnce(O) -> BoxFuture<'a, (WaiterValue, InitResult)>, + // Closure to get an existing value from cache. + mut get: impl FnMut() -> Option, + // Future to initialize a new value. + init: Pin<&mut impl Future>, + // Closure that returns a future to insert a new value into cache. + mut insert: impl FnMut(V) -> BoxFuture<'a, ()> + Send + 'a, + // This function will be called after the init future has returned a value of + // type O. It converts O into Result. + post_init: fn(O) -> Result, ) -> InitResult where E: Send + Sync + 'static, @@ -283,12 +172,12 @@ where ); // Check if the value has already been inserted by other thread. - if let Some((waiter_val, init_res)) = pre_init() { + if let Some(value) = get() { // Yes. Set the waiter value, remove our waiter, and return // the existing value. - waiter_guard.set_waiter_value(waiter_val); + waiter_guard.set_waiter_value(WaiterValue::Ready(Ok(value.clone()))); remove_waiter(&self.waiters, cht_key, hash); - return init_res; + return InitResult::ReadExisting(value); } // The value still does note exist. Let's resolve the init future. @@ -297,7 +186,22 @@ where match AssertUnwindSafe(init).catch_unwind().await { // Resolved. Ok(value) => { - let (waiter_val, init_res) = post_init(value).await; + let (waiter_val, init_res) = match post_init(value) { + Ok(value) => { + insert(value.clone()).await; + ( + WaiterValue::Ready(Ok(value.clone())), + InitResult::Initialized(value), + ) + } + Err(e) => { + let err: ErrorObject = Arc::new(e); + ( + WaiterValue::Ready(Err(Arc::clone(&err))), + InitResult::InitErr(err.downcast().unwrap()), + ) + } + }; waiter_guard.set_waiter_value(waiter_val); remove_waiter(&self.waiters, cht_key, hash); return init_res; @@ -345,6 +249,44 @@ where } } } + + /// The `post_init` function for the `get_with` method of cache. + pub(crate) fn post_init_for_get_with(value: V) -> Result { + Ok(value) + } + + /// The `post_init` function for the `optionally_get_with` method of cache. + pub(crate) fn post_init_for_optionally_get_with( + value: Option, + ) -> Result> { + // `value` can be either `Some` or `None`. For `None` case, without change + // the existing API too much, we will need to convert `None` to Arc here. + // `Infallible` could not be instantiated. So it might be good to use an + // empty struct to indicate the error type. + value.ok_or(Arc::new(OptionallyNone)) + } + + /// The `post_init` function for `try_get_with` method of cache. + pub(crate) fn post_init_for_try_get_with(result: Result) -> Result { + result + } + + /// Returns the `type_id` for `get_with` method of cache. + pub(crate) fn type_id_for_get_with() -> TypeId { + // NOTE: We use a regular function here instead of a const fn because TypeId + // is not stable as a const fn. (as of our MSRV) + TypeId::of::<()>() + } + + /// Returns the `type_id` for `optionally_get_with` method of cache. + pub(crate) fn type_id_for_optionally_get_with() -> TypeId { + TypeId::of::() + } + + /// Returns the `type_id` for `try_get_with` method of cache. + pub(crate) fn type_id_for_try_get_with() -> TypeId { + TypeId::of::() + } } #[inline] @@ -386,23 +328,6 @@ where (cht_key, hash) } -#[inline] -fn make_pre_init( - mut get: impl FnMut() -> Option, -) -> impl FnMut() -> Option<(WaiterValue, InitResult)> -where - V: Clone, -{ - move || { - get().map(|value| { - ( - WaiterValue::Ready(Ok(value.clone())), - InitResult::ReadExisting(value), - ) - }) - } -} - fn panic_if_retry_exhausted_for_panicking(retries: usize, max: usize) { if retries >= max { panic!(