From 654f219200452e41487904c2a11dad9c17b2fb85 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 7 Aug 2019 13:04:47 -0400 Subject: [PATCH] second pass: loop -> recursion --- hail/python/hail/ir/renderer.py | 128 +++++++++++++++++++++----------- 1 file changed, 86 insertions(+), 42 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 7dbd905c96e..1d1435c1e51 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -279,6 +279,49 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. + class State: + ... + + def loop(self, state): + if state.i >= len(state.children): + return + state.builder.append(' ') + child = state.children[state.i] + lift_to = self.lifted_in_scope(child, state.context) + child_outermost_scope = state.outermost_scope + if state.x.new_block(state.ir_child_num): + child_outermost_scope = state.depth + if (lift_to >= 0 and + state.context[lift_to] and + state.context[lift_to].depth >= state.outermost_scope): + state.ir_child_num += 1 + (name, _) = state.context[lift_to].lifted_lets[id(child)] + if id(child) not in state.context[lift_to].visited: + state.context[lift_to].visited[id(child)] = child + let_body = [f'(Let {name} '] + self.print(let_body, state.context, child_outermost_scope, + state.depth + 1, child) + let_body.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + state.context[lift_to].let_bodies.append(let_body) + state.builder.append(f'(Ref {name})') + state.i += 1 + self.loop(state) + else: + if isinstance(child, ir.BaseIR): + self.print(state.builder, state.context, + child_outermost_scope, state.depth + 1, child) + state.i += 1 + self.loop(state) + else: + state.ir_child_num = self.print_renderable( + state.builder, state.context, + child_outermost_scope, state.depth + 1, + state.ir_child_num, child) + state.i += 1 + self.loop(state) + def print(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: builder.append(self.memo[id(x)]) @@ -292,59 +335,60 @@ def print(self, builder: List[str], context: List[PrintStackFrame], outermost_sc head = x.render_head(self) if head != '': local_builder.append(head) - children = x.render_children(self) - ir_child_num = 0 - for i in range(0, len(children)): - local_builder.append(' ') - child = children[i] - lift_to = self.lifted_in_scope(child, context) - child_outermost_scope = outermost_scope - if x.new_block(ir_child_num): - child_outermost_scope = depth - if lift_to >= 0 and context[lift_to] and context[lift_to].depth >= outermost_scope: - ir_child_num += 1 - (name, _) = context[lift_to].lifted_lets[id(child)] - if id(child) not in context[lift_to].visited: - context[lift_to].visited[id(child)] = child - let_body = [f'(Let {name} '] - self.print(let_body, context, child_outermost_scope, depth + 1, child) - let_body.append(' ') - # let_bodies is built post-order, which guarantees earlier - # lets can't refer to later lets - context[lift_to].let_bodies.append(let_body) - local_builder.append(f'(Ref {name})') - else: - if isinstance(child, ir.BaseIR): - self.print(local_builder, context, child_outermost_scope, depth + 1, child) - else: - ir_child_num = self.print_renderable(local_builder, context, child_outermost_scope, depth + 1, ir_child_num, child) - local_builder.append(x.render_tail(self)) + + state = self.State() + state.x = x + state.children = x.render_children(self) + state.builder = local_builder + state.context = context + state.outermost_scope = outermost_scope + state.ir_child_num = 0 + state.depth = depth + state.i = 0 + + self.loop(state) + + state.builder.append(state.x.render_tail(self)) if insert_lets: - sf = context[-1] - context.pop() + sf = state.context[-1] + state.context.pop() for let_body in sf.let_bodies: builder.extend(let_body) - builder.extend(local_builder) + builder.extend(state.builder) num_lets = len(sf.lifted_lets) - for i in range(num_lets): + for _ in range(num_lets): builder.append(')') + def loop2(self, state): + if state.i >= len(state.children): + return + state.builder.append(' ') + child = state.children[state.i] + if isinstance(child, ir.BaseIR): + self.print(state.builder, state.context, state.outermost_scope, state.depth, child) + state.ir_child_num += 1 + else: + state.ir_child_num = self.print_renderable(state.builder, state.context, state.outermost_scope, state.depth, state.ir_child_num, child) + state.i += 1 + self.loop2(state) + def print_renderable(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable): assert not isinstance(x, ir.BaseIR) head = x.render_head(self) if head != '': builder.append(head) - children = x.render_children(self) - for i in range(0, len(children)): - builder.append(' ') - child = children[i] - if isinstance(child, ir.BaseIR): - self.print(builder, context, outermost_scope, depth, child) - ir_child_num += 1 - else: - ir_child_num = self.print_renderable(builder, context, outermost_scope, depth, ir_child_num, child) - builder.append(x.render_tail(self)) - return ir_child_num + state = self.State() + state.x = x + state.children = x.render_children(self) + state.builder = builder + state.context = context + state.outermost_scope = outermost_scope + state.ir_child_num = ir_child_num + state.depth = depth + state.i = 0 + self.loop2(state) + state.builder.append(state.x.render_tail(self)) + return state.ir_child_num def compute_new_bindings(self, root: 'ir.BaseIR'): # force computation of all types