Skip to content

Commit

Permalink
Speedup hashmap by speeding up node creation
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurprs committed Mar 30, 2024
1 parent 51847ea commit 73c98ae
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 23 deletions.
19 changes: 19 additions & 0 deletions src/fakepool.rs
Expand Up @@ -4,7 +4,9 @@

#![allow(dead_code)]

use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::rc::Rc as RRc;
use std::sync::Arc as RArc;
Expand Down Expand Up @@ -56,6 +58,11 @@ impl<A> Rc<A> {
Rc(RRc::new(value))
}

#[inline(always)]
pub(crate) fn new_uninit(_pool: &Pool<A>) -> Rc<UnsafeCell<MaybeUninit<A>>> {
Rc(RRc::new(UnsafeCell::new(MaybeUninit::uninit())))
}

#[inline(always)]
pub(crate) fn clone_from(_pool: &Pool<A>, value: &A) -> Self
where
Expand Down Expand Up @@ -141,6 +148,11 @@ impl<A> Arc<A> {
Self(RArc::new(value))
}

#[inline(always)]
pub(crate) fn new_uninit(_pool: &Pool<A>) -> Arc<UnsafeCell<MaybeUninit<A>>> {
Arc(RArc::new(UnsafeCell::new(MaybeUninit::uninit())))
}

#[inline(always)]
pub(crate) fn clone_from(_pool: &Pool<A>, value: &A) -> Self
where
Expand Down Expand Up @@ -210,6 +222,8 @@ where
// Triomphe Arc
#[cfg(feature = "triomphe")]
pub(crate) mod triomphe {
use std::cell::UnsafeCell;

use super::*;

#[derive(Default)]
Expand All @@ -229,6 +243,11 @@ pub(crate) mod triomphe {
Self(::triomphe::Arc::new(value))
}

#[inline(always)]
pub(crate) fn new_uninit(_pool: &Pool<A>) -> Arc<UnsafeCell<MaybeUninit<A>>> {
Arc(::triomphe::Arc::new(UnsafeCell::new(MaybeUninit::uninit())))
}

#[inline(always)]
pub(crate) fn clone_from(_pool: &Pool<A>, value: &A) -> Self
where
Expand Down
3 changes: 1 addition & 2 deletions src/hash/set.rs
Expand Up @@ -26,8 +26,7 @@ use std::collections::hash_map::RandomState;
use std::collections::{self, BTreeSet};
use std::fmt::{Debug, Error, Formatter};
use std::hash::{BuildHasher, Hash};
use std::iter::FusedIterator;
use std::iter::{FromIterator, IntoIterator, Sum};
use std::iter::{FromIterator, FusedIterator, Sum};
use std::ops::{Add, Deref, Mul};

use crate::nodes::hamt::{hash_key, Drain as NodeDrain, HashValue, Iter as NodeIter, Node};
Expand Down
64 changes: 45 additions & 19 deletions src/nodes/hamt.rs
Expand Up @@ -3,6 +3,7 @@
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

use std::borrow::Borrow;
use std::cell::UnsafeCell;
use std::fmt;
use std::hash::{BuildHasher, Hash, Hasher};
use std::iter::FusedIterator;
Expand Down Expand Up @@ -104,10 +105,6 @@ impl<A> Entry<A> {
_ => panic!("nodes::hamt::Entry::unwrap_value: unwrapped a non-value"),
}
}

fn from_node(pool: &Pool<Node<A>>, node: Node<A>) -> Self {
Entry::Node(PoolRef::new(pool, node))
}
}

