Skip to content

Commit

Permalink
refactor: add configuration options to control WHNF
Browse files Browse the repository at this point in the history
This commit also removes parameter `simpleReduce` from discrimination
trees, and take WHNF configuration options.
Reason: it is more dynamic now. For example, the simplifier
will be able to use different configurations for discrimination tree insertion
and retrieval. We need this feature to address issues #2669 and #2281

This commit also removes the dead Meta.Config field `zetaNonDep`.
  • Loading branch information
leodemoura committed Oct 25, 2023
1 parent aecc83e commit 3a13200
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 180 deletions.
4 changes: 2 additions & 2 deletions src/Lean/Meta/ACLt.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ where
-- Drawback: cost.
return e
else match mode with
| .reduce => DiscrTree.reduce e (simpleReduce := false)
| .reduceSimpleOnly => DiscrTree.reduce e (simpleReduce := true)
| .reduce => DiscrTree.reduce e {}
| .reduceSimpleOnly => DiscrTree.reduce e { iota := false, proj := .no }
| .none => return e

lt (a b : Expr) : MetaM Bool := do
Expand Down
2 changes: 0 additions & 2 deletions src/Lean/Meta/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ structure Config where
Controls which definitions and theorems can be unfolded by `isDefEq` and `whnf`.
-/
transparency : TransparencyMode := TransparencyMode.default
/-- If zetaNonDep == false, then non dependent let-decls are not zeta expanded. -/
zetaNonDep : Bool := true
/--
When `trackZeta = true`, we track all free variables that have been zeta-expanded.
That is, suppose the local context contains
Expand Down
144 changes: 72 additions & 72 deletions src/Lean/Meta/DiscrTree.lean

Large diffs are not rendered by default.

61 changes: 30 additions & 31 deletions src/Lean/Meta/DiscrTreeTypes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,55 +13,50 @@ namespace DiscrTree
/--
Discrimination tree key. See `DiscrTree`
-/
inductive Key (simpleReduce : Bool) where
| const : Name → Nat → Key simpleReduce
| fvar : FVarId → Nat → Key simpleReduce
| lit : Literal → Key simpleReduce
| star : Key simpleReduce
| other : Key simpleReduce
| arrow : Key simpleReduce
| proj : Name → Nat → Nat → Key simpleReduce
inductive Key where
| const : Name → Nat → Key
| fvar : FVarId → Nat → Key
| lit : Literal → Key
| star : Key
| other : Key
| arrow : Key
| proj : Name → Nat → Nat → Key
deriving Inhabited, BEq, Repr

protected def Key.hash : Key s → UInt64
| Key.const n a => mixHash 5237 $ mixHash (hash n) (hash a)
| Key.fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
| Key.lit v => mixHash 1879 $ hash v
| Key.star => 7883
| Key.other => 2411
| Key.arrow => 17
| Key.proj s i a => mixHash (hash a) $ mixHash (hash s) (hash i)
protected def Key.hash : Key → UInt64
| .const n a => mixHash 5237 $ mixHash (hash n) (hash a)
| .fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
| .lit v => mixHash 1879 $ hash v
| .star => 7883
| .other => 2411
| .arrow => 17
| .proj s i a => mixHash (hash a) $ mixHash (hash s) (hash i)

instance : Hashable (Key s) := ⟨Key.hash⟩
instance : Hashable Key := ⟨Key.hash⟩

/--
Discrimination tree trie. See `DiscrTree`.
-/
inductive Trie (α : Type) (simpleReduce : Bool) where
| node (vs : Array α) (children : Array (Key simpleReduce × Trie α simpleReduce)) : Trie α simpleReduce
inductive Trie (α : Type) where
| node (vs : Array α) (children : Array (Key × Trie α)) : Trie α

end DiscrTree

open DiscrTree

/--
Discrimination trees. It is an index from terms to values of type `α`.
If `simpleReduce := true`, then only simple reduction are performed while
indexing/retrieving terms. For example, `iota` reduction is not performed.
We use `simpleReduce := false` in the type class resolution module,
and `simpleReduce := true` in `simp`.
/-!
Notes regarding term reduction at the `DiscrTree` module.
Motivations:
- In `simp`, we want to have `simp` theorem such as
```
@[simp] theorem liftOn_mk (a : α) (f : α → γ) (h : ∀ a₁ a₂, r a₁ a₂ → f a₁ = f a₂) :
Quot.liftOn (Quot.mk r a) f h = f a := rfl
```
If we enable `iota`, then the lhs is reduced to `f a`.
Note that when retrieving terms, we may also disable `beta` and `zeta` reduction.
See issue https://github.com/leanprover/lean4/issues/2669
- During type class resolution, we often want to reduce types using even `iota`.
- During type class resolution, we often want to reduce types using even `iota` and projection reductionn.
Example:
```
inductive Ty where
Expand All @@ -80,7 +75,11 @@ def f (a b : Ty.bool.interp) : Ty.bool.interp :=
test (.==.) a b
```
-/
structure DiscrTree (α : Type) (simpleReduce : Bool) where
root : PersistentHashMap (Key simpleReduce) (Trie α simpleReduce) := {}

