Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: unbundle array size constraint from hash map bucket array (second attempt) #754

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 99 additions & 58 deletions Std/Data/HashMap/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ structure Imp (α : Type u) (β : Type v) where
use the size to determine when to resize the map. -/
size : Nat
/-- The bucket array of the `HashMap`. -/
buckets : Imp.Buckets α β
buckets : Array (AssocList α β)

def Buckets.mk' : {m : Imp α β // 0 < m.2.size} → Imp.Buckets α β
| ⟨m, h⟩ => ⟨m.2, h⟩

namespace Imp

Expand All @@ -81,16 +84,16 @@ A "load factor" of 0.75 is the usual standard for hash maps, so we return `capac
capacity * 4 / 3

/-- Constructs an empty hash map with the specified nonzero number of buckets. -/
@[inline] def empty' (buckets := 8) (h : 0 < buckets := by decide) : Imp α β :=
⟨0, .mk buckets h
@[inline] def empty' (buckets := 8) : Imp α β :=
⟨0, mkArray buckets .nil

/-- Constructs an empty hash map with the specified target capacity. -/
def empty (capacity := 0) : Imp α β :=
def empty (capacity := 0) : {m : Imp α β // 0 < m.2.size} :=
let nbuckets := numBucketsForCapacity capacity
let n : {n : Nat // 0 < n} :=
if h : nbuckets = 0 then ⟨8, by decide⟩
else ⟨nbuckets, Nat.zero_lt_of_ne_zero h⟩
empty' n n.2
empty' n, by simpa [empty'] using n.2

/-- Calculates the bucket index from a hash value `u`. -/
def mkIdx {n : Nat} (h : 0 < n) (u : USize) : {u : USize // u.toNat < n} :=
Expand Down Expand Up @@ -118,27 +121,29 @@ already in the array, which is appropriate when reinserting elements into the ar
h.buckets.1.forM fun b => b.forM f

/-- Given a key `a`, returns a key-value pair in the map whose key compares equal to `a`. -/
def findEntry? [BEq α] [Hashable α] (m : Imp α β) (a : α) : Option (α × β) :=
def findEntry? [BEq α] [Hashable α] (m : Imp α β) (h : 0 < m.2.size) (a : α) : Option (α × β) :=
let ⟨_, buckets⟩ := m
let ⟨i, h⟩ := mkIdx buckets.2 (hash a |>.toUSize)
buckets.1[i].findEntry? a
let ⟨i, h⟩ := mkIdx h (hash a |>.toUSize)
buckets[i].findEntry? a

/-- Looks up an element in the map with key `a`. -/
def find? [BEq α] [Hashable α] (m : Imp α β) (a : α) : Option β :=
def find? [BEq α] [Hashable α] (m : Imp α β) (h : 0 < m.2.size) (a : α) : Option β :=
let ⟨_, buckets⟩ := m
let ⟨i, h⟩ := mkIdx buckets.2 (hash a |>.toUSize)
buckets.1[i].find? a
let ⟨i, h⟩ := mkIdx h (hash a |>.toUSize)
buckets[i].find? a

/-- Returns true if the element `a` is in the map. -/
def contains [BEq α] [Hashable α] (m : Imp α β) (a : α) : Bool :=
def contains [BEq α] [Hashable α] (m : Imp α β) (h : 0 < m.2.size) (a : α) : Bool :=
let ⟨_, buckets⟩ := m
let ⟨i, h⟩ := mkIdx buckets.2 (hash a |>.toUSize)
buckets.1[i].contains a
let ⟨i, h⟩ := mkIdx h (hash a |>.toUSize)
buckets[i].contains a

/-- Copies all the entries from `buckets` into a new hash map with a larger capacity. -/
def expand [Hashable α] (size : Nat) (buckets : Buckets α β) : Imp α β :=
def expand [Hashable α] (size : Nat) (buckets : Buckets α β) :
{m : Imp α β // 0 < m.2.size} :=
let nbuckets := buckets.1.size * 2
{ size, buckets := go 0 buckets.1 (.mk nbuckets (Nat.mul_pos buckets.2 (by decide))) }
let ⟨arr, h⟩ := go 0 buckets.1 (.mk nbuckets (Nat.mul_pos buckets.2 (by decide)))
⟨{ size, buckets := arr }, h⟩
where
/-- Inner loop of `expand`. Copies elements `source[i:]` into `target`,
destroying `source` in the process. -/
Expand All @@ -158,53 +163,70 @@ where
Inserts key-value pair `a, b` into the map.
If an element equal to `a` is already in the map, it is replaced by `b`.
-/
@[inline] def insert [BEq α] [Hashable α] (m : Imp α β) (a : α) (b : β) : Imp α β :=
let ⟨size, buckets⟩ := m
let ⟨i, h⟩ := mkIdx buckets.2 (hash a |>.toUSize)
let bkt := buckets.1[i]
@[inline] def insert [BEq α] [Hashable α] (m : Imp α β) (hm : 0 < m.2.size) (a : α) (b : β) :
{m : Imp α β // 0 < m.2.size} :=
let ⟨size, bucketArray⟩ := m
let buckets : Buckets α β := ⟨bucketArray, hm⟩
let ⟨i, h⟩ := mkIdx hm (hash a |>.toUSize)
let bkt := bucketArray[i]
bif bkt.contains a then
⟨size, buckets.update i (bkt.replace a b) h⟩
let ⟨newBucketArray, hnew⟩ := buckets.update i (bkt.replace a b) h
⟨⟨size, newBucketArray⟩, hnew⟩
else
let size' := size + 1
let buckets' := buckets.update i (.cons a b bkt) h
let buckets', hnew⟩ := buckets.update i (.cons a b bkt) h
if numBucketsForCapacity size' ≤ buckets.1.size then
{ size := size', buckets := buckets' }
{ size := size', buckets := buckets' }, hnew⟩
else
expand size' buckets'
expand size' buckets', hnew⟩

/--
Removes key `a` from the map. If it does not exist in the map, the map is returned unchanged.
-/
def erase [BEq α] [Hashable α] (m : Imp α β) (a : α) : Imp α β :=
let ⟨size, buckets⟩ := m
let ⟨i, h⟩ := mkIdx buckets.2 (hash a |>.toUSize)
let bkt := buckets.1[i]
bif bkt.contains a then ⟨size - 1, buckets.update i (bkt.erase a) h⟩ else ⟨size, buckets⟩
def erase [BEq α] [Hashable α] (m : Imp α β) (hm : 0 < m.2.size) (a : α) :
{m : Imp α β // 0 < m.2.size} :=
let ⟨size, bucketArray⟩ := m
let buckets : Buckets α β := ⟨bucketArray, hm⟩
let ⟨i, h⟩ := mkIdx hm (hash a |>.toUSize)
let bkt := bucketArray[i]
bif bkt.contains a then
let ⟨newBuckets, hnew⟩ := buckets.update i (bkt.erase a) h
⟨⟨size - 1, newBuckets⟩, hnew⟩
else
⟨⟨size, bucketArray⟩, hm⟩

/-- Map a function over the values in the map. -/
@[inline] def mapVal (f : α → β → γ) (self : Imp α β) : Imp α γ :=
{ size := self.size, buckets := self.buckets.mapVal f }
@[inline] def mapVal (f : α → β → γ) (self : Imp α β) (hm : 0 < self.2.size) :
{m : Imp α γ // 0 < m.2.size} :=
let ⟨size, bucketArray⟩ := self
let buckets : Buckets α β := ⟨bucketArray, hm⟩
let ⟨newBucketArray, hnew⟩ := buckets.mapVal f
⟨{ size := size, buckets := newBucketArray }, hnew⟩

/-- Performs an in-place edit of the value, ensuring that the value is used linearly. -/
def modify [BEq α] [Hashable α] (m : Imp α β) (a : α) (f : α → β → β) : Imp α β :=
let ⟨size, buckets⟩ := m
let ⟨i, h⟩ := mkIdx buckets.2 (hash a |>.toUSize)
let bkt := buckets.1[i]
def modify [BEq α] [Hashable α] (m : Imp α β) (hm : 0 < m.2.size) (a : α) (f : α → β → β) :
{m : Imp α β // 0 < m.2.size} :=
let ⟨size, bucketArray⟩ := m
let buckets : Buckets α β := ⟨bucketArray, hm⟩
let ⟨i, h⟩ := mkIdx hm (hash a |>.toUSize)
let bkt := bucketArray[i]
let buckets := buckets.update i .nil h -- for linearity
⟨size, buckets.update i (bkt.modify a f) ((Buckets.update_size ..).symm ▸ h)⟩
let ⟨newBucketArray, hnew⟩ := buckets.update i (bkt.modify a f) ((Buckets.update_size ..).symm ▸ h)
⟨⟨size, newBucketArray⟩, hnew⟩

/--
Applies `f` to each key-value pair `a, b` in the map. If it returns `some c` then
`a, c` is pushed into the new map; else the key is removed from the map.
-/
@[specialize] def filterMap {α : Type u} {β : Type v} {γ : Type w}
(f : α → β → Option γ) (m : Imp α β) : Imp α γ :=
let m' := m.buckets.1.mapM (m := StateT (ULift Nat) Id) (go .nil) |>.run ⟨0⟩ |>.run
(m : Imp α β) (hm : 0 < m.2.size) (f : α → β → Option γ) :
{m : Imp α γ // 0 < m.2.size} :=
let m' := m.buckets.mapM (m := StateT (ULift Nat) Id) (go .nil) |>.run ⟨0⟩ |>.run
have : m'.1.size > 0 := by
have := Array.size_mapM (m := StateT (ULift Nat) Id) (go .nil) m.buckets.1
have := Array.size_mapM (m := StateT (ULift Nat) Id) (go .nil) m.buckets
simp [SatisfiesM_StateT_eq, SatisfiesM_Id_eq] at this
simp [this, Id.run, StateT.run, m.2.2, m']
⟨m'.2.1, m'.1, this⟩
simp [this, Id.run, StateT.run, hm, m']
m'.2.1, m'.1, this⟩
where
/-- Inner loop of `filterMap`. Note that this reverses the bucket lists,
but this is fine since bucket lists are unordered. -/
Expand All @@ -215,8 +237,9 @@ where
| some c => go (.cons a c acc) l ⟨n.1 + 1⟩

/-- Constructs a map with the set of all pairs `a, b` such that `f` returns true. -/
@[inline] def filter (f : α → β → Bool) (m : Imp α β) : Imp α β :=
m.filterMap fun a b => bif f a b then some b else none
@[inline] def filter (f : α → β → Bool) (m : Imp α β) (hm : 0 < m.2.size) :
{m : Imp α β // 0 < m.2.size} :=
m.filterMap hm fun a b => bif f a b then some b else none

/--
The well-formedness invariant for a hash map. The first constructor is the real invariant,
Expand All @@ -228,17 +251,32 @@ inductive WF [BEq α] [Hashable α] : Imp α β → Prop where
* The `size` field should match the actual number of elements in the map
* The bucket array should be well-formed, meaning that if the hashable instance
is lawful then every element hashes to its index. -/
| mk : m.size = m.buckets.size → m.buckets.WF → WF m
| mk (m : Imp α β) :
(h : 0 < m.2.size) → m.size = Buckets.size ⟨m.2, h⟩ → Buckets.WF ⟨m.2, h⟩ → WF m
/-- The empty hash map is well formed. -/
| empty' : WF (empty' n h)
| empty' : 0 < n → WF (empty' n)
/-- Inserting into a well formed hash map yields a well formed hash map. -/
| insert : WF m → WF (insert m a b)
| insert : WF m → WF (insert m hm a b).1
/-- Removing an element from a well formed hash map yields a well formed hash map. -/
| erase : WF m → WF (erase m a)
| erase : WF m → WF (erase m h a).1
/-- Replacing an element in a well formed hash map yields a well formed hash map. -/
| modify : WF m → WF (modify m a f)

theorem WF.empty [BEq α] [Hashable α] : WF (empty n : Imp α β) := by unfold empty; apply empty'
| modify : WF m → WF (modify m h a f).1

theorem WF.empty [BEq α] [Hashable α] : WF ((Imp.empty n).1 : Imp α β) := by
dsimp only [Imp.empty]
split
· apply WF.empty'
simp
· next h =>
apply WF.empty'
exact Nat.pos_of_ne_zero h

theorem WF.size [BEq α] [Hashable α] {m : Imp α β} : m.WF → 0 < m.2.size
| mk _ h _ _ => h
| empty' h => by simp [h, Imp.empty']
| insert _ => (Imp.insert _ _ _ _).2
| erase _ => (Imp.erase _ _ _).2
| modify _ => (Imp.modify _ _ _ _).2

end Imp

Expand All @@ -255,7 +293,7 @@ open HashMap.Imp

/-- Make a new hash map with the specified capacity. -/
@[inline] def _root_.Std.mkHashMap [BEq α] [Hashable α] (capacity := 0) : HashMap α β :=
⟨.empty capacity, .empty⟩
(Imp.empty capacity).1, .empty⟩

instance [BEq α] [Hashable α] : Inhabited (HashMap α β) where
default := mkHashMap
Expand Down Expand Up @@ -298,7 +336,8 @@ hashMap.insert "three" 3 = {"one" => 1, "two" => 2, "three" => 3}
hashMap.insert "two" 0 = {"one" => 1, "two" => 0}
```
-/
def insert (self : HashMap α β) (a : α) (b : β) : HashMap α β := ⟨self.1.insert a b, self.2.insert⟩
def insert (self : HashMap α β) (a : α) (b : β) : HashMap α β :=
⟨self.1.insert self.2.size a b, self.2.insert⟩

