/
DecEq.lean
189 lines (175 loc) · 7.95 KB
/
DecEq.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
/-
Copyright (c) 2020 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Meta.Transform
import Lean.Meta.Inductive
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
namespace Lean.Elab.Deriving.DecEq
open Lean.Parser.Term
open Meta
def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `DecidableEq 2 indVal
def mkMatch (header : Header) (indVal : InductiveVal) (auxFunName : Name) : TermElabM Term := do
let discrs ← mkDiscrs header indVal
let alts ← mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
where
mkSameCtorRhs : List (Ident × Ident × Bool × Bool) → TermElabM Term
| [] => ``(isTrue rfl)
| (a, b, recField, isProof) :: todo => withFreshMacroScope do
let rhs ← if isProof then
`(have h : $a = $b := rfl; by subst h; exact $(← mkSameCtorRhs todo):term)
else
`(if h : $a = $b then
by subst h; exact $(← mkSameCtorRhs todo):term
else
isFalse (by intro n; injection n; apply h _; assumption))
if recField then
-- add local instance for `a = b` using the function being defined `auxFunName`
`(let inst := $(mkIdent auxFunName) $a $b; $rhs)
else
return rhs
mkAlts : TermElabM (Array (TSyntax ``matchAlt)) := do
let mut alts := #[]
for ctorName₁ in indVal.ctors do
let ctorInfo ← getConstInfoCtor ctorName₁
for ctorName₂ in indVal.ctors do
let mut patterns := #[]
-- add `_` pattern for indices
for _ in [:indVal.numIndices] do
patterns := patterns.push (← `(_))
if ctorName₁ == ctorName₂ then
let alt ← forallTelescopeReducing ctorInfo.type fun xs type => do
let type ← Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut patterns := patterns
let mut ctorArgs1 := #[]
let mut ctorArgs2 := #[]
-- add `_` for inductive parameters, they are inaccessible
for _ in [:indVal.numParams] do
ctorArgs1 := ctorArgs1.push (← `(_))
ctorArgs2 := ctorArgs2.push (← `(_))
let mut todo := #[]
for i in [:ctorInfo.numFields] do
let x := xs[indVal.numParams + i]!
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to compare
ctorArgs1 := ctorArgs1.push (← `(_))
ctorArgs2 := ctorArgs2.push (← `(_))
else
let a := mkIdent (← mkFreshUserName `a)
let b := mkIdent (← mkFreshUserName `b)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let recField := (← inferType x).isAppOf indVal.name
let isProof := (← inferType (← inferType x)).isProp
todo := todo.push (a, b, recField, isProof)
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs1:term*))
patterns := patterns.push (← `(@$(mkIdent ctorName₁):ident $ctorArgs2:term*))
let rhs ← mkSameCtorRhs todo.toList
`(matchAltExpr| | $[$patterns:term],* => $rhs:term)
alts := alts.push alt
else if (← compatibleCtors ctorName₁ ctorName₂) then
patterns := patterns ++ #[(← `($(mkIdent ctorName₁) ..)), (← `($(mkIdent ctorName₂) ..))]
let rhs ← `(isFalse (by intro h; injection h))
alts := alts.push (← `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
return alts
def mkAuxFunction (ctx : Context) : TermElabM Syntax := do
let auxFunName := ctx.auxFunNames[0]!
let indVal :=ctx.typeInfos[0]!
let header ← mkDecEqHeader indVal
let mut body ← mkMatch header indVal auxFunName
let binders := header.binders
let type ← `(Decidable ($(mkIdent header.targetNames[0]!) = $(mkIdent header.targetNames[1]!)))
`(private def $(mkIdent auxFunName):ident $binders:bracketedBinder* : $type:term := $body:term)
def mkDecEqCmds (indVal : InductiveVal) : TermElabM (Array Syntax) := do
let ctx ← mkContext "decEq" indVal.name
let cmds := #[← mkAuxFunction ctx] ++ (← mkInstanceCmds ctx `DecidableEq #[indVal.name] (useAnonCtor := false))
trace[Elab.Deriving.decEq] "\n{cmds}"
return cmds
open Command
def mkDecEq (declName : Name) : CommandElabM Bool := do
let indVal ← getConstInfoInduct declName
if indVal.isNested then
return false -- nested inductive types are not supported yet
else
let cmds ← liftTermElabM <| mkDecEqCmds indVal
cmds.forM elabCommand
return true
partial def mkEnumOfNat (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
let enumType := mkConst declName
let ctors := indVal.ctors.toArray
withLocalDeclD `n (mkConst ``Nat) fun n => do
let cond := mkConst ``cond [levelZero]
let rec mkDecTree (low high : Nat) : Expr :=
if low + 1 == high then
mkConst ctors[low]!
else if low + 2 == high then
mkApp4 cond enumType (mkApp2 (mkConst ``Nat.beq) n (mkRawNatLit low)) (mkConst ctors[low]!) (mkConst ctors[low+1]!)
else
let mid := (low + high)/2
let lowBranch := mkDecTree low mid
let highBranch := mkDecTree mid high
mkApp4 cond enumType (mkApp2 (mkConst ``Nat.ble) (mkRawNatLit mid) n) highBranch lowBranch
let value ← mkLambdaFVars #[n] (mkDecTree 0 ctors.size)
let type ← mkArrow (mkConst ``Nat) enumType
addAndCompile <| Declaration.defnDecl {
name := Name.mkStr declName "ofNat"
levelParams := []
safety := DefinitionSafety.safe
hints := ReducibilityHints.abbrev
value, type
}
def mkEnumOfNatThm (declName : Name) : MetaM Unit := do
let indVal ← getConstInfoInduct declName
let toCtorIdx := mkConst (Name.mkStr declName "toCtorIdx")
let ofNat := mkConst (Name.mkStr declName "ofNat")
let enumType := mkConst declName
let eqEnum := mkApp (mkConst ``Eq [levelOne]) enumType
let rflEnum := mkApp (mkConst ``Eq.refl [levelOne]) enumType
let ctors := indVal.ctors
withLocalDeclD `x enumType fun x => do
let resultType := mkApp2 eqEnum (mkApp ofNat (mkApp toCtorIdx x)) x
let motive ← mkLambdaFVars #[x] resultType
let casesOn := mkConst (mkCasesOnName declName) [levelZero]
let mut value := mkApp2 casesOn motive x
for ctor in ctors do
value := mkApp value (mkApp rflEnum (mkConst ctor))
value ← mkLambdaFVars #[x] value
let type ← mkForallFVars #[x] resultType
addAndCompile <| Declaration.thmDecl {
name := Name.mkStr declName "ofNat_toCtorIdx"
levelParams := []
value, type
}
def mkDecEqEnum (declName : Name) : CommandElabM Unit := do
liftTermElabM <| mkEnumOfNat declName
liftTermElabM <| mkEnumOfNatThm declName
let ofNatIdent := mkIdent (Name.mkStr declName "ofNat")
let auxThmIdent := mkIdent (Name.mkStr declName "ofNat_toCtorIdx")
let cmd ← `(
instance : DecidableEq $(mkIdent declName) :=
fun x y =>
if h : x.toCtorIdx = y.toCtorIdx then
-- We use `rfl` in the following proof because the first script fails for unit-like datatypes due to etaStruct.
isTrue (by first | have aux := congrArg $ofNatIdent h; rw [$auxThmIdent:ident, $auxThmIdent:ident] at aux; assumption | rfl)
else
isFalse fun h => by subst h; contradiction
)
trace[Elab.Deriving.decEq] "\n{cmd}"
elabCommand cmd
def mkDecEqInstanceHandler (declNames : Array Name) : CommandElabM Bool := do
if declNames.size != 1 then
return false -- mutually inductive types are not supported yet
else if (← isEnumType declNames[0]!) then
mkDecEqEnum declNames[0]!
return true
else
mkDecEq declNames[0]!
builtin_initialize
registerDerivingHandler `DecidableEq mkDecEqInstanceHandler
registerTraceClass `Elab.Deriving.decEq
end Lean.Elab.Deriving.DecEq