/
StgToIR.hs
370 lines (304 loc) · 12.5 KB
/
StgToIR.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ParallelListComp #-}
module StgToIR where
import StgLanguage hiding (constructorName)
import qualified StgLanguage as StgL
import ColorUtils
import Debug.Trace
import IR
import IRBuilder
import Control.Monad.Except
import Data.Traversable
import Data.Foldable
import Control.Monad.State.Strict
import Data.Text.Prettyprint.Doc as PP
import qualified OrderedMap as M
import qualified Data.List as L
import qualified Data.Set as S
(&) :: a -> (a -> b) -> b
(&) = flip ($)
-- | The type of an entry function.
-- | () -> void
irTypeEntryFn :: IRType
irTypeEntryFn = IRTypeFunction [] IRTypeVoid
-- | The type of a continuation given by alts
-- | () -> void
irTypeContinuation :: IRType
irTypeContinuation = IRTypePointer irTypeEntryFn
-- | The type of the ID of a heap object
irTypeHeapObjId :: IRType
irTypeHeapObjId = irTypeInt32
-- | The type of an entry function we need to tail call into.
-- remember, "boxed value" is a lie, they're just functions.
irTypeEntryFnPtr :: IRType
irTypeEntryFnPtr = IRTypePointer irTypeEntryFn
-- | Type of the info struct
-- struct {}
irTypeInfoStruct :: IRType
irTypeInfoStruct = IRTypeStruct [ irTypeEntryFnPtr, -- pointer to function to call,
irTypeHeapObjId -- ID of this object
]
-- | Type of the constructor tag
irTypeConstructorTag :: IRType
irTypeConstructorTag = irTypeInt32
-- | Type of a heap object
-- TODO: keep a pointer to the info table. Right now, just store a copy of the
-- info table, since it is way easier to do this.
-- struct { info, void *mem }
irTypeHeapObject :: IRType
irTypeHeapObject = IRTypeStruct [irTypeInfoStruct, -- info table,
irTypeMemoryPtr -- data payload
]
-- | A pointer to a heap object
irTypeHeapObjectPtr :: IRType
irTypeHeapObjectPtr = IRTypePointer irTypeHeapObject
-- | Int value corresponding to binding
type BindingId = Int
-- | Data associated to a binding
data BindingData = BindingData {
binding :: Binding,
bindingId :: BindingId,
bindingFn :: Value
}
instance Pretty BindingData where
pretty BindingData{..}=
vcat [pretty "BindingData {",
indent 4 (vcat $ [pretty "binding :=" <+> pretty binding,
pretty "id := " <+> pretty bindingId,
pretty "bindingFn :=" <+> pretty bindingFn]), pretty "}"]
-- | Int val corresponding to to constructor
type ConstructorId = Int
data ConstructorData = ConstructorData {
constructorName :: ConstructorName,
constructorId :: ConstructorId
}
-- | TODO: create an IRStack object to represent a stack in LLVM IR
-- | G = global, P = pointer. General Context that we need throughout.
data Context = Context {
-- | Stack pointer to continuation values.
contstackGP:: Value,
-- | Number of continuation values on the stack.
contstacknG :: Value,
-- | Register for the tag of a constructor
rtagG :: Value,
-- | Binding name to binding data
bindingNameToData :: M.OrderedMap VarName BindingData,
-- | constructor name to constructor data
constructorNameToData :: M.OrderedMap ConstructorName ConstructorData,
-- | Matcher function
fnmatcher :: Value,
-- | function to push continuation value to stack
fnpushcont :: Value,
-- | function to pop continuation value on stack
fnpopcont :: Value,
-- | function that traps
fntrap :: Value
}
-- | Get all bindings in a program
getBindsInProgram :: Program -> [Binding]
getBindsInProgram prog = prog >>= collectBindingsInBinding
-- | Get all constructors in a program
getConstructorNamesInProgram :: Program -> [ConstructorName]
getConstructorNamesInProgram prog = prog >>= collectConstructorNamesInBinding
-- | Build the function stubs that corresponds to the binding.
-- We first build all the stubs to populate the Context. Then, we can build
-- the indivisual bindings.
buildFnStubForBind :: Binding -> State ModuleBuilder Value
buildFnStubForBind Binding{..} = let
paramsty = []
retty = IRTypeVoid
fnname = (_unVarName _bindingName)
in
createFunction paramsty retty fnname
-- | Create a function that allocates a constructor heap object.
_createAllocConstructorFn :: State ModuleBuilder Value
_createAllocConstructorFn = do
lbl <- createFunction [irTypeHeapObjId] irTypeHeapObjectPtr "alloc_constructor"
runFunctionBuilder lbl $ do
mem <- "mem" =:= InstMalloc irTypeHeapObject
heapObjIdLoc <- "heapObjIdLoc" =:= InstGEP mem [ValueConstInt 0, ValueConstInt 0, ValueConstInt 1]
idval <- getParamValue 0
appendInst $ InstStore heapObjIdLoc idval
return ()
return lbl
-- | Create a function that pushes values on the stack
_createStackPushFn :: String -- ^function name
-> IRType -- ^type of stack elements
-> Value-- ^count global
-> Value -- ^stack pointer global
-> State ModuleBuilder Value
_createStackPushFn fnname elemty nG stackGP = do
lbl <- createFunction [elemty] IRTypeVoid fnname
runFunctionBuilder lbl $ do
n <- "nooo" =:= InstLoad nG
-- Load the pointer
-- stackP <- "stackp" =:= InstLoad stackGP
-- compute store addr
-- storeaddr <- "storeaddr" =:= InstGEP stackP [n]
-- val <- getParamValue 0
-- appendInst $ InstStore storeaddr val
-- TODO: this should be (n + 1)
-- | Understand why this fails..?
ninc <- "ninc" =:= InstAdd n (ValueConstInt 1)
appendInst $ InstStore nG ninc
return ()
return lbl
-- | Create a function that pops values off the stack
_createStackPopFn :: String -- ^Function name
-> IRType -- ^type of stack elements
-> Value -- ^count global
-> Value -- ^stack pointer global
-> State ModuleBuilder Value
_createStackPopFn fnname elemty nG stackGP = do
lbl <- createFunction [] elemty fnname
runFunctionBuilder lbl $ do
n <- "n" =:= InstLoad nG
-- Load the pointer
stackP <- "stackp" =:= InstLoad stackGP
-- compute store addr
loadaddr <- "loadaddr" =:= InstGEP stackP [n]
loadval <- "loadval" =:= InstLoad loadaddr
n' <- "ndec" =:= InstAdd n (ValueConstInt (-1))
appendInst $ InstStore nG n'
setRetInst $ RetInstReturn loadval
return ()
return lbl
-- | Create the `Context` object that is contains data needed to build all of the
-- LLVM Module for our program.
createContext :: [Binding] -> [ConstructorName] -> State ModuleBuilder Context
createContext bs cnames = do
-- NOTE: the pointer in the global is implicit, in the sense of GEP
contstack <- createGlobalVariable "stackcont" (IRTypePointer irTypeContinuation)
contn <- createGlobalVariable "contn" irTypeInt32
bfns <- for bs buildFnStubForBind
rtag <- createGlobalVariable "rtag" (IRTypePointer irTypeConstructorTag)
fnmatcher <- createFunction [irTypeConstructorTag] irTypeContinuation "matcher"
trap <- createFunction [] IRTypeVoid "llvm.trap"
let bdatas = [BindingData {
bindingId=bid,
bindingFn=fn,
binding=b} | bid <- [1..] | fn <- bfns | b <- bs]
let bnames = map (_bindingName . binding) bdatas
let cdatas = [ConstructorData {
constructorId=cid,
constructorName=cname
} | cid <- [1..] | cname <- cnames]
pushcont <- _createStackPushFn "pushcont" irTypeContinuation contn contstack
popcont <- _createStackPopFn "popcont" irTypeContinuation contn contstack
-- allocContructor <- _createAllocConstructorFn
return $ Context {
contstackGP=contstack,
contstacknG=contn,
rtagG=rtag,
bindingNameToData=M.fromList (zip bnames bdatas),
constructorNameToData=M.fromList (zip cnames cdatas),
fnmatcher=fnmatcher,
fnpushcont=pushcont,
fnpopcont=popcont,
fntrap=trap
}
-- | Push a continuation into the stack. Used by alts
pushCont :: Context -> Value -> State FunctionBuilder ()
pushCont ctx val = do
let f = fnpushcont ctx
appendInst $ InstCall f [val]
return ()
-- | Create the instruction to pop a continuation from the stack.
-- Used by alts.
-- Note that return value needs to be named with (=:=)
popCont :: Context -> Inst
popCont ctx = InstCall (fnpopcont ctx) []
createMatcher :: Context -> State ModuleBuilder ()
createMatcher ctx = do
runFunctionBuilder (fnmatcher ctx) (_createMatcherFn (bindingNameToData ctx))
where
-- | Build a BB of the matcher that mathes with the ID and returns the
-- actual function.
-- Return the IR::Value of the switch case needed, and the label of the BB
-- to jump to.
buildMatchBBForBind_ :: M.OrderedMap VarName BindingData -> VarName -> State FunctionBuilder (Value, BBLabel)
buildMatchBBForBind_ bdata bname = do
bbid <- createBB ("switch." ++ (_unVarName bname))
focusBB bbid
let bfn = (bdata M.! bname) & bindingFn :: Value
let bid = (bdata M.! bname) & bindingId :: BindingId
setRetInst (RetInstReturn bfn)
return ((ValueConstInt bid), bbid)
-- | Build the matcher function, that takes a function ID and returns the
-- function corresponding to the ID.
_createMatcherFn :: M.OrderedMap VarName BindingData ->
State FunctionBuilder ()
_createMatcherFn bdata = do
entrybb <- getEntryBBLabel
let bnames = M.keys bdata
switchValAndBBs <- for bnames (buildMatchBBForBind_ bdata)
param <- getParamValue 0
-- create error block
errBB <- createBB "switch.fail"
focusBB errBB
setRetInst (RetInstReturn (ValueUndef (irTypeContinuation)))
-- create entry block
focusBB entrybb
setRetInst (RetInstSwitch param errBB switchValAndBBs)
-- | Create a call to the matcher to return the function with name VarName
createMatcherCallWithName :: Context -> VarName -> Inst
createMatcherCallWithName ctx bname = let
bid = bindingId $ (bindingNameToData ctx) M.! bname
in InstCall (fnmatcher ctx) [(ValueConstInt bid)]
-- | push an STG atom to the correct stack
pushAtomToStack :: Context -> M.OrderedMap VarName Value -> Atom -> State FunctionBuilder ()
pushAtomToStack ctx _ (AtomInt (StgInt i)) =
pushInt ctx (ValueConstInt i) where
pushInt _ _ = error "Unimplemented pushInt"
pushAtomToStack ctx nametoval (AtomVarName v) = pushCont ctx (nametoval M.! v)
-- | Generate code for an expression node in the IR
codegenExprNode :: Context
-> M.OrderedMap VarName Value -- ^mapping between variable name and which value to use to access this
-> ExprNode -- ^expression node
-> State FunctionBuilder ()
-- | Function appplication codegen
codegenExprNode ctx nametoval (ExprNodeFnApplication fnname atoms) = do
fn <- case fnname `M.lookup` nametoval of
Just fn_ -> return $ fn_
Nothing -> "fn" =:= createMatcherCallWithName ctx fnname
for atoms (pushAtomToStack ctx nametoval)
appendInst $ InstCall fn []
return ()
-- | Constructor codegen
{-
codegenExprNode ctx nametoval (ExprNodeConstructor (Constructor name atoms)) = do
jumpfn <- "jumpfn" =:= popCont ctx
for atoms (pushAtomToStack ctx nametoval)
appendInst $ InstCall jumpfn []
return ()
-}
codegenExprNode _ nametoval e = error . docToString $
vcat [pretty " Unimplemented codegen for exprnode: ", indent 4 (pretty e)]
-- | Setup a binding with name VarName
setupTopLevelBinding :: Context -> VarName -> State FunctionBuilder ()
setupTopLevelBinding ctx name = do
let b = binding $ (bindingNameToData ctx) M.! name :: Binding
let Lambda{_lambdaFreeVarIdentifiers=free,
_lambdaBoundVarIdentifiers=bound,
_lambdaExprNode=e} = _bindingLambda b
-- if bound = A B C, stack will have
-- C
-- B
-- A
-- So we need to reverse the stack
boundvals <- for (reverse bound) (\b -> (_unVarName b) =:= (popCont ctx))
let boundNameToVal = M.fromList $ zip bound boundvals :: M.OrderedMap VarName Value
let toplevelNameToVal = fmap bindingFn (bindingNameToData ctx) :: M.OrderedMap VarName Value
let nameToVal = boundNameToVal `M.union` toplevelNameToVal
codegenExprNode ctx nameToVal e
programToModule :: Program -> Module
programToModule p = runModuleBuilder $ do
let bs = getBindsInProgram p
let cs = getConstructorNamesInProgram p
ctx <- createContext bs cs
createMatcher ctx
for_ (M.toList . bindingNameToData $ ctx)
(\(bname, bdata) -> runFunctionBuilder
(bindingFn bdata)
(setupTopLevelBinding ctx bname))
return ()