Skip to content

Commit ed1f8fe

Browse files
committed
feat: HashMap.modify
1 parent 5507f9d commit ed1f8fe

File tree

6 files changed

+134
-29
lines changed

6 files changed

+134
-29
lines changed

Std/Data/Array/Lemmas.lean

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ theorem get_set (a : Array α) (i : Fin a.size) (j : Nat) (hj : j < a.size) (v :
107107
(a.set i v)[j]'(by simp [*]) = if i = j then v else a[j] := by
108108
if h : i.1 = j then subst j; simp [*] else simp [*]
109109

110+
theorem set_set (a : Array α) (i : Fin a.size) (v v' : α) :
111+
(a.set i v).set ⟨i, by simp [i.2]⟩ v' = a.set i v' := by simp [set, List.set_set]
112+
110113
private theorem fin_cast_val (e : n = n') (i : Fin n) : e ▸ i = ⟨i.1, e ▸ i.2⟩ := by cases e; rfl
111114

112115
theorem swap_def (a : Array α) (i j : Fin a.size) :

Std/Data/AssocList.lean

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,23 @@ with key equal to `a` to have key `a` and value `b`.
182182
@[simp] theorem erase_toList [BEq α] (a : α) (l : AssocList α β) :
183183
(erase a l).toList = l.toList.eraseP (·.1 == a) := eraseP_toList ..
184184

185+
/--
186+
`O(n)`. Replace the first entry `a', b` in the list
187+
with key equal to `a` to have key `a` and value `f a' b`.
188+
-/
189+
@[simp] def modify [BEq α] (a : α) (f : α → β → β) : AssocList α β → AssocList α β
190+
| nil => nil
191+
| cons k v es => match k == a with
192+
| true => cons a (f k v) es
193+
| false => cons k v (modify a f es)
194+
195+
@[simp] theorem modify_toList [BEq α] (a : α) (l : AssocList α β) :
196+
(modify a f l).toList =
197+
l.toList.replaceF fun (k, v) => bif k == a then some (a, f k v) else none := by
198+
simp [cond]
199+
induction l with simp [List.replaceF]
200+
| cons k v es ih => cases k == a <;> simp [ih]
201+
185202
/-- The implementation of `ForIn`, which enables `for (k, v) in aList do ...` notation. -/
186203
@[specialize] protected def forIn [Monad m]
187204
(as : AssocList α β) (init : δ) (f : (α × β) → δ → m (ForInStep δ)) : m δ :=

Std/Data/HashMap/Basic.lean

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ Note: this is marked `noncomputable` because it is only intended for specificati
3939
-/
4040
noncomputable def size (data : Bucket α β) : Nat := .sum (data.1.data.map (·.toList.length))
4141

42+
@[simp] theorem update_size (self : Bucket α β) (i d h) :
43+
(self.update i d h).1.size = self.1.size := Array.size_uset ..
44+
4245
/-- Map a function over the values in the map. -/
4346
@[specialize] def mapVal (f : α → β → γ) (self : Bucket α β) : Bucket α γ :=
4447
⟨self.1.map (.mapVal f), by simp [self.2]⟩
@@ -181,6 +184,14 @@ def erase [BEq α] [Hashable α] (m : Imp α β) (a : α) : Imp α β :=
181184
@[inline] def mapVal (f : α → β → γ) (self : Imp α β) : Imp α γ :=
182185
{ size := self.size, buckets := self.buckets.mapVal f }
183186

187+
/-- Performs an in-place edit of the value, ensuring that the value is used linearly. -/
188+
def modify [BEq α] [Hashable α] (m : Imp α β) (a : α) (f : α → β → β) : Imp α β :=
189+
let ⟨size, buckets⟩ := m
190+
let ⟨i, h⟩ := mkIdx buckets.2 (hash a |>.toUSize)
191+
let bkt := buckets.1[i]
192+
let buckets := buckets.update i .nil h -- for linearity
193+
⟨size, buckets.update i (bkt.modify a f) ((Bucket.update_size ..).symm ▸ h)⟩
194+
184195
/--
185196
Applies `f` to each key-value pair `a, b` in the map. If it returns `some c` then
186197
`a, c` is pushed into the new map; else the key is removed from the map.
@@ -223,6 +234,8 @@ inductive WF [BEq α] [Hashable α] : Imp α β → Prop where
223234
| insert : WF m → WF (insert m a b)
224235
/-- Removing an element from a well formed hash map yields a well formed hash map. -/
225236
| erase : WF m → WF (erase m a)
237+
/-- Replacing an element in a well formed hash map yields a well formed hash map. -/
238+
| modify : WF m → WF (modify m a f)
226239

227240
theorem WF.empty [BEq α] [Hashable α] : WF (empty n : Imp α β) := by unfold empty; apply empty'
228241

@@ -280,6 +293,13 @@ Removes key `a` from the map. If it does not exist in the map, the map is return
280293
-/
281294
@[inline] def erase (self : HashMap α β) (a : α) : HashMap α β := ⟨self.1.erase a, self.2.erase⟩
282295

296+
/--
297+
Performs an in-place edit of the value, ensuring that the value is used linearly.
298+
The function `f` is passed the original key of the entry, along with the value in the map.
299+
-/
300+
def modify (self : HashMap α β) (a : α) (f : α → β → β) : HashMap α β :=
301+
⟨self.1.modify a f, self.2.modify⟩
302+
283303
/-- Given a key `a`, returns a key-value pair in the map whose key compares equal to `a`. -/
284304
@[inline] def findEntry? (self : HashMap α β) (a : α) : Option (α × β) := self.1.findEntry? a
285305

Std/Data/HashMap/WF.lean

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ namespace Bucket
2121
theorem update_data (self : Bucket α β) (i d h) :
2222
(self.update i d h).1.data = self.1.data.set i.toNat d := rfl
2323

24-
@[simp] theorem update_size (self : Bucket α β) (i d h) :
25-
(self.update i d h).1.size = self.1.size := Array.size_uset ..
26-
2724
theorem exists_of_update (self : Bucket α β) (i d h) :
2825
∃ l₁ l₂, self.1.data = l₁ ++ self.1[i] :: l₂ ∧ List.length l₁ = i.toNat ∧
2926
(self.update i d h).1.data = l₁ ++ d :: l₂ := by
3027
simp [Array.getElem_eq_data_get]; exact List.exists_of_set' h
3128

29+
theorem update_update (self : Bucket α β) (i d d' h h') :
30+
(self.update i d h).update i d' h' = self.update i d' h := by
31+
simp [update]; congr 1; rw [Array.set_set]
32+
3233
theorem size_eq (data : Bucket α β) :
3334
size data = .sum (data.1.data.map (·.toList.length)) := rfl
3435

@@ -185,8 +186,8 @@ theorem insert_size [BEq α] [Hashable α] {m : Imp α β} {k v}
185186
refine have ⟨_, _, h₁, _, eq⟩ := Bucket.exists_of_update ..; eq ▸ ?_
186187
simp [h₁, Bucket.size_eq, Nat.succ_add]; rfl
187188

188-
private theorem mem_replaceF {l : List (α × β)} {x : α × β} {p : α × β → Bool} :
189-
x ∈ (l.replaceF fun a => bif p a then some (k, v) else none) → x.1 = k ∨ x ∈ l := by
189+
private theorem mem_replaceF {l : List (α × β)} {x : α × β} {p : α × β → Bool} {f : α × β → β} :
190+
x ∈ (l.replaceF fun a => bif p a then some (k, f a) else none) → x.1 = k ∨ x ∈ l := by
190191
induction l with
191192
| nil => exact .inr
192193
| cons a l ih =>
@@ -200,16 +201,16 @@ private theorem mem_replaceF {l : List (α × β)} {x : α × β} {p : α × β
200201
| .inr h => exact (ih h).imp_right .inr
201202

202203
private theorem pairwise_replaceF [BEq α] [PartialEquivBEq α]
203-
{l : List (α × β)} {x : α × β} (hx₁ : x ∈ l) (hx₂ : x.fst == k)
204+
{l : List (α × β)} {f : α × β → β}
204205
(H : l.Pairwise fun a b => ¬(a.fst == b.fst)) :
205-
(l.replaceF fun a => bif a.fst == k then some (k, v) else none)
206+
(l.replaceF fun a => bif a.fst == k then some (k, f a) else none)
206207
|>.Pairwise fun a b => ¬(a.fst == b.fst) := by
207-
induction hx₁ with
208-
| head => simp_all; exact (H.1 · · ∘ PartialEquivBEq.trans hx₂)
209-
| tail _ _ ih =>
208+
induction l with
209+
| nil => simp [H]
210+
| cons a l ih =>
210211
simp at H ⊢
211-
generalize e : cond .. = z; revert e
212-
unfold cond; split <;> (intro h; subst h; simp)
212+
generalize e : cond .. = z; unfold cond at e; revert e
213+
split <;> (intro h; subst h; simp)
213214
· next e => exact ⟨(H.1 · · ∘ PartialEquivBEq.trans e), H.2
214215
· next e =>
215216
refine ⟨fun a h => ?_, ih H.2
@@ -223,7 +224,7 @@ theorem insert_WF [BEq α] [Hashable α] {m : Imp α β} {k v}
223224
· next h₁ =>
224225
simp at h₁; have ⟨x, hx₁, hx₂⟩ := h₁
225226
refine h.update (fun H => ?_) (fun H a h => ?_)
226-
· simp; exact pairwise_replaceF hx₁ hx₂ H
227+
· simp; exact pairwise_replaceF H
227228
· simp [AssocList.All] at H h ⊢
228229
match mem_replaceF h with
229230
| .inl rfl => rfl
@@ -261,13 +262,32 @@ theorem erase_WF [BEq α] [Hashable α] {m : Imp α β} {k}
261262
· exact H _ (List.mem_of_mem_eraseP h)
262263
· exact h
263264

265+
theorem modify_size [BEq α] [Hashable α] {m : Imp α β} {k}
266+
(h : m.size = m.buckets.size) :
267+
(modify m k f).size = (modify m k f).buckets.size := by
268+
dsimp [modify, cond]; rw [Bucket.update_update]
269+
simp [h, Bucket.size]
270+
refine have ⟨_, _, h₁, _, eq⟩ := Bucket.exists_of_update ..; eq ▸ ?_
271+
simp [h, h₁, Bucket.size_eq]
272+
273+
theorem modify_WF [BEq α] [Hashable α] {m : Imp α β} {k}
274+
(h : m.buckets.WF) : (modify m k f).buckets.WF := by
275+
dsimp [modify, cond]; rw [Bucket.update_update]
276+
refine h.update (fun H => ?_) (fun H a h => ?_) <;> simp at h ⊢
277+
· exact pairwise_replaceF H
278+
· simp [AssocList.All] at H h ⊢
279+
match mem_replaceF h with
280+
| .inl rfl => rfl
281+
| .inr h => exact H _ h
282+
264283
theorem WF.out [BEq α] [Hashable α] {m : Imp α β} (h : m.WF) :
265284
m.size = m.buckets.size ∧ m.buckets.WF := by
266285
induction h with
267286
| mk h₁ h₂ => exact ⟨h₁, h₂⟩
268287
| @empty' _ h => exact ⟨(Bucket.mk_size h).symm, .mk' h⟩
269288
| insert _ ih => exact ⟨insert_size ih.1, insert_WF ih.2
270289
| erase _ ih => exact ⟨erase_size ih.1, erase_WF ih.2
290+
| modify _ ih => exact ⟨modify_size ih.1, modify_WF ih.2
271291

272292
theorem WF_iff [BEq α] [Hashable α] {m : Imp α β} :
273293
m.WF ↔ m.size = m.buckets.size ∧ m.buckets.WF :=

Std/Data/List/Basic.lean

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -390,26 +390,23 @@ def indexOf [BEq α] (a : α) : List α → Nat := findIdx (a == ·)
390390
| none => x :: replaceF f xs
391391
| some a => a :: xs
392392

393-
/-- Tail recursive version of `replaceF`. -/
393+
/-- Tail-recursive version of `replaceF`. -/
394394
@[inline] def replaceFTR (f : α → Option α) (l : List α) : List α := go l #[] where
395-
/-- Auxiliary for `replaceFTR`:
396-
`replaceFTR.go f l xs acc = acc.toList ++ replaceF f xs` if `f` returns `some`, else `l`. -/
397-
go : List α → Array α → List α
398-
| [], _ => l
395+
/-- Auxiliary for `replaceFTR`: `replaceFTR.go f xs acc = acc.toList ++ replaceF f xs`. -/
396+
@[specialize] go : List α → Array α → List α
397+
| [], acc => acc.toList
399398
| x :: xs, acc => match f x with
400399
| none => go xs (acc.push x)
401-
| some a => acc.toListAppend (a :: xs)
400+
| some a' => acc.toListAppend (a' :: xs)
402401

403402
@[csimp] theorem replaceF_eq_replaceFTR : @replaceF = @replaceFTR := by
404-
funext α f l; simp [replaceFTR]
405-
suffices ∀ xs acc, l = acc.data ++ xs →
406-
replaceFTR.go f l xs acc = acc.data ++ xs.replaceF f from
407-
(this l #[] (by simp)).symm
408-
intro xs; induction xs with intro acc
409-
| nil => simp [replaceF, replaceFTR.go]
410-
| cons x xs IH =>
411-
simp [replaceF, replaceFTR.go]; split <;> simp [*]
412-
· intro h; rw [IH]; simp; simp; exact h
403+
funext α p l; simp [replaceFTR]
404+
let rec go (acc) : ∀ xs, replaceFTR.go p xs acc = acc.data ++ xs.replaceF p
405+
| [] => by simp [replaceFTR.go, replaceF]
406+
| x::xs => by
407+
simp [replaceFTR.go, replaceF]; cases p x <;> simp
408+
· rw [go _ xs]; simp
409+
exact (go #[] _).symm
413410

414411
/-- Inserts an element into a list without duplication. -/
415412
@[inline] protected def insert [DecidableEq α] (a : α) (l : List α) : List α :=

Std/Data/List/Lemmas.lean

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,11 @@ theorem set_comm (a b : α) : ∀ {n m : Nat} (l : List α), n ≠ m →
819819
| n+1, m+1, x :: t, h =>
820820
congrArg _ <| set_comm a b t fun h' => h <| Nat.succ_inj'.mpr h'
821821

822+
theorem set_set (a b : α) : ∀ (l : List α) (n : Nat), (l.set n a).set n b = l.set n b
823+
| [], _ => by simp
824+
| _ :: _, 0 => by simp [set]
825+
| _ :: _, _+1 => by simp [set, set_set]
826+
822827
@[simp] theorem get_set_eq (l : List α) (i : Nat) (a : α) (h : i < (l.set i a).length) :
823828
(l.set i a).get ⟨i, h⟩ = a := by
824829
rw [← Option.some_inj, ← get?_eq_get, get?_set_eq, get?_eq_get] <;> simp_all
@@ -1280,8 +1285,51 @@ theorem Pairwise.imp {α R S} (H : ∀ {a b}, R a b → S a b) :
12801285

12811286
/-! ### replaceF -/
12821287

1288+
theorem replaceF_nil : [].replaceF p = [] := rfl
1289+
1290+
theorem replaceF_cons (a : α) (l : List α) :
1291+
(a :: l).replaceF p = match p a with
1292+
| none => a :: replaceF p l
1293+
| some a' => a' :: l := rfl
1294+
1295+
theorem replaceF_cons_of_some {l : List α} (p) (h : p a = some a') :
1296+
(a :: l).replaceF p = a' :: l := by
1297+
simp [replaceF_cons, h]
1298+
1299+
theorem replaceF_cons_of_none {l : List α} (p) (h : p a = none) :
1300+
(a :: l).replaceF p = a :: l.replaceF p := by simp [replaceF_cons, h]
1301+
1302+
theorem replaceF_of_forall_none {l : List α} (h : ∀ a, a ∈ l → p a = none) : l.replaceF p = l := by
1303+
induction l with
1304+
| nil => rfl
1305+
| cons _ _ ih => simp [h _ (.head ..), ih (forall_mem_cons.1 h).2]
1306+
1307+
theorem exists_of_replaceF : ∀ {l : List α} {a a'} (al : a ∈ l) (pa : p a = some a'),
1308+
∃ a a' l₁ l₂,
1309+
(∀ b ∈ l₁, p b = none) ∧ p a = some a' ∧ l = l₁ ++ a :: l₂ ∧ l.replaceF p = l₁ ++ a' :: l₂
1310+
| b :: l, a, a', al, pa =>
1311+
match pb : p b with
1312+
| some b' => ⟨b, b', [], l, forall_mem_nil _, pb, by simp [pb]⟩
1313+
| none =>
1314+
match al with
1315+
| .head .. => nomatch pb.symm.trans pa
1316+
| .tail _ al =>
1317+
let ⟨c, c', l₁, l₂, h₁, h₂, h₃, h₄⟩ := exists_of_replaceF al pa
1318+
⟨c, c', b::l₁, l₂, (forall_mem_cons ..).2 ⟨pb, h₁⟩,
1319+
h₂, by rw [h₃, cons_append], by simp [pb, h₄]⟩
1320+
1321+
theorem exists_or_eq_self_of_replaceF (p) (l : List α) :
1322+
l.replaceF p = l ∨ ∃ a a' l₁ l₂,
1323+
(∀ b ∈ l₁, p b = none) ∧ p a = some a' ∧ l = l₁ ++ a :: l₂ ∧ l.replaceF p = l₁ ++ a' :: l₂ :=
1324+
if h : ∃ a ∈ l, (p a).isSome then
1325+
let ⟨_, ha, pa⟩ := h
1326+
.inr (exists_of_replaceF ha (Option.get_mem pa))
1327+
else
1328+
.inl <| replaceF_of_forall_none fun a ha =>
1329+
Option.not_isSome_iff_eq_none.1 fun h' => h ⟨a, ha, h'⟩
1330+
12831331
@[simp] theorem length_replaceF : length (replaceF f l) = length l := by
1284-
induction l <;> simp; split <;> simp [*]
1332+
induction l <;> simp [replaceF]; split <;> simp [*]
12851333

12861334
/-! ### disjoint -/
12871335

0 commit comments

Comments
 (0)