Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - feat: binary heaps #136

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Mathlib.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Mathlib.Algebra.GroupWithZero.Defs
import Mathlib.Algebra.Ring.Basic
import Mathlib.Data.Array.Basic
import Mathlib.Data.Array.Defs
import Mathlib.Data.BinaryHeap
import Mathlib.Data.ByteArray
import Mathlib.Data.Char
import Mathlib.Data.Equiv.Basic
Expand Down Expand Up @@ -32,6 +33,7 @@ import Mathlib.Init.Function
import Mathlib.Init.Logic
import Mathlib.Init.Set
import Mathlib.Init.SetNotation
import Mathlib.Init.WF
import Mathlib.Lean.Expr
import Mathlib.Lean.LocalContext
import Mathlib.Logic.Basic
Expand Down
139 changes: 139 additions & 0 deletions Mathlib/Data/BinaryHeap.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import Mathlib.Init.WF
import Mathlib.Data.Fin.Basic

/-- A max-heap data structure. -/
structure BinaryHeap (α) (lt : α → α → Bool) where
arr : Array α

namespace BinaryHeap

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/
def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
{a' : Array α // a'.size = a.size} :=
let left := 2 * i.1 + 1
let right := left + 1
have left_le : i ≤ left := Nat.le_trans
(by rw [Nat.succ_mul, Nat.one_mul]; exact Nat.le_add_left i i)
(Nat.le_add_right ..)
have right_le : i ≤ right := Nat.le_trans left_le (Nat.le_add_right ..)
have i_le : i ≤ i := Nat.le_refl _
have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then
if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩
have j := if h : right < a.size then
if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j
if h : i.1 = j then ⟨a, rfl⟩ else
let a' := a.swap i j
let j' := ⟨j, by rw [a.size_swap i j]; exact j.1.2⟩
have : (skipLeft Fin.upRel).1 ⟨a'.size, j'⟩ ⟨a.size, i⟩ := by
have H {n} (h : n = a.size) (j' : Fin n) (e' : i.1 < j'.1) :
(skipLeft Fin.upRel).1 ⟨n, j'⟩ ⟨a.size, i⟩ := by
subst n; exact PSigma.Lex.right _ e'
exact H (a.size_swap i j) _ (lt_of_le_of_ne j.2 h)
let ⟨a₂, h₂⟩ := heapifyDown lt a' j'
⟨a₂, h₂.trans (a.size_swap i j)⟩
termination_by invImage (fun ⟨_, _, a, i⟩ => (⟨a.size, i⟩ : (n : ℕ) ×' Fin n)) $ skipLeft Fin.upRel
Copy link
Member

@gebner gebner Dec 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a simpler termination argument:

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/
def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
  {a' : Array α // a'.size = a.size} :=
  let left := 2 * i.1 + 1
  let right := left + 1
  have left_le : i ≤ left := Nat.le_trans
    (by rw [Nat.succ_mul, Nat.one_mul]; exact Nat.le_add_left i i)
    (Nat.le_add_right ..)
  have right_le : i ≤ right := Nat.le_trans left_le (Nat.le_add_right ..)
  have i_le : i ≤ i := Nat.le_refl _
  have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then
    if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩
  have j := if h : right < a.size then
    if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j
  if h : i.1 = j then ⟨a, rfl⟩ else
    let a' := a.swap i j
    let j' := ⟨j, by rw [a.size_swap i j]; exact j.1.2have : a.size - j < a.size - i :=
      Nat.sub_lt_sub_left i.2 <| Nat.lt_of_le_and_ne j.2 h
    let ⟨a₂, h₂⟩ := heapifyDown lt a' j'
    ⟨a₂, h₂.trans (a.size_swap i j)⟩
termination_by measure fun ⟨α, lt, a, i⟩ => a.size - i
decreasing_by simp [measure, invImage, InvImage]; assumption

BTW, instead of Fin.upRel I think it would be more ergonomic to talk about -i and use the standard order. (alas, this doesn't work directly because you then have two versions of -: one in Fin a.size and one in Fin a'.size...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point of Nat.upRel and Fin.upRel is so that you don't have to do this a - x < a - y argument all the time. I think we should try to make this kind of proof more ergonomic, there is no reason to always be doing downward induction. I'm not too happy with the skipLeft stuff, I think more should be automated than currently, but I do think that the original proof is in the right direction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I find the measure version much easier to read compared to explicitly specifying the relations. Although ideally measure would accept functions into any well-ordered type. Then you could write measure fun ⟨α, lt, a, i⟩ => (⟨a.size, -i⟩ : (n : ℕ) ×' Fin n) or measure fun ⟨α, lt, a, i⟩ => (⟨a.size, i⟩ : (n : ℕ) ×' OrderDual (Fin n)). The OrderDual version also works for any finite type, not just Fin n.

decreasing_by assumption

@[simp] theorem size_heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(heapifyDown lt a i).1.size = a.size := (heapifyDown lt a i).2

/-- Core operation for binary heaps, expressed directly on arrays.
Construct a heap from an unsorted array, by heapifying all the elements. -/
def mkHeap (lt : α → α → Bool) (a : Array α) : {a' : Array α // a'.size = a.size} :=
let rec loop : (i : Nat) → (a : Array α) → i ≤ a.size → {a' : Array α // a'.size = a.size}
| 0, a, _ => ⟨a, rfl⟩
| i+1, a, h =>
let h := Nat.lt_of_succ_le h
let a' := heapifyDown lt a ⟨i, h⟩
let ⟨a₂, h₂⟩ := loop i a' ((heapifyDown ..).2.symm ▸ le_of_lt h)
⟨a₂, h₂.trans a'.2⟩
loop (a.size / 2) a (Nat.div_le_self ..)

@[simp] theorem size_mkHeap (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(mkHeap lt a).1.size = a.size := (mkHeap lt a).2

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` up to restore the max-heap property. -/
def heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
{a' : Array α // a'.size = a.size} :=
if i0 : i.1 = 0 then ⟨a, rfl⟩ else
have : (i.1 - 1) / 2 < i := lt_of_le_of_lt (Nat.div_le_self ..) $
Nat.sub_lt (Nat.pos_iff_ne_zero.2 i0) Nat.one_pos
let j := ⟨(i.1 - 1) / 2, lt_trans this i.2⟩
if lt (a.get j) (a.get i) then
let a' := a.swap i j
let ⟨a₂, h₂⟩ := heapifyUp lt a' ⟨j.1, by rw [a.size_swap i j]; exact j.2⟩
⟨a₂, h₂.trans (a.size_swap i j)⟩
else ⟨a, rfl⟩
termination_by measure (·.2.2.2)
decreasing_by assumption

@[simp] theorem size_heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(heapifyUp lt a i).1.size = a.size := (heapifyUp lt a i).2

/-- `O(1)`. Build a new empty heap. -/
@[inline] def empty (lt) : BinaryHeap α lt := ⟨#[]⟩
digama0 marked this conversation as resolved.
Show resolved Hide resolved

instance (lt) : Inhabited (BinaryHeap α lt) := ⟨empty _⟩

/-- `O(1)`. Get the number of elements in a `BinaryHeap`. -/
@[inline] def size {lt} (self : BinaryHeap α lt) : Nat := self.1.size
digama0 marked this conversation as resolved.
Show resolved Hide resolved

/-- `O(log n)`. Insert an element into a `BinaryHeap`, preserving the max-heap property. -/
def insert {lt} (self : BinaryHeap α lt) (x : α) : BinaryHeap α lt where
arr := let n := self.size;
heapifyUp lt (self.1.push x) ⟨n, by rw [Array.size_push]; apply Nat.lt_succ_self⟩

@[simp] theorem size_insert {lt} (self : BinaryHeap α lt) (x : α) :
(self.insert x).size = self.size + 1 := by
simp [insert, size, size_heapifyUp]

/-- `O(1)`. Get the maximum element in a `BinaryHeap`. -/
def max {lt} (self : BinaryHeap α lt) : Option α := self.1.get? 0

/-- Auxiliary for `popMax`. -/
def popMaxAux {lt} (self : BinaryHeap α lt) : {a' : BinaryHeap α lt // a'.size = self.size - 1} :=
match e: self.1.size with
| 0 => ⟨self, by simp [size, e]⟩
| n+1 =>
have h0 := by rw [e]; apply Nat.succ_pos
have hn := by rw [e]; apply Nat.lt_succ_self
if hn0 : 0 < n then
let a := self.1.swap ⟨0, h0⟩ ⟨n, hn⟩ |>.pop
⟨⟨heapifyDown lt a ⟨0, by rwa [Array.size_pop, Array.size_swap, e, Nat.add_sub_cancel]⟩⟩,
by simp [size]⟩
else
⟨⟨self.1.pop⟩, by simp [size]⟩

/-- `O(log n)`. Remove the maximum element from a `BinaryHeap`.
Call `max` first to actually retrieve the maximum element. -/
@[inline] def popMax {lt} (self : BinaryHeap α lt) : BinaryHeap α lt := self.popMaxAux

@[simp] theorem size_popMax {lt} (self : BinaryHeap α lt) :
self.popMax.size = self.size - 1 := self.popMaxAux.2

/-- `O(log n)`. Return and remove the maximum element from a `BinaryHeap`. -/
def extractMax {lt} (self : BinaryHeap α lt) : Option α × BinaryHeap α lt :=
(self.max, self.popMax)

end BinaryHeap

/-- `O(n)`. Convert an unsorted array to a `BinaryHeap`. -/
@[inline] def Array.toBinaryHeap (lt : α → α → Bool) (a : Array α) : BinaryHeap α lt where
arr := BinaryHeap.mkHeap lt a

/-- `O(n log n)`. Sort an array using a `BinaryHeap`. -/
@[inline] def Array.heapSort (a : Array α) (lt : α → α → Bool) : Array α :=
let gt y x := lt x y
let rec loop (a : BinaryHeap α gt) (out : Array α) : Array α :=
match e: a.max with
| none => out
| some x =>
have : a.popMax.size < a.size := by
simp; refine Nat.sub_lt (Decidable.of_not_not fun h: ¬ 0 < a.1.size => ?_) Nat.zero_lt_one
simp [BinaryHeap.max, Array.get?, h] at e
loop a.popMax (out.push x)
loop (a.toBinaryHeap gt) #[]
termination_by measure (·.2.2.1.size)
decreasing_by assumption
60 changes: 19 additions & 41 deletions Mathlib/Data/ByteArray.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Mathlib.Init.WF
import Mathlib.Data.Nat.Basic
import Mathlib.Data.Char
import Mathlib.Data.UInt
Expand Down Expand Up @@ -30,29 +31,19 @@ def toArray : ByteSlice → ByteArray
/-- Index into a byte slice. The `getOp` function allows the use of the `buf[i]` notation. -/
@[inline] def getOp (self : ByteSlice) (idx : Nat) : UInt8 := self.arr.get! (self.off + idx)

/-- Implementation of `forIn.loop`. -/
partial def forIn.loop.impl [Monad m] (f : UInt8 → β → m (ForInStep β))

/-- The inner loop of the `forIn` implementation for byte slices. -/
def forIn.loop [Monad m] (f : UInt8 → β → m (ForInStep β))
(arr : ByteArray) (off _end : Nat) (i : Nat) (b : β) : m β :=
if i < _end then do
if h : i < _end then do
match ← f (arr.get! i) b with
| ForInStep.done b => pure b
| ForInStep.yield b => impl f arr off _end (i+1) b
| ForInStep.yield b => have := Nat.Up.next h; loop f arr off _end (i+1) b
else b

set_option codegen false in
/-- The inner loop of the `forIn` implementation for byte slices. It is defined twice:
this version is the model, while `forIn.loop.impl` is the version used for code generation. -/
@[implementedBy forIn.loop.impl]
def forIn.loop [Monad m] (f : UInt8 → β → m (ForInStep β))
(arr : ByteArray) (off _end : Nat) (i : Nat) (b : β) : m β := do
(Nat.Up.WF _end).fix (x := i) (C := fun _ => ∀ b, m β) (b := b)
fun i IH b =>
if h : i < _end then do
let b ← f (arr.get! i) b
match b with
| ForInStep.done b => pure b
| ForInStep.yield b => IH (i+1) (Nat.Up.next h) b
else b
termination_by by
iterate 6 refine skipLeft fun _ => ?_
exact skipLeft fun _end => generalizeRight (Nat.upRel _end)
decreasing_by (iterate 7 apply PSigma.Lex.right); assumption

instance : ForIn m ByteSlice UInt8 :=
⟨fun ⟨arr, off, len⟩ b f => forIn.loop f arr off (off + len) off b⟩
Expand All @@ -66,31 +57,18 @@ def ByteSliceT.toSlice : ByteSliceT → ByteSlice
/-- Convert a byte array into a byte slice. -/
def ByteArray.toSlice (arr : ByteArray) : ByteSlice := ⟨arr, 0, arr.size⟩

/-- Implementation of `String.toAsciiByteArray.loop`. -/
partial def String.toAsciiByteArray.loop.impl
(s : String) (out : ByteArray) (p : Pos) : ByteArray :=
if s.atEnd p then out else
let c := s.get p
impl s (out.push c.toUInt8) (s.next p)

set_option codegen false in
/-- The inner loop of `String.toAsciiByteArray`. Because it uses well founded recursion, we have
to write the compiler version of the implementation separately from the version used for
reasoning inside lean. -/
@[implementedBy String.toAsciiByteArray.loop.impl]
def String.toAsciiByteArray.loop (s : String) (out : ByteArray) (p : Pos) : ByteArray :=
(Nat.Up.WF (utf8ByteSize s)).fix (x := p) (C := fun _ => ∀ out, ByteArray) (out := out)
fun p IH i =>
if h : s.atEnd p then out else
let c := s.get p
IH (s.next p) (out := out.push c.toUInt8)
⟨Nat.lt_add_of_pos_right (String.csize_pos _),
Nat.lt_of_not_le (mt decide_eq_true h)⟩

/-- Convert a string of assumed-ASCII characters into a byte array.
(If any characters are non-ASCII they will be reduced modulo 256.) -/
def String.toAsciiByteArray (s : String) : ByteArray :=
String.toAsciiByteArray.loop s ByteArray.empty 0
let rec loop (p : Pos) (out : ByteArray) : ByteArray :=
if h : s.atEnd p then out else
let c := s.get p
have : Nat.Up (utf8ByteSize s) (next s p) p :=
⟨Nat.lt_add_of_pos_right (String.csize_pos _), Nat.lt_of_not_le (mt decide_eq_true h)⟩
loop (s.next p) (out.push c.toUInt8)
loop 0 ByteArray.empty
termination_by skipLeft fun s => generalizeRight $ Nat.upRel (utf8ByteSize s)
decreasing_by apply PSigma.Lex.right; assumption

/-- Convert a byte slice into a string. This does not handle non-ASCII characters correctly:
every byte will become a unicode character with codepoint < 256. -/
Expand Down
Loading