Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/Lean/Meta/Tactic/Cbv/CbvEvalExt.lean
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ open Lean.Meta.Sym.Simp

/--
Entry of the `CbvEvalExtension`.
Consists of the precomputed `Theorem` object and a name of the head function appearing on the left-hand side of the theorem.
Consists of name of the theorem used, the precomputed `Theorem` object
and a name of the head function appearing on the left-hand side of the theorem.
-/
structure CbvEvalEntry where
appFn : Name
thm : Theorem
origin : Name
appFn : Name
thm : Theorem
deriving BEq, Inhabited

/-- Create a `CbvEvalEntry` from a theorem declaration. When `inv = true`, creates an
Expand All @@ -62,15 +64,38 @@ def mkCbvTheoremFromConst (declName : Name) (inv : Bool := false) : MetaM CbvEva
thmDeclName ← mkAuxLemma (kind? := `_cbv_eval) cinfo.levelParams invType invVal
return (constName, thmDeclName)
let thm ← mkTheoremFromDecl thmDeclName
return ⟨fnName, thm⟩
return ⟨declName, fnName, thm⟩

structure CbvEvalState where
lemmas : NameMap Theorems := {}
lemmas : NameMap Theorems := {}
entries : NameMap <| Array CbvEvalEntry := {}
deriving Inhabited

def CbvEvalState.addEntry (s : CbvEvalState) (e : CbvEvalEntry) : CbvEvalState :=
let existing := (s.lemmas.find? e.appFn).getD {}
{ s with lemmas := s.lemmas.insert e.appFn (existing.insert e.thm) }
let lemmas := (s.lemmas.find? e.appFn).getD {}
let entries := (s.entries.find? e.appFn).getD {}
{ s with
lemmas := s.lemmas.insert e.appFn (lemmas.insert e.thm)
entries := s.entries.insert e.appFn <| entries.push e}

/-- Rebuild the `Theorems` for a given `appFn` from the entries that target it. -/
private def CbvEvalState.rebuildLemmas (entries : NameMap <| Array CbvEvalEntry) (appFn : Name) (lemmas : NameMap Theorems) : NameMap Theorems :=
let appFnEntries := entries.getD appFn #[]
if appFnEntries.isEmpty then
lemmas.erase appFn
else
lemmas.insert appFn (appFnEntries.foldl (fun thms e => thms.insert e.thm) {})

/-- Erase a theorem from the state. Returns `none` if the theorem is not found. -/
def CbvEvalState.erase (s : CbvEvalState) (declName : Name) : Option CbvEvalState := do
let (appFn, oldEntries) ← s.entries.foldl (init := none) fun acc appFn entries =>
if acc.isSome then acc
else if entries.any (·.origin == declName) then some (appFn, entries)
else none
let newEntries := oldEntries.filter (·.origin != declName)
let entries := if newEntries.isEmpty then s.entries.erase appFn
else s.entries.insert appFn newEntries
return { lemmas := rebuildLemmas entries appFn s.lemmas, entries }

abbrev CbvEvalExtension := SimpleScopedEnvExtension CbvEvalEntry CbvEvalState

Expand Down Expand Up @@ -99,6 +124,11 @@ builtin_initialize
let inv := !stx[1].isNone
let (entry, _) ← MetaM.run (mkCbvTheoremFromConst lemmaName (inv := inv)) {}
cbvEvalExt.add entry kind
erase := fun declName => do
let s := cbvEvalExt.getState (← getEnv)
match s.erase declName with
| some s' => modifyEnv fun env => cbvEvalExt.modifyState env fun _ => s'
| none => logWarning m!"`{.ofConstName declName}` does not have the `[cbv_eval]` attribute"
}

end Lean.Meta.Tactic.Cbv
100 changes: 100 additions & 0 deletions tests/elab/cbv_eval_erase.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import Std
set_option cbv.warning false
-- Use opaque constants so that cbv_eval is the ONLY way cbv can reduce them.
-- After erasure, cbv makes no progress and the goal must be closed manually.

opaque myConst : Nat
@[cbv_eval] theorem myConst_eq : myConst = 42 := sorry

-- Before erasure: cbv reduces myConst to 42 via cbv_eval
example : myConst = 42 := by cbv

-- Warning when erasing a theorem that doesn't have [cbv_eval]
theorem not_cbv (n : Nat) : n = n := rfl

/-- warning: `not_cbv` does not have the `[cbv_eval]` attribute -/
#guard_msgs in
attribute [-cbv_eval] not_cbv

-- Basic erasure (no section — permanent)
attribute [-cbv_eval] myConst_eq

-- After erasure: cbv can't reduce myConst, so the goal isn't closed
example : myConst = 42 := by
cbv -- makes no progress
exact myConst_eq

-- Scoping: erasure inside a section is reverted
opaque myConst2 : Nat
@[cbv_eval] theorem myConst2_eq : myConst2 = 100 := sorry

section
attribute [-cbv_eval] myConst2_eq

-- Inside section: cbv can't reduce
/--
trace: ⊢ myConst2 = 100
---
warning: declaration uses `sorry`
-/
#guard_msgs in
example : myConst2 = 100 := by
cbv
trace_state
sorry
end

-- Outside section: cbv_eval is back
example : myConst2 = 100 := by cbv

-- Erasure of inverted theorem
opaque myConst3 : Nat
@[cbv_eval ←] theorem myConst3_eq : 7 = myConst3 := sorry

example : myConst3 = 7 := by cbv

section
attribute [-cbv_eval] myConst3_eq

/--
trace: ⊢ myConst3 = 7
---
warning: declaration uses `sorry`
-/
#guard_msgs in
example : myConst3 = 7 := by
cbv
trace_state
sorry
end

-- Reverted after section
example : myConst3 = 7 := by cbv

-- Erasure with multiple cbv_eval rules: erase only one
opaque myFn : Nat → Nat
@[cbv_eval] theorem myFn_zero : myFn 0 = 1 := sorry
@[cbv_eval] theorem myFn_one : myFn 1 = 0 := sorry

example : myFn 0 = 1 := by cbv
example : myFn 1 = 0 := by cbv

section
attribute [-cbv_eval] myFn_zero

-- myFn_zero is erased, so cbv can't reduce myFn 0

/--
trace: ⊢ myFn 0 = 1
---
warning: declaration uses `sorry`
-/
#guard_msgs in
example : myFn 0 = 1 := by
cbv
trace_state
sorry

-- myFn_one is still active
example : myFn 1 = 0 := by cbv
end
5 changes: 5 additions & 0 deletions tests/elab/cbv_eval_erase.lean.out.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cbv_eval_erase.lean:7:20-7:30: warning: declaration uses `sorry`
cbv_eval_erase.lean:29:20-29:31: warning: declaration uses `sorry`
cbv_eval_erase.lean:52:22-52:33: warning: declaration uses `sorry`
cbv_eval_erase.lean:76:20-76:29: warning: declaration uses `sorry`
cbv_eval_erase.lean:77:20-77:28: warning: declaration uses `sorry`
Loading