Skip to content

Commit

Permalink
Optimize closure iterator locals (#23787)
Browse files Browse the repository at this point in the history
This pr redefines the relation between lambda lifting and closureiter
transformation.

Key takeaways:
- Lambdalifting now has less distinction between closureiters and
regular closures. Namely instead of lifting _all_ closureiter variables,
it lifts only those variables it would also lift for simple closure,
i.e. those not owned by the closure.
- It is now closureiter transformation's responsibility to lift all the
locals that need lifting and are not lifted by lambdalifting. So now we
lift only those locals that appear in more than one state. The rest
remains on stack, yay!
- Closureiter transformation always relies on the closure env param
created by lambdalifting. Special care taken to make lambdalifting
create it even in cases when it's "too early" to lift.
- Environments created by lambdalifting will contain `:state` only for
closureiters, whereas previously any closure env contained it.

IMO this is a more reasonable approach as it simplifies not only
lambdalifting, but transf too (e.g. freshVarsForClosureIters is now gone
for good).

I tried to organize the changes logically by commits, so it might be
easier to review this on per commit basis.

Some ugliness:
- Adding lifting to closureiters transformation I had to repeat this
matching of `return result = value` node. I tried to understand why it
is needed, but that was just another rabbit hole, so I left it for
another time. @Araq your input is welcome.
- In the last commit I've reused currently undocumented `liftLocals`
pragma for symbols so that closureiter transformation will forcefully
lift those even if they don't require lifting otherwise. This is needed
for [yasync](https://github.com/yglukhov/yasync) or else it will be very
sad.

Overall I'm quite happy with the results, I'm seeing some noticeable
code size reductions in my projects. Heavy closureiter/async users,
please give it a go.
  • Loading branch information
yglukhov committed Jul 3, 2024
1 parent 051a536 commit 05df263
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 124 deletions.
1 change: 1 addition & 0 deletions compiler/ast.nim
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type
sfMember # proc is a C++ member of a type
sfCodegenDecl # type, proc, global or proc param is marked as codegenDecl
sfWasGenSym # symbol was 'gensym'ed
sfForceLift # variable has to be lifted into closure environment

TSymFlags* = set[TSymFlag]

Expand Down
183 changes: 120 additions & 63 deletions compiler/closureiters.nim
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,6 @@
# else:
# return

# The transformation should play well with lambdalifting, however depending
# on situation, it can be called either before or after lambdalifting
# transformation. As such we behave slightly differently, when accessing
# iterator state, or using temp variables. If lambdalifting did not happen,
# we just create local variables, so that they will be lifted further on.
# Otherwise, we utilize existing env, created by lambdalifting.

# Lambdalifting treats :state variable specially, it should always end up
# as the first field in env. Currently C codegen depends on this behavior.

Expand Down Expand Up @@ -151,7 +144,6 @@ type
Ctx = object
g: ModuleGraph
fn: PSym
stateVarSym: PSym # :state variable. nil if env already introduced by lambdalifting
tmpResultSym: PSym # Used when we return, but finally has to interfere
unrollFinallySym: PSym # Indicates that we're unrolling finally states (either exception happened or premature return)
curExcSym: PSym # Current exception
Expand All @@ -168,18 +160,18 @@ type
nearestFinally: int # Index of the nearest finally block. For try/except it
# is their finally. For finally it is parent finally. Otherwise -1
idgen: IdGenerator
varStates: Table[ItemId, int] # Used to detect if local variable belongs to multiple states

const
nkSkip = {nkEmpty..nkNilLit, nkTemplateDef, nkTypeSection, nkStaticStmt,
nkCommentStmt, nkMixinStmt, nkBindStmt} + procDefs
emptyStateLabel = -1
localNotSeen = -1
localRequiresLifting = -2

proc newStateAccess(ctx: var Ctx): PNode =
if ctx.stateVarSym.isNil:
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
getStateField(ctx.g, ctx.fn), ctx.fn.info)
else:
result = newSymNode(ctx.stateVarSym)
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)),
getStateField(ctx.g, ctx.fn), ctx.fn.info)

proc newStateAssgn(ctx: var Ctx, toValue: PNode): PNode =
# Creates state assignment:
Expand All @@ -195,24 +187,17 @@ proc newEnvVar(ctx: var Ctx, name: string, typ: PType): PSym =
result = newSym(skVar, getIdent(ctx.g.cache, name), ctx.idgen, ctx.fn, ctx.fn.info)
result.typ = typ
result.flags.incl sfNoInit
assert(not typ.isNil)