/--
Discrimination trees. It is an index from terms to values of type `α`.
-/
structure DiscrTree (α : Type) where
root : PersistentHashMap Key (Trie α) := {}

end Lean.Meta
15 changes: 9 additions & 6 deletions src/Lean/Meta/Instances.lean
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def f (a b : Ty.bool.interp) : Ty.bool.interp :=
See comment at `DiscrTree`.
-/

abbrev InstanceKey := DiscrTree.Key (simpleReduce := false)
abbrev InstanceKey := DiscrTree.Key

structure InstanceEntry where
keys : Array InstanceKey
Expand All @@ -63,18 +63,21 @@ instance : ToFormat InstanceEntry where
| some n => format n
| _ => "<local>"

abbrev InstanceTree := DiscrTree InstanceEntry (simpleReduce := false)
abbrev InstanceTree := DiscrTree InstanceEntry

structure Instances where
discrTree : InstanceTree := DiscrTree.empty
instanceNames : PHashMap Name InstanceEntry := {}
erased : PHashSet Name := {}
deriving Inhabited

/-- Configuration for the discrimination tree module -/
def tcDtConfig : WhnfCoreConfig := {}

def addInstanceEntry (d : Instances) (e : InstanceEntry) : Instances :=
match e.globalName? with
| some n => { d with discrTree := d.discrTree.insertCore e.keys e, instanceNames := d.instanceNames.insert n e, erased := d.erased.erase n }
| none => { d with discrTree := d.discrTree.insertCore e.keys e }
| some n => { d with discrTree := d.discrTree.insertCore e.keys e tcDtConfig, instanceNames := d.instanceNames.insert n e, erased := d.erased.erase n }
| none => { d with discrTree := d.discrTree.insertCore e.keys e tcDtConfig }

def Instances.eraseCore (d : Instances) (declName : Name) : Instances :=
{ d with erased := d.erased.insert declName, instanceNames := d.instanceNames.erase declName }
Expand All @@ -94,7 +97,7 @@ private def mkInstanceKey (e : Expr) : MetaM (Array InstanceKey) := do
let type ← inferType e
withNewMCtxDepth do
let (_, _, type) ← forallMetaTelescopeReducing type
DiscrTree.mkPath type
DiscrTree.mkPath type tcDtConfig

/--
Compute the order the arguments of `inst` should by synthesized.
Expand Down Expand Up @@ -207,7 +210,7 @@ builtin_initialize
modifyEnv fun env => instanceExtension.modifyState env fun _ => s
}

def getGlobalInstancesIndex : CoreM (DiscrTree InstanceEntry (simpleReduce := false)) :=
def getGlobalInstancesIndex : CoreM (DiscrTree InstanceEntry) :=
return Meta.instanceExtension.getState (← getEnv) |>.discrTree

