-
Notifications
You must be signed in to change notification settings - Fork 345
/
InferType.lean
342 lines (299 loc) · 12.4 KB
/
InferType.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.Types
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.OtherDecl
namespace Lean.Compiler.LCNF
/-! # Type inference for LCNF -/
/-
Note about **erasure confusion**.
1- After instantiating universe polymorphic code, we may have
some types that become propositions, and all propositions are erased.
For example, suppose we have
```
def f (α : Sort u) (x : α → α → Sort v) (a b : α) (h : x a b) ...
```
The LCNF type for this universe polymorphic declaration is
```
def f (α : Sort u) (x : α → α → Sort v) (a b : α) (h : x ◾ ◾) ...
```
Now, if we instantiate with `v` with the universe `0`, we have that `x ◾ ◾` is also a proposition
and should be erased.
2- We may also get "erasure confusion" when instantiating
polymorphic code with types and type formers. For example, suppose we have
```
structure S (α : Type u) (β : Type v) (f : α → β) where
a : α
b : β := f a
```
The LCNF type for `S.mk` is
```
S.mk : {α : Type u} → {β : Type v} → {f : α → β} → α → β → S α β ◾
```
Note that `f` was erased from the resulting type `S α β ◾` because it is
not a type former. Now, suppose we have the valid Lean declaration
```
def f : S Nat Type (fun _ => Nat) :=
S.mk 0 Nat
```
The LNCF type for the value `S.mk 0 Nat` is `S Nat Type ◾` (see `S.mk` type above),
but the expected type is `S Nat Type (fun x => Nat)`. `fun x => Nat` is not erased
here because it is a type former.
-/
namespace InferType
/-
Type inference algorithm for LCNF. Invoked by the LCNF type checker
to check correctness of LCNF IR.
-/
/--
We use a regular local context to store temporary local declarations
created during type inference.
-/
abbrev InferTypeM := ReaderT LocalContext CompilerM
def getBinderName (fvarId : FVarId) : InferTypeM Name := do
match (← read).find? fvarId with
| some localDecl => return localDecl.userName
| none => LCNF.getBinderName fvarId
def getType (fvarId : FVarId) : InferTypeM Expr := do
match (← read).find? fvarId with
| some localDecl => return localDecl.type
| none => LCNF.getType fvarId
def mkForallFVars (xs : Array Expr) (type : Expr) : InferTypeM Expr :=
let b := type.abstract xs
xs.size.foldRevM (init := b) fun i b => do
let x := xs[i]!
let n ← InferType.getBinderName x.fvarId!
let ty ← InferType.getType x.fvarId!
let ty := ty.abstractRange i xs;
return .forallE n ty b .default
def mkForallParams (params : Array Param) (type : Expr) : InferTypeM Expr :=
let xs := params.map fun p => .fvar p.fvarId
mkForallFVars xs type |>.run {}
@[inline] def withLocalDecl (binderName : Name) (type : Expr) (binderInfo : BinderInfo) (k : Expr → InferTypeM α) : InferTypeM α := do
let fvarId ← mkFreshFVarId
withReader (fun lctx => lctx.mkLocalDecl fvarId binderName type binderInfo) do
k (.fvar fvarId)
def inferConstType (declName : Name) (us : List Level) : CompilerM Expr := do
if declName == ``lcErased then
return erasedExpr
else if let some decl ← getDecl? declName then
return decl.instantiateTypeLevelParams us
else
/- Declaration does not have code associated with it: constructor, inductive type, foreign function -/
getOtherDeclType declName us
def inferLitValueType (value : LitValue) : Expr :=
match value with
| .natVal .. => mkConst ``Nat
| .strVal .. => mkConst ``String
mutual
partial def inferArgType (arg : Arg) : InferTypeM Expr :=
match arg with
| .erased => return erasedExpr
| .type e => inferType e
| .fvar fvarId => LCNF.getType fvarId
partial def inferType (e : Expr) : InferTypeM Expr :=
match e with
| .const c us => inferConstType c us
| .app .. => inferAppType e
| .fvar fvarId => InferType.getType fvarId
| .sort lvl => return .sort (mkLevelSucc lvl)
| .forallE .. => inferForallType e
| .lam .. => inferLambdaType e
| .letE .. | .mvar .. | .mdata .. | .lit .. | .bvar .. | .proj .. => unreachable!
partial def inferLetValueType (e : LetValue) : InferTypeM Expr := do
match e with
| .erased => return erasedExpr
| .value v => return inferLitValueType v
| .proj structName idx fvarId => inferProjType structName idx fvarId
| .const declName us args => inferAppTypeCore (← inferConstType declName us) args
| .fvar fvarId args => inferAppTypeCore (← getType fvarId) args
partial def inferAppTypeCore (fType : Expr) (args : Array Arg) : InferTypeM Expr := do
let mut j := 0
let mut fType := fType
for i in [:args.size] do
fType := fType.headBeta
match fType with
| .forallE _ _ b _ => fType := b
| _ =>
fType := instantiateRevRangeArgs fType j i args |>.headBeta
match fType with
| .forallE _ _ b _ => j := i; fType := b
| _ => return erasedExpr
return instantiateRevRangeArgs fType j args.size args |>.headBeta
partial def inferAppType (e : Expr) : InferTypeM Expr := do
let mut j := 0
let mut fType ← inferType e.getAppFn
let args := e.getAppArgs
for i in [:args.size] do
fType := fType.headBeta
match fType with
| .forallE _ _ b _ => fType := b
| _ =>
fType := fType.instantiateRevRange j i args |>.headBeta
match fType with
| .forallE _ _ b _ => j := i; fType := b
| _ => return erasedExpr
return fType.instantiateRevRange j args.size args |>.headBeta
partial def inferProjType (structName : Name) (idx : Nat) (s : FVarId) : InferTypeM Expr := do
let failed {α} : Unit → InferTypeM α := fun _ =>
throwError "invalid projection{indentExpr (mkProj structName idx (mkFVar s))}"
let structType := (← getType s).headBeta
if structType.isErased then
/- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/
return erasedExpr
else
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
let n := structVal.numParams
let structParams := structType.getAppArgs
if n != structParams.size then
failed ()
else do
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
for _ in [:idx] do
match ctorType with
| .forallE _ _ body _ =>
if body.hasLooseBVars then
-- This can happen when one of the fields is a type or type former.
ctorType := body.instantiate1 erasedExpr
else
ctorType := body
| _ =>
if ctorType.isErased then return erasedExpr
failed ()
match ctorType with
| .forallE _ d _ _ => return d
| _ =>
if ctorType.isErased then return erasedExpr
failed ()
partial def getLevel? (type : Expr) : InferTypeM (Option Level) := do
match (← inferType type) with
| .sort u => return some u
| _ => return none
partial def inferForallType (e : Expr) : InferTypeM Expr :=
go e #[]
where
go (e : Expr) (fvars : Array Expr) : InferTypeM Expr := do
match e with
| .forallE n d b bi =>
withLocalDecl n (d.instantiateRev fvars) bi fun fvar =>
go b (fvars.push fvar)
| _ =>
let e := e.instantiateRev fvars
let some u ← getLevel? e | return erasedExpr
let mut u := u
for x in fvars.reverse do
let xType ← inferType x
let some v ← getLevel? xType | return erasedExpr
u := mkLevelIMax' v u
return .sort u.normalize
partial def inferLambdaType (e : Expr) : InferTypeM Expr :=
go e #[] #[]
where
go (e : Expr) (fvars : Array Expr) (all : Array Expr) : InferTypeM Expr := do
match e with
| .lam n d b bi =>
withLocalDecl n (d.instantiateRev all) bi fun fvar => go b (fvars.push fvar) (all.push fvar)
| .letE n t _ b _ =>
withLocalDecl n (t.instantiateRev all) .default fun fvar => go b fvars (all.push fvar)
| e =>
let type ← inferType (e.instantiateRev all)
mkForallFVars fvars type
end
end InferType
def inferType (e : Expr) : CompilerM Expr :=
InferType.inferType e |>.run {}
def inferAppType (fnType : Expr) (args : Array Arg) : CompilerM Expr :=
InferType.inferAppTypeCore fnType args |>.run {}
def getLevel (type : Expr) : CompilerM Level := do
match (← inferType type) with
| .sort u => return u
| e => if e.isErased then return levelOne else throwError "type expected{indentExpr type}"
def Arg.inferType (arg : Arg) : CompilerM Expr :=
InferType.inferArgType arg |>.run {}
def LetValue.inferType (e : LetValue) : CompilerM Expr :=
InferType.inferLetValueType e |>.run {}
def Code.inferType (code : Code) : CompilerM Expr := do
match code with
| .let _ k | .fun _ k | .jp _ k => k.inferType
| .return fvarId => getType fvarId
| .jmp fvarId args => InferType.inferAppTypeCore (← getType fvarId) args |>.run {}
| .unreach type => return type
| .cases c => return c.resultType
def Code.inferParamType (params : Array Param) (code : Code) : CompilerM Expr := do
let type ← code.inferType
let xs := params.map fun p => .fvar p.fvarId
InferType.mkForallFVars xs type |>.run {}
def AltCore.inferType (alt : Alt) : CompilerM Expr :=
alt.getCode.inferType
def mkAuxLetDecl (e : LetValue) (prefixName := `_x) : CompilerM LetDecl := do
mkLetDecl (← mkFreshBinderName prefixName) (← e.inferType) e
def mkForallParams (params : Array Param) (type : Expr) : CompilerM Expr :=
InferType.mkForallParams params type |>.run {}
def mkAuxFunDecl (params : Array Param) (code : Code) (prefixName := `_f) : CompilerM FunDecl := do
let type ← mkForallParams params (← code.inferType)
let binderName ← mkFreshBinderName prefixName
mkFunDecl binderName type params code
def mkAuxJpDecl (params : Array Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
mkAuxFunDecl params code prefixName
def mkAuxJpDecl' (param : Param) (code : Code) (prefixName := `_jp) : CompilerM FunDecl := do
let params := #[param]
mkAuxFunDecl params code prefixName
def mkCasesResultType (alts : Array Alt) : CompilerM Expr := do
if alts.isEmpty then
throwError "`Code.bind` failed, empty `cases` found"
let mut resultType ← alts[0]!.inferType
for alt in alts[1:] do
resultType := joinTypes resultType (← alt.inferType)
return resultType
/--
Return `true` if `type` should be erased. See item 1 in the note above where `x ◾ ◾` is
a proposition and should be erased when the universe level parameter is set to 0.
Remark: `predVars` is a bitmask that indicates whether de-bruijn variables are predicates or not.
That is, `#i` is a predicate if `predVars[predVars.size - i - 1] = true`
-/
partial def isErasedCompatible (type : Expr) (predVars : Array Bool := #[]): CompilerM Bool :=
go type predVars
where
go (type : Expr) (predVars : Array Bool) : CompilerM Bool := do
let type := type.headBeta
match type with
| .const .. => return type.isErased
| .sort .. => return false
| .mdata _ e => go e predVars
| .forallE _ t b _
| .lam _ t b _ => go b (predVars.push <| isPredicateType t)
| .app f _ => go f predVars
| .bvar idx => return predVars[predVars.size - idx - 1]!
| .fvar fvarId => return isPredicateType (← getType fvarId)
| .proj .. | .mvar .. | .letE .. | .lit .. => unreachable!
/--
Return `true` if the given LCNF are equivalent.
`List Nat` and `(fun x => List x) Nat` are both equivalent.
-/
partial def eqvTypes (a b : Expr) : Bool :=
if a == b then
true
else if a.isErased && b.isErased then
-- `◾ α` is equivalent to `◾`
true
else
let a' := a.headBeta
let b' := b.headBeta
if a != a' || b != b' then
eqvTypes a' b'
else
match a, b with
| .mdata _ a, b => eqvTypes a b
| a, .mdata _ b => eqvTypes a b
| .app f a, .app g b => eqvTypes f g && eqvTypes a b
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => eqvTypes d₁ d₂ && eqvTypes b₁ b₂
| .lam _ d₁ b₁ _, .lam _ d₂ b₂ _ => eqvTypes d₁ d₂ && eqvTypes b₁ b₂
| .sort u, .sort v => Level.isEquiv u v
| .const n us, .const m vs => n == m && List.isEqv us vs Level.isEquiv
| _, _ => false
end Lean.Compiler.LCNF