impl<A> From<CollisionNode<A>> for Entry<A> {
Expand All @@ -123,37 +120,64 @@ impl<A> Default for Node<A> {
}

impl<A> Node<A> {
#[inline]
#[inline(always)]
pub(crate) fn new() -> Self {
Node {
data: SparseChunk::new(),
}
}

/// Special constructor to allow initializing Nodes w/o incurring multiple memory copies.
/// These copies really slow things down once Node crosses a certain size threshold and copies become calls to memcopy.
#[inline]
fn with(pool: &Pool<Self>, with: impl FnOnce(&mut Self)) -> PoolRef<Self> {
let result: PoolRef<UnsafeCell<mem::MaybeUninit<Node<A>>>> = PoolRef::new_uninit(pool);
#[allow(unsafe_code)]
unsafe {
// Initialize the MaybeUninit node
(&mut *result.get()).write(Node::new());
// Safety: UnsafeCell<Self> and UnsafeCell<MaybeUninit<Self>> have the same memory representation
let result: PoolRef<UnsafeCell<Self>> = mem::transmute(result);
let mut_ptr = UnsafeCell::raw_get(&*result);
with(&mut *mut_ptr);
// Safety UnsafeCell<Self> and Self have the same memory representation
mem::transmute(result)
}
}

#[inline]
fn len(&self) -> usize {
self.data.len()
}

#[inline]
pub(crate) fn unit(index: usize, value: Entry<A>) -> Self {
Node {
data: SparseChunk::unit(index, value),
}
fn unit(pool: &Pool<Node<A>>, index: usize, value: Entry<A>) -> PoolRef<Self> {
Self::with(pool, |this| {
this.data.insert(index, value);
})
}

#[inline]
pub(crate) fn pair(index1: usize, value1: Entry<A>, index2: usize, value2: Entry<A>) -> Self {
Node {
data: SparseChunk::pair(index1, value1, index2, value2),
}
fn pair(
pool: &Pool<Node<A>>,
index1: usize,
value1: Entry<A>,
index2: usize,
value2: Entry<A>,
) -> PoolRef<Self> {
Self::with(pool, |this| {
this.data.insert(index1, value1);
this.data.insert(index2, value2);
})
}

#[inline]
pub(crate) fn single_child(pool: &Pool<Node<A>>, index: usize, node: Self) -> Self {
Node {
data: SparseChunk::unit(index, Entry::from_node(pool, node)),
}
pub(crate) fn single_child(
pool: &Pool<Node<A>>,
index: usize,
node: PoolRef<Self>,
) -> PoolRef<Self> {
Self::unit(pool, index, Entry::Node(node))
}

fn pop(&mut self) -> Entry<A> {
Expand All @@ -169,12 +193,13 @@ impl<A: HashValue> Node<A> {
value2: A,
hash2: HashBits,
shift: usize,
) -> Self {
) -> PoolRef<Self> {
let index1 = mask(hash1, shift) as usize;
let index2 = mask(hash2, shift) as usize;
if index1 != index2 {
// Both values fit on the same level.
Node::pair(
pool,
index1,
Entry::Value(value1, hash1),
index2,
Expand All @@ -183,6 +208,7 @@ impl<A: HashValue> Node<A> {
} else if shift + HASH_SHIFT >= HASH_WIDTH {
// If we're at the bottom, we've got a collision.
Node::unit(
pool,
index1,
Entry::from(CollisionNode::new(hash1, value1, value2)),
)
Expand Down Expand Up @@ -307,7 +333,7 @@ impl<A: HashValue> Node<A> {
hash,
shift + HASH_SHIFT,
);
unsafe { ptr::write(entry, Entry::from_node(pool, node)) };
unsafe { ptr::write(entry, Entry::Node(node)) };
} else {
unreachable!()
}
Expand Down
2 changes: 1 addition & 1 deletion src/ord/map.rs
Expand Up @@ -22,7 +22,7 @@ use std::cmp::Ordering;
use std::collections;
use std::fmt::{Debug, Error, Formatter};
use std::hash::{BuildHasher, Hash, Hasher};
use std::iter::{FromIterator, Iterator, Sum};
use std::iter::{FromIterator, Sum};
use std::mem;
use std::ops::{Add, Index, IndexMut, RangeBounds};

Expand Down
2 changes: 1 addition & 1 deletion src/ord/set.rs
Expand Up @@ -20,7 +20,7 @@ use std::cmp::Ordering;
use std::collections;
use std::fmt::{Debug, Error, Formatter};
use std::hash::{BuildHasher, Hash, Hasher};
use std::iter::{FromIterator, IntoIterator, Sum};
use std::iter::{FromIterator, Sum};
use std::ops::{Add, Deref, Mul, RangeBounds};

use crate::hashset::HashSet;
Expand Down

0 comments on commit 73c98ae

Please sign in to comment.