Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - feat: binary heaps #136

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Mathlib.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import Mathlib.Algebra.GroupWithZero.Defs
import Mathlib.Algebra.Ring.Basic
import Mathlib.Data.Array.Basic
import Mathlib.Data.Array.Defs
import Mathlib.Data.BinaryHeap
import Mathlib.Data.ByteArray
import Mathlib.Data.Char
import Mathlib.Data.Equiv.Basic
Expand Down Expand Up @@ -32,6 +33,7 @@ import Mathlib.Init.Function
import Mathlib.Init.Logic
import Mathlib.Init.Set
import Mathlib.Init.SetNotation
import Mathlib.Init.WF
import Mathlib.Lean.Expr
import Mathlib.Lean.LocalContext
import Mathlib.Logic.Basic
Expand Down
141 changes: 141 additions & 0 deletions Mathlib/Data/BinaryHeap.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/-
Copyright (c) 2021 Mario Carneiro. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mario Carneiro
-/
import Mathlib.Data.Fin.Basic

/-- A max-heap data structure. -/
structure BinaryHeap (α) (lt : α → α → Bool) where
arr : Array α

namespace BinaryHeap

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` down to restore the max-heap property. -/
def heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
{a' : Array α // a'.size = a.size} :=
let left := 2 * i.1 + 1
let right := left + 1
have left_le : i ≤ left := Nat.le_trans
(by rw [Nat.succ_mul, Nat.one_mul]; exact Nat.le_add_left i i)
(Nat.le_add_right ..)
have right_le : i ≤ right := Nat.le_trans left_le (Nat.le_add_right ..)
have i_le : i ≤ i := Nat.le_refl _
have j : {j : Fin a.size // i ≤ j} := if h : left < a.size then
if lt (a.get i) (a.get ⟨left, h⟩) then ⟨⟨left, h⟩, left_le⟩ else ⟨i, i_le⟩ else ⟨i, i_le⟩
have j := if h : right < a.size then
if lt (a.get j) (a.get ⟨right, h⟩) then ⟨⟨right, h⟩, right_le⟩ else j else j
if h : i.1 = j then ⟨a, rfl⟩ else
let a' := a.swap i j
let j' := ⟨j, by rw [a.size_swap i j]; exact j.1.2⟩
have : a'.size - j < a.size - i := by
rw [a.size_swap i j]; exact Nat.sub_lt_sub_left i.2 <| Nat.lt_of_le_and_ne j.2 h
let ⟨a₂, h₂⟩ := heapifyDown lt a' j'
⟨a₂, h₂.trans (a.size_swap i j)⟩
termination_by measure fun ⟨α, lt, a, i⟩ => a.size - i
decreasing_by assumption

@[simp] theorem size_heapifyDown (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(heapifyDown lt a i).1.size = a.size := (heapifyDown lt a i).2

/-- Core operation for binary heaps, expressed directly on arrays.
Construct a heap from an unsorted array, by heapifying all the elements. -/
def mkHeap (lt : α → α → Bool) (a : Array α) : {a' : Array α // a'.size = a.size} :=
let rec loop : (i : Nat) → (a : Array α) → i ≤ a.size → {a' : Array α // a'.size = a.size}
| 0, a, _ => ⟨a, rfl⟩
| i+1, a, h =>
let h := Nat.lt_of_succ_le h
let a' := heapifyDown lt a ⟨i, h⟩
let ⟨a₂, h₂⟩ := loop i a' ((heapifyDown ..).2.symm ▸ le_of_lt h)
⟨a₂, h₂.trans a'.2⟩
loop (a.size / 2) a (Nat.div_le_self ..)

@[simp] theorem size_mkHeap (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(mkHeap lt a).1.size = a.size := (mkHeap lt a).2

/-- Core operation for binary heaps, expressed directly on arrays.
Given an array which is a max-heap, push item `i` up to restore the max-heap property. -/
def heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
{a' : Array α // a'.size = a.size} :=
if i0 : i.1 = 0 then ⟨a, rfl⟩ else
have : (i.1 - 1) / 2 < i := lt_of_le_of_lt (Nat.div_le_self ..) $
Nat.sub_lt (Nat.pos_iff_ne_zero.2 i0) Nat.one_pos
let j := ⟨(i.1 - 1) / 2, lt_trans this i.2⟩
if lt (a.get j) (a.get i) then
let a' := a.swap i j
let ⟨a₂, h₂⟩ := heapifyUp lt a' ⟨j.1, by rw [a.size_swap i j]; exact j.2⟩
⟨a₂, h₂.trans (a.size_swap i j)⟩
else ⟨a, rfl⟩
termination_by measure (·.2.2.2)
decreasing_by assumption

@[simp] theorem size_heapifyUp (lt : α → α → Bool) (a : Array α) (i : Fin a.size) :
(heapifyUp lt a i).1.size = a.size := (heapifyUp lt a i).2

/-- `O(1)`. Build a new empty heap. -/
def empty (lt) : BinaryHeap α lt := ⟨#[]⟩

instance (lt) : Inhabited (BinaryHeap α lt) := ⟨empty _⟩
instance (lt) : EmptyCollection (BinaryHeap α lt) := ⟨empty _⟩

/-- `O(1)`. Get the number of elements in a `BinaryHeap`. -/
def size {lt} (self : BinaryHeap α lt) : Nat := self.1.size

/-- `O(log n)`. Insert an element into a `BinaryHeap`, preserving the max-heap property. -/
def insert {lt} (self : BinaryHeap α lt) (x : α) : BinaryHeap α lt where
arr := let n := self.size;
heapifyUp lt (self.1.push x) ⟨n, by rw [Array.size_push]; apply Nat.lt_succ_self⟩

@[simp] theorem size_insert {lt} (self : BinaryHeap α lt) (x : α) :
(self.insert x).size = self.size + 1 := by
simp [insert, size, size_heapifyUp]

/-- `O(1)`. Get the maximum element in a `BinaryHeap`. -/
def max {lt} (self : BinaryHeap α lt) : Option α := self.1.get? 0

/-- Auxiliary for `popMax`. -/
def popMaxAux {lt} (self : BinaryHeap α lt) : {a' : BinaryHeap α lt // a'.size = self.size - 1} :=
match e: self.1.size with
| 0 => ⟨self, by simp [size, e]⟩
| n+1 =>
have h0 := by rw [e]; apply Nat.succ_pos
have hn := by rw [e]; apply Nat.lt_succ_self
if hn0 : 0 < n then
let a := self.1.swap ⟨0, h0⟩ ⟨n, hn⟩ |>.pop
⟨⟨heapifyDown lt a ⟨0, by rwa [Array.size_pop, Array.size_swap, e, Nat.add_sub_cancel]⟩⟩,
by simp [size]⟩
else
⟨⟨self.1.pop⟩, by simp [size]⟩

/-- `O(log n)`. Remove the maximum element from a `BinaryHeap`.
Call `max` first to actually retrieve the maximum element. -/
def popMax {lt} (self : BinaryHeap α lt) : BinaryHeap α lt := self.popMaxAux

@[simp] theorem size_popMax {lt} (self : BinaryHeap α lt) :
self.popMax.size = self.size - 1 := self.popMaxAux.2

/-- `O(log n)`. Return and remove the maximum element from a `BinaryHeap`. -/
def extractMax {lt} (self : BinaryHeap α lt) : Option α × BinaryHeap α lt :=
(self.max, self.popMax)

end BinaryHeap

/-- `O(n)`. Convert an unsorted array to a `BinaryHeap`. -/
def Array.toBinaryHeap (lt : α → α → Bool) (a : Array α) : BinaryHeap α lt where
arr := BinaryHeap.mkHeap lt a

/-- `O(n log n)`. Sort an array using a `BinaryHeap`. -/
@[specialize] def Array.heapSort (a : Array α) (lt : α → α → Bool) : Array α :=
let gt y x := lt x y
let rec loop (a : BinaryHeap α gt) (out : Array α) : Array α :=
match e: a.max with
| none => out
| some x =>
have : a.popMax.size < a.size := by
simp; refine Nat.sub_lt (Decidable.of_not_not fun h: ¬ 0 < a.1.size => ?_) Nat.zero_lt_one
simp [BinaryHeap.max, Array.get?, h] at e
loop a.popMax (out.push x)
loop (a.toBinaryHeap gt) #[]
termination_by measure (·.2.2.1.size)
decreasing_by assumption
60 changes: 19 additions & 41 deletions Mathlib/Data/ByteArray.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import Mathlib.Init.WF
import Mathlib.Data.Nat.Basic
import Mathlib.Data.Char
import Mathlib.Data.UInt
Expand Down Expand Up @@ -30,29 +31,19 @@ def toArray : ByteSlice → ByteArray
/-- Index into a byte slice. The `getOp` function allows the use of the `buf[i]` notation. -/
@[inline] def getOp (self : ByteSlice) (idx : Nat) : UInt8 := self.arr.get! (self.off + idx)

/-- Implementation of `forIn.loop`. -/
partial def forIn.loop.impl [Monad m] (f : UInt8 → β → m (ForInStep β))

/-- The inner loop of the `forIn` implementation for byte slices. -/
def forIn.loop [Monad m] (f : UInt8 → β → m (ForInStep β))
(arr : ByteArray) (off _end : Nat) (i : Nat) (b : β) : m β :=
if i < _end then do
if h : i < _end then do
match ← f (arr.get! i) b with
| ForInStep.done b => pure b
| ForInStep.yield b => impl f arr off _end (i+1) b
| ForInStep.yield b => have := Nat.Up.next h; loop f arr off _end (i+1) b
else b

set_option codegen false in
/-- The inner loop of the `forIn` implementation for byte slices. It is defined twice:
this version is the model, while `forIn.loop.impl` is the version used for code generation. -/
@[implementedBy forIn.loop.impl]
def forIn.loop [Monad m] (f : UInt8 → β → m (ForInStep β))
(arr : ByteArray) (off _end : Nat) (i : Nat) (b : β) : m β := do
(Nat.Up.WF _end).fix (x := i) (C := fun _ => ∀ b, m β) (b := b)
fun i IH b =>
if h : i < _end then do
let b ← f (arr.get! i) b
match b with
| ForInStep.done b => pure b
| ForInStep.yield b => IH (i+1) (Nat.Up.next h) b
else b
termination_by by
iterate 6 refine skipLeft' fun _ => ?_
exact skipLeft' fun _end => generalizeRight (Nat.upRel _end)
decreasing_by (iterate 7 apply PSigma.Lex.right); assumption

instance : ForIn m ByteSlice UInt8 :=
⟨fun ⟨arr, off, len⟩ b f => forIn.loop f arr off (off + len) off b⟩
Expand All @@ -66,31 +57,18 @@ def ByteSliceT.toSlice : ByteSliceT → ByteSlice
/-- Convert a byte array into a byte slice. -/
def ByteArray.toSlice (arr : ByteArray) : ByteSlice := ⟨arr, 0, arr.size⟩

/-- Implementation of `String.toAsciiByteArray.loop`. -/
partial def String.toAsciiByteArray.loop.impl
(s : String) (out : ByteArray) (p : Pos) : ByteArray :=
if s.atEnd p then out else
let c := s.get p
impl s (out.push c.toUInt8) (s.next p)

set_option codegen false in
/-- The inner loop of `String.toAsciiByteArray`. Because it uses well founded recursion, we have
to write the compiler version of the implementation separately from the version used for
reasoning inside lean. -/
@[implementedBy String.toAsciiByteArray.loop.impl]
def String.toAsciiByteArray.loop (s : String) (out : ByteArray) (p : Pos) : ByteArray :=
(Nat.Up.WF (utf8ByteSize s)).fix (x := p) (C := fun _ => ∀ out, ByteArray) (out := out)
fun p IH i =>
if h : s.atEnd p then out else
let c := s.get p
IH (s.next p) (out := out.push c.toUInt8)
⟨Nat.lt_add_of_pos_right (String.csize_pos _),
Nat.lt_of_not_le (mt decide_eq_true h)⟩

/-- Convert a string of assumed-ASCII characters into a byte array.
(If any characters are non-ASCII they will be reduced modulo 256.) -/
def String.toAsciiByteArray (s : String) : ByteArray :=
String.toAsciiByteArray.loop s ByteArray.empty 0
let rec loop (p : Pos) (out : ByteArray) : ByteArray :=
if h : s.atEnd p then out else
let c := s.get p
have : Nat.Up (utf8ByteSize s) (next s p) p :=
⟨Nat.lt_add_of_pos_right (String.csize_pos _), Nat.lt_of_not_le (mt decide_eq_true h)⟩
loop (s.next p) (out.push c.toUInt8)
loop 0 ByteArray.empty
termination_by skipLeft' fun s => generalizeRight $ Nat.upRel (utf8ByteSize s)
decreasing_by apply PSigma.Lex.right; assumption

/-- Convert a byte slice into a string. This does not handle non-ASCII characters correctly:
every byte will become a unicode character with codepoint < 256. -/
Expand Down
Loading