Skip to content

Commit

Permalink
chore: extend GetElem with getElem! and getElem? (#3694)
Browse files Browse the repository at this point in the history
This makes changes to the `GetElem` class so that it does not lead to
unnecessary overhead in container like `RBMap`.

The changes are to:
1. Make `getElem?` and `getElem!` part of the `GetElem` class so they
can be overridden in instances.
2. Introduce a `LawfulGetElem` class that contains correctness theorems
for `getElem?` and `getElem!` using the original definitions.
3. Reorganize definitions (e.g, by moving `GetElem` out of
`Init.Prelude`) so that the `GetElem` changes are feasible.
4. Provide `LawfulGetElem` instances to complement all existing
`GetElem` instances in Lean core.

To reduce the size of the PR, this doesn't do the work of providing new
`GetElem` instances for `RBMap`, `HashMap` etc. That will be done in a
separate PR (#3688) that depends on this.

---------

Co-authored-by: Mac Malone <tydeu@hatpress.net>
  • Loading branch information
joehendrix and tydeu committed Mar 28, 2024
1 parent 7989f62 commit 0963f34
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 97 deletions.
4 changes: 3 additions & 1 deletion src/Init/Data/Array/Basic.lean
Expand Up @@ -10,7 +10,7 @@ import Init.Data.Fin.Basic
import Init.Data.UInt.Basic
import Init.Data.Repr
import Init.Data.ToString.Basic
import Init.Util
import Init.GetElem
universe u v w

namespace Array
Expand Down Expand Up @@ -59,6 +59,8 @@ def uget (a : @& Array α) (i : USize) (h : i.toNat < a.size) : α :=
instance : GetElem (Array α) USize α fun xs i => i.toNat < xs.size where
getElem xs i h := xs.uget i h

instance : LawfulGetElem (Array α) USize α fun xs i => i.toNat < xs.size where

def back [Inhabited α] (a : Array α) : α :=
a.get! (a.size - 1)

Expand Down
2 changes: 2 additions & 0 deletions src/Init/Data/Array/Subarray.lean
Expand Up @@ -32,6 +32,8 @@ def get (s : Subarray α) (i : Fin s.size) : α :=
instance : GetElem (Subarray α) Nat α fun xs i => i < xs.size where
getElem xs i h := xs.get ⟨i, h⟩

instance : LawfulGetElem (Subarray α) Nat α fun xs i => i < xs.size where

@[inline] def getD (s : Subarray α) (i : Nat) (v₀ : α) : α :=
if h : i < s.size then s.get ⟨i, h⟩ else v₀

Expand Down
4 changes: 4 additions & 0 deletions src/Init/Data/ByteArray/Basic.lean
Expand Up @@ -52,9 +52,13 @@ def get : (a : @& ByteArray) → (@& Fin a.size) → UInt8
instance : GetElem ByteArray Nat UInt8 fun xs i => i < xs.size where
getElem xs i h := xs.get ⟨i, h⟩

instance : LawfulGetElem ByteArray Nat UInt8 fun xs i => i < xs.size where

instance : GetElem ByteArray USize UInt8 fun xs i => i.val < xs.size where
getElem xs i h := xs.uget i h

instance : LawfulGetElem ByteArray USize UInt8 fun xs i => i.val < xs.size where

@[extern "lean_byte_array_set"]
def set! : ByteArray → (@& Nat) → UInt8 → ByteArray
| ⟨bs⟩, i, b => ⟨bs.set! i b⟩
Expand Down
8 changes: 0 additions & 8 deletions src/Init/Data/Fin/Basic.lean
Expand Up @@ -4,9 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura, Robert Y. Lewis, Keeley Hoek, Mario Carneiro
-/
prelude
import Init.Data.Nat.Div
import Init.Data.Nat.Bitwise.Basic
import Init.Coe

open Nat

Expand Down Expand Up @@ -170,9 +168,3 @@ theorem val_add_one_le_of_lt {n : Nat} {a b : Fin n} (h : a < b) : (a : Nat) + 1
theorem val_add_one_le_of_gt {n : Nat} {a b : Fin n} (h : a > b) : (b : Nat) + 1 ≤ (a : Nat) := h

end Fin

instance [GetElem cont Nat elem dom] : GetElem cont (Fin n) elem fun xs i => dom xs i where
getElem xs i h := getElem xs i.1 h

macro_rules
| `(tactic| get_elem_tactic_trivial) => `(tactic| apply Fin.val_lt_of_le; get_elem_tactic_trivial; done)
4 changes: 4 additions & 0 deletions src/Init/Data/FloatArray/Basic.lean
Expand Up @@ -58,9 +58,13 @@ def get? (ds : FloatArray) (i : Nat) : Option Float :=
instance : GetElem FloatArray Nat Float fun xs i => i < xs.size where
getElem xs i h := xs.get ⟨i, h⟩

instance : LawfulGetElem FloatArray Nat Float fun xs i => i < xs.size where

instance : GetElem FloatArray USize Float fun xs i => i.val < xs.size where
getElem xs i h := xs.uget i h

instance : LawfulGetElem FloatArray USize Float fun xs i => i.val < xs.size where

@[extern "lean_float_array_uset"]
def uset : (a : FloatArray) → (i : USize) → Float → i.toNat < a.size → FloatArray
| ⟨ds⟩, i, v, h => ⟨ds.uset i v h⟩
Expand Down
15 changes: 1 addition & 14 deletions src/Init/Data/List/Basic.lean
Expand Up @@ -7,6 +7,7 @@ prelude
import Init.SimpLemmas
import Init.Data.Nat.Basic
import Init.Data.Nat.Div

set_option linter.missingDocs true -- keep it documented
open Decidable List

Expand Down Expand Up @@ -54,15 +55,6 @@ variable {α : Type u} {β : Type v} {γ : Type w}

namespace List

instance : GetElem (List α) Nat α fun as i => i < as.length where
getElem as i h := as.get ⟨i, h⟩

@[simp] theorem cons_getElem_zero (a : α) (as : List α) (h : 0 < (a :: as).length) : getElem (a :: as) 0 h = a := by
rfl

@[simp] theorem cons_getElem_succ (a : α) (as : List α) (i : Nat) (h : i + 1 < (a :: as).length) : getElem (a :: as) (i+1) h = getElem as i (Nat.lt_of_succ_lt_succ h) := by
rfl

theorem length_add_eq_lengthTRAux (as : List α) (n : Nat) : as.length + n = as.lengthTRAux n := by
induction as generalizing n with
| nil => simp [length, lengthTRAux]
Expand Down Expand Up @@ -520,11 +512,6 @@ def drop : Nat → List α → List α
@[simp] theorem drop_nil : ([] : List α).drop i = [] := by
cases i <;> rfl

theorem get_drop_eq_drop (as : List α) (i : Nat) (h : i < as.length) : as[i] :: as.drop (i+1) = as.drop i :=
match as, i with
| _::_, 0 => rfl
| _::_, i+1 => get_drop_eq_drop _ i _

/--
`O(min n |xs|)`. Returns the first `n` elements of `xs`, or the whole list if `n` is too large.
* `take 0 [a, b, c, d, e] = []`
Expand Down
173 changes: 173 additions & 0 deletions src/Init/GetElem.lean
@@ -0,0 +1,173 @@
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura, Mario Carneiro
-/
prelude
import Init.Util

@[never_extract]
private def outOfBounds [Inhabited α] : α :=
panic! "index out of bounds"

/--
The class `GetElem coll idx elem valid` implements the `xs[i]` notation.
Given `xs[i]` with `xs : coll` and `i : idx`, Lean looks for an instance of
`GetElem coll idx elem valid` and uses this to infer the type of return
value `elem` and side conditions `valid` required to ensure `xs[i]` yields
a valid value of type `elem`.
For example, the instance for arrays looks like
`GetElem (Array α) Nat α (fun xs i => i < xs.size)`.
The proof side-condition `valid xs i` is automatically dispatched by the
`get_elem_tactic` tactic, which can be extended by adding more clauses to
`get_elem_tactic_trivial`.
-/
class GetElem (coll : Type u) (idx : Type v) (elem : outParam (Type w))
(valid : outParam (coll → idx → Prop)) where
/--
The syntax `arr[i]` gets the `i`'th element of the collection `arr`. If there
are proof side conditions to the application, they will be automatically
inferred by the `get_elem_tactic` tactic.
The actual behavior of this class is type-dependent, but here are some
important implementations:
* `arr[i] : α` where `arr : Array α` and `i : Nat` or `i : USize`: does array
indexing with no bounds check and a proof side goal `i < arr.size`.
* `l[i] : α` where `l : List α` and `i : Nat`: index into a list, with proof
side goal `i < l.length`.
* `stx[i] : Syntax` where `stx : Syntax` and `i : Nat`: get a syntax argument,
no side goal (returns `.missing` out of range)
There are other variations on this syntax:
* `arr[i]!` is syntax for `getElem! arr i` which should panic and return
`default : α` if the index is not valid.
* `arr[i]?` is syntax for `getElem?` which should return `none` if the index
is not valid.
* `arr[i]'h` is syntax for `getElem arr i h` with `h` an explicit proof the
index is valid.
-/
getElem (xs : coll) (i : idx) (h : valid xs i) : elem

getElem? (xs : coll) (i : idx) [Decidable (valid xs i)] : Option elem :=
if h : _ then some (getElem xs i h) else none

getElem! [Inhabited elem] (xs : coll) (i : idx) [Decidable (valid xs i)] : elem :=
match getElem? xs i with | some e => e | none => outOfBounds

export GetElem (getElem getElem! getElem?)

@[inherit_doc getElem]
syntax:max term noWs "[" withoutPosition(term) "]" : term
macro_rules | `($x[$i]) => `(getElem $x $i (by get_elem_tactic))

@[inherit_doc getElem]
syntax term noWs "[" withoutPosition(term) "]'" term:max : term
macro_rules | `($x[$i]'$h) => `(getElem $x $i $h)

/--
The syntax `arr[i]?` gets the `i`'th element of the collection `arr` or
returns `none` if `i` is out of bounds.
-/
macro:max x:term noWs "[" i:term "]" noWs "?" : term => `(getElem? $x $i)

/--
The syntax `arr[i]!` gets the `i`'th element of the collection `arr` and
panics `i` is out of bounds.
-/
macro:max x:term noWs "[" i:term "]" noWs "!" : term => `(getElem! $x $i)

class LawfulGetElem (cont : Type u) (idx : Type v) (elem : outParam (Type w))
(dom : outParam (cont → idx → Prop)) [ge : GetElem cont idx elem dom] : Prop where

getElem?_def (c : cont) (i : idx) [Decidable (dom c i)] :
c[i]? = if h : dom c i then some (c[i]'h) else none := by intros; eq_refl
getElem!_def [Inhabited elem] (c : cont) (i : idx) [Decidable (dom c i)] :
c[i]! = match c[i]? with | some e => e | none => default := by intros; eq_refl

export LawfulGetElem (getElem?_def getElem!_def)

theorem getElem?_pos [GetElem cont idx elem dom] [LawfulGetElem cont idx elem dom]
(c : cont) (i : idx) (h : dom c i) [Decidable (dom c i)] : c[i]? = some (c[i]'h) := by
rw [getElem?_def]
exact dif_pos h

theorem getElem?_neg [GetElem cont idx elem dom] [LawfulGetElem cont idx elem dom]
(c : cont) (i : idx) (h : ¬dom c i) [Decidable (dom c i)] : c[i]? = none := by
rw [getElem?_def]
exact dif_neg h

theorem getElem!_pos [GetElem cont idx elem dom] [LawfulGetElem cont idx elem dom]
[Inhabited elem] (c : cont) (i : idx) (h : dom c i) [Decidable (dom c i)] :
c[i]! = c[i]'h := by
simp only [getElem!_def, getElem?_def, h]

theorem getElem!_neg [GetElem cont idx elem dom] [LawfulGetElem cont idx elem dom]
[Inhabited elem] (c : cont) (i : idx) (h : ¬dom c i) [Decidable (dom c i)] : c[i]! = default := by
simp only [getElem!_def, getElem?_def, h]

namespace Fin

instance instGetElemFinVal [GetElem cont Nat elem dom] : GetElem cont (Fin n) elem fun xs i => dom xs i where
getElem xs i h := getElem xs i.1 h
getElem? xs i := getElem? xs i.val
getElem! xs i := getElem! xs i.val

instance [GetElem cont Nat elem dom] [h : LawfulGetElem cont Nat elem dom] :
LawfulGetElem cont (Fin n) elem fun xs i => dom xs i where

getElem?_def _c _i _d := h.getElem?_def ..
getElem!_def _c _i _d := h.getElem!_def ..

@[simp] theorem getElem_fin [GetElem Cont Nat Elem Dom] (a : Cont) (i : Fin n) (h : Dom a i) :
a[i] = a[i.1] := rfl

@[simp] theorem getElem?_fin [h : GetElem Cont Nat Elem Dom] (a : Cont) (i : Fin n)
[Decidable (Dom a i)] : a[i]? = a[i.1]? := by rfl

@[simp] theorem getElem!_fin [GetElem Cont Nat Elem Dom] (a : Cont) (i : Fin n)
[Decidable (Dom a i)] [Inhabited Elem] : a[i]! = a[i.1]! := rfl

macro_rules
| `(tactic| get_elem_tactic_trivial) => `(tactic| apply Fin.val_lt_of_le; get_elem_tactic_trivial; done)

end Fin

namespace List

instance : GetElem (List α) Nat α fun as i => i < as.length where
getElem as i h := as.get ⟨i, h⟩

instance : LawfulGetElem (List α) Nat α fun as i => i < as.length where

@[simp] theorem cons_getElem_zero (a : α) (as : List α) (h : 0 < (a :: as).length) : getElem (a :: as) 0 h = a := by
rfl

@[simp] theorem cons_getElem_succ (a : α) (as : List α) (i : Nat) (h : i + 1 < (a :: as).length) : getElem (a :: as) (i+1) h = getElem as i (Nat.lt_of_succ_lt_succ h) := by
rfl

theorem get_drop_eq_drop (as : List α) (i : Nat) (h : i < as.length) : as[i] :: as.drop (i+1) = as.drop i :=
match as, i with
| _::_, 0 => rfl
| _::_, i+1 => get_drop_eq_drop _ i _

end List

namespace Array

instance : GetElem (Array α) Nat α fun xs i => i < xs.size where
getElem xs i h := xs.get ⟨i, h⟩

instance : LawfulGetElem (Array α) Nat α fun xs i => i < xs.size where

end Array

namespace Lean.Syntax

instance : GetElem Syntax Nat Syntax fun _ _ => True where
getElem stx i _ := stx.getArg i

instance : LawfulGetElem Syntax Nat Syntax fun _ _ => True where

end Lean.Syntax
8 changes: 0 additions & 8 deletions src/Init/Meta.lean
Expand Up @@ -1194,14 +1194,6 @@ instance : Coe (Lean.Term) (Lean.TSyntax `Lean.Parser.Term.funBinder) where

end Lean.Syntax

set_option linter.unusedVariables.funArgs false in
/--
Gadget for automatic parameter support. This is similar to the `optParam` gadget, but it uses
the given tactic.
Like `optParam`, this gadget only affects elaboration.
For example, the tactic will *not* be invoked during type class resolution. -/
abbrev autoParam.{u} (α : Sort u) (tactic : Lean.Syntax) : Sort u := α

/-! # Helper functions for manipulating interpolated strings -/

namespace Lean.Syntax
Expand Down
43 changes: 0 additions & 43 deletions src/Init/Prelude.lean
Expand Up @@ -2543,43 +2543,6 @@ def panic {α : Type u} [Inhabited α] (msg : String) : α :=
-- TODO: this be applied directly to `Inhabited`'s definition when we remove the above workaround
attribute [nospecialize] Inhabited

/--
The class `GetElem cont idx elem dom` implements the `xs[i]` notation.
When you write this, given `xs : cont` and `i : idx`, Lean looks for an instance
of `GetElem cont idx elem dom`. Here `elem` is the type of `xs[i]`, while
`dom` is whatever proof side conditions are required to make this applicable.
For example, the instance for arrays looks like
`GetElem (Array α) Nat α (fun xs i => i < xs.size)`.
The proof side-condition `dom xs i` is automatically dispatched by the
`get_elem_tactic` tactic, which can be extended by adding more clauses to
`get_elem_tactic_trivial`.
-/
class GetElem (cont : Type u) (idx : Type v) (elem : outParam (Type w)) (dom : outParam (cont → idx → Prop)) where
/--
The syntax `arr[i]` gets the `i`'th element of the collection `arr`.
If there are proof side conditions to the application, they will be automatically
inferred by the `get_elem_tactic` tactic.
The actual behavior of this class is type-dependent,
but here are some important implementations:
* `arr[i] : α` where `arr : Array α` and `i : Nat` or `i : USize`:
does array indexing with no bounds check and a proof side goal `i < arr.size`.
* `l[i] : α` where `l : List α` and `i : Nat`: index into a list,
with proof side goal `i < l.length`.
* `stx[i] : Syntax` where `stx : Syntax` and `i : Nat`: get a syntax argument,
no side goal (returns `.missing` out of range)
There are other variations on this syntax:
* `arr[i]`: proves the proof side goal by `get_elem_tactic`
* `arr[i]!`: panics if the side goal is false
* `arr[i]?`: returns `none` if the side goal is false
* `arr[i]'h`: uses `h` to prove the side goal
-/
getElem (xs : cont) (i : idx) (h : dom xs i) : elem

export GetElem (getElem)

/--
`Array α` is the type of [dynamic arrays](https://en.wikipedia.org/wiki/Dynamic_array)
with elements from `α`. This type has special support in the runtime.
Expand Down Expand Up @@ -2637,9 +2600,6 @@ def Array.get {α : Type u} (a : @& Array α) (i : @& Fin a.size) : α :=
def Array.get! {α : Type u} [Inhabited α] (a : @& Array α) (i : @& Nat) : α :=
Array.getD a i default

instance : GetElem (Array α) Nat α fun xs i => LT.lt i xs.size where
getElem xs i h := xs.get ⟨i, h⟩

/--
Push an element onto the end of an array. This is amortized O(1) because
`Array α` is internally a dynamic array.
Expand Down Expand Up @@ -3907,9 +3867,6 @@ def getArg (stx : Syntax) (i : Nat) : Syntax :=
| Syntax.node _ _ args => args.getD i Syntax.missing
| _ => Syntax.missing

instance : GetElem Syntax Nat Syntax fun _ _ => True where
getElem stx i _ := stx.getArg i

/-- Gets the list of arguments of the syntax node, or `#[]` if it's not a `node`. -/
def getArgs (stx : Syntax) : Array Syntax :=
match stx with
Expand Down
16 changes: 8 additions & 8 deletions src/Init/Tactics.lean
Expand Up @@ -1522,16 +1522,16 @@ macro "get_elem_tactic" : tactic =>
- Use `a[i]'h` notation instead, where `h` is a proof that index is valid"
)

@[inherit_doc getElem]
syntax:max term noWs "[" withoutPosition(term) "]" : term
macro_rules | `($x[$i]) => `(getElem $x $i (by get_elem_tactic))

@[inherit_doc getElem]
syntax term noWs "[" withoutPosition(term) "]'" term:max : term
macro_rules | `($x[$i]'$h) => `(getElem $x $i $h)

/--
Searches environment for definitions or theorems that can be substituted in
for `exact?% to solve the goal.
-/
syntax (name := Lean.Parser.Syntax.exact?) "exact?%" : term

set_option linter.unusedVariables.funArgs false in
/--
Gadget for automatic parameter support. This is similar to the `optParam` gadget, but it uses
the given tactic.
Like `optParam`, this gadget only affects elaboration.
For example, the tactic will *not* be invoked during type class resolution. -/
abbrev autoParam.{u} (α : Sort u) (tactic : Lean.Syntax) : Sort u := α

0 comments on commit 0963f34

Please sign in to comment.