diff --git a/src/lib.rs b/src/lib.rs index a0870bf..a2953c3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -215,10 +215,23 @@ macro_rules! test_println { } } } + +#[cfg(all(test, loom))] +macro_rules! test_dbg { + ($e:expr) => { + match $e { + e => { + test_println!("{} = {:?}", stringify!($e), &e); + e + } + } + }; +} + mod clear; pub mod implementation; mod page; -mod pool; +pub mod pool; pub(crate) mod sync; mod tid; pub(crate) use tid::Tid; @@ -228,7 +241,9 @@ mod shard; use cfg::CfgPrivate; pub use cfg::{Config, DefaultConfig}; pub use clear::Clear; -pub use pool::{Pool, PoolGuard}; +#[doc(inline)] +pub use pool::Pool; +use std::ptr; use shard::Shard; use std::{fmt, marker::PhantomData}; @@ -247,7 +262,8 @@ pub struct Slab { /// references is currently being accessed. If the item is removed from the slab /// while a guard exists, the removal will be deferred until all guards are dropped. pub struct Guard<'a, T, C: cfg::Config = DefaultConfig> { - inner: page::slot::Guard<'a, T, C>, + inner: page::slot::Guard<'a, Option, C>, + value: ptr::NonNull, shard: &'a Shard, C>, key: usize, } @@ -304,7 +320,10 @@ impl Slab { test_println!("insert {:?}", tid); let mut value = Some(value); shard - .init_with(|slot| slot.insert(&mut value)) + .init_with(|idx, slot| { + let gen = slot.insert(&mut value)?; + Some(gen.pack(idx)) + }) .map(|idx| tid.pack(idx)) } @@ -452,13 +471,14 @@ impl Slab { test_println!("get {:?}; current={:?}", tid, Tid::::current()); let shard = self.shards.get(tid.as_usize())?; - let inner = shard.get(key, |x| { - x.as_ref().expect( - "if a slot can be accessed at the current generation, its value must be `Some`", - ) - })?; - - Some(Guard { inner, shard, key }) + let inner = shard.with_slot(key, |slot| slot.get(C::unpack_gen(key)))?; + let value = ptr::NonNull::from(inner.slot().value().as_ref().unwrap()); + Some(Guard { + inner, + value, + shard, + key, + }) } /// Returns `true` if the slab contains a value for the given key. @@ -517,13 +537,23 @@ impl<'a, T, C: cfg::Config> Guard<'a, T, C> { pub fn key(&self) -> usize { self.key } + + #[inline(always)] + fn value(&self) -> &T { + unsafe { + // Safety: this is always going to be valid, as it's projected from + // the safe reference to `self.value` --- this is just to avoid + // having to `expect` an option in the hot path when dereferencing. + self.value.as_ref() + } + } } impl<'a, T, C: cfg::Config> std::ops::Deref for Guard<'a, T, C> { type Target = T; fn deref(&self) -> &Self::Target { - self.inner.item() + self.value() } } @@ -547,7 +577,7 @@ where C: cfg::Config, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(self.inner.item(), f) + fmt::Debug::fmt(self.value(), f) } } @@ -557,7 +587,7 @@ where C: cfg::Config, { fn eq(&self, other: &T) -> bool { - self.inner.item().eq(other) + self.value().eq(other) } } diff --git a/src/page/mod.rs b/src/page/mod.rs index 143889c..48bd209 100644 --- a/src/page/mod.rs +++ b/src/page/mod.rs @@ -5,7 +5,7 @@ use crate::Pack; pub(crate) mod slot; mod stack; -use self::slot::Slot; +pub(crate) use self::slot::Slot; use std::{fmt, marker::PhantomData}; /// A page address encodes the location of a slot within a shard (the page @@ -175,21 +175,18 @@ where } #[inline] - pub(crate) fn get( - &self, + pub(crate) fn with_slot<'a, U>( + &'a self, addr: Addr, - idx: usize, - f: impl FnOnce(&T) -> &U, - ) -> Option> { + f: impl FnOnce(&'a Slot) -> Option, + ) -> Option { let poff = addr.offset() - self.prev_sz; test_println!("-> offset {:?}", poff); self.slab.with(|slab| { - unsafe { &*slab } - .as_ref()? - .get(poff)? - .get(C::unpack_gen(idx), f) + let slot = unsafe { &*slab }.as_ref()?.get(poff)?; + f(slot) }) } @@ -264,15 +261,11 @@ where T: Clear + Default, C: cfg::Config, { - /// Allocate and initialize a new slot. - /// - /// It does this via the provided initializatin function `func`. Once it get's the generation - /// number for the new slot, it performs the operations required to return the key to the - /// caller.] - pub(crate) fn init_with(&self, local: &Local, func: F) -> Option - where - F: FnOnce(&Slot) -> Option>, - { + pub(crate) fn init_with( + &self, + local: &Local, + init: impl FnOnce(usize, &Slot) -> Option, + ) -> Option { let head = self.pop(local)?; // do we need to allocate storage for this page? @@ -280,19 +273,20 @@ where self.allocate(); } - let gen = self.slab.with(|slab| { + let index = head + self.prev_sz; + + let result = self.slab.with(|slab| { let slab = unsafe { &*(slab) } .as_ref() .expect("page must have been allocated to insert!"); let slot = &slab[head]; + let result = init(index, slot)?; local.set_head(slot.next()); - func(slot) + Some(result) })?; - let index = head + self.prev_sz; - - test_println!("-> initialize_new_slot: insert at offset: {}", index); - Some(gen.pack(index)) + test_println!("-> init_with: insert at offset: {}", index); + Some(result) } /// Allocates storage for the page's slots. diff --git a/src/page/slot.rs b/src/page/slot.rs index b23b15a..ceed33a 100644 --- a/src/page/slot.rs +++ b/src/page/slot.rs @@ -4,7 +4,7 @@ use crate::sync::{ UnsafeCell, }; use crate::{cfg, clear::Clear, Pack, Tid}; -use std::{fmt, marker::PhantomData}; +use std::{fmt, marker::PhantomData, mem, ptr, thread}; pub(crate) struct Slot { lifecycle: AtomicUsize, @@ -16,10 +16,20 @@ pub(crate) struct Slot { } #[derive(Debug)] -pub(crate) struct Guard<'a, T, C = cfg::DefaultConfig> { - item: &'a T, - lifecycle: &'a AtomicUsize, - _cfg: PhantomData, +pub(crate) struct Guard<'a, T, C: cfg::Config = cfg::DefaultConfig> { + slot: &'a Slot, +} + +#[derive(Debug)] +pub(crate) struct GuardMut<'a, T, C: cfg::Config = cfg::DefaultConfig> { + slot: &'a Slot, +} + +#[derive(Debug)] +pub(crate) struct InitGuard { + slot: ptr::NonNull>, + curr_lifecycle: usize, + released: bool, } #[repr(transparent)] @@ -43,7 +53,7 @@ struct LifecycleGen(Generation); #[derive(Debug, Eq, PartialEq, Copy, Clone)] #[repr(usize)] enum State { - NotRemoved = 0b00, + Present = 0b00, Marked = 0b01, Removing = 0b11, } @@ -87,7 +97,7 @@ where } #[inline(always)] - pub(super) fn value(&self) -> &T { + pub(crate) fn value(&self) -> &T { self.item.with(|item| unsafe { &*item }) } @@ -99,11 +109,7 @@ where } #[inline(always)] - pub(in crate::page) fn get( - &self, - gen: Generation, - f: impl FnOnce(&T) -> &U, - ) -> Option> { + pub(crate) fn get<'a>(&'a self, gen: Generation) -> Option> { let mut lifecycle = self.lifecycle.load(Ordering::Acquire); loop { // Unpack the current state. @@ -124,22 +130,13 @@ where // current, and the slot must not be in the process of being // removed. If we can no longer access the slot at the given // generation, return `None`. - if gen != current_gen || state != Lifecycle::NOT_REMOVED { + if gen != current_gen || state != Lifecycle::PRESENT { test_println!("-> get: no longer exists!"); return None; } - // Would incrementing the ref count cause an overflow? - if refs.value >= RefCount::::MAX { - test_println!( - "-> get: max concurrent references ({}) reached!", - RefCount::::MAX - ); - return None; - } - // Try to increment the slot's ref count by one. - let new_refs = refs.incr(); + let new_refs = refs.incr()?; match self.lifecycle.compare_exchange( lifecycle, new_refs.pack(current_gen.pack(state.pack(0))), @@ -149,15 +146,11 @@ where Ok(_) => { // Okay, the ref count was incremented successfully! We can // now return a guard! - let item = f(self.value()); + // let item = f(self.value()); test_println!("-> {:?}", new_refs); - return Some(Guard { - item, - lifecycle: &self.lifecycle, - _cfg: PhantomData, - }); + return Some(Guard { slot: self }); } Err(actual) => { // Another thread modified the slot's state before us! We @@ -305,19 +298,19 @@ where /// Initialize a slot /// /// This method initializes and sets up the state for a slot. When being used in `Pool`, we - /// only need to ensure that the `Slot` is in the right state, while when being used in a + /// only need to ensure that the `Slot` is in the right `state, while when being used in a /// `Slab` we want to insert a value into it, as the memory is not initialized - pub(crate) fn initialize_state(&self, f: impl FnOnce(&mut T)) -> Option> { + pub(crate) fn init(&self) -> Option> { // Load the current lifecycle state. let lifecycle = self.lifecycle.load(Ordering::Acquire); - let gen = LifecycleGen::from_packed(lifecycle).0; + let gen = LifecycleGen::::from_packed(lifecycle).0; let refs = RefCount::::from_packed(lifecycle); test_println!( - "-> initialize_state; state={:?}; gen={:?}; refs={:?}", + "-> initialize_state; state={:?}; gen={:?}; refs={:?};", Lifecycle::::from_packed(lifecycle), gen, - refs + refs, ); if refs.value != 0 { @@ -325,32 +318,11 @@ where return None; } - // Set the slot's state to NOT_REMOVED. - let new_lifecycle = gen.pack(Lifecycle::::NOT_REMOVED.pack(0)); - let was_set = self.lifecycle.compare_exchange( - lifecycle, - new_lifecycle, - Ordering::AcqRel, - Ordering::Acquire, - ); - if let Err(_actual) = was_set { - // The slot was modified while we were inserting to it! It's no - // longer safe to insert a new value. - test_println!( - "-> modified during insert, cancelling! new={:#x}; expected={:#x}; actual={:#x};", - new_lifecycle, - lifecycle, - _actual - ); - return None; - } - - // call provided function to update this slot - self.item.with_mut(|item| unsafe { - f(&mut *item); - }); - - Some(gen) + Some(InitGuard { + slot: ptr::NonNull::from(self), + curr_lifecycle: lifecycle, + released: false, + }) } } @@ -371,9 +343,16 @@ where debug_assert!(self.is_empty(), "inserted into full slot"); debug_assert!(value.is_some(), "inserted twice"); - let gen = self.initialize_state(|item| { - *item = value.take(); - })?; + let mut guard = self.init()?; + let gen = guard.generation(); + unsafe { + // Safety: Accessing the value of an `InitGuard` is unsafe because + // it has a pointer to a slot which may dangle. Here, we know the + // pointed slot is alive because we have a reference to it in scope, + // and the `InitGuard` will be dropped when this function returns. + mem::swap(guard.value_mut(), value); + guard.release(); + }; test_println!("-> inserted at {:?}", gen); Some(gen) @@ -419,7 +398,7 @@ where { pub(in crate::page) fn new(next: usize) -> Self { Self { - lifecycle: AtomicUsize::new(0), + lifecycle: AtomicUsize::new(Lifecycle::::REMOVING.as_usize()), item: UnsafeCell::new(T::default()), next: UnsafeCell::new(next), _cfg: PhantomData, @@ -465,6 +444,52 @@ where } } +impl Slot { + fn release(&self) -> bool { + let mut lifecycle = self.lifecycle.load(Ordering::Acquire); + loop { + let refs = RefCount::::from_packed(lifecycle); + let state = Lifecycle::::from_packed(lifecycle).state; + let gen = LifecycleGen::::from_packed(lifecycle).0; + + // Are we the last guard, and is the slot marked for removal? + let dropping = refs.value == 1 && state == State::Marked; + let new_lifecycle = if dropping { + // If so, we want to advance the state to "removing" + gen.pack(State::Removing as usize) + } else { + // Otherwise, just subtract 1 from the ref count. + refs.decr().pack(lifecycle) + }; + + test_println!( + "-> drop guard: state={:?}; gen={:?}; refs={:?}; lifecycle={:#x}; new_lifecycle={:#x}; dropping={:?}", + state, + gen, + refs, + lifecycle, + new_lifecycle, + dropping + ); + match self.lifecycle.compare_exchange( + lifecycle, + new_lifecycle, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + test_println!("-> drop guard: done; dropping={:?}", dropping); + return dropping; + } + Err(actual) => { + test_println!("-> drop guard; retry, actual={:#x}", actual); + lifecycle = actual; + } + } + } + } +} + impl fmt::Debug for Slot { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let lifecycle = self.lifecycle.load(Ordering::Relaxed); @@ -524,51 +549,16 @@ impl Copy for Generation {} impl<'a, T, C: cfg::Config> Guard<'a, T, C> { pub(crate) fn release(&self) -> bool { - let mut lifecycle = self.lifecycle.load(Ordering::Acquire); - loop { - let refs = RefCount::::from_packed(lifecycle); - let state = Lifecycle::::from_packed(lifecycle).state; - let gen = LifecycleGen::::from_packed(lifecycle).0; - - // Are we the last guard, and is the slot marked for removal? - let dropping = refs.value == 1 && state == State::Marked; - let new_lifecycle = if dropping { - // If so, we want to advance the state to "removing" - gen.pack(State::Removing as usize) - } else { - // Otherwise, just subtract 1 from the ref count. - refs.decr().pack(lifecycle) - }; + self.slot.release() + } - test_println!( - "-> drop guard: state={:?}; gen={:?}; refs={:?}; lifecycle={:#x}; new_lifecycle={:#x}; dropping={:?}", - state, - gen, - refs, - lifecycle, - new_lifecycle, - dropping - ); - match self.lifecycle.compare_exchange( - lifecycle, - new_lifecycle, - Ordering::AcqRel, - Ordering::Acquire, - ) { - Ok(_) => { - test_println!("-> drop guard: done; dropping={:?}", dropping); - return dropping; - } - Err(actual) => { - test_println!("-> drop guard; retry, actual={:#x}", actual); - lifecycle = actual; - } - } - } + pub(crate) fn slot(&self) -> &Slot { + self.slot } - pub(crate) fn item(&self) -> &T { - self.item + #[inline(always)] + pub(crate) fn value(&self) -> &T { + self.slot.item.with(|item| unsafe { &*item }) } } @@ -579,9 +569,12 @@ impl Lifecycle { state: State::Marked, _cfg: PhantomData, }; - - const NOT_REMOVED: Self = Self { - state: State::NotRemoved, + const REMOVING: Self = Self { + state: State::Removing, + _cfg: PhantomData, + }; + const PRESENT: Self = Self { + state: State::Present, _cfg: PhantomData, }; } @@ -593,7 +586,7 @@ impl Pack for Lifecycle { fn from_usize(u: usize) -> Self { Self { state: match u & Self::MASK { - 0b00 => State::NotRemoved, + 0b00 => State::Present, 0b01 => State::Marked, 0b11 => State::Removing, bad => unreachable!("weird lifecycle {:#b}", bad), @@ -628,7 +621,7 @@ impl Pack for RefCount { type Prev = Lifecycle; fn from_usize(value: usize) -> Self { - debug_assert!(value <= Self::MAX); + debug_assert!(value <= Self::BITS); Self { value, _cfg: PhantomData, @@ -641,20 +634,16 @@ impl Pack for RefCount { } impl RefCount { - pub(crate) const MAX: usize = Self::BITS; + pub(crate) const MAX: usize = Self::BITS - 1; #[inline] - fn incr(self) -> Self { - // It's okay for this to be a debug assertion, because the check in - // `Slot::get` should protect against incrementing the reference count - // if it would overflow. This is intended to test that the check is in - // place. - debug_assert!( - self.value < Self::MAX, - "incrementing ref count would overflow max value ({})", - Self::MAX - ); - Self::from_usize(self.value + 1) + fn incr(self) -> Option { + if self.value >= Self::MAX { + test_println!("-> get: {}; MAX={}", self.value, RefCount::::MAX); + return None; + } + + Some(Self::from_usize(self.value + 1)) } #[inline] @@ -712,6 +701,117 @@ impl Pack for LifecycleGen { } } +impl InitGuard { + pub(crate) fn generation(&self) -> Generation { + LifecycleGen::::from_packed(self.curr_lifecycle).0 + } + + /// Returns a borrowed reference to the slot's value. + /// + /// ## Safety + /// + /// This dereferences a raw pointer to the slot. The caller is responsible + /// for ensuring that the `InitGuard` does not outlive the slab that + /// contains the pointed slot. Failure to do so means this pointer may + /// dangle. + pub(crate) unsafe fn value(&self) -> &T { + self.slot.as_ref().item.with(|val| &*val) + } + + /// Returns a mutably borrowed reference to the slot's value. + /// + /// ## Safety + /// + /// This dereferences a raw pointer to the slot. The caller is responsible + /// for ensuring that the `InitGuard` does not outlive the slab that + /// contains the pointed slot. Failure to do so means this pointer may + /// dangle. + /// + /// It's safe to reference the slot mutably, though, because creating an + /// `InitGuard` ensures there are no outstanding immutable references. + pub(crate) unsafe fn value_mut(&mut self) -> &mut T { + self.slot.as_ref().item.with_mut(|val| &mut *val) + } + + /// Releases the guard, returning whether the slot should be cleared. + /// + /// ## Safety + /// + /// This dereferences a raw pointer to the slot. The caller is responsible + /// for ensuring that the `InitGuard` does not outlive the slab that + /// contains the pointed slot. Failure to do so means this pointer may + /// dangle. + pub(crate) unsafe fn release(&mut self) -> bool { + test_println!( + "InitGuard::release; curr_lifecycle={:?};", + Lifecycle::::from_packed(self.curr_lifecycle) + ); + if self.released { + test_println!("-> already released!"); + return false; + } + self.released = true; + let mut curr_lifecycle = self.curr_lifecycle; + let slot = self.slot.as_ref(); + let new_lifecycle = LifecycleGen::::from_packed(self.curr_lifecycle) + .pack(Lifecycle::::PRESENT.pack(0)); + + match slot.lifecycle.compare_exchange( + curr_lifecycle, + new_lifecycle, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + test_println!("--> advanced to PRESENT; done"); + return false; + } + Err(actual) => { + test_println!( + "--> lifecycle changed; actual={:?}", + Lifecycle::::from_packed(actual) + ); + curr_lifecycle = actual; + } + } + + // if the state was no longer the prior state, we are now responsible + // for releasing the slot. + loop { + let refs = RefCount::::from_packed(curr_lifecycle); + let state = Lifecycle::::from_packed(curr_lifecycle).state; + + test_println!( + "-> InitGuard::release; lifecycle={:#x}; state={:?}; refs={:?};", + curr_lifecycle, + state, + refs, + ); + + debug_assert!(state == State::Marked || thread::panicking(), "state was not MARKED; someone else has removed the slot while we have exclusive access!\nactual={:?}", state); + debug_assert!(refs.value == 0 || thread::panicking(), "ref count was not 0; someone else has referenced the slot while we have exclusive access!\nactual={:?}", refs); + let new_lifecycle = self.generation().pack(State::Removing as usize); + + match slot.lifecycle.compare_exchange( + curr_lifecycle, + new_lifecycle, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => { + test_println!("-> InitGuard::RELEASE: done!"); + return true; + } + Err(actual) => { + debug_assert!(thread::panicking(), "we should not have to retry this CAS!"); + test_println!("-> InitGuard::release; retry, actual={:#x}", actual); + curr_lifecycle = actual; + } + } + } + } +} + // === helpers === #[inline(always)] diff --git a/src/pool.rs b/src/pool.rs index 4b01517..efddebe 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,7 +1,15 @@ +//! A lock-free concurrent object pool. +//! +//! See the [`Pool` type's documentation][pool] for details on the object pool API and how +//! it differs from the [`Slab`] API. +//! +//! [pool]: ../struct.Pool.html +//! [`Slab`]: ../struct.Slab.html use crate::{ cfg::{self, CfgPrivate, DefaultConfig}, clear::Clear, page, shard, + sync::atomic, tid::Tid, Pack, Shard, }; @@ -31,17 +39,30 @@ use std::{fmt, marker::PhantomData}; /// # use sharded_slab::Pool; /// let pool: Pool = Pool::new(); /// -/// let key = pool.create(|item| item.push_str("hello world")).unwrap(); +/// let key = pool.create_with(|item| item.push_str("hello world")).unwrap(); /// assert_eq!(pool.get(key).unwrap(), String::from("hello world")); /// ``` /// -/// Pool entries can be cleared either by manually calling [`Pool::clear`]. This marks the entry to +/// Create a new pooled item, returning a guard that allows mutable access: +/// ``` +/// # use sharded_slab::Pool; +/// let pool: Pool = Pool::new(); +/// +/// let mut guard = pool.create().unwrap(); +/// let key = guard.key(); +/// guard.push_str("hello world"); +/// +/// drop(guard); // release the guard, allowing immutable access. +/// assert_eq!(pool.get(key).unwrap(), String::from("hello world")); +/// ``` +/// +/// Pool entries can be cleared by calling [`Pool::clear`]. This marks the entry to /// be cleared when the guards referencing to it are dropped. /// ``` /// # use sharded_slab::Pool; /// let pool: Pool = Pool::new(); /// -/// let key = pool.create(|item| item.push_str("hello world")).unwrap(); +/// let key = pool.create_with(|item| item.push_str("hello world")).unwrap(); /// /// // Mark this entry to be cleared. /// pool.clear(key); @@ -94,7 +115,7 @@ where /// While the guard exists, it indicates to the pool that the item the guard references is /// currently being accessed. If the item is removed from the pool while the guard exists, the /// removal will be deferred until all guards are dropped. -pub struct PoolGuard<'a, T, C> +pub struct Ref<'a, T, C> where T: Clear + Default, C: cfg::Config, @@ -104,6 +125,23 @@ where key: usize, } +/// A guard that allows exclusive mutable access to an object in a pool. +/// +/// While the guard exists, it indicates to the pool that the item the guard +/// references is currently being accessed. If the item is removed from the pool +/// while a guard exists, the removal will be deferred until the guard is +/// dropped. The slot cannot be accessed by other threads while it is accessed +/// mutably. +pub struct RefMut<'a, T, C = DefaultConfig> +where + T: Clear + Default, + C: cfg::Config, +{ + inner: page::slot::InitGuard, + shard: &'a Shard, + key: usize, +} + impl Pool where T: Clear + Default, @@ -124,7 +162,8 @@ where /// [`Slab::insert`]: struct.Slab.html#method.insert pub const USED_BITS: usize = C::USED_BITS; - /// Creates a new object in the pool, returning a key that can be used to access it. + /// Creates a new object in the pool, returning an [`RefMut`] guard that + /// may be used to mutate the new object. /// /// If this function returns `None`, then the shard for the current thread is full and no items /// can be added until some are removed, or the maximum number of shards has been reached. @@ -132,20 +171,68 @@ where /// # Examples /// ```rust /// # use sharded_slab::Pool; + /// # use std::thread; /// let pool: Pool = Pool::new(); - /// let key = pool.create(|item| item.push_str("Hello")).unwrap(); - /// assert_eq!(pool.get(key).unwrap(), String::from("Hello")); + /// + /// // Create a new pooled item, returning a guard that allows mutable + /// // access to the new item. + /// let mut item = pool.create().unwrap(); + /// // Return a key that allows indexing the created item once the guard + /// // has been dropped. + /// let key = item.key(); + /// + /// // Mutate the item. + /// item.push_str("Hello"); + /// // Drop the guard, releasing mutable access to the new item. + /// drop(item); + /// + /// /// Other threads may now (immutably) access the item using the returned key. + /// thread::spawn(move || { + /// assert_eq!(pool.get(key).unwrap(), String::from("Hello")); + /// }).join().unwrap(); /// ``` - pub fn create(&self, initializer: impl FnOnce(&mut T)) -> Option { + /// + /// [`RefMut`]: pool/struct.RefMut.html + pub fn create(&self) -> Option> { let (tid, shard) = self.shards.current(); - let mut init = Some(initializer); test_println!("pool: create {:?}", tid); - shard - .init_with(|slot| { - let init = init.take().expect("initializer will only be called once"); - slot.initialize_state(init) - }) - .map(|idx| tid.pack(idx)) + let (key, inner) = shard.init_with(|idx, slot| { + let guard = slot.init()?; + let gen = guard.generation(); + Some((gen.pack(idx), guard)) + })?; + Some(RefMut { + inner, + key: tid.pack(key), + shard, + }) + } + + /// Creates a new object in the pool with the provided initializer, + /// returning a key that may be used to access the new object. + /// + /// If this function returns `None`, then the shard for the current thread is full and no items + /// can be added until some are removed, or the maximum number of shards has been reached. + /// + /// # Examples + /// ```rust + /// # use sharded_slab::Pool; + /// # use std::thread; + /// let pool: Pool = Pool::new(); + /// + /// // Create a new pooled item, returning its integer key. + /// let key = pool.create_with(|s| s.push_str("Hello")).unwrap(); + /// + /// /// Other threads may now (immutably) access the item using the key. + /// thread::spawn(move || { + /// assert_eq!(pool.get(key).unwrap(), String::from("Hello")); + /// }).join().unwrap(); + /// ``` + pub fn create_with(&self, init: impl FnOnce(&mut T)) -> Option { + test_println!("pool: create_with"); + let mut guard = self.create()?; + init(&mut guard); + Some(guard.key()) } /// Return a reference to the value associated with the given key. @@ -156,20 +243,19 @@ where /// /// ```rust /// # use sharded_slab::Pool; - /// let pool: Pool = sharded_slab::Pool::new(); - /// let key = pool.create(|item| item.push_str("hello world")).unwrap(); + /// let pool: Pool = Pool::new(); + /// let key = pool.create_with(|item| item.push_str("hello world")).unwrap(); /// /// assert_eq!(pool.get(key).unwrap(), String::from("hello world")); /// assert!(pool.get(12345).is_none()); /// ``` - pub fn get(&self, key: usize) -> Option> { + pub fn get(&self, key: usize) -> Option> { let tid = C::unpack_tid(key); test_println!("pool: get{:?}; current={:?}", tid, Tid::::current()); let shard = self.shards.get(tid.as_usize())?; - let inner = shard.get(key, |x| x)?; - - Some(PoolGuard { inner, shard, key }) + let inner = shard.with_slot(key, |slot| slot.get(C::unpack_gen(key)))?; + Some(Ref { inner, shard, key }) } /// Remove the value using the storage associated with the given key from the pool, returning @@ -184,7 +270,12 @@ where /// ```rust /// # use sharded_slab::Pool; /// let pool: Pool = Pool::new(); - /// let key = pool.create(|item| item.push_str("hello world")).unwrap(); + /// + /// // Check out an item from the pool. + /// let mut item = pool.create().unwrap(); + /// let key = item.key(); + /// item.push_str("hello world"); + /// drop(item); /// /// assert_eq!(pool.get(key).unwrap(), String::from("hello world")); /// @@ -196,7 +287,7 @@ where /// # use sharded_slab::Pool; /// let pool: Pool = Pool::new(); /// - /// let key = pool.create(|item| item.push_str("Hello world!")).unwrap(); + /// let key = pool.create_with(|item| item.push_str("Hello world!")).unwrap(); /// /// // Clearing a key that doesn't exist in the `Pool` will return `false` /// assert_eq!(pool.clear(key + 69420), false); @@ -259,7 +350,7 @@ where } } -impl<'a, T, C> PoolGuard<'a, T, C> +impl<'a, T, C> Ref<'a, T, C> where T: Clear + Default, C: cfg::Config, @@ -268,9 +359,14 @@ where pub fn key(&self) -> usize { self.key } + + #[inline] + fn value(&self) -> &T { + self.inner.value() + } } -impl<'a, T, C> std::ops::Deref for PoolGuard<'a, T, C> +impl<'a, T, C> std::ops::Deref for Ref<'a, T, C> where T: Clear + Default, C: cfg::Config, @@ -278,18 +374,17 @@ where type Target = T; fn deref(&self) -> &Self::Target { - self.inner.item() + self.value() } } -impl<'a, T, C> Drop for PoolGuard<'a, T, C> +impl<'a, T, C> Drop for Ref<'a, T, C> where T: Clear + Default, C: cfg::Config, { fn drop(&mut self) { - use crate::sync::atomic; - test_println!(" -> drop PoolGuard: clearing data"); + test_println!("-> drop Ref: try clearing data"); if self.inner.release() { atomic::fence(atomic::Ordering::Acquire); if Tid::::current().as_usize() == self.shard.tid { @@ -301,22 +396,164 @@ where } } -impl<'a, T, C> fmt::Debug for PoolGuard<'a, T, C> +impl<'a, T, C> fmt::Debug for Ref<'a, T, C> +where + T: fmt::Debug + Clear + Default, + C: cfg::Config, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(self.value(), f) + } +} + +impl<'a, T, C> PartialEq for Ref<'a, T, C> +where + T: PartialEq + Clear + Default, + C: cfg::Config, +{ + fn eq(&self, other: &T) -> bool { + *self.value() == *other + } +} + +// === impl GuardMut === + +impl<'a, T, C: cfg::Config> RefMut<'a, T, C> +where + T: Clear + Default, + C: cfg::Config, +{ + /// Returns the key used to access the guard. + pub fn key(&self) -> usize { + self.key + } + + /// Downgrades the mutable guard to an immutable guard, allowing access to + /// the pooled value from other threads. + /// + /// ## Examples + /// + /// ``` + /// # use sharded_slab::Pool; + /// # use std::{sync::Arc, thread}; + /// let pool = Arc::new(Pool::::new()); + /// + /// let mut guard_mut = pool.create().unwrap(); + /// let key = guard_mut.key(); + /// guard_mut.push_str("Hello"); + /// + /// // The pooled string is currently borrowed mutably, so other threads + /// // may not access it. + /// let pool2 = pool.clone(); + /// thread::spawn(move || { + /// assert!(pool2.get(key).is_none()) + /// }).join().unwrap(); + /// + /// // Downgrade the guard to an immutable reference. + /// let guard = guard_mut.downgrade(); + /// + /// // Now, other threads may also access the pooled value. + /// let pool2 = pool.clone(); + /// thread::spawn(move || { + /// let guard = pool2.get(key) + /// .expect("the item may now be referenced by other threads"); + /// assert_eq!(guard, String::from("Hello")); + /// }).join().unwrap(); + /// + /// // We can still access the value immutably through the downgraded guard. + /// assert_eq!(guard, String::from("Hello")); + /// ``` + pub fn downgrade(mut self) -> Ref<'a, T, C> { + unsafe { + self.inner.release(); + } + let inner = self + .shard + .with_slot(self.key, |slot| slot.get(C::unpack_gen(self.key))) + .expect("generation advanced before a value was released?"); + Ref { + inner, + shard: self.shard, + key: self.key, + } + } + + #[inline] + fn value(&self) -> &T { + unsafe { + // Safety: we are holding a reference to the shard which keeps the + // pointed slot alive. The returned reference will not outlive + // `self`. + self.inner.value() + } + } +} + +impl<'a, T, C: cfg::Config> std::ops::Deref for RefMut<'a, T, C> +where + T: Clear + Default, + C: cfg::Config, +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value() + } +} + +impl<'a, T, C> std::ops::DerefMut for RefMut<'a, T, C> +where + T: Clear + Default, + C: cfg::Config, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { + // Safety: we are holding a reference to the shard which keeps the + // pointed slot alive. The returned reference will not outlive `self`. + self.inner.value_mut() + } + } +} + +impl<'a, T, C> Drop for RefMut<'a, T, C> +where + T: Clear + Default, + C: cfg::Config, +{ + fn drop(&mut self) { + test_println!(" -> drop RefMut: try clearing data"); + let should_clear = unsafe { + // Safety: we are holding a reference to the shard which keeps the + // pointed slot alive. The returned reference will not outlive `self`. + self.inner.release() + }; + if should_clear { + atomic::fence(atomic::Ordering::Acquire); + if Tid::::current().as_usize() == self.shard.tid { + self.shard.mark_clear_local(self.key); + } else { + self.shard.mark_clear_remote(self.key); + } + } + } +} + +impl<'a, T, C> fmt::Debug for RefMut<'a, T, C> where T: fmt::Debug + Clear + Default, C: cfg::Config, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(self.inner.item(), f) + fmt::Debug::fmt(self.value(), f) } } -impl<'a, T, C> PartialEq for PoolGuard<'a, T, C> +impl<'a, T, C> PartialEq for RefMut<'a, T, C> where T: PartialEq + Clear + Default, C: cfg::Config, { fn eq(&self, other: &T) -> bool { - *self.inner.item() == *other + self.value().eq(other) } } diff --git a/src/shard.rs b/src/shard.rs index c5f8bcd..cc95b38 100644 --- a/src/shard.rs +++ b/src/shard.rs @@ -68,11 +68,11 @@ where C: cfg::Config, { #[inline(always)] - pub(crate) fn get( - &self, + pub(crate) fn with_slot<'a, U>( + &'a self, idx: usize, - f: impl FnOnce(&T) -> &U, - ) -> Option> { + f: impl FnOnce(&'a page::Slot) -> Option, + ) -> Option { debug_assert_eq!(Tid::::from_packed(idx).as_usize(), self.tid); let (addr, page_index) = page::indices::(idx); @@ -81,7 +81,7 @@ where return None; } - self.shared[page_index].get(addr, idx, f) + self.shared[page_index].with_slot(addr, f) } pub(crate) fn new(tid: usize) -> Self { @@ -161,18 +161,18 @@ where T: Clear + Default, C: cfg::Config, { - pub(crate) fn init_with(&self, mut func: F) -> Option - where - F: FnMut(&page::slot::Slot) -> Option>, - { - // Can we fit the value into an existing page? + pub(crate) fn init_with( + &self, + mut init: impl FnMut(usize, &page::Slot) -> Option, + ) -> Option { + // Can we fit the value into an exist`ing page? for (page_idx, page) in self.shared.iter().enumerate() { let local = self.local(page_idx); test_println!("-> page {}; {:?}; {:?}", page_idx, local, page); - if let Some(poff) = page.init_with(local, &mut func) { - return Some(poff); + if let Some(res) = page.init_with(local, &mut init) { + return Some(res); } } diff --git a/src/tests/loom_pool.rs b/src/tests/loom_pool.rs index 47e9ba2..36b187b 100644 --- a/src/tests/loom_pool.rs +++ b/src/tests/loom_pool.rs @@ -73,7 +73,9 @@ fn dont_drop() { let pool: Pool = Pool::new(); let (item1, value) = DontDropMe::new(1); test_println!("-> dont_drop: Inserting into pool {}", item1.id); - let idx = pool.create(move |item| *item = value).expect("Create"); + let idx = pool + .create_with(move |item| *item = value) + .expect("create_with"); item1.assert_not_clear(); @@ -85,13 +87,15 @@ fn dont_drop() { } #[test] -fn concurrent_create_clear() { - run_model("concurrent_create_clear", || { +fn concurrent_create_with_clear() { + run_model("concurrent_create_with_clear", || { let pool: Arc> = Arc::new(Pool::new()); let pair = Arc::new((Mutex::new(None), Condvar::new())); let (item1, value) = DontDropMe::new(1); - let idx1 = pool.create(move |item| *item = value).expect("Create"); + let idx1 = pool + .create_with(move |item| *item = value) + .expect("create_with"); let p = pool.clone(); let pair2 = pair.clone(); let test_value = item1.clone(); @@ -130,7 +134,9 @@ fn racy_clear() { let pool = Arc::new(Pool::new()); let (item, value) = DontDropMe::new(1); - let idx = pool.create(move |item| *item = value).expect("Create"); + let idx = pool + .create_with(move |item| *item = value) + .expect("create_with"); assert_eq!(pool.get(idx).unwrap().0.id, item.id); let p = pool.clone(); @@ -156,12 +162,16 @@ fn clear_local_and_reuse() { let pool = Arc::new(Pool::new_with_config::()); let idx1 = pool - .create(|item: &mut String| { + .create_with(|item: &mut String| { item.push_str("hello world"); }) - .expect("create"); - let idx2 = pool.create(|item| item.push_str("foo")).expect("create"); - let idx3 = pool.create(|item| item.push_str("bar")).expect("create"); + .expect("create_with"); + let idx2 = pool + .create_with(|item| item.push_str("foo")) + .expect("create_with"); + let idx3 = pool + .create_with(|item| item.push_str("bar")) + .expect("create_with"); assert_eq!(pool.get(idx1).unwrap(), String::from("hello world")); assert_eq!(pool.get(idx2).unwrap(), String::from("foo")); @@ -170,10 +180,148 @@ fn clear_local_and_reuse() { let first = idx1 & (!crate::page::slot::Generation::::MASK); assert!(pool.clear(idx1)); - let idx1 = pool.create(move |item| item.push_str("h")).expect("create"); + let idx1 = pool + .create_with(move |item| item.push_str("h")) + .expect("create_with"); let second = idx1 & (!crate::page::slot::Generation::::MASK); assert_eq!(first, second); assert!(pool.get(idx1).unwrap().capacity() >= 11); }) } + +#[test] +fn create_mut_guard_prevents_access() { + run_model("create_mut_guard_prevents_access", || { + let pool = Arc::new(Pool::::new()); + let guard = pool.create().unwrap(); + let key: usize = guard.key(); + + let pool2 = pool.clone(); + thread::spawn(move || { + assert!(pool2.get(key).is_none()); + }) + .join() + .unwrap(); + }); +} + +#[test] +fn create_mut_guard() { + run_model("create_mut_guard", || { + let pool = Arc::new(Pool::::new()); + let mut guard = pool.create().unwrap(); + let key: usize = guard.key(); + + let pool2 = pool.clone(); + let t1 = thread::spawn(move || { + test_dbg!(pool2.get(key)); + }); + + guard.push_str("Hello world"); + drop(guard); + + t1.join().unwrap(); + }); +} + +#[test] +fn create_mut_guard_2() { + run_model("create_mut_guard_2", || { + let pool = Arc::new(Pool::::new()); + let mut guard = pool.create().unwrap(); + let key: usize = guard.key(); + + let pool2 = pool.clone(); + let pool3 = pool.clone(); + let t1 = thread::spawn(move || { + test_dbg!(pool2.get(key)); + }); + + guard.push_str("Hello world"); + let t2 = thread::spawn(move || { + test_dbg!(pool3.get(key)); + }); + drop(guard); + + t1.join().unwrap(); + t2.join().unwrap(); + }); +} + +#[test] +fn create_mut_guard_downgrade() { + run_model("create_mut_guard_downgrade", || { + let pool = Arc::new(Pool::::new()); + let mut guard = pool.create().unwrap(); + let key: usize = guard.key(); + + let pool2 = pool.clone(); + let pool3 = pool.clone(); + let t1 = thread::spawn(move || { + test_dbg!(pool2.get(key)); + }); + + guard.push_str("Hello world"); + let guard = guard.downgrade(); + let t2 = thread::spawn(move || { + test_dbg!(pool3.get(key)); + }); + + t1.join().unwrap(); + t2.join().unwrap(); + assert_eq!(guard, "Hello world".to_owned()); + }); +} + +#[test] +fn create_mut_guard_downgrade_clear() { + run_model("create_mut_guard_downgrade_clear", || { + let pool = Arc::new(Pool::::new()); + let mut guard = pool.create().unwrap(); + let key: usize = guard.key(); + + let pool2 = pool.clone(); + + guard.push_str("Hello world"); + let guard = guard.downgrade(); + let pool3 = pool.clone(); + let t1 = thread::spawn(move || { + test_dbg!(pool2.get(key)); + }); + let t2 = thread::spawn(move || { + test_dbg!(pool3.clear(key)); + }); + + assert_eq!(guard, "Hello world".to_owned()); + drop(guard); + + t1.join().unwrap(); + t2.join().unwrap(); + + assert!(pool.get(key).is_none()); + }); +} + +#[test] +fn create_mut_downgrade_during_clear() { + run_model("create_mut_downgrade_during_clear", || { + let pool = Arc::new(Pool::::new()); + let mut guard = pool.create().unwrap(); + let key: usize = guard.key(); + guard.push_str("Hello world"); + + let pool2 = pool.clone(); + let guard = guard.downgrade(); + let t1 = thread::spawn(move || { + test_dbg!(pool2.clear(key)); + }); + + t1.join().unwrap(); + + assert_eq!(guard, "Hello world".to_owned()); + drop(guard); + + assert!(pool.get(key).is_none()); + }); +} diff --git a/src/tests/loom_slab.rs b/src/tests/loom_slab.rs index 3122240..7d7c0fc 100644 --- a/src/tests/loom_slab.rs +++ b/src/tests/loom_slab.rs @@ -265,7 +265,7 @@ fn concurrent_remove_remote_and_reuse() { } struct SetDropped { - value: usize, + val: usize, dropped: std::sync::Arc, } @@ -274,10 +274,10 @@ struct AssertDropped { } impl AssertDropped { - fn new(value: usize) -> (Self, SetDropped) { + fn new(val: usize) -> (Self, SetDropped) { let dropped = std::sync::Arc::new(AtomicBool::new(false)); let val = SetDropped { - value, + val, dropped: dropped.clone(), }; (Self { dropped }, val) @@ -358,7 +358,7 @@ fn remove_remote_during_insert() { let t1 = thread::spawn(move || { let g = slab2.get(idx); - assert_ne!(g.as_ref().map(|v| v.value), Some(2)); + assert_ne!(g.as_ref().map(|v| v.val), Some(2)); drop(g); });