-
Notifications
You must be signed in to change notification settings - Fork 402
/
CompilerM.lean
485 lines (392 loc) · 18.2 KB
/
CompilerM.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.CoreM
import Lean.Compiler.LCNF.Basic
import Lean.Compiler.LCNF.LCtx
import Lean.Compiler.LCNF.ConfigOptions
namespace Lean.Compiler.LCNF
/--
The pipeline phase a certain `Pass` is supposed to happen in.
-/
inductive Phase where
/-- Here we still carry most of the original type information, most
of the dependent portion is already (partially) erased though. -/
| base
/-- In this phase polymorphism has been eliminated. -/
| mono
/-- In this phase impure stuff such as RC or efficient BaseIO transformations happen. -/
| impure
deriving Inhabited
/--
The state managed by the `CompilerM` `Monad`.
-/
structure CompilerM.State where
/--
A `LocalContext` to store local declarations from let binders
and other constructs in as we move through `Expr`s.
-/
lctx : LCtx := {}
/-- Next auxiliary variable suffix -/
nextIdx : Nat := 1
deriving Inhabited
structure CompilerM.Context where
phase : Phase
config : ConfigOptions
deriving Inhabited
abbrev CompilerM := ReaderT CompilerM.Context $ StateRefT CompilerM.State CoreM
@[always_inline]
instance : Monad CompilerM := let i := inferInstanceAs (Monad CompilerM); { pure := i.pure, bind := i.bind }
@[inline] def withPhase (phase : Phase) (x : CompilerM α) : CompilerM α :=
withReader (fun ctx => { ctx with phase }) x
def getPhase : CompilerM Phase :=
return (← read).phase
def inBasePhase : CompilerM Bool :=
return (← getPhase) matches .base
instance : AddMessageContext CompilerM where
addMessageContext msgData := do
let env ← getEnv
let lctx := (← get).lctx.toLocalContext
let opts ← getOptions
return MessageData.withContext { env, lctx, opts, mctx := {} } msgData
def getType (fvarId : FVarId) : CompilerM Expr := do
let lctx := (← get).lctx
if let some decl := lctx.letDecls.find? fvarId then
return decl.type
else if let some decl := lctx.params.find? fvarId then
return decl.type
else if let some decl := lctx.funDecls.find? fvarId then
return decl.type
else
throwError "unknown free variable {fvarId.name}"
def getBinderName (fvarId : FVarId) : CompilerM Name := do
let lctx := (← get).lctx
if let some decl := lctx.letDecls.find? fvarId then
return decl.binderName
else if let some decl := lctx.params.find? fvarId then
return decl.binderName
else if let some decl := lctx.funDecls.find? fvarId then
return decl.binderName
else
throwError "unknown free variable {fvarId.name}"
def findParam? (fvarId : FVarId) : CompilerM (Option Param) :=
return (← get).lctx.params.find? fvarId
def findLetDecl? (fvarId : FVarId) : CompilerM (Option LetDecl) :=
return (← get).lctx.letDecls.find? fvarId
def findFunDecl? (fvarId : FVarId) : CompilerM (Option FunDecl) :=
return (← get).lctx.funDecls.find? fvarId
def findLetValue? (fvarId : FVarId) : CompilerM (Option LetValue) := do
let some { value, .. } ← findLetDecl? fvarId | return none
return some value
def isConstructorApp (fvarId : FVarId) : CompilerM Bool := do
let some (.const declName _ _) ← findLetValue? fvarId | return false
return (← getEnv).find? declName matches some (.ctorInfo ..)
def Arg.isConstructorApp (arg : Arg) : CompilerM Bool := do
let .fvar fvarId := arg | return false
LCNF.isConstructorApp fvarId
def getParam (fvarId : FVarId) : CompilerM Param := do
let some param ← findParam? fvarId | throwError "unknown parameter {fvarId.name}"
return param
def getLetDecl (fvarId : FVarId) : CompilerM LetDecl := do
let some decl ← findLetDecl? fvarId | throwError "unknown let-declaration {fvarId.name}"
return decl
def getFunDecl (fvarId : FVarId) : CompilerM FunDecl := do
let some decl ← findFunDecl? fvarId | throwError "unknown local function {fvarId.name}"
return decl
@[inline] def modifyLCtx (f : LCtx → LCtx) : CompilerM Unit := do
modify fun s => { s with lctx := f s.lctx }
def eraseLetDecl (decl : LetDecl) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseLetDecl decl
def eraseFunDecl (decl : FunDecl) (recursive := true) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseFunDecl decl recursive
def eraseCode (code : Code) : CompilerM Unit := do
modifyLCtx fun lctx => lctx.eraseCode code
def eraseParam (param : Param) : CompilerM Unit :=
modifyLCtx fun lctx => lctx.eraseParam param
def eraseParams (params : Array Param) : CompilerM Unit :=
modifyLCtx fun lctx => lctx.eraseParams params
def eraseCodeDecl (decl : CodeDecl) : CompilerM Unit := do
match decl with
| .let decl => eraseLetDecl decl
| .jp decl | .fun decl => eraseFunDecl decl
/--
Erase all free variables occurring in `decls` from the local context.
-/
def eraseCodeDecls (decls : Array CodeDecl) : CompilerM Unit := do
decls.forM fun decl => eraseCodeDecl decl
def eraseDecl (decl : Decl) : CompilerM Unit := do
eraseParams decl.params
eraseCode decl.value
abbrev Decl.erase (decl : Decl) : CompilerM Unit :=
eraseDecl decl
/--
A free variable substitution.
We use these substitutions when inlining definitions and "internalizing" LCNF code into `CompilerM`.
During the internalization process, we ensure all free variables in the LCNF code do not collide with existing ones
at the `CompilerM` local context.
Remark: in LCNF, (computationally relevant) data is in A-normal form, but this is not the case for types and type formers.
So, when inlining we often want to replace a free variable with a type or type former.
The substitution contains entries `fvarId ↦ e` s.t., `e` is a valid LCNF argument. That is,
it is a free variable, a type (or type former), or `lcErased`.
`Check.lean` contains a substitution validator.
-/
abbrev FVarSubst := HashMap FVarId Expr
/--
Replace the free variables in `e` using the given substitution.
If `translator = true`, then we assume the free variables occurring in the range of the substitution are in another
local context. For example, `translator = true` during internalization where we are making sure all free variables
in a given expression are replaced with new ones that do not collide with the ones in the current local context.
If `translator = false`, we assume the substitution contains free variable replacements in the same local context,
and given entries such as `x₁ ↦ x₂`, `x₂ ↦ x₃`, ..., `xₙ₋₁ ↦ xₙ`, and the expression `f x₁ x₂`, we want the resulting
expression to be `f xₙ xₙ`. We use this setting, for example, in the simplifier.
-/
private partial def normExprImp (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
go e
where
goApp (e : Expr) : Expr :=
match e with
| .app f a => e.updateApp! (goApp f) (go a)
| _ => go e
go (e : Expr) : Expr :=
if e.hasFVar then
match e with
| .fvar fvarId => match s.find? fvarId with
| some e => if translator then e else go e
| none => e
| .lit .. | .const .. | .sort .. | .mvar .. | .bvar .. => e
| .app f a => e.updateApp! (goApp f) (go a) |>.headBeta
| .mdata _ b => e.updateMData! (go b)
| .proj _ _ b => e.updateProj! (go b)
| .forallE _ d b _ => e.updateForallE! (go d) (go b)
| .lam _ d b _ => e.updateLambdaE! (go d) (go b)
| .letE .. => unreachable! -- Valid LCNF does not contain `let`-declarations
else
e
/--
Result type for `normFVar` and `normFVarImp`.
-/
inductive NormFVarResult where
| /-- New free variable. -/
fvar (fvarId : FVarId)
| /--
Free variable has been erased. This can happen when instantiating polymorphic code
with computationally irrelant stuff. -/
erased
deriving Inhabited
/--
Normalize the given free variable.
See `normExprImp` for documentation on the `translator` parameter.
This function is meant to be used in contexts where the input free-variable is computationally relevant.
This function panics if the substitution is mapping `fvarId` to an expression that is not another free variable.
That is, it is not a type (or type former), nor `lcErased`. Recall that a valid `FVarSubst` contains only
expressions that are free variables, `lcErased`, or type formers.
-/
private partial def normFVarImp (s : FVarSubst) (fvarId : FVarId) (translator : Bool) : NormFVarResult :=
match s.find? fvarId with
| some (.fvar fvarId') =>
if translator then
.fvar fvarId'
else
normFVarImp s fvarId' translator
| some e =>
if e.isErased then
.erased
else
panic! s!"invalid LCNF substitution of free variable with expression {e}"
| none => .fvar fvarId
/--
Replace the free variables in `arg` using the given substitution.
See `normExprImp`
-/
private partial def normArgImp (s : FVarSubst) (arg : Arg) (translator : Bool) : Arg :=
match arg with
| .erased => arg
| .fvar fvarId =>
match s.find? fvarId with
| some (.fvar fvarId') =>
let arg' := .fvar fvarId'
if translator then arg' else normArgImp s arg' translator
| some e => if e.isErased then .erased else .type e
| none => arg
| .type e => arg.updateType! (normExprImp s e translator)
private def normArgsImp (s : FVarSubst) (args : Array Arg) (translator : Bool) : Array Arg :=
args.mapMono (normArgImp s · translator)
/--
Replace the free variables in `e` using the given substitution.
See `normExprImp`
-/
private partial def normLetValueImp (s : FVarSubst) (e : LetValue) (translator : Bool) : LetValue :=
match e with
| .erased | .value .. => e
| .proj _ _ fvarId => match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateProj! fvarId'
| .erased => .erased
| .const _ _ args => e.updateArgs! (normArgsImp s args translator)
| .fvar fvarId args => match normFVarImp s fvarId translator with
| .fvar fvarId' => e.updateFVar! fvarId' (normArgsImp s args translator)
| .erased => .erased
/--
Interface for monads that have a free substitutions.
-/
class MonadFVarSubst (m : Type → Type) (translator : outParam Bool) where
getSubst : m FVarSubst
export MonadFVarSubst (getSubst)
instance (m n) [MonadLift m n] [MonadFVarSubst m t] : MonadFVarSubst n t where
getSubst := liftM (getSubst : m _)
class MonadFVarSubstState (m : Type → Type) where
modifySubst : (FVarSubst → FVarSubst) → m Unit
export MonadFVarSubstState (modifySubst)
instance (m n) [MonadLift m n] [MonadFVarSubstState m] : MonadFVarSubstState n where
modifySubst f := liftM (modifySubst f : m _)
/--
Add the entry `fvarId ↦ fvarId'` to the free variable substitution.
-/
@[inline] def addFVarSubst [MonadFVarSubstState m] (fvarId : FVarId) (fvarId' : FVarId) : m Unit :=
modifySubst fun s => s.insert fvarId (.fvar fvarId')
/--
Add the substitution `fvarId ↦ e`, `e` must be a valid LCNF argument.
That is, it must be a free variable, type (or type former), or `lcErased`.
See `Check.lean` for the free variable substitution checker.
-/
@[inline] def addSubst [MonadFVarSubstState m] (fvarId : FVarId) (e : Expr) : m Unit :=
modifySubst fun s => s.insert fvarId e
@[inline, inherit_doc normFVarImp] def normFVar [MonadFVarSubst m t] [Monad m] (fvarId : FVarId) : m NormFVarResult :=
return normFVarImp (← getSubst) fvarId t
@[inline, inherit_doc normExprImp] def normExpr [MonadFVarSubst m t] [Monad m] (e : Expr) : m Expr :=
return normExprImp (← getSubst) e t
@[inline, inherit_doc normArgImp] def normArg [MonadFVarSubst m t] [Monad m] (arg : Arg) : m Arg :=
return normArgImp (← getSubst) arg t
@[inline, inherit_doc normLetValueImp] def normLetValue [MonadFVarSubst m t] [Monad m] (e : LetValue) : m LetValue :=
return normLetValueImp (← getSubst) e t
@[inherit_doc normExprImp]
abbrev normExprCore (s : FVarSubst) (e : Expr) (translator : Bool) : Expr :=
normExprImp s e translator
/--
Normalize the given arguments using the current substitution.
-/
def normArgs [MonadFVarSubst m t] [Monad m] (args : Array Arg) : m (Array Arg) :=
return normArgsImp (← getSubst) args t
def mkFreshBinderName (binderName := `_x): CompilerM Name := do
let declName := .num binderName (← get).nextIdx
modify fun s => { s with nextIdx := s.nextIdx + 1 }
return declName
def ensureNotAnonymous (binderName : Name) (baseName : Name) : CompilerM Name :=
if binderName.isAnonymous then
mkFreshBinderName baseName
else
return binderName
/-!
Helper functions for creating LCNF local declarations.
-/
def mkParam (binderName : Name) (type : Expr) (borrow : Bool) : CompilerM Param := do
let fvarId ← mkFreshFVarId
let binderName ← ensureNotAnonymous binderName `_y
let param := { fvarId, binderName, type, borrow }
modifyLCtx fun lctx => lctx.addParam param
return param
def mkLetDecl (binderName : Name) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
let fvarId ← mkFreshFVarId
let binderName ← ensureNotAnonymous binderName `_x
let decl := { fvarId, binderName, type, value }
modifyLCtx fun lctx => lctx.addLetDecl decl
return decl
def mkFunDecl (binderName : Name) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
let fvarId ← mkFreshFVarId
let binderName ← ensureNotAnonymous binderName `_f
let funDecl := { fvarId, binderName, type, params, value }
modifyLCtx fun lctx => lctx.addFunDecl funDecl
return funDecl
def mkLetDeclErased : CompilerM LetDecl := do
mkLetDecl (← mkFreshBinderName `_x) erasedExpr .erased
def mkReturnErased : CompilerM Code := do
let auxDecl ← mkLetDeclErased
return .let auxDecl (.return auxDecl.fvarId)
private unsafe def updateParamImp (p : Param) (type : Expr) : CompilerM Param := do
if ptrEq type p.type then
return p
else
let p := { p with type }
modifyLCtx fun lctx => lctx.addParam p
return p
@[implemented_by updateParamImp] opaque Param.update (p : Param) (type : Expr) : CompilerM Param
private unsafe def updateLetDeclImp (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl := do
if ptrEq type decl.type && ptrEq value decl.value then
return decl
else
let decl := { decl with type, value }
modifyLCtx fun lctx => lctx.addLetDecl decl
return decl
@[implemented_by updateLetDeclImp] opaque LetDecl.update (decl : LetDecl) (type : Expr) (value : LetValue) : CompilerM LetDecl
def LetDecl.updateValue (decl : LetDecl) (value : LetValue) : CompilerM LetDecl :=
decl.update decl.type value
private unsafe def updateFunDeclImp (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl := do
if ptrEq type decl.type && ptrEq params decl.params && ptrEq value decl.value then
return decl
else
let decl := { decl with type, params, value }
modifyLCtx fun lctx => lctx.addFunDecl decl
return decl
@[implemented_by updateFunDeclImp] opaque FunDeclCore.update (decl: FunDecl) (type : Expr) (params : Array Param) (value : Code) : CompilerM FunDecl
abbrev FunDeclCore.update' (decl : FunDecl) (type : Expr) (value : Code) : CompilerM FunDecl :=
decl.update type decl.params value
abbrev FunDeclCore.updateValue (decl : FunDecl) (value : Code) : CompilerM FunDecl :=
decl.update decl.type decl.params value
@[inline] def normParam [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (p : Param) : m Param := do
p.update (← normExpr p.type)
def normParams [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (ps : Array Param) : m (Array Param) :=
ps.mapMonoM normParam
def normLetDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : LetDecl) : m LetDecl := do
decl.update (← normExpr decl.type) (← normLetValue decl.value)
abbrev NormalizerM (_translator : Bool) := ReaderT FVarSubst CompilerM
instance : MonadFVarSubst (NormalizerM t) t where
getSubst := read
/--
If `result` is `.fvar fvarId`, then return `x fvarId`. Otherwise, it is `.erased`,
and method returns `let _x.i := .erased; return _x.i`.
-/
@[inline] def withNormFVarResult [MonadLiftT CompilerM m] [Monad m] (result : NormFVarResult) (x : FVarId → m Code) : m Code := do
match result with
| .fvar fvarId => x fvarId
| .erased => mkReturnErased
mutual
partial def normFunDeclImp (decl : FunDecl) : NormalizerM t FunDecl := do
let type ← normExpr decl.type
let params ← normParams decl.params
let value ← normCodeImp decl.value
decl.update type params value
partial def normCodeImp (code : Code) : NormalizerM t Code := do
match code with
| .let decl k => return code.updateLet! (← normLetDecl decl) (← normCodeImp k)
| .fun decl k | .jp decl k => return code.updateFun! (← normFunDeclImp decl) (← normCodeImp k)
| .return fvarId => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateReturn! fvarId
| .jmp fvarId args => withNormFVarResult (← normFVar fvarId) fun fvarId => return code.updateJmp! fvarId (← normArgs args)
| .unreach type => return code.updateUnreach! (← normExpr type)
| .cases c =>
let resultType ← normExpr c.resultType
withNormFVarResult (← normFVar c.discr) fun discr => do
let alts ← c.alts.mapMonoM fun alt =>
match alt with
| .alt _ params k => return alt.updateAlt! (← normParams params) (← normCodeImp k)
| .default k => return alt.updateCode (← normCodeImp k)
return code.updateCases! resultType discr alts
end
@[inline] def normFunDecl [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (decl : FunDecl) : m FunDecl := do
normFunDeclImp (t := t) decl (← getSubst)
/-- Similar to `internalize`, but does not refresh `FVarId`s. -/
@[inline] def normCode [MonadLiftT CompilerM m] [Monad m] [MonadFVarSubst m t] (code : Code) : m Code := do
normCodeImp (t := t) code (← getSubst)
def replaceExprFVars (e : Expr) (s : FVarSubst) (translator : Bool) : CompilerM Expr :=
(normExpr e : NormalizerM translator Expr).run s
def replaceFVars (code : Code) (s : FVarSubst) (translator : Bool) : CompilerM Code :=
(normCode code : NormalizerM translator Code).run s
def mkFreshJpName : CompilerM Name := do
mkFreshBinderName `_jp
def mkAuxParam (type : Expr) (borrow := false) : CompilerM Param := do
mkParam (← mkFreshBinderName `_y) type borrow
def getConfig : CompilerM ConfigOptions :=
return (← read).config
def CompilerM.run (x : CompilerM α) (s : State := {}) (phase : Phase := .base) : CoreM α := do
x { phase, config := toConfigOptions (← getOptions) } |>.run' s
end Lean.Compiler.LCNF