-
Notifications
You must be signed in to change notification settings - Fork 88
/
Ext.lean
216 lines (182 loc) · 8.74 KB
/
Ext.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
/-
Copyright (c) 2019 Paul-Nicolas Madelaine. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul-Nicolas Madelaine, Robert Y. Lewis, Mario Carneiro, Gabriel Ebner
-/
import Lean.Meta.CongrTheorems
import Lean.Meta.Tactic.Simp.SimpTheorems
import Std.Tactic.CoeExt
open Lean Meta
namespace Std.Tactic.NormCast
open Tactic.Coe
/--
`Label` is a type used to classify `norm_cast` lemmas.
* elim lemma: LHS has 0 head coes and ≥ 1 internal coe
* move lemma: LHS has 1 head coe and 0 internal coes, RHS has 0 head coes and ≥ 1 internal coes
* squash lemma: LHS has ≥ 1 head coes and 0 internal coes, RHS has fewer head coes
-/
inductive Label
/-- elim lemma: LHS has 0 head coes and ≥ 1 internal coe -/
| elim
/-- move lemma: LHS has 1 head coe and 0 internal coes,
RHS has 0 head coes and ≥ 1 internal coes -/
| move
/-- squash lemma: LHS has ≥ 1 head coes and 0 internal coes, RHS has fewer head coes -/
| squash
deriving DecidableEq, Repr, Inhabited
/-- Assuming `e` is an application, returns the list of subterms that `simp` will rewrite in. -/
def getSimpArgs (e : Expr) : MetaM (Array Expr) := do
match ← mkCongrSimp? e.getAppFn with
| none => return e.getAppArgs
| some {argKinds, ..} =>
let mut args := #[]
for a in e.getAppArgs, k in argKinds do
if k matches .eq then
args := args.push a
return args
/-- Count how many coercions are at the top of the expression. -/
partial def countHeadCoes (e : Expr) : MetaM Nat := do
if let Expr.const fn .. := e.getAppFn then
if let some info ← getCoeFnInfo? fn then
if e.getAppNumArgs >= info.numArgs then
return (← countHeadCoes (e.getArg! info.coercee)) + 1
return 0
/-- Count how many coercions are inside the expression, including the top ones. -/
partial def countCoes (e : Expr) : MetaM Nat :=
lambdaTelescope e fun _ e => do
if let Expr.const fn .. := e.getAppFn then
if let some info ← getCoeFnInfo? fn then
if e.getAppNumArgs >= info.numArgs then
let mut coes := (← countHeadCoes (e.getArg! info.coercee)) + 1
for i in [info.numArgs:e.getAppNumArgs] do
coes := coes + (← countCoes (e.getArg! i))
return coes
return (← (← getSimpArgs e).mapM countCoes).foldl (·+·) 0
/-- Count how many coercions are inside the expression, excluding the top ones. -/
def countInternalCoes (e : Expr) : MetaM Nat :=
return (← countCoes e) - (← countHeadCoes e)
/-- Classifies a declaration of type `ty` as a `norm_cast` rule. -/
def classifyType (ty : Expr) : MetaM Label :=
forallTelescopeReducing ty fun _ ty => do
let ty ← whnf ty
let (lhs, rhs) ←
if ty.isAppOfArity ``Eq 3 then pure (ty.getArg! 1, ty.getArg! 2)
else if ty.isAppOfArity ``Iff 2 then pure (ty.getArg! 0, ty.getArg! 1)
else throwError "norm_cast: lemma must be = or ↔, but is{indentExpr ty}"
let lhsCoes ← countCoes lhs
if lhsCoes = 0 then throwError "norm_cast: badly shaped lemma, lhs must contain at least one coe{indentExpr lhs}"
let lhsHeadCoes ← countHeadCoes lhs
let rhsHeadCoes ← countHeadCoes rhs
let rhsInternalCoes ← countInternalCoes rhs
if lhsHeadCoes = 0 then
return Label.elim
else if lhsHeadCoes = 1 then do
unless rhsHeadCoes = 0 do throwError "norm_cast: badly shaped lemma, rhs can't start with coe{indentExpr rhs}"
if rhsInternalCoes = 0 then
return Label.squash
else
return Label.move
else if rhsHeadCoes < lhsHeadCoes then do
return Label.squash
else do
throwError "norm_cast: badly shaped shaped squash lemma, rhs must have fewer head coes than lhs{indentExpr ty}"
/-- The `push_cast` simp attribute. -/
initialize pushCastExt : SimpExtension ←
registerSimpAttr `push_cast <|
"The `push_cast` simp attribute uses `norm_cast` lemmas " ++
"to move casts toward the leaf nodes of the expression."
/-- The `norm_cast` attribute stores three simp sets. -/
structure NormCastExtension where
/-- A simp set which lifts coercion arrows to the top level. -/
up : SimpExtension
/-- A simp set which pushes coercion arrows to the leaves. -/
down : SimpExtension
/-- A simp set which simplifies transitive coercions. -/
squash : SimpExtension
deriving Inhabited
/-- The `norm_cast` extension data. -/
initialize normCastExt : NormCastExtension ← pure {
up := ← mkSimpExt (decl_name% ++ `up)
down := ← mkSimpExt (decl_name% ++ `down)
squash := ← mkSimpExt (decl_name% ++ `squash)
}
/-- `addElim decl` adds `decl` as an `elim` lemma to the cache. -/
def addElim (decl : Name)
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit :=
addSimpTheorem normCastExt.up decl (post := true) (inv := false) kind prio
/-- `addMove decl` adds `decl` as a `move` lemma to the cache. -/
def addMove (decl : Name)
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit := do
addSimpTheorem pushCastExt decl (post := true) (inv := false) kind prio
addSimpTheorem normCastExt.up decl (post := true) (inv := true) kind prio
addSimpTheorem normCastExt.down decl (post := true) (inv := false) kind prio
/-- `addSquash decl` adds `decl` as a `squash` lemma to the cache. -/
def addSquash (decl : Name)
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit := do
addSimpTheorem pushCastExt decl (post := true) (inv := false) kind prio
addSimpTheorem normCastExt.squash decl (post := true) (inv := false) kind prio
addSimpTheorem normCastExt.down decl (post := true) (inv := false) kind prio
/-- `addInfer decl` infers the label of `decl` and adds it to the cache.
* elim lemma: LHS has 0 head coes and ≥ 1 internal coe
* move lemma: LHS has 1 head coe and 0 internal coes, RHS has 0 head coes and ≥ 1 internal coes
* squash lemma: LHS has ≥ 1 head coes and 0 internal coes, RHS has fewer head coes
-/
def addInfer (decl : Name)
(kind := AttributeKind.global) (prio := eval_prio default) : MetaM Unit := do
let ty := (← getConstInfo decl).type
match ← classifyType ty with
| Label.elim => addElim decl kind prio
| Label.squash => addSquash decl kind prio
| Label.move => addMove decl kind prio
namespace Attr
/-- The possible `norm_cast` kinds: `elim`, `move`, or `squash`. -/
syntax normCastLabel := &"elim" <|> &"move" <|> &"squash"
/--
The `norm_cast` attribute should be given to lemmas that describe the
behaviour of a coercion in regard to an operator, a relation, or a particular
function.
It only concerns equality or iff lemmas involving `↑`, `⇑` and `↥`, describing the behavior of
the coercion functions.
It does not apply to the explicit functions that define the coercions.
Examples:
```lean
@[norm_cast] theorem coe_nat_inj' {m n : ℕ} : (↑m : ℤ) = ↑n ↔ m = n
@[norm_cast] theorem coe_int_denom (n : ℤ) : (n : ℚ).denom = 1
@[norm_cast] theorem cast_id : ∀ n : ℚ, ↑n = n
@[norm_cast] theorem coe_nat_add (m n : ℕ) : (↑(m + n) : ℤ) = ↑m + ↑n
@[norm_cast] theorem cast_coe_nat (n : ℕ) : ((n : ℤ) : α) = n
@[norm_cast] theorem cast_one : ((1 : ℚ) : α) = 1
```
Lemmas tagged with `@[norm_cast]` are classified into three categories: `move`, `elim`, and
`squash`. They are classified roughly as follows:
* elim lemma: LHS has 0 head coes and ≥ 1 internal coe
* move lemma: LHS has 1 head coe and 0 internal coes, RHS has 0 head coes and ≥ 1 internal coes
* squash lemma: LHS has ≥ 1 head coes and 0 internal coes, RHS has fewer head coes
`norm_cast` uses `move` and `elim` lemmas to factor coercions toward the root of an expression
and to cancel them from both sides of an equation or relation. It uses `squash` lemmas to clean
up the result.
Occasionally you may want to override the automatic classification.
You can do this by giving an optional `elim`, `move`, or `squash` parameter to the attribute.
```lean
@[simp, norm_cast elim] lemma nat_cast_re (n : ℕ) : (n : ℂ).re = n := by
rw [← of_real_nat_cast, of_real_re]
```
Don't do this unless you understand what you are doing.
A full description of the tactic, and the use of each lemma category, can be found at
<https://lean-forward.github.io/norm_cast/norm_cast.pdf>.
-/
syntax (name := norm_cast) "norm_cast" (ppSpace normCastLabel)? (ppSpace num)? : attr
end Attr
initialize registerBuiltinAttribute {
name := `norm_cast
descr := "attribute for norm_cast"
add := fun decl stx kind => MetaM.run' do
let `(attr| norm_cast $[$label:normCastLabel]? $[$prio]?) := stx | unreachable!
let prio := (prio.bind (·.1.isNatLit?)).getD (eval_prio default)
match label.bind (·.1.isStrLit?) with
| "elim" => addElim decl kind prio
| "move" => addMove decl kind prio
| "squash" => addSquash decl kind prio
| none => addInfer decl kind prio
| _ => unreachable!
}