diff --git a/imhamt/src/hamt.rs b/imhamt/src/hamt.rs index cb99d798d..632a7b895 100644 --- a/imhamt/src/hamt.rs +++ b/imhamt/src/hamt.rs @@ -4,6 +4,7 @@ use super::node::{ update_rec, Entry, LookupRet, Node, NodeIter, }; pub use super::operation::{InsertError, RemoveError, ReplaceError, UpdateError}; +use std::borrow::Borrow; use std::iter::FromIterator; use std::marker::PhantomData; use std::mem::swap; @@ -177,8 +178,12 @@ impl Hamt { impl Hamt { /// Try to get the element related to key K - pub fn lookup(&self, k: &K) -> Option<&V> { - let h = HashedKey::compute(self.hasher, &k); + pub fn lookup(&self, k: &Q) -> Option<&V> + where + K: Borrow, + Q: Hash + PartialEq, + { + let h = HashedKey::compute(self.hasher, k.borrow()); let mut n = &self.root; let mut lvl = 0; loop { @@ -193,7 +198,10 @@ impl Hamt { } } /// Check if the key is contained into the HAMT - pub fn contains_key(&self, k: &K) -> bool { + pub fn contains_key(&self, k: &Q) -> bool + where + K: Borrow, + { self.lookup(k).map_or_else(|| false, |_| true) } pub fn iter(&self) -> HamtIter { diff --git a/imhamt/src/node/reference.rs b/imhamt/src/node/reference.rs index 6a4ca6201..8ec8e6db6 100644 --- a/imhamt/src/node/reference.rs +++ b/imhamt/src/node/reference.rs @@ -3,6 +3,7 @@ use super::super::hash::{HashedKey, LevelIndex}; use super::super::helper; use super::super::operation::*; use super::super::sharedref::SharedRef; +use std::borrow::Borrow; use std::slice; @@ -305,11 +306,11 @@ pub enum LookupRet<'a, K, V> { ContinueIn(&'a Node), } -pub fn lookup_one<'a, K: PartialEq, V>( +pub fn lookup_one<'a, Q: PartialEq, K: PartialEq + Borrow, V>( node: &'a Node, h: &HashedKey, lvl: usize, - k: &K, + k: &Q, ) -> LookupRet<'a, K, V> { let level_hash = h.level_index(lvl); let idx = node.bitmap.get_index_sparse(level_hash); @@ -318,7 +319,7 @@ pub fn lookup_one<'a, K: PartialEq, V>( } else { match &(node.get_child(idx)).as_ref() { Entry::Leaf(lh, lk, lv) => { - if lh == h && lk == k { + if lh == h && lk.borrow() == k { LookupRet::Found(lv) } else { LookupRet::NotFound @@ -328,7 +329,7 @@ pub fn lookup_one<'a, K: PartialEq, V>( if lh != h { LookupRet::NotFound } else { - match col.0.iter().find(|(lk, _)| lk == k) { + match col.0.iter().find(|(lk, _)| lk.borrow() == k) { None => LookupRet::NotFound, Some(lkv) => LookupRet::Found(&lkv.1), }