diff --git a/docs/overview.yaml b/docs/overview.yaml index 7094ca0693040..1952065e5a3c8 100644 --- a/docs/overview.yaml +++ b/docs/overview.yaml @@ -319,6 +319,7 @@ Data structures: set: 'set' finite set: 'finset' multiset: 'multiset' + ordered set: 'ordset' Maps: key-value map: 'alist' @@ -330,6 +331,7 @@ Data structures: Trees: tree: 'tree' red-black tree: 'rbtree' + size-balanced binary search tree: 'ordnode' W type: 'W_type' Logic and computation: diff --git a/src/algebra/order.lean b/src/algebra/order.lean index 1de4ca0400dc9..79b467fd1d61a 100644 --- a/src/algebra/order.lean +++ b/src/algebra/order.lean @@ -236,6 +236,28 @@ lemma le_iff_le_iff_lt_iff_lt {β} [linear_order α] [linear_order β] end decidable +/-- Like `cmp`, but uses a `≤` on the type instead of `<`. Given two elements +`x` and `y`, returns a three-way comparison result `ordering`. -/ +def cmp_le {α} [has_le α] [@decidable_rel α (≤)] (x y : α) : ordering := +if x ≤ y then + if y ≤ x then ordering.eq else ordering.lt +else ordering.gt + +theorem cmp_le_swap {α} [has_le α] [is_total α (≤)] [@decidable_rel α (≤)] (x y : α) : + (cmp_le x y).swap = cmp_le y x := +begin + by_cases xy : x ≤ y; by_cases yx : y ≤ x; simp [cmp_le, *, ordering.swap], + cases not_or xy yx (total_of _ _ _) +end + +theorem cmp_le_eq_cmp {α} [preorder α] [is_total α (≤)] + [@decidable_rel α (≤)] [@decidable_rel α (<)] (x y : α) : cmp_le x y = cmp x y := +begin + by_cases xy : x ≤ y; by_cases yx : y ≤ x; + simp [cmp_le, lt_iff_le_not_le, *, cmp, cmp_using], + cases not_or xy yx (total_of _ _ _) +end + namespace ordering /-- `compares o a b` means that `a` and `b` have the ordering relation diff --git a/src/data/nat/dist.lean b/src/data/nat/dist.lean index 7c2e8a276065c..d6b4046861e51 100644 --- a/src/data/nat/dist.lean +++ b/src/data/nat/dist.lean @@ -38,6 +38,18 @@ begin rw [dist.def, sub_eq_zero_of_le h, zero_add] end theorem dist_eq_sub_of_le_right {n m : ℕ} (h : m ≤ n) : dist n m = n - m := begin rw [dist_comm], apply dist_eq_sub_of_le h end +theorem dist_tri_left (n m : ℕ) : m ≤ dist n m + n := +le_trans (nat.le_sub_add _ _) (add_le_add_right (le_add_left _ _) _) + +theorem dist_tri_right (n m : ℕ) : m ≤ n + dist n m := +by rw add_comm; apply dist_tri_left + +theorem dist_tri_left' (n m : ℕ) : n ≤ dist n m + m := +by rw dist_comm; apply dist_tri_left + +theorem dist_tri_right' (n m : ℕ) : n ≤ m + dist n m := +by rw dist_comm; apply dist_tri_right + theorem dist_zero_right (n : ℕ) : dist n 0 = n := eq.trans (dist_eq_sub_of_le_right (zero_le n)) (nat.sub_zero n) diff --git a/src/data/nat/psub.lean b/src/data/nat/psub.lean index d675962351cf3..ea87106ac939a 100644 --- a/src/data/nat/psub.lean +++ b/src/data/nat/psub.lean @@ -56,7 +56,7 @@ theorem psub_eq_some {m : ℕ} : ∀ {n k}, psub m n = some k ↔ k + n = m simp [psub_eq_some, add_comm, add_left_comm, nat.succ_eq_add_one] end -theorem psub_eq_none (m n : ℕ) : psub m n = none ↔ m < n := +theorem psub_eq_none {m n : ℕ} : psub m n = none ↔ m < n := begin cases s : psub m n; simp [eq_comm], { show m < n, refine lt_of_not_ge (λ h, _), @@ -74,4 +74,11 @@ psub_eq_some.2 $ nat.sub_add_cancel h theorem psub_add (m n k) : psub m (n + k) = do x ← psub m n, psub x k := by induction k; simp [*, add_succ, bind_assoc] +/-- Same as `psub`, but with a more efficient implementation. -/ +@[inline] def psub' (m n : ℕ) : option ℕ := if n ≤ m then some (m - n) else none + +theorem psub'_eq_psub (m n) : psub' m n = psub m n := +by rw [psub']; split_ifs; + [exact (psub_eq_sub h).symm, exact (psub_eq_none.2 (not_le.1 h)).symm] + end nat diff --git a/src/data/ordmap/ordnode.lean b/src/data/ordmap/ordnode.lean new file mode 100644 index 0000000000000..a12102699cc80 --- /dev/null +++ b/src/data/ordmap/ordnode.lean @@ -0,0 +1,1169 @@ +/- +Copyright (c) 2017 Mario Carneiro. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Mario Carneiro +-/ +import data.list.defs +import data.nat.psub + +/-! +# Ordered sets + +This file defines a data structure for ordered sets, supporting a +variety of useful operations including insertion and deletion, +logarithmic time lookup, set operations, folds, +and conversion from lists. + +The `ordnode α` operations all assume that `α` has the structure of +a total preorder, meaning a `≤` operation that is + +* Transitive: `x ≤ y → y ≤ z → x ≤ z` +* Reflexive: `x ≤ x` +* Total: `x ≤ y ∨ y ≤ x` + +For example, in order to use this data structure as a map type, one +can store pairs `(k, v)` where `(k, v) ≤ (k', v')` is defined to mean +`k ≤ k'` (assuming that the key values are linearly ordered). + +Two values `x,y` are equivalent if `x ≤ y` and `y ≤ x`. An `ordnode α` +maintains the invariant that it never stores two equivalent nodes; +the insertion operation comes with two variants depending on whether +you want to keep the old value or the new value in case you insert a value +that is equivalent to one in the set. + +The operations in this file are not verified, in the sense that they provide +"raw operations" that work for programming purposes but the invariants +are not explicitly in the structure. See `ordset` for a verified version +of this data structure. + +## Main definitions + +* `ordnode α`: A set of values of type `α` + +## Implementation notes + +Based on weight balanced trees: + + * Stephen Adams, "Efficient sets: a balancing act", + Journal of Functional Programming 3(4):553-562, October 1993, + . + * J. Nievergelt and E.M. Reingold, + "Binary search trees of bounded balance", + SIAM journal of computing 2(1), March 1973. + +Ported from Haskell's `Data.Set`. + +## Tags + +ordered map, ordered set, data structure + +-/ + +universes u + +/-- An `ordnode α` is a finite set of values, represented as a tree. + The operations on this type maintain that the tree is balanced + and correctly stores subtree sizes at each level. -/ +inductive ordnode (α : Type u) : Type u +| nil {} : ordnode +| node (size : ℕ) (l : ordnode) (x : α) (r : ordnode) : ordnode + +namespace ordnode +variable {α : Type u} + +instance : has_emptyc (ordnode α) := ⟨nil⟩ +instance : inhabited (ordnode α) := ⟨nil⟩ + +/-- **Internal use only** + +The maximal relative difference between the sizes of +two trees, it corresponds with the `w` in Adams' paper. + +According to the Haskell comment, only `(delta, ratio)` settings +of `(3, 2)` and `(4, 2)` will work, and the proofs in +`ordset.lean` assume `delta := 3` and `ratio := 2`. -/ +@[inline] def delta := 3 + +/-- **Internal use only** + +The ratio between an outer and inner sibling of the +heavier subtree in an unbalanced setting. It determines +whether a double or single rotation should be performed +to restore balance. It is corresponds with the inverse +of `α` in Adam's article. -/ +@[inline] def ratio := 2 + +/-- O(1). Construct a singleton set containing value `a`. + + singleton 3 = {3} -/ +@[inline] protected def singleton (a : α) : ordnode α := node 1 nil a nil +local prefix `ι`:max := ordnode.singleton + +instance : has_singleton α (ordnode α) := ⟨ordnode.singleton⟩ + +/-- O(1). Get the size of the set. + + size {2, 1, 1, 4} = 3 -/ +@[inline, simp] def size : ordnode α → ℕ +| nil := 0 +| (node sz _ _ _) := sz + +/-- O(1). Is the set empty? + + empty ∅ = tt + empty {1, 2, 3} = ff -/ +@[inline] def empty : ordnode α → bool +| nil := tt +| (node _ _ _ _) := ff + +/-- **Internal use only**, because it violates the BST property on the original order. + +O(n). The dual of a tree is a tree with its left and right sides reversed throughout. +The dual of a valid BST is valid under the dual order. This is convenient for exploiting +symmetries in the algorithms. -/ +@[simp] def dual : ordnode α → ordnode α +| nil := nil +| (node s l x r) := node s (dual r) x (dual l) + +/-- **Internal use only** + +O(1). Construct a node with the correct size information, without rebalancing. -/ +@[inline, reducible] def node' (l : ordnode α) (x : α) (r : ordnode α) : ordnode α := +node (size l + size r + 1) l x r + +/-- Basic pretty printing for `ordnode α` that shows the structure of the tree. + + repr {3, 1, 2, 4} = ((∅ 1 ∅) 2 ((∅ 3 ∅) 4 ∅)) -/ +def repr {α} [has_repr α] : ordnode α → string +| nil := "∅" +| (node _ l x r) := "(" ++ repr l ++ " " ++ _root_.repr x ++ " " ++ repr r ++ ")" + +instance {α} [has_repr α] : has_repr (ordnode α) := ⟨repr⟩ + +/-- **Internal use only** + +O(1). Rebalance a tree which was previously balanced but has had its left +side grow by 1, or its right side shrink by 1. -/ +-- Note: The function has been written with tactics to avoid extra junk +def balance_l (l : ordnode α) (x : α) (r : ordnode α) : ordnode α := +by clean begin + cases id r with rs, + { cases id l with ls ll lx lr, + { exact ι x }, + { cases id ll with lls, + { cases lr with _ _ lrx, + { exact node 2 l x nil }, + { exact node 3 (ι lx) lrx (ι x) } }, + { cases id lr with lrs lrl lrx lrr, + { exact node 3 ll lx (ι x) }, + { exact if lrs < ratio * lls then + node (ls+1) ll lx (node (lrs+1) lr x nil) + else + node (ls+1) + (node (lls + size lrl + 1) ll lx lrl) lrx + (node (size lrr + 1) lrr x nil) } } } }, + { cases id l with ls ll lx lr, + { exact node (rs+1) nil x r }, + { refine if ls > delta * rs then _ else node (ls + rs + 1) l x r, + cases id ll with lls, {exact nil /-should not happen-/}, + cases id lr with lrs lrl lrx lrr, {exact nil /-should not happen-/}, + exact if lrs < ratio * lls then + node (ls + rs + 1) ll lx (node (rs + lrs + 1) lr x r) + else + node (ls + rs + 1) + (node (lls + size lrl + 1) ll lx lrl) lrx + (node (size lrr + rs + 1) lrr x r) } } +end + +/-- **Internal use only** + +O(1). Rebalance a tree which was previously balanced but has had its right +side grow by 1, or its left side shrink by 1. -/ +def balance_r (l : ordnode α) (x : α) (r : ordnode α) : ordnode α := +by clean begin + cases id l with ls, + { cases id r with rs rl rx rr, + { exact ι x }, + { cases id rr with rrs, + { cases rl with _ _ rlx, + { exact node 2 nil x r }, + { exact node 3 (ι x) rlx (ι rx) } }, + { cases id rl with rls rll rlx rlr, + { exact node 3 (ι x) rx rr }, + { exact if rls < ratio * rrs then + node (rs+1) (node (rls+1) nil x rl) rx rr + else + node (rs+1) + (node (size rll + 1) nil x rll) rlx + (node (size rlr + rrs + 1) rlr rx rr) } } } }, + { cases id r with rs rl rx rr, + { exact node (ls+1) l x nil }, + { refine if rs > delta * ls then _ else node (ls + rs + 1) l x r, + cases id rr with rrs, {exact nil /-should not happen-/}, + cases id rl with rls rll rlx rlr, {exact nil /-should not happen-/}, + exact if rls < ratio * rrs then + node (ls + rs + 1) (node (ls + rls + 1) l x rl) rx rr + else + node (ls + rs + 1) + (node (ls + size rll + 1) l x rll) rlx + (node (size rlr + rrs + 1) rlr rx rr) } } +end + +/-- **Internal use only** + +O(1). Rebalance a tree which was previously balanced but has had one side change +by at most 1. -/ +def balance (l : ordnode α) (x : α) (r : ordnode α) : ordnode α := +by clean begin + cases id l with ls ll lx lr, + { cases id r with rs rl rx rr, + { exact ι x }, + { cases id rl with rls rll rlx rlr, + { cases id rr, + { exact node 2 nil x r }, + { exact node 3 (ι x) rx rr } }, + { cases id rr with rrs, + { exact node 3 (ι x) rlx (ι rx) }, + { exact if rls < ratio * rrs then + node (rs+1) (node (rls+1) nil x rl) rx rr + else + node (rs+1) + (node (size rll + 1) nil x rll) rlx + (node (size rlr + rrs + 1) rlr rx rr) } } } }, + { cases id r with rs rl rx rr, + { cases id ll with lls, + { cases lr with _ _ lrx, + { exact node 2 l x nil }, + { exact node 3 (ι lx) lrx (ι x) } }, + { cases id lr with lrs lrl lrx lrr, + { exact node 3 ll lx (ι x) }, + { exact if lrs < ratio * lls then + node (ls+1) ll lx (node (lrs+1) lr x nil) + else + node (ls+1) + (node (lls + size lrl + 1) ll lx lrl) lrx + (node (size lrr + 1) lrr x nil) } } }, + { refine + if delta * ls < rs then _ else + if delta * rs < ls then _ else + node (ls+rs+1) l x r, + { cases id rl with rls rll rlx rlr, {exact nil /-should not happen-/}, + cases id rr with rrs, {exact nil /-should not happen-/}, + exact if rls < ratio * rrs then + node (ls+rs+1) (node (ls+rls+1) l x rl) rx rr + else + node (ls+rs+1) + (node (ls + size rll + 1) l x rll) rlx + (node (size rlr + rrs + 1) rlr rx rr) }, + { cases id ll with lls, {exact nil /-should not happen-/}, + cases id lr with lrs lrl lrx lrr, {exact nil /-should not happen-/}, + exact if lrs < ratio * lls then + node (ls+rs+1) ll lx (node (lrs+rs+1) lr x r) + else + node (ls+rs+1) + (node (lls + size lrl + 1) ll lx lrl) lrx + (node (size lrr + rs + 1) lrr x r) } } } +end + +/-- O(n). Does every element of the map satisfy property `P`? + + all (λ x, x < 5) {1, 2, 3} = true + all (λ x, x < 5) {1, 2, 3, 5} = false -/ +def all (P : α → Prop) : ordnode α → Prop +| nil := true +| (node _ l x r) := all l ∧ P x ∧ all r + +instance all.decidable {P : α → Prop} [decidable_pred P] (t) : decidable (all P t) := +by induction t; dunfold all; resetI; apply_instance + +/-- O(n). Does any element of the map satisfy property `P`? + + any (λ x, x < 2) {1, 2, 3} = true + any (λ x, x < 2) {2, 3, 5} = false -/ +def any (P : α → Prop) : ordnode α → Prop +| nil := false +| (node _ l x r) := any l ∨ P x ∨ any r + +instance any.decidable {P : α → Prop} [decidable_pred P] (t) : decidable (any P t) := +by induction t; dunfold any; resetI; apply_instance + +/-- O(n). Exact membership in the set. This is useful primarily for stating +correctness properties; use `∈` for a version that actually uses the BST property +of the tree. + + emem 2 {1, 2, 3} = true + emem 4 {1, 2, 3} = false -/ +def emem (x : α) : ordnode α → Prop := any (eq x) + +instance emem.decidable [decidable_eq α] (x : α) : ∀ t, decidable (emem x t) := +any.decidable + +/-- O(n). Approximate membership in the set, that is, whether some element in the +set is equivalent to this one in the preorder. This is useful primarily for stating +correctness properties; use `∈` for a version that actually uses the BST property +of the tree. + + amem 2 {1, 2, 3} = true + amem 4 {1, 2, 3} = false + +To see the difference with `emem`, we need a preorder that is not a partial order. +For example, suppose we compare pairs of numbers using only their first coordinate. Then: + + emem (0, 1) {(0, 0), (1, 2)} = false + amem (0, 1) {(0, 0), (1, 2)} = true + (0, 1) ∈ {(0, 0), (1, 2)} = true + +The `∈` relation is equivalent to `amem` as long as the `ordnode` is well formed, +and should always be used instead of `amem`. -/ +def amem [has_le α] (x : α) : ordnode α → Prop := any (λ y, x ≤ y ∧ y ≤ x) + +instance amem.decidable [has_le α] [@decidable_rel α (≤)] (x : α) : ∀ t, decidable (amem x t) := +any.decidable + +/-- O(log n). Return the minimum element of the tree, or the provided default value. + + find_min' 37 {1, 2, 3} = 1 + find_min' 37 ∅ = 37 -/ +def find_min' : ordnode α → α → α +| nil x := x +| (node _ l x r) _ := find_min' l x + +/-- O(log n). Return the minimum element of the tree, if it exists. + + find_min {1, 2, 3} = some 1 + find_min ∅ = none -/ +def find_min : ordnode α → option α +| nil := none +| (node _ l x r) := some (find_min' l x) + +/-- O(log n). Return the maximum element of the tree, or the provided default value. + + find_max' 37 {1, 2, 3} = 3 + find_max' 37 ∅ = 37 -/ +def find_max' : α → ordnode α → α +| x nil := x +| _ (node _ l x r) := find_max' x r + +/-- O(log n). Return the maximum element of the tree, if it exists. + + find_max {1, 2, 3} = some 3 + find_max ∅ = none -/ +def find_max : ordnode α → option α +| nil := none +| (node _ l x r) := some (find_max' x r) + +/-- O(log n). Remove the minimum element from the tree, or do nothing if it is already empty. + + erase_min {1, 2, 3} = {2, 3} + erase_min ∅ = ∅ -/ +def erase_min : ordnode α → ordnode α +| nil := nil +| (node _ nil x r) := r +| (node _ l x r) := balance_r (erase_min l) x r + +/-- O(log n). Remove the maximum element from the tree, or do nothing if it is already empty. + + erase_max {1, 2, 3} = {1, 2} + erase_max ∅ = ∅ -/ +def erase_max : ordnode α → ordnode α +| nil := nil +| (node _ l x nil) := l +| (node _ l x r) := balance_l l x (erase_max r) + +/-- **Internal use only**, because it requires a balancing constraint on the inputs. + +O(log n). Extract and remove the minimum element from a nonempty tree. -/ +def split_min' : ordnode α → α → ordnode α → α × ordnode α +| nil x r := (x, r) +| (node _ ll lx lr) x r := + let (xm, l') := split_min' ll lx lr in + (xm, balance_r l' x r) + +/-- O(log n). Extract and remove the minimum element from the tree, if it exists. + + split_min {1, 2, 3} = some (1, {2, 3}) + split_min ∅ = none -/ +def split_min : ordnode α → option (α × ordnode α) +| nil := none +| (node _ l x r) := split_min' l x r + +/-- **Internal use only**, because it requires a balancing constraint on the inputs. + +O(log n). Extract and remove the maximum element from a nonempty tree. -/ +def split_max' : ordnode α → α → ordnode α → ordnode α × α +| l x nil := (l, x) +| l x (node _ rl rx rr) := + let (r', xm) := split_max' rl rx rr in + (balance_l l x r', xm) + +/-- O(log n). Extract and remove the maximum element from the tree, if it exists. + + split_max {1, 2, 3} = some ({1, 2}, 3) + split_max ∅ = none -/ +def split_max : ordnode α → option (ordnode α × α) +| nil := none +| (node _ x l r) := split_max' x l r + +/-- **Internal use only** + +O(log(m+n)). Concatenate two trees that are balanced and ordered with respect to each other. -/ +def glue : ordnode α → ordnode α → ordnode α +| nil r := r +| l@(node _ _ _ _) nil := l +| l@(node sl ll lx lr) r@(node sr rl rx rr) := + if sl > sr then + let (l', m) := split_max' ll lx lr in balance_r l' m r + else + let (m, r') := split_min' rl rx rr in balance_l l m r' + +/-- O(log(m+n)). Concatenate two trees that are ordered with respect to each other. + + merge {1, 2} {3, 4} = {1, 2, 3, 4} + merge {3, 4} {1, 2} = precondition violation -/ +def merge (l : ordnode α) : ordnode α → ordnode α := +ordnode.rec_on l (λ r, r) $ λ ls ll lx lr IHll IHlr r, +ordnode.rec_on r (node ls ll lx lr) $ λ rs rl rx rr IHrl IHrr, +if delta * ls < rs then + balance_l IHrl rx rr +else if delta * rs < ls then + balance_r ll lx (IHlr $ node rs rl rx rr) +else glue (node ls ll lx lr) (node rs rl rx rr) + +/-- O(log n). Insert an element above all the others, without any comparisons. +(Assumes that the element is in fact above all the others). + + insert_max {1, 2} 4 = {1, 2, 4} + insert_max {1, 2} 0 = precondition violation -/ +def insert_max : ordnode α → α → ordnode α +| nil x := ι x +| (node _ l y r) x := balance_r l y (insert_max r x) + +/-- O(log n). Insert an element below all the others, without any comparisons. +(Assumes that the element is in fact below all the others). + + insert_min {1, 2} 0 = {0, 1, 2} + insert_min {1, 2} 4 = precondition violation -/ +def insert_min (x : α) : ordnode α → ordnode α +| nil := ι x +| (node _ l y r) := balance_r (insert_min l) y r + +/-- O(log(m+n)). Build a tree from an element between two trees, without any +assumption on the relative sizes. + + link {1, 2} 4 {5, 6} = {1, 2, 4, 5, 6} + link {1, 3} 2 {5} = precondition violation -/ +def link (l : ordnode α) (x : α) : ordnode α → ordnode α := +ordnode.rec_on l (insert_min x) $ λ ls ll lx lr IHll IHlr r, +ordnode.rec_on r (insert_max l x) $ λ rs rl rx rr IHrl IHrr, +if delta * ls < rs then + balance_l IHrl rx rr +else if delta * rs < ls then + balance_r ll lx (IHlr r) +else node' l x r + +/-- O(n). Filter the elements of a tree satisfying a predicate. + + filter (λ x, x < 3) {1, 2, 4} = {1, 2} + filter (λ x, x > 5) {1, 2, 4} = ∅ -/ +def filter (p : α → Prop) [decidable_pred p] : ordnode α → ordnode α +| nil := nil +| (node _ l x r) := + if p x then link (filter l) x (filter r) + else merge (filter l) (filter r) + +/-- O(n). Split the elements of a tree into those satisfying, and not satisfying, a predicate. + + partition (λ x, x < 3) {1, 2, 4} = ({1, 2}, {3}) -/ +def partition (p : α → Prop) [decidable_pred p] : ordnode α → ordnode α × ordnode α +| nil := (nil, nil) +| (node _ l x r) := + let (l₁, l₂) := partition l, (r₁, r₂) := partition r in + if p x then (link l₁ x r₁, merge l₂ r₂) + else (merge l₁ r₁, link l₂ x r₂) + +/-- O(n). Map a function across a tree, without changing the structure. Only valid when +the function is strictly monotonic, i.e. `x < y → f x < f y`. + + partition (λ x, x + 2) {1, 2, 4} = {2, 3, 6} + partition (λ x : ℕ, x - 2) {1, 2, 4} = precondition violation -/ +def map {β} (f : α → β) : ordnode α → ordnode β +| nil := nil +| (node s l x r) := node s (map l) (f x) (map r) + +/-- O(n). Fold a function across the structure of a tree. + + fold z f {1, 2, 4} = f (f z 1 z) 2 (f z 4 z) + +The exact structure of function applications depends on the tree and so +is unspecified. -/ +def fold {β} (z : β) (f : β → α → β → β) : ordnode α → β +| nil := z +| (node _ l x r) := f (fold l) x (fold r) + +/-- O(n). Fold a function from left to right (in increasing order) across the tree. + + foldl f z {1, 2, 4} = f (f (f z 1) 2) 4 -/ +def foldl {β} (f : β → α → β) : β → ordnode α → β +| z nil := z +| z (node _ l x r) := foldl (f (foldl z l) x) r + +/-- O(n). Fold a function from right to left (in decreasing order) across the tree. + + foldl f {1, 2, 4} z = f 1 (f 2 (f 4 z)) -/ +def foldr {β} (f : α → β → β) : ordnode α → β → β +| nil z := z +| (node _ l x r) z := foldr l (f x (foldr r z)) + +/-- O(n). Build a list of elements in ascending order from the tree. + + to_list {1, 2, 4} = [1, 2, 4] + to_list {2, 1, 1, 4} = [1, 2, 4] -/ +def to_list (t : ordnode α) : list α := foldr list.cons t [] + +/-- O(n). Build a list of elements in descending order from the tree. + + to_rev_list {1, 2, 4} = [4, 2, 1] + to_rev_list {2, 1, 1, 4} = [4, 2, 1] -/ +def to_rev_list (t : ordnode α) : list α := foldl (flip list.cons) [] t + +instance [has_to_string α] : has_to_string (ordnode α) := +⟨λ t, "{" ++ string.intercalate ", " (t.to_list.map to_string) ++ "}"⟩ + +meta instance [has_to_format α] : has_to_format (ordnode α) := +⟨λ t, "{" ++ format.intercalate ", " (t.to_list.map to_fmt) ++ "}"⟩ + +/-- O(n). True if the trees have the same elements, ignoring structural differences. + + equiv {1, 2, 4} {2, 1, 1, 4} = true + equiv {1, 2, 4} {1, 2, 3} = false -/ +def equiv (t₁ t₂ : ordnode α) : Prop := +t₁.size = t₂.size ∧ t₁.to_list = t₂.to_list + +instance [decidable_eq α] : decidable_rel (@equiv α) := λ t₁ t₂, and.decidable + +/-- O(2^n). Constructs the powerset of a given set, that is, the set of all subsets. + + powerset {1, 2, 3} = {∅, {1}, {2}, {3}, {1,2}, {1,3}, {2,3}, {1,2,3}} -/ +def powerset (t : ordnode α) : ordnode (ordnode α) := +insert_min nil $ foldr (λ x ts, glue (insert_min (ι x) (map (insert_min x) ts)) ts) t nil + +/-- O(m*n). The cartesian product of two sets: `(a, b) ∈ s.prod t` iff `a ∈ s` and `b ∈ t`. + + prod {1, 2} {2, 3} = {(1, 2), (1, 3), (2, 2), (2, 3)} -/ +protected def prod {β} (t₁ : ordnode α) (t₂ : ordnode β) : ordnode (α × β) := +fold nil (λ s₁ a s₂, merge s₁ $ merge (map (prod.mk a) t₂) s₂) t₁ + +/-- O(m+n). Build a set on the disjoint union by combining sets on the factors. +`inl a ∈ s.copair t` iff `a ∈ s`, and `inr b ∈ s.copair t` iff `b ∈ t`. + + copair {1, 2} {2, 3} = {inl 1, inl 2, inr 2, inr 3} -/ +protected def copair {β} (t₁ : ordnode α) (t₂ : ordnode β) : ordnode (α ⊕ β) := +merge (map sum.inl t₁) (map sum.inr t₂) + +/-- O(n). Map a partial function across a set. The result depends on a proof +that the function is defined on all members of the set. + + pmap (fin.mk : ∀ n, n < 4 → fin 4) {1, 2} H = {(1 : fin 4), (2 : fin 4)} -/ +def pmap {P : α → Prop} {β} (f : ∀ a, P a → β) : ∀ t : ordnode α, all P t → ordnode β +| nil _ := nil +| (node s l x r) ⟨hl, hx, hr⟩ := node s (pmap l hl) (f x hx) (pmap r hr) + +/-- O(n). "Attach" the information that every element of `t` satisfies property +P to these elements inside the set, producing a set in the subtype. + + attach' (λ x, x < 4) {1, 2} H = ({1, 2} : ordnode {x // x<4}) -/ +def attach' {P : α → Prop} : ∀ t, all P t → ordnode {a // P a} := pmap subtype.mk + +/-- O(log n). Get the `i`th element of the set, by its index from left to right. + + nth {a, b, c, d} 2 = some c + nth {a, b, c, d} 5 = none -/ +def nth : ordnode α → ℕ → option α +| nil i := none +| (node _ l x r) i := + match nat.psub' i (size l) with + | none := nth l i + | some 0 := some x + | some (j + 1) := nth r j + end + +/-- O(log n). Remove the `i`th element of the set, by its index from left to right. + + remove_nth {a, b, c, d} 2 = {a, b, d} + remove_nth {a, b, c, d} 5 = {a, b, c, d} -/ +def remove_nth : ordnode α → ℕ → ordnode α +| nil i := nil +| (node _ l x r) i := + match nat.psub' i (size l) with + | none := balance_r (remove_nth l i) x r + | some 0 := glue l r + | some (j + 1) := balance_l l x (remove_nth r j) + end + +/-- Auxiliary definition for `take`. (Can also be used in lieu of `take` if you know the +index is within the range of the data structure.) + + take_aux {a, b, c, d} 2 = {a, b} + take_aux {a, b, c, d} 5 = {a, b, c, d} -/ +def take_aux : ordnode α → ℕ → ordnode α +| nil i := nil +| (node _ l x r) i := + if i = 0 then nil else + match nat.psub' i (size l) with + | none := take_aux l i + | some 0 := l + | some (j + 1) := link l x (take_aux r j) + end + +/-- O(log n). Get the first `i` elements of the set, counted from the left. + + take 2 {a, b, c, d} = {a, b} + take 5 {a, b, c, d} = {a, b, c, d} -/ +def take (i : ℕ) (t : ordnode α) : ordnode α := +if size t ≤ i then t else take_aux t i + +/-- Auxiliary definition for `drop`. (Can also be used in lieu of `drop` if you know the +index is within the range of the data structure.) + + drop_aux {a, b, c, d} 2 = {c, d} + drop_aux {a, b, c, d} 5 = ∅ -/ +def drop_aux : ordnode α → ℕ → ordnode α +| nil i := nil +| t@(node _ l x r) i := + if i = 0 then t else + match nat.psub' i (size l) with + | none := link (drop_aux l i) x r + | some 0 := insert_min x r + | some (j + 1) := drop_aux r j + end + +/-- O(log n). Remove the first `i` elements of the set, counted from the left. + + drop 2 {a, b, c, d} = {c, d} + drop 5 {a, b, c, d} = ∅ -/ +def drop (i : ℕ) (t : ordnode α) : ordnode α := +if size t ≤ i then nil else drop_aux t i + +/-- Auxiliary definition for `split_at`. (Can also be used in lieu of `split_at` if you know the +index is within the range of the data structure.) + + split_at_aux {a, b, c, d} 2 = ({a, b}, {c, d}) + split_at_aux {a, b, c, d} 5 = ({a, b, c, d}, ∅) -/ +def split_at_aux : ordnode α → ℕ → ordnode α × ordnode α +| nil i := (nil, nil) +| t@(node _ l x r) i := + if i = 0 then (nil, t) else + match nat.psub' i (size l) with + | none := let (l₁, l₂) := split_at_aux l i in (l₁, link l₂ x r) + | some 0 := (glue l r, insert_min x r) + | some (j + 1) := let (r₁, r₂) := split_at_aux r j in (link l x r₁, r₂) + end + +/-- O(log n). Split a set at the `i`th element, getting the first `i` and everything else. + + split_at 2 {a, b, c, d} = ({a, b}, {c, d}) + split_at 5 {a, b, c, d} = ({a, b, c, d}, ∅) -/ +def split_at (i : ℕ) (t : ordnode α) : ordnode α × ordnode α := +if size t ≤ i then (t, nil) else split_at_aux t i + +/-- O(log n). Get an initial segment of the set that satisfies the predicate `p`. +`p` is required to be antitone, that is, `x < y → p y → p x`. + + take_while (λ x, x < 4) {1, 2, 3, 4, 5} = {1, 2, 3} + take_while (λ x, x > 4) {1, 2, 3, 4, 5} = precondition violation -/ +def take_while (p : α → Prop) [decidable_pred p] : ordnode α → ordnode α +| nil := nil +| (node _ l x r) := if p x then link l x (take_while r) else take_while l + +/-- O(log n). Remove an initial segment of the set that satisfies the predicate `p`. +`p` is required to be antitone, that is, `x < y → p y → p x`. + + drop_while (λ x, x < 4) {1, 2, 3, 4, 5} = {4, 5} + drop_while (λ x, x > 4) {1, 2, 3, 4, 5} = precondition violation -/ +def drop_while (p : α → Prop) [decidable_pred p] : ordnode α → ordnode α +| nil := nil +| (node _ l x r) := if p x then drop_while r else link (drop_while l) x r + +/-- O(log n). Split the set into those satisfying and not satisfying the predicate `p`. +`p` is required to be antitone, that is, `x < y → p y → p x`. + + span (λ x, x < 4) {1, 2, 3, 4, 5} = ({1, 2, 3}, {4, 5}) + span (λ x, x > 4) {1, 2, 3, 4, 5} = precondition violation -/ +def span (p : α → Prop) [decidable_pred p] : ordnode α → ordnode α × ordnode α +| nil := (nil, nil) +| (node _ l x r) := + if p x then let (r₁, r₂) := span r in (link l x r₁, r₂) + else let (l₁, l₂) := span l in (l₁, link l₂ x r) + +/-- Auxiliary definition for `of_asc_list`. + +**Note:** This function is defined by well founded recursion, so it will probably not compute +in the kernel, meaning that you probably can't prove things like +`of_asc_list [1, 2, 3] = {1, 2, 3}` by `rfl`. +This implementation is optimized for VM evaluation. -/ +def of_asc_list_aux₁ : ∀ l : list α, ℕ → ordnode α × {l' : list α // l'.length ≤ l.length} +| [] := λ s, (nil, ⟨[], le_refl _⟩) +| (x :: xs) := λ s, + if s = 1 then (ι x, ⟨xs, nat.le_succ _⟩) else + have _, from nat.lt_succ_self xs.length, + match of_asc_list_aux₁ xs (s.shiftl 1) with + | (t, ⟨[], h⟩) := (t, ⟨[], nat.zero_le _⟩) + | (l, ⟨y :: ys, h⟩) := + have _, from nat.le_succ_of_le h, + let (r, ⟨zs, h'⟩) := of_asc_list_aux₁ ys (s.shiftl 1) in + (link l y r, ⟨zs, le_trans h' (le_of_lt this)⟩) + end +using_well_founded +{ rel_tac := λ _ _, `[exact ⟨_, measure_wf list.length⟩], + dec_tac := `[assumption] } + +/-- Auxiliary definition for `of_asc_list`. -/ +def of_asc_list_aux₂ : list α → ordnode α → ℕ → ordnode α +| [] := λ t s, t +| (x :: xs) := λ l s, + have _, from nat.lt_succ_self xs.length, + match of_asc_list_aux₁ xs s with + | (r, ⟨ys, h⟩) := + have _, from nat.lt_succ_of_le h, + of_asc_list_aux₂ ys (link l x r) (s.shiftl 1) + end +using_well_founded +{ rel_tac := λ _ _, `[exact ⟨_, measure_wf list.length⟩], + dec_tac := `[assumption] } + +/-- O(n). Build a set from a list which is already sorted. Performs no comparisons. + + of_asc_list [1, 2, 3] = {1, 2, 3} + of_asc_list [3, 2, 1] = precondition violation -/ +def of_asc_list : list α → ordnode α +| [] := nil +| (x :: xs) := of_asc_list_aux₂ xs (ι x) 1 + +section +variables [has_le α] [@decidable_rel α (≤)] + +/-- O(log n). Does the set (approximately) contain the element `x`? That is, +is there an element that is equivalent to `x` in the order? + + 1 ∈ {1, 2, 3} = true + 4 ∈ {1, 2, 3} = false + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + (1, 1) ∈ {(0, 1), (1, 2)} = true + (3, 1) ∈ {(0, 1), (1, 2)} = false -/ +def mem (x : α) : ordnode α → bool +| nil := ff +| (node _ l y r) := + match cmp_le x y with + | ordering.lt := mem l + | ordering.eq := tt + | ordering.gt := mem r + end + +/-- O(log n). Retrieve an element in the set that is equivalent to `x` in the order, +if it exists. + + find 1 {1, 2, 3} = some 1 + find 4 {1, 2, 3} = none + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + find (1, 1) {(0, 1), (1, 2)} = some (1, 2) + find (3, 1) {(0, 1), (1, 2)} = none -/ +def find (x : α) : ordnode α → option α +| nil := none +| (node _ l y r) := + match cmp_le x y with + | ordering.lt := find l + | ordering.eq := some y + | ordering.gt := find r + end + +instance : has_mem α (ordnode α) := ⟨λ x t, t.mem x⟩ + +instance mem.decidable (x : α) (t : ordnode α) : decidable (x ∈ t) := +bool.decidable_eq _ _ + +/-- O(log n). Insert an element into the set, preserving balance and the BST property. +If an equivalent element is already in the set, the function `f` is used to generate +the element to insert (being passed the current value in the set). + + insert_with f 0 {1, 2, 3} = {0, 1, 2, 3} + insert_with f 1 {1, 2, 3} = {f 1, 2, 3} + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + insert_with f (1, 1) {(0, 1), (1, 2)} = {(0, 1), f (1, 2)} + insert_with f (3, 1) {(0, 1), (1, 2)} = {(0, 1), (1, 2), (3, 1)} -/ +def insert_with (f : α → α) (x : α) : ordnode α → ordnode α +| nil := ι x +| t@(node sz l y r) := + match cmp_le x y with + | ordering.lt := balance_l (insert_with l) y r + | ordering.eq := node sz l (f y) r + | ordering.gt := balance_r l y (insert_with r) + end + +/-- O(log n). Modify an element in the set with the given function, +doing nothing if the key is not found. +Note that the element returned by `f` must be equivalent to `x`. + + adjust_with f 0 {1, 2, 3} = {1, 2, 3} + adjust_with f 1 {1, 2, 3} = {f 1, 2, 3} + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + adjust_with f (1, 1) {(0, 1), (1, 2)} = {(0, 1), f (1, 2)} + adjust_with f (3, 1) {(0, 1), (1, 2)} = {(0, 1), (1, 2)} -/ +def adjust_with (f : α → α) (x : α) : ordnode α → ordnode α +| nil := nil +| t@(node sz l y r) := + match cmp_le x y with + | ordering.lt := node sz (adjust_with l) y r + | ordering.eq := node sz l (f y) r + | ordering.gt := node sz l y (adjust_with r) + end + +/-- O(log n). Modify an element in the set with the given function, +doing nothing if the key is not found. +Note that the element returned by `f` must be equivalent to `x`. + + update_with f 0 {1, 2, 3} = {1, 2, 3} + update_with f 1 {1, 2, 3} = {2, 3} if f 1 = none + = {a, 2, 3} if f 1 = some a -/ +def update_with (f : α → option α) (x : α) : ordnode α → ordnode α +| nil := nil +| t@(node sz l y r) := + match cmp_le x y with + | ordering.lt := balance_r (update_with l) y r + | ordering.eq := + match f y with + | none := glue l r + | some a := node sz l a r + end + | ordering.gt := balance_l l y (update_with r) + end + +/-- O(log n). Modify an element in the set with the given function, +doing nothing if the key is not found. +Note that the element returned by `f` must be equivalent to `x`. + + alter f 0 {1, 2, 3} = {1, 2, 3} if f none = none + = {a, 1, 2, 3} if f none = some a + alter f 1 {1, 2, 3} = {2, 3} if f 1 = none + = {a, 2, 3} if f 1 = some a -/ +def alter (f : option α → option α) (x : α) : ordnode α → ordnode α +| nil := option.rec_on (f none) nil ordnode.singleton +| t@(node sz l y r) := + match cmp_le x y with + | ordering.lt := balance (alter l) y r + | ordering.eq := + match f (some y) with + | none := glue l r + | some a := node sz l a r + end + | ordering.gt := balance l y (alter r) + end + +/-- O(log n). Insert an element into the set, preserving balance and the BST property. +If an equivalent element is already in the set, this replaces it. + + insert 1 {1, 2, 3} = {1, 2, 3} + insert 4 {1, 2, 3} = {1, 2, 3, 4} + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + insert (1, 1) {(0, 1), (1, 2)} = {(0, 1), (1, 1)} + insert (3, 1) {(0, 1), (1, 2)} = {(0, 1), (1, 2), (3, 1)} -/ +protected def insert (x : α) : ordnode α → ordnode α +| nil := ι x +| (node sz l y r) := + match cmp_le x y with + | ordering.lt := balance_l (insert l) y r + | ordering.eq := node sz l x r + | ordering.gt := balance_r l y (insert r) + end + +instance : has_insert α (ordnode α) := ⟨ordnode.insert⟩ + +/-- O(log n). Insert an element into the set, preserving balance and the BST property. +If an equivalent element is already in the set, the set is returned as is. + + insert' 1 {1, 2, 3} = {1, 2, 3} + insert' 4 {1, 2, 3} = {1, 2, 3, 4} + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + insert' (1, 1) {(0, 1), (1, 2)} = {(0, 1), (1, 2)} + insert' (3, 1) {(0, 1), (1, 2)} = {(0, 1), (1, 2), (3, 1)} -/ +def insert' (x : α) : ordnode α → ordnode α +| nil := ι x +| t@(node sz l y r) := + match cmp_le x y with + | ordering.lt := balance_l (insert' l) y r + | ordering.eq := t + | ordering.gt := balance_r l y (insert' r) + end + +/-- O(log n). Split the tree into those smaller than `x` and those greater than it. +If an element equivalent to `x` is in the set, it is discarded. + + split 2 {1, 2, 4} = ({1}, {4}) + split 3 {1, 2, 4} = ({1, 2}, {4}) + split 4 {1, 2, 4} = ({1, 2}, ∅) + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + split (1, 1) {(0, 1), (1, 2)} = ({(0, 1)}, ∅) + split (3, 1) {(0, 1), (1, 2)} = ({(0, 1), (1, 2)}, ∅) -/ +def split (x : α) : ordnode α → ordnode α × ordnode α +| nil := (nil, nil) +| (node sz l y r) := + match cmp_le x y with + | ordering.lt := let (lt, gt) := split l in (lt, link gt y r) + | ordering.eq := (l, r) + | ordering.gt := let (lt, gt) := split r in (link l y lt, gt) + end + +/-- O(log n). Split the tree into those smaller than `x` and those greater than it, +plus an element equivalent to `x`, if it exists. + + split 2 {1, 2, 4} = ({1}, some 2, {4}) + split 3 {1, 2, 4} = ({1, 2}, none, {4}) + split 4 {1, 2, 4} = ({1, 2}, some 4, ∅) + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + split (1, 1) {(0, 1), (1, 2)} = ({(0, 1)}, some (1, 2), ∅) + split (3, 1) {(0, 1), (1, 2)} = ({(0, 1), (1, 2)}, none, ∅) -/ +def split3 (x : α) : ordnode α → ordnode α × option α × ordnode α +| nil := (nil, none, nil) +| (node sz l y r) := + match cmp_le x y with + | ordering.lt := let (lt, f, gt) := split3 l in (lt, f, link gt y r) + | ordering.eq := (l, some y, r) + | ordering.gt := let (lt, f, gt) := split3 r in (link l y lt, f, gt) + end + +/-- O(log n). Remove an element from the set equivalent to `x`. Does nothing if there +is no such element. + + erase 1 {1, 2, 3} = {2, 3} + erase 4 {1, 2, 3} = {1, 2, 3} + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + erase (1, 1) {(0, 1), (1, 2)} = {(0, 1)} + erase (3, 1) {(0, 1), (1, 2)} = {(0, 1), (1, 2)} -/ +def erase (x : α) : ordnode α → ordnode α +| nil := nil +| t@(node sz l y r) := + match cmp_le x y with + | ordering.lt := balance_r (erase l) y r + | ordering.eq := glue l r + | ordering.gt := balance_l l y (erase r) + end + +/-- Auxiliary definition for `find_lt`. -/ +def find_lt_aux (x : α) : ordnode α → α → α +| nil best := best +| (node _ l y r) best := + if x ≤ y then find_lt_aux l best else find_lt_aux r y + +/-- O(log n). Get the largest element in the tree that is `< x`. + + find_lt 2 {1, 2, 4} = some 1 + find_lt 3 {1, 2, 4} = some 2 + find_lt 0 {1, 2, 4} = none -/ +def find_lt (x : α) : ordnode α → option α +| nil := none +| (node _ l y r) := + if x ≤ y then find_lt l else some (find_lt_aux x r y) + +/-- Auxiliary definition for `find_gt`. -/ +def find_gt_aux (x : α) : ordnode α → α → α +| nil best := best +| (node _ l y r) best := + if y ≤ x then find_gt_aux r best else find_gt_aux l y + +/-- O(log n). Get the smallest element in the tree that is `> x`. + + find_lt 2 {1, 2, 4} = some 4 + find_lt 3 {1, 2, 4} = some 4 + find_lt 4 {1, 2, 4} = none -/ +def find_gt (x : α) : ordnode α → option α +| nil := none +| (node _ l y r) := + if y ≤ x then find_gt r else some (find_gt_aux x l y) + +/-- Auxiliary definition for `find_le`. -/ +def find_le_aux (x : α) : ordnode α → α → α +| nil best := best +| (node _ l y r) best := + match cmp_le x y with + | ordering.lt := find_le_aux l best + | ordering.eq := y + | ordering.gt := find_le_aux r y + end + +/-- O(log n). Get the largest element in the tree that is `≤ x`. + + find_le 2 {1, 2, 4} = some 2 + find_le 3 {1, 2, 4} = some 2 + find_le 0 {1, 2, 4} = none -/ +def find_le (x : α) : ordnode α → option α +| nil := none +| (node _ l y r) := + match cmp_le x y with + | ordering.lt := find_le l + | ordering.eq := some y + | ordering.gt := some (find_le_aux x r y) + end + +/-- Auxiliary definition for `find_ge`. -/ +def find_ge_aux (x : α) : ordnode α → α → α +| nil best := best +| (node _ l y r) best := + match cmp_le x y with + | ordering.lt := find_ge_aux l y + | ordering.eq := y + | ordering.gt := find_ge_aux r best + end + +/-- O(log n). Get the smallest element in the tree that is `≥ x`. + + find_le 2 {1, 2, 4} = some 2 + find_le 3 {1, 2, 4} = some 4 + find_le 5 {1, 2, 4} = none -/ +def find_ge (x : α) : ordnode α → option α +| nil := none +| (node _ l y r) := + match cmp_le x y with + | ordering.lt := some (find_ge_aux x l y) + | ordering.eq := some y + | ordering.gt := find_ge r + end + +/-- Auxiliary definition for `find_index`. -/ +def find_index_aux (x : α) : ordnode α → ℕ → option ℕ +| nil i := none +| (node _ l y r) i := + match cmp_le x y with + | ordering.lt := find_index_aux l i + | ordering.eq := some (i + size l) + | ordering.gt := find_index_aux r (i + size l + 1) + end + +/-- O(log n). Get the index, counting from the left, +of an element equivalent to `x` if it exists. + + find_index 2 {1, 2, 4} = some 1 + find_index 4 {1, 2, 4} = some 2 + find_index 5 {1, 2, 4} = none -/ +def find_index (x : α) (t : ordnode α) : option ℕ := find_index_aux x t 0 + +/-- Auxiliary definition for `is_subset`. -/ +def is_subset_aux : ordnode α → ordnode α → bool +| nil _ := tt +| _ nil := ff +| (node _ l x r) t := + let (lt, found, gt) := split3 x t in + found.is_some && is_subset_aux l lt && is_subset_aux r gt + +/-- O(m+n). Is every element of `t₁` equivalent to some element of `t₂`? + + is_subset {1, 4} {1, 2, 4} = tt + is_subset {1, 3} {1, 2, 4} = ff -/ +def is_subset (t₁ t₂ : ordnode α) : bool := +to_bool (size t₁ ≤ size t₂) && is_subset_aux t₁ t₂ + +/-- O(m+n). Is every element of `t₁` not equivalent to any element of `t₂`? + + disjoint {1, 3} {2, 4} = tt + disjoint {1, 2} {2, 4} = ff -/ +def disjoint : ordnode α → ordnode α → bool +| nil _ := tt +| _ nil := tt +| (node _ l x r) t := + let (lt, found, gt) := split3 x t in + found.is_none && disjoint l lt && disjoint r gt + +/-- O(m * log(|m ∪ n| + 1)), m ≤ n. The union of two sets, preferring members of + `t₁` over those of `t₂` when equivalent elements are encountered. + + union {1, 2} {2, 3} = {1, 2, 3} + union {1, 3} {2} = {1, 2, 3} + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + union {(1, 1)} {(0, 1), (1, 2)} = {(0, 1), (1, 1)} -/ +def union : ordnode α → ordnode α → ordnode α +| t₁ nil := t₁ +| nil t₂ := t₂ +| t₁@(node s₁ l₁ x₁ r₁) t₂@(node s₂ l₂ x₂ r₂) := + if s₂ = 1 then insert' x₂ t₁ else + if s₁ = 1 then insert x₁ t₂ else + let (l₂', r₂') := split x₁ t₂ in + link (union l₁ l₂') x₁ (union r₁ r₂') + +/-- O(m * log(|m ∪ n| + 1)), m ≤ n. Difference of two sets. + + diff {1, 2} {2, 3} = {1} + diff {1, 2, 3} {2} = {1, 3} -/ +def diff : ordnode α → ordnode α → ordnode α +| t₁ nil := t₁ +| t₁ t₂@(node _ l₂ x r₂) := cond t₁.empty t₂ $ + let (l₁, r₁) := split x t₁, + l₁₂ := diff l₁ l₂, r₁₂ := diff r₁ r₂ in + if size l₁₂ + size r₁₂ = size t₁ then t₁ else + merge l₁₂ r₁₂ + +/-- O(m * log(|m ∪ n| + 1)), m ≤ n. Intersection of two sets, preferring members of +`t₁` over those of `t₂` when equivalent elements are encountered. + + inter {1, 2} {2, 3} = {2} + inter {1, 3} {2} = ∅ -/ +def inter : ordnode α → ordnode α → ordnode α +| nil t₂ := nil +| t₁@(node _ l₁ x r₁) t₂ := cond t₂.empty t₁ $ + let (l₂, y, r₂) := split3 x t₂, + l₁₂ := inter l₁ l₂, r₁₂ := inter r₁ r₂ in + cond y.is_some (link l₁₂ x r₁₂) (merge l₁₂ r₁₂) + +/-- O(n * log n). Build a set from a list, preferring elements that appear earlier in the list +in the case of equivalent elements. + + of_list [1, 2, 3] = {1, 2, 3} + of_list [2, 1, 1, 3] = {1, 2, 3} + +Using a preorder on `ℕ × ℕ` that only compares the first coordinate: + + of_list [(1, 1), (0, 1), (1, 2)] = {(0, 1), (1, 1)} -/ +def of_list (l : list α) : ordnode α := l.foldr insert nil + +/-- O(n * log n). Adaptively chooses between the linear and log-linear algorithm depending + on whether the input list is already sorted. + + of_list' [1, 2, 3] = {1, 2, 3} + of_list' [2, 1, 1, 3] = {1, 2, 3} -/ +def of_list' : list α → ordnode α +| [] := nil +| (x :: xs) := + if list.chain (λ a b, ¬ b ≤ a) x xs + then of_asc_list (x :: xs) + else of_list (x :: xs) + +/-- O(n * log n). Map a function on a set. Unlike `map` this has no requirements on +`f`, and the resulting set may be smaller than the input if `f` is noninjective. +Equivalent elements are selected with a preference for smaller source elements. + + image (λ x, x + 2) {1, 2, 4} = {3, 4, 6} + image (λ x : ℕ, x - 2) {1, 2, 4} = {0, 2} -/ +def image {α β} [has_le β] [@decidable_rel β (≤)] + (f : α → β) (t : ordnode α) : ordnode β := +of_list (t.to_list.map f) + +end + +end ordnode diff --git a/src/data/ordmap/ordset.lean b/src/data/ordmap/ordset.lean new file mode 100644 index 0000000000000..11a26a5381ace --- /dev/null +++ b/src/data/ordmap/ordset.lean @@ -0,0 +1,1478 @@ +/- +Copyright (c) 2017 Mario Carneiro. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Mario Carneiro +-/ +import data.ordmap.ordnode +import algebra.ordered_ring +import data.nat.dist +import tactic.linarith + +/-! +# Verification of the `ordnode α` datatype + +This file proves the correctness of the operations in `data.ordmap.ordnode`. +The public facing version is the type `ordset α`, which is a wrapper around +`ordnode α` which includes the correctness invariant of the type, and it exposes +parallel operations like `insert` as functions on `ordset` that do the same +thing but bundle the correctness proofs. The advantage is that it is possible +to, for example, prove that the result of `find` on `insert` will actually find +the element, while `ordnode` cannot guarantee this if the input tree did not +satisfy the type invariants. + +## Main definitions + +* `ordset α`: A well formed set of values of type `α` + +## Implementation notes + +The majority of this file is actually in the `ordnode` namespace, because we first +have to prove the correctness of all the operations (and defining what correctness +means here is actually somewhat subtle). So all the actual `ordset` operations are +at the very end, once we have all the theorems. + +An `ordnode α` is an inductive type which describes a tree which stores the `size` at +internal nodes. The correctness invariant of an `ordnode α` is: + +* `ordnode.sized t`: All internal `size` fields must match the actual measured + size of the tree. (This is not hard to satisfy.) +* `ordnode.balanced t`: Unless the tree has the form `()` or `((a) b)` or `(a (b))` + (that is, nil or a single singleton subtree), the two subtrees must satisfy + `size l ≤ δ * size r` and `size r ≤ δ * size l`, where `δ := 3` is a global + parameter of the data structure (and this property must hold recursively at subtrees). + This is why we say this is a "size balanced tree" data structure. +* `ordnode.bounded lo hi t`: The members of the tree must be in strictly increasing order, + meaning that if `a` is in the left subtree and `b` is the root, then `a ≤ b` and + `¬ (b ≤ a)`. We enforce this using `ordnode.bounded` which includes also a global + upper and lower bound. + +Because the `ordnode` file was ported from Haskell, the correctness invariants of some +of the functions have not been spelled out, and some theorems like +`ordnode.valid'.balance_l_aux` show very intricate assumptions on the sizes, +which may need to be revised if it turns out some operations violate these assumptions, +because there is a decent amount of slop in the actual data structure invariants, so the +theorem will go through with multiple choices of assumption. + +**Note:** This file is incomplete, in the sense that the intent is to have verified +versions and lemmas about all the definitions in `ordnode.lean`, but at the moment only +a few operations are verified (the hard part should be out of the way, but still). +Contributors are encouraged to pick this up and finish the job, if it appeals to you. + +## Tags + +ordered map, ordered set, data structure, verified programming + +-/ + +variable {α : Type*} +namespace ordnode + +/-! ### delta and ratio -/ + +theorem not_le_delta {s} (H : 1 ≤ s) : ¬ s ≤ delta * 0 := +λ h, by rw mul_zero at h; exact not_lt_of_le h H + +theorem delta_lt_false {a b : ℕ} + (h₁ : delta * a < b) (h₂ : delta * b < a) : false := +not_le_of_lt (lt_trans ((mul_lt_mul_left dec_trivial).2 h₁) h₂) $ +by simpa [mul_assoc] using nat.mul_le_mul_right a (dec_trivial : 1 ≤ delta * delta) + +/-! ### `singleton` -/ + +/-! ### `size` and `empty` -/ + +/-- O(n). Computes the actual number of elements in the set, ignoring the cached `size` field. -/ +def real_size : ordnode α → ℕ +| nil := 0 +| (node _ l _ r) := real_size l + real_size r + 1 + +/-! ### `sized` -/ + +/-- The `sized` property asserts that all the `size` fields in nodes match the actual size of the +respective subtrees. -/ +def sized : ordnode α → Prop +| nil := true +| (node s l _ r) := s = size l + size r + 1 ∧ sized l ∧ sized r + +theorem sized.node' {l x r} (hl : @sized α l) (hr : sized r) : sized (node' l x r) := +⟨rfl, hl, hr⟩ + +theorem sized.eq_node' {s l x r} (h : @sized α (node s l x r)) : node s l x r = node' l x r := +by rw h.1; refl + +theorem sized.size_eq {s l x r} (H : sized (@node α s l x r)) : + size (@node α s l x r) = size l + size r + 1 := H.1 + +@[elab_as_eliminator] theorem sized.induction {t} (hl : @sized α t) + {C : ordnode α → Prop} (H0 : C nil) + (H1 : ∀ l x r, C l → C r → C (node' l x r)) : C t := +begin + induction t, {exact H0}, + rw hl.eq_node', + exact H1 _ _ _ (t_ih_l hl.2.1) (t_ih_r hl.2.2) +end + +theorem size_eq_real_size : ∀ {t : ordnode α}, sized t → size t = real_size t +| nil _ := rfl +| (node s l x r) ⟨h₁, h₂, h₃⟩ := + by rw [size, h₁, size_eq_real_size h₂, size_eq_real_size h₃]; refl + +@[simp] theorem sized.size_eq_zero {t : ordnode α} (ht : sized t) : size t = 0 ↔ t = nil := +by cases t; [simp, simp [ht.1]] + +theorem sized.pos {s l x r} (h : sized (@node α s l x r)) : 0 < s := +by rw h.1; apply nat.le_add_left + +/-! `dual` -/ + +theorem dual_dual : ∀ (t : ordnode α), dual (dual t) = t +| nil := rfl +| (node s l x r) := by rw [dual, dual, dual_dual, dual_dual] + +@[simp] theorem size_dual (t : ordnode α) : size (dual t) = size t := +by cases t; refl + +/-! `balanced` -/ + +/-- The `balanced_sz l r` asserts that a hypothetical tree with children of sizes `l` and `r` is +balanced: either `l ≤ δ * r` and `r ≤ δ * r`, or the tree is trivial with a singleton on one side +and nothing on the other. -/ +def balanced_sz (l r : ℕ) : Prop := +l + r ≤ 1 ∨ (l ≤ delta * r ∧ r ≤ delta * l) + +instance balanced_sz.dec : decidable_rel balanced_sz := λ l r, or.decidable + +/-- The `balanced t` asserts that the tree `t` satisfies the balance invariants +(at every level). -/ +def balanced : ordnode α → Prop +| nil := true +| (node _ l _ r) := balanced_sz (size l) (size r) ∧ balanced l ∧ balanced r + +instance balanced.dec : decidable_pred (@balanced α) | t := +by induction t; unfold balanced; resetI; apply_instance + +theorem balanced_sz.symm {l r : ℕ} : balanced_sz l r → balanced_sz r l := +or.imp (by rw add_comm; exact id) and.symm + +theorem balanced_sz_zero {l : ℕ} : balanced_sz l 0 ↔ l ≤ 1 := +by simp [balanced_sz]; apply or_iff_left_of_imp; rintro rfl; exact zero_le_one + +theorem balanced_sz_up {l r₁ r₂ : ℕ} (h₁ : r₁ ≤ r₂) + (h₂ : l + r₂ ≤ 1 ∨ r₂ ≤ delta * l) + (H : balanced_sz l r₁) : balanced_sz l r₂ := +begin + refine or_iff_not_imp_left.2 (λ h, _), + refine ⟨_, h₂.resolve_left h⟩, + cases H, + { cases r₂, + { cases h (le_trans (nat.add_le_add_left (nat.zero_le _) _) H) }, + { exact le_trans (le_trans (nat.le_add_right _ _) H) (nat.le_add_left 1 _) } }, + { exact le_trans H.1 (nat.mul_le_mul_left _ h₁) } +end + +theorem balanced_sz_down {l r₁ r₂ : ℕ} (h₁ : r₁ ≤ r₂) + (h₂ : l + r₂ ≤ 1 ∨ l ≤ delta * r₁) + (H : balanced_sz l r₂) : balanced_sz l r₁ := +have l + r₂ ≤ 1 → balanced_sz l r₁, from + λ H, or.inl (le_trans (nat.add_le_add_left h₁ _) H), +or.cases_on H this (λ H, or.cases_on h₂ this (λ h₂, + or.inr ⟨h₂, le_trans h₁ H.2⟩)) + +theorem balanced.dual : ∀ {t : ordnode α}, balanced t → balanced (dual t) +| nil h := ⟨⟩ +| (node s l x r) ⟨b, bl, br⟩ := + ⟨by rw [size_dual, size_dual]; exact b.symm, br.dual, bl.dual⟩ + +/-! ### `rotate` and `balance` -/ + +/-- Build a tree from three nodes, left associated (ignores the invariants). -/ +def node3_l (l : ordnode α) (x : α) (m : ordnode α) (y : α) (r : ordnode α) : ordnode α := +node' (node' l x m) y r + +/-- Build a tree from three nodes, right associated (ignores the invariants). -/ +def node3_r (l : ordnode α) (x : α) (m : ordnode α) (y : α) (r : ordnode α) : ordnode α := +node' l x (node' m y r) + +/-- Build a tree from three nodes, with `a () b -> (a ()) b` and `a (b c) d -> ((a b) (c d))`. -/ +def node4_l : ordnode α → α → ordnode α → α → ordnode α → ordnode α +| l x (node _ ml y mr) z r := node' (node' l x ml) y (node' mr z r) +| l x nil z r := node3_l l x nil z r -- should not happen + +/-- Build a tree from three nodes, with `a () b -> a (() b)` and `a (b c) d -> ((a b) (c d))`. -/ +def node4_r : ordnode α → α → ordnode α → α → ordnode α → ordnode α +| l x (node _ ml y mr) z r := node' (node' l x ml) y (node' mr z r) +| l x nil z r := node3_r l x nil z r -- should not happen + +/-- Concatenate two nodes, performing a left rotation `x (y z) -> ((x y) z)` +if balance is upset. -/ +def rotate_l : ordnode α → α → ordnode α → ordnode α +| l x (node _ m y r) := + if size m < ratio * size r then node3_l l x m y r else node4_l l x m y r +| l x nil := node' l x nil -- should not happen + +/-- Concatenate two nodes, performing a right rotation `(x y) z -> (x (y z))` +if balance is upset. -/ +def rotate_r : ordnode α → α → ordnode α → ordnode α +| (node _ l x m) y r := + if size m < ratio * size l then node3_r l x m y r else node4_r l x m y r +| nil y r := node' nil y r -- should not happen + +/-- A left balance operation. This will rebalance a concatenation, assuming the original nodes are +not too far from balanced. -/ +def balance_l' (l : ordnode α) (x : α) (r : ordnode α) : ordnode α := +if size l + size r ≤ 1 then node' l x r else +if size l > delta * size r then rotate_r l x r else +node' l x r + +/-- A right balance operation. This will rebalance a concatenation, assuming the original nodes are +not too far from balanced. -/ +def balance_r' (l : ordnode α) (x : α) (r : ordnode α) : ordnode α := +if size l + size r ≤ 1 then node' l x r else +if size r > delta * size l then rotate_l l x r else +node' l x r + +/-- The full balance operation. This is the same as `balance`, but with less manual inlining. +It is somewhat easier to work with this version in proofs. -/ +def balance' (l : ordnode α) (x : α) (r : ordnode α) : ordnode α := +if size l + size r ≤ 1 then node' l x r else +if size r > delta * size l then rotate_l l x r else +if size l > delta * size r then rotate_r l x r else +node' l x r + +theorem dual_node' (l : ordnode α) (x : α) (r : ordnode α) : + dual (node' l x r) = node' (dual r) x (dual l) := by simp [node', add_comm] + +theorem dual_node3_l (l : ordnode α) (x : α) (m : ordnode α) (y : α) (r : ordnode α) : + dual (node3_l l x m y r) = node3_r (dual r) y (dual m) x (dual l) := +by simp [node3_l, node3_r, dual_node'] + +theorem dual_node3_r (l : ordnode α) (x : α) (m : ordnode α) (y : α) (r : ordnode α) : + dual (node3_r l x m y r) = node3_l (dual r) y (dual m) x (dual l) := +by simp [node3_l, node3_r, dual_node'] + +theorem dual_node4_l (l : ordnode α) (x : α) (m : ordnode α) (y : α) (r : ordnode α) : + dual (node4_l l x m y r) = node4_r (dual r) y (dual m) x (dual l) := +by cases m; simp [node4_l, node4_r, dual_node3_l, dual_node'] + +theorem dual_node4_r (l : ordnode α) (x : α) (m : ordnode α) (y : α) (r : ordnode α) : + dual (node4_r l x m y r) = node4_l (dual r) y (dual m) x (dual l) := +by cases m; simp [node4_l, node4_r, dual_node3_r, dual_node'] + +theorem dual_rotate_l (l : ordnode α) (x : α) (r : ordnode α) : + dual (rotate_l l x r) = rotate_r (dual r) x (dual l) := +by cases r; simp [rotate_l, rotate_r, dual_node']; + split_ifs; simp [dual_node3_l, dual_node4_l] + +theorem dual_rotate_r (l : ordnode α) (x : α) (r : ordnode α) : + dual (rotate_r l x r) = rotate_l (dual r) x (dual l) := +by rw [← dual_dual (rotate_l _ _ _), dual_rotate_l, dual_dual, dual_dual] + +theorem dual_balance' (l : ordnode α) (x : α) (r : ordnode α) : + dual (balance' l x r) = balance' (dual r) x (dual l) := +begin + simp [balance', add_comm], split_ifs; simp [dual_node', dual_rotate_l, dual_rotate_r], + cases delta_lt_false h_1 h_2 +end + +theorem dual_balance_l (l : ordnode α) (x : α) (r : ordnode α) : + dual (balance_l l x r) = balance_r (dual r) x (dual l) := +begin + unfold balance_l balance_r, + cases r with rs rl rx rr, + { cases l with ls ll lx lr, {refl}, + cases ll with lls lll llx llr; cases lr with lrs lrl lrx lrr; + dsimp only [dual]; try {refl}, + split_ifs; repeat {simp [h, add_comm]} }, + { cases l with ls ll lx lr, {refl}, + dsimp only [dual], + split_ifs, swap, {simp [add_comm]}, + cases ll with lls lll llx llr; cases lr with lrs lrl lrx lrr; try {refl}, + dsimp only [dual], + split_ifs; simp [h, add_comm] }, +end + +theorem dual_balance_r (l : ordnode α) (x : α) (r : ordnode α) : + dual (balance_r l x r) = balance_l (dual r) x (dual l) := +by rw [← dual_dual (balance_l _ _ _), dual_balance_l, dual_dual, dual_dual] + +theorem sized.node3_l {l x m y r} + (hl : @sized α l) (hm : sized m) (hr : sized r) : sized (node3_l l x m y r) := +(hl.node' hm).node' hr + +theorem sized.node3_r {l x m y r} + (hl : @sized α l) (hm : sized m) (hr : sized r) : sized (node3_r l x m y r) := +hl.node' (hm.node' hr) + +theorem sized.node4_l {l x m y r} + (hl : @sized α l) (hm : sized m) (hr : sized r) : sized (node4_l l x m y r) := +by cases m; [exact (hl.node' hm).node' hr, + exact (hl.node' hm.2.1).node' (hm.2.2.node' hr)] + +theorem node3_l_size {l x m y r} : + size (@node3_l α l x m y r) = size l + size m + size r + 2 := +by dsimp [node3_l, node', size]; rw add_right_comm _ 1 + +theorem node3_r_size {l x m y r} : + size (@node3_r α l x m y r) = size l + size m + size r + 2 := +by dsimp [node3_r, node', size]; rw [← add_assoc, ← add_assoc] + +theorem node4_l_size {l x m y r} (hm : sized m) : + size (@node4_l α l x m y r) = size l + size m + size r + 2 := +by cases m; simp [node4_l, node3_l, node', add_comm, add_left_comm]; [skip, simp [size, hm.1]]; + rw [← add_assoc, ← bit0]; simp [add_comm, add_left_comm] + +theorem sized.dual : ∀ {t : ordnode α} (h : sized t), sized (dual t) +| nil h := ⟨⟩ +| (node s l x r) ⟨rfl, sl, sr⟩ := ⟨by simp [size_dual, add_comm], sized.dual sr, sized.dual sl⟩ + +theorem sized.dual_iff {t : ordnode α} : sized (dual t) ↔ sized t := +⟨λ h, by rw ← dual_dual t; exact h.dual, sized.dual⟩ + +theorem sized.rotate_l {l x r} (hl : @sized α l) (hr : sized r) : sized (rotate_l l x r) := +begin + cases r, {exact hl.node' hr}, + rw rotate_l, split_ifs, + { exact hl.node3_l hr.2.1 hr.2.2 }, + { exact hl.node4_l hr.2.1 hr.2.2 } +end + +theorem sized.rotate_r {l x r} (hl : @sized α l) (hr : sized r) : sized (rotate_r l x r) := +sized.dual_iff.1 $ by rw dual_rotate_r; exact hr.dual.rotate_l hl.dual + +theorem sized.rotate_l_size {l x r} (hm : sized r) : + size (@rotate_l α l x r) = size l + size r + 1 := +begin + cases r; simp [rotate_l], + simp [size, hm.1, add_comm, add_left_comm], rw [← add_assoc, ← bit0], simp, + split_ifs; simp [node3_l_size, node4_l_size hm.2.1, add_comm, add_left_comm] +end + +theorem sized.rotate_r_size {l x r} (hl : sized l) : + size (@rotate_r α l x r) = size l + size r + 1 := +by rw [← size_dual, dual_rotate_r, hl.dual.rotate_l_size, + size_dual, size_dual, add_comm (size l)] + +theorem sized.balance' {l x r} (hl : @sized α l) (hr : sized r) : sized (balance' l x r) := +begin + unfold balance', split_ifs, + { exact hl.node' hr }, + { exact hl.rotate_l hr }, + { exact hl.rotate_r hr }, + { exact hl.node' hr } +end + +theorem size_balance' {l x r} (hl : @sized α l) (hr : sized r) : + size (@balance' α l x r) = size l + size r + 1 := +begin + unfold balance', split_ifs, + { refl }, + { exact hr.rotate_l_size }, + { exact hl.rotate_r_size }, + { refl } +end + +/-! ## `all`, `any`, `emem`, `amem` -/ + +theorem all.imp {P Q : α → Prop} (H : ∀ a, P a → Q a) : ∀ {t}, all P t → all Q t +| nil h := ⟨⟩ +| (node _ l x r) ⟨h₁, h₂, h₃⟩ := ⟨h₁.imp, H _ h₂, h₃.imp⟩ + +theorem any.imp {P Q : α → Prop} (H : ∀ a, P a → Q a) : ∀ {t}, any P t → any Q t +| nil := id +| (node _ l x r) := or.imp any.imp $ or.imp (H _) any.imp + +theorem all_singleton {P : α → Prop} {x : α} : all P (singleton x) ↔ P x := +⟨λ h, h.2.1, λ h, ⟨⟨⟩, h, ⟨⟩⟩⟩ + +theorem any_singleton {P : α → Prop} {x : α} : any P (singleton x) ↔ P x := +⟨by rintro (⟨⟨⟩⟩ | h | ⟨⟨⟩⟩); exact h, λ h, or.inr (or.inl h)⟩ + +theorem all_dual {P : α → Prop} : ∀ {t : ordnode α}, + all P (dual t) ↔ all P t +| nil := iff.rfl +| (node s l x r) := + ⟨λ ⟨hr, hx, hl⟩, ⟨all_dual.1 hl, hx, all_dual.1 hr⟩, + λ ⟨hl, hx, hr⟩, ⟨all_dual.2 hr, hx, all_dual.2 hl⟩⟩ + +theorem all_iff_forall {P : α → Prop} : ∀ {t}, all P t ↔ ∀ x, emem x t → P x +| nil := (iff_true_intro $ by rintro _ ⟨⟩).symm +| (node _ l x r) := + by simp [all, emem, all_iff_forall, any, or_imp_distrib, forall_and_distrib] + +theorem any_iff_exists {P : α → Prop} : ∀ {t}, any P t ↔ ∃ x, emem x t ∧ P x +| nil := ⟨by rintro ⟨⟩, by rintro ⟨_, ⟨⟩, _⟩⟩ +| (node _ l x r) := + by simp [any, emem, any_iff_exists, or_and_distrib_right, exists_or_distrib] + +theorem emem_iff_all {x : α} {t} : emem x t ↔ ∀ P, all P t → P x := +⟨λ h P al, all_iff_forall.1 al _ h, + λ H, H _ $ all_iff_forall.2 $ λ _, id⟩ + +theorem all_node' {P l x r} : + @all α P (node' l x r) ↔ all P l ∧ P x ∧ all P r := iff.rfl + +theorem all_node3_l {P l x m y r} : + @all α P (node3_l l x m y r) ↔ all P l ∧ P x ∧ all P m ∧ P y ∧ all P r := +by simp [node3_l, all_node', and_assoc] + +theorem all_node3_r {P l x m y r} : + @all α P (node3_r l x m y r) ↔ all P l ∧ P x ∧ all P m ∧ P y ∧ all P r := iff.rfl + +theorem all_node4_l {P l x m y r} : + @all α P (node4_l l x m y r) ↔ all P l ∧ P x ∧ all P m ∧ P y ∧ all P r := +by cases m; simp [node4_l, all_node', all, all_node3_l, and_assoc] + +theorem all_node4_r {P l x m y r} : + @all α P (node4_r l x m y r) ↔ all P l ∧ P x ∧ all P m ∧ P y ∧ all P r := +by cases m; simp [node4_r, all_node', all, all_node3_r, and_assoc] + +theorem all_rotate_l {P l x r} : + @all α P (rotate_l l x r) ↔ all P l ∧ P x ∧ all P r := +by cases r; simp [rotate_l, all_node']; + split_ifs; simp [all_node3_l, all_node4_l, all] + +theorem all_rotate_r {P l x r} : + @all α P (rotate_r l x r) ↔ all P l ∧ P x ∧ all P r := +by rw [← all_dual, dual_rotate_r, all_rotate_l]; + simp [all_dual, and_comm, and.left_comm] + +theorem all_balance' {P l x r} : + @all α P (balance' l x r) ↔ all P l ∧ P x ∧ all P r := +by rw balance'; split_ifs; simp [all_node', all_rotate_l, all_rotate_r] + +/-! ### `to_list` -/ + +theorem foldr_cons_eq_to_list : ∀ (t : ordnode α) (r : list α), + t.foldr list.cons r = to_list t ++ r +| nil r := rfl +| (node _ l x r) r' := by rw [foldr, foldr_cons_eq_to_list, foldr_cons_eq_to_list, + ← list.cons_append, ← list.append_assoc, ← foldr_cons_eq_to_list]; refl + +@[simp] theorem to_list_nil : to_list (@nil α) = [] := rfl + +@[simp] theorem to_list_node (s l x r) : to_list (@node α s l x r) = to_list l ++ x :: to_list r := +by rw [to_list, foldr, foldr_cons_eq_to_list]; refl + +theorem emem_iff_mem_to_list {x : α} {t} : emem x t ↔ x ∈ to_list t := +by unfold emem; induction t; simp [any, *, or_assoc] + +theorem length_to_list' : ∀ t : ordnode α, (to_list t).length = t.real_size +| nil := rfl +| (node _ l _ r) := by rw [to_list_node, list.length_append, list.length_cons, + length_to_list', length_to_list']; refl + +theorem length_to_list {t : ordnode α} (h : sized t) : (to_list t).length = t.size := +by rw [length_to_list', size_eq_real_size h] + +theorem equiv_iff {t₁ t₂ : ordnode α} (h₁ : sized t₁) (h₂ : sized t₂) : + equiv t₁ t₂ ↔ to_list t₁ = to_list t₂ := +and_iff_right_of_imp $ λ h, by rw [← length_to_list h₁, h, length_to_list h₂] + +/-! ### `(find/erase/split)_(min/max)` -/ + +theorem find_min'_dual : ∀ t (x : α), find_min' (dual t) x = find_max' x t +| nil x := rfl +| (node _ l x r) _ := find_min'_dual r x + +theorem find_max'_dual (t) (x : α) : find_max' x (dual t) = find_min' t x := +by rw [← find_min'_dual, dual_dual] + +theorem find_min_dual : ∀ t : ordnode α, find_min (dual t) = find_max t +| nil := rfl +| (node _ l x r) := congr_arg some $ find_min'_dual _ _ + +theorem find_max_dual (t : ordnode α) : find_max (dual t) = find_min t := +by rw [← find_min_dual, dual_dual] + +theorem dual_erase_min : ∀ t : ordnode α, dual (erase_min t) = erase_max (dual t) +| nil := rfl +| (node _ nil x r) := rfl +| (node _ l@(node _ _ _ _) x r) := + by rw [erase_min, dual_balance_r, dual_erase_min, dual, dual, dual, erase_max] + +theorem dual_erase_max (t : ordnode α) : dual (erase_max t) = erase_min (dual t) := +by rw [← dual_dual (erase_min _), dual_erase_min, dual_dual] + +theorem split_min_eq : ∀ s l (x : α) r, + split_min' l x r = (find_min' l x, erase_min (node s l x r)) +| _ nil x r := rfl +| _ (node ls ll lx lr) x r := + by rw [split_min', split_min_eq, split_min', find_min', erase_min] + +theorem split_max_eq : ∀ s l (x : α) r, + split_max' l x r = (erase_max (node s l x r), find_max' x r) +| _ l x nil := rfl +| _ l x (node ls ll lx lr) := + by rw [split_max', split_max_eq, split_max', find_max', erase_max] + +@[elab_as_eliminator] +theorem find_min'_all {P : α → Prop} : ∀ t (x : α), all P t → P x → P (find_min' t x) +| nil x h hx := hx +| (node _ ll lx lr) x ⟨h₁, h₂, h₃⟩ hx := find_min'_all _ _ h₁ h₂ + +@[elab_as_eliminator] +theorem find_max'_all {P : α → Prop} : ∀ (x : α) t, P x → all P t → P (find_max' x t) +| x nil hx h := hx +| x (node _ ll lx lr) hx ⟨h₁, h₂, h₃⟩ := find_max'_all _ _ h₂ h₃ + +/-! ### `glue` -/ + +/-! ### `merge` -/ + +@[simp] theorem merge_nil_left (t : ordnode α) : merge t nil = t := by cases t; refl + +@[simp] theorem merge_nil_right (t : ordnode α) : merge nil t = t := rfl + +@[simp] theorem merge_node {ls ll lx lr rs rl rx rr} : + merge (@node α ls ll lx lr) (node rs rl rx rr) = + if delta * ls < rs then + balance_l (merge (node ls ll lx lr) rl) rx rr + else if delta * rs < ls then + balance_r ll lx (merge lr (node rs rl rx rr)) + else glue (node ls ll lx lr) (node rs rl rx rr) := rfl + +/-! ### `insert` -/ + +theorem dual_insert [preorder α] [is_total α (≤)] [@decidable_rel α (≤)] (x : α) : + ∀ t : ordnode α, dual (ordnode.insert x t) = @ordnode.insert (order_dual α) _ _ x (dual t) +| nil := rfl +| (node _ l y r) := begin + rw [ordnode.insert, dual, ordnode.insert, order_dual.cmp_le_flip, ← cmp_le_swap x y], + cases cmp_le x y; + simp [ordering.swap, ordnode.insert, dual_balance_l, dual_balance_r, dual_insert] +end + +/-! ### `balance` properties -/ + +theorem balance_eq_balance' {l x r} + (hl : balanced l) (hr : balanced r) + (sl : sized l) (sr : sized r) : + @balance α l x r = balance' l x r := +begin + cases l with ls ll lx lr, + { cases r with rs rl rx rr, + { refl }, + { rw sr.eq_node' at hr ⊢, + cases rl with rls rll rlx rlr; cases rr with rrs rrl rrx rrr; + dsimp [balance, balance'], + { refl }, + { have : size rrl = 0 ∧ size rrr = 0, + { have := balanced_sz_zero.1 hr.1.symm, + rwa [size, sr.2.2.1, nat.succ_le_succ_iff, + nat.le_zero_iff, add_eq_zero_iff] at this }, + cases sr.2.2.2.1.size_eq_zero.1 this.1, + cases sr.2.2.2.2.size_eq_zero.1 this.2, + have : rrs = 1 := sr.2.2.1, subst rrs, + rw [if_neg, if_pos, rotate_l, if_pos], {refl}, + all_goals {exact dec_trivial} }, + { have : size rll = 0 ∧ size rlr = 0, + { have := balanced_sz_zero.1 hr.1, + rwa [size, sr.2.1.1, nat.succ_le_succ_iff, + nat.le_zero_iff, add_eq_zero_iff] at this }, + cases sr.2.1.2.1.size_eq_zero.1 this.1, + cases sr.2.1.2.2.size_eq_zero.1 this.2, + have : rls = 1 := sr.2.1.1, subst rls, + rw [if_neg, if_pos, rotate_l, if_neg], {refl}, + all_goals {exact dec_trivial} }, + { symmetry, rw [zero_add, if_neg, if_pos, rotate_l], + { split_ifs, + { simp [node3_l, node', add_comm, add_left_comm] }, + { simp [node4_l, node', sr.2.1.1, add_comm, add_left_comm] } }, + { exact dec_trivial }, + { exact not_le_of_gt (nat.succ_lt_succ + (add_pos sr.2.1.pos sr.2.2.pos)) } } } }, + { cases r with rs rl rx rr, + { rw sl.eq_node' at hl ⊢, + cases ll with lls lll llx llr; cases lr with lrs lrl lrx lrr; + dsimp [balance, balance'], + { refl }, + { have : size lrl = 0 ∧ size lrr = 0, + { have := balanced_sz_zero.1 hl.1.symm, + rwa [size, sl.2.2.1, nat.succ_le_succ_iff, + nat.le_zero_iff, add_eq_zero_iff] at this }, + cases sl.2.2.2.1.size_eq_zero.1 this.1, + cases sl.2.2.2.2.size_eq_zero.1 this.2, + have : lrs = 1 := sl.2.2.1, subst lrs, + rw [if_neg, if_neg, if_pos, rotate_r, if_neg], {refl}, + all_goals {exact dec_trivial} }, + { have : size lll = 0 ∧ size llr = 0, + { have := balanced_sz_zero.1 hl.1, + rwa [size, sl.2.1.1, nat.succ_le_succ_iff, + nat.le_zero_iff, add_eq_zero_iff] at this }, + cases sl.2.1.2.1.size_eq_zero.1 this.1, + cases sl.2.1.2.2.size_eq_zero.1 this.2, + have : lls = 1 := sl.2.1.1, subst lls, + rw [if_neg, if_neg, if_pos, rotate_r, if_pos], {refl}, + all_goals {exact dec_trivial} }, + { symmetry, rw [if_neg, if_neg, if_pos, rotate_r], + { split_ifs, + { simp [node3_r, node', add_comm, add_left_comm] }, + { simp [node4_r, node', sl.2.2.1, add_comm, add_left_comm] } }, + { exact dec_trivial }, + { exact dec_trivial }, + { exact not_le_of_gt (nat.succ_lt_succ + (add_pos sl.2.1.pos sl.2.2.pos)) } } }, + { simp [balance, balance'], + symmetry, rw [if_neg], + { split_ifs, + { have rd : delta ≤ size rl + size rr, + { have := lt_of_le_of_lt (nat.mul_le_mul_left _ sl.pos) h, + rwa [sr.1, nat.lt_succ_iff] at this }, + cases rl with rls rll rlx rlr, + { rw [size, zero_add] at rd, + exact absurd (le_trans rd (balanced_sz_zero.1 hr.1.symm)) dec_trivial }, + cases rr with rrs rrl rrx rrr, + { exact absurd (le_trans rd (balanced_sz_zero.1 hr.1)) dec_trivial }, + dsimp [rotate_l], split_ifs, + { simp [node3_l, node', sr.1, add_comm, add_left_comm] }, + { simp [node4_l, node', sr.1, sr.2.1.1, add_comm, add_left_comm] } }, + { have ld : delta ≤ size ll + size lr, + { have := lt_of_le_of_lt (nat.mul_le_mul_left _ sr.pos) h_1, + rwa [sl.1, nat.lt_succ_iff] at this }, + cases ll with lls lll llx llr, + { rw [size, zero_add] at ld, + exact absurd (le_trans ld (balanced_sz_zero.1 hl.1.symm)) dec_trivial }, + cases lr with lrs lrl lrx lrr, + { exact absurd (le_trans ld (balanced_sz_zero.1 hl.1)) dec_trivial }, + dsimp [rotate_r], split_ifs, + { simp [node3_r, node', sl.1, add_comm, add_left_comm] }, + { simp [node4_r, node', sl.1, sl.2.2.1, add_comm, add_left_comm] } }, + { simp [node'] } }, + { exact not_le_of_gt (add_le_add sl.pos sr.pos : 2 ≤ ls + rs) } } } +end + +theorem balance_l_eq_balance {l x r} + (sl : sized l) (sr : sized r) + (H1 : size l = 0 → size r ≤ 1) + (H2 : 1 ≤ size l → 1 ≤ size r → size r ≤ delta * size l) : + @balance_l α l x r = balance l x r := +begin + cases r with rs rl rx rr, + { refl }, + { cases l with ls ll lx lr, + { have : size rl = 0 ∧ size rr = 0, + { have := H1 rfl, + rwa [size, sr.1, nat.succ_le_succ_iff, + nat.le_zero_iff, add_eq_zero_iff] at this }, + cases sr.2.1.size_eq_zero.1 this.1, + cases sr.2.2.size_eq_zero.1 this.2, + rw sr.eq_node', refl }, + { replace H2 : ¬ rs > delta * ls := not_lt_of_le (H2 sl.pos sr.pos), + simp [balance_l, balance, H2]; split_ifs; simp [add_comm] } } +end + +/-- `raised n m` means `m` is either equal or one up from `n`. -/ +def raised (n m : ℕ) : Prop := m = n ∨ m = n + 1 + +theorem raised_iff {n m} : raised n m ↔ n ≤ m ∧ m ≤ n + 1 := +begin + split, rintro (rfl | rfl), + { exact ⟨le_refl _, nat.le_succ _⟩ }, + { exact ⟨nat.le_succ _, le_refl _⟩ }, + { rintro ⟨h₁, h₂⟩, + rcases eq_or_lt_of_le h₁ with rfl | h₁, + { exact or.inl rfl }, + { exact or.inr (le_antisymm h₂ h₁) } } +end + +theorem raised.dist_le {n m} (H : raised n m) : nat.dist n m ≤ 1 := +by cases raised_iff.1 H with H1 H2; + rwa [nat.dist_eq_sub_of_le H1, nat.sub_le_left_iff_le_add] + +theorem raised.dist_le' {n m} (H : raised n m) : nat.dist m n ≤ 1 := +by rw nat.dist_comm; exact H.dist_le + +theorem raised.add_left (k) {n m} (H : raised n m) : raised (k + n) (k + m) := +begin + rcases H with rfl | rfl, + { exact or.inl rfl }, + { exact or.inr rfl } +end + +theorem raised.add_right (k) {n m} (H : raised n m) : raised (n + k) (m + k) := +by rw [add_comm, add_comm m]; exact H.add_left _ + +theorem raised.right {l x₁ x₂ r₁ r₂} (H : raised (size r₁) (size r₂)) : + raised (size (@node' α l x₁ r₁)) (size (@node' α l x₂ r₂)) := +begin + dsimp [node', size], generalize_hyp : size r₂ = m at H ⊢, + rcases H with rfl | rfl, + { exact or.inl rfl }, + { exact or.inr rfl } +end + +theorem balance_l_eq_balance' {l x r} + (hl : balanced l) (hr : balanced r) + (sl : sized l) (sr : sized r) + (H : (∃ l', raised l' (size l) ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised (size r) r' ∧ balanced_sz (size l) r')) : + @balance_l α l x r = balance' l x r := +begin + rw [← balance_eq_balance' hl hr sl sr, balance_l_eq_balance sl sr], + { intro l0, rw l0 at H, + rcases H with ⟨_, ⟨⟨⟩⟩|⟨⟨⟩⟩, H⟩ | ⟨r', e, H⟩, + { exact balanced_sz_zero.1 H.symm }, + exact le_trans (raised_iff.1 e).1 (balanced_sz_zero.1 H.symm) }, + { intros l1 r1, + rcases H with ⟨l', e, H | ⟨H₁, H₂⟩⟩ | ⟨r', e, H | ⟨H₁, H₂⟩⟩, + { exact le_trans (le_trans (nat.le_add_left _ _) H) + (mul_pos dec_trivial l1 : (0:ℕ)<_) }, + { exact le_trans H₂ (nat.mul_le_mul_left _ (raised_iff.1 e).1) }, + { cases raised_iff.1 e, unfold delta, linarith }, + { exact le_trans (raised_iff.1 e).1 H₂ } } +end + +theorem balance_sz_dual {l r} + (H : (∃ l', raised (@size α l) l' ∧ balanced_sz l' (@size α r)) ∨ + ∃ r', raised r' (size r) ∧ balanced_sz (size l) r') : + (∃ l', raised l' (size (dual r)) ∧ balanced_sz l' (size (dual l))) ∨ + ∃ r', raised (size (dual l)) r' ∧ balanced_sz (size (dual r)) r' := +begin + rw [size_dual, size_dual], + exact H.symm.imp + (Exists.imp $ λ _, and.imp_right balanced_sz.symm) + (Exists.imp $ λ _, and.imp_right balanced_sz.symm) +end + +theorem size_balance_l {l x r} + (hl : balanced l) (hr : balanced r) + (sl : sized l) (sr : sized r) + (H : (∃ l', raised l' (size l) ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised (size r) r' ∧ balanced_sz (size l) r')) : + size (@balance_l α l x r) = size l + size r + 1 := +by rw [balance_l_eq_balance' hl hr sl sr H, size_balance' sl sr] + +theorem all_balance_l {P l x r} + (hl : balanced l) (hr : balanced r) + (sl : sized l) (sr : sized r) + (H : (∃ l', raised l' (size l) ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised (size r) r' ∧ balanced_sz (size l) r')) : + all P (@balance_l α l x r) ↔ all P l ∧ P x ∧ all P r := +by rw [balance_l_eq_balance' hl hr sl sr H, all_balance'] + +theorem balance_r_eq_balance' {l x r} + (hl : balanced l) (hr : balanced r) + (sl : sized l) (sr : sized r) + (H : (∃ l', raised (size l) l' ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised r' (size r) ∧ balanced_sz (size l) r')) : + @balance_r α l x r = balance' l x r := +by rw [← dual_dual (balance_r l x r), dual_balance_r, + balance_l_eq_balance' hr.dual hl.dual sr.dual sl.dual (balance_sz_dual H), + ← dual_balance', dual_dual] + +theorem size_balance_r {l x r} + (hl : balanced l) (hr : balanced r) + (sl : sized l) (sr : sized r) + (H : (∃ l', raised (size l) l' ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised r' (size r) ∧ balanced_sz (size l) r')) : + size (@balance_r α l x r) = size l + size r + 1 := +by rw [balance_r_eq_balance' hl hr sl sr H, size_balance' sl sr] + +theorem all_balance_r {P l x r} + (hl : balanced l) (hr : balanced r) + (sl : sized l) (sr : sized r) + (H : (∃ l', raised (size l) l' ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised r' (size r) ∧ balanced_sz (size l) r')) : + all P (@balance_r α l x r) ↔ all P l ∧ P x ∧ all P r := +by rw [balance_r_eq_balance' hl hr sl sr H, all_balance'] + +/-! ### `bounded` -/ + +section +variable [preorder α] + +/-- `bounded t lo hi` says that every element `x ∈ t` is in the range `lo < x < hi`, and also this +property holds recursively in subtrees, making the full tree a BST. The bounds can be set to +`lo = ⊥` and `hi = ⊤` if we care only about the internal ordering constraints. -/ +def bounded : ordnode α → with_bot α → with_top α → Prop +| nil (some a) (some b) := a < b +| nil _ _ := true +| (node _ l x r) o₁ o₂ := bounded l o₁ ↑x ∧ bounded r ↑x o₂ + +theorem bounded.dual : ∀ {t : ordnode α} {o₁ o₂} (h : bounded t o₁ o₂), + @bounded (order_dual α) _ (dual t) o₂ o₁ +| nil o₁ o₂ h := by cases o₁; cases o₂; try {trivial}; exact h +| (node s l x r) _ _ ⟨ol, or⟩ := ⟨or.dual, ol.dual⟩ + +theorem bounded.dual_iff {t : ordnode α} {o₁ o₂} : bounded t o₁ o₂ ↔ + @bounded (order_dual α) _ (dual t) o₂ o₁ := +⟨bounded.dual, λ h, by have := bounded.dual h; + rwa [dual_dual, order_dual.preorder.dual_dual] at this⟩ + +theorem bounded.weak_left : ∀ {t : ordnode α} {o₁ o₂}, bounded t o₁ o₂ → bounded t ⊥ o₂ +| nil o₁ o₂ h := by cases o₂; try {trivial}; exact h +| (node s l x r) _ _ ⟨ol, or⟩ := ⟨ol.weak_left, or⟩ + +theorem bounded.weak_right : ∀ {t : ordnode α} {o₁ o₂}, bounded t o₁ o₂ → bounded t o₁ ⊤ +| nil o₁ o₂ h := by cases o₁; try {trivial}; exact h +| (node s l x r) _ _ ⟨ol, or⟩ := ⟨ol, or.weak_right⟩ + +theorem bounded.weak {t : ordnode α} {o₁ o₂} (h : bounded t o₁ o₂) : bounded t ⊥ ⊤ := +h.weak_left.weak_right + +theorem bounded.mono_left {x y : α} (xy : x ≤ y) : + ∀ {t : ordnode α} {o}, bounded t ↑y o → bounded t ↑x o +| nil none h := ⟨⟩ +| nil (some z) h := lt_of_le_of_lt xy h +| (node s l z r) o ⟨ol, or⟩ := ⟨ol.mono_left, or⟩ + +theorem bounded.mono_right {x y : α} (xy : x ≤ y) : + ∀ {t : ordnode α} {o}, bounded t o ↑x → bounded t o ↑y +| nil none h := ⟨⟩ +| nil (some z) h := lt_of_lt_of_le h xy +| (node s l z r) o ⟨ol, or⟩ := ⟨ol, or.mono_right⟩ + +theorem bounded.to_lt : ∀ {t : ordnode α} {x y : α}, bounded t x y → x < y +| nil x y h := h +| (node _ l y r) x z ⟨h₁, h₂⟩ := lt_trans h₁.to_lt h₂.to_lt + +theorem bounded.to_nil {t : ordnode α} : ∀ {o₁ o₂}, bounded t o₁ o₂ → bounded nil o₁ o₂ +| none _ h := ⟨⟩ +| (some _) none h := ⟨⟩ +| (some x) (some y) h := h.to_lt + +theorem bounded.trans_left {t₁ t₂ : ordnode α} {x : α} : + ∀ {o₁ o₂}, bounded t₁ o₁ ↑x → bounded t₂ ↑x o₂ → bounded t₂ o₁ o₂ +| none o₂ h₁ h₂ := h₂.weak_left +| (some y) o₂ h₁ h₂ := h₂.mono_left (le_of_lt h₁.to_lt) + +theorem bounded.trans_right {t₁ t₂ : ordnode α} {x : α} : + ∀ {o₁ o₂}, bounded t₁ o₁ ↑x → bounded t₂ ↑x o₂ → bounded t₁ o₁ o₂ +| o₁ none h₁ h₂ := h₁.weak_right +| o₁ (some y) h₁ h₂ := h₁.mono_right (le_of_lt h₂.to_lt) + +theorem bounded.mem_lt : ∀ {t o} {x : α}, bounded t o ↑x → all (< x) t +| nil o x _ := ⟨⟩ +| (node _ l y r) o x ⟨h₁, h₂⟩ := + ⟨h₁.mem_lt.imp (λ z h, lt_trans h h₂.to_lt), h₂.to_lt, h₂.mem_lt⟩ + +theorem bounded.mem_gt : ∀ {t o} {x : α}, bounded t ↑x o → all (> x) t +| nil o x _ := ⟨⟩ +| (node _ l y r) o x ⟨h₁, h₂⟩ := + ⟨h₁.mem_gt, h₁.to_lt, h₂.mem_gt.imp (λ z, lt_trans h₁.to_lt)⟩ + +theorem bounded.of_lt : ∀ {t o₁ o₂} {x : α}, + bounded t o₁ o₂ → bounded nil o₁ ↑x → all (< x) t → bounded t o₁ ↑x +| nil o₁ o₂ x _ hn _ := hn +| (node _ l y r) o₁ o₂ x ⟨h₁, h₂⟩ hn ⟨al₁, al₂, al₃⟩ := ⟨h₁, h₂.of_lt al₂ al₃⟩ + +theorem bounded.of_gt : ∀ {t o₁ o₂} {x : α}, + bounded t o₁ o₂ → bounded nil ↑x o₂ → all (> x) t → bounded t ↑x o₂ +| nil o₁ o₂ x _ hn _ := hn +| (node _ l y r) o₁ o₂ x ⟨h₁, h₂⟩ hn ⟨al₁, al₂, al₃⟩ := ⟨h₁.of_gt al₂ al₁, h₂⟩ + +theorem bounded.to_sep {t₁ t₂ o₁ o₂} {x : α} + (h₁ : bounded t₁ o₁ ↑x) (h₂ : bounded t₂ ↑x o₂) : t₁.all (λ y, t₂.all (λ z : α, y < z)) := +h₁.mem_lt.imp $ λ y yx, h₂.mem_gt.imp $ λ z xz, lt_trans yx xz + +end + +/-! ### `valid` -/ + +section +variable [preorder α] + +/-- The validity predicate for an `ordnode` subtree. This asserts that the `size` fields are +correct, the tree is balanced, and the elements of the tree are organized according to the +ordering. This version of `valid` also puts all elements in the tree in the interval `(lo, hi)`. -/ +structure valid' (lo : with_bot α) (t : ordnode α) (hi : with_top α) : Prop := +(ord : t.bounded lo hi) +(sz : t.sized) +(bal : t.balanced) + +/-- The validity predicate for an `ordnode` subtree. This asserts that the `size` fields are +correct, the tree is balanced, and the elements of the tree are organized according to the +ordering. -/ +def valid (t : ordnode α) : Prop := valid' ⊥ t ⊤ + +theorem valid'.mono_left {x y : α} (xy : x ≤ y) + {t : ordnode α} {o} (h : valid' ↑y t o) : valid' ↑x t o := +⟨h.1.mono_left xy, h.2, h.3⟩ + +theorem valid'.mono_right {x y : α} (xy : x ≤ y) + {t : ordnode α} {o} (h : valid' o t ↑x) : valid' o t ↑y := +⟨h.1.mono_right xy, h.2, h.3⟩ + +theorem valid'.trans_left {t₁ t₂ : ordnode α} {x : α} {o₁ o₂} + (h : bounded t₁ o₁ ↑x) (H : valid' ↑x t₂ o₂) : valid' o₁ t₂ o₂ := +⟨h.trans_left H.1, H.2, H.3⟩ + +theorem valid'.trans_right {t₁ t₂ : ordnode α} {x : α} {o₁ o₂} + (H : valid' o₁ t₁ ↑x) (h : bounded t₂ ↑x o₂) : valid' o₁ t₁ o₂ := +⟨H.1.trans_right h, H.2, H.3⟩ + +theorem valid'.of_lt {t : ordnode α} {x : α} {o₁ o₂} + (H : valid' o₁ t o₂) (h₁ : bounded nil o₁ ↑x) (h₂ : all (< x) t) : valid' o₁ t ↑x := +⟨H.1.of_lt h₁ h₂, H.2, H.3⟩ + +theorem valid'.of_gt {t : ordnode α} {x : α} {o₁ o₂} + (H : valid' o₁ t o₂) (h₁ : bounded nil ↑x o₂) (h₂ : all (> x) t) : valid' ↑x t o₂ := +⟨H.1.of_gt h₁ h₂, H.2, H.3⟩ + +theorem valid'.valid {t o₁ o₂} (h : @valid' α _ o₁ t o₂) : valid t := ⟨h.1.weak, h.2, h.3⟩ + +theorem valid'_nil {o₁ o₂} (h : bounded nil o₁ o₂) : valid' o₁ (@nil α) o₂ := ⟨h, ⟨⟩, ⟨⟩⟩ + +theorem valid_nil : valid (@nil α) := valid'_nil ⟨⟩ + +theorem valid'.node {s l x r o₁ o₂} + (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H : balanced_sz (size l) (size r)) (hs : s = size l + size r + 1) : + valid' o₁ (@node α s l x r) o₂ := +⟨⟨hl.1, hr.1⟩, ⟨hs, hl.2, hr.2⟩, ⟨H, hl.3, hr.3⟩⟩ + +theorem valid'.dual : ∀ {t : ordnode α} {o₁ o₂} (h : valid' o₁ t o₂), + @valid' (order_dual α) _ o₂ (dual t) o₁ +| nil o₁ o₂ h := valid'_nil h.1.dual +| (node s l x r) o₁ o₂ ⟨⟨ol, or⟩, ⟨rfl, sl, sr⟩, ⟨b, bl, br⟩⟩ := + let ⟨ol', sl', bl'⟩ := valid'.dual ⟨ol, sl, bl⟩, + ⟨or', sr', br'⟩ := valid'.dual ⟨or, sr, br⟩ in + ⟨⟨or', ol'⟩, + ⟨by simp [size_dual, add_comm], sr', sl'⟩, + ⟨by rw [size_dual, size_dual]; exact b.symm, br', bl'⟩⟩ + +theorem valid'.dual_iff {t : ordnode α} {o₁ o₂} : valid' o₁ t o₂ ↔ + @valid' (order_dual α) _ o₂ (dual t) o₁ := +⟨valid'.dual, λ h, by have := valid'.dual h; + rwa [dual_dual, order_dual.preorder.dual_dual] at this⟩ + +theorem valid.dual {t : ordnode α} : valid t → + @valid (order_dual α) _ (dual t) := valid'.dual + +theorem valid.dual_iff {t : ordnode α} : valid t ↔ + @valid (order_dual α) _ (dual t) := valid'.dual_iff + +theorem valid'.left {s l x r o₁ o₂} (H : valid' o₁ (@node α s l x r) o₂) : valid' o₁ l x := +⟨H.1.1, H.2.2.1, H.3.2.1⟩ + +theorem valid'.right {s l x r o₁ o₂} (H : valid' o₁ (@node α s l x r) o₂) : valid' ↑x r o₂ := +⟨H.1.2, H.2.2.2, H.3.2.2⟩ + +theorem valid.left {s l x r} (H : valid (@node α s l x r)) : valid l := H.left.valid + +theorem valid.right {s l x r} (H : valid (@node α s l x r)) : valid r := H.right.valid + +theorem valid.size_eq {s l x r} (H : valid (@node α s l x r)) : + size (@node α s l x r) = size l + size r + 1 := H.2.1 + +theorem valid'.node' {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H : balanced_sz (size l) (size r)) : valid' o₁ (@node' α l x r) o₂ := +hl.node hr H rfl + +theorem valid'_singleton {x : α} {o₁ o₂} + (h₁ : bounded nil o₁ ↑x) (h₂ : bounded nil ↑x o₂) : valid' o₁ (singleton x : ordnode α) o₂ := +(valid'_nil h₁).node (valid'_nil h₂) (or.inl zero_le_one) rfl + +theorem valid_singleton {x : α} : valid (singleton x : ordnode α) := valid'_singleton ⟨⟩ ⟨⟩ + +theorem valid'.node3_l {l x m y r o₁ o₂} + (hl : valid' o₁ l ↑x) (hm : valid' ↑x m ↑y) (hr : valid' ↑y r o₂) + (H1 : balanced_sz (size l) (size m)) + (H2 : balanced_sz (size l + size m + 1) (size r)) : + valid' o₁ (@node3_l α l x m y r) o₂ := +(hl.node' hm H1).node' hr H2 + +theorem valid'.node3_r {l x m y r o₁ o₂} + (hl : valid' o₁ l ↑x) (hm : valid' ↑x m ↑y) (hr : valid' ↑y r o₂) + (H1 : balanced_sz (size l) (size m + size r + 1)) + (H2 : balanced_sz (size m) (size r)) : + valid' o₁ (@node3_r α l x m y r) o₂ := +hl.node' (hm.node' hr H2) H1 + +theorem valid'.node4_l_lemma₁ {a b c d : ℕ} + (lr₂ : 3 * (b + c + 1 + d) ≤ 16 * a + 9) + (mr₂ : b + c + 1 ≤ 3 * d) + (mm₁ : b ≤ 3 * c) : b < 3 * a + 1 := by linarith + +theorem valid'.node4_l_lemma₂ {b c d : ℕ} (mr₂ : b + c + 1 ≤ 3 * d) : c ≤ 3 * d := by linarith + +theorem valid'.node4_l_lemma₃ {b c d : ℕ} + (mr₁ : 2 * d ≤ b + c + 1) + (mm₁ : b ≤ 3 * c) : d ≤ 3 * c := by linarith + +theorem valid'.node4_l_lemma₄ {a b c d : ℕ} + (lr₁ : 3 * a ≤ b + c + 1 + d) + (mr₂ : b + c + 1 ≤ 3 * d) + (mm₁ : b ≤ 3 * c) : a + b + 1 ≤ 3 * (c + d + 1) := by linarith + +theorem valid'.node4_l_lemma₅ {a b c d : ℕ} + (lr₂ : 3 * (b + c + 1 + d) ≤ 16 * a + 9) + (mr₁ : 2 * d ≤ b + c + 1) + (mm₂ : c ≤ 3 * b) : c + d + 1 ≤ 3 * (a + b + 1) := by linarith + +theorem valid'.node4_l {l x m y r o₁ o₂} + (hl : valid' o₁ l ↑x) (hm : valid' ↑x m ↑y) (hr : valid' ↑y r o₂) + (Hm : 0 < size m) + (H : (size l = 0 ∧ size m = 1 ∧ size r ≤ 1) ∨ + (0 < size l ∧ ratio * size r ≤ size m ∧ + delta * size l ≤ size m + size r ∧ + 3 * (size m + size r) ≤ 16 * size l + 9 ∧ + size m ≤ delta * size r)) : + valid' o₁ (@node4_l α l x m y r) o₂ := +begin + cases m with s ml z mr, {cases Hm}, + suffices : balanced_sz (size l) (size ml) ∧ + balanced_sz (size mr) (size r) ∧ + balanced_sz (size l + size ml + 1) (size mr + size r + 1), + from (valid'.node' (hl.node' hm.left this.1) (hm.right.node' hr this.2.1) this.2.2), + rcases H with ⟨l0, m1, r0⟩ | ⟨l0, mr₁, lr₁, lr₂, mr₂⟩, + { rw [hm.2.size_eq, nat.succ_inj', add_eq_zero_iff] at m1, + rw [l0, m1.1, m1.2], rcases size r with _|_|_; exact dec_trivial }, + { cases nat.eq_zero_or_pos (size r) with r0 r0, + { rw r0 at mr₂, cases not_le_of_lt Hm mr₂ }, + rw [hm.2.size_eq] at lr₁ lr₂ mr₁ mr₂, + by_cases mm : size ml + size mr ≤ 1, + { have r1 := le_antisymm ((mul_le_mul_left dec_trivial).1 + (le_trans mr₁ (nat.succ_le_succ mm) : _ ≤ ratio * 1)) r0, + rw [r1, add_assoc] at lr₁, + have l1 := le_antisymm ((mul_le_mul_left dec_trivial).1 + (le_trans lr₁ (add_le_add_right mm 2) : _ ≤ delta * 1)) l0, + rw [l1, r1], + cases size ml; cases size mr, + { exact dec_trivial }, + { rw zero_add at mm, rcases mm with _|⟨_,⟨⟩⟩, + exact dec_trivial }, + { rcases mm with _|⟨_,⟨⟩⟩, exact dec_trivial }, + { rw nat.succ_add at mm, rcases mm with _|⟨_,⟨⟩⟩ } }, + rcases hm.3.1.resolve_left mm with ⟨mm₁, mm₂⟩, + cases nat.eq_zero_or_pos (size ml) with ml0 ml0, + { rw [ml0, mul_zero, nat.le_zero_iff] at mm₂, + rw [ml0, mm₂] at mm, cases mm dec_trivial }, + cases nat.eq_zero_or_pos (size mr) with mr0 mr0, + { rw [mr0, mul_zero, nat.le_zero_iff] at mm₁, + rw [mr0, mm₁] at mm, cases mm dec_trivial }, + have : 2 * size l ≤ size ml + size mr + 1, + { have := nat.mul_le_mul_left _ lr₁, + rw [mul_left_comm, mul_add] at this, + have := le_trans this (add_le_add_left mr₁ _), + rw [← nat.succ_mul] at this, + exact (mul_le_mul_left dec_trivial).1 this }, + refine ⟨or.inr ⟨_, _⟩, or.inr ⟨_, _⟩, or.inr ⟨_, _⟩⟩, + { refine (mul_le_mul_left dec_trivial).1 (le_trans this _), + rw [two_mul, nat.succ_le_iff], + refine add_lt_add_of_lt_of_le _ mm₂, + simpa using (mul_lt_mul_right ml0).2 (dec_trivial:1<3) }, + { exact nat.le_of_lt_succ (valid'.node4_l_lemma₁ lr₂ mr₂ mm₁) }, + { exact valid'.node4_l_lemma₂ mr₂ }, + { exact valid'.node4_l_lemma₃ mr₁ mm₁ }, + { exact valid'.node4_l_lemma₄ lr₁ mr₂ mm₁ }, + { exact valid'.node4_l_lemma₅ lr₂ mr₁ mm₂ } } +end + +theorem valid'.rotate_l_lemma₁ {a b c : ℕ} + (H2 : 3 * a ≤ b + c) (hb₂ : c ≤ 3 * b) : a ≤ 3 * b := by linarith + +theorem valid'.rotate_l_lemma₂ {a b c : ℕ} + (H3 : 2 * (b + c) ≤ 9 * a + 3) (h : b < 2 * c) : b < 3 * a + 1 := by linarith + +theorem valid'.rotate_l_lemma₃ {a b c : ℕ} + (H2 : 3 * a ≤ b + c) (h : b < 2 * c) : a + b < 3 * c := by linarith + +theorem valid'.rotate_l_lemma₄ {a b : ℕ} + (H3 : 2 * b ≤ 9 * a + 3) : 3 * b ≤ 16 * a + 9 := by linarith + +theorem valid'.rotate_l {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H1 : ¬ size l + size r ≤ 1) + (H2 : delta * size l < size r) + (H3 : 2 * size r ≤ 9 * size l + 5 ∨ size r ≤ 3) : + valid' o₁ (@rotate_l α l x r) o₂ := +begin + cases r with rs rl rx rr, {cases H2}, + rw [hr.2.size_eq, nat.lt_succ_iff] at H2, + rw [hr.2.size_eq] at H3, + replace H3 : 2 * (size rl + size rr) ≤ 9 * size l + 3 ∨ size rl + size rr ≤ 2 := + H3.imp (@nat.le_of_add_le_add_right 2 _ _) nat.le_of_succ_le_succ, + have H3_0 : size l = 0 → size rl + size rr ≤ 2, + { intro l0, rw l0 at H3, + exact (or_iff_right_of_imp $ by exact λ h, + (mul_le_mul_left dec_trivial).1 (le_trans h dec_trivial)).1 H3 }, + have H3p : size l > 0 → 2 * (size rl + size rr) ≤ 9 * size l + 3 := + λ l0 : 1 ≤ size l, (or_iff_left_of_imp $ by intro; linarith).1 H3, + have ablem : ∀ {a b : ℕ}, 1 ≤ a → a + b ≤ 2 → b ≤ 1, {intros, linarith}, + have hlp : size l > 0 → ¬ size rl + size rr ≤ 1 := λ l0 hb, absurd + (le_trans (le_trans (nat.mul_le_mul_left _ l0) H2) hb) dec_trivial, + rw rotate_l, split_ifs, + { have rr0 : size rr > 0 := (mul_lt_mul_left dec_trivial).1 + (lt_of_le_of_lt (nat.zero_le _) h : ratio * 0 < _), + suffices : balanced_sz (size l) (size rl) ∧ balanced_sz (size l + size rl + 1) (size rr), + { exact hl.node3_l hr.left hr.right this.1 this.2 }, + cases nat.eq_zero_or_pos (size l) with l0 l0, + { rw l0, replace H3 := H3_0 l0, + have := hr.3.1, + cases nat.eq_zero_or_pos (size rl) with rl0 rl0, + { rw rl0 at this ⊢, + rw le_antisymm (balanced_sz_zero.1 this.symm) rr0, + exact dec_trivial }, + have rr1 : size rr = 1 := le_antisymm (ablem rl0 H3) rr0, + rw add_comm at H3, + rw [rr1, show size rl = 1, from le_antisymm (ablem rr0 H3) rl0], + exact dec_trivial }, + replace H3 := H3p l0, + rcases hr.3.1.resolve_left (hlp l0) with ⟨hb₁, hb₂⟩, + cases nat.eq_zero_or_pos (size rl) with rl0 rl0, + { rw rl0 at hb₂, cases not_le_of_gt rr0 hb₂ }, + cases eq_or_lt_of_le (show 1 ≤ size rr, from rr0) with rr1 rr1, + { rw [← rr1] at h H2 ⊢, + have : size rl = 1 := le_antisymm (nat.lt_succ_iff.1 h) rl0, + rw this at H2, + exact absurd (le_trans (nat.mul_le_mul_left _ l0) H2) dec_trivial }, + refine ⟨or.inr ⟨_, _⟩, or.inr ⟨_, _⟩⟩, + { exact valid'.rotate_l_lemma₁ H2 hb₂ }, + { exact nat.le_of_lt_succ (valid'.rotate_l_lemma₂ H3 h) }, + { exact valid'.rotate_l_lemma₃ H2 h }, + { exact le_trans hb₂ (nat.mul_le_mul_left _ $ + le_trans (nat.le_add_left _ _) (nat.le_add_right _ _)) } }, + { cases nat.eq_zero_or_pos (size rl) with rl0 rl0, + { rw [rl0, not_lt, nat.le_zero_iff, nat.mul_eq_zero] at h, + replace h := h.resolve_left dec_trivial, + rw [rl0, h, nat.le_zero_iff, nat.mul_eq_zero] at H2, + rw [hr.2.size_eq, rl0, h, H2.resolve_left dec_trivial] at H1, + cases H1 dec_trivial }, + refine hl.node4_l hr.left hr.right rl0 _, + cases nat.eq_zero_or_pos (size l) with l0 l0, + { replace H3 := H3_0 l0, + cases nat.eq_zero_or_pos (size rr) with rr0 rr0, + { have := hr.3.1, + rw rr0 at this, + exact or.inl ⟨l0, + le_antisymm (balanced_sz_zero.1 this) rl0, + rr0.symm ▸ zero_le_one⟩ }, + exact or.inl ⟨l0, + le_antisymm (ablem rr0 $ by rwa add_comm) rl0, + ablem rl0 H3⟩ }, + exact or.inr ⟨l0, not_lt.1 h, H2, + valid'.rotate_l_lemma₄ (H3p l0), + (hr.3.1.resolve_left (hlp l0)).1⟩ } +end + +theorem valid'.rotate_r {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H1 : ¬ size l + size r ≤ 1) + (H2 : delta * size r < size l) + (H3 : 2 * size l ≤ 9 * size r + 5 ∨ size l ≤ 3) : + valid' o₁ (@rotate_r α l x r) o₂ := +begin + refine valid'.dual_iff.2 _, + rw dual_rotate_r, + refine hr.dual.rotate_l hl.dual _ _ _, + { rwa [size_dual, size_dual, add_comm] }, + { rwa [size_dual, size_dual] }, + { rwa [size_dual, size_dual] } +end + +theorem valid'.balance'_aux {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H₁ : 2 * @size α r ≤ 9 * size l + 5 ∨ size r ≤ 3) + (H₂ : 2 * @size α l ≤ 9 * size r + 5 ∨ size l ≤ 3) : + valid' o₁ (@balance' α l x r) o₂ := +begin + rw balance', split_ifs, + { exact hl.node' hr (or.inl h) }, + { exact hl.rotate_l hr h h_1 H₁ }, + { exact hl.rotate_r hr h h_2 H₂ }, + { exact hl.node' hr (or.inr ⟨not_lt.1 h_2, not_lt.1 h_1⟩) } +end + +theorem valid'.balance'_lemma {α l l' r r'} + (H1 : balanced_sz l' r') + (H2 : nat.dist (@size α l) l' ≤ 1 ∧ size r = r' ∨ + nat.dist (size r) r' ≤ 1 ∧ size l = l') : + 2 * @size α r ≤ 9 * size l + 5 ∨ size r ≤ 3 := +begin + suffices : @size α r ≤ 3 * (size l + 1), + { cases nat.eq_zero_or_pos (size l) with l0 l0, + { apply or.inr, rwa l0 at this }, + change 1 ≤ _ at l0, apply or.inl, linarith }, + rcases H2 with ⟨hl, rfl⟩ | ⟨hr, rfl⟩; + rcases H1 with h | ⟨h₁, h₂⟩, + { exact le_trans (nat.le_add_left _ _) (le_trans h (nat.le_add_left _ _)) }, + { exact le_trans h₂ (nat.mul_le_mul_left _ $ + le_trans (nat.dist_tri_right _ _) (nat.add_le_add_left hl _)) }, + { exact le_trans (nat.dist_tri_left' _ _) + (le_trans (add_le_add hr (le_trans (nat.le_add_left _ _) h)) dec_trivial) }, + { rw nat.mul_succ, + exact le_trans (nat.dist_tri_right' _ _) + (add_le_add h₂ (le_trans hr dec_trivial)) }, +end + +theorem valid'.balance' {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H : ∃ l' r', balanced_sz l' r' ∧ + (nat.dist (size l) l' ≤ 1 ∧ size r = r' ∨ + nat.dist (size r) r' ≤ 1 ∧ size l = l')) : + valid' o₁ (@balance' α l x r) o₂ := +let ⟨l', r', H1, H2⟩ := H in +valid'.balance'_aux hl hr (valid'.balance'_lemma H1 H2) (valid'.balance'_lemma H1.symm H2.symm) + +theorem valid'.balance {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H : ∃ l' r', balanced_sz l' r' ∧ + (nat.dist (size l) l' ≤ 1 ∧ size r = r' ∨ + nat.dist (size r) r' ≤ 1 ∧ size l = l')) : + valid' o₁ (@balance α l x r) o₂ := +by rw balance_eq_balance' hl.3 hr.3 hl.2 hr.2; exact hl.balance' hr H + +theorem valid'.balance_l_aux {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H₁ : size l = 0 → size r ≤ 1) + (H₂ : 1 ≤ size l → 1 ≤ size r → size r ≤ delta * size l) + (H₃ : 2 * @size α l ≤ 9 * size r + 5 ∨ size l ≤ 3) : + valid' o₁ (@balance_l α l x r) o₂ := +begin + rw [balance_l_eq_balance hl.2 hr.2 H₁ H₂, balance_eq_balance' hl.3 hr.3 hl.2 hr.2], + refine hl.balance'_aux hr (or.inl _) H₃, + cases nat.eq_zero_or_pos (size r) with r0 r0, + { rw r0, exact nat.zero_le _ }, + cases nat.eq_zero_or_pos (size l) with l0 l0, + { rw l0, exact le_trans (nat.mul_le_mul_left _ (H₁ l0)) dec_trivial }, + replace H₂ : _ ≤ 3 * _ := H₂ l0 r0, linarith +end + +theorem valid'.balance_l {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H : (∃ l', raised l' (size l) ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised (size r) r' ∧ balanced_sz (size l) r')) : + valid' o₁ (@balance_l α l x r) o₂ := +begin + rw balance_l_eq_balance' hl.3 hr.3 hl.2 hr.2 H, + refine hl.balance' hr _, + rcases H with ⟨l', e, H⟩ | ⟨r', e, H⟩, + { exact ⟨_, _, H, or.inl ⟨e.dist_le', rfl⟩⟩ }, + { exact ⟨_, _, H, or.inr ⟨e.dist_le, rfl⟩⟩ }, +end + +theorem valid'.balance_r_aux {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H₁ : size r = 0 → size l ≤ 1) + (H₂ : 1 ≤ size r → 1 ≤ size l → size l ≤ delta * size r) + (H₃ : 2 * @size α r ≤ 9 * size l + 5 ∨ size r ≤ 3) : + valid' o₁ (@balance_r α l x r) o₂ := +begin + rw [valid'.dual_iff, dual_balance_r], + have := hr.dual.balance_l_aux hl.dual, + rw [size_dual, size_dual] at this, + exact this H₁ H₂ H₃ +end + +theorem valid'.balance_r {l x r o₁ o₂} (hl : valid' o₁ l ↑x) (hr : valid' ↑x r o₂) + (H : (∃ l', raised (size l) l' ∧ balanced_sz l' (size r)) ∨ + (∃ r', raised r' (size r) ∧ balanced_sz (size l) r')) : + valid' o₁ (@balance_r α l x r) o₂ := +by rw [valid'.dual_iff, dual_balance_r]; exact +hr.dual.balance_l hl.dual (balance_sz_dual H) + +theorem valid'.erase_max_aux {s l x r o₁ o₂} + (H : valid' o₁ (node s l x r) o₂) : + valid' o₁ (@erase_max α (node' l x r)) ↑(find_max' x r) ∧ + size (node' l x r) = size (erase_max (node' l x r)) + 1 := +begin + have := H.2.eq_node', rw this at H, clear this, + induction r with rs rl rx rr IHrl IHrr generalizing l x o₁, + { exact ⟨H.left, rfl⟩ }, + have := H.2.2.2.eq_node', rw this at H ⊢, + rcases IHrr H.right with ⟨h, e⟩, + refine ⟨valid'.balance_l H.left h (or.inr ⟨_, or.inr e, H.3.1⟩), _⟩, + rw [erase_max, size_balance_l H.3.2.1 h.3 H.2.2.1 h.2 (or.inr ⟨_, or.inr e, H.3.1⟩)], + rw [size, e], refl +end + +theorem valid'.erase_min_aux {s l x r o₁ o₂} + (H : valid' o₁ (node s l x r) o₂) : + valid' ↑(find_min' l x) (@erase_min α (node' l x r)) o₂ ∧ + size (node' l x r) = size (erase_min (node' l x r)) + 1 := +by have := H.dual.erase_max_aux; + rwa [← dual_node', size_dual, ← dual_erase_min, + size_dual, ← valid'.dual_iff, find_max'_dual] at this + +theorem erase_min.valid : ∀ {t} (h : @valid α _ t), valid (erase_min t) +| nil _ := valid_nil +| (node _ l x r) h := by rw h.2.eq_node'; exact h.erase_min_aux.1.valid + +theorem erase_max.valid {t} (h : @valid α _ t) : valid (erase_max t) := +by rw [valid.dual_iff, dual_erase_max]; exact erase_min.valid h.dual + +theorem valid'.glue_aux {l r o₁ o₂} + (hl : valid' o₁ l o₂) (hr : valid' o₁ r o₂) + (sep : l.all (λ x, r.all (λ y, x < y))) + (bal : balanced_sz (size l) (size r)) : + valid' o₁ (@glue α l r) o₂ ∧ size (glue l r) = size l + size r := +begin + cases l with ls ll lx lr, {exact ⟨hr, (zero_add _).symm⟩ }, + cases r with rs rl rx rr, {exact ⟨hl, rfl⟩ }, + dsimp [glue], split_ifs, + { rw [split_max_eq, glue], + cases valid'.erase_max_aux hl with v e, + suffices H, + refine ⟨valid'.balance_r v (hr.of_gt _ _) H, _⟩, + { refine find_max'_all lx lr hl.1.2.to_nil (sep.2.2.imp _), + exact λ x h, hr.1.2.to_nil.mono_left (le_of_lt h.2.1) }, + { exact @find_max'_all _ (λ a, all (> a) (node rs rl rx rr)) lx lr sep.2.1 sep.2.2 }, + { rw [size_balance_r v.3 hr.3 v.2 hr.2 H, add_right_comm, ← e, hl.2.1], refl }, + { refine or.inl ⟨_, or.inr e, _⟩, + rwa hl.2.eq_node' at bal } }, + { rw [split_min_eq, glue], + cases valid'.erase_min_aux hr with v e, + suffices H, + refine ⟨valid'.balance_l (hl.of_lt _ _) v H, _⟩, + { refine @find_min'_all _ (λ a, bounded nil o₁ ↑a) rl rx (sep.2.1.1.imp _) hr.1.1.to_nil, + exact λ y h, hl.1.1.to_nil.mono_right (le_of_lt h) }, + { exact @find_min'_all _ (λ a, all (< a) (node ls ll lx lr)) rl rx + (all_iff_forall.2 $ λ x hx, sep.imp $ λ y hy, all_iff_forall.1 hy.1 _ hx) + (sep.imp $ λ y hy, hy.2.1) }, + { rw [size_balance_l hl.3 v.3 hl.2 v.2 H, add_assoc, ← e, hr.2.1], refl }, + { refine or.inr ⟨_, or.inr e, _⟩, + rwa hr.2.eq_node' at bal } }, +end + +theorem valid'.glue {l x r o₁ o₂} + (hl : valid' o₁ l ↑(x:α)) (hr : valid' ↑x r o₂) : + balanced_sz (size l) (size r) → + valid' o₁ (@glue α l r) o₂ ∧ size (@glue α l r) = size l + size r := +valid'.glue_aux (hl.trans_right hr.1) (hr.trans_left hl.1) (hl.1.to_sep hr.1) + +theorem valid'.merge_lemma {a b c : ℕ} + (h₁ : 3 * a < b + c + 1) (h₂ : b ≤ 3 * c) : 2 * (a + b) ≤ 9 * c + 5 := +by linarith + +theorem valid'.merge_aux₁ {o₁ o₂ ls ll lx lr rs rl rx rr t} + (hl : valid' o₁ (@node α ls ll lx lr) o₂) + (hr : valid' o₁ (node rs rl rx rr) o₂) + (h : delta * ls < rs) + (v : valid' o₁ t ↑rx) + (e : size t = ls + size rl) : + valid' o₁ (balance_l t rx rr) o₂ ∧ size (balance_l t rx rr) = ls + rs := +begin + rw hl.2.1 at e, + rw [hl.2.1, hr.2.1, delta] at h, + rcases hr.3.1 with H|⟨hr₁, hr₂⟩, {linarith}, + suffices H₂, suffices H₁, + refine ⟨valid'.balance_l_aux v hr.right H₁ H₂ _, _⟩, + { rw e, exact or.inl (valid'.merge_lemma h hr₁) }, + { rw [balance_l_eq_balance v.2 hr.2.2.2 H₁ H₂, balance_eq_balance' v.3 hr.3.2.2 v.2 hr.2.2.2, + size_balance' v.2 hr.2.2.2, e, hl.2.1, hr.2.1], simp [add_comm, add_left_comm] }, + { rw [e, add_right_comm], rintro ⟨⟩ }, + { intros _ h₁, rw e, unfold delta at hr₂ ⊢, linarith } +end + +theorem valid'.merge_aux {l r o₁ o₂} + (hl : valid' o₁ l o₂) (hr : valid' o₁ r o₂) + (sep : l.all (λ x, r.all (λ y, x < y))) : + valid' o₁ (@merge α l r) o₂ ∧ size (merge l r) = size l + size r := +begin + induction l with ls ll lx lr IHll IHlr generalizing o₁ o₂ r, + { exact ⟨hr, (zero_add _).symm⟩ }, + induction r with rs rl rx rr IHrl IHrr generalizing o₁ o₂, + { exact ⟨hl, rfl⟩ }, + rw [merge_node], split_ifs, + { cases IHrl (sep.imp $ λ x h, h.1) + (hl.of_lt hr.1.1.to_nil $ sep.imp $ λ x h, h.2.1) hr.left with v e, + exact valid'.merge_aux₁ hl hr h v e }, + { cases IHlr hl.right (hr.of_gt hl.1.2.to_nil sep.2.1) sep.2.2 with v e, + have := valid'.merge_aux₁ hr.dual hl.dual h_1 v.dual, + rw [size_dual, add_comm, size_dual, + ← dual_balance_r, ← valid'.dual_iff, size_dual, add_comm rs] at this, + exact this e }, + { refine valid'.glue_aux hl hr sep (or.inr ⟨not_lt.1 h_1, not_lt.1 h⟩) } +end + +theorem valid.merge {l r} (hl : valid l) (hr : valid r) + (sep : l.all (λ x, r.all (λ y, x < y))) : valid (@merge α l r) := +(valid'.merge_aux hl hr sep).1 + +theorem insert_with.valid_aux [is_total α (≤)] [@decidable_rel α (≤)] + (f : α → α) (x : α) (hf : ∀ y, x ≤ y ∧ y ≤ x → x ≤ f y ∧ f y ≤ x) : ∀ {t o₁ o₂}, + valid' o₁ t o₂ → bounded nil o₁ ↑x → bounded nil ↑x o₂ → + valid' o₁ (insert_with f x t) o₂ ∧ + raised (size t) (size (insert_with f x t)) +| nil o₁ o₂ _ bl br := ⟨valid'_singleton bl br, or.inr rfl⟩ +| (node sz l y r) o₁ o₂ h bl br := begin + rw [insert_with, cmp_le], + split_ifs; rw [insert_with], + { rcases h with ⟨⟨lx, xr⟩, hs, hb⟩, + rcases hf _ ⟨h_1, h_2⟩ with ⟨xf, fx⟩, + refine ⟨⟨⟨lx.mono_right (le_trans h_2 xf), + xr.mono_left (le_trans fx h_1)⟩, hs, hb⟩, or.inl rfl⟩ }, + { rcases insert_with.valid_aux h.left bl (lt_of_le_not_le h_1 h_2) with ⟨vl, e⟩, + suffices H, + { refine ⟨vl.balance_l h.right H, _⟩, + rw [size_balance_l vl.3 h.3.2.2 vl.2 h.2.2.2 H, h.2.size_eq], + refine (e.add_right _).add_right _ }, + { exact or.inl ⟨_, e, h.3.1⟩ } }, + { have : y < x := lt_of_le_not_le ((total_of (≤) _ _).resolve_left h_1) h_1, + rcases insert_with.valid_aux h.right this br with ⟨vr, e⟩, + suffices H, + { refine ⟨h.left.balance_r vr H, _⟩, + rw [size_balance_r h.3.2.1 vr.3 h.2.2.1 vr.2 H, h.2.size_eq], + refine (e.add_left _).add_right _ }, + { exact or.inr ⟨_, e, h.3.1⟩ } }, + end + +theorem insert_with.valid [is_total α (≤)] [@decidable_rel α (≤)] + (f : α → α) (x : α) (hf : ∀ y, x ≤ y ∧ y ≤ x → x ≤ f y ∧ f y ≤ x) + {t} (h : valid t) : valid (insert_with f x t) := +(insert_with.valid_aux _ _ hf h ⟨⟩ ⟨⟩).1 + +theorem insert_eq_insert_with [@decidable_rel α (≤)] + (x : α) : ∀ t, ordnode.insert x t = insert_with (λ _, x) x t +| nil := rfl +| (node _ l y r) := by unfold ordnode.insert insert_with; + cases cmp_le x y; unfold ordnode.insert insert_with; simp [insert_eq_insert_with] + +theorem insert.valid [is_total α (≤)] [@decidable_rel α (≤)] + (x : α) {t} (h : valid t) : valid (ordnode.insert x t) := +by rw insert_eq_insert_with; exact +insert_with.valid _ _ (λ _ _, ⟨le_refl _, le_refl _⟩) h + +theorem insert'_eq_insert_with [@decidable_rel α (≤)] + (x : α) : ∀ t, insert' x t = insert_with id x t +| nil := rfl +| (node _ l y r) := by unfold insert' insert_with; + cases cmp_le x y; unfold insert' insert_with; simp [insert'_eq_insert_with] + +theorem insert'.valid [is_total α (≤)] [@decidable_rel α (≤)] + (x : α) {t} (h : valid t) : valid (insert' x t) := +by rw insert'_eq_insert_with; exact insert_with.valid _ _ (λ _, id) h + +end + +end ordnode + +/-- An `ordset α` is a finite set of values, represented as a tree. The operations on this type +maintain that the tree is balanced and correctly stores subtree sizes at each level. The +correctness property of the tree is baked into the type, so all operations on this type are correct +by construction. -/ +def ordset (α : Type*) [preorder α] := {t : ordnode α // t.valid} + +namespace ordset +open ordnode +variable [preorder α] + +/-- O(1). The empty set. -/ +def nil : ordset α := ⟨nil, ⟨⟩, ⟨⟩, ⟨⟩⟩ + +/-- O(1). Get the size of the set. -/ +def size (s : ordset α) : ℕ := s.1.size + +/-- O(1). Construct a singleton set containing value `a`. -/ +protected def singleton (a : α) : ordset α := ⟨singleton a, valid_singleton⟩ + +instance : has_emptyc (ordset α) := ⟨nil⟩ +instance : inhabited (ordset α) := ⟨nil⟩ +instance : has_singleton α (ordset α) := ⟨ordset.singleton⟩ + +/-- O(1). Is the set empty? -/ +def empty (s : ordset α) : Prop := s = ∅ + +theorem empty_iff {s : ordset α} : s = ∅ ↔ s.1.empty := +⟨λ h, by cases h; exact rfl, + λ h, by cases s; cases s_val; [exact rfl, cases h]⟩ + +instance : decidable_pred (@empty α _) := +λ s, decidable_of_iff' _ empty_iff + +/-- O(log n). Insert an element into the set, preserving balance and the BST property. + If an equivalent element is already in the set, this replaces it. -/ +protected def insert [is_total α (≤)] [@decidable_rel α (≤)] (x : α) (s : ordset α) : ordset α := +⟨ordnode.insert x s.1, insert.valid _ s.2⟩ + +instance [is_total α (≤)] [@decidable_rel α (≤)] : has_insert α (ordset α) := ⟨ordset.insert⟩ + +/-- O(log n). Insert an element into the set, preserving balance and the BST property. + If an equivalent element is already in the set, the set is returned as is. -/ +def insert' [is_total α (≤)] [@decidable_rel α (≤)] (x : α) (s : ordset α) : ordset α := +⟨insert' x s.1, insert'.valid _ s.2⟩ + +end ordset diff --git a/src/order/basic.lean b/src/order/basic.lean index aeda3093a9150..379a2c9a066c1 100644 --- a/src/order/basic.lean +++ b/src/order/basic.lean @@ -189,6 +189,21 @@ instance (α : Type*) [linear_order α] : linear_order (order_dual α) := instance : Π [inhabited α], inhabited (order_dual α) := id +theorem preorder.dual_dual (α : Type*) [H : preorder α] : + order_dual.preorder (order_dual α) = H := +preorder.ext $ λ _ _, iff.rfl + +theorem partial_order.dual_dual (α : Type*) [H : partial_order α] : + order_dual.partial_order (order_dual α) = H := +partial_order.ext $ λ _ _, iff.rfl + +theorem linear_order.dual_dual (α : Type*) [H : linear_order α] : + order_dual.linear_order (order_dual α) = H := +linear_order.ext $ λ _ _, iff.rfl + +theorem cmp_le_flip {α} [has_le α] [@decidable_rel α (≤)] (x y : α) : + @cmp_le (order_dual α) _ _ x y = cmp_le y x := rfl + end order_dual namespace strict_mono_incr_on