Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: merging functions on List + mergeSort #763

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 120 additions & 9 deletions Std/Data/List/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1620,13 +1620,124 @@ See `isSubperm_iff` for a characterization in terms of `List.Subperm`.
def isSubperm [BEq α] (l₁ l₂ : List α) : Bool := ∀ x ∈ l₁, count x l₁ ≤ count x l₂

/--
`O(|l| + |r|)`. Merge two lists using `s` as a switch.
-/
def merge (s : α → α → Bool) (l r : List α) : List α :=
loop l r []
`O(|xs| + |ys|)`. Merge lists `xs` and `ys`. If the lists are sorted according to `lt`, then the
result is sorted as well. If two (or more) elements are equal according to `lt`, they are preserved.
-/
def merge (lt : α → α → Bool) : (xs ys : List α) → List α
| [], xs
| xs, [] => xs
| x :: xs, y :: ys =>
bif lt x y then x :: merge lt xs (y :: ys) else y :: merge lt (x :: xs) ys

/-- Tail recursive version of `merge`. -/
@[inline] def mergeTR (lt : α → α → Bool) (xs ys : List α) : List α := go xs ys [] where
/-- Auxiliary for `mergeTR`: `mergeTR.go xs ys acc = acc.toList ++ merge xs ys`. -/
go : List α → List α → List α → List α
| [], ys, acc => reverseAux acc ys
| xs, [], acc => reverseAux acc xs
| x::xs, y::ys, acc => bif lt x y then go xs (y::ys) (x::acc) else go (x::xs) ys (y::acc)

@[csimp] theorem merge_eq_mergeTR : @merge = @mergeTR := by
funext α lt xs ys
let rec go (acc) : ∀ xs ys, @mergeTR.go α lt xs ys acc = reverseAux acc (merge lt xs ys)
| [], _ => by simp [mergeTR.go, merge]
| _::_, [] => by simp [mergeTR.go, merge]
| x::xs, y::ys => by
simp [mergeTR.go, merge, cond]; split
· exact go _ xs (y::ys)
· exact go _ (x::xs) ys
simp [mergeTR, go]

/--
`O(|xs| + |ys|)`. Merge lists `xs` and `ys`, which must be sorted according to `compare` and must
not contain duplicates. Equal elements are merged using `merge`. If `merge` respects the order
(i.e. for all `x`, `y`, `y'`, `z`, if `x < y < z` and `x < y' < z` then `x < merge y y' < z`)
then the resulting list is again sorted.
-/
def mergeDedupWith [Ord α] (merge : α → α → α) : (xs ys : List α) → List α
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The merge parameter is a bit confusing. How about dedup, join or meld?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine?

Copy link
Member Author

@digama0 digama0 Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or just f. I think most of the other list functions don't give their predicate and function arguments long names, and while there are some downsides to having lots of one letter variables in context, one upside is that it makes it visibly distinct from global definitions. I assume that's the reason you flagged merge as being a bad name...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the issue is that merge clashes with a definition in this file. I'm happy with f, combine or whatever you decide.

| [], xs
| xs, [] => xs
| x :: xs, y :: ys =>
match compare x y with
| .lt => x :: mergeDedupWith merge xs (y :: ys)
| .gt => y :: mergeDedupWith merge (x :: xs) ys
| .eq => merge x y :: mergeDedupWith merge xs ys

/--
`O(|xs| + |ys|)`. Merge lists `xs` and `ys`, which must be sorted according to `compare` and must
not contain duplicates. If an element appears in both `xs` and `ys`, only one copy is kept.
-/
@[inline] def mergeDedup [Ord α] (xs ys : List α) : List α := mergeDedupWith (fun x _ => x) xs ys

/--
`O(|xs| * |ys|)`. Merge `xs` and `ys`, which do not need to be sorted. Elements which occur in
both `xs` and `ys` are only added once. If `xs` and `ys` do not contain duplicates, then neither
does the result.
-/
def mergeUnsortedDedup [BEq α] (xs ys : List α) : List α :=
if xs.length < ys.length then go ys xs else go xs ys
where
/-- Auxiliary definition for `mergeUnsortedDedup`. -/
go (xs ys : List α) := xs ++ ys.filter fun y => xs.any (· == y)