/--
Similar to `insert`, but also returns a boolean flag indicating whether an existing entry has been
Expand All @@ -323,7 +362,8 @@ hashMap.erase "one" = {"two" => 2}
hashMap.erase "three" = {"one" => 1, "two" => 2}
```
-/
@[inline] def erase (self : HashMap α β) (a : α) : HashMap α β := ⟨self.1.erase a, self.2.erase⟩
@[inline] def erase (self : HashMap α β) (a : α) : HashMap α β :=
⟨self.1.erase self.2.size a, self.2.erase⟩

/--
Performs an in-place edit of the value, ensuring that the value is used linearly.
Expand All @@ -334,7 +374,7 @@ The function `f` is passed the original key of the entry, along with the value i
```
-/
def modify (self : HashMap α β) (a : α) (f : α → β → β) : HashMap α β :=
⟨self.1.modify a f, self.2.modify⟩
⟨self.1.modify self.2.size a f, self.2.modify⟩

/--
Given a key `a`, returns a key-value pair in the map whose key compares equal to `a`.
Expand All @@ -346,7 +386,8 @@ hashMap.findEntry? "one" = some ("one", 1)
hashMap.findEntry? "three" = none
```
-/
@[inline] def findEntry? (self : HashMap α β) (a : α) : Option (α × β) := self.1.findEntry? a
@[inline] def findEntry? (self : HashMap α β) (a : α) : Option (α × β) :=
self.1.findEntry? self.2.size a

/--
Looks up an element in the map with key `a`.
Expand All @@ -356,7 +397,7 @@ hashMap.find? "one" = some 1
hashMap.find? "three" = none
```
-/
@[inline] def find? (self : HashMap α β) (a : α) : Option β := self.1.find? a
@[inline] def find? (self : HashMap α β) (a : α) : Option β := self.1.find? self.2.size a

