/
Borrow.lean
316 lines (269 loc) · 10.6 KB
/
Borrow.lean
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
import Lean.Compiler.ExportAttr
import Lean.Compiler.IR.CompilerM
import Lean.Compiler.IR.NormIds
namespace Lean
namespace IR
namespace Borrow
namespace OwnedSet
abbrev Key := FunId × Index
def beq : Key → Key → Bool
| (f₁, x₁), (f₂, x₂) => f₁ == f₂ && x₁ == x₂
instance : BEq Key := ⟨beq⟩
def getHash : Key → UInt64
| (f, x) => mixHash (hash f) (hash x)
instance : Hashable Key := ⟨getHash⟩
end OwnedSet
open OwnedSet (Key) in
abbrev OwnedSet := HashMap Key Unit
def OwnedSet.insert (s : OwnedSet) (k : OwnedSet.Key) : OwnedSet := HashMap.insert s k ()
def OwnedSet.contains (s : OwnedSet) (k : OwnedSet.Key) : Bool := HashMap.contains s k
/-! We perform borrow inference in a block of mutually recursive functions.
Join points are viewed as local functions, and are identified using
their local id + the name of the surrounding function.
We keep a mapping from function and joint points to parameters (`Array Param`).
Recall that `Param` contains the field `borrow`. -/
namespace ParamMap
inductive Key where
| decl (name : FunId)
| jp (name : FunId) (jpid : JoinPointId)
deriving BEq
def getHash : Key → UInt64
| Key.decl n => hash n
| Key.jp n id => mixHash (hash n) (hash id)
instance : Hashable Key := ⟨getHash⟩
end ParamMap
open ParamMap (Key)
abbrev ParamMap := HashMap Key (Array Param)
def ParamMap.fmt (map : ParamMap) : Format :=
let fmts := map.fold (fun fmt k ps =>
let k := match k with
| ParamMap.Key.decl n => format n
| ParamMap.Key.jp n id => format n ++ ":" ++ format id
fmt ++ Format.line ++ k ++ " -> " ++ formatParams ps)
Format.nil
"{" ++ (Format.nest 1 fmts) ++ "}"
instance : ToFormat ParamMap := ⟨ParamMap.fmt⟩
instance : ToString ParamMap := ⟨fun m => Format.pretty (format m)⟩
namespace InitParamMap
/-- Mark parameters that take a reference as borrow -/
def initBorrow (ps : Array Param) : Array Param :=
ps.map fun p => { p with borrow := p.ty.isObj }
/-- We do perform borrow inference for constants marked as `export`.
Reason: we current write wrappers in C++ for using exported functions.
These wrappers use smart pointers such as `object_ref`.
When writing a new wrapper we need to know whether an argument is a borrow
inference or not.
We can revise this decision when we implement code for generating
the wrappers automatically. -/
def initBorrowIfNotExported (exported : Bool) (ps : Array Param) : Array Param :=
if exported then ps else initBorrow ps
partial def visitFnBody (fnid : FunId) : FnBody → StateM ParamMap Unit
| FnBody.jdecl j xs v b => do
modify fun m => m.insert (ParamMap.Key.jp fnid j) (initBorrow xs)
visitFnBody fnid v
visitFnBody fnid b
| FnBody.case _ _ _ alts => alts.forM fun alt => visitFnBody fnid alt.body
| e => do
unless e.isTerminal do
let (_, b) := e.split
visitFnBody fnid b
def visitDecls (env : Environment) (decls : Array Decl) : StateM ParamMap Unit :=
decls.forM fun decl => match decl with
| .fdecl (f := f) (xs := xs) (body := b) .. => do
let exported := isExport env f
modify fun m => m.insert (ParamMap.Key.decl f) (initBorrowIfNotExported exported xs)
visitFnBody f b
| _ => pure ()
end InitParamMap
def mkInitParamMap (env : Environment) (decls : Array Decl) : ParamMap :=
(InitParamMap.visitDecls env decls *> get).run' {}
/-! Apply the inferred borrow annotations stored at `ParamMap` to a block of mutually
recursive functions. -/
namespace ApplyParamMap
partial def visitFnBody (fn : FunId) (paramMap : ParamMap) : FnBody → FnBody
| FnBody.jdecl j _ v b =>
let v := visitFnBody fn paramMap v
let b := visitFnBody fn paramMap b
match paramMap.find? (ParamMap.Key.jp fn j) with
| some ys => FnBody.jdecl j ys v b
| none => unreachable!
| FnBody.case tid x xType alts =>
FnBody.case tid x xType <| alts.map fun alt => alt.modifyBody (visitFnBody fn paramMap)
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := visitFnBody fn paramMap b
instr.setBody b
def visitDecls (decls : Array Decl) (paramMap : ParamMap) : Array Decl :=
decls.map fun decl => match decl with
| Decl.fdecl f _ ty b info =>
let b := visitFnBody f paramMap b
match paramMap.find? (ParamMap.Key.decl f) with
| some xs => Decl.fdecl f xs ty b info
| none => unreachable!
| other => other
end ApplyParamMap
def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl :=
ApplyParamMap.visitDecls decls map
structure BorrowInfCtx where
env : Environment
decls : Array Decl -- block of mutually recursive functions
currFn : FunId := default -- Function being analyzed.
paramSet : IndexSet := {} -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
structure BorrowInfState where
/-- Set of variables that must be `owned`. -/
owned : OwnedSet := {}
modified : Bool := false
paramMap : ParamMap
abbrev M := ReaderT BorrowInfCtx (StateM BorrowInfState)
def getCurrFn : M FunId := do
let ctx ← read
pure ctx.currFn
def markModified : M Unit :=
modify fun s => { s with modified := true }
def ownVar (x : VarId) : M Unit := do
let currFn ← getCurrFn
modify fun s =>
if s.owned.contains (currFn, x.idx) then s
else { s with owned := s.owned.insert (currFn, x.idx), modified := true }
def ownArg (x : Arg) : M Unit :=
match x with
| Arg.var x => ownVar x
| _ => pure ()
def ownArgs (xs : Array Arg) : M Unit :=
xs.forM ownArg
def isOwned (x : VarId) : M Bool := do
let currFn ← getCurrFn
let s ← get
return s.owned.contains (currFn, x.idx)
/-- Updates `map[k]` using the current set of `owned` variables. -/
def updateParamMap (k : ParamMap.Key) : M Unit := do
let s ← get
match s.paramMap.find? k with
| some ps => do
let ps ← ps.mapM fun (p : Param) => do
if !p.borrow then pure p
else if (← isOwned p.x) then
markModified
pure { p with borrow := false }
else
pure p
modify fun s => { s with paramMap := s.paramMap.insert k ps }
| none => pure ()
def getParamInfo (k : ParamMap.Key) : M (Array Param) := do
let s ← get
match s.paramMap.find? k with
| some ps => pure ps
| none =>
match k with
| ParamMap.Key.decl fn => do
let ctx ← read
match findEnvDecl ctx.env fn with
| some decl => pure decl.params
| none => unreachable!
| _ => unreachable!
/-- For each ps[i], if ps[i] is owned, then mark xs[i] as owned. -/
def ownArgsUsingParams (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i => do
let x := xs[i]!
let p := ps[i]!
unless p.borrow do ownArg x
/-- For each xs[i], if xs[i] is owned, then mark ps[i] as owned.
We use this action to preserve tail calls. That is, if we have
a tail call `f xs`, if the i-th parameter is borrowed, but `xs[i]` is owned
we would have to insert a `dec xs[i]` after `f xs` and consequently
"break" the tail call. -/
def ownParamsUsingArgs (xs : Array Arg) (ps : Array Param) : M Unit :=
xs.size.forM fun i => do
let x := xs[i]!
let p := ps[i]!
match x with
| Arg.var x => if (← isOwned x) then ownVar p.x
| _ => pure ()
/-- Mark `xs[i]` as owned if it is one of the parameters `ps`.
We use this action to mark function parameters that are being "packed" inside constructors.
This is a heuristic, and is not related with the effectiveness of the reset/reuse optimization.
It is useful for code such as
```
def f (x y : obj) :=
let z := ctor_1 x y;
ret z
```
-/
def ownArgsIfParam (xs : Array Arg) : M Unit := do
let ctx ← read
xs.forM fun x => do
match x with
| Arg.var x => if ctx.paramSet.contains x.idx then ownVar x
| _ => pure ()
def collectExpr (z : VarId) : Expr → M Unit
| Expr.reset _ x => ownVar z *> ownVar x
| Expr.reuse x _ _ ys => ownVar z *> ownVar x *> ownArgsIfParam ys
| Expr.ctor _ xs => ownVar z *> ownArgsIfParam xs
| Expr.proj _ x => do
if (← isOwned x) then ownVar z
if (← isOwned z) then ownVar x
| Expr.fap g xs => do
let ps ← getParamInfo (ParamMap.Key.decl g)
ownVar z *> ownArgsUsingParams xs ps
| Expr.ap x ys => ownVar z *> ownVar x *> ownArgs ys
| Expr.pap _ xs => ownVar z *> ownArgs xs
| _ => pure ()
def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
let ctx ← read
match v, b with
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
if ctx.decls.any (·.name == g) && x == z then
let ps ← getParamInfo (ParamMap.Key.decl g)
ownParamsUsingArgs ys ps
| _, _ => pure ()
def updateParamSet (ctx : BorrowInfCtx) (ps : Array Param) : BorrowInfCtx :=
{ ctx with paramSet := ps.foldl (fun s p => s.insert p.x.idx) ctx.paramSet }
partial def collectFnBody : FnBody → M Unit
| FnBody.jdecl j ys v b => do
withReader (fun ctx => updateParamSet ctx ys) (collectFnBody v)
let ctx ← read
updateParamMap (ParamMap.Key.jp ctx.currFn j)
collectFnBody b
| FnBody.vdecl x _ v b => collectFnBody b *> collectExpr x v *> preserveTailCall x v b
| FnBody.jmp j ys => do
let ctx ← read
let ps ← getParamInfo (ParamMap.Key.jp ctx.currFn j)
ownArgsUsingParams ys ps -- for making sure the join point can reuse
ownParamsUsingArgs ys ps -- for making sure the tail call is preserved
| FnBody.case _ _ _ alts => alts.forM fun alt => collectFnBody alt.body
| e => do unless e.isTerminal do collectFnBody e.body
partial def collectDecl : Decl → M Unit
| .fdecl (f := f) (xs := ys) (body := b) .. =>
withReader (fun ctx => let ctx := updateParamSet ctx ys; { ctx with currFn := f }) do
collectFnBody b
updateParamMap (ParamMap.Key.decl f)
| _ => pure ()
/-- Keep executing `x` until it reaches a fixpoint -/
partial def whileModifing (x : M Unit) : M Unit := do
modify fun s => { s with modified := false }
x
let s ← get
if s.modified then
whileModifing x
else
pure ()
def collectDecls : M ParamMap := do
whileModifing ((← read).decls.forM collectDecl)
let s ← get
pure s.paramMap
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
collectDecls { env, decls } |>.run' { paramMap := mkInitParamMap env decls }
end Borrow
def inferBorrow (decls : Array Decl) : CompilerM (Array Decl) := do
let env ← getEnv
let paramMap := Borrow.infer env decls
pure (Borrow.applyParamMap decls paramMap)
end IR
end Lean