/-- Replace each run `[x₁, ⋯, xₙ]` of equal elements in `xs` with `f ⋯ (f (f x₁ x₂) x₃) ⋯ xₙ`. -/
def mergeAdjacentDups [BEq α] (f : α → α → α) : (xs : List α) → List α
| [] => []
| x :: xs => go x xs
where
/-- Inner loop for `List.merge`. Tail recursive. -/
loop : List α → List α → List α → List α
| [], r, t => reverseAux t r
| l, [], t => reverseAux t l
| a::l, b::r, t => bif s a b then loop l (b::r) (a::t) else loop (a::l) r (b::t)
/-- Auxiliary definition for `mergeAdjacentDups`. -/
go (hd : α)
| [] => [hd]
| x :: xs =>
if x == hd then
go (f hd x) xs
else
hd :: go x xs

/--
`O(|xs|)`. Deduplicate a sorted list. The list must be sorted with to an order which agrees with
`==`, i.e. whenever `x == y` then `compare x y == .eq`.
-/
def dedupSorted [BEq α] (xs : List α) : List α :=
xs.mergeAdjacentDups fun x _ => x

namespace MergeSort

/-- `O(|l|)`. Split `l` into two lists of approximately equal length.
digama0 marked this conversation as resolved.
Show resolved Hide resolved
```
split [1, 2, 3, 4, 5] = ([1, 3, 5], [2, 4])
```
-/
@[simp] def split : List α → List α × List α
| [] => ([], [])
| a :: l =>
let (l₁, l₂) := split l
(a :: l₂, l₁)

theorem length_split_le :
∀ l : List α, length (split l).1 ≤ length l ∧ length (split l).2 ≤ length l
| [] => ⟨Nat.le_refl 0, Nat.le_refl 0⟩
| _ :: l =>
let ⟨h₁, h₂⟩ := length_split_le l
⟨Nat.succ_le_succ h₂, Nat.le_succ_of_le h₁⟩

end MergeSort

/-- `O(|l| log |l|)`. Uses merge sort to sort a list in ascending order by `lt`. -/
def mergeSort (lt : α → α → Bool) (l : List α) : List α :=
match _e : l with
| [] => []
| [a] => [a]
| _ :: _ :: _ =>
let ls := MergeSort.split l
merge lt (mergeSort lt ls.1) (mergeSort lt ls.2)
termination_by length l
decreasing_by
all_goals subst _e
· exact Nat.add_le_add_right (MergeSort.length_split_le _).1 2
· exact Nat.add_le_add_right (MergeSort.length_split_le _).2 2