/--
Looks up an element in the map with key `a`. Returns `b₀` if the element is not found.
Expand Down Expand Up @@ -390,7 +431,7 @@ hashMap.contains "one" = true
hashMap.contains "three" = false
```
-/
@[inline] def contains (self : HashMap α β) (a : α) : Bool := self.1.contains a
@[inline] def contains (self : HashMap α β) (a : α) : Bool := self.1.contains self.2.size a

/--
Folds a monadic function over the elements in the map (in arbitrary order).
Expand Down Expand Up @@ -483,7 +524,7 @@ def toArray (self : HashMap α β) : Array (α × β) :=
self.fold (init := #[]) fun r k v => r.push (k, v)

/-- The number of buckets in the hash map. -/
def numBuckets (self : HashMap α β) : Nat := self.1.buckets.1.size
def numBuckets (self : HashMap α β) : Nat := self.1.buckets.size

/--
Builds a `HashMap` from a list of key-value pairs.
Expand Down
14 changes: 8 additions & 6 deletions Std/Data/HashMap/WF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@ theorem reinsertAux_WF [BEq α] [Hashable α] {data : Buckets α β} {a : α} {b
| H, _, .tail _ h => H _ h

theorem expand_size [Hashable α] {buckets : Buckets α β} :
(expand sz buckets).buckets.size = buckets.size := by
rw [expand, go]
(Buckets.mk' (expand sz buckets)).size = buckets.size := by
rw [expand]
change (expand.go 0 buckets.val _).size = _ -- Meh
rw [go]
· rw [Buckets.mk_size]; simp [Buckets.size]
· nofun
where
Expand Down Expand Up @@ -140,7 +142,7 @@ theorem expand_WF.foldl [BEq α] [Hashable α] (rank : α → Nat) {l : List (α
exact ⟨h₁, h₂.2⟩

theorem expand_WF [BEq α] [Hashable α] {buckets : Buckets α β} (H : buckets.WF) :
(expand sz buckets).buckets.WF :=
(Buckets.mk' (expand sz buckets)).WF :=
go _ H.1 H.2 ⟨.mk' _, fun _ _ _ _ => by simp_all [Buckets.mk, List.mem_replicate]⟩
where
go (i) {source : Array (AssocList α β)}
Expand Down Expand Up @@ -171,9 +173,9 @@ where
· exact ht.1
termination_by source.size - i

theorem insert_size [BEq α] [Hashable α] {m : Imp α β} {k v}
(h : m.size = m.buckets.size) :
(insert m k v).size = (insert m k v).buckets.size := by
theorem insert_size [BEq α] [Hashable α] {m : Imp α β} (hm) {k v}
(h : m.size = (Buckets.mk' ⟨m, hm⟩).size) :
(insert m hm k v).1.size = (Buckets.mk' (insert m hm k v)).size := by
dsimp [insert, cond]; split
· unfold Buckets.size
refine have ⟨_, _, h₁, _, eq⟩ := Buckets.exists_of_update ..; eq ▸ ?_
Expand Down
Loading