if not ctx.stateVarSym.isNil:
# We haven't gone through labmda lifting yet, so just create a local var,
# it will be lifted later
if ctx.tempVars.isNil:
ctx.tempVars = newNodeI(nkVarSection, ctx.fn.info)
addVar(ctx.tempVars, newSymNode(result))
else:
let envParam = getEnvParam(ctx.fn)
# let obj = envParam.typ.lastSon
result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)
assert(not typ.isNil, "Env var needs a type")

let envParam = getEnvParam(ctx.fn)
# let obj = envParam.typ.lastSon
result = addUniqueField(envParam.typ.elementType, result, ctx.g.cache, ctx.idgen)

proc newEnvVarAccess(ctx: Ctx, s: PSym): PNode =
if ctx.stateVarSym.isNil:
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)
else:
result = newSymNode(s)
result = rawIndirectAccess(newSymNode(getEnvParam(ctx.fn)), s, ctx.fn.info)

proc newTempVarAccess(ctx: Ctx, s: PSym): PNode =
result = newSymNode(s, ctx.fn.info)

proc newTmpResultAccess(ctx: var Ctx): PNode =
if ctx.tmpResultSym.isNil:
Expand Down Expand Up @@ -255,9 +240,18 @@ proc addGotoOut(n: PNode, gotoOut: PNode): PNode =
if result.len == 0 or result[^1].kind != nkGotoState:
result.add(gotoOut)

proc newTempVar(ctx: var Ctx, typ: PType): PSym =
result = ctx.newEnvVar(":tmpSlLower" & $ctx.tempVarId, typ)
proc newTempVarDef(ctx: Ctx, s: PSym, initialValue: PNode): PNode =
var v = initialValue
if v == nil:
v = ctx.g.emptyNode
newTree(nkVarSection, newTree(nkIdentDefs, newSymNode(s), ctx.g.emptyNode, v))

proc newTempVar(ctx: var Ctx, typ: PType, parent: PNode, initialValue: PNode = nil): PSym =
result = newSym(skVar, getIdent(ctx.g.cache, ":tmpSlLower" & $ctx.tempVarId), ctx.idgen, ctx.fn, ctx.fn.info)
inc ctx.tempVarId
result.typ = typ
assert(not typ.isNil, "Temp var needs a type")
parent.add(ctx.newTempVarDef(result, initialValue))

proc hasYields(n: PNode): bool =
# TODO: This is very inefficient. It traverses the node, looking for nkYieldStmt.
Expand Down Expand Up @@ -429,21 +423,20 @@ proc exprToStmtList(n: PNode): tuple[s, res: PNode] =

result.res = n


proc newEnvVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
proc newTempVarAsgn(ctx: Ctx, s: PSym, v: PNode): PNode =
if isEmptyType(v.typ):
result = v
else:
result = newTree(nkFastAsgn, ctx.newEnvVarAccess(s), v)
result = newTree(nkFastAsgn, ctx.newTempVarAccess(s), v)
result.info = v.info

proc addExprAssgn(ctx: Ctx, output, input: PNode, sym: PSym) =
if input.kind == nkStmtListExpr:
let (st, res) = exprToStmtList(input)
output.add(st)
output.add(ctx.newEnvVarAsgn(sym, res))
output.add(ctx.newTempVarAsgn(sym, res))
else:
output.add(ctx.newEnvVarAsgn(sym, input))
output.add(ctx.newTempVarAsgn(sym, input))

proc convertExprBodyToAsgn(ctx: Ctx, exprBody: PNode, res: PSym): PNode =
result = newNodeI(nkStmtList, exprBody.info)
Expand All @@ -457,6 +450,12 @@ proc boolLit(g: ModuleGraph; info: TLineInfo; value: bool): PNode =
result = newIntLit(g, info, ord value)
result.typ = getSysType(g, info, tyBool)

proc captureVar(c: var Ctx, s: PSym) =
if c.varStates.getOrDefault(s.itemId) != localRequiresLifting:
c.varStates[s.itemId] = localRequiresLifting # Mark this variable for lifting
let e = getEnvParam(c.fn)
discard addField(e.typ.elementType, s, c.g.cache, c.idgen)

proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
result = n
case n.kind
Expand Down Expand Up @@ -513,9 +512,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
var tmp: PSym = nil
let isExpr = not isEmptyType(n.typ)
if isExpr:
tmp = ctx.newTempVar(n.typ)
result = newNodeI(nkStmtListExpr, n.info)
result.typ = n.typ
tmp = ctx.newTempVar(n.typ, result)
else:
result = newNodeI(nkStmtList, n.info)

Expand Down Expand Up @@ -566,7 +565,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
else:
internalError(ctx.g.config, "lowerStmtListExpr(nkIf): " & $branch.kind)