def getErasedInstances : CoreM (PHashSet Name) :=
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/SynthInstance.lean
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def getInstances (type : Expr) : MetaM (Array Instance) := do
| none => throwError "type class instance expected{indentExpr type}"
| some className =>
let globalInstances ← getGlobalInstancesIndex
let result ← globalInstances.getUnify type
let result ← globalInstances.getUnify type tcDtConfig
-- Using insertion sort because it is stable and the array `result` should be mostly sorted.
-- Most instances have default priority.
let result := result.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority
Expand Down
2 changes: 1 addition & 1 deletion src/Lean/Meta/Tactic/Simp/Rewrite.lean
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def tryTheorem? (e : Expr) (thm : SimpTheorem) (discharge? : Expr → SimpM (Opt
Remark: the parameter tag is used for creating trace messages. It is irrelevant otherwise.
-/
def rewrite? (e : Expr) (s : SimpTheoremTree) (erased : PHashSet Origin) (discharge? : Expr → SimpM (Option Expr)) (tag : String) (rflOnly : Bool) : SimpM (Option Result) := do
let candidates ← s.getMatchWithExtra e
let candidates ← s.getMatchWithExtra e simpDtConfig
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return none
Expand Down
15 changes: 9 additions & 6 deletions src/Lean/Meta/Tactic/Simp/SimpTheorems.lean
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ If we use `iota`, then the lhs is reduced to `f a`.
See comment at `DiscrTree`.
-/

abbrev SimpTheoremKey := DiscrTree.Key (simpleReduce := true)
abbrev SimpTheoremKey := DiscrTree.Key

/--
The fields `levelParams` and `proof` are used to encode the proof of the simp theorem.
Expand Down Expand Up @@ -151,7 +151,7 @@ def ppSimpTheorem [Monad m] [MonadLiftT IO m] [MonadEnv m] [MonadError m] (s : S
instance : BEq SimpTheorem where
beq e₁ e₂ := e₁.proof == e₂.proof

abbrev SimpTheoremTree := DiscrTree SimpTheorem (simpleReduce := true)
abbrev SimpTheoremTree := DiscrTree SimpTheorem

structure SimpTheorems where
pre : SimpTheoremTree := DiscrTree.empty
Expand All @@ -162,11 +162,14 @@ structure SimpTheorems where
toUnfoldThms : PHashMap Name (Array Name) := {}
deriving Inhabited

/-- Configuration for the discrimination tree. -/
def simpDtConfig : WhnfCoreConfig := { iota := false, proj := .no }

def addSimpTheoremEntry (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
if e.post then
{ d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
{ d with post := d.post.insertCore e.keys e simpDtConfig, lemmaNames := updateLemmaNames d.lemmaNames }
else
{ d with pre := d.pre.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
{ d with pre := d.pre.insertCore e.keys e simpDtConfig, lemmaNames := updateLemmaNames d.lemmaNames }
where
updateLemmaNames (s : PHashSet Origin) : PHashSet Origin :=
s.insert e.origin
Expand Down Expand Up @@ -218,7 +221,7 @@ private partial def isPerm : Expr → Expr → MetaM Bool
| s, t => return s == t

private def checkBadRewrite (lhs rhs : Expr) : MetaM Unit := do
let lhs ← DiscrTree.reduceDT lhs (root := true) (simpleReduce := true)
let lhs ← DiscrTree.reduceDT lhs (root := true) simpDtConfig
if lhs == rhs && lhs.isFVar then
throwError "invalid `simp` theorem, equation is equivalent to{indentExpr (← mkEq lhs rhs)}"

Expand Down Expand Up @@ -305,7 +308,7 @@ private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array
let type ← whnfR type
let (keys, perm) ←
match type.eq? with
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs, ← isPerm lhs rhs)
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs simpDtConfig, ← isPerm lhs rhs)
| none => throwError "unexpected kind of 'simp' theorem{indentExpr type}"
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) }

Expand Down
12 changes: 7 additions & 5 deletions src/Lean/Meta/UnificationHint.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import Lean.Meta.SynthInstance

namespace Lean.Meta

abbrev UnificationHintKey := DiscrTree.Key (simpleReduce := true)
abbrev UnificationHintKey := DiscrTree.Key

structure UnificationHintEntry where
keys : Array UnificationHintKey
val : Name
deriving Inhabited

abbrev UnificationHintTree := DiscrTree Name (simpleReduce := true)
abbrev UnificationHintTree := DiscrTree Name

structure UnificationHints where
discrTree : UnificationHintTree := DiscrTree.empty
Expand All @@ -26,8 +26,10 @@ structure UnificationHints where
instance : ToFormat UnificationHints where
format h := format h.discrTree

def UnificationHints.config : WhnfCoreConfig := {}

def UnificationHints.add (hints : UnificationHints) (e : UnificationHintEntry) : UnificationHints :=
{ hints with discrTree := hints.discrTree.insertCore e.keys e.val }
{ hints with discrTree := hints.discrTree.insertCore e.keys e.val config }

builtin_initialize unificationHintExtension : SimpleScopedEnvExtension UnificationHintEntry UnificationHints ←
registerSimpleScopedEnvExtension {
Expand Down Expand Up @@ -78,7 +80,7 @@ def addUnificationHint (declName : Name) (kind : AttributeKind) : MetaM Unit :=
match decodeUnificationHint body with
| Except.error msg => throwError msg
| Except.ok hint =>
let keys ← DiscrTree.mkPath hint.pattern.lhs
let keys ← DiscrTree.mkPath hint.pattern.lhs UnificationHints.config
validateHint hint
unificationHintExtension.add { keys := keys, val := declName } kind

Expand All @@ -98,7 +100,7 @@ def tryUnificationHints (t s : Expr) : MetaM Bool := do
if t.isMVar then
return false
let hints := unificationHintExtension.getState (← getEnv)
let candidates ← hints.discrTree.getMatch t
let candidates ← hints.discrTree.getMatch t UnificationHints.config
for candidate in candidates do
if (← tryCandidate candidate) then
return true
Expand Down
Loading

0 comments on commit 3a13200

Please sign in to comment.