Skip to content

Commit

Permalink
second pass: loop -> recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Aug 7, 2019
1 parent ed6ce81 commit 654f219
Showing 1 changed file with 86 additions and 42 deletions.
128 changes: 86 additions & 42 deletions hail/python/hail/ir/renderer.py
Expand Up @@ -279,6 +279,49 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int:
# * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing
# any lets to be inserted above y.

class State:
...

def loop(self, state):
if state.i >= len(state.children):
return
state.builder.append(' ')
child = state.children[state.i]
lift_to = self.lifted_in_scope(child, state.context)
child_outermost_scope = state.outermost_scope
if state.x.new_block(state.ir_child_num):
child_outermost_scope = state.depth
if (lift_to >= 0 and
state.context[lift_to] and
state.context[lift_to].depth >= state.outermost_scope):
state.ir_child_num += 1
(name, _) = state.context[lift_to].lifted_lets[id(child)]
if id(child) not in state.context[lift_to].visited:
state.context[lift_to].visited[id(child)] = child
let_body = [f'(Let {name} ']
self.print(let_body, state.context, child_outermost_scope,
state.depth + 1, child)
let_body.append(' ')
# let_bodies is built post-order, which guarantees earlier
# lets can't refer to later lets
state.context[lift_to].let_bodies.append(let_body)
state.builder.append(f'(Ref {name})')
state.i += 1
self.loop(state)
else:
if isinstance(child, ir.BaseIR):
self.print(state.builder, state.context,
child_outermost_scope, state.depth + 1, child)
state.i += 1
self.loop(state)
else:
state.ir_child_num = self.print_renderable(
state.builder, state.context,
child_outermost_scope, state.depth + 1,
state.ir_child_num, child)
state.i += 1
self.loop(state)

def print(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, x: 'ir.BaseIR'):
if id(x) in self.memo:
builder.append(self.memo[id(x)])
Expand All @@ -292,59 +335,60 @@ def print(self, builder: List[str], context: List[PrintStackFrame], outermost_sc
head = x.render_head(self)
if head != '':
local_builder.append(head)
children = x.render_children(self)
ir_child_num = 0
for i in range(0, len(children)):
local_builder.append(' ')
child = children[i]
lift_to = self.lifted_in_scope(child, context)
child_outermost_scope = outermost_scope
if x.new_block(ir_child_num):
child_outermost_scope = depth
if lift_to >= 0 and context[lift_to] and context[lift_to].depth >= outermost_scope:
ir_child_num += 1
(name, _) = context[lift_to].lifted_lets[id(child)]
if id(child) not in context[lift_to].visited:
context[lift_to].visited[id(child)] = child
let_body = [f'(Let {name} ']
self.print(let_body, context, child_outermost_scope, depth + 1, child)
let_body.append(' ')
# let_bodies is built post-order, which guarantees earlier
# lets can't refer to later lets
context[lift_to].let_bodies.append(let_body)
local_builder.append(f'(Ref {name})')
else:
if isinstance(child, ir.BaseIR):
self.print(local_builder, context, child_outermost_scope, depth + 1, child)
else:
ir_child_num = self.print_renderable(local_builder, context, child_outermost_scope, depth + 1, ir_child_num, child)
local_builder.append(x.render_tail(self))

state = self.State()
state.x = x
state.children = x.render_children(self)
state.builder = local_builder
state.context = context
state.outermost_scope = outermost_scope
state.ir_child_num = 0
state.depth = depth
state.i = 0

self.loop(state)

state.builder.append(state.x.render_tail(self))
if insert_lets:
sf = context[-1]
context.pop()
sf = state.context[-1]
state.context.pop()
for let_body in sf.let_bodies:
builder.extend(let_body)
builder.extend(local_builder)
builder.extend(state.builder)
num_lets = len(sf.lifted_lets)
for i in range(num_lets):
for _ in range(num_lets):
builder.append(')')

def loop2(self, state):
if state.i >= len(state.children):
return
state.builder.append(' ')
child = state.children[state.i]
if isinstance(child, ir.BaseIR):
self.print(state.builder, state.context, state.outermost_scope, state.depth, child)
state.ir_child_num += 1
else:
state.ir_child_num = self.print_renderable(state.builder, state.context, state.outermost_scope, state.depth, state.ir_child_num, child)
state.i += 1
self.loop2(state)

def print_renderable(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable):
assert not isinstance(x, ir.BaseIR)
head = x.render_head(self)
if head != '':
builder.append(head)
children = x.render_children(self)
for i in range(0, len(children)):
builder.append(' ')
child = children[i]
if isinstance(child, ir.BaseIR):
self.print(builder, context, outermost_scope, depth, child)
ir_child_num += 1
else:
ir_child_num = self.print_renderable(builder, context, outermost_scope, depth, ir_child_num, child)
builder.append(x.render_tail(self))
return ir_child_num
state = self.State()
state.x = x
state.children = x.render_children(self)
state.builder = builder
state.context = context
state.outermost_scope = outermost_scope
state.ir_child_num = ir_child_num
state.depth = depth
state.i = 0
self.loop2(state)
state.builder.append(state.x.render_tail(self))
return state.ir_child_num

def compute_new_bindings(self, root: 'ir.BaseIR'):
# force computation of all types
Expand Down

0 comments on commit 654f219

Please sign in to comment.