if isExpr: result.add(ctx.newEnvVarAccess(tmp))
if isExpr: result.add(ctx.newTempVarAccess(tmp))

of nkTryStmt, nkHiddenTryStmt:
var ns = false
Expand All @@ -580,7 +579,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
if isExpr:
result = newNodeI(nkStmtListExpr, n.info)
result.typ = n.typ
let tmp = ctx.newTempVar(n.typ)
let tmp = ctx.newTempVar(n.typ, result)

n[0] = ctx.convertExprBodyToAsgn(n[0], tmp)
for i in 1..<n.len:
Expand All @@ -596,7 +595,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
else:
internalError(ctx.g.config, "lowerStmtListExpr(nkTryStmt): " & $branch.kind)
result.add(n)
result.add(ctx.newEnvVarAccess(tmp))
result.add(ctx.newTempVarAccess(tmp))

of nkCaseStmt:
var ns = false
Expand All @@ -609,9 +608,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
let isExpr = not isEmptyType(n.typ)

if isExpr:
let tmp = ctx.newTempVar(n.typ)
result = newNodeI(nkStmtListExpr, n.info)
result.typ = n.typ
let tmp = ctx.newTempVar(n.typ, result)

if n[0].kind == nkStmtListExpr:
let (st, ex) = exprToStmtList(n[0])
Expand All @@ -628,7 +627,7 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
else:
internalError(ctx.g.config, "lowerStmtListExpr(nkCaseStmt): " & $branch.kind)
result.add(n)
result.add(ctx.newEnvVarAccess(tmp))
result.add(ctx.newTempVarAccess(tmp))
elif n[0].kind == nkStmtListExpr:
result = newNodeI(nkStmtList, n.info)
let (st, ex) = exprToStmtList(n[0])
Expand Down Expand Up @@ -658,10 +657,10 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
result.add(st)
cond = ex

let tmp = ctx.newTempVar(cond.typ)
result.add(ctx.newEnvVarAsgn(tmp, cond))
let tmp = ctx.newTempVar(cond.typ, result, cond)
# result.add(ctx.newTempVarAsgn(tmp, cond))

var check = ctx.newEnvVarAccess(tmp)
var check = ctx.newTempVarAccess(tmp)
if n[0].sym.magic == mOr:
check = ctx.g.newNotCall(check)

Expand All @@ -671,12 +670,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
let (st, ex) = exprToStmtList(cond)
ifBody.add(st)
cond = ex
ifBody.add(ctx.newEnvVarAsgn(tmp, cond))
ifBody.add(ctx.newTempVarAsgn(tmp, cond))

let ifBranch = newTree(nkElifBranch, check, ifBody)
let ifNode = newTree(nkIfStmt, ifBranch)
result.add(ifNode)
result.add(ctx.newEnvVarAccess(tmp))
result.add(ctx.newTempVarAccess(tmp))
else:
for i in 0..<n.len:
if n[i].kind == nkStmtListExpr:
Expand All @@ -685,9 +684,9 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
n[i] = ex

if n[i].kind in nkCallKinds: # XXX: This should better be some sort of side effect tracking
let tmp = ctx.newTempVar(n[i].typ)
result.add(ctx.newEnvVarAsgn(tmp, n[i]))
n[i] = ctx.newEnvVarAccess(tmp)
let tmp = ctx.newTempVar(n[i].typ, result, n[i])
# result.add(ctx.newTempVarAsgn(tmp, n[i]))
n[i] = ctx.newTempVarAccess(tmp)

result.add(n)

Expand All @@ -703,6 +702,12 @@ proc lowerStmtListExprs(ctx: var Ctx, n: PNode, needsSplit: var bool): PNode =
let (st, ex) = exprToStmtList(c[^1])
result.add(st)
c[^1] = ex
for i in 0 .. c.len - 3:
if c[i].kind == nkSym:
let s = c[i].sym
if sfForceLift in s.flags:
ctx.captureVar(s)

result.add(varSect)

of nkDiscardStmt, nkReturnStmt, nkRaiseStmt:
Expand Down Expand Up @@ -1279,13 +1284,6 @@ proc wrapIntoStateLoop(ctx: var Ctx, n: PNode): PNode =
result.info = n.info

let localVars = newNodeI(nkStmtList, n.info)
if not ctx.stateVarSym.isNil:
let varSect = newNodeI(nkVarSection, n.info)
addVar(varSect, newSymNode(ctx.stateVarSym))
localVars.add(varSect)

if not ctx.tempVars.isNil:
localVars.add(ctx.tempVars)