/-- `O(|xs| log |xs|)`. Sort and deduplicate a list. -/
def sortDedup [ord : Ord α] (xs : List α) : List α :=
have := ord.toBEq
dedupSorted <| xs.mergeSort (compare · · |>.isLT)
74 changes: 21 additions & 53 deletions Std/Data/List/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2684,66 +2684,34 @@ theorem indexOf_mem_indexesOf [BEq α] [LawfulBEq α] {xs : List α} (m : x ∈
specialize ih m
simpa

theorem merge_loop_nil_left (s : α → α → Bool) (r t) :
merge.loop s [] r t = reverseAux t r := by
rw [merge.loop]

theorem merge_loop_nil_right (s : α → α → Bool) (l t) :
merge.loop s l [] t = reverseAux t l := by
cases l <;> rw [merge.loop]; intro; contradiction

theorem merge_loop (s : α → α → Bool) (l r t) :
merge.loop s l r t = reverseAux t (merge s l r) := by
rw [merge]; generalize hn : l.length + r.length = n
induction n using Nat.recAux generalizing l r t with
| zero =>
rw [eq_nil_of_length_eq_zero (Nat.eq_zero_of_add_eq_zero_left hn)]
rw [eq_nil_of_length_eq_zero (Nat.eq_zero_of_add_eq_zero_right hn)]
rfl
| succ n ih =>
match l, r with
| [], r => simp only [merge_loop_nil_left]; rfl
| l, [] => simp only [merge_loop_nil_right]; rfl
| a::l, b::r =>
simp only [merge.loop, cond]
split
· have hn : l.length + (b :: r).length = n := by
apply Nat.add_right_cancel (m:=1)
rw [←hn]; simp only [length_cons, Nat.add_succ, Nat.succ_add]
rw [ih _ _ (a::t) hn, ih _ _ [] hn, ih _ _ [a] hn]; rfl
· have hn : (a::l).length + r.length = n := by
apply Nat.add_right_cancel (m:=1)
rw [←hn]; simp only [length_cons, Nat.add_succ, Nat.succ_add]
rw [ih _ _ (b::t) hn, ih _ _ [] hn, ih _ _ [b] hn]; rfl

@[simp] theorem merge_nil (s : α → α → Bool) (l) : merge s l [] = l := merge_loop_nil_right ..
@[simp] theorem merge_nil (lt : α → α → Bool) (l) : merge lt l [] = l := by cases l <;> simp [merge]

@[simp] theorem nil_merge (s : α → α → Bool) (r) : merge s [] r = r := merge_loop_nil_left ..
@[simp] theorem nil_merge (lt : α → α → Bool) (r) : merge lt [] r = r := by simp [merge]

theorem cons_merge_cons (s : α → α → Bool) (a b l r) :
merge s (a::l) (b::r) = if s a b then a :: merge s l (b::r) else b :: merge s (a::l) r := by
simp only [merge, merge.loop, cond]; split <;> (next hs => rw [hs, merge_loop]; rfl)
theorem cons_merge_cons (lt : α → α → Bool) (a b l r) :
merge lt (a::l) (b::r) = if lt a b then a :: merge lt l (b::r) else b :: merge lt (a::l) r := by
simp only [merge, cond_eq_if]

@[simp] theorem cons_merge_cons_pos (s : α → α → Bool) (l r) (h : s a b) :
merge s (a::l) (b::r) = a :: merge s l (b::r) := by
@[simp] theorem cons_merge_cons_pos (lt : α → α → Bool) (l r) (h : lt a b) :
merge lt (a::l) (b::r) = a :: merge lt l (b::r) := by
rw [cons_merge_cons, if_pos h]

@[simp] theorem cons_merge_cons_neg (s : α → α → Bool) (l r) (h : ¬ s a b) :
merge s (a::l) (b::r) = b :: merge s (a::l) r := by
@[simp] theorem cons_merge_cons_neg (lt : α → α → Bool) (l r) (h : ¬ lt a b) :
merge lt (a::l) (b::r) = b :: merge lt (a::l) r := by
rw [cons_merge_cons, if_neg h]

@[simp] theorem length_merge (s : α → α → Bool) (l r) :
(merge s l r).length = l.length + r.length := by
@[simp] theorem length_merge (lt : α → α → Bool) (l r) :
(merge lt l r).length = l.length + r.length := by
match l, r with
| [], r => simp
| l, [] => simp
| a::l, b::r =>
rw [cons_merge_cons]
split
· simp_arith [length_merge s l (b::r)]
· simp_arith [length_merge s (a::l) r]
· simp_arith [length_merge lt l (b::r)]
· simp_arith [length_merge lt (a::l) r]

theorem mem_merge_left (s : α → α → Bool) (h : x ∈ l) : x ∈ merge s l r := by
theorem mem_merge_left (lt : α → α → Bool) (h : x ∈ l) : x ∈ merge lt l r := by
match l, r with
| l, [] => simp [h]
| a::l, b::r =>
Expand All @@ -2752,25 +2720,25 @@ theorem mem_merge_left (s : α → α → Bool) (h : x ∈ l) : x ∈ merge s l
rw [cons_merge_cons]
split
· exact mem_cons_self ..
· apply mem_cons_of_mem; exact mem_merge_left s h
· apply mem_cons_of_mem; exact mem_merge_left lt h
| .inr h' =>
rw [cons_merge_cons]
split
· apply mem_cons_of_mem; exact mem_merge_left s h'
· apply mem_cons_of_mem; exact mem_merge_left s h
· apply mem_cons_of_mem; exact mem_merge_left lt h'
· apply mem_cons_of_mem; exact mem_merge_left lt h

theorem mem_merge_right (s : α → α → Bool) (h : x ∈ r) : x ∈ merge s l r := by
theorem mem_merge_right (lt : α → α → Bool) (h : x ∈ r) : x ∈ merge lt l r := by
match l, r with
| [], r => simp [h]
| a::l, b::r =>
match mem_cons.1 h with
| .inl rfl =>
rw [cons_merge_cons]
split
· apply mem_cons_of_mem; exact mem_merge_right s h
· apply mem_cons_of_mem; exact mem_merge_right lt h
· exact mem_cons_self ..
| .inr h' =>
rw [cons_merge_cons]
split
· apply mem_cons_of_mem; exact mem_merge_right s h
· apply mem_cons_of_mem; exact mem_merge_right s h'
· apply mem_cons_of_mem; exact mem_merge_right lt h
· apply mem_cons_of_mem; exact mem_merge_right lt h'