-
Notifications
You must be signed in to change notification settings - Fork 350
/
SimpTheorems.lean
452 lines (397 loc) · 18.8 KB
/
SimpTheorems.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.ScopedEnvExtension
import Lean.Util.Recognizers
import Lean.Meta.DiscrTree
import Lean.Meta.AppBuilder
import Lean.Meta.Eqns
import Lean.Meta.Tactic.AuxLemma
import Lean.DocString
namespace Lean.Meta
/--
An `Origin` is an identifier for simp theorems which indicates roughly
what action the user took which lead to this theorem existing in the simp set.
-/
inductive Origin where
/-- A global declaration in the environment. -/
| decl (declName : Name) (post := true) (inv := false)
/--
A local hypothesis.
When `contextual := true` is enabled, this fvar may exist in an extension
of the current local context; it will not be used for rewriting by simp once
it is out of scope but it may end up in the `usedSimps` trace.
-/
| fvar (fvarId : FVarId)
/--
A proof term provided directly to a call to `simp [ref, ...]` where `ref`
is the provided simp argument (of kind `Parser.Tactic.simpLemma`).
The `id` is a unique identifier for the call.
-/
| stx (id : Name) (ref : Syntax)
/--
Some other origin. `name` should not collide with the other types
for erasure to work correctly, and simp trace will ignore this lemma.
The other origins should be preferred if possible.
-/
| other (name : Name)
deriving Inhabited, Repr
/-- A unique identifier corresponding to the origin. -/
def Origin.key : Origin → Name
| .decl declName _ _ => declName
| .fvar fvarId => fvarId.name
| .stx id _ => id
| .other name => name
instance : BEq Origin := ⟨(·.key == ·.key)⟩
instance : Hashable Origin := ⟨(hash ·.key)⟩
/-
Note: we want to use iota reduction when indexing instances. Otherwise,
we cannot use simp theorems such as
```
@[simp] theorem liftOn_mk (a : α) (f : α → γ) (h : ∀ a₁ a₂, r a₁ a₂ → f a₁ = f a₂) :
Quot.liftOn (Quot.mk r a) f h = f a := rfl
```
If we use `iota`, then the lhs is reduced to `f a`.
See comment at `DiscrTree`.
-/
abbrev SimpTheoremKey := DiscrTree.Key
/--
The fields `levelParams` and `proof` are used to encode the proof of the simp theorem.
If the `proof` is a global declaration `c`, we store `Expr.const c []` at `proof` without the universe levels, and `levelParams` is set to `#[]`
When using the lemma, we create fresh universe metavariables.
Motivation: most simp theorems are global declarations, and this approach is faster and saves memory.
The field `levelParams` is not empty only when we elaborate an expression provided by the user, and it contains universe metavariables.
Then, we use `abstractMVars` to abstract the universe metavariables and create new fresh universe parameters that are stored at the field `levelParams`.
-/
structure SimpTheorem where
keys : Array SimpTheoremKey := #[]
/--
It stores universe parameter names for universe polymorphic proofs.
Recall that it is non-empty only when we elaborate an expression provided by the user.
When `proof` is just a constant, we can use the universe parameter names stored in the declaration.
-/
levelParams : Array Name := #[]
proof : Expr
priority : Nat := eval_prio default
post : Bool := true
/-- `perm` is true if lhs and rhs are identical modulo permutation of variables. -/
perm : Bool := false
/--
`origin` is mainly relevant for producing trace messages.
It is also viewed an `id` used to "erase" `simp` theorems from `SimpTheorems`.
-/
origin : Origin
/-- `rfl` is true if `proof` is by `Eq.refl` or `rfl`. -/
rfl : Bool
deriving Inhabited
mutual
partial def isRflProofCore (type : Expr) (proof : Expr) : CoreM Bool := do
match type with
| .forallE _ _ type _ =>
if let .lam _ _ proof _ := proof then
isRflProofCore type proof
else
return false
| _ =>
if type.isAppOfArity ``Eq 3 then
if proof.isAppOfArity ``Eq.refl 2 || proof.isAppOfArity ``rfl 2 then
return true
else if proof.isAppOfArity ``Eq.symm 4 then
-- `Eq.symm` of rfl theorem is a rfl theorem
isRflProofCore type proof.appArg! -- small hack: we don't need to set the exact type
else if proof.getAppFn.isConst then
-- The application of a `rfl` theorem is a `rfl` theorem
-- A constant which is a `rfl` theorem is a `rfl` theorem
isRflTheorem proof.getAppFn.constName!
else
return false
else
return false
partial def isRflTheorem (declName : Name) : CoreM Bool := do
let .thmInfo info ← getConstInfo declName | return false
isRflProofCore info.type info.value
end
def isRflProof (proof : Expr) : MetaM Bool := do
if let .const declName .. := proof then
isRflTheorem declName
else
isRflProofCore (← inferType proof) proof
instance : ToFormat SimpTheorem where
format s :=
let perm := if s.perm then ":perm" else ""
let name := format s.origin.key
let prio := f!":{s.priority}"
name ++ prio ++ perm
def ppOrigin [Monad m] [MonadEnv m] [MonadError m] : Origin → m MessageData
| .decl n post inv => do
let r ← mkConstWithLevelParams n;
match post, inv with
| true, true => return m!"← {r}"
| true, false => return r
| false, true => return m!"↓ ← {r}"
| false, false => return m!"↓ {r}"
| .fvar n => return mkFVar n
| .stx _ ref => return ref
| .other n => return n
def ppSimpTheorem [Monad m] [MonadLiftT IO m] [MonadEnv m] [MonadError m] (s : SimpTheorem) : m MessageData := do
let perm := if s.perm then ":perm" else ""
let name ← ppOrigin s.origin
let prio := m!":{s.priority}"
return name ++ prio ++ perm
instance : BEq SimpTheorem where
beq e₁ e₂ := e₁.proof == e₂.proof
abbrev SimpTheoremTree := DiscrTree SimpTheorem
structure SimpTheorems where
pre : SimpTheoremTree := DiscrTree.empty
post : SimpTheoremTree := DiscrTree.empty
lemmaNames : PHashSet Origin := {}
/--
Constants (and let-declaration `FVarId`) to unfold.
When `zetaDelta := false`, the simplifier will expand a let-declaration if it is in this set.
-/
toUnfold : PHashSet Name := {}
erased : PHashSet Origin := {}
toUnfoldThms : PHashMap Name (Array Name) := {}
deriving Inhabited
/-- Configuration for the discrimination tree. -/
def simpDtConfig : WhnfCoreConfig := { iota := false, proj := .no, zetaDelta := false }
def addSimpTheoremEntry (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
if e.post then
{ d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
else
{ d with pre := d.pre.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
where
updateLemmaNames (s : PHashSet Origin) : PHashSet Origin :=
s.insert e.origin
def SimpTheorems.addDeclToUnfoldCore (d : SimpTheorems) (declName : Name) : SimpTheorems :=
{ d with toUnfold := d.toUnfold.insert declName }
def SimpTheorems.addLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : SimpTheorems :=
-- A small hack that relies on the fact that constants and `FVarId` names should be disjoint.
{ d with toUnfold := d.toUnfold.insert fvarId.name }
/-- Return `true` if `declName` is tagged to be unfolded using `unfoldDefinition?` (i.e., without using equational theorems). -/
def SimpTheorems.isDeclToUnfold (d : SimpTheorems) (declName : Name) : Bool :=
d.toUnfold.contains declName
def SimpTheorems.isLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : Bool :=
d.toUnfold.contains fvarId.name -- See comment at `addLetDeclToUnfold`
def SimpTheorems.isLemma (d : SimpTheorems) (thmId : Origin) : Bool :=
d.lemmaNames.contains thmId
/-- Register the equational theorems for the given definition. -/
def SimpTheorems.registerDeclToUnfoldThms (d : SimpTheorems) (declName : Name) (eqThms : Array Name) : SimpTheorems :=
{ d with toUnfoldThms := d.toUnfoldThms.insert declName eqThms }
partial def SimpTheorems.eraseCore (d : SimpTheorems) (thmId : Origin) : SimpTheorems :=
let d := { d with erased := d.erased.insert thmId, lemmaNames := d.lemmaNames.erase thmId }
if let .decl declName .. := thmId then
let d := { d with toUnfold := d.toUnfold.erase declName }
if let some thms := d.toUnfoldThms.find? declName then
let dummy := true
thms.foldl (init := d) (eraseCore · <| .decl · dummy dummy)
else
d
else
d
def SimpTheorems.erase [Monad m] [MonadError m] (d : SimpTheorems) (thmId : Origin) : m SimpTheorems := do
unless d.isLemma thmId ||
match thmId with
| .decl declName .. => d.isDeclToUnfold declName || d.toUnfoldThms.contains declName
| _ => false
do
throwError "'{thmId.key}' does not have [simp] attribute"
return d.eraseCore thmId
private partial def isPerm : Expr → Expr → MetaM Bool
| Expr.app f₁ a₁, Expr.app f₂ a₂ => isPerm f₁ f₂ <&&> isPerm a₁ a₂
| Expr.mdata _ s, t => isPerm s t
| s, Expr.mdata _ t => isPerm s t
| s@(Expr.mvar ..), t@(Expr.mvar ..) => isDefEq s t
| Expr.forallE n₁ d₁ b₁ _, Expr.forallE _ d₂ b₂ _ => isPerm d₁ d₂ <&&> withLocalDeclD n₁ d₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| Expr.lam n₁ d₁ b₁ _, Expr.lam _ d₂ b₂ _ => isPerm d₁ d₂ <&&> withLocalDeclD n₁ d₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| Expr.letE n₁ t₁ v₁ b₁ _, Expr.letE _ t₂ v₂ b₂ _ =>
isPerm t₁ t₂ <&&> isPerm v₁ v₂ <&&> withLetDecl n₁ t₁ v₁ fun x => isPerm (b₁.instantiate1 x) (b₂.instantiate1 x)
| Expr.proj _ i₁ b₁, Expr.proj _ i₂ b₂ => pure (i₁ == i₂) <&&> isPerm b₁ b₂
| s, t => return s == t
private def checkBadRewrite (lhs rhs : Expr) : MetaM Unit := do
let lhs ← DiscrTree.reduceDT lhs (root := true) simpDtConfig
if lhs == rhs && lhs.isFVar then
throwError "invalid `simp` theorem, equation is equivalent to{indentExpr (← mkEq lhs rhs)}"
private partial def shouldPreprocess (type : Expr) : MetaM Bool :=
forallTelescopeReducing type fun _ result => do
if let some (_, lhs, rhs) := result.eq? then
checkBadRewrite lhs rhs
return false
else
return true
private partial def preprocess (e type : Expr) (inv : Bool) (isGlobal : Bool) : MetaM (List (Expr × Expr)) :=
go e type
where
go (e type : Expr) : MetaM (List (Expr × Expr)) := do
let type ← whnf type
if type.isForall then
forallTelescopeReducing type fun xs type => do
let e := mkAppN e xs
let ps ← go e type
ps.mapM fun (e, type) =>
return (← mkLambdaFVars xs e, ← mkForallFVars xs type)
else if let some (_, lhs, rhs) := type.eq? then
if isGlobal then
checkBadRewrite lhs rhs
if inv then
let type ← mkEq rhs lhs
let e ← mkEqSymm e
return [(e, type)]
else
return [(e, type)]
else if let some (lhs, rhs) := type.iff? then
if isGlobal then
checkBadRewrite lhs rhs
if inv then
let type ← mkEq rhs lhs
let e ← mkEqSymm (← mkPropExt e)
return [(e, type)]
else
let type ← mkEq lhs rhs
let e ← mkPropExt e
return [(e, type)]
else if let some (_, lhs, rhs) := type.ne? then
if inv then
throwError "invalid '←' modifier in rewrite rule to 'False'"
if rhs.isConstOf ``Bool.true then
return [(← mkAppM ``Bool.of_not_eq_true #[e], ← mkEq lhs (mkConst ``Bool.false))]
else if rhs.isConstOf ``Bool.false then
return [(← mkAppM ``Bool.of_not_eq_false #[e], ← mkEq lhs (mkConst ``Bool.true))]
let type ← mkEq (← mkEq lhs rhs) (mkConst ``False)
let e ← mkEqFalse e
return [(e, type)]
else if let some p := type.not? then
if inv then
throwError "invalid '←' modifier in rewrite rule to 'False'"
if let some (_, lhs, rhs) := p.eq? then
if rhs.isConstOf ``Bool.true then
return [(← mkAppM ``Bool.of_not_eq_true #[e], ← mkEq lhs (mkConst ``Bool.false))]
else if rhs.isConstOf ``Bool.false then
return [(← mkAppM ``Bool.of_not_eq_false #[e], ← mkEq lhs (mkConst ``Bool.true))]
let type ← mkEq p (mkConst ``False)
let e ← mkEqFalse e
return [(e, type)]
else if let some (type₁, type₂) := type.and? then
let e₁ := mkProj ``And 0 e
let e₂ := mkProj ``And 1 e
return (← go e₁ type₁) ++ (← go e₂ type₂)
else
if inv then
throwError "invalid '←' modifier in rewrite rule to 'True'"
let type ← mkEq type (mkConst ``True)
let e ← mkEqTrue e
return [(e, type)]
private def checkTypeIsProp (type : Expr) : MetaM Unit :=
unless (← isProp type) do
throwError "invalid 'simp', proposition expected{indentExpr type}"
private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array Name) (proof : Expr) (post : Bool) (prio : Nat) (noIndexAtArgs : Bool) : MetaM SimpTheorem := do
assert! origin != .fvar ⟨.anonymous⟩
let type ← instantiateMVars (← inferType e)
withNewMCtxDepth do
let (_, _, type) ← withReducible <| forallMetaTelescopeReducing type
let type ← whnfR type
let (keys, perm) ←
match type.eq? with
| some (_, lhs, rhs) => pure (← DiscrTree.mkPath lhs simpDtConfig noIndexAtArgs, ← isPerm lhs rhs)
| none => throwError "unexpected kind of 'simp' theorem{indentExpr type}"
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := (← isRflProof proof) }
private def mkSimpTheoremsFromConst (declName : Name) (post : Bool) (inv : Bool) (prio : Nat) : MetaM (Array SimpTheorem) := do
let cinfo ← getConstInfo declName
let val := mkConst declName (cinfo.levelParams.map mkLevelParam)
withReducible do
let type ← inferType val
checkTypeIsProp type
if inv || (← shouldPreprocess type) then
let mut r := #[]
for (val, type) in (← preprocess val type inv (isGlobal := true)) do
let auxName ← mkAuxLemma cinfo.levelParams type val
r := r.push <| (← mkSimpTheoremCore (.decl declName post inv) (mkConst auxName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst auxName) post prio (noIndexAtArgs := false))
return r
else
return #[← mkSimpTheoremCore (.decl declName post inv) (mkConst declName (cinfo.levelParams.map mkLevelParam)) #[] (mkConst declName) post prio (noIndexAtArgs := false)]
inductive SimpEntry where
| thm : SimpTheorem → SimpEntry
| toUnfold : Name → SimpEntry
| toUnfoldThms : Name → Array Name → SimpEntry
deriving Inhabited
abbrev SimpExtension := SimpleScopedEnvExtension SimpEntry SimpTheorems
def SimpExtension.getTheorems (ext : SimpExtension) : CoreM SimpTheorems :=
return ext.getState (← getEnv)
def addSimpTheorem (ext : SimpExtension) (declName : Name) (post : Bool) (inv : Bool) (attrKind : AttributeKind) (prio : Nat) : MetaM Unit := do
let simpThms ← mkSimpTheoremsFromConst declName post inv prio
for simpThm in simpThms do
ext.add (SimpEntry.thm simpThm) attrKind
def mkSimpExt (name : Name := by exact decl_name%) : IO SimpExtension :=
registerSimpleScopedEnvExtension {
name := name
initial := {}
addEntry := fun d e =>
match e with
| SimpEntry.thm e => addSimpTheoremEntry d e
| SimpEntry.toUnfold n => d.addDeclToUnfoldCore n
| SimpEntry.toUnfoldThms n thms => d.registerDeclToUnfoldThms n thms
}
abbrev SimpExtensionMap := HashMap Name SimpExtension
builtin_initialize simpExtensionMapRef : IO.Ref SimpExtensionMap ← IO.mkRef {}
def getSimpExtension? (attrName : Name) : IO (Option SimpExtension) :=
return (← simpExtensionMapRef.get).find? attrName
/-- Auxiliary method for adding a global declaration to a `SimpTheorems` datastructure. -/
def SimpTheorems.addConst (s : SimpTheorems) (declName : Name) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
let s := { s with erased := s.erased.erase (.decl declName post inv) }
let simpThms ← mkSimpTheoremsFromConst declName post inv prio
return simpThms.foldl addSimpTheoremEntry s
def SimpTheorem.getValue (simpThm : SimpTheorem) : MetaM Expr := do
if simpThm.proof.isConst && simpThm.levelParams.isEmpty then
let info ← getConstInfo simpThm.proof.constName!
if info.levelParams.isEmpty then
return simpThm.proof
else
return simpThm.proof.updateConst! (← info.levelParams.mapM (fun _ => mkFreshLevelMVar))
else
let us ← simpThm.levelParams.mapM fun _ => mkFreshLevelMVar
return simpThm.proof.instantiateLevelParamsArray simpThm.levelParams us
private def preprocessProof (val : Expr) (inv : Bool) : MetaM (Array Expr) := do
let type ← inferType val
checkTypeIsProp type
let ps ← preprocess val type inv (isGlobal := false)
return ps.toArray.map fun (val, _) => val
/-- Auxiliary method for creating simp theorems from a proof term `val`. -/
def mkSimpTheorems (id : Origin) (levelParams : Array Name) (proof : Expr) (post := true) (inv := false) (prio : Nat := eval_prio default) : MetaM (Array SimpTheorem) :=
withReducible do
(← preprocessProof proof inv).mapM fun val => mkSimpTheoremCore id val levelParams val post prio (noIndexAtArgs := true)
/-- Auxiliary method for adding a local simp theorem to a `SimpTheorems` datastructure. -/
def SimpTheorems.add (s : SimpTheorems) (id : Origin) (levelParams : Array Name) (proof : Expr) (inv := false) (post := true) (prio : Nat := eval_prio default) : MetaM SimpTheorems := do
if proof.isConst then
s.addConst proof.constName! post inv prio
else
let simpThms ← mkSimpTheorems id levelParams proof post inv prio
return simpThms.foldl addSimpTheoremEntry s
def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM SimpTheorems := do
if let some eqns ← getEqnsFor? declName then
let mut d := d
for eqn in eqns do
d ← SimpTheorems.addConst d eqn
if hasSmartUnfoldingDecl (← getEnv) declName then
d := d.addDeclToUnfoldCore declName
return d
else
return d.addDeclToUnfoldCore declName
abbrev SimpTheoremsArray := Array SimpTheorems
def SimpTheoremsArray.addTheorem (thmsArray : SimpTheoremsArray) (id : Origin) (h : Expr) : MetaM SimpTheoremsArray :=
if thmsArray.isEmpty then
let thms : SimpTheorems := {}
return #[ (← thms.add id #[] h) ]
else
thmsArray.modifyM 0 fun thms => thms.add id #[] h
def SimpTheoremsArray.eraseTheorem (thmsArray : SimpTheoremsArray) (thmId : Origin) : SimpTheoremsArray :=
thmsArray.map fun thms => thms.eraseCore thmId
def SimpTheoremsArray.isErased (thmsArray : SimpTheoremsArray) (thmId : Origin) : Bool :=
thmsArray.any fun thms => thms.erased.contains thmId
def SimpTheoremsArray.isDeclToUnfold (thmsArray : SimpTheoremsArray) (declName : Name) : Bool :=
thmsArray.any fun thms => thms.isDeclToUnfold declName
def SimpTheoremsArray.isLetDeclToUnfold (thmsArray : SimpTheoremsArray) (fvarId : FVarId) : Bool :=
thmsArray.any fun thms => thms.isLetDeclToUnfold fvarId
end Lean.Meta