let blockStmt = newNodeI(nkBlockStmt, n.info)
blockStmt.add(newSymNode(ctx.stateLoopLabel))
Expand Down Expand Up @@ -1433,15 +1431,67 @@ proc preprocess(c: var PreprocessContext; n: PNode): PNode =
for i in 0 ..< n.len:
result[i] = preprocess(c, n[i])

proc detectCapturedVars(c: var Ctx, n: PNode, stateIdx: int) =
case n.kind
of nkSym:
let s = n.sym
if s.kind in {skResult, skVar, skLet, skForVar, skTemp} and sfGlobal notin s.flags and s.owner == c.fn:
let vs = c.varStates.getOrDefault(s.itemId, localNotSeen)
if vs == localNotSeen: # First seing this variable
c.varStates[s.itemId] = stateIdx
elif vs == localRequiresLifting:
discard # Sym already marked
elif vs != stateIdx:
c.captureVar(s)
of nkReturnStmt:
if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
# we have a `result = result` expression produced by the closure
# transform, let's not touch the LHS in order to make the lifting pass
# correct when `result` is lifted
detectCapturedVars(c, n[0][1], stateIdx)
else:
detectCapturedVars(c, n[0], stateIdx)
else:
for i in 0 ..< n.safeLen:
detectCapturedVars(c, n[i], stateIdx)

proc detectCapturedVars(c: var Ctx) =
for i, s in c.states:
detectCapturedVars(c, s.body, i)

proc liftLocals(c: var Ctx, n: PNode): PNode =
result = n
case n.kind
of nkSym:
let s = n.sym
if c.varStates.getOrDefault(s.itemId) == localRequiresLifting:
# lift
let e = getEnvParam(c.fn)
let field = getFieldFromObj(e.typ.elementType, s)
assert(field != nil)
result = rawIndirectAccess(newSymNode(e), field, n.info)
# elif c.varStates.getOrDefault(s.itemId, localNotSeen) != localNotSeen:
# echo "Not lifting ", s.name.s

of nkReturnStmt:
if n[0].kind in {nkAsgn, nkFastAsgn, nkSinkAsgn}:
# we have a `result = result` expression produced by the closure
# transform, let's not touch the LHS in order to make the lifting pass
# correct when `result` is lifted
n[0][1] = liftLocals(c, n[0][1])
else:
n[0] = liftLocals(c, n[0])
else:
for i in 0 ..< n.safeLen:
n[i] = liftLocals(c, n[i])

proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n: PNode): PNode =
var ctx = Ctx(g: g, fn: fn, idgen: idgen)

if getEnvParam(fn).isNil:
# Lambda lifting was not done yet. Use temporary :state sym, which will
# be handled specially by lambda lifting. Local temp vars (if needed)
# should follow the same logic.
ctx.stateVarSym = newSym(skVar, getIdent(ctx.g.cache, ":state"), idgen, fn, fn.info)
ctx.stateVarSym.typ = g.createClosureIterStateType(fn, idgen)
# The transformation should always happen after at least partial lambdalifting
# is performed, so that the closure iter environment is always created upfront.
doAssert(getEnvParam(fn) != nil, "Env param not created before iter transformation")

ctx.stateLoopLabel = newSym(skLabel, getIdent(ctx.g.cache, ":stateLoop"), idgen, fn, fn.info)
var pc = PreprocessContext(finallys: @[], config: g.config, idgen: idgen)
var n = preprocess(pc, n.toStmtList)
Expand All @@ -1466,13 +1516,18 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
let caseDispatcher = newTreeI(nkCaseStmt, n.info,
ctx.newStateAccess())

# Lamdalifting will not touch our locals, it is our responsibility to lift those that
# need it.
detectCapturedVars(ctx)

for s in ctx.states:
let body = ctx.transformStateAssignments(s.body)
caseDispatcher.add newTreeI(nkOfBranch, body.info, g.newIntLit(body.info, s.label), body)

caseDispatcher.add newTreeI(nkElse, n.info, newTreeI(nkReturnStmt, n.info, g.emptyNode))

result = wrapIntoStateLoop(ctx, caseDispatcher)
result = liftLocals(ctx, result)

when false:
echo "TRANSFORM TO STATES: "
Expand All @@ -1481,3 +1536,5 @@ proc transformClosureIterator*(g: ModuleGraph; idgen: IdGenerator; fn: PSym, n:
echo "exception table:"
for i, e in ctx.exceptionTable:
echo i, " -> ", e

echo "ENV: ", renderTree(getEnvParam(fn).typ.elementType.n)
Loading

0 comments on commit 05df263

Please sign in to comment.