From 62e56806caf2803e1ce48e49934ff665369e3213 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 24 Jul 2019 11:51:06 -0400 Subject: [PATCH 01/47] New CSE IR renrerer in Python MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Doesn’t handle agg environments correctly. --- hail/python/hail/backend/backend.py | 4 +- hail/python/hail/ir/base_ir.py | 20 ++- hail/python/hail/ir/blockmatrix_ir.py | 6 + hail/python/hail/ir/ir.py | 53 ++++++++ hail/python/hail/ir/matrix_ir.py | 73 +++++++++++ hail/python/hail/ir/renderer.py | 174 +++++++++++++++++++++++++- hail/python/hail/ir/table_ir.py | 41 ++++++ hail/python/test/hail/test_ir.py | 53 ++++++++ 8 files changed, 419 insertions(+), 5 deletions(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 24c260e638e..163981965c2 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -6,7 +6,7 @@ from hail.expr.table_type import * from hail.expr.matrix_type import * from hail.expr.blockmatrix_type import * -from hail.ir.renderer import Renderer +from hail.ir.renderer import Renderer, CSERenderer, NewCSE from hail.table import Table from hail.matrixtable import MatrixTable @@ -99,7 +99,7 @@ def fs(self): def _to_java_ir(self, ir): if not hasattr(ir, '_jir'): - r = Renderer(stop_at_jir=True) + r = CSERenderer(stop_at_jir=True) # FIXME parse should be static ir._jir = ir.parse(r(ir), ir_map=r.jirs) return ir._jir diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index e116ffe40e0..6070233d306 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,6 +1,6 @@ import abc -from typing import List +from typing import List, Set from hail.utils.java import Env from .renderer import Renderer, Renderable, RenderableStr @@ -44,7 +44,8 @@ def head_str(self): def parse(self, code, ref_map, ir_map): return - @abc.abstractproperty + @property + @abc.abstractmethod def typ(self): return @@ -71,6 +72,13 @@ def _eq(self, other): def __hash__(self): return 31 + hash(str(self)) + @staticmethod + def new_block(i: int): + return False + + def binds(self, i: int) -> Set[str]: + return set() + class IR(BaseIR): def __init__(self, *children): @@ -146,6 +154,9 @@ def typ(self): def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_table_ir(code, ref_map, ir_map) + global_env = {'global'} + row_env = {'global', 'row'} + class MatrixIR(BaseIR): def __init__(self, *children): @@ -165,6 +176,11 @@ def typ(self): def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map) + global_env = {'global'} + row_env = {'global', 'va'} + col_env = {'global', 'sa'} + entry_env = {'global', 'sa', 'va', 'g'} + class BlockMatrixIR(BaseIR): def __init__(self, *children): diff --git a/hail/python/hail/ir/blockmatrix_ir.py b/hail/python/hail/ir/blockmatrix_ir.py index b6db226171f..3a18caefa08 100644 --- a/hail/python/hail/ir/blockmatrix_ir.py +++ b/hail/python/hail/ir/blockmatrix_ir.py @@ -37,6 +37,9 @@ def __init__(self, child, f): def _compute_type(self): self._type = self.child.typ + def binds(self, i): + return {'element'} if i == 1 else {} + class BlockMatrixMap2(BlockMatrixIR): @typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR, f=IR) @@ -50,6 +53,9 @@ def _compute_type(self): self.right.typ # Force self._type = self.left.typ + def binds(self, i): + return {'l', 'r'} if i == 2 else {} + class BlockMatrixDot(BlockMatrixIR): @typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 82917a29a81..d8be066d9dc 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -236,6 +236,10 @@ def _compute_type(self, env, agg_env): assert (self.cnsq.typ == self.altr.typ) self._type = self.cnsq.typ + @staticmethod + def new_block(i): + return i == 1 or i == 2 + class Coalesce(IR): @typecheck_method(values=IR) @@ -264,6 +268,10 @@ def __init__(self, name, value, body): self.value = value self.body = body + @staticmethod + def new_block(i): + return i > 0 + @typecheck_method(value=IR, body=IR) def copy(self, value, body): return Let(self.name, value, body) @@ -283,6 +291,9 @@ def _compute_type(self, env, agg_env): self.body._compute_type(_env_bind(env, self.name, self.value._type), agg_env) self._type = self.body._type + def binds(self, i): + return {self.name} if i == 1 else {} + class Ref(IR): @typecheck_method(name=str) @@ -550,6 +561,9 @@ def _compute_type(self, env, agg_env): self.body._compute_type(_env_bind(env, self.name, self.nd.typ.element_type), agg_env) self._type = tndarray(self.body.typ, self.nd.typ.ndim) + def binds(self, i): + return {self.name} if i == 1 else {} + class NDArrayRef(IR): @typecheck_method(nd=IR, idxs=sequenceof(IR)) @@ -696,6 +710,9 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self._type = self.a.typ + def binds(self, i): + return {self.l_name, self.r_name} if i == 1 else {} + class ToSet(IR): @typecheck_method(a=IR) @@ -806,6 +823,9 @@ def _compute_type(self, env, agg_env): self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) self._type = tarray(self.body.typ) + def binds(self, i): + return {self.name} if i == 1 else {} + class ArrayFilter(IR): @typecheck_method(a=IR, name=str, body=IR) @@ -834,6 +854,9 @@ def _compute_type(self, env, agg_env): self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) self._type = self.a.typ + def binds(self, i): + return {self.name} if i == 1 else {} + class ArrayFlatMap(IR): @typecheck_method(a=IR, name=str, body=IR) @@ -862,6 +885,9 @@ def _compute_type(self, env, agg_env): self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) self._type = tarray(self.body.typ.element_type) + def binds(self, i): + return {self.name} if i == 1 else {} + class ArrayFold(IR): @typecheck_method(a=IR, zero=IR, accum_name=str, value_name=str, body=IR) @@ -898,6 +924,9 @@ def _compute_type(self, env, agg_env): agg_env) self._type = self.zero.typ + def binds(self, i): + return {self.accum_name, self.value_name} if i == 2 else {} + class ArrayScan(IR): @typecheck_method(a=IR, zero=IR, accum_name=str, value_name=str, body=IR) @@ -934,6 +963,9 @@ def _compute_type(self, env, agg_env): agg_env) self._type = tarray(self.body.typ) + def binds(self, i): + return {self.accum_name, self.value_name} if i == 2 else {} + class ArrayLeftJoinDistinct(IR): @typecheck_method(left=IR, right=IR, l_name=str, r_name=str, compare=IR, join=IR) @@ -961,6 +993,9 @@ def _eq(self, other): def bound_variables(self): return {self.l_name, self.r_name} | super().bound_variables + def binds(self, i): + return {self.l_name, self.r_name} if i == 2 or i == 3 else {} + class ArrayFor(IR): @typecheck_method(a=IR, value_name=str, body=IR) @@ -989,6 +1024,9 @@ def _compute_type(self, env, agg_env): self.body._compute_type(_env_bind(env, self.value_name, self.a.typ.element_type), agg_env) self._type = tvoid + def binds(self, i): + return {self.value_name} if i == 1 else {} + class AggFilter(IR): @typecheck_method(cond=IR, agg_ir=IR, is_scan=bool) @@ -1042,6 +1080,9 @@ def _compute_type(self, env, agg_env): self.agg_body._compute_type(env, _env_bind(agg_env, self.name, self.array.typ.element_type)) self._type = self.agg_body.typ + def binds(self, i): + return {self.name} if i == 1 else {} + class AggGroupBy(IR): @typecheck_method(key=IR, agg_ir=IR, is_scan=bool) @@ -1097,6 +1138,9 @@ def _compute_type(self, env, agg_env): def bound_variables(self): return {self.element_name, self.index_name} | super().bound_variables + def binds(self, i): + return {self.index_name, self.element_name} if i == 1 else {} + def _register(registry, name, f): registry[name].append(f) @@ -1615,6 +1659,9 @@ def _compute_type(self, env, agg_env): self.max._compute_type(env, agg_env) self._type = tfloat64 + def binds(self, i): + return {self.argname} if i == 0 else {} + class TableCount(IR): @typecheck_method(child=TableIR) @@ -1677,6 +1724,9 @@ def _compute_type(self, env, agg_env): self.query._compute_type(self.child.typ.global_env(), self.child.typ.row_env()) self._type = self.query.typ + def binds(self, i): + return TableIR.row_env if i == 1 else {} + class MatrixAggregate(IR): @typecheck_method(child=MatrixIR, query=IR) @@ -1698,6 +1748,9 @@ def _compute_type(self, env, agg_env): self.query._compute_type(self.child.typ.global_env(), self.child.typ.entry_env()) self._type = self.query.typ + def binds(self, i): + return MatrixIR.entry_env if i == 1 else {} + class TableWrite(IR): @typecheck_method(child=TableIR, writer=TableWriter) diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index a1fa0e47d7d..5b5663503e0 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -10,6 +10,10 @@ def __init__(self, child, entry_expr, row_expr): self.entry_expr = entry_expr self.row_expr = row_expr + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.entry_expr._compute_type(child_typ.col_env(), child_typ.entry_env()) @@ -22,6 +26,14 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) + def binds(self, i): + if i == 1: + return self.entry_env + elif i == 2: + return self.row_env + else: + return {} + class MatrixRead(MatrixIR): def __init__(self, reader, drop_cols=False, drop_rows=False): @@ -46,10 +58,17 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ + def binds(self, i): + return self.row_env if i == 1 else {} + class MatrixChooseCols(MatrixIR): def __init__(self, child, old_indices): @@ -74,6 +93,10 @@ def __init__(self, child, new_col, new_key): self.new_col = new_col self.new_key = new_key + @staticmethod + def new_block(i): + return i > 0 + def head_str(self): return '(' + ' '.join(f'"{escape_str(f)}"' for f in self.new_key) + ')' if self.new_key is not None else 'None' @@ -91,6 +114,9 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def binds(self, i): + return self.entry_env if i == 1 else {} + class MatrixUnionCols(MatrixIR): def __init__(self, left, right): @@ -109,6 +135,10 @@ def __init__(self, child, new_entry): self.child = child self.new_entry = new_entry + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.new_entry._compute_type(child_typ.entry_env(), None) @@ -120,6 +150,9 @@ def _compute_type(self): child_typ.row_key, self.new_entry.typ) + def binds(self, i): + return self.entry_env if i == 1 else {} + class MatrixFilterEntries(MatrixIR): def __init__(self, child, pred): @@ -127,10 +160,17 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.entry_env(), None) self._type = self.child.typ + def binds(self, i): + return self.entry_env if i == 1 else {} + class MatrixKeyRowsBy(MatrixIR): def __init__(self, child, keys, is_sorted=False): @@ -164,6 +204,10 @@ def __init__(self, child, new_row): self.child = child self.new_row = new_row + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.new_row._compute_type(child_typ.row_env(), child_typ.entry_env()) @@ -175,6 +219,9 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def binds(self, i): + return self.entry_env if i == 1 else {} + class MatrixMapGlobals(MatrixIR): def __init__(self, child, new_global): @@ -182,6 +229,10 @@ def __init__(self, child, new_global): self.child = child self.new_global = new_global + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.new_global._compute_type(child_typ.global_env(), None) @@ -193,6 +244,9 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def binds(self, i): + return self.global_env if i == 1 else {} + class MatrixFilterCols(MatrixIR): def __init__(self, child, pred): @@ -200,10 +254,17 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.col_env(), None) self._type = self.child.typ + def binds(self, i): + return self.col_env if i == 1 else {} + class MatrixCollectColsByKey(MatrixIR): def __init__(self, child): @@ -229,6 +290,10 @@ def __init__(self, child, entry_expr, col_expr): self.entry_expr = entry_expr self.col_expr = col_expr + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.entry_expr._compute_type(child_typ.row_env(), child_typ.entry_env()) @@ -241,6 +306,14 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) + def binds(self, i): + if i == 1: + return self.entry_env + elif i == 2: + return self.col_env + else: + return {} + class MatrixExplodeRows(MatrixIR): def __init__(self, child, path): diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index c1339496032..ccea35c7792 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,6 +1,6 @@ from hail import ir import abc -from typing import Sequence, List +from typing import Sequence, List, Set, Dict class Renderable(object): @@ -134,3 +134,175 @@ def __call__(self, x: 'Renderable'): x = top.pop() return ''.join(builder) + + +class Scope: + def __init__(self): + self.visited: Set[int] = set() + self.lifted_lets: Dict[int, str] = {} + self.let_bodies: List[str] = [] + + def visit(self, x: 'BaseIR'): + self.visited.add(id(x)) + +class CSERenderer(Renderer): + def __init__(self, stop_at_jir=False): + self.stop_at_jir = stop_at_jir + self.jir_count = 0 + self.jirs = {} + self.memo: Dict[int, Sequence[str]] = {} + self.uid_count = 0 + self.scopes: Dict[int, Scope] = {} + + def uid(self) -> str: + self.uid_count += 1 + return f'__cse_{self.uid_count}' + + def add_jir(self, x: 'ir.BaseIR'): + jir_id = f'm{self.jir_count}' + self.jir_count += 1 + self.jirs[jir_id] = x._jir + if isinstance(x, ir.MatrixIR): + return f'(JavaMatrix {jir_id})' + elif isinstance(x, ir.TableIR): + return f'(JavaTable {jir_id})' + elif isinstance(x, ir.BlockMatrixIR): + return f'(JavaBlockMatrix {jir_id})' + else: + assert isinstance(x, ir.IR) + return f'(JavaIR {jir_id})' + + def find_in_context(self, x: 'ir.BaseIR', context: List[Scope]) -> int: + for i in range(len(context)): + if id(x) in context[i].visited: + return i + return -1 + + def lifted_in_context(self, x: 'ir.BaseIR', context: List[Scope]) -> int: + for i in range(len(context)): + if id(x) in context[i].lifted_lets: + return i + return -1 + + # Pre: + # * 'context' is a list of 'Scope's, one for each potential let-insertion + # site. + # * 'ref_to_scope' maps each bound variable to the index in 'context' of the + # scope of its binding site. + # Post: + # * Returns set of free variables in 'x'. + # * Each subtree of 'x' is flagged as visited in the outermost scope + # containing all of its free variables. + # * Each subtree previously visited (either an earlier subtree of 'x', or + # marked visited in 'context') is added to set of lets in its (previously + # computed) outermost scope. + # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing + # any lets to be inserted above y. + def recur(self, context: List[Scope], ref_to_scope: Dict[str, int], x: 'ir.BaseIR') -> Set[str]: + # Ref nodes should never be lifted to a let. (Not that it would be + # incorrect, just pointlessly adding names for the same thing.) + if isinstance(x, ir.Ref): + return {x.name} + free_vars = set() + for i in range(len(x.children)): + child = x.children[i] + # FIXME: maintain a union of seen nodes in all scopes + seen_in_context = self.find_in_context(child, context) + if seen_in_context >= 0: + # for now, only add non-relational lets + # FIXME: if generating uid can be moved to second pass, this + # can be factored into 'Scope' + if id(child) not in context[seen_in_context].lifted_lets and isinstance(child, ir.IR): + # second time we've seen 'child', lift to a let + context[seen_in_context].lifted_lets[id(child)] = self.uid() + elif self.stop_at_jir and hasattr(child, '_jir'): + self.memo[id(child)] = self.add_jir(child) + else: + binds = x.binds(i) + if len(binds) > 0: + new_scope = Scope() + if x.new_block(i): + # Repeated subtrees of this child should never be lifted to + # lets above 'x'. We accomplish that by clearing the context + # in the recursive call. + child_context = [new_scope] + child_ref_to_scope = {} + new_idx = 0 + else: + new_idx = len(context) + context.append(new_scope) + child_context = context + child_ref_to_scope = ref_to_scope.copy() + for v in binds: + child_ref_to_scope[v] = new_idx + child_free_vars = self.recur(child_context, child_ref_to_scope, child) - binds + free_vars |= child_free_vars + if not x.new_block(i): + context.pop() + new_scope.visited.clear() + self.scopes[id(child)] = new_scope + elif x.new_block(i): + new_scope = Scope() + free_vars |= self.recur([new_scope], {}, child) + new_scope.visited.clear() + self.scopes[id(child)] = new_scope + else: + free_vars |= self.recur(context, ref_to_scope, child) + if len(free_vars) > 0: + outermost_scope = max((ref_to_scope.get(v, 0) for v in free_vars)) + else: + outermost_scope = 0 + context[outermost_scope].visited.add(id(x)) + return free_vars + + def print(self, builder: List[str], context: List[Scope], x: Renderable): + if id(x) in self.memo: + builder.append(self.memo[id(x)]) + return + insert_lets = id(x) in self.scopes and len(self.scopes[id(x)].lifted_lets) > 0 + if insert_lets: + local_builder = [] + context.append(self.scopes[id(x)]) + else: + local_builder = builder + head = x.render_head(self) + if head != '': + local_builder.append(head) + children = x.render_children(self) + for i in range(0, len(children)): + local_builder.append(' ') + child = children[i] + lift_to = self.lifted_in_context(child, context) + if lift_to >= 0: + name = context[lift_to].lifted_lets[id(child)] + if id(child) not in context[lift_to].visited: + context[lift_to].visited.add(id(child)) + let_body = [f'(Let {name} '] + self.print(let_body, context, child) + let_body.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + context[lift_to].let_bodies.append(let_body) + local_builder.append(f'(Ref {name})') + else: + self.print(local_builder, context, child) + local_builder.append(x.render_tail(self)) + if insert_lets: + context.pop() + for let_body in self.scopes[id(x)].let_bodies: + builder.extend(let_body) + builder.extend(local_builder) + num_lets = len(self.scopes[id(x)].lifted_lets) + for i in range(num_lets): + builder.append(')') + + + def __call__(self, x: 'BaseIR') -> str: + root_scope = Scope() + free_vars = self.recur([root_scope], {}, x) + root_scope.visited = set() + assert(len(free_vars) == 0) + self.scopes[id(x)] = root_scope + builder = [] + self.print(builder, [], x) + return ''.join(builder) diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index e96b4f09dbe..92d0d654bec 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -125,12 +125,23 @@ def __init__(self, child, new_globals): self.child = child self.new_globals = new_globals + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.new_globals._compute_type(self.child.typ.global_env(), None) self._type = hl.ttable(self.new_globals.typ, self.child.typ.row_type, self.child.typ.row_key) + @staticmethod + def new_block(i): + return i == 1 + + def binds(self, i): + return self.global_env if i == 1 else {} + class TableExplode(TableIR): def __init__(self, child, path): @@ -176,6 +187,10 @@ def __init__(self, child, new_row): self.child = child self.new_row = new_row + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): # agg_env for scans self.new_row._compute_type(self.child.typ.row_env(), self.child.typ.row_env()) @@ -184,6 +199,9 @@ def _compute_type(self): self.new_row.typ, self.child.typ.row_key) + def binds(self, i): + return self.row_env if i == 1 else {} + class TableRead(TableIR): def __init__(self, reader, drop_rows=False): @@ -241,10 +259,17 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ + def binds(self, i): + return self.row_env if i == 1 else {} + class TableKeyByAndAggregate(TableIR): def __init__(self, child, expr, new_key, n_partitions, buffer_size): @@ -255,6 +280,10 @@ def __init__(self, child, expr, new_key, n_partitions, buffer_size): self.n_partitions = n_partitions self.buffer_size = buffer_size + @staticmethod + def new_block(i): + return i > 0 + def head_str(self): return f'{self.n_partitions} {self.buffer_size}' @@ -268,6 +297,9 @@ def _compute_type(self): self.new_key.typ._concat(self.expr.typ), list(self.new_key.typ)) + def binds(self, i): + return self.row_env if i == 1 or i == 2 else {} + class TableAggregateByKey(TableIR): def __init__(self, child, expr): @@ -275,6 +307,10 @@ def __init__(self, child, expr): self.child = child self.expr = expr + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.expr._compute_type(child_typ.global_env(), child_typ.row_env()) @@ -282,6 +318,9 @@ def _compute_type(self): child_typ.key_type._concat(self.expr.typ), child_typ.row_key) + def binds(self, i): + return self.row_env if i == 1 else {} + class MatrixColsTable(TableIR): def __init__(self, child): @@ -300,6 +339,8 @@ def __init__(self, rows_and_global, n_partitions): self.rows_and_global = rows_and_global self.n_partitions = n_partitions + # FIXME: Does this need a new_block? + def head_str(self): return self.n_partitions diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 29ad1787604..1056470b295 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -1,7 +1,9 @@ import unittest import hail as hl import hail.ir as ir +from hail.ir.renderer import CSERenderer, NewCSE from hail.expr import construct_expr +from hail.expr.types import tint32 from hail.utils.java import Env from hail.utils import new_temp_file from .helpers import * @@ -351,3 +353,54 @@ def test_value_same_after_parsing(self): None)) new_globals = hl.eval(hl.Table(map_globals_ir).index_globals()) self.assertEqual(new_globals, hl.Struct(foo=v)) + +class CSETests(unittest.TestCase): + def test_cse(self): + x = ir.I32(5) + x = ir.ApplyBinaryPrimOp('+', x, x) + expected = ( + '(Let __cse_1 (I32 5)' + ' (ApplyBinaryPrimOp `+`' + ' (Ref __cse_1)' + ' (Ref __cse_1)))') + self.assertEqual(NewCSE()(x), expected) + + def test_cse2(self): + x = ir.I32(5) + y = ir.I32(4) + sum = ir.ApplyBinaryPrimOp('+', x, x) + prod = ir.ApplyBinaryPrimOp('*', sum, y) + div = ir.ApplyBinaryPrimOp('/', prod, sum) + expected = ( + '(Let __cse_1 (I32 5)' + ' (Let __cse_2 (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (ApplyBinaryPrimOp `/`' + ' (ApplyBinaryPrimOp `*`' + ' (Ref __cse_2)' + ' (I32 4))' + ' (Ref __cse_2))))') + self.assertEqual(NewCSE()(div), expected) + + def test_cse_ifs(self): + outer_repeated = ir.I32(5) + inner_repeated = ir.I32(1) + sum = ir.ApplyBinaryPrimOp('+', inner_repeated, inner_repeated) + prod = ir.ApplyBinaryPrimOp('*', sum, outer_repeated) + cond = ir.If(ir.TrueIR(), prod, outer_repeated) + expected = ( + '(If (True)' + ' (Let __cse_1 (I32 1)' + ' (ApplyBinaryPrimOp `*`' + ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (I32 5)))' + ' (I32 5))' + ) + self.assertEqual(NewCSE()(cond), expected) + # + # def test_foo(self): + # array = ir.MakeArray([ir.I32(5)], tint32) + # ref = ir.Ref('x') + # sum = ir.ApplyBinaryPrimOp('+', ref, ref) + # map = ir.ArrayMap(array, 'x', sum) + # print(NewCSE()(map)) + # self.assertTrue(False) From 689ca407c3880c4e321129eeb193a0716171fcec Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 29 Jul 2019 14:57:16 -0400 Subject: [PATCH 02/47] works with agg bindings --- hail/python/hail/backend/backend.py | 2 +- hail/python/hail/ir/__init__.py | 1 + hail/python/hail/ir/base_ir.py | 61 +++++++++- hail/python/hail/ir/ir.py | 173 +++++++++++++++++++++------- hail/python/hail/ir/matrix_ir.py | 85 +++++++++++--- hail/python/hail/ir/renderer.py | 72 +++++++----- hail/python/hail/ir/table_ir.py | 39 ++++++- hail/python/test/hail/test_ir.py | 8 +- 8 files changed, 339 insertions(+), 102 deletions(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 163981965c2..8a993e76b05 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -6,7 +6,7 @@ from hail.expr.table_type import * from hail.expr.matrix_type import * from hail.expr.blockmatrix_type import * -from hail.ir.renderer import Renderer, CSERenderer, NewCSE +from hail.ir.renderer import Renderer, CSERenderer from hail.table import Table from hail.matrixtable import MatrixTable diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index 614d25abe21..38f22c98770 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -1,3 +1,4 @@ +from .base_ir import * from .ir import * from .register_functions import * from .register_aggregators import * diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 6070233d306..bb2483c5cf3 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,11 +1,23 @@ import abc -from typing import List, Set +from typing import List, Set, Optional from hail.utils.java import Env from .renderer import Renderer, Renderable, RenderableStr +def _env_bind(env, *bindings): + if bindings: + if env: + res = env.copy() + res.update(bindings) + return res + else: + return dict(bindings) + else: + return env + + class BaseIR(Renderable): def __init__(self, *children): super().__init__() @@ -72,12 +84,49 @@ def _eq(self, other): def __hash__(self): return 31 + hash(str(self)) - @staticmethod - def new_block(i: int): + def new_block(self, i: int): + return self.uses_agg_context(i) or self.uses_scan_context(i) + + def binds(self, i: int): + return False + + def bindings(self, i: int): + return [] + + def agg_bindings(self, i: int): + return [] + + def scan_bindings(self, i: int): + return [] + + def uses_agg_context(self, i: int): + return False + + def uses_scan_context(self, i: int): return False - def binds(self, i: int) -> Set[str]: - return set() + def child_context_without_bindings(self, i: int, parent_context): + (eval_c, agg_c, scan_c) = parent_context + if self.uses_agg_context(i): + return (agg_c, None, None) + elif self.uses_scan_context(i): + return (scan_c, None, None) + else: + return parent_context + + def child_context(self, i: int, parent_context, default_value=None): + base = self.child_context_without_bindings(i, parent_context) + if not self.binds(i): + return base + eval_b = self.bindings(i) + agg_b = self.agg_bindings(i) + scan_b = self.scan_bindings(i) + if default_value: + eval_b = [(var, default_value) for (var, _) in eval_b] + agg_b = [(var, default_value) for (var, _) in agg_b] + scan_b = [(var, default_value) for (var, _) in scan_b] + (eval_c, agg_c, scan_c) = base + return _env_bind(eval_c, *eval_b), _env_bind(agg_c, *agg_b), _env_bind(scan_c, *scan_b) class IR(BaseIR): @@ -114,6 +163,8 @@ def map_ir(self, f): return self.copy(*new_children) + # FIXME: This is only used to compute the free variables in a subtree, + # in an incorrect way. Should just define a free variables method directly. @property def bound_variables(self): return {v for child in self.children for v in child.bound_variables} diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index d8be066d9dc..f157104f96a 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -9,15 +9,11 @@ from hail.typecheck import * from hail.utils.misc import escape_str, dump_json, parsable_strings, escape_id from .base_ir import * +from .base_ir import _env_bind from .matrix_writer import MatrixWriter, MatrixNativeMultiWriter from .renderer import Renderer, Renderable, RenderableStr, ParensRenderer from .table_writer import TableWriter -def _env_bind(env, k, v): - env = env.copy() - env[k] = v - return env - class I32(IR): @typecheck_method(x=int) @@ -236,8 +232,7 @@ def _compute_type(self, env, agg_env): assert (self.cnsq.typ == self.altr.typ) self._type = self.cnsq.typ - @staticmethod - def new_block(i): + def new_block(self, i): return i == 1 or i == 2 @@ -268,8 +263,7 @@ def __init__(self, name, value, body): self.value = value self.body = body - @staticmethod - def new_block(i): + def new_block(self, i): return i > 0 @typecheck_method(value=IR, body=IR) @@ -288,11 +282,14 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.value._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.value._type), agg_env) + self.body._compute_type(_env_bind(env, *self.bindings(1)), agg_env) self._type = self.body._type + def bindings(self, i): + return [(self.name, self.value._type)] if i == 1 else [] + def binds(self, i): - return {self.name} if i == 1 else {} + return i == 1 class Ref(IR): @@ -558,11 +555,14 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.nd._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.nd.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, (self.name, self.nd.typ.element_type)), agg_env) self._type = tndarray(self.body.typ, self.nd.typ.ndim) + def bindings(self, i): + return [(self.name, self.nd.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.name} if i == 1 else {} + return i == 1 class NDArrayRef(IR): @@ -710,8 +710,11 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self._type = self.a.typ + def bindings(self, i): + return [(self.l_name, self.a.typ.element_type), (self.r_name, self.a.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.l_name, self.r_name} if i == 1 else {} + return i == 1 class ToSet(IR): @@ -820,11 +823,14 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, (self.name, self.a.typ.element_type)), agg_env) self._type = tarray(self.body.typ) + def bindings(self, i): + return [(self.name, self.a.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.name} if i == 1 else {} + return i == 1 class ArrayFilter(IR): @@ -851,11 +857,14 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, (self.name, self.a.typ.element_type)), agg_env) self._type = self.a.typ + def bindings(self, i): + return [(self.name, self.a.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.name} if i == 1 else {} + return i == 1 class ArrayFlatMap(IR): @@ -882,11 +891,14 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, (self.name, self.a.typ.element_type)), agg_env) self._type = tarray(self.body.typ.element_type) + def bindings(self, i): + return [(self.name, self.a.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.name} if i == 1 else {} + return i == 1 class ArrayFold(IR): @@ -918,14 +930,16 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) self.body._compute_type( - _env_bind( - _env_bind(env, self.value_name, self.a.typ.element_type), - self.accum_name, self.zero.typ), + _env_bind(env, (self.value_name, self.a.typ.element_type), + (self.accum_name, self.zero.typ)), agg_env) self._type = self.zero.typ + def bindings(self, i): + return [(self.accum_name, self.zero.typ), (self.value_name, self.a.typ.element_type)] if i == 2 else [] + def binds(self, i): - return {self.accum_name, self.value_name} if i == 2 else {} + return i == 2 class ArrayScan(IR): @@ -957,14 +971,16 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) self.body._compute_type( - _env_bind( - _env_bind(env, self.value_name, self.a.typ.element_type), - self.accum_name, self.zero.typ), + _env_bind(env, (self.value_name, self.a.typ.element_type), + (self.accum_name, self.zero.typ)), agg_env) self._type = tarray(self.body.typ) + def bindings(self, i): + return [(self.accum_name, self.zero.typ), (self.value_name, self.a.typ.element_type)] if i == 2 else [] + def binds(self, i): - return {self.accum_name, self.value_name} if i == 2 else {} + return i == 2 class ArrayLeftJoinDistinct(IR): @@ -993,8 +1009,13 @@ def _eq(self, other): def bound_variables(self): return {self.l_name, self.r_name} | super().bound_variables + def bindings(self, i): + return [(self.l_name, self.left.typ.element_type), + (self.r_name, self.right.typ.element_type)]\ + if i == 2 or i == 3 else [] + def binds(self, i): - return {self.l_name, self.r_name} if i == 2 or i == 3 else {} + return i == 2 or i == 3 class ArrayFor(IR): @@ -1021,11 +1042,14 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.value_name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, (self.value_name, self.a.typ.element_type)), agg_env) self._type = tvoid + def bindings(self, i): + return [(self.value_name, self.a.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.value_name} if i == 1 else {} + return i == 1 class AggFilter(IR): @@ -1051,6 +1075,12 @@ def _compute_type(self, env, agg_env): self.agg_ir._compute_type(env, agg_env) self._type = self.agg_ir.typ + def uses_agg_context(self, i: int): + return i == 0 + + def uses_scan_context(self, i: int): + return i == 0 + class AggExplode(IR): @typecheck_method(array=IR, name=str, agg_body=IR, is_scan=bool) @@ -1077,11 +1107,26 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - self.agg_body._compute_type(env, _env_bind(agg_env, self.name, self.array.typ.element_type)) + new_agg = _env_bind(agg_env, (self.name, self.array.typ.element_type)) + if not new_agg: + print("...") + self.agg_body._compute_type(env, _env_bind(agg_env, (self.name, self.array.typ.element_type))) self._type = self.agg_body.typ + def agg_bindings(self, i): + return [(self.name, self.array.typ.element_type)] if i == 1 else [] + + def scan_bindings(self, i): + return [(self.name, self.array.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.name} if i == 1 else {} + return i == 1 + + def uses_agg_context(self, i: int): + return i == 0 + + def uses_scan_context(self, i: int): + return i == 0 class AggGroupBy(IR): @@ -1107,6 +1152,12 @@ def _compute_type(self, env, agg_env): self.agg_ir._compute_type(env, agg_env) self._type = tdict(self.key.typ, self.agg_ir.typ) + def uses_agg_context(self, i: int): + return i == 0 + + def uses_scan_context(self, i: int): + return i == 0 + class AggArrayPerElement(IR): @typecheck_method(array=IR, element_name=str, index_name=str, agg_ir=IR, is_scan=bool) @@ -1130,16 +1181,31 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - self.agg_ir._compute_type(_env_bind(env, self.index_name, tint32), - _env_bind(agg_env, self.element_name, self.array.typ.element_type)) + self.agg_ir._compute_type(_env_bind(env, (self.index_name, tint32)), + _env_bind(agg_env, (self.element_name, self.array.typ.element_type))) self._type = tarray(self.agg_ir.typ) @property def bound_variables(self): return {self.element_name, self.index_name} | super().bound_variables + def uses_agg_context(self, i: int): + return i == 0 + + def uses_scan_context(self, i: int): + return i == 0 + + def bindings(self, i): + return [(self.index_name, tint32)] if i == 1 else [] + + def agg_bindings(self, i): + return [(self.element_name, self.array.typ.element_type)] if i == 1 else [] + + def scan_bindings(self, i): + return [(self.element_name, self.array.typ.element_type)] if i == 1 else [] + def binds(self, i): - return {self.index_name, self.element_name} if i == 1 else {} + return i == 1 def _register(registry, name, f): @@ -1249,6 +1315,9 @@ class ApplyAggOp(BaseApplyAggOp): def __init__(self, agg_op, constructor_args, init_op_args, seq_op_args): super().__init__(agg_op, constructor_args, init_op_args, seq_op_args) + def uses_agg_context(self, i: int): + return i >= len(self.constructor_args) + (len(self.init_op_args) if self.init_op_args else 0) + class ApplyScanOp(BaseApplyAggOp): @typecheck_method(agg_op=str, @@ -1258,6 +1327,9 @@ class ApplyScanOp(BaseApplyAggOp): def __init__(self, agg_op, constructor_args, init_op_args, seq_op_args): super().__init__(agg_op, constructor_args, init_op_args, seq_op_args) + def uses_scan_context(self, i: int): + return i >= len(self.constructor_args) + (len(self.init_op_args) if self.init_op_args else 0) + class Begin(IR): @typecheck_method(xs=sequenceof(IR)) @@ -1654,13 +1726,16 @@ def _eq(self, other): return other.argname == self.argname def _compute_type(self, env, agg_env): - self.function._compute_type(_env_bind(env, self.argname, tfloat64), agg_env) + self.function._compute_type(_env_bind(env, (self.argname, tfloat64)), agg_env) self.min._compute_type(env, agg_env) self.max._compute_type(env, agg_env) self._type = tfloat64 + def bindings(self, i): + return [(self.argname, tfloat64)] if i == 0 else [] + def binds(self, i): - return {self.argname} if i == 0 else {} + return i == 0 class TableCount(IR): @@ -1724,8 +1799,17 @@ def _compute_type(self, env, agg_env): self.query._compute_type(self.child.typ.global_env(), self.child.typ.row_env()) self._type = self.query.typ + def new_block(self, i: int): + return i == 1 + + def bindings(self, i): + return self.child.typ.global_env() if i == 1 else [] + + def agg_bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + def binds(self, i): - return TableIR.row_env if i == 1 else {} + return i == 1 class MatrixAggregate(IR): @@ -1748,8 +1832,17 @@ def _compute_type(self, env, agg_env): self.query._compute_type(self.child.typ.global_env(), self.child.typ.entry_env()) self._type = self.query.typ + def new_block(self, i: int): + return i == 1 + + def bindings(self, i): + return self.child.typ.global_env() if i == 1 else [] + + def agg_bindings(self, i): + return self.child.typ.entry_env() if i == 1 else [] + def binds(self, i): - return MatrixIR.entry_env if i == 1 else {} + return i == 1 class TableWrite(IR): diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index 5b5663503e0..6bb4ba7d264 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -26,13 +26,24 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) - def binds(self, i): + def bindings(self, i): if i == 1: - return self.entry_env + return self.child.typ.col_env() elif i == 2: - return self.row_env + return self.child.typ.global_env() else: - return {} + return [] + + def agg_bindings(self, i): + if i == 1: + return self.child.typ.entry_env() + elif i == 2: + return self.child.typ.row_env() + else: + return [] + + def binds(self, i): + return i == 1 or i == 2 class MatrixRead(MatrixIR): @@ -66,8 +77,11 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ + def bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + def binds(self, i): - return self.row_env if i == 1 else {} + return i == 1 class MatrixChooseCols(MatrixIR): @@ -114,8 +128,17 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def bindings(self, i): + return self.child.typ.col_env() if i == 1 else [] + + def agg_bindings(self, i): + return self.child.typ.entry_env() if i == 1 else [] + + def scan_bindings(self, i): + return self.child.typ.col_env() if i == 1 else [] + def binds(self, i): - return self.entry_env if i == 1 else {} + return i == 1 class MatrixUnionCols(MatrixIR): @@ -150,8 +173,11 @@ def _compute_type(self): child_typ.row_key, self.new_entry.typ) + def bindings(self, i): + return self.child.typ.entry_env() if i == 1 else [] + def binds(self, i): - return self.entry_env if i == 1 else {} + return i == 1 class MatrixFilterEntries(MatrixIR): @@ -168,8 +194,11 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.entry_env(), None) self._type = self.child.typ + def bindings(self, i): + return self.child.typ.entry_env() if i == 1 else [] + def binds(self, i): - return self.entry_env if i == 1 else {} + return i == 1 class MatrixKeyRowsBy(MatrixIR): @@ -219,8 +248,17 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + + def agg_bindings(self, i): + return self.child.typ.entry_env() if i == 1 else [] + + def scan_bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + def binds(self, i): - return self.entry_env if i == 1 else {} + return i == 1 class MatrixMapGlobals(MatrixIR): @@ -244,8 +282,11 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def bindings(self, i): + return self.child.typ.global_env() if i == 1 else [] + def binds(self, i): - return self.global_env if i == 1 else {} + return i == 1 class MatrixFilterCols(MatrixIR): @@ -262,8 +303,11 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.col_env(), None) self._type = self.child.typ + def bindings(self, i): + return self.child.typ.col_env() if i == 1 else [] + def binds(self, i): - return self.col_env if i == 1 else {} + return i == 1 class MatrixCollectColsByKey(MatrixIR): @@ -306,13 +350,24 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) - def binds(self, i): + def bindings(self, i): if i == 1: - return self.entry_env + return self.child.typ.row_env() elif i == 2: - return self.col_env + return self.child.typ.global_env() else: - return {} + return [] + + def agg_bindings(self, i): + if i == 1: + return self.child.typ.entry_env() + elif i == 2: + return self.child.typ.col_env() + else: + return [] + + def binds(self, i): + return i == 1 or i == 2 class MatrixExplodeRows(MatrixIR): diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index ccea35c7792..9dbd8abdd52 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -172,18 +172,20 @@ def add_jir(self, x: 'ir.BaseIR'): assert isinstance(x, ir.IR) return f'(JavaIR {jir_id})' - def find_in_context(self, x: 'ir.BaseIR', context: List[Scope]) -> int: + def find_in_scope(self, x: 'ir.BaseIR', context: List[Scope]) -> int: for i in range(len(context)): if id(x) in context[i].visited: return i return -1 - def lifted_in_context(self, x: 'ir.BaseIR', context: List[Scope]) -> int: + def lifted_in_scope(self, x: 'ir.BaseIR', context: List[Scope]) -> int: for i in range(len(context)): if id(x) in context[i].lifted_lets: return i return -1 + Context = (Dict[str, int], Dict[str, int], Dict[str, int]) + # Pre: # * 'context' is a list of 'Scope's, one for each potential let-insertion # site. @@ -198,61 +200,64 @@ def lifted_in_context(self, x: 'ir.BaseIR', context: List[Scope]) -> int: # computed) outermost scope. # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - def recur(self, context: List[Scope], ref_to_scope: Dict[str, int], x: 'ir.BaseIR') -> Set[str]: + def recur(self, scopes: List[Scope], context: Context, x: 'ir.BaseIR') -> Set[str]: # Ref nodes should never be lifted to a let. (Not that it would be # incorrect, just pointlessly adding names for the same thing.) if isinstance(x, ir.Ref): return {x.name} + if isinstance(x, ir.GetField): + print('...') free_vars = set() for i in range(len(x.children)): child = x.children[i] # FIXME: maintain a union of seen nodes in all scopes - seen_in_context = self.find_in_context(child, context) - if seen_in_context >= 0: - # for now, only add non-relational lets - # FIXME: if generating uid can be moved to second pass, this - # can be factored into 'Scope' - if id(child) not in context[seen_in_context].lifted_lets and isinstance(child, ir.IR): + seen_in_scope = self.find_in_scope(child, scopes) + if seen_in_scope >= 0: + # we've seen 'child' before, no need to traverse + if id(child) not in scopes[seen_in_scope].lifted_lets and isinstance(child, ir.IR): # second time we've seen 'child', lift to a let - context[seen_in_context].lifted_lets[id(child)] = self.uid() + scopes[seen_in_scope].lifted_lets[id(child)] = self.uid() elif self.stop_at_jir and hasattr(child, '_jir'): self.memo[id(child)] = self.add_jir(child) else: - binds = x.binds(i) - if len(binds) > 0: + if x.binds(i) or x.new_block(i): + def get_vars(bindings): + if isinstance(bindings, dict): + bindings = bindings.items() + return [var for (var, _) in bindings] + eval_b = get_vars(x.bindings(i)) + agg_b = get_vars(x.agg_bindings(i)) + scan_b = get_vars(x.scan_bindings(i)) new_scope = Scope() if x.new_block(i): # Repeated subtrees of this child should never be lifted to # lets above 'x'. We accomplish that by clearing the context # in the recursive call. - child_context = [new_scope] - child_ref_to_scope = {} + child_scopes = [new_scope] + (eval_c, agg_c, scan_c) = ({}, {}, {}) new_idx = 0 else: - new_idx = len(context) - context.append(new_scope) - child_context = context - child_ref_to_scope = ref_to_scope.copy() - for v in binds: - child_ref_to_scope[v] = new_idx - child_free_vars = self.recur(child_context, child_ref_to_scope, child) - binds + new_idx = len(scopes) + scopes.append(new_scope) + child_scopes = scopes + (eval_c, agg_c, scan_c) = x.child_context_without_bindings(i, context) + eval_c = ir.base_ir._env_bind(eval_c, *[(var, new_idx) for var in eval_b]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, new_idx) for var in agg_b]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, new_idx) for var in scan_b]) + child_free_vars = self.recur(child_scopes, (eval_c, agg_c, scan_c), child) + child_free_vars.difference_update(eval_b, agg_b, scan_b) free_vars |= child_free_vars if not x.new_block(i): - context.pop() - new_scope.visited.clear() - self.scopes[id(child)] = new_scope - elif x.new_block(i): - new_scope = Scope() - free_vars |= self.recur([new_scope], {}, child) + scopes.pop() new_scope.visited.clear() self.scopes[id(child)] = new_scope else: - free_vars |= self.recur(context, ref_to_scope, child) + free_vars |= self.recur(scopes, x.child_context_without_bindings(i, context), child) if len(free_vars) > 0: - outermost_scope = max((ref_to_scope.get(v, 0) for v in free_vars)) + outermost_scope = max((context[0].get(v, 0) for v in free_vars)) else: outermost_scope = 0 - context[outermost_scope].visited.add(id(x)) + scopes[outermost_scope].visited.add(id(x)) return free_vars def print(self, builder: List[str], context: List[Scope], x: Renderable): @@ -272,7 +277,7 @@ def print(self, builder: List[str], context: List[Scope], x: Renderable): for i in range(0, len(children)): local_builder.append(' ') child = children[i] - lift_to = self.lifted_in_context(child, context) + lift_to = self.lifted_in_scope(child, context) if lift_to >= 0: name = context[lift_to].lifted_lets[id(child)] if id(child) not in context[lift_to].visited: @@ -298,9 +303,12 @@ def print(self, builder: List[str], context: List[Scope], x: Renderable): def __call__(self, x: 'BaseIR') -> str: + x.typ root_scope = Scope() - free_vars = self.recur([root_scope], {}, x) + free_vars = self.recur([root_scope], ({}, {}, {}), x) root_scope.visited = set() + if len(free_vars) != 0: + print('...') assert(len(free_vars) == 0) self.scopes[id(x)] = root_scope builder = [] diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 92d0d654bec..825a092835d 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -139,8 +139,11 @@ def _compute_type(self): def new_block(i): return i == 1 + def bindings(self, i): + return self.child.typ.global_env() if i == 1 else [] + def binds(self, i): - return self.global_env if i == 1 else {} + return i == 1 class TableExplode(TableIR): @@ -199,8 +202,14 @@ def _compute_type(self): self.new_row.typ, self.child.typ.row_key) + def bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + + def scan_bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + def binds(self, i): - return self.row_env if i == 1 else {} + return i == 1 class TableRead(TableIR): @@ -267,8 +276,11 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ + def bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + def binds(self, i): - return self.row_env if i == 1 else {} + return i == 1 class TableKeyByAndAggregate(TableIR): @@ -297,8 +309,19 @@ def _compute_type(self): self.new_key.typ._concat(self.expr.typ), list(self.new_key.typ)) + def bindings(self, i): + if i == 1: + return self.child.typ.global_env() + elif i == 2: + return self.child.typ.row_env() + else: + return [] + + def agg_bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + def binds(self, i): - return self.row_env if i == 1 or i == 2 else {} + return i == 1 or i == 2 class TableAggregateByKey(TableIR): @@ -318,8 +341,14 @@ def _compute_type(self): child_typ.key_type._concat(self.expr.typ), child_typ.row_key) + def bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + + def agg_bindings(self, i): + return self.child.typ.row_env() if i == 1 else [] + def binds(self, i): - return self.row_env if i == 1 else {} + return i == 1 class MatrixColsTable(TableIR): diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 1056470b295..5124bfc46d4 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -1,7 +1,7 @@ import unittest import hail as hl import hail.ir as ir -from hail.ir.renderer import CSERenderer, NewCSE +from hail.ir.renderer import CSERenderer from hail.expr import construct_expr from hail.expr.types import tint32 from hail.utils.java import Env @@ -363,7 +363,7 @@ def test_cse(self): ' (ApplyBinaryPrimOp `+`' ' (Ref __cse_1)' ' (Ref __cse_1)))') - self.assertEqual(NewCSE()(x), expected) + self.assertEqual(CSERenderer()(x), expected) def test_cse2(self): x = ir.I32(5) @@ -379,7 +379,7 @@ def test_cse2(self): ' (Ref __cse_2)' ' (I32 4))' ' (Ref __cse_2))))') - self.assertEqual(NewCSE()(div), expected) + self.assertEqual(CSERenderer()(div), expected) def test_cse_ifs(self): outer_repeated = ir.I32(5) @@ -395,7 +395,7 @@ def test_cse_ifs(self): ' (I32 5)))' ' (I32 5))' ) - self.assertEqual(NewCSE()(cond), expected) + self.assertEqual(CSERenderer()(cond), expected) # # def test_foo(self): # array = ir.MakeArray([ir.I32(5)], tint32) From ed133da523882a25a01940019d7cdfab81809c46 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 31 Jul 2019 14:40:18 -0400 Subject: [PATCH 03/47] some refactoring --- hail/python/hail/ir/renderer.py | 82 ++++++++++++++++----------------- 1 file changed, 39 insertions(+), 43 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 9dbd8abdd52..fb31816f173 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -172,14 +172,16 @@ def add_jir(self, x: 'ir.BaseIR'): assert isinstance(x, ir.IR) return f'(JavaIR {jir_id})' - def find_in_scope(self, x: 'ir.BaseIR', context: List[Scope]) -> int: - for i in range(len(context)): + @staticmethod + def find_in_scope(x: 'ir.BaseIR', context: List[Scope], outermost_scope: int) -> int: + for i in range(outermost_scope, len(context), -1): if id(x) in context[i].visited: return i return -1 - def lifted_in_scope(self, x: 'ir.BaseIR', context: List[Scope]) -> int: - for i in range(len(context)): + @staticmethod + def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: + for i in range(0, len(context), -1): if id(x) in context[i].lifted_lets: return i return -1 @@ -200,18 +202,16 @@ def lifted_in_scope(self, x: 'ir.BaseIR', context: List[Scope]) -> int: # computed) outermost scope. # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - def recur(self, scopes: List[Scope], context: Context, x: 'ir.BaseIR') -> Set[str]: + def recur(self, scopes: List[Scope], outermost_scope: int, context: Context, x: 'ir.BaseIR') -> Set[str]: # Ref nodes should never be lifted to a let. (Not that it would be # incorrect, just pointlessly adding names for the same thing.) if isinstance(x, ir.Ref): return {x.name} - if isinstance(x, ir.GetField): - print('...') free_vars = set() for i in range(len(x.children)): child = x.children[i] # FIXME: maintain a union of seen nodes in all scopes - seen_in_scope = self.find_in_scope(child, scopes) + seen_in_scope = self.find_in_scope(child, scopes, outermost_scope) if seen_in_scope >= 0: # we've seen 'child' before, no need to traverse if id(child) not in scopes[seen_in_scope].lifted_lets and isinstance(child, ir.IR): @@ -220,44 +220,40 @@ def recur(self, scopes: List[Scope], context: Context, x: 'ir.BaseIR') -> Set[st elif self.stop_at_jir and hasattr(child, '_jir'): self.memo[id(child)] = self.add_jir(child) else: + def get_vars(bindings): + if isinstance(bindings, dict): + bindings = bindings.items() + return [var for (var, _) in bindings] + eval_b = get_vars(x.bindings(i)) + agg_b = get_vars(x.agg_bindings(i)) + scan_b = get_vars(x.scan_bindings(i)) + new_scope = Scope() + depth = len(scopes) + scopes.append(new_scope) + if x.new_block(i): + child_context = ({}, {}, {}) + child_outermost_scope = depth + else: + child_context = x.child_context_without_bindings(i, context) + child_outermost_scope = outermost_scope + if x.binds(i): + (eval_c, agg_c, scan_c) = child_context + eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in eval_b]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in agg_b]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) + child_context = (eval_c, agg_c, scan_c) + child_free_vars = self.recur(scopes, child_outermost_scope, child_context, child) + child_free_vars.difference_update(eval_b, agg_b, scan_b) + free_vars |= child_free_vars + scopes.pop() + new_scope.visited.clear() if x.binds(i) or x.new_block(i): - def get_vars(bindings): - if isinstance(bindings, dict): - bindings = bindings.items() - return [var for (var, _) in bindings] - eval_b = get_vars(x.bindings(i)) - agg_b = get_vars(x.agg_bindings(i)) - scan_b = get_vars(x.scan_bindings(i)) - new_scope = Scope() - if x.new_block(i): - # Repeated subtrees of this child should never be lifted to - # lets above 'x'. We accomplish that by clearing the context - # in the recursive call. - child_scopes = [new_scope] - (eval_c, agg_c, scan_c) = ({}, {}, {}) - new_idx = 0 - else: - new_idx = len(scopes) - scopes.append(new_scope) - child_scopes = scopes - (eval_c, agg_c, scan_c) = x.child_context_without_bindings(i, context) - eval_c = ir.base_ir._env_bind(eval_c, *[(var, new_idx) for var in eval_b]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, new_idx) for var in agg_b]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, new_idx) for var in scan_b]) - child_free_vars = self.recur(child_scopes, (eval_c, agg_c, scan_c), child) - child_free_vars.difference_update(eval_b, agg_b, scan_b) - free_vars |= child_free_vars - if not x.new_block(i): - scopes.pop() - new_scope.visited.clear() self.scopes[id(child)] = new_scope - else: - free_vars |= self.recur(scopes, x.child_context_without_bindings(i, context), child) if len(free_vars) > 0: - outermost_scope = max((context[0].get(v, 0) for v in free_vars)) + bind_site = max((context[0].get(v, outermost_scope) for v in free_vars)) else: - outermost_scope = 0 - scopes[outermost_scope].visited.add(id(x)) + bind_site = outermost_scope + scopes[bind_site].visited.add(id(x)) return free_vars def print(self, builder: List[str], context: List[Scope], x: Renderable): @@ -305,7 +301,7 @@ def print(self, builder: List[str], context: List[Scope], x: Renderable): def __call__(self, x: 'BaseIR') -> str: x.typ root_scope = Scope() - free_vars = self.recur([root_scope], ({}, {}, {}), x) + free_vars = self.recur([root_scope], 0, ({}, {}, {}), x) root_scope.visited = set() if len(free_vars) != 0: print('...') From 3c392e6c655652556e178c203b81d0af04459b73 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 1 Aug 2019 14:04:06 -0400 Subject: [PATCH 04/47] finally fully working? --- hail/python/hail/ir/ir.py | 21 +++--- hail/python/hail/ir/renderer.py | 116 ++++++++++++++++++++++---------- 2 files changed, 87 insertions(+), 50 deletions(-) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index f157104f96a..8c68d0c0bd1 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -263,9 +263,6 @@ def __init__(self, name, value, body): self.value = value self.body = body - def new_block(self, i): - return i > 0 - @typecheck_method(value=IR, body=IR) def copy(self, value, body): return Let(self.name, value, body) @@ -1076,10 +1073,10 @@ def _compute_type(self, env, agg_env): self._type = self.agg_ir.typ def uses_agg_context(self, i: int): - return i == 0 + return i == 0 and not self.is_scan def uses_scan_context(self, i: int): - return i == 0 + return i == 0 and self.is_scan class AggExplode(IR): @@ -1108,8 +1105,6 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) new_agg = _env_bind(agg_env, (self.name, self.array.typ.element_type)) - if not new_agg: - print("...") self.agg_body._compute_type(env, _env_bind(agg_env, (self.name, self.array.typ.element_type))) self._type = self.agg_body.typ @@ -1123,10 +1118,10 @@ def binds(self, i): return i == 1 def uses_agg_context(self, i: int): - return i == 0 + return i == 0 and not self.is_scan def uses_scan_context(self, i: int): - return i == 0 + return i == 0 and self.is_scan class AggGroupBy(IR): @@ -1153,10 +1148,10 @@ def _compute_type(self, env, agg_env): self._type = tdict(self.key.typ, self.agg_ir.typ) def uses_agg_context(self, i: int): - return i == 0 + return i == 0 and not self.is_scan def uses_scan_context(self, i: int): - return i == 0 + return i == 0 and self.is_scan class AggArrayPerElement(IR): @@ -1190,10 +1185,10 @@ def bound_variables(self): return {self.element_name, self.index_name} | super().bound_variables def uses_agg_context(self, i: int): - return i == 0 + return i == 0 and not self.is_scan def uses_scan_context(self, i: int): - return i == 0 + return i == 0 and self.is_scan def bindings(self, i): return [(self.index_name, tint32)] if i == 1 else [] diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index fb31816f173..8a4836f065a 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -137,14 +137,12 @@ def __call__(self, x: 'Renderable'): class Scope: - def __init__(self): - self.visited: Set[int] = set() - self.lifted_lets: Dict[int, str] = {} + def __init__(self, depth: int): + self.depth = depth + self.visited: Dict[int, 'ir.BaseIR'] = {} + self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} self.let_bodies: List[str] = [] - def visit(self, x: 'BaseIR'): - self.visited.add(id(x)) - class CSERenderer(Renderer): def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir @@ -174,20 +172,28 @@ def add_jir(self, x: 'ir.BaseIR'): @staticmethod def find_in_scope(x: 'ir.BaseIR', context: List[Scope], outermost_scope: int) -> int: - for i in range(outermost_scope, len(context), -1): + for i in range(len(context) - 1, outermost_scope - 1, -1): if id(x) in context[i].visited: return i return -1 @staticmethod def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: - for i in range(0, len(context), -1): + for i in range(len(context) - 1, -1, -1): if id(x) in context[i].lifted_lets: return i return -1 Context = (Dict[str, int], Dict[str, int], Dict[str, int]) + # class State: + # def __init__(self, scopes, outermost_scope, context): + # self.scopes = scopes + # self.outermost_scope = outermost_scope + # self.context = context + # + # def loop(self, state, x, free_vars, i): + # Pre: # * 'context' is a list of 'Scope's, one for each potential let-insertion # site. @@ -202,21 +208,23 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: # computed) outermost scope. # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - def recur(self, scopes: List[Scope], outermost_scope: int, context: Context, x: 'ir.BaseIR') -> Set[str]: - # Ref nodes should never be lifted to a let. (Not that it would be - # incorrect, just pointlessly adding names for the same thing.) - if isinstance(x, ir.Ref): - return {x.name} - free_vars = set() + def recur(self, scopes: List[Scope], outermost_scope: int, context: Context, x: 'ir.BaseIR') -> Dict[str, int]: + free_vars = {} + depth = len(scopes) for i in range(len(x.children)): child = x.children[i] # FIXME: maintain a union of seen nodes in all scopes seen_in_scope = self.find_in_scope(child, scopes, outermost_scope) if seen_in_scope >= 0: # we've seen 'child' before, no need to traverse - if id(child) not in scopes[seen_in_scope].lifted_lets and isinstance(child, ir.IR): - # second time we've seen 'child', lift to a let - scopes[seen_in_scope].lifted_lets[id(child)] = self.uid() + if isinstance(child, ir.IR): + if id(child) not in scopes[seen_in_scope].lifted_lets: + # second time we've seen 'child', lift to a let + uid = self.uid() + scopes[seen_in_scope].lifted_lets[id(child)] = (uid, child) + else: + (uid, _) = scopes[seen_in_scope].lifted_lets[id(child)] + free_vars[uid] = seen_in_scope elif self.stop_at_jir and hasattr(child, '_jir'): self.memo[id(child)] = self.add_jir(child) else: @@ -227,36 +235,45 @@ def get_vars(bindings): eval_b = get_vars(x.bindings(i)) agg_b = get_vars(x.agg_bindings(i)) scan_b = get_vars(x.scan_bindings(i)) - new_scope = Scope() - depth = len(scopes) + new_scope = Scope(depth) scopes.append(new_scope) if x.new_block(i): - child_context = ({}, {}, {}) child_outermost_scope = depth else: - child_context = x.child_context_without_bindings(i, context) child_outermost_scope = outermost_scope + child_context = x.child_context_without_bindings(i, context) if x.binds(i): (eval_c, agg_c, scan_c) = child_context eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in eval_b]) agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in agg_b]) scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) child_context = (eval_c, agg_c, scan_c) - child_free_vars = self.recur(scopes, child_outermost_scope, child_context, child) - child_free_vars.difference_update(eval_b, agg_b, scan_b) - free_vars |= child_free_vars + if isinstance(child, ir.Ref): + if child.name not in child_context[0]: + print('...') + child_free_vars = {child.name: child_context[0][child.name]} + else: + child_free_vars = self.recur(scopes, child_outermost_scope, child_context, child) + for var in [*eval_b, *agg_b, *scan_b]: + var_depth = child_free_vars.pop(var, depth) + assert var_depth == depth + for (var, _) in new_scope.lifted_lets.values(): + var_depth = child_free_vars.pop(var, depth) + assert var_depth == depth + free_vars.update(child_free_vars) scopes.pop() new_scope.visited.clear() - if x.binds(i) or x.new_block(i): + if new_scope.lifted_lets: self.scopes[id(child)] = new_scope if len(free_vars) > 0: - bind_site = max((context[0].get(v, outermost_scope) for v in free_vars)) + bind_depth = max(free_vars.values()) + bind_depth = max(bind_depth, outermost_scope) else: - bind_site = outermost_scope - scopes[bind_site].visited.add(id(x)) + bind_depth = outermost_scope + scopes[bind_depth].visited[id(x)] = x return free_vars - def print(self, builder: List[str], context: List[Scope], x: Renderable): + def print(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: builder.append(self.memo[id(x)]) return @@ -270,23 +287,31 @@ def print(self, builder: List[str], context: List[Scope], x: Renderable): if head != '': local_builder.append(head) children = x.render_children(self) + ir_child_num = 0 for i in range(0, len(children)): local_builder.append(' ') child = children[i] lift_to = self.lifted_in_scope(child, context) - if lift_to >= 0: - name = context[lift_to].lifted_lets[id(child)] + child_outermost_scope = outermost_scope + if x.new_block(ir_child_num): + child_outermost_scope = depth + if lift_to >= 0 and context[lift_to] and context[lift_to].depth >= outermost_scope: + ir_child_num += 1 + (name, _) = context[lift_to].lifted_lets[id(child)] if id(child) not in context[lift_to].visited: - context[lift_to].visited.add(id(child)) + context[lift_to].visited[id(child)] = child let_body = [f'(Let {name} '] - self.print(let_body, context, child) + self.print(let_body, context, child_outermost_scope, depth + 1, child) let_body.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets context[lift_to].let_bodies.append(let_body) local_builder.append(f'(Ref {name})') else: - self.print(local_builder, context, child) + if isinstance(child, ir.BaseIR): + self.print(local_builder, context, child_outermost_scope, depth + 1, child) + else: + ir_child_num = self.print_renderable(local_builder, context, child_outermost_scope, depth + 1, ir_child_num, child) local_builder.append(x.render_tail(self)) if insert_lets: context.pop() @@ -297,16 +322,33 @@ def print(self, builder: List[str], context: List[Scope], x: Renderable): for i in range(num_lets): builder.append(')') + def print_renderable(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable): + assert not isinstance(x, ir.BaseIR) + head = x.render_head(self) + if head != '': + builder.append(head) + children = x.render_children(self) + for i in range(0, len(children)): + builder.append(' ') + child = children[i] + if isinstance(child, ir.BaseIR): + self.print(builder, context, outermost_scope, depth, child) + ir_child_num += 1 + else: + ir_child_num = self.print_renderable(builder, context, outermost_scope, depth, ir_child_num, child) + builder.append(x.render_tail(self)) + return ir_child_num + def __call__(self, x: 'BaseIR') -> str: x.typ - root_scope = Scope() + root_scope = Scope(0) free_vars = self.recur([root_scope], 0, ({}, {}, {}), x) - root_scope.visited = set() + root_scope.visited = {} if len(free_vars) != 0: print('...') assert(len(free_vars) == 0) self.scopes[id(x)] = root_scope builder = [] - self.print(builder, [], x) + self.print(builder, [], 0, 1, x) return ''.join(builder) From 9e788d3c0a2b5c901969aabc3ffd0b3b6ccf0bfe Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 1 Aug 2019 14:29:51 -0400 Subject: [PATCH 05/47] loop -> recursion --- hail/python/hail/ir/renderer.py | 137 +++++++++++++++++--------------- 1 file changed, 74 insertions(+), 63 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 8a4836f065a..1b8cfd18812 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -143,6 +143,16 @@ def __init__(self, depth: int): self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} self.let_bodies: List[str] = [] +class State: + def __init__(self, scopes, outermost_scope, context, depth, free_vars): + # immutable + self.outermost_scope = outermost_scope + self.context = context + self.depth = depth + # mutable + self.scopes = scopes + self.free_vars = free_vars + class CSERenderer(Renderer): def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir @@ -186,13 +196,69 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: Context = (Dict[str, int], Dict[str, int], Dict[str, int]) - # class State: - # def __init__(self, scopes, outermost_scope, context): - # self.scopes = scopes - # self.outermost_scope = outermost_scope - # self.context = context - # - # def loop(self, state, x, free_vars, i): + def loop(self, state, x, i): + if i >= len(x.children): + return state.free_vars + child = x.children[i] + seen_in_scope = self.find_in_scope(child, state.scopes, state.outermost_scope) + if seen_in_scope >= 0: + # we've seen 'child' before, should not traverse (or we will find + # too many lifts) + if isinstance(child, ir.IR): + if id(child) not in state.scopes[seen_in_scope].lifted_lets: + # second time we've seen 'child', lift to a let + uid = self.uid() + state.scopes[seen_in_scope].lifted_lets[id(child)] = (uid, child) + else: + (uid, _) = state.scopes[seen_in_scope].lifted_lets[id(child)] + state.free_vars[uid] = seen_in_scope + return self.loop(state, x, i + 1) + elif self.stop_at_jir and hasattr(child, '_jir'): + self.memo[id(child)] = self.add_jir(child) + return self.loop(state, x, i + 1) + else: + def get_vars(bindings): + if isinstance(bindings, dict): + bindings = bindings.items() + return [var for (var, _) in bindings] + eval_b = get_vars(x.bindings(i)) + agg_b = get_vars(x.agg_bindings(i)) + scan_b = get_vars(x.scan_bindings(i)) + new_scope = Scope(state.depth) + state.scopes.append(new_scope) + + if x.new_block(i): + child_outermost_scope = state.depth + else: + child_outermost_scope = state.outermost_scope + + child_context = x.child_context_without_bindings(i, state.context) + if x.binds(i): + (eval_c, agg_c, scan_c) = child_context + eval_c = ir.base_ir._env_bind(eval_c, *[(var, state.depth) for var in eval_b]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, state.depth) for var in agg_b]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, state.depth) for var in scan_b]) + child_context = (eval_c, agg_c, scan_c) + + # continuation uses: bindings, scopes, child_outermost_scope, child_context + + if isinstance(child, ir.Ref): + child_free_vars = {child.name: child_context[0][child.name]} + else: + child_free_vars = self.recur(state.scopes, child_outermost_scope, child_context, child) + + for var in [*eval_b, *agg_b, *scan_b]: + var_depth = child_free_vars.pop(var, state.depth) + assert var_depth == state.depth + for (var, _) in new_scope.lifted_lets.values(): + var_depth = child_free_vars.pop(var, state.depth) + assert var_depth == state.depth + state.free_vars.update(child_free_vars) + state.scopes.pop() + new_scope.visited.clear() + if new_scope.lifted_lets: + self.scopes[id(child)] = new_scope + return self.loop(state, x, i + 1) # Pre: # * 'context' is a list of 'Scope's, one for each potential let-insertion @@ -209,62 +275,7 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. def recur(self, scopes: List[Scope], outermost_scope: int, context: Context, x: 'ir.BaseIR') -> Dict[str, int]: - free_vars = {} - depth = len(scopes) - for i in range(len(x.children)): - child = x.children[i] - # FIXME: maintain a union of seen nodes in all scopes - seen_in_scope = self.find_in_scope(child, scopes, outermost_scope) - if seen_in_scope >= 0: - # we've seen 'child' before, no need to traverse - if isinstance(child, ir.IR): - if id(child) not in scopes[seen_in_scope].lifted_lets: - # second time we've seen 'child', lift to a let - uid = self.uid() - scopes[seen_in_scope].lifted_lets[id(child)] = (uid, child) - else: - (uid, _) = scopes[seen_in_scope].lifted_lets[id(child)] - free_vars[uid] = seen_in_scope - elif self.stop_at_jir and hasattr(child, '_jir'): - self.memo[id(child)] = self.add_jir(child) - else: - def get_vars(bindings): - if isinstance(bindings, dict): - bindings = bindings.items() - return [var for (var, _) in bindings] - eval_b = get_vars(x.bindings(i)) - agg_b = get_vars(x.agg_bindings(i)) - scan_b = get_vars(x.scan_bindings(i)) - new_scope = Scope(depth) - scopes.append(new_scope) - if x.new_block(i): - child_outermost_scope = depth - else: - child_outermost_scope = outermost_scope - child_context = x.child_context_without_bindings(i, context) - if x.binds(i): - (eval_c, agg_c, scan_c) = child_context - eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in eval_b]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in agg_b]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) - child_context = (eval_c, agg_c, scan_c) - if isinstance(child, ir.Ref): - if child.name not in child_context[0]: - print('...') - child_free_vars = {child.name: child_context[0][child.name]} - else: - child_free_vars = self.recur(scopes, child_outermost_scope, child_context, child) - for var in [*eval_b, *agg_b, *scan_b]: - var_depth = child_free_vars.pop(var, depth) - assert var_depth == depth - for (var, _) in new_scope.lifted_lets.values(): - var_depth = child_free_vars.pop(var, depth) - assert var_depth == depth - free_vars.update(child_free_vars) - scopes.pop() - new_scope.visited.clear() - if new_scope.lifted_lets: - self.scopes[id(child)] = new_scope + free_vars = self.loop(State(scopes, outermost_scope, context, len(scopes), {}), x, 0) if len(free_vars) > 0: bind_depth = max(free_vars.values()) bind_depth = max(bind_depth, outermost_scope) From c2e9f9268cfa1333d2fa65e1db3b9c409b9a73e9 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 1 Aug 2019 15:02:48 -0400 Subject: [PATCH 06/47] CPS transform --- hail/python/hail/ir/renderer.py | 107 +++++++++++++++++-------------- hail/python/test/hail/test_ir.py | 6 +- 2 files changed, 61 insertions(+), 52 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 1b8cfd18812..1759d573aee 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,6 +1,6 @@ from hail import ir import abc -from typing import Sequence, List, Set, Dict +from typing import Sequence, List, Set, Dict, Callable class Renderable(object): @@ -143,15 +143,18 @@ def __init__(self, depth: int): self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} self.let_bodies: List[str] = [] -class State: - def __init__(self, scopes, outermost_scope, context, depth, free_vars): +class LoopState: + def __init__(self, scopes, outermost_scope, context): # immutable self.outermost_scope = outermost_scope self.context = context - self.depth = depth # mutable self.scopes = scopes - self.free_vars = free_vars + self.free_vars = {} + self.i = 0 + +Vars = Dict[str, int] +Context = (Vars, Vars, Vars) class CSERenderer(Renderer): def __init__(self, stop_at_jir=False): @@ -194,12 +197,11 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: return i return -1 - Context = (Dict[str, int], Dict[str, int], Dict[str, int]) - - def loop(self, state, x, i): + def loop(self, state, x, i, k: Callable[[Vars], str]) -> str: if i >= len(x.children): - return state.free_vars + return k(state.free_vars) child = x.children[i] + depth = len(state.scopes) seen_in_scope = self.find_in_scope(child, state.scopes, state.outermost_scope) if seen_in_scope >= 0: # we've seen 'child' before, should not traverse (or we will find @@ -212,10 +214,10 @@ def loop(self, state, x, i): else: (uid, _) = state.scopes[seen_in_scope].lifted_lets[id(child)] state.free_vars[uid] = seen_in_scope - return self.loop(state, x, i + 1) + return self.loop(state, x, i + 1, k) elif self.stop_at_jir and hasattr(child, '_jir'): self.memo[id(child)] = self.add_jir(child) - return self.loop(state, x, i + 1) + return self.loop(state, x, i + 1, k) else: def get_vars(bindings): if isinstance(bindings, dict): @@ -224,41 +226,43 @@ def get_vars(bindings): eval_b = get_vars(x.bindings(i)) agg_b = get_vars(x.agg_bindings(i)) scan_b = get_vars(x.scan_bindings(i)) - new_scope = Scope(state.depth) + new_scope = Scope(depth) state.scopes.append(new_scope) if x.new_block(i): - child_outermost_scope = state.depth + child_outermost_scope = depth else: child_outermost_scope = state.outermost_scope child_context = x.child_context_without_bindings(i, state.context) if x.binds(i): (eval_c, agg_c, scan_c) = child_context - eval_c = ir.base_ir._env_bind(eval_c, *[(var, state.depth) for var in eval_b]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, state.depth) for var in agg_b]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, state.depth) for var in scan_b]) + eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in eval_b]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in agg_b]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) child_context = (eval_c, agg_c, scan_c) # continuation uses: bindings, scopes, child_outermost_scope, child_context + def kont(child_free_vars: Vars) -> str: + for var in [*eval_b, *agg_b, *scan_b]: + var_depth = child_free_vars.pop(var, depth) + assert var_depth == depth + for (var, _) in new_scope.lifted_lets.values(): + var_depth = child_free_vars.pop(var, depth) + assert var_depth == depth + state.free_vars.update(child_free_vars) + state.scopes.pop() + new_scope.visited.clear() + if new_scope.lifted_lets: + self.scopes[id(child)] = new_scope + return self.loop(state, x, i + 1, k) + if isinstance(child, ir.Ref): - child_free_vars = {child.name: child_context[0][child.name]} + return kont({child.name: child_context[0][child.name]}) else: - child_free_vars = self.recur(state.scopes, child_outermost_scope, child_context, child) - - for var in [*eval_b, *agg_b, *scan_b]: - var_depth = child_free_vars.pop(var, state.depth) - assert var_depth == state.depth - for (var, _) in new_scope.lifted_lets.values(): - var_depth = child_free_vars.pop(var, state.depth) - assert var_depth == state.depth - state.free_vars.update(child_free_vars) - state.scopes.pop() - new_scope.visited.clear() - if new_scope.lifted_lets: - self.scopes[id(child)] = new_scope - return self.loop(state, x, i + 1) + return self.recur(state.scopes, child_outermost_scope, child_context, child, kont) + # Pre: # * 'context' is a list of 'Scope's, one for each potential let-insertion @@ -274,15 +278,16 @@ def get_vars(bindings): # computed) outermost scope. # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - def recur(self, scopes: List[Scope], outermost_scope: int, context: Context, x: 'ir.BaseIR') -> Dict[str, int]: - free_vars = self.loop(State(scopes, outermost_scope, context, len(scopes), {}), x, 0) - if len(free_vars) > 0: - bind_depth = max(free_vars.values()) - bind_depth = max(bind_depth, outermost_scope) - else: - bind_depth = outermost_scope - scopes[bind_depth].visited[id(x)] = x - return free_vars + def recur(self, scopes: List[Scope], outermost_scope: int, context: Context, x: 'ir.BaseIR', k: Callable[[Vars], str]) -> str: + def kont(free_vars: Vars) -> str: + if len(free_vars) > 0: + bind_depth = max(free_vars.values()) + bind_depth = max(bind_depth, outermost_scope) + else: + bind_depth = outermost_scope + scopes[bind_depth].visited[id(x)] = x + return k(free_vars) + return self.loop(LoopState(scopes, outermost_scope, context), x, 0, kont) def print(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: @@ -354,12 +359,16 @@ def print_renderable(self, builder: List[str], context: List[Scope], outermost_s def __call__(self, x: 'BaseIR') -> str: x.typ root_scope = Scope(0) - free_vars = self.recur([root_scope], 0, ({}, {}, {}), x) - root_scope.visited = {} - if len(free_vars) != 0: - print('...') - assert(len(free_vars) == 0) - self.scopes[id(x)] = root_scope - builder = [] - self.print(builder, [], 0, 1, x) - return ''.join(builder) + def kont(free_vars: Vars) -> str: + root_scope.visited = {} + if len(free_vars) != 0: + print('...') + for (var, _) in root_scope.lifted_lets.values(): + var_depth = free_vars.pop(var, 0) + assert var_depth == 0 + assert(len(free_vars) == 0) + self.scopes[id(x)] = root_scope + builder = [] + self.print(builder, [], 0, 1, x) + return ''.join(builder) + return self.recur([root_scope], 0, ({}, {}, {}), x, kont) diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 5124bfc46d4..7928f17ece6 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -363,7 +363,7 @@ def test_cse(self): ' (ApplyBinaryPrimOp `+`' ' (Ref __cse_1)' ' (Ref __cse_1)))') - self.assertEqual(CSERenderer()(x), expected) + self.assertEqual(expected, CSERenderer()(x)) def test_cse2(self): x = ir.I32(5) @@ -379,7 +379,7 @@ def test_cse2(self): ' (Ref __cse_2)' ' (I32 4))' ' (Ref __cse_2))))') - self.assertEqual(CSERenderer()(div), expected) + self.assertEqual(expected, CSERenderer()(div)) def test_cse_ifs(self): outer_repeated = ir.I32(5) @@ -395,7 +395,7 @@ def test_cse_ifs(self): ' (I32 5)))' ' (I32 5))' ) - self.assertEqual(CSERenderer()(cond), expected) + self.assertEqual(expected, CSERenderer()(cond)) # # def test_foo(self): # array = ir.MakeArray([ir.I32(5)], tint32) From 6619b17264f2aedd3ed840ef33f595c39b5cdbe5 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 1 Aug 2019 15:12:36 -0400 Subject: [PATCH 07/47] inline recur --- hail/python/hail/ir/renderer.py | 58 ++++++++++++++++----------------- 1 file changed, 28 insertions(+), 30 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 1759d573aee..e7c4b766c7a 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -197,6 +197,20 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: return i return -1 + # Pre: + # * 'context' is a list of 'Scope's, one for each potential let-insertion + # site. + # * 'ref_to_scope' maps each bound variable to the index in 'context' of the + # scope of its binding site. + # Post: + # * Returns set of free variables in 'x'. + # * Each subtree of 'x' is flagged as visited in the outermost scope + # containing all of its free variables. + # * Each subtree previously visited (either an earlier subtree of 'x', or + # marked visited in 'context') is added to set of lets in its (previously + # computed) outermost scope. + # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing + # any lets to be inserted above y. def loop(self, state, x, i, k: Callable[[Vars], str]) -> str: if i >= len(x.children): return k(state.free_vars) @@ -244,7 +258,7 @@ def get_vars(bindings): # continuation uses: bindings, scopes, child_outermost_scope, child_context - def kont(child_free_vars: Vars) -> str: + def post_visit(child_free_vars: Vars) -> str: for var in [*eval_b, *agg_b, *scan_b]: var_depth = child_free_vars.pop(var, depth) assert var_depth == depth @@ -259,35 +273,17 @@ def kont(child_free_vars: Vars) -> str: return self.loop(state, x, i + 1, k) if isinstance(child, ir.Ref): - return kont({child.name: child_context[0][child.name]}) + return post_visit({child.name: child_context[0][child.name]}) else: - return self.recur(state.scopes, child_outermost_scope, child_context, child, kont) - - - # Pre: - # * 'context' is a list of 'Scope's, one for each potential let-insertion - # site. - # * 'ref_to_scope' maps each bound variable to the index in 'context' of the - # scope of its binding site. - # Post: - # * Returns set of free variables in 'x'. - # * Each subtree of 'x' is flagged as visited in the outermost scope - # containing all of its free variables. - # * Each subtree previously visited (either an earlier subtree of 'x', or - # marked visited in 'context') is added to set of lets in its (previously - # computed) outermost scope. - # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing - # any lets to be inserted above y. - def recur(self, scopes: List[Scope], outermost_scope: int, context: Context, x: 'ir.BaseIR', k: Callable[[Vars], str]) -> str: - def kont(free_vars: Vars) -> str: - if len(free_vars) > 0: - bind_depth = max(free_vars.values()) - bind_depth = max(bind_depth, outermost_scope) - else: - bind_depth = outermost_scope - scopes[bind_depth].visited[id(x)] = x - return k(free_vars) - return self.loop(LoopState(scopes, outermost_scope, context), x, 0, kont) + def post_children(free_vars: Vars) -> str: + if len(free_vars) > 0: + bind_depth = max(free_vars.values()) + bind_depth = max(bind_depth, child_outermost_scope) + else: + bind_depth = child_outermost_scope + state.scopes[bind_depth].visited[id(child)] = child + return post_visit(free_vars) + return self.loop(LoopState(state.scopes, child_outermost_scope, child_context), child, 0, post_children) def print(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: @@ -359,6 +355,7 @@ def print_renderable(self, builder: List[str], context: List[Scope], outermost_s def __call__(self, x: 'BaseIR') -> str: x.typ root_scope = Scope(0) + def kont(free_vars: Vars) -> str: root_scope.visited = {} if len(free_vars) != 0: @@ -371,4 +368,5 @@ def kont(free_vars: Vars) -> str: builder = [] self.print(builder, [], 0, 1, x) return ''.join(builder) - return self.recur([root_scope], 0, ({}, {}, {}), x, kont) + + return self.loop(LoopState([root_scope], 0, ({}, {}, {})), x, 0, kont) From 28f23e36e14014314dd0b28dc198ee30200acb15 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 1 Aug 2019 16:28:47 -0400 Subject: [PATCH 08/47] defunctionalization --- hail/python/hail/ir/renderer.py | 118 ++++++++++++++++++++------------ 1 file changed, 74 insertions(+), 44 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index e7c4b766c7a..f8ee5f8dd64 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -144,15 +144,34 @@ def __init__(self, depth: int): self.let_bodies: List[str] = [] class LoopState: - def __init__(self, scopes, outermost_scope, context): + def __init__(self, scopes, outermost_scope, context, x, k): # immutable self.outermost_scope = outermost_scope self.context = context + self.x = x + self.k = k # mutable self.scopes = scopes self.free_vars = {} self.i = 0 +class Kont: + pass + +class PostChildren(Kont): + def __init__(self, state, child_outermost_scope, eval_b, agg_b, scan_b, new_bindings): + self.state = state + self.child_outermost_scope = child_outermost_scope + self.eval_b = eval_b + self.agg_b = agg_b + self.scan_b = scan_b + self.new_bindings = new_bindings + +class BeginSecondPass(Kont): + def __init__(self, root_scope, root): + self.root_scope = root_scope + self.root = root + Vars = Dict[str, int] Context = (Vars, Vars, Vars) @@ -211,9 +230,11 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: # computed) outermost scope. # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - def loop(self, state, x, i, k: Callable[[Vars], str]) -> str: + def loop(self, state) -> str: + i = state.i + x = state.x if i >= len(x.children): - return k(state.free_vars) + return self.apply(state.k, state.free_vars) child = x.children[i] depth = len(state.scopes) seen_in_scope = self.find_in_scope(child, state.scopes, state.outermost_scope) @@ -228,10 +249,12 @@ def loop(self, state, x, i, k: Callable[[Vars], str]) -> str: else: (uid, _) = state.scopes[seen_in_scope].lifted_lets[id(child)] state.free_vars[uid] = seen_in_scope - return self.loop(state, x, i + 1, k) + state.i += 1 + return self.loop(state) elif self.stop_at_jir and hasattr(child, '_jir'): self.memo[id(child)] = self.add_jir(child) - return self.loop(state, x, i + 1, k) + state.i += 1 + return self.loop(state) else: def get_vars(bindings): if isinstance(bindings, dict): @@ -256,34 +279,53 @@ def get_vars(bindings): scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) child_context = (eval_c, agg_c, scan_c) - # continuation uses: bindings, scopes, child_outermost_scope, child_context - - def post_visit(child_free_vars: Vars) -> str: + if isinstance(child, ir.Ref): + state.scopes.pop() + child_free_vars = {child.name: child_context[0][child.name]} for var in [*eval_b, *agg_b, *scan_b]: - var_depth = child_free_vars.pop(var, depth) - assert var_depth == depth - for (var, _) in new_scope.lifted_lets.values(): - var_depth = child_free_vars.pop(var, depth) - assert var_depth == depth + child_free_vars.pop(var, 0) state.free_vars.update(child_free_vars) - state.scopes.pop() - new_scope.visited.clear() - if new_scope.lifted_lets: - self.scopes[id(child)] = new_scope - return self.loop(state, x, i + 1, k) - - if isinstance(child, ir.Ref): - return post_visit({child.name: child_context[0][child.name]}) + state.i += 1 + return self.loop(state) else: - def post_children(free_vars: Vars) -> str: - if len(free_vars) > 0: - bind_depth = max(free_vars.values()) - bind_depth = max(bind_depth, child_outermost_scope) - else: - bind_depth = child_outermost_scope - state.scopes[bind_depth].visited[id(child)] = child - return post_visit(free_vars) - return self.loop(LoopState(state.scopes, child_outermost_scope, child_context), child, 0, post_children) + pc = PostChildren(state, child_outermost_scope, eval_b, agg_b, scan_b, new_scope.lifted_lets.values()) + ls = LoopState(state.scopes, child_outermost_scope, child_context, child, pc) + return self.loop(ls) + + def apply(self, kont: Kont, free_vars: Vars): + if isinstance(kont, PostChildren): + child = kont.state.x.children[kont.state.i] + if len(free_vars) > 0: + bind_depth = max(free_vars.values()) + bind_depth = max(bind_depth, kont.child_outermost_scope) + else: + bind_depth = kont.child_outermost_scope + kont.state.scopes[bind_depth].visited[id(child)] = child + new_scope = kont.state.scopes[-1] + kont.state.scopes.pop() + new_scope.visited.clear() + if new_scope.lifted_lets: + self.scopes[id(child)] = new_scope + for var in [*kont.eval_b, *kont.agg_b, *kont.scan_b]: + free_vars.pop(var, 0) + for (var, _) in kont.new_bindings: + free_vars.pop(var, 0) + kont.state.free_vars.update(free_vars) + kont.state.i += 1 + return self.loop(kont.state) + else: + assert isinstance(kont, BeginSecondPass) + kont.root_scope.visited = {} + if len(free_vars) != 0: + print('...') + for (var, _) in kont.root_scope.lifted_lets.values(): + var_depth = free_vars.pop(var, 0) + assert var_depth == 0 + assert(len(free_vars) == 0) + self.scopes[id(kont.root)] = kont.root_scope + builder = [] + self.print(builder, [], 0, 1, kont.root) + return ''.join(builder) def print(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: @@ -356,17 +398,5 @@ def __call__(self, x: 'BaseIR') -> str: x.typ root_scope = Scope(0) - def kont(free_vars: Vars) -> str: - root_scope.visited = {} - if len(free_vars) != 0: - print('...') - for (var, _) in root_scope.lifted_lets.values(): - var_depth = free_vars.pop(var, 0) - assert var_depth == 0 - assert(len(free_vars) == 0) - self.scopes[id(x)] = root_scope - builder = [] - self.print(builder, [], 0, 1, x) - return ''.join(builder) - - return self.loop(LoopState([root_scope], 0, ({}, {}, {})), x, 0, kont) + k = BeginSecondPass(root_scope, x) + return self.loop(LoopState([root_scope], 0, ({}, {}, {}), x, k)) From 144a1bd5976090cf40b39459946c63ee9a34b3f2 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 1 Aug 2019 16:40:53 -0400 Subject: [PATCH 09/47] recursion -> loop --- hail/python/hail/ir/renderer.py | 197 ++++++++++++++++---------------- 1 file changed, 100 insertions(+), 97 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index f8ee5f8dd64..fcc2937e58e 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -230,102 +230,6 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: # computed) outermost scope. # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - def loop(self, state) -> str: - i = state.i - x = state.x - if i >= len(x.children): - return self.apply(state.k, state.free_vars) - child = x.children[i] - depth = len(state.scopes) - seen_in_scope = self.find_in_scope(child, state.scopes, state.outermost_scope) - if seen_in_scope >= 0: - # we've seen 'child' before, should not traverse (or we will find - # too many lifts) - if isinstance(child, ir.IR): - if id(child) not in state.scopes[seen_in_scope].lifted_lets: - # second time we've seen 'child', lift to a let - uid = self.uid() - state.scopes[seen_in_scope].lifted_lets[id(child)] = (uid, child) - else: - (uid, _) = state.scopes[seen_in_scope].lifted_lets[id(child)] - state.free_vars[uid] = seen_in_scope - state.i += 1 - return self.loop(state) - elif self.stop_at_jir and hasattr(child, '_jir'): - self.memo[id(child)] = self.add_jir(child) - state.i += 1 - return self.loop(state) - else: - def get_vars(bindings): - if isinstance(bindings, dict): - bindings = bindings.items() - return [var for (var, _) in bindings] - eval_b = get_vars(x.bindings(i)) - agg_b = get_vars(x.agg_bindings(i)) - scan_b = get_vars(x.scan_bindings(i)) - new_scope = Scope(depth) - state.scopes.append(new_scope) - - if x.new_block(i): - child_outermost_scope = depth - else: - child_outermost_scope = state.outermost_scope - - child_context = x.child_context_without_bindings(i, state.context) - if x.binds(i): - (eval_c, agg_c, scan_c) = child_context - eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in eval_b]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in agg_b]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) - child_context = (eval_c, agg_c, scan_c) - - if isinstance(child, ir.Ref): - state.scopes.pop() - child_free_vars = {child.name: child_context[0][child.name]} - for var in [*eval_b, *agg_b, *scan_b]: - child_free_vars.pop(var, 0) - state.free_vars.update(child_free_vars) - state.i += 1 - return self.loop(state) - else: - pc = PostChildren(state, child_outermost_scope, eval_b, agg_b, scan_b, new_scope.lifted_lets.values()) - ls = LoopState(state.scopes, child_outermost_scope, child_context, child, pc) - return self.loop(ls) - - def apply(self, kont: Kont, free_vars: Vars): - if isinstance(kont, PostChildren): - child = kont.state.x.children[kont.state.i] - if len(free_vars) > 0: - bind_depth = max(free_vars.values()) - bind_depth = max(bind_depth, kont.child_outermost_scope) - else: - bind_depth = kont.child_outermost_scope - kont.state.scopes[bind_depth].visited[id(child)] = child - new_scope = kont.state.scopes[-1] - kont.state.scopes.pop() - new_scope.visited.clear() - if new_scope.lifted_lets: - self.scopes[id(child)] = new_scope - for var in [*kont.eval_b, *kont.agg_b, *kont.scan_b]: - free_vars.pop(var, 0) - for (var, _) in kont.new_bindings: - free_vars.pop(var, 0) - kont.state.free_vars.update(free_vars) - kont.state.i += 1 - return self.loop(kont.state) - else: - assert isinstance(kont, BeginSecondPass) - kont.root_scope.visited = {} - if len(free_vars) != 0: - print('...') - for (var, _) in kont.root_scope.lifted_lets.values(): - var_depth = free_vars.pop(var, 0) - assert var_depth == 0 - assert(len(free_vars) == 0) - self.scopes[id(kont.root)] = kont.root_scope - builder = [] - self.print(builder, [], 0, 1, kont.root) - return ''.join(builder) def print(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: @@ -399,4 +303,103 @@ def __call__(self, x: 'BaseIR') -> str: root_scope = Scope(0) k = BeginSecondPass(root_scope, x) - return self.loop(LoopState([root_scope], 0, ({}, {}, {}), x, k)) + + state = LoopState([root_scope], 0, ({}, {}, {}), x, k) + while True: + i = state.i + x = state.x + if i >= len(x.children): + kont = state.k + free_vars = state.free_vars + + if isinstance(kont, PostChildren): + child = kont.state.x.children[kont.state.i] + if len(free_vars) > 0: + bind_depth = max(free_vars.values()) + bind_depth = max(bind_depth, kont.child_outermost_scope) + else: + bind_depth = kont.child_outermost_scope + kont.state.scopes[bind_depth].visited[id(child)] = child + new_scope = kont.state.scopes[-1] + kont.state.scopes.pop() + new_scope.visited.clear() + if new_scope.lifted_lets: + self.scopes[id(child)] = new_scope + for var in [*kont.eval_b, *kont.agg_b, *kont.scan_b]: + free_vars.pop(var, 0) + for (var, _) in kont.new_bindings: + free_vars.pop(var, 0) + kont.state.free_vars.update(free_vars) + state = kont.state + state.i += 1 + continue + else: + assert isinstance(kont, BeginSecondPass) + kont.root_scope.visited = {} + if len(free_vars) != 0: + print('...') + for (var, _) in kont.root_scope.lifted_lets.values(): + var_depth = free_vars.pop(var, 0) + assert var_depth == 0 + assert(len(free_vars) == 0) + self.scopes[id(kont.root)] = kont.root_scope + builder = [] + self.print(builder, [], 0, 1, kont.root) + return ''.join(builder) + + child = x.children[i] + depth = len(state.scopes) + seen_in_scope = self.find_in_scope(child, state.scopes, state.outermost_scope) + if seen_in_scope >= 0: + # we've seen 'child' before, should not traverse (or we will find + # too many lifts) + if isinstance(child, ir.IR): + if id(child) not in state.scopes[seen_in_scope].lifted_lets: + # second time we've seen 'child', lift to a let + uid = self.uid() + state.scopes[seen_in_scope].lifted_lets[id(child)] = (uid, child) + else: + (uid, _) = state.scopes[seen_in_scope].lifted_lets[id(child)] + state.free_vars[uid] = seen_in_scope + state.i += 1 + continue + elif self.stop_at_jir and hasattr(child, '_jir'): + self.memo[id(child)] = self.add_jir(child) + state.i += 1 + continue + else: + def get_vars(bindings): + if isinstance(bindings, dict): + bindings = bindings.items() + return [var for (var, _) in bindings] + eval_b = get_vars(x.bindings(i)) + agg_b = get_vars(x.agg_bindings(i)) + scan_b = get_vars(x.scan_bindings(i)) + new_scope = Scope(depth) + state.scopes.append(new_scope) + + if x.new_block(i): + child_outermost_scope = depth + else: + child_outermost_scope = state.outermost_scope + + child_context = x.child_context_without_bindings(i, state.context) + if x.binds(i): + (eval_c, agg_c, scan_c) = child_context + eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in eval_b]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in agg_b]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) + child_context = (eval_c, agg_c, scan_c) + + if isinstance(child, ir.Ref): + state.scopes.pop() + child_free_vars = {child.name: child_context[0][child.name]} + for var in [*eval_b, *agg_b, *scan_b]: + child_free_vars.pop(var, 0) + state.free_vars.update(child_free_vars) + state.i += 1 + continue + else: + pc = PostChildren(state, child_outermost_scope, eval_b, agg_b, scan_b, new_scope.lifted_lets.values()) + state = LoopState(state.scopes, child_outermost_scope, child_context, child, pc) + continue From e0c76adf18cd93b4dddbdaeaa1309abf109c75fa Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 2 Aug 2019 15:11:21 -0400 Subject: [PATCH 10/47] cleaning up --- hail/python/hail/ir/renderer.py | 237 +++++++++++++++---------------- hail/python/test/hail/test_ir.py | 1 + 2 files changed, 116 insertions(+), 122 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index fcc2937e58e..cdb91e159ac 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -143,38 +143,31 @@ def __init__(self, depth: int): self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} self.let_bodies: List[str] = [] -class LoopState: - def __init__(self, scopes, outermost_scope, context, x, k): - # immutable - self.outermost_scope = outermost_scope - self.context = context - self.x = x - self.k = k - # mutable - self.scopes = scopes - self.free_vars = {} - self.i = 0 - -class Kont: - pass - -class PostChildren(Kont): - def __init__(self, state, child_outermost_scope, eval_b, agg_b, scan_b, new_bindings): - self.state = state - self.child_outermost_scope = child_outermost_scope - self.eval_b = eval_b - self.agg_b = agg_b - self.scan_b = scan_b - self.new_bindings = new_bindings - -class BeginSecondPass(Kont): - def __init__(self, root_scope, root): - self.root_scope = root_scope - self.root = root Vars = Dict[str, int] Context = (Vars, Vars, Vars) + +class StackFrame: + def __init__(self, outermost_scope: int, context: Context, x: 'ir.BaseIR'): + # immutable + self.parent_outermost_scope = outermost_scope + self.parent_context = context + self.parent = x + # mutable + self.parent_free_vars = {} + self.visited: Dict[int, 'ir.BaseIR'] = {} + self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} + self.let_bodies: List[str] = [] + # cached state while visiting a child + self.child_idx = 0 + self.child_outermost_scope = None + self.eval_b = None + self.agg_b = None + self.scan_b = None + self.new_bindings = None + + class CSERenderer(Renderer): def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir @@ -298,108 +291,108 @@ def print_renderable(self, builder: List[str], context: List[Scope], outermost_s return ir_child_num - def __call__(self, x: 'BaseIR') -> str: + def __call__(self, x: 'ir.BaseIR') -> str: x.typ - root_scope = Scope(0) - - k = BeginSecondPass(root_scope, x) - state = LoopState([root_scope], 0, ({}, {}, {}), x, k) + state = StackFrame(0, ({}, {}, {}), x) + kont_stack = [state] while True: - i = state.i - x = state.x + state = kont_stack[-1] + depth = len(kont_stack) + i = state.child_idx + x = state.parent + if i >= len(x.children): - kont = state.k - free_vars = state.free_vars - - if isinstance(kont, PostChildren): - child = kont.state.x.children[kont.state.i] - if len(free_vars) > 0: - bind_depth = max(free_vars.values()) - bind_depth = max(bind_depth, kont.child_outermost_scope) - else: - bind_depth = kont.child_outermost_scope - kont.state.scopes[bind_depth].visited[id(child)] = child - new_scope = kont.state.scopes[-1] - kont.state.scopes.pop() - new_scope.visited.clear() - if new_scope.lifted_lets: - self.scopes[id(child)] = new_scope - for var in [*kont.eval_b, *kont.agg_b, *kont.scan_b]: - free_vars.pop(var, 0) - for (var, _) in kont.new_bindings: - free_vars.pop(var, 0) - kont.state.free_vars.update(free_vars) - state = kont.state - state.i += 1 - continue + if len(kont_stack) <= 1: + break + + child_state = state + + if len(child_state.parent_free_vars) > 0: + bind_depth = max(child_state.parent_free_vars.values()) + bind_depth = max(bind_depth, child_state.parent_outermost_scope) else: - assert isinstance(kont, BeginSecondPass) - kont.root_scope.visited = {} - if len(free_vars) != 0: - print('...') - for (var, _) in kont.root_scope.lifted_lets.values(): - var_depth = free_vars.pop(var, 0) - assert var_depth == 0 - assert(len(free_vars) == 0) - self.scopes[id(kont.root)] = kont.root_scope - builder = [] - self.print(builder, [], 0, 1, kont.root) - return ''.join(builder) + bind_depth = child_state.parent_outermost_scope + kont_stack[bind_depth].visited[id(x)] = x + + kont_stack.pop() + state = kont_stack[-1] + depth = len(kont_stack) + + if child_state.lifted_lets: + new_scope = Scope(depth) + new_scope.lifted_lets = child_state.lifted_lets + self.scopes[id(child_state.parent)] = new_scope + for var in [*state.eval_b, *state.agg_b, *state.scan_b]: + child_state.parent_free_vars.pop(var, 0) + for (var, _) in child_state.lifted_lets.values(): + child_state.parent_free_vars.pop(var, 0) + state.parent_free_vars.update(child_state.parent_free_vars) + state.child_idx += 1 + continue child = x.children[i] - depth = len(state.scopes) - seen_in_scope = self.find_in_scope(child, state.scopes, state.outermost_scope) - if seen_in_scope >= 0: + + if self.stop_at_jir and hasattr(child, '_jir'): + self.memo[id(child)] = self.add_jir(child) + state.child_idx += 1 + continue + + seen_in_scope = self.find_in_scope(child, kont_stack, state.parent_outermost_scope) + + if seen_in_scope >= 0 and isinstance(child, ir.IR): # we've seen 'child' before, should not traverse (or we will find # too many lifts) - if isinstance(child, ir.IR): - if id(child) not in state.scopes[seen_in_scope].lifted_lets: - # second time we've seen 'child', lift to a let - uid = self.uid() - state.scopes[seen_in_scope].lifted_lets[id(child)] = (uid, child) - else: - (uid, _) = state.scopes[seen_in_scope].lifted_lets[id(child)] - state.free_vars[uid] = seen_in_scope - state.i += 1 - continue - elif self.stop_at_jir and hasattr(child, '_jir'): - self.memo[id(child)] = self.add_jir(child) - state.i += 1 + if id(child) not in kont_stack[seen_in_scope].lifted_lets: + # second time we've seen 'child', lift to a let + uid = self.uid() + kont_stack[seen_in_scope].lifted_lets[id(child)] = (uid, child) + else: + (uid, _) = kont_stack[seen_in_scope].lifted_lets[id(child)] + state.parent_free_vars[uid] = seen_in_scope + state.child_idx += 1 continue + + def get_vars(bindings): + if isinstance(bindings, dict): + bindings = bindings.items() + return [var for (var, _) in bindings] + state.eval_b = get_vars(x.bindings(i)) + state.agg_b = get_vars(x.agg_bindings(i)) + state.scan_b = get_vars(x.scan_bindings(i)) + + if x.new_block(i): + state.child_outermost_scope = depth else: - def get_vars(bindings): - if isinstance(bindings, dict): - bindings = bindings.items() - return [var for (var, _) in bindings] - eval_b = get_vars(x.bindings(i)) - agg_b = get_vars(x.agg_bindings(i)) - scan_b = get_vars(x.scan_bindings(i)) - new_scope = Scope(depth) - state.scopes.append(new_scope) - - if x.new_block(i): - child_outermost_scope = depth - else: - child_outermost_scope = state.outermost_scope - - child_context = x.child_context_without_bindings(i, state.context) - if x.binds(i): - (eval_c, agg_c, scan_c) = child_context - eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in eval_b]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in agg_b]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in scan_b]) - child_context = (eval_c, agg_c, scan_c) - - if isinstance(child, ir.Ref): - state.scopes.pop() - child_free_vars = {child.name: child_context[0][child.name]} - for var in [*eval_b, *agg_b, *scan_b]: - child_free_vars.pop(var, 0) - state.free_vars.update(child_free_vars) - state.i += 1 - continue - else: - pc = PostChildren(state, child_outermost_scope, eval_b, agg_b, scan_b, new_scope.lifted_lets.values()) - state = LoopState(state.scopes, child_outermost_scope, child_context, child, pc) - continue + state.child_outermost_scope = state.parent_outermost_scope + + child_context = x.child_context_without_bindings(i, state.parent_context) + if x.binds(i): + (eval_c, agg_c, scan_c) = child_context + eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in state.eval_b]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in state.agg_b]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in state.scan_b]) + child_context = (eval_c, agg_c, scan_c) + + if isinstance(child, ir.Ref): + child_free_vars = {child.name: child_context[0][child.name]} + for var in [*state.eval_b, *state.agg_b, *state.scan_b]: + child_free_vars.pop(var, 0) + state.parent_free_vars.update(child_free_vars) + state.child_idx += 1 + continue + + state = StackFrame(state.child_outermost_scope, child_context, child) + kont_stack.append(state) + continue + + for (var, _) in state.lifted_lets.values(): + var_depth = state.parent_free_vars.pop(var, 0) + assert var_depth == 0 + assert(len(state.parent_free_vars) == 0) + root_scope = Scope(0) + root_scope.lifted_lets = state.lifted_lets + self.scopes[id(x)] = root_scope + builder = [] + self.print(builder, [], 0, 1, x) + return ''.join(builder) diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 7928f17ece6..34792cf08ca 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -354,6 +354,7 @@ def test_value_same_after_parsing(self): new_globals = hl.eval(hl.Table(map_globals_ir).index_globals()) self.assertEqual(new_globals, hl.Struct(foo=v)) + class CSETests(unittest.TestCase): def test_cse(self): x = ir.I32(5) From e12b7e552e07afc6a9cf2877f1bf04cd3ac86422 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 7 Aug 2019 08:50:49 -0400 Subject: [PATCH 11/47] some refactoring --- hail/python/hail/ir/renderer.py | 220 ++++++++++++++++++-------------- 1 file changed, 122 insertions(+), 98 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index cdb91e159ac..fbeb4dbeb15 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,6 +1,6 @@ from hail import ir import abc -from typing import Sequence, List, Set, Dict, Callable +from typing import Sequence, List, Set, Dict, Callable, Tuple class Renderable(object): @@ -136,12 +136,18 @@ def __call__(self, x: 'Renderable'): return ''.join(builder) -class Scope: - def __init__(self, depth: int): +class BindingSite: + def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int): self.depth = depth - self.visited: Dict[int, 'ir.BaseIR'] = {} - self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} - self.let_bodies: List[str] = [] + self.lifted_lets = lifted_lets + + +class PrintStackFrame: + def __init__(self, binding_site: BindingSite): + self.depth = binding_site.depth + self.lifted_lets = binding_site.lifted_lets + self.visited = {} + self.let_bodies = [] Vars = Dict[str, int] @@ -149,23 +155,48 @@ def __init__(self, depth: int): class StackFrame: - def __init__(self, outermost_scope: int, context: Context, x: 'ir.BaseIR'): + def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR'): # immutable - self.parent_outermost_scope = outermost_scope - self.parent_context = context - self.parent = x + self.min_binding_depth = min_binding_depth + self.context = context + self.node = x # mutable - self.parent_free_vars = {} + self.free_vars = {} self.visited: Dict[int, 'ir.BaseIR'] = {} self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} - self.let_bodies: List[str] = [] # cached state while visiting a child self.child_idx = 0 - self.child_outermost_scope = None - self.eval_b = None - self.agg_b = None - self.scan_b = None - self.new_bindings = None + self.child_eval_bindings = None + self.child_agg_bindings = None + self.child_scan_bindings = None + + # compute depth at which we might bind this node + def bind_depth(self) -> int: + if len(self.free_vars) > 0: + bind_depth = max(self.free_vars.values()) + bind_depth = max(bind_depth, self.min_binding_depth) + else: + bind_depth = self.min_binding_depth + return bind_depth + + def make_child_frame(self, depth: int) -> 'StackFrame': + x = self.node + i = self.child_idx - 1 + child = x.children[i] + if x.new_block(i): + child_outermost_scope = depth + else: + child_outermost_scope = self.min_binding_depth + + child_context = x.child_context_without_bindings(i, self.context) + if x.binds(i): + (eval_c, agg_c, scan_c) = child_context + eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in self.child_eval_bindings]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in self.child_agg_bindings]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in self.child_scan_bindings]) + child_context = (eval_c, agg_c, scan_c) + + return StackFrame(child_outermost_scope, child_context, child) class CSERenderer(Renderer): @@ -175,7 +206,7 @@ def __init__(self, stop_at_jir=False): self.jirs = {} self.memo: Dict[int, Sequence[str]] = {} self.uid_count = 0 - self.scopes: Dict[int, Scope] = {} + self.binding_sites: Dict[int, BindingSite] = {} def uid(self) -> str: self.uid_count += 1 @@ -196,14 +227,14 @@ def add_jir(self, x: 'ir.BaseIR'): return f'(JavaIR {jir_id})' @staticmethod - def find_in_scope(x: 'ir.BaseIR', context: List[Scope], outermost_scope: int) -> int: + def find_in_scope(x: 'ir.BaseIR', context: List[StackFrame], outermost_scope: int) -> int: for i in range(len(context) - 1, outermost_scope - 1, -1): if id(x) in context[i].visited: return i return -1 @staticmethod - def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: + def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: for i in range(len(context) - 1, -1, -1): if id(x) in context[i].lifted_lets: return i @@ -224,14 +255,14 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[Scope]) -> int: # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - def print(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, x: 'ir.BaseIR'): + def print(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: builder.append(self.memo[id(x)]) return - insert_lets = id(x) in self.scopes and len(self.scopes[id(x)].lifted_lets) > 0 + insert_lets = id(x) in self.binding_sites and len(self.binding_sites[id(x)].lifted_lets) > 0 if insert_lets: local_builder = [] - context.append(self.scopes[id(x)]) + context.append(PrintStackFrame(self.binding_sites[id(x)])) else: local_builder = builder head = x.render_head(self) @@ -265,15 +296,16 @@ def print(self, builder: List[str], context: List[Scope], outermost_scope: int, ir_child_num = self.print_renderable(local_builder, context, child_outermost_scope, depth + 1, ir_child_num, child) local_builder.append(x.render_tail(self)) if insert_lets: + sf = context[-1] context.pop() - for let_body in self.scopes[id(x)].let_bodies: + for let_body in sf.let_bodies: builder.extend(let_body) builder.extend(local_builder) - num_lets = len(self.scopes[id(x)].lifted_lets) + num_lets = len(sf.lifted_lets) for i in range(num_lets): builder.append(')') - def print_renderable(self, builder: List[str], context: List[Scope], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable): + def print_renderable(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable): assert not isinstance(x, ir.BaseIR) head = x.render_head(self) if head != '': @@ -290,109 +322,101 @@ def print_renderable(self, builder: List[str], context: List[Scope], outermost_s builder.append(x.render_tail(self)) return ir_child_num + def compute_new_bindings(self, root: 'ir.BaseIR'): + # force computation of all types + root.typ - def __call__(self, x: 'ir.BaseIR') -> str: - x.typ + root_frame = StackFrame(0, ({}, {}, {}), root) + stack = [root_frame] + binding_sites = {} - state = StackFrame(0, ({}, {}, {}), x) - kont_stack = [state] while True: - state = kont_stack[-1] - depth = len(kont_stack) - i = state.child_idx - x = state.parent + frame = stack[-1] + node = frame.node + child_idx = frame.child_idx + frame.child_idx += 1 - if i >= len(x.children): - if len(kont_stack) <= 1: + if child_idx >= len(node.children): + if len(stack) <= 1: break - child_state = state - - if len(child_state.parent_free_vars) > 0: - bind_depth = max(child_state.parent_free_vars.values()) - bind_depth = max(bind_depth, child_state.parent_outermost_scope) - else: - bind_depth = child_state.parent_outermost_scope - kont_stack[bind_depth].visited[id(x)] = x - - kont_stack.pop() - state = kont_stack[-1] - depth = len(kont_stack) - - if child_state.lifted_lets: - new_scope = Scope(depth) - new_scope.lifted_lets = child_state.lifted_lets - self.scopes[id(child_state.parent)] = new_scope - for var in [*state.eval_b, *state.agg_b, *state.scan_b]: - child_state.parent_free_vars.pop(var, 0) - for (var, _) in child_state.lifted_lets.values(): - child_state.parent_free_vars.pop(var, 0) - state.parent_free_vars.update(child_state.parent_free_vars) - state.child_idx += 1 + parent_frame = stack[-2] + + # mark node as visited at potential let insertion site + stack[frame.bind_depth()].visited[id(node)] = node + + # if any lets being inserted here, add node to registry of + # binding sites + if frame.lifted_lets: + binding_sites[id(node)] = \ + BindingSite(frame.lifted_lets, len(stack)) + + # subtract vars bound by parent from free_vars + for var in [*parent_frame.child_eval_bindings, + *parent_frame.child_agg_bindings, + *parent_frame.child_scan_bindings]: + frame.free_vars.pop(var, 0) + # subtract vars that will be bound by inserted lets + for (var, _) in frame.lifted_lets.values(): + frame.free_vars.pop(var, 0) + # update parent's free variables + parent_frame.free_vars.update(frame.free_vars) + stack.pop() continue - child = x.children[i] + child = node.children[child_idx] if self.stop_at_jir and hasattr(child, '_jir'): self.memo[id(child)] = self.add_jir(child) - state.child_idx += 1 continue - seen_in_scope = self.find_in_scope(child, kont_stack, state.parent_outermost_scope) + seen_in_scope = self.find_in_scope(child, stack, frame.min_binding_depth) if seen_in_scope >= 0 and isinstance(child, ir.IR): # we've seen 'child' before, should not traverse (or we will find # too many lifts) - if id(child) not in kont_stack[seen_in_scope].lifted_lets: + if id(child) in stack[seen_in_scope].lifted_lets: + (uid, _) = stack[seen_in_scope].lifted_lets[id(child)] + else: # second time we've seen 'child', lift to a let uid = self.uid() - kont_stack[seen_in_scope].lifted_lets[id(child)] = (uid, child) - else: - (uid, _) = kont_stack[seen_in_scope].lifted_lets[id(child)] - state.parent_free_vars[uid] = seen_in_scope - state.child_idx += 1 + stack[seen_in_scope].lifted_lets[id(child)] = (uid, child) + # Since we are not traversing 'child', we don't know its free + # variables. To prevent a parent from being lifted too high, + # we must register 'child' as having the free variable 'uid', + # which will be true when 'child' is replaced by "Ref uid". + frame.free_vars[uid] = seen_in_scope continue + # first time visiting 'child' + # compute vars bound in 'child' by 'node' def get_vars(bindings): if isinstance(bindings, dict): bindings = bindings.items() return [var for (var, _) in bindings] - state.eval_b = get_vars(x.bindings(i)) - state.agg_b = get_vars(x.agg_bindings(i)) - state.scan_b = get_vars(x.scan_bindings(i)) - - if x.new_block(i): - state.child_outermost_scope = depth - else: - state.child_outermost_scope = state.parent_outermost_scope - - child_context = x.child_context_without_bindings(i, state.parent_context) - if x.binds(i): - (eval_c, agg_c, scan_c) = child_context - eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in state.eval_b]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in state.agg_b]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in state.scan_b]) - child_context = (eval_c, agg_c, scan_c) + frame.child_eval_bindings = get_vars(node.bindings(child_idx)) + frame.child_agg_bindings = get_vars(node.agg_bindings(child_idx)) + frame.child_scan_bindings = get_vars(node.scan_bindings(child_idx)) if isinstance(child, ir.Ref): - child_free_vars = {child.name: child_context[0][child.name]} - for var in [*state.eval_b, *state.agg_b, *state.scan_b]: - child_free_vars.pop(var, 0) - state.parent_free_vars.update(child_free_vars) - state.child_idx += 1 + if child.name not in frame.child_eval_bindings: + (eval_c, _, _) = node.child_context_without_bindings(child_idx, frame.context) + frame.free_vars.update({child.name: eval_c[child.name]}) continue - state = StackFrame(state.child_outermost_scope, child_context, child) - kont_stack.append(state) + stack.append(frame.make_child_frame(len(stack))) continue - for (var, _) in state.lifted_lets.values(): - var_depth = state.parent_free_vars.pop(var, 0) + for (var, _) in root_frame.lifted_lets.values(): + var_depth = root_frame.free_vars.pop(var, 0) assert var_depth == 0 - assert(len(state.parent_free_vars) == 0) - root_scope = Scope(0) - root_scope.lifted_lets = state.lifted_lets - self.scopes[id(x)] = root_scope + assert(len(root_frame.free_vars) == 0) + + binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0) + return binding_sites + + def __call__(self, root: 'ir.BaseIR') -> str: + self.binding_sites = self.compute_new_bindings(root) builder = [] - self.print(builder, [], 0, 1, x) + self.print(builder, [], 0, 1, root) return ''.join(builder) From ed6ce8104248f168cb22557c313b338796a491f8 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 7 Aug 2019 09:56:43 -0400 Subject: [PATCH 12/47] more refactoring --- hail/python/hail/expr/matrix_type.py | 30 +++++++---- hail/python/hail/expr/table_type.py | 11 ++-- hail/python/hail/ir/base_ir.py | 19 +++---- hail/python/hail/ir/ir.py | 8 +-- hail/python/hail/ir/matrix_ir.py | 38 +++++++------- hail/python/hail/ir/renderer.py | 75 ++++++++++++++++------------ hail/python/hail/ir/table_ir.py | 18 +++---- 7 files changed, 114 insertions(+), 85 deletions(-) diff --git a/hail/python/hail/expr/matrix_type.py b/hail/python/hail/expr/matrix_type.py index d63010489d5..3f4144d9a62 100644 --- a/hail/python/hail/expr/matrix_type.py +++ b/hail/python/hail/expr/matrix_type.py @@ -122,22 +122,34 @@ def _rename(self, global_map, col_map, row_map, entry_map): [row_map.get(k, k) for k in self.row_key], self.entry_type._rename(entry_map)) + def global_bindings(self): + return [('global', self.global_type)] + def global_env(self): - return {'global': self.global_type} + return dict(self.global_bindings()) + + def row_bindings(self): + return [('global', self.global_type), + ('va', self.row_type)] def row_env(self): - return {'global': self.global_type, - 'va': self.row_type} + return dict(self.row_bindings()) + + def col_bindings(self): + return [('global', self.global_type), + ('sa', self.col_type)] def col_env(self): - return {'global': self.global_type, - 'sa': self.col_type} + return dict(self.col_bindings()) + + def entry_bindings(self): + return [('global', self.global_type), + ('va', self.row_type), + ('sa', self.col_type), + ('g', self.entry_type)] def entry_env(self): - return {'global': self.global_type, - 'va': self.row_type, - 'sa': self.col_type, - 'g': self.entry_type} + return dict(self.entry_bindings()) import pprint diff --git a/hail/python/hail/expr/table_type.py b/hail/python/hail/expr/table_type.py index 705dea5120b..b908edcd792 100644 --- a/hail/python/hail/expr/table_type.py +++ b/hail/python/hail/expr/table_type.py @@ -81,11 +81,16 @@ def _rename(self, global_map, row_map): [row_map.get(k, k) for k in self.row_key]) def row_env(self): - return {'global': self.global_type, - 'row': self.row_type} + return dict(self.row_bindings()) + + def row_bindings(self): + return [('global', self.global_type), ('row', self.row_type)] def global_env(self): - return {'global': self.global_type} + return dict(self.global_bindings()) + + def global_bindings(self): + return [('global', self.global_type)] import pprint diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index bb2483c5cf3..21e50d30cbd 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,8 +1,9 @@ import abc -from typing import List, Set, Optional +from typing import List, Tuple from hail.utils.java import Env +from hail.expr.types import HailType from .renderer import Renderer, Renderable, RenderableStr @@ -84,25 +85,25 @@ def _eq(self, other): def __hash__(self): return 31 + hash(str(self)) - def new_block(self, i: int): + def new_block(self, i: int) -> bool: return self.uses_agg_context(i) or self.uses_scan_context(i) - def binds(self, i: int): + def binds(self, i: int) -> bool: return False - def bindings(self, i: int): + def bindings(self, i: int) -> List[Tuple[str, HailType]]: return [] - def agg_bindings(self, i: int): + def agg_bindings(self, i: int) -> List[Tuple[str, HailType]]: return [] - def scan_bindings(self, i: int): + def scan_bindings(self, i: int) -> List[Tuple[str, HailType]]: return [] - def uses_agg_context(self, i: int): + def uses_agg_context(self, i: int) -> bool: return False - def uses_scan_context(self, i: int): + def uses_scan_context(self, i: int) -> bool: return False def child_context_without_bindings(self, i: int, parent_context): @@ -121,7 +122,7 @@ def child_context(self, i: int, parent_context, default_value=None): eval_b = self.bindings(i) agg_b = self.agg_bindings(i) scan_b = self.scan_bindings(i) - if default_value: + if default_value is not None: eval_b = [(var, default_value) for (var, _) in eval_b] agg_b = [(var, default_value) for (var, _) in agg_b] scan_b = [(var, default_value) for (var, _) in scan_b] diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 8c68d0c0bd1..1114f641d18 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1798,10 +1798,10 @@ def new_block(self, i: int): return i == 1 def bindings(self, i): - return self.child.typ.global_env() if i == 1 else [] + return self.child.typ.global_bindings() if i == 1 else [] def agg_bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -1831,10 +1831,10 @@ def new_block(self, i: int): return i == 1 def bindings(self, i): - return self.child.typ.global_env() if i == 1 else [] + return self.child.typ.global_bindings() if i == 1 else [] def agg_bindings(self, i): - return self.child.typ.entry_env() if i == 1 else [] + return self.child.typ.entry_bindings() if i == 1 else [] def binds(self, i): return i == 1 diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index 6bb4ba7d264..485477c2d1f 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -28,17 +28,17 @@ def _compute_type(self): def bindings(self, i): if i == 1: - return self.child.typ.col_env() + return self.child.typ.col_bindings() elif i == 2: - return self.child.typ.global_env() + return self.child.typ.global_bindings() else: return [] def agg_bindings(self, i): if i == 1: - return self.child.typ.entry_env() + return self.child.typ.entry_bindings() elif i == 2: - return self.child.typ.row_env() + return self.child.typ.row_bindings() else: return [] @@ -78,7 +78,7 @@ def _compute_type(self): self._type = self.child.typ def bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -129,13 +129,13 @@ def _compute_type(self): child_typ.entry_type) def bindings(self, i): - return self.child.typ.col_env() if i == 1 else [] + return self.child.typ.col_bindings() if i == 1 else [] def agg_bindings(self, i): - return self.child.typ.entry_env() if i == 1 else [] + return self.child.typ.entry_bindings() if i == 1 else [] def scan_bindings(self, i): - return self.child.typ.col_env() if i == 1 else [] + return self.child.typ.col_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -174,7 +174,7 @@ def _compute_type(self): self.new_entry.typ) def bindings(self, i): - return self.child.typ.entry_env() if i == 1 else [] + return self.child.typ.entry_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -195,7 +195,7 @@ def _compute_type(self): self._type = self.child.typ def bindings(self, i): - return self.child.typ.entry_env() if i == 1 else [] + return self.child.typ.entry_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -249,13 +249,13 @@ def _compute_type(self): child_typ.entry_type) def bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def agg_bindings(self, i): - return self.child.typ.entry_env() if i == 1 else [] + return self.child.typ.entry_bindings() if i == 1 else [] def scan_bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -283,7 +283,7 @@ def _compute_type(self): child_typ.entry_type) def bindings(self, i): - return self.child.typ.global_env() if i == 1 else [] + return self.child.typ.global_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -304,7 +304,7 @@ def _compute_type(self): self._type = self.child.typ def bindings(self, i): - return self.child.typ.col_env() if i == 1 else [] + return self.child.typ.col_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -352,17 +352,17 @@ def _compute_type(self): def bindings(self, i): if i == 1: - return self.child.typ.row_env() + return self.child.typ.row_bindingss() elif i == 2: - return self.child.typ.global_env() + return self.child.typ.global_bindings() else: return [] def agg_bindings(self, i): if i == 1: - return self.child.typ.entry_env() + return self.child.typ.entry_bindings() elif i == 2: - return self.child.typ.col_env() + return self.child.typ.col_bindings() else: return [] diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index fbeb4dbeb15..7dbd905c96e 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -155,20 +155,19 @@ def __init__(self, binding_site: BindingSite): class StackFrame: - def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR'): + def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR', + new_bindings=(None, None, None)): # immutable self.min_binding_depth = min_binding_depth self.context = context self.node = x + (self._parent_eval_bindings, self._parent_agg_bindings, + self._parent_scan_bindings) = new_bindings # mutable self.free_vars = {} self.visited: Dict[int, 'ir.BaseIR'] = {} self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} - # cached state while visiting a child self.child_idx = 0 - self.child_eval_bindings = None - self.child_agg_bindings = None - self.child_scan_bindings = None # compute depth at which we might bind this node def bind_depth(self) -> int: @@ -188,15 +187,40 @@ def make_child_frame(self, depth: int) -> 'StackFrame': else: child_outermost_scope = self.min_binding_depth + # compute vars bound in 'child' by 'node' + def get_vars(bindings): + return [var for (var, _) in bindings] + eval_bindings = get_vars(x.bindings(i)) + agg_bindings = get_vars(x.agg_bindings(i)) + scan_bindings = get_vars(x.scan_bindings(i)) + child_context = x.child_context_without_bindings(i, self.context) if x.binds(i): (eval_c, agg_c, scan_c) = child_context - eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in self.child_eval_bindings]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in self.child_agg_bindings]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in self.child_scan_bindings]) + eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in + eval_bindings]) + agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in + agg_bindings]) + scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in + scan_bindings]) child_context = (eval_c, agg_c, scan_c) - return StackFrame(child_outermost_scope, child_context, child) + new_bindings = (eval_bindings, agg_bindings, scan_bindings) + + return StackFrame(child_outermost_scope, child_context, child, + new_bindings) + + def update_parent_free_vars(self, parent_free_vars: Set[str]): + # subtract vars bound by parent from free_vars + for var in [*self._parent_eval_bindings, + *self._parent_agg_bindings, + *self._parent_scan_bindings]: + self.free_vars.pop(var, 0) + # subtract vars that will be bound by inserted lets + for (var, _) in self.lifted_lets.values(): + self.free_vars.pop(var, 0) + # update parent's free variables + parent_free_vars.update(self.free_vars) class CSERenderer(Renderer): @@ -351,16 +375,8 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): binding_sites[id(node)] = \ BindingSite(frame.lifted_lets, len(stack)) - # subtract vars bound by parent from free_vars - for var in [*parent_frame.child_eval_bindings, - *parent_frame.child_agg_bindings, - *parent_frame.child_scan_bindings]: - frame.free_vars.pop(var, 0) - # subtract vars that will be bound by inserted lets - for (var, _) in frame.lifted_lets.values(): - frame.free_vars.pop(var, 0) - # update parent's free variables - parent_frame.free_vars.update(frame.free_vars) + frame.update_parent_free_vars(parent_frame.free_vars) + stack.pop() continue @@ -370,11 +386,12 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): self.memo[id(child)] = self.add_jir(child) continue - seen_in_scope = self.find_in_scope(child, stack, frame.min_binding_depth) + seen_in_scope = self.find_in_scope(child, stack, + frame.min_binding_depth) if seen_in_scope >= 0 and isinstance(child, ir.IR): - # we've seen 'child' before, should not traverse (or we will find - # too many lifts) + # we've seen 'child' before, should not traverse (or we will + # find too many lifts) if id(child) in stack[seen_in_scope].lifted_lets: (uid, _) = stack[seen_in_scope].lifted_lets[id(child)] else: @@ -389,18 +406,12 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): continue # first time visiting 'child' - # compute vars bound in 'child' by 'node' - def get_vars(bindings): - if isinstance(bindings, dict): - bindings = bindings.items() - return [var for (var, _) in bindings] - frame.child_eval_bindings = get_vars(node.bindings(child_idx)) - frame.child_agg_bindings = get_vars(node.agg_bindings(child_idx)) - frame.child_scan_bindings = get_vars(node.scan_bindings(child_idx)) if isinstance(child, ir.Ref): - if child.name not in frame.child_eval_bindings: - (eval_c, _, _) = node.child_context_without_bindings(child_idx, frame.context) + if child.name not in (var for (var, _) in + node.bindings(child_idx)): + (eval_c, _, _) = node.child_context_without_bindings( + child_idx, frame.context) frame.free_vars.update({child.name: eval_c[child.name]}) continue diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 825a092835d..23b40aebe49 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -140,7 +140,7 @@ def new_block(i): return i == 1 def bindings(self, i): - return self.child.typ.global_env() if i == 1 else [] + return self.child.typ.global_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -203,10 +203,10 @@ def _compute_type(self): self.child.typ.row_key) def bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def scan_bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -277,7 +277,7 @@ def _compute_type(self): self._type = self.child.typ def bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def binds(self, i): return i == 1 @@ -311,14 +311,14 @@ def _compute_type(self): def bindings(self, i): if i == 1: - return self.child.typ.global_env() + return self.child.typ.global_bindings() elif i == 2: - return self.child.typ.row_env() + return self.child.typ.row_bindings() else: return [] def agg_bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def binds(self, i): return i == 1 or i == 2 @@ -342,10 +342,10 @@ def _compute_type(self): child_typ.row_key) def bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def agg_bindings(self, i): - return self.child.typ.row_env() if i == 1 else [] + return self.child.typ.row_bindings() if i == 1 else [] def binds(self, i): return i == 1 From 654f219200452e41487904c2a11dad9c17b2fb85 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 7 Aug 2019 13:04:47 -0400 Subject: [PATCH 13/47] second pass: loop -> recursion --- hail/python/hail/ir/renderer.py | 128 +++++++++++++++++++++----------- 1 file changed, 86 insertions(+), 42 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 7dbd905c96e..1d1435c1e51 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -279,6 +279,49 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. + class State: + ... + + def loop(self, state): + if state.i >= len(state.children): + return + state.builder.append(' ') + child = state.children[state.i] + lift_to = self.lifted_in_scope(child, state.context) + child_outermost_scope = state.outermost_scope + if state.x.new_block(state.ir_child_num): + child_outermost_scope = state.depth + if (lift_to >= 0 and + state.context[lift_to] and + state.context[lift_to].depth >= state.outermost_scope): + state.ir_child_num += 1 + (name, _) = state.context[lift_to].lifted_lets[id(child)] + if id(child) not in state.context[lift_to].visited: + state.context[lift_to].visited[id(child)] = child + let_body = [f'(Let {name} '] + self.print(let_body, state.context, child_outermost_scope, + state.depth + 1, child) + let_body.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + state.context[lift_to].let_bodies.append(let_body) + state.builder.append(f'(Ref {name})') + state.i += 1 + self.loop(state) + else: + if isinstance(child, ir.BaseIR): + self.print(state.builder, state.context, + child_outermost_scope, state.depth + 1, child) + state.i += 1 + self.loop(state) + else: + state.ir_child_num = self.print_renderable( + state.builder, state.context, + child_outermost_scope, state.depth + 1, + state.ir_child_num, child) + state.i += 1 + self.loop(state) + def print(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, x: 'ir.BaseIR'): if id(x) in self.memo: builder.append(self.memo[id(x)]) @@ -292,59 +335,60 @@ def print(self, builder: List[str], context: List[PrintStackFrame], outermost_sc head = x.render_head(self) if head != '': local_builder.append(head) - children = x.render_children(self) - ir_child_num = 0 - for i in range(0, len(children)): - local_builder.append(' ') - child = children[i] - lift_to = self.lifted_in_scope(child, context) - child_outermost_scope = outermost_scope - if x.new_block(ir_child_num): - child_outermost_scope = depth - if lift_to >= 0 and context[lift_to] and context[lift_to].depth >= outermost_scope: - ir_child_num += 1 - (name, _) = context[lift_to].lifted_lets[id(child)] - if id(child) not in context[lift_to].visited: - context[lift_to].visited[id(child)] = child - let_body = [f'(Let {name} '] - self.print(let_body, context, child_outermost_scope, depth + 1, child) - let_body.append(' ') - # let_bodies is built post-order, which guarantees earlier - # lets can't refer to later lets - context[lift_to].let_bodies.append(let_body) - local_builder.append(f'(Ref {name})') - else: - if isinstance(child, ir.BaseIR): - self.print(local_builder, context, child_outermost_scope, depth + 1, child) - else: - ir_child_num = self.print_renderable(local_builder, context, child_outermost_scope, depth + 1, ir_child_num, child) - local_builder.append(x.render_tail(self)) + + state = self.State() + state.x = x + state.children = x.render_children(self) + state.builder = local_builder + state.context = context + state.outermost_scope = outermost_scope + state.ir_child_num = 0 + state.depth = depth + state.i = 0 + + self.loop(state) + + state.builder.append(state.x.render_tail(self)) if insert_lets: - sf = context[-1] - context.pop() + sf = state.context[-1] + state.context.pop() for let_body in sf.let_bodies: builder.extend(let_body) - builder.extend(local_builder) + builder.extend(state.builder) num_lets = len(sf.lifted_lets) - for i in range(num_lets): + for _ in range(num_lets): builder.append(')') + def loop2(self, state): + if state.i >= len(state.children): + return + state.builder.append(' ') + child = state.children[state.i] + if isinstance(child, ir.BaseIR): + self.print(state.builder, state.context, state.outermost_scope, state.depth, child) + state.ir_child_num += 1 + else: + state.ir_child_num = self.print_renderable(state.builder, state.context, state.outermost_scope, state.depth, state.ir_child_num, child) + state.i += 1 + self.loop2(state) + def print_renderable(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable): assert not isinstance(x, ir.BaseIR) head = x.render_head(self) if head != '': builder.append(head) - children = x.render_children(self) - for i in range(0, len(children)): - builder.append(' ') - child = children[i] - if isinstance(child, ir.BaseIR): - self.print(builder, context, outermost_scope, depth, child) - ir_child_num += 1 - else: - ir_child_num = self.print_renderable(builder, context, outermost_scope, depth, ir_child_num, child) - builder.append(x.render_tail(self)) - return ir_child_num + state = self.State() + state.x = x + state.children = x.render_children(self) + state.builder = builder + state.context = context + state.outermost_scope = outermost_scope + state.ir_child_num = ir_child_num + state.depth = depth + state.i = 0 + self.loop2(state) + state.builder.append(state.x.render_tail(self)) + return state.ir_child_num def compute_new_bindings(self, root: 'ir.BaseIR'): # force computation of all types From f10306ea0ac6441e8364dcde45edeaea5791ba1e Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 8 Aug 2019 10:50:38 -0400 Subject: [PATCH 14/47] CPS --- hail/python/hail/ir/renderer.py | 132 ++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 59 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 1d1435c1e51..ac6a7f43ec1 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -282,9 +282,9 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: class State: ... - def loop(self, state): + def loop(self, state, k): if state.i >= len(state.children): - return + return k() state.builder.append(' ') child = state.children[state.i] lift_to = self.lifted_in_scope(child, state.context) @@ -299,80 +299,92 @@ def loop(self, state): if id(child) not in state.context[lift_to].visited: state.context[lift_to].visited[id(child)] = child let_body = [f'(Let {name} '] - self.print(let_body, state.context, child_outermost_scope, - state.depth + 1, child) - let_body.append(' ') - # let_bodies is built post-order, which guarantees earlier - # lets can't refer to later lets - state.context[lift_to].let_bodies.append(let_body) - state.builder.append(f'(Ref {name})') - state.i += 1 - self.loop(state) + + def post_lift_visit(): + let_body.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + state.context[lift_to].let_bodies.append(let_body) + state.builder.append(f'(Ref {name})') + state.i += 1 + return self.loop(state, k) + + return self.print(let_body, state.context, child_outermost_scope, + state.depth + 1, child, post_lift_visit) + else: + state.builder.append(f'(Ref {name})') + state.i += 1 + return self.loop(state, k) else: - if isinstance(child, ir.BaseIR): - self.print(state.builder, state.context, - child_outermost_scope, state.depth + 1, child) + def post_visit(): state.i += 1 - self.loop(state) + return self.loop(state, k) + if isinstance(child, ir.BaseIR): + return self.print(state.builder, state.context, + child_outermost_scope, state.depth + 1, child, + post_visit) else: - state.ir_child_num = self.print_renderable( + return self.print_renderable( state.builder, state.context, child_outermost_scope, state.depth + 1, - state.ir_child_num, child) - state.i += 1 - self.loop(state) - - def print(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, x: 'ir.BaseIR'): - if id(x) in self.memo: - builder.append(self.memo[id(x)]) - return - insert_lets = id(x) in self.binding_sites and len(self.binding_sites[id(x)].lifted_lets) > 0 - if insert_lets: - local_builder = [] - context.append(PrintStackFrame(self.binding_sites[id(x)])) - else: - local_builder = builder - head = x.render_head(self) - if head != '': - local_builder.append(head) + state.ir_child_num, child, post_visit) + def print(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, x: 'ir.BaseIR', k): state = self.State() state.x = x state.children = x.render_children(self) - state.builder = local_builder + state.builder = builder state.context = context state.outermost_scope = outermost_scope state.ir_child_num = 0 state.depth = depth state.i = 0 - self.loop(state) - - state.builder.append(state.x.render_tail(self)) + if id(x) in self.memo: + state.builder.append(self.memo[id(x)]) + return k() + insert_lets = id(x) in self.binding_sites and len(self.binding_sites[id(x)].lifted_lets) > 0 if insert_lets: - sf = state.context[-1] - state.context.pop() - for let_body in sf.let_bodies: - builder.extend(let_body) - builder.extend(state.builder) - num_lets = len(sf.lifted_lets) - for _ in range(num_lets): - builder.append(')') - - def loop2(self, state): + state.builder = [] + context.append(PrintStackFrame(self.binding_sites[id(x)])) + head = x.render_head(self) + if head != '': + state.builder.append(head) + + def post_children(): + state.builder.append(state.x.render_tail(self)) + if insert_lets: + sf = state.context[-1] + state.context.pop() + for let_body in sf.let_bodies: + builder.extend(let_body) + builder.extend(state.builder) + num_lets = len(sf.lifted_lets) + for _ in range(num_lets): + builder.append(')') + return k() + + return self.loop(state, post_children) + + + def loop2(self, state, k): if state.i >= len(state.children): - return + return k() state.builder.append(' ') child = state.children[state.i] if isinstance(child, ir.BaseIR): - self.print(state.builder, state.context, state.outermost_scope, state.depth, child) - state.ir_child_num += 1 + def post_visit_2(): + state.ir_child_num += 1 + state.i += 1 + return self.loop2(state, k) + return self.print(state.builder, state.context, state.outermost_scope, state.depth, child, post_visit_2) else: - state.ir_child_num = self.print_renderable(state.builder, state.context, state.outermost_scope, state.depth, state.ir_child_num, child) - state.i += 1 - self.loop2(state) + def post_visit_renderable_2(): + state.i += 1 + return self.loop2(state, k) + return self.print_renderable(state.builder, state.context, state.outermost_scope, state.depth, state.ir_child_num, child, post_visit_renderable_2) - def print_renderable(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable): + def print_renderable(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable, k): assert not isinstance(x, ir.BaseIR) head = x.render_head(self) if head != '': @@ -386,9 +398,10 @@ def print_renderable(self, builder: List[str], context: List[PrintStackFrame], o state.ir_child_num = ir_child_num state.depth = depth state.i = 0 - self.loop2(state) - state.builder.append(state.x.render_tail(self)) - return state.ir_child_num + def post_children_renderable(): + state.builder.append(state.x.render_tail(self)) + return k() + return self.loop2(state, post_children_renderable) def compute_new_bindings(self, root: 'ir.BaseIR'): # force computation of all types @@ -473,5 +486,6 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): def __call__(self, root: 'ir.BaseIR') -> str: self.binding_sites = self.compute_new_bindings(root) builder = [] - self.print(builder, [], 0, 1, root) - return ''.join(builder) + def post_root(): + return ''.join(builder) + return self.print(builder, [], 0, 1, root, post_root) From 8460050d5fe6925e2c579e31e68abc4e2f25992d Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 8 Aug 2019 11:23:40 -0400 Subject: [PATCH 15/47] factor out args into state object --- hail/python/hail/ir/renderer.py | 74 ++++++++++++++++----------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index ac6a7f43ec1..85d40042815 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -281,6 +281,21 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: class State: ... + def __init__(self, x, children, builder, context, outermost_scope, depth): + self.x = x + self.children = children + self.builder = builder + self.context = context + self.outermost_scope = outermost_scope + self.ir_child_num = 0 + self.depth = depth + self.i = 0 + + def copy(self): + new_state = CSERenderer.State(self.x, self.children, self.builder, self.context, self.outermost_scope, self.depth) + new_state.ir_child_num = self.ir_child_num + new_state.i = self.i + return new_state def loop(self, state, k): if state.i >= len(state.children): @@ -309,8 +324,8 @@ def post_lift_visit(): state.i += 1 return self.loop(state, k) - return self.print(let_body, state.context, child_outermost_scope, - state.depth + 1, child, post_lift_visit) + new_state = self.State(child, child.render_children(self), let_body, state.context, child_outermost_scope, state.depth + 1) + return self.print(new_state, post_lift_visit) else: state.builder.append(f'(Ref {name})') state.i += 1 @@ -320,33 +335,23 @@ def post_visit(): state.i += 1 return self.loop(state, k) if isinstance(child, ir.BaseIR): - return self.print(state.builder, state.context, - child_outermost_scope, state.depth + 1, child, - post_visit) + new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1) + return self.print(new_state, post_visit) else: - return self.print_renderable( - state.builder, state.context, - child_outermost_scope, state.depth + 1, - state.ir_child_num, child, post_visit) - - def print(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, x: 'ir.BaseIR', k): - state = self.State() - state.x = x - state.children = x.render_children(self) - state.builder = builder - state.context = context - state.outermost_scope = outermost_scope - state.ir_child_num = 0 - state.depth = depth - state.i = 0 + new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1) + new_state.ir_child_num = state.ir_child_num + return self.print_renderable(new_state, post_visit) + def print(self, state, k): + x = state.x + builder = state.builder if id(x) in self.memo: state.builder.append(self.memo[id(x)]) return k() insert_lets = id(x) in self.binding_sites and len(self.binding_sites[id(x)].lifted_lets) > 0 if insert_lets: state.builder = [] - context.append(PrintStackFrame(self.binding_sites[id(x)])) + state.context.append(PrintStackFrame(self.binding_sites[id(x)])) head = x.render_head(self) if head != '': state.builder.append(head) @@ -377,27 +382,21 @@ def post_visit_2(): state.ir_child_num += 1 state.i += 1 return self.loop2(state, k) - return self.print(state.builder, state.context, state.outermost_scope, state.depth, child, post_visit_2) + new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) + return self.print(new_state, post_visit_2) else: def post_visit_renderable_2(): state.i += 1 return self.loop2(state, k) - return self.print_renderable(state.builder, state.context, state.outermost_scope, state.depth, state.ir_child_num, child, post_visit_renderable_2) + new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) + new_state.ir_child_num = state.ir_child_num + return self.print_renderable(new_state, post_visit_renderable_2) - def print_renderable(self, builder: List[str], context: List[PrintStackFrame], outermost_scope: int, depth: int, ir_child_num: int, x: Renderable, k): - assert not isinstance(x, ir.BaseIR) - head = x.render_head(self) + def print_renderable(self, state, k): + assert not isinstance(state.x, ir.BaseIR) + head = state.x.render_head(self) if head != '': - builder.append(head) - state = self.State() - state.x = x - state.children = x.render_children(self) - state.builder = builder - state.context = context - state.outermost_scope = outermost_scope - state.ir_child_num = ir_child_num - state.depth = depth - state.i = 0 + state.builder.append(head) def post_children_renderable(): state.builder.append(state.x.render_tail(self)) return k() @@ -488,4 +487,5 @@ def __call__(self, root: 'ir.BaseIR') -> str: builder = [] def post_root(): return ''.join(builder) - return self.print(builder, [], 0, 1, root, post_root) + state = self.State(root, root.render_children(self), builder, [], 0, 1) + return self.print(state, post_root) From 16eb524226f9aa3a6ac7b3011a2341d2ec75a882 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 8 Aug 2019 11:36:28 -0400 Subject: [PATCH 16/47] inline print_renderable --- hail/python/hail/ir/renderer.py | 53 +++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 85d40042815..cc90aaa47bb 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -338,9 +338,19 @@ def post_visit(): new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1) return self.print(new_state, post_visit) else: + head = child.render_head(self) + if head != '': + state.builder.append(head) + new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1) new_state.ir_child_num = state.ir_child_num - return self.print_renderable(new_state, post_visit) + + def post_children_renderable(): + new_state.builder.append(new_state.x.render_tail(self)) + state.i += 1 + return self.loop(state, k) + + return self.loop2(new_state, post_children_renderable) def print(self, state, k): x = state.x @@ -359,18 +369,21 @@ def print(self, state, k): def post_children(): state.builder.append(state.x.render_tail(self)) if insert_lets: - sf = state.context[-1] - state.context.pop() - for let_body in sf.let_bodies: - builder.extend(let_body) - builder.extend(state.builder) - num_lets = len(sf.lifted_lets) - for _ in range(num_lets): - builder.append(')') + self.add_lets(state, builder) return k() return self.loop(state, post_children) + @staticmethod + def add_lets(state, builder): + sf = state.context[-1] + state.context.pop() + for let_body in sf.let_bodies: + builder.extend(let_body) + builder.extend(state.builder) + num_lets = len(sf.lifted_lets) + for _ in range(num_lets): + builder.append(')') def loop2(self, state, k): if state.i >= len(state.children): @@ -385,22 +398,18 @@ def post_visit_2(): new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) return self.print(new_state, post_visit_2) else: - def post_visit_renderable_2(): - state.i += 1 - return self.loop2(state, k) + head = child.render_head(self) + if head != '': + state.builder.append(head) + new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) new_state.ir_child_num = state.ir_child_num - return self.print_renderable(new_state, post_visit_renderable_2) - def print_renderable(self, state, k): - assert not isinstance(state.x, ir.BaseIR) - head = state.x.render_head(self) - if head != '': - state.builder.append(head) - def post_children_renderable(): - state.builder.append(state.x.render_tail(self)) - return k() - return self.loop2(state, post_children_renderable) + def post_children_renderable(): + new_state.builder.append(new_state.x.render_tail(self)) + state.i += 1 + return self.loop2(state, k) + return self.loop2(new_state, post_children_renderable) def compute_new_bindings(self, root: 'ir.BaseIR'): # force computation of all types From d196b2e5bf2892c8e6ab28a9a0223eb283b0630f Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 8 Aug 2019 14:34:15 -0400 Subject: [PATCH 17/47] inline print --- hail/python/hail/ir/renderer.py | 129 ++++++++++++++++++++++---------- 1 file changed, 88 insertions(+), 41 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index cc90aaa47bb..4a0e187fa09 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -297,6 +297,26 @@ def copy(self): new_state.i = self.i return new_state + def pre_add_lets(self, binding_sites): + x = self.x + insert_lets = id(x) in binding_sites and len(binding_sites[id(x)].lifted_lets) > 0 + if insert_lets: + self.builder = [] + self.context.append(PrintStackFrame(binding_sites[id(x)])) + head = x.render_head(self) + if head != '': + self.builder.append(head) + return insert_lets + + @staticmethod + def post_lift_visit(state, let_body, lift_to, name): + let_body.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + state.context[lift_to].let_bodies.append(let_body) + state.builder.append(f'(Ref {name})') + state.i += 1 + def loop(self, state, k): if state.i >= len(state.children): return k() @@ -315,28 +335,47 @@ def loop(self, state, k): state.context[lift_to].visited[id(child)] = child let_body = [f'(Let {name} '] - def post_lift_visit(): - let_body.append(' ') - # let_bodies is built post-order, which guarantees earlier - # lets can't refer to later lets - state.context[lift_to].let_bodies.append(let_body) - state.builder.append(f'(Ref {name})') - state.i += 1 + new_state = self.State(child, child.render_children(self), let_body, state.context, child_outermost_scope, state.depth + 1) + x = new_state.x + builder = state.builder + if id(x) in self.memo: + new_state.builder.append(self.memo[id(x)]) + self.post_lift_visit(state, let_body, lift_to, name) return self.loop(state, k) - new_state = self.State(child, child.render_children(self), let_body, state.context, child_outermost_scope, state.depth + 1) - return self.print(new_state, post_lift_visit) + insert_lets = new_state.pre_add_lets(self.binding_sites) + + def post_children(): + new_state.builder.append(new_state.x.render_tail(self)) + if insert_lets: + self.add_lets(new_state, builder) + self.post_lift_visit(state, let_body, lift_to, name) + return self.loop(state, k) + + return self.loop(new_state, post_children) else: state.builder.append(f'(Ref {name})') state.i += 1 return self.loop(state, k) else: - def post_visit(): - state.i += 1 - return self.loop(state, k) if isinstance(child, ir.BaseIR): new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1) - return self.print(new_state, post_visit) + + if id(child) in self.memo: + new_state.builder.append(self.memo[id(child)]) + state.i += 1 + return self.loop(state, k) + + insert_lets = new_state.pre_add_lets(self.binding_sites) + + def post_children(): + new_state.builder.append(child.render_tail(self)) + if insert_lets: + self.add_lets(new_state, state.builder) + state.i += 1 + return self.loop(state, k) + + return self.loop(new_state, post_children) else: head = child.render_head(self) if head != '': @@ -352,28 +391,6 @@ def post_children_renderable(): return self.loop2(new_state, post_children_renderable) - def print(self, state, k): - x = state.x - builder = state.builder - if id(x) in self.memo: - state.builder.append(self.memo[id(x)]) - return k() - insert_lets = id(x) in self.binding_sites and len(self.binding_sites[id(x)].lifted_lets) > 0 - if insert_lets: - state.builder = [] - state.context.append(PrintStackFrame(self.binding_sites[id(x)])) - head = x.render_head(self) - if head != '': - state.builder.append(head) - - def post_children(): - state.builder.append(state.x.render_tail(self)) - if insert_lets: - self.add_lets(state, builder) - return k() - - return self.loop(state, post_children) - @staticmethod def add_lets(state, builder): sf = state.context[-1] @@ -391,12 +408,27 @@ def loop2(self, state, k): state.builder.append(' ') child = state.children[state.i] if isinstance(child, ir.BaseIR): - def post_visit_2(): + new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) + + x = new_state.x + builder = new_state.builder + if id(x) in self.memo: + new_state.builder.append(self.memo[id(x)]) state.ir_child_num += 1 state.i += 1 return self.loop2(state, k) - new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) - return self.print(new_state, post_visit_2) + + insert_lets = new_state.pre_add_lets(self.binding_sites) + + def post_children(): + new_state.builder.append(new_state.x.render_tail(self)) + if insert_lets: + self.add_lets(new_state, builder) + state.ir_child_num += 1 + state.i += 1 + return self.loop2(state, k) + + return self.loop(new_state, post_children) else: head = child.render_head(self) if head != '': @@ -494,7 +526,22 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): def __call__(self, root: 'ir.BaseIR') -> str: self.binding_sites = self.compute_new_bindings(root) builder = [] - def post_root(): - return ''.join(builder) state = self.State(root, root.render_children(self), builder, [], 0, 1) - return self.print(state, post_root) + if id(root) in self.memo: + state.builder.append(self.memo[id(root)]) + return ''.join(builder) + insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 + if insert_lets: + state.builder = [] + state.context.append(PrintStackFrame(self.binding_sites[id(root)])) + head = root.render_head(self) + if head != '': + state.builder.append(head) + + def post_children(): + state.builder.append(root.render_tail(self)) + if insert_lets: + self.add_lets(state, builder) + return ''.join(builder) + + return self.loop(state, post_children) From 3db71dbbcb2afd30a1565082a130240d80b069d2 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 8 Aug 2019 15:05:37 -0400 Subject: [PATCH 18/47] wip --- hail/python/hail/ir/renderer.py | 37 ++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 4a0e187fa09..e781ca5a876 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -348,7 +348,7 @@ def loop(self, state, k): def post_children(): new_state.builder.append(new_state.x.render_tail(self)) if insert_lets: - self.add_lets(new_state, builder) + self.add_lets(new_state.context, new_state.builder, builder) self.post_lift_visit(state, let_body, lift_to, name) return self.loop(state, k) @@ -371,7 +371,7 @@ def post_children(): def post_children(): new_state.builder.append(child.render_tail(self)) if insert_lets: - self.add_lets(new_state, state.builder) + self.add_lets(new_state.context, new_state.builder, state.builder) state.i += 1 return self.loop(state, k) @@ -391,13 +391,36 @@ def post_children_renderable(): return self.loop2(new_state, post_children_renderable) + class Kont: + def __apply__(self): + pass + + class PostChildren(Kont): + def __init__(self, node, state, insert_lets, builder, local_builder, k): + self.state = state + self.insert_lets = insert_lets + self.builder = builder + self.k = k + self.local_builder = local_builder + self.node = node + self.context = state.context + + def apply(self, k: Kont): + if isinstance(k, self.PostChildren): + k.local_builder.append(k.node.render_tail(self)) + if k.insert_lets: + CSERenderer.add_lets(k.context, k.local_builder, k.builder) + k.state.ir_child_num += 1 + k.state.i += 1 + return k.loop2(k.state, k.k) + @staticmethod - def add_lets(state, builder): - sf = state.context[-1] - state.context.pop() + def add_lets(context, local_builder, builder): + sf = context[-1] + context.pop() for let_body in sf.let_bodies: builder.extend(let_body) - builder.extend(state.builder) + builder.extend(local_builder) num_lets = len(sf.lifted_lets) for _ in range(num_lets): builder.append(')') @@ -541,7 +564,7 @@ def __call__(self, root: 'ir.BaseIR') -> str: def post_children(): state.builder.append(root.render_tail(self)) if insert_lets: - self.add_lets(state, builder) + self.add_lets(state.context, state.builder, builder) return ''.join(builder) return self.loop(state, post_children) From 55c32baaf379890d1c57d55ba865852293044515 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 9 Aug 2019 15:34:17 -0400 Subject: [PATCH 19/47] defunctionalize --- hail/python/hail/ir/renderer.py | 211 ++++++++++++++------------------ 1 file changed, 90 insertions(+), 121 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index e781ca5a876..fd4d8d221d2 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -308,111 +308,125 @@ def pre_add_lets(self, binding_sites): self.builder.append(head) return insert_lets - @staticmethod - def post_lift_visit(state, let_body, lift_to, name): - let_body.append(' ') - # let_bodies is built post-order, which guarantees earlier - # lets can't refer to later lets - state.context[lift_to].let_bodies.append(let_body) - state.builder.append(f'(Ref {name})') - state.i += 1 - def loop(self, state, k): if state.i >= len(state.children): - return k() + return self.apply(k) state.builder.append(' ') child = state.children[state.i] - lift_to = self.lifted_in_scope(child, state.context) - child_outermost_scope = state.outermost_scope - if state.x.new_block(state.ir_child_num): - child_outermost_scope = state.depth + + new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) + + if isinstance(state.x, ir.BaseIR): + if state.x.new_block(state.ir_child_num): + new_state.outermost_scope = state.depth + if isinstance(child, ir.BaseIR): + new_state.depth += 1 + lift_to = self.lifted_in_scope(child, state.context) + else: + lift_to = -1 + if (lift_to >= 0 and state.context[lift_to] and state.context[lift_to].depth >= state.outermost_scope): state.ir_child_num += 1 (name, _) = state.context[lift_to].lifted_lets[id(child)] - if id(child) not in state.context[lift_to].visited: - state.context[lift_to].visited[id(child)] = child - let_body = [f'(Let {name} '] - - new_state = self.State(child, child.render_children(self), let_body, state.context, child_outermost_scope, state.depth + 1) - x = new_state.x - builder = state.builder - if id(x) in self.memo: - new_state.builder.append(self.memo[id(x)]) - self.post_lift_visit(state, let_body, lift_to, name) - return self.loop(state, k) - - insert_lets = new_state.pre_add_lets(self.binding_sites) - - def post_children(): - new_state.builder.append(new_state.x.render_tail(self)) - if insert_lets: - self.add_lets(new_state.context, new_state.builder, builder) - self.post_lift_visit(state, let_body, lift_to, name) - return self.loop(state, k) - - return self.loop(new_state, post_children) - else: + + if id(child) in state.context[lift_to].visited: state.builder.append(f'(Ref {name})') state.i += 1 return self.loop(state, k) - else: - if isinstance(child, ir.BaseIR): - new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1) - if id(child) in self.memo: - new_state.builder.append(self.memo[id(child)]) - state.i += 1 - return self.loop(state, k) + state.context[lift_to].visited[id(child)] = child - insert_lets = new_state.pre_add_lets(self.binding_sites) + new_state.builder = [f'(Let {name} '] - def post_children(): - new_state.builder.append(child.render_tail(self)) - if insert_lets: - self.add_lets(new_state.context, new_state.builder, state.builder) - state.i += 1 - return self.loop(state, k) + if id(child) in self.memo: + pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name, k) + return self.apply(pcl) - return self.loop(new_state, post_children) - else: - head = child.render_head(self) - if head != '': - state.builder.append(head) + insert_lets = new_state.pre_add_lets(self.binding_sites) + assert(not insert_lets) + + pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name, k) + + return self.loop(new_state, pcl) + + if isinstance(child, ir.BaseIR): + if id(child) in self.memo: + new_state.builder.append(*self.memo[id(child)]) + state.i += 1 + return self.loop(state, k) + + insert_lets = new_state.pre_add_lets(self.binding_sites) + + pc = self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets, k) + + return self.loop(new_state, pc) + else: + head = child.render_head(self) + if head != '': + state.builder.append(head) - new_state = self.State(child, child.render_children(self), state.builder, state.context, child_outermost_scope, state.depth + 1) - new_state.ir_child_num = state.ir_child_num + new_state.ir_child_num = state.ir_child_num + insert_lets = False - def post_children_renderable(): - new_state.builder.append(new_state.x.render_tail(self)) - state.i += 1 - return self.loop(state, k) + pc = self.PostChildren(state, state.builder, new_state.builder, new_state.x.render_tail(self), insert_lets, k) - return self.loop2(new_state, post_children_renderable) + return self.loop(new_state, pc) class Kont: - def __apply__(self): - pass + pass + + class PostChildrenLifted(Kont): + def __init__(self, state, builder, local_builder, tail, lift_to, name, k): + self.state = state + self.builder = builder + self.k = k + self.local_builder = local_builder + self.tail = tail + self.lift_to = lift_to + self.name = name class PostChildren(Kont): - def __init__(self, node, state, insert_lets, builder, local_builder, k): + def __init__(self, state, builder, local_builder, tail, insert_lets, k): self.state = state - self.insert_lets = insert_lets self.builder = builder self.k = k self.local_builder = local_builder - self.node = node - self.context = state.context + self.tail = tail + self.insert_lets = insert_lets + + class PostRoot(Kont): + def __init__(self, state, builder, local_builder, tail, insert_lets): + self.state = state + self.builder = builder + self.local_builder = local_builder + self.tail = tail + self.insert_lets = insert_lets def apply(self, k: Kont): - if isinstance(k, self.PostChildren): - k.local_builder.append(k.node.render_tail(self)) + if isinstance(k, self.PostChildrenLifted): + k.local_builder.extend(k.tail) + k.local_builder.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + k.state.context[k.lift_to].let_bodies.append(k.local_builder) + k.builder.append(f'(Ref {k.name})') + k.state.i += 1 + return self.loop(k.state, k.k) + elif isinstance(k, self.PostChildren): + k.local_builder.extend(k.tail) if k.insert_lets: - CSERenderer.add_lets(k.context, k.local_builder, k.builder) - k.state.ir_child_num += 1 + self.add_lets(k.state.context, k.local_builder, k.builder) + if not isinstance(k.state.x, ir.BaseIR): + k.state.ir_child_num += 1 k.state.i += 1 - return k.loop2(k.state, k.k) + return self.loop(k.state, k.k) + else: + k.local_builder.append(k.tail) + if k.insert_lets: + self.add_lets(k.state.context, k.local_builder, k.builder) + return ''.join(k.builder) @staticmethod def add_lets(context, local_builder, builder): @@ -425,47 +439,6 @@ def add_lets(context, local_builder, builder): for _ in range(num_lets): builder.append(')') - def loop2(self, state, k): - if state.i >= len(state.children): - return k() - state.builder.append(' ') - child = state.children[state.i] - if isinstance(child, ir.BaseIR): - new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) - - x = new_state.x - builder = new_state.builder - if id(x) in self.memo: - new_state.builder.append(self.memo[id(x)]) - state.ir_child_num += 1 - state.i += 1 - return self.loop2(state, k) - - insert_lets = new_state.pre_add_lets(self.binding_sites) - - def post_children(): - new_state.builder.append(new_state.x.render_tail(self)) - if insert_lets: - self.add_lets(new_state, builder) - state.ir_child_num += 1 - state.i += 1 - return self.loop2(state, k) - - return self.loop(new_state, post_children) - else: - head = child.render_head(self) - if head != '': - state.builder.append(head) - - new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) - new_state.ir_child_num = state.ir_child_num - - def post_children_renderable(): - new_state.builder.append(new_state.x.render_tail(self)) - state.i += 1 - return self.loop2(state, k) - return self.loop2(new_state, post_children_renderable) - def compute_new_bindings(self, root: 'ir.BaseIR'): # force computation of all types root.typ @@ -561,10 +534,6 @@ def __call__(self, root: 'ir.BaseIR') -> str: if head != '': state.builder.append(head) - def post_children(): - state.builder.append(root.render_tail(self)) - if insert_lets: - self.add_lets(state.context, state.builder, builder) - return ''.join(builder) + pc = self.PostRoot(state, builder, state.builder, root.render_tail(self), insert_lets) - return self.loop(state, post_children) + return self.loop(state, pc) From 50ff007079592bbb23faffd56ffb6904cde1ade9 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 9 Aug 2019 16:18:55 -0400 Subject: [PATCH 20/47] recursion -> loop --- hail/python/hail/ir/renderer.py | 182 ++++++++++++++++---------------- 1 file changed, 93 insertions(+), 89 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index fd4d8d221d2..a925ff40ead 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -308,72 +308,6 @@ def pre_add_lets(self, binding_sites): self.builder.append(head) return insert_lets - def loop(self, state, k): - if state.i >= len(state.children): - return self.apply(k) - state.builder.append(' ') - child = state.children[state.i] - - new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) - - if isinstance(state.x, ir.BaseIR): - if state.x.new_block(state.ir_child_num): - new_state.outermost_scope = state.depth - if isinstance(child, ir.BaseIR): - new_state.depth += 1 - lift_to = self.lifted_in_scope(child, state.context) - else: - lift_to = -1 - - if (lift_to >= 0 and - state.context[lift_to] and - state.context[lift_to].depth >= state.outermost_scope): - state.ir_child_num += 1 - (name, _) = state.context[lift_to].lifted_lets[id(child)] - - if id(child) in state.context[lift_to].visited: - state.builder.append(f'(Ref {name})') - state.i += 1 - return self.loop(state, k) - - state.context[lift_to].visited[id(child)] = child - - new_state.builder = [f'(Let {name} '] - - if id(child) in self.memo: - pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name, k) - return self.apply(pcl) - - insert_lets = new_state.pre_add_lets(self.binding_sites) - assert(not insert_lets) - - pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name, k) - - return self.loop(new_state, pcl) - - if isinstance(child, ir.BaseIR): - if id(child) in self.memo: - new_state.builder.append(*self.memo[id(child)]) - state.i += 1 - return self.loop(state, k) - - insert_lets = new_state.pre_add_lets(self.binding_sites) - - pc = self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets, k) - - return self.loop(new_state, pc) - else: - head = child.render_head(self) - if head != '': - state.builder.append(head) - - new_state.ir_child_num = state.ir_child_num - insert_lets = False - - pc = self.PostChildren(state, state.builder, new_state.builder, new_state.x.render_tail(self), insert_lets, k) - - return self.loop(new_state, pc) - class Kont: pass @@ -404,29 +338,99 @@ def __init__(self, state, builder, local_builder, tail, insert_lets): self.tail = tail self.insert_lets = insert_lets - def apply(self, k: Kont): - if isinstance(k, self.PostChildrenLifted): - k.local_builder.extend(k.tail) - k.local_builder.append(' ') - # let_bodies is built post-order, which guarantees earlier - # lets can't refer to later lets - k.state.context[k.lift_to].let_bodies.append(k.local_builder) - k.builder.append(f'(Ref {k.name})') - k.state.i += 1 - return self.loop(k.state, k.k) - elif isinstance(k, self.PostChildren): - k.local_builder.extend(k.tail) - if k.insert_lets: - self.add_lets(k.state.context, k.local_builder, k.builder) - if not isinstance(k.state.x, ir.BaseIR): - k.state.ir_child_num += 1 - k.state.i += 1 - return self.loop(k.state, k.k) - else: - k.local_builder.append(k.tail) - if k.insert_lets: - self.add_lets(k.state.context, k.local_builder, k.builder) - return ''.join(k.builder) + def loop(self, state, k): + while True: + if state.i >= len(state.children): + if isinstance(k, self.PostChildrenLifted): + k.local_builder.extend(k.tail) + k.local_builder.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + k.state.context[k.lift_to].let_bodies.append(k.local_builder) + k.builder.append(f'(Ref {k.name})') + state = k.state + state.i += 1 + k = k.k + continue + elif isinstance(k, self.PostChildren): + k.local_builder.extend(k.tail) + if k.insert_lets: + self.add_lets(k.state.context, k.local_builder, k.builder) + if not isinstance(k.state.x, ir.BaseIR): + k.state.ir_child_num += 1 + state = k.state + state.i += 1 + k = k.k + continue + else: + assert isinstance(k, self.PostRoot) + k.local_builder.append(k.tail) + if k.insert_lets: + self.add_lets(k.state.context, k.local_builder, k.builder) + return ''.join(k.builder) + + state.builder.append(' ') + child = state.children[state.i] + + new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) + + if isinstance(state.x, ir.BaseIR): + if state.x.new_block(state.ir_child_num): + new_state.outermost_scope = state.depth + if isinstance(child, ir.BaseIR): + new_state.depth += 1 + lift_to = self.lifted_in_scope(child, state.context) + else: + lift_to = -1 + + if (lift_to >= 0 and + state.context[lift_to] and + state.context[lift_to].depth >= state.outermost_scope): + state.ir_child_num += 1 + (name, _) = state.context[lift_to].lifted_lets[id(child)] + + if id(child) in state.context[lift_to].visited: + state.builder.append(f'(Ref {name})') + state.i += 1 + continue + + state.context[lift_to].visited[id(child)] = child + + new_state.builder = [f'(Let {name} '] + + if id(child) in self.memo: + pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name, k) + return self.apply(pcl) + + insert_lets = new_state.pre_add_lets(self.binding_sites) + assert(not insert_lets) + + k = self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name, k) + state = new_state + continue + + if isinstance(child, ir.BaseIR): + if id(child) in self.memo: + new_state.builder.append(*self.memo[id(child)]) + state.i += 1 + continue + + insert_lets = new_state.pre_add_lets(self.binding_sites) + + k = self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets, k) + state = new_state + continue + else: + head = child.render_head(self) + if head != '': + state.builder.append(head) + + new_state.ir_child_num = state.ir_child_num + insert_lets = False + + k = self.PostChildren(state, state.builder, new_state.builder, new_state.x.render_tail(self), insert_lets, k) + state = new_state + continue @staticmethod def add_lets(context, local_builder, builder): From 6ca6b238d99fb3c59671161b003ed97734c27b75 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 12 Aug 2019 10:48:31 -0400 Subject: [PATCH 21/47] linkedlist -> stack --- hail/python/hail/ir/renderer.py | 76 ++++++++++++++++----------------- 1 file changed, 36 insertions(+), 40 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index a925ff40ead..1b32f1d151f 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -312,25 +312,15 @@ class Kont: pass class PostChildrenLifted(Kont): - def __init__(self, state, builder, local_builder, tail, lift_to, name, k): + def __init__(self, state, builder, local_builder, tail, lift_to, name): self.state = state self.builder = builder - self.k = k self.local_builder = local_builder self.tail = tail self.lift_to = lift_to self.name = name class PostChildren(Kont): - def __init__(self, state, builder, local_builder, tail, insert_lets, k): - self.state = state - self.builder = builder - self.k = k - self.local_builder = local_builder - self.tail = tail - self.insert_lets = insert_lets - - class PostRoot(Kont): def __init__(self, state, builder, local_builder, tail, insert_lets): self.state = state self.builder = builder @@ -338,9 +328,33 @@ def __init__(self, state, builder, local_builder, tail, insert_lets): self.tail = tail self.insert_lets = insert_lets - def loop(self, state, k): + def loop(self, root): + root_builder = [] + root_state = self.State(root, root.render_children(self), root_builder, [], 0, 1) + if id(root) in self.memo: + root_state.builder.append(self.memo[id(root)]) + return ''.join(root_builder) + insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 + if insert_lets: + root_state.builder = [] + root_state.context.append(PrintStackFrame(self.binding_sites[id(root)])) + head = root.render_head(self) + if head != '': + root_state.builder.append(head) + + stack = [] + + state = root_state + while True: if state.i >= len(state.children): + if not stack: + root_state.builder.append(root.render_tail(self)) + insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 + if insert_lets: + self.add_lets(root_state.context, root_state.builder, root_builder) + return ''.join(root_builder) + k = stack[-1] if isinstance(k, self.PostChildrenLifted): k.local_builder.extend(k.tail) k.local_builder.append(' ') @@ -350,9 +364,10 @@ def loop(self, state, k): k.builder.append(f'(Ref {k.name})') state = k.state state.i += 1 - k = k.k + stack.pop() continue - elif isinstance(k, self.PostChildren): + else: + assert isinstance(k, self.PostChildren) k.local_builder.extend(k.tail) if k.insert_lets: self.add_lets(k.state.context, k.local_builder, k.builder) @@ -360,14 +375,8 @@ def loop(self, state, k): k.state.ir_child_num += 1 state = k.state state.i += 1 - k = k.k + stack.pop() continue - else: - assert isinstance(k, self.PostRoot) - k.local_builder.append(k.tail) - if k.insert_lets: - self.add_lets(k.state.context, k.local_builder, k.builder) - return ''.join(k.builder) state.builder.append(' ') child = state.children[state.i] @@ -399,13 +408,14 @@ def loop(self, state, k): new_state.builder = [f'(Let {name} '] if id(child) in self.memo: - pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name, k) + pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name) + # FIXME return self.apply(pcl) insert_lets = new_state.pre_add_lets(self.binding_sites) assert(not insert_lets) - k = self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name, k) + stack.append(self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name)) state = new_state continue @@ -417,7 +427,7 @@ def loop(self, state, k): insert_lets = new_state.pre_add_lets(self.binding_sites) - k = self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets, k) + stack.append(self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets)) state = new_state continue else: @@ -428,7 +438,7 @@ def loop(self, state, k): new_state.ir_child_num = state.ir_child_num insert_lets = False - k = self.PostChildren(state, state.builder, new_state.builder, new_state.x.render_tail(self), insert_lets, k) + stack.append(self.PostChildren(state, state.builder, new_state.builder, new_state.x.render_tail(self), insert_lets)) state = new_state continue @@ -525,19 +535,5 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): def __call__(self, root: 'ir.BaseIR') -> str: self.binding_sites = self.compute_new_bindings(root) - builder = [] - state = self.State(root, root.render_children(self), builder, [], 0, 1) - if id(root) in self.memo: - state.builder.append(self.memo[id(root)]) - return ''.join(builder) - insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 - if insert_lets: - state.builder = [] - state.context.append(PrintStackFrame(self.binding_sites[id(root)])) - head = root.render_head(self) - if head != '': - state.builder.append(head) - - pc = self.PostRoot(state, builder, state.builder, root.render_tail(self), insert_lets) - return self.loop(state, pc) + return self.loop(root) From d2df59240d3e85521b9049b35b8d19aa61085fce Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 12 Aug 2019 11:46:21 -0400 Subject: [PATCH 22/47] refactoring --- hail/python/hail/ir/renderer.py | 53 ++++++++++++++++----------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 1b32f1d151f..42c7c927f91 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -280,29 +280,21 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: # any lets to be inserted above y. class State: - ... - def __init__(self, x, children, builder, context, outermost_scope, depth): + def __init__(self, x, children, builder, outermost_scope, depth): self.x = x self.children = children self.builder = builder - self.context = context self.outermost_scope = outermost_scope self.ir_child_num = 0 self.depth = depth self.i = 0 - def copy(self): - new_state = CSERenderer.State(self.x, self.children, self.builder, self.context, self.outermost_scope, self.depth) - new_state.ir_child_num = self.ir_child_num - new_state.i = self.i - return new_state - - def pre_add_lets(self, binding_sites): + def pre_add_lets(self, binding_sites, context): x = self.x insert_lets = id(x) in binding_sites and len(binding_sites[id(x)].lifted_lets) > 0 if insert_lets: self.builder = [] - self.context.append(PrintStackFrame(binding_sites[id(x)])) + context.append(PrintStackFrame(binding_sites[id(x)])) head = x.render_head(self) if head != '': self.builder.append(head) @@ -330,29 +322,34 @@ def __init__(self, state, builder, local_builder, tail, insert_lets): def loop(self, root): root_builder = [] - root_state = self.State(root, root.render_children(self), root_builder, [], 0, 1) + context = [] + root_state = self.State(root, root.render_children(self), root_builder, 0, 1) if id(root) in self.memo: root_state.builder.append(self.memo[id(root)]) return ''.join(root_builder) insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 if insert_lets: root_state.builder = [] - root_state.context.append(PrintStackFrame(self.binding_sites[id(root)])) + context.append(PrintStackFrame(self.binding_sites[id(root)])) head = root.render_head(self) if head != '': root_state.builder.append(head) - stack = [] + root_frame = lambda x: None + root_frame.state = root_state + # stack = [root_frame] + stack = [None] state = root_state while True: + # state = stack[-1].state if state.i >= len(state.children): - if not stack: + if len(stack) <= 1: root_state.builder.append(root.render_tail(self)) insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 if insert_lets: - self.add_lets(root_state.context, root_state.builder, root_builder) + self.add_lets(context, root_state.builder, root_builder) return ''.join(root_builder) k = stack[-1] if isinstance(k, self.PostChildrenLifted): @@ -360,7 +357,7 @@ def loop(self, root): k.local_builder.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets - k.state.context[k.lift_to].let_bodies.append(k.local_builder) + context[k.lift_to].let_bodies.append(k.local_builder) k.builder.append(f'(Ref {k.name})') state = k.state state.i += 1 @@ -370,7 +367,7 @@ def loop(self, root): assert isinstance(k, self.PostChildren) k.local_builder.extend(k.tail) if k.insert_lets: - self.add_lets(k.state.context, k.local_builder, k.builder) + self.add_lets(context, k.local_builder, k.builder) if not isinstance(k.state.x, ir.BaseIR): k.state.ir_child_num += 1 state = k.state @@ -381,38 +378,38 @@ def loop(self, root): state.builder.append(' ') child = state.children[state.i] - new_state = self.State(child, child.render_children(self), state.builder, state.context, state.outermost_scope, state.depth) + new_state = self.State(child, child.render_children(self), state.builder, state.outermost_scope, state.depth) if isinstance(state.x, ir.BaseIR): if state.x.new_block(state.ir_child_num): new_state.outermost_scope = state.depth if isinstance(child, ir.BaseIR): new_state.depth += 1 - lift_to = self.lifted_in_scope(child, state.context) + lift_to = self.lifted_in_scope(child, context) else: lift_to = -1 if (lift_to >= 0 and - state.context[lift_to] and - state.context[lift_to].depth >= state.outermost_scope): + context[lift_to] and + context[lift_to].depth >= state.outermost_scope): state.ir_child_num += 1 - (name, _) = state.context[lift_to].lifted_lets[id(child)] + (name, _) = context[lift_to].lifted_lets[id(child)] - if id(child) in state.context[lift_to].visited: + if id(child) in context[lift_to].visited: state.builder.append(f'(Ref {name})') state.i += 1 continue - state.context[lift_to].visited[id(child)] = child + context[lift_to].visited[id(child)] = child new_state.builder = [f'(Let {name} '] if id(child) in self.memo: - pcl = self.PostChildrenLifted(state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name) + pcl = self.PostChildrenLifted(new_state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name) # FIXME return self.apply(pcl) - insert_lets = new_state.pre_add_lets(self.binding_sites) + insert_lets = new_state.pre_add_lets(self.binding_sites, context) assert(not insert_lets) stack.append(self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name)) @@ -425,7 +422,7 @@ def loop(self, root): state.i += 1 continue - insert_lets = new_state.pre_add_lets(self.binding_sites) + insert_lets = new_state.pre_add_lets(self.binding_sites, context) stack.append(self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets)) state = new_state From 1cccb0516018652e9896d9308cc0aae4a397ffc2 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Tue, 13 Aug 2019 10:05:40 -0400 Subject: [PATCH 23/47] refactoring, bugfixes --- hail/python/hail/ir/renderer.py | 66 ++++++++++++++++----------------- 1 file changed, 31 insertions(+), 35 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 42c7c927f91..04b400adb71 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -137,9 +137,10 @@ def __call__(self, x: 'Renderable'): class BindingSite: - def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int): + def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'): self.depth = depth self.lifted_lets = lifted_lets + self.node = node class PrintStackFrame: @@ -252,14 +253,14 @@ def add_jir(self, x: 'ir.BaseIR'): @staticmethod def find_in_scope(x: 'ir.BaseIR', context: List[StackFrame], outermost_scope: int) -> int: - for i in range(len(context) - 1, outermost_scope - 1, -1): + for i in reversed(range(len(context))): if id(x) in context[i].visited: return i return -1 @staticmethod def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: - for i in range(len(context) - 1, -1, -1): + for i in range(len(context)): if id(x) in context[i].lifted_lets: return i return -1 @@ -304,20 +305,18 @@ class Kont: pass class PostChildrenLifted(Kont): - def __init__(self, state, builder, local_builder, tail, lift_to, name): + def __init__(self, state, builder, lift_to, name): self.state = state self.builder = builder - self.local_builder = local_builder - self.tail = tail + self.local_builder = state.builder self.lift_to = lift_to self.name = name class PostChildren(Kont): - def __init__(self, state, builder, local_builder, tail, insert_lets): + def __init__(self, state, builder, insert_lets): self.state = state self.builder = builder - self.local_builder = local_builder - self.tail = tail + self.local_builder = state.builder self.insert_lets = insert_lets def loop(self, root): @@ -337,13 +336,11 @@ def loop(self, root): root_frame = lambda x: None root_frame.state = root_state - # stack = [root_frame] - stack = [None] - - state = root_state + stack = [root_frame] while True: - # state = stack[-1].state + state = stack[-1].state + if state.i >= len(state.children): if len(stack) <= 1: root_state.builder.append(root.render_tail(self)) @@ -353,26 +350,28 @@ def loop(self, root): return ''.join(root_builder) k = stack[-1] if isinstance(k, self.PostChildrenLifted): - k.local_builder.extend(k.tail) + if id(k.state.x) in self.memo: + k.local_builder.extend(self.memo[id(k.state.x)]) + else: + k.local_builder.append(k.state.x.render_tail(self)) k.local_builder.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets context[k.lift_to].let_bodies.append(k.local_builder) k.builder.append(f'(Ref {k.name})') - state = k.state - state.i += 1 stack.pop() + stack[-1].state.i += 1 continue else: assert isinstance(k, self.PostChildren) - k.local_builder.extend(k.tail) + k.local_builder.append(k.state.x.render_tail(self)) if k.insert_lets: self.add_lets(context, k.local_builder, k.builder) - if not isinstance(k.state.x, ir.BaseIR): - k.state.ir_child_num += 1 - state = k.state - state.i += 1 stack.pop() + state = stack[-1].state + if not isinstance(state.x, ir.BaseIR): + state.ir_child_num += 1 + state.i += 1 continue state.builder.append(' ') @@ -405,27 +404,25 @@ def loop(self, root): new_state.builder = [f'(Let {name} '] if id(child) in self.memo: - pcl = self.PostChildrenLifted(new_state, state.builder, new_state.builder, self.memo[id(child)], lift_to, name) - # FIXME - return self.apply(pcl) + new_state.children = [] + stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) + continue insert_lets = new_state.pre_add_lets(self.binding_sites, context) assert(not insert_lets) - stack.append(self.PostChildrenLifted(state, state.builder, new_state.builder, [child.render_tail(self)], lift_to, name)) - state = new_state + stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) continue if isinstance(child, ir.BaseIR): if id(child) in self.memo: - new_state.builder.append(*self.memo[id(child)]) + new_state.builder.extend(self.memo[id(child)]) state.i += 1 continue insert_lets = new_state.pre_add_lets(self.binding_sites, context) - stack.append(self.PostChildren(state, state.builder, new_state.builder, [child.render_tail(self)], insert_lets)) - state = new_state + stack.append(self.PostChildren(new_state, state.builder, insert_lets)) continue else: head = child.render_head(self) @@ -435,8 +432,7 @@ def loop(self, root): new_state.ir_child_num = state.ir_child_num insert_lets = False - stack.append(self.PostChildren(state, state.builder, new_state.builder, new_state.x.render_tail(self), insert_lets)) - state = new_state + stack.append(self.PostChildren(new_state, state.builder, insert_lets)) continue @staticmethod @@ -446,7 +442,7 @@ def add_lets(context, local_builder, builder): for let_body in sf.let_bodies: builder.extend(let_body) builder.extend(local_builder) - num_lets = len(sf.lifted_lets) + num_lets = len(sf.let_bodies) for _ in range(num_lets): builder.append(')') @@ -477,7 +473,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): # binding sites if frame.lifted_lets: binding_sites[id(node)] = \ - BindingSite(frame.lifted_lets, len(stack)) + BindingSite(frame.lifted_lets, len(stack), node) frame.update_parent_free_vars(parent_frame.free_vars) @@ -527,7 +523,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): assert var_depth == 0 assert(len(root_frame.free_vars) == 0) - binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0) + binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0, root) return binding_sites def __call__(self, root: 'ir.BaseIR') -> str: From fb6e986dd49de59107d776798a7d2e8c7b8f0324 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 14 Aug 2019 10:43:19 -0400 Subject: [PATCH 24/47] bugfixes --- hail/python/hail/expr/matrix_type.py | 57 +++-- hail/python/hail/expr/table_type.py | 20 +- hail/python/hail/ir/base_ir.py | 64 +++-- hail/python/hail/ir/blockmatrix_ir.py | 18 ++ hail/python/hail/ir/ir.py | 234 +++++++++++++----- hail/python/hail/ir/matrix_ir.py | 76 +++--- hail/python/hail/ir/renderer.py | 91 ++++--- hail/python/hail/ir/table_ir.py | 36 +-- .../scala/is/hail/expr/ir/LowerMatrixIR.scala | 7 +- 9 files changed, 384 insertions(+), 219 deletions(-) diff --git a/hail/python/hail/expr/matrix_type.py b/hail/python/hail/expr/matrix_type.py index 3f4144d9a62..bb812010d07 100644 --- a/hail/python/hail/expr/matrix_type.py +++ b/hail/python/hail/expr/matrix_type.py @@ -122,34 +122,39 @@ def _rename(self, global_map, col_map, row_map, entry_map): [row_map.get(k, k) for k in self.row_key], self.entry_type._rename(entry_map)) - def global_bindings(self): - return [('global', self.global_type)] - - def global_env(self): - return dict(self.global_bindings()) - - def row_bindings(self): - return [('global', self.global_type), - ('va', self.row_type)] - - def row_env(self): - return dict(self.row_bindings()) - - def col_bindings(self): - return [('global', self.global_type), - ('sa', self.col_type)] - - def col_env(self): - return dict(self.col_bindings()) + def global_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type} + else: + return {'global': default_value} - def entry_bindings(self): - return [('global', self.global_type), - ('va', self.row_type), - ('sa', self.col_type), - ('g', self.entry_type)] + def row_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, + 'va': self.row_type} + else: + return {'global': default_value, + 'va': default_value} - def entry_env(self): - return dict(self.entry_bindings()) + def col_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, + 'sa': self.col_type} + else: + return {'global': default_value, + 'sa': default_value} + + def entry_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, + 'va': self.row_type, + 'sa': self.col_type, + 'g': self.entry_type} + else: + return {'global': default_value, + 'va': default_value, + 'sa': default_value, + 'g': default_value} import pprint diff --git a/hail/python/hail/expr/table_type.py b/hail/python/hail/expr/table_type.py index b908edcd792..06c6fe1a041 100644 --- a/hail/python/hail/expr/table_type.py +++ b/hail/python/hail/expr/table_type.py @@ -80,17 +80,17 @@ def _rename(self, global_map, row_map): self.row_type._rename(row_map), [row_map.get(k, k) for k in self.row_key]) - def row_env(self): - return dict(self.row_bindings()) - - def row_bindings(self): - return [('global', self.global_type), ('row', self.row_type)] - - def global_env(self): - return dict(self.global_bindings()) + def row_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, 'row': self.row_type} + else: + return {'global': default_value, 'row': default_value} - def global_bindings(self): - return [('global', self.global_type)] + def global_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type} + else: + return {'global': default_value} import pprint diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 21e50d30cbd..1e852174975 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -7,7 +7,7 @@ from .renderer import Renderer, Renderable, RenderableStr -def _env_bind(env, *bindings): +def _env_bind(env, bindings): if bindings: if env: res = env.copy() @@ -85,20 +85,33 @@ def _eq(self, other): def __hash__(self): return 31 + hash(str(self)) + @abc.abstractmethod def new_block(self, i: int) -> bool: - return self.uses_agg_context(i) or self.uses_scan_context(i) + ... + + @staticmethod + def is_effectful() -> bool: + return False def binds(self, i: int) -> bool: return False - def bindings(self, i: int) -> List[Tuple[str, HailType]]: - return [] + def bindings(self, i: int, default_value=None): + """Compute variables bound in child 'i'. - def agg_bindings(self, i: int) -> List[Tuple[str, HailType]]: - return [] + Returns + ------- + dict + mapping from bound variables to 'default_value', if provided, + otherwise to their types + """ + return {} + + def agg_bindings(self, i: int, default_value=None): + return {} - def scan_bindings(self, i: int) -> List[Tuple[str, HailType]]: - return [] + def scan_bindings(self, i: int, default_value=None): + return {} def uses_agg_context(self, i: int) -> bool: return False @@ -119,15 +132,11 @@ def child_context(self, i: int, parent_context, default_value=None): base = self.child_context_without_bindings(i, parent_context) if not self.binds(i): return base - eval_b = self.bindings(i) - agg_b = self.agg_bindings(i) - scan_b = self.scan_bindings(i) - if default_value is not None: - eval_b = [(var, default_value) for (var, _) in eval_b] - agg_b = [(var, default_value) for (var, _) in agg_b] - scan_b = [(var, default_value) for (var, _) in scan_b] + eval_b = self.bindings(i, default_value) + agg_b = self.agg_bindings(i, default_value) + scan_b = self.scan_bindings(i, default_value) (eval_c, agg_c, scan_c) = base - return _env_bind(eval_c, *eval_b), _env_bind(agg_c, *agg_b), _env_bind(scan_c, *scan_b) + return _env_bind(eval_c, eval_b), _env_bind(agg_c, agg_b), _env_bind(scan_c, scan_b) class IR(BaseIR): @@ -177,15 +186,21 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return self.uses_agg_context(i) or self.uses_scan_context(i) + @abc.abstractmethod def _compute_type(self, env, agg_env): raise NotImplementedError(self) def parse(self, code, ref_map={}, ir_map={}): - return Env.hail().expr.ir.IRParser.parse_value_ir( - code, - {k: t._parsable_string() for k, t in ref_map.items()}, - ir_map) + try: + return Env.hail().expr.ir.IRParser.parse_value_ir( + code, + {k: t._parsable_string() for k, t in ref_map.items()}, + ir_map) + except: + raise class TableIR(BaseIR): @@ -203,6 +218,9 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return True + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_table_ir(code, ref_map, ir_map) @@ -225,6 +243,9 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return True + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map) @@ -249,6 +270,9 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return True + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_blockmatrix_ir(code, ref_map, ir_map) diff --git a/hail/python/hail/ir/blockmatrix_ir.py b/hail/python/hail/ir/blockmatrix_ir.py index 3a18caefa08..704d29e5b9e 100644 --- a/hail/python/hail/ir/blockmatrix_ir.py +++ b/hail/python/hail/ir/blockmatrix_ir.py @@ -37,6 +37,13 @@ def __init__(self, child, f): def _compute_type(self): self._type = self.child.typ + def bindings(self, i: int, default_value=None): + if i == 1: + value = self.child.typ.element_type if default_value is None else default_value + return {'element': value} + else: + return {} + def binds(self, i): return {'element'} if i == 1 else {} @@ -53,6 +60,17 @@ def _compute_type(self): self.right.typ # Force self._type = self.left.typ + def bindings(self, i: int, default_value=None): + if i == 2: + if default_value is None: + l_value = self.left.typ.element_type + r_value = self.right.typ.element_type + else: + (l_value, r_value) = (default_value, default_value) + return {'l': l_value, 'r': r_value} + else: + return {} + def binds(self, i): return {'l', 'r'} if i == 2 else {} diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 1114f641d18..434b877a209 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -279,11 +279,18 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.value._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, *self.bindings(1)), agg_env) + self.body._compute_type(_env_bind(env, self.bindings(1)), agg_env) self._type = self.body._type - def bindings(self, i): - return [(self.name, self.value._type)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.value._type + else: + value = default_value + return {self.name: value} + else: + return {} def binds(self, i): return i == 1 @@ -552,11 +559,18 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.nd._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, (self.name, self.nd.typ.element_type)), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.nd.typ.element_type)]), agg_env) self._type = tndarray(self.body.typ, self.nd.typ.ndim) - def bindings(self, i): - return [(self.name, self.nd.typ.element_type)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.nd.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} def binds(self, i): return i == 1 @@ -679,6 +693,10 @@ def _compute_type(self, env, agg_env): self.path._compute_type(env, agg_env) self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class ArraySort(IR): @typecheck_method(a=IR, l_name=str, r_name=str, compare=IR) @@ -707,8 +725,15 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self._type = self.a.typ - def bindings(self, i): - return [(self.l_name, self.a.typ.element_type), (self.r_name, self.a.typ.element_type)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.l_name: value, self.r_name: value} + else: + return {} def binds(self, i): return i == 1 @@ -820,11 +845,18 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, (self.name, self.a.typ.element_type)), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) self._type = tarray(self.body.typ) - def bindings(self, i): - return [(self.name, self.a.typ.element_type)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} def binds(self, i): return i == 1 @@ -854,11 +886,18 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, (self.name, self.a.typ.element_type)), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) self._type = self.a.typ - def bindings(self, i): - return [(self.name, self.a.typ.element_type)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} def binds(self, i): return i == 1 @@ -888,11 +927,17 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, (self.name, self.a.typ.element_type)), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) self._type = tarray(self.body.typ.element_type) - def bindings(self, i): - return [(self.name, self.a.typ.element_type)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.name: value} + return {} def binds(self, i): return i == 1 @@ -927,13 +972,19 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) self.body._compute_type( - _env_bind(env, (self.value_name, self.a.typ.element_type), - (self.accum_name, self.zero.typ)), + _env_bind(env, [(self.value_name, self.a.typ.element_type), + (self.accum_name, self.zero.typ)]), agg_env) self._type = self.zero.typ - def bindings(self, i): - return [(self.accum_name, self.zero.typ), (self.value_name, self.a.typ.element_type)] if i == 2 else [] + def bindings(self, i, default_value=None): + if i == 2: + if default_value is None: + return {self.accum_name: self.zero.typ, self.value_name: self.a.typ.element_type} + else: + return {self.accum_name: default_value, self.value_name: default_value} + else: + return {} def binds(self, i): return i == 2 @@ -968,13 +1019,19 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) self.body._compute_type( - _env_bind(env, (self.value_name, self.a.typ.element_type), - (self.accum_name, self.zero.typ)), + _env_bind(env, [(self.value_name, self.a.typ.element_type), + (self.accum_name, self.zero.typ)]), agg_env) self._type = tarray(self.body.typ) - def bindings(self, i): - return [(self.accum_name, self.zero.typ), (self.value_name, self.a.typ.element_type)] if i == 2 else [] + def bindings(self, i, default_value=None): + if i == 2: + if default_value is None: + return {self.accum_name: self.zero.typ, self.value_name: self.a.typ.element_type} + else: + return {self.accum_name: default_value, self.value_name: default_value} + else: + return {} def binds(self, i): return i == 2 @@ -1006,10 +1063,16 @@ def _eq(self, other): def bound_variables(self): return {self.l_name, self.r_name} | super().bound_variables - def bindings(self, i): - return [(self.l_name, self.left.typ.element_type), - (self.r_name, self.right.typ.element_type)]\ - if i == 2 or i == 3 else [] + def bindings(self, i, default_value=None): + if i == 2 or i == 3: + if default_value is None: + return {self.l_name: self.left.typ.element_type, + self.r_name: self.right.typ.element_type} + else: + return {self.l_name: default_value, + self.r_name: default_value} + else: + return {} def binds(self, i): return i == 2 or i == 3 @@ -1039,11 +1102,18 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, (self.value_name, self.a.typ.element_type)), agg_env) + self.body._compute_type(_env_bind(env, [(self.value_name, self.a.typ.element_type)]), agg_env) self._type = tvoid - def bindings(self, i): - return [(self.value_name, self.a.typ.element_type)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.value_name: value} + else: + return {} def binds(self, i): return i == 1 @@ -1104,15 +1174,22 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - new_agg = _env_bind(agg_env, (self.name, self.array.typ.element_type)) - self.agg_body._compute_type(env, _env_bind(agg_env, (self.name, self.array.typ.element_type))) + new_agg = _env_bind(agg_env, [(self.name, self.array.typ.element_type)]) + self.agg_body._compute_type(env, new_agg) self._type = self.agg_body.typ - def agg_bindings(self, i): - return [(self.name, self.array.typ.element_type)] if i == 1 else [] + def agg_bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.array.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} - def scan_bindings(self, i): - return [(self.name, self.array.typ.element_type)] if i == 1 else [] + def scan_bindings(self, i, default_value=None): + return self.agg_bindings(i, default_value) def binds(self, i): return i == 1 @@ -1176,8 +1253,8 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - self.agg_ir._compute_type(_env_bind(env, (self.index_name, tint32)), - _env_bind(agg_env, (self.element_name, self.array.typ.element_type))) + self.agg_ir._compute_type(_env_bind(env, [(self.index_name, tint32)]), + _env_bind(agg_env, [(self.element_name, self.array.typ.element_type)])) self._type = tarray(self.agg_ir.typ) @property @@ -1190,14 +1267,25 @@ def uses_agg_context(self, i: int): def uses_scan_context(self, i: int): return i == 0 and self.is_scan - def bindings(self, i): - return [(self.index_name, tint32)] if i == 1 else [] + def bindings(self, i, default_value=None): + if i == 1: + value = tint32 if default_value is None else default_value + return {self.index_name: value} + else: + return {} - def agg_bindings(self, i): - return [(self.element_name, self.array.typ.element_type)] if i == 1 else [] + def agg_bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.array.typ.element_type + else: + value = default_value + return {self.element_name: value} + else: + return {} - def scan_bindings(self, i): - return [(self.element_name, self.array.typ.element_type)] if i == 1 else [] + def scan_bindings(self, i, default_value=None): + return self.agg_bindings(i, default_value) def binds(self, i): return i == 1 @@ -1575,6 +1663,10 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self._type = self._typ + @staticmethod + def is_effectful() -> bool: + return True + _function_registry = defaultdict(list) _seeded_function_registry = defaultdict(list) @@ -1696,6 +1788,10 @@ def _compute_type(self, env, agg_env): self._type = lookup_seeded_function_return_type(self.function, [a.typ for a in self.args]) + @staticmethod + def is_effectful() -> bool: + return True + class Uniroot(IR): @typecheck_method(argname=str, function=IR, min=IR, max=IR) @@ -1721,13 +1817,17 @@ def _eq(self, other): return other.argname == self.argname def _compute_type(self, env, agg_env): - self.function._compute_type(_env_bind(env, (self.argname, tfloat64)), agg_env) + self.function._compute_type(_env_bind(env, [(self.argname, tfloat64)]), agg_env) self.min._compute_type(env, agg_env) self.max._compute_type(env, agg_env) self._type = tfloat64 - def bindings(self, i): - return [(self.argname, tfloat64)] if i == 0 else [] + def bindings(self, i, default_value=None): + if i == 0: + value = tfloat64 if default_value is None else default_value + return {self.argname: value} + else: + return {} def binds(self, i): return i == 0 @@ -1797,11 +1897,11 @@ def _compute_type(self, env, agg_env): def new_block(self, i: int): return i == 1 - def bindings(self, i): - return self.child.typ.global_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} - def agg_bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def agg_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -1830,11 +1930,11 @@ def _compute_type(self, env, agg_env): def new_block(self, i: int): return i == 1 - def bindings(self, i): - return self.child.typ.global_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} - def agg_bindings(self, i): - return self.child.typ.entry_bindings() if i == 1 else [] + def agg_bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -1861,6 +1961,10 @@ def _compute_type(self, env, agg_env): self.child._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class MatrixWrite(IR): @typecheck_method(child=MatrixIR, matrix_writer=MatrixWriter) def __init__(self, child, matrix_writer): @@ -1882,6 +1986,10 @@ def _compute_type(self, env, agg_env): self.child._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class MatrixMultiWrite(IR): @typecheck_method(children=sequenceof(MatrixIR), writer=MatrixNativeMultiWriter) @@ -1903,6 +2011,10 @@ def _compute_type(self, env, agg_env): x._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class BlockMatrixWrite(IR): @typecheck_method(child=BlockMatrixIR, writer=BlockMatrixWriter) @@ -1924,6 +2036,10 @@ def _compute_type(self, env, agg_env): self.child._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class BlockMatrixMultiWrite(IR): @typecheck_method(block_matrices=sequenceof(BlockMatrixIR), writer=BlockMatrixMultiWriter) @@ -1946,6 +2062,10 @@ def _compute_type(self, env, agg_env): x._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class TableToValueApply(IR): def __init__(self, child, config): diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index 485477c2d1f..7a791550d00 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -26,21 +26,21 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) - def bindings(self, i): + def bindings(self, i, default_value=None): if i == 1: - return self.child.typ.col_bindings() + return self.child.typ.col_env(default_value) elif i == 2: - return self.child.typ.global_bindings() + return self.child.typ.global_env(default_value) else: - return [] + return {} - def agg_bindings(self, i): + def agg_bindings(self, i, default_value=None): if i == 1: - return self.child.typ.entry_bindings() + return self.child.typ.entry_env(default_value) elif i == 2: - return self.child.typ.row_bindings() + return self.child.typ.row_env(default_value) else: - return [] + return {} def binds(self, i): return i == 1 or i == 2 @@ -77,8 +77,8 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ - def bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -128,14 +128,14 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) - def bindings(self, i): - return self.child.typ.col_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.col_env(default_value) if i == 1 else {} - def agg_bindings(self, i): - return self.child.typ.entry_bindings() if i == 1 else [] + def agg_bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} - def scan_bindings(self, i): - return self.child.typ.col_bindings() if i == 1 else [] + def scan_bindings(self, i, default_value=None): + return self.child.typ.col_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -173,8 +173,8 @@ def _compute_type(self): child_typ.row_key, self.new_entry.typ) - def bindings(self, i): - return self.child.typ.entry_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -194,8 +194,8 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.entry_env(), None) self._type = self.child.typ - def bindings(self, i): - return self.child.typ.entry_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -248,14 +248,14 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) - def bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} - def agg_bindings(self, i): - return self.child.typ.entry_bindings() if i == 1 else [] + def agg_bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} - def scan_bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def scan_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -282,8 +282,8 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) - def bindings(self, i): - return self.child.typ.global_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -303,8 +303,8 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.col_env(), None) self._type = self.child.typ - def bindings(self, i): - return self.child.typ.col_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.col_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -350,21 +350,21 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) - def bindings(self, i): + def bindings(self, i, default_value=None): if i == 1: - return self.child.typ.row_bindingss() + return self.child.typ.row_env(default_value) elif i == 2: - return self.child.typ.global_bindings() + return self.child.typ.global_env(default_value) else: - return [] + return {} - def agg_bindings(self, i): + def agg_bindings(self, i, default_value=None): if i == 1: - return self.child.typ.entry_bindings() + return self.child.typ.entry_env(default_value) elif i == 2: - return self.child.typ.col_bindings() + return self.child.typ.col_env(default_value) else: - return [] + return {} def binds(self, i): return i == 1 or i == 2 diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 04b400adb71..83e87f3416e 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -173,7 +173,10 @@ def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR', # compute depth at which we might bind this node def bind_depth(self) -> int: if len(self.free_vars) > 0: - bind_depth = max(self.free_vars.values()) + try: + bind_depth = max(self.free_vars.values()) + except: + raise bind_depth = max(bind_depth, self.min_binding_depth) else: bind_depth = self.min_binding_depth @@ -189,24 +192,11 @@ def make_child_frame(self, depth: int) -> 'StackFrame': child_outermost_scope = self.min_binding_depth # compute vars bound in 'child' by 'node' - def get_vars(bindings): - return [var for (var, _) in bindings] - eval_bindings = get_vars(x.bindings(i)) - agg_bindings = get_vars(x.agg_bindings(i)) - scan_bindings = get_vars(x.scan_bindings(i)) - - child_context = x.child_context_without_bindings(i, self.context) - if x.binds(i): - (eval_c, agg_c, scan_c) = child_context - eval_c = ir.base_ir._env_bind(eval_c, *[(var, depth) for var in - eval_bindings]) - agg_c = ir.base_ir._env_bind(agg_c, *[(var, depth) for var in - agg_bindings]) - scan_c = ir.base_ir._env_bind(scan_c, *[(var, depth) for var in - scan_bindings]) - child_context = (eval_c, agg_c, scan_c) - + eval_bindings = x.bindings(i, 0).keys() + agg_bindings = x.agg_bindings(i, 0).keys() + scan_bindings = x.scan_bindings(i, 0).keys() new_bindings = (eval_bindings, agg_bindings, scan_bindings) + child_context = x.child_context(i, self.context, depth) return StackFrame(child_outermost_scope, child_context, child, new_bindings) @@ -237,19 +227,11 @@ def uid(self) -> str: self.uid_count += 1 return f'__cse_{self.uid_count}' - def add_jir(self, x: 'ir.BaseIR'): + def add_jir(self, jir): jir_id = f'm{self.jir_count}' self.jir_count += 1 - self.jirs[jir_id] = x._jir - if isinstance(x, ir.MatrixIR): - return f'(JavaMatrix {jir_id})' - elif isinstance(x, ir.TableIR): - return f'(JavaTable {jir_id})' - elif isinstance(x, ir.BlockMatrixIR): - return f'(JavaBlockMatrix {jir_id})' - else: - assert isinstance(x, ir.IR) - return f'(JavaIR {jir_id})' + self.jirs[jir_id] = jir + return jir_id @staticmethod def find_in_scope(x: 'ir.BaseIR', context: List[StackFrame], outermost_scope: int) -> int: @@ -290,16 +272,16 @@ def __init__(self, x, children, builder, outermost_scope, depth): self.depth = depth self.i = 0 - def pre_add_lets(self, binding_sites, context): - x = self.x - insert_lets = id(x) in binding_sites and len(binding_sites[id(x)].lifted_lets) > 0 - if insert_lets: - self.builder = [] - context.append(PrintStackFrame(binding_sites[id(x)])) - head = x.render_head(self) - if head != '': - self.builder.append(head) - return insert_lets + def pre_add_lets(self, state, binding_sites, context): + x = state.x + insert_lets = id(x) in binding_sites and len(binding_sites[id(x)].lifted_lets) > 0 + if insert_lets: + state.builder = [] + context.append(PrintStackFrame(binding_sites[id(x)])) + head = x.render_head(self) + if head != '': + state.builder.append(head) + return insert_lets class Kont: pass @@ -408,7 +390,7 @@ def loop(self, root): stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) continue - insert_lets = new_state.pre_add_lets(self.binding_sites, context) + insert_lets = self.pre_add_lets(new_state, self.binding_sites, context) assert(not insert_lets) stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) @@ -420,7 +402,7 @@ def loop(self, root): state.i += 1 continue - insert_lets = new_state.pre_add_lets(self.binding_sites, context) + insert_lets = self.pre_add_lets(new_state, self.binding_sites, context) stack.append(self.PostChildren(new_state, state.builder, insert_lets)) continue @@ -447,9 +429,6 @@ def add_lets(context, local_builder, builder): builder.append(')') def compute_new_bindings(self, root: 'ir.BaseIR'): - # force computation of all types - root.typ - root_frame = StackFrame(0, ({}, {}, {}), root) stack = [root_frame] binding_sites = {} @@ -467,7 +446,8 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): parent_frame = stack[-2] # mark node as visited at potential let insertion site - stack[frame.bind_depth()].visited[id(node)] = node + if not node.is_effectful(): + stack[frame.bind_depth()].visited[id(node)] = node # if any lets being inserted here, add node to registry of # binding sites @@ -483,7 +463,18 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): child = node.children[child_idx] if self.stop_at_jir and hasattr(child, '_jir'): - self.memo[id(child)] = self.add_jir(child) + jir_id = self.add_jir(child._jir) + if isinstance(child, ir.MatrixIR): + jref = f'(JavaMatrix {jir_id})' + elif isinstance(child, ir.TableIR): + jref = f'(JavaTable {jir_id})' + elif isinstance(child, ir.BlockMatrixIR): + jref = f'(JavaBlockMatrix {jir_id})' + else: + assert isinstance(child, ir.IR) + jref = f'(JavaIR {jir_id})' + + self.memo[id(child)] = jref continue seen_in_scope = self.find_in_scope(child, stack, @@ -508,11 +499,13 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): # first time visiting 'child' if isinstance(child, ir.Ref): - if child.name not in (var for (var, _) in - node.bindings(child_idx)): + if child.name not in node.bindings(child_idx, default_value=0).keys(): (eval_c, _, _) = node.child_context_without_bindings( child_idx, frame.context) - frame.free_vars.update({child.name: eval_c[child.name]}) + try: + frame.free_vars.update({child.name: eval_c[child.name]}) + except: + raise continue stack.append(frame.make_child_frame(len(stack))) diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 23b40aebe49..e17ce80bc9f 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -139,8 +139,8 @@ def _compute_type(self): def new_block(i): return i == 1 - def bindings(self, i): - return self.child.typ.global_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -202,11 +202,11 @@ def _compute_type(self): self.new_row.typ, self.child.typ.row_key) - def bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} - def scan_bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def scan_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -276,8 +276,8 @@ def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ - def bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 @@ -309,16 +309,16 @@ def _compute_type(self): self.new_key.typ._concat(self.expr.typ), list(self.new_key.typ)) - def bindings(self, i): + def bindings(self, i, default_value=None): if i == 1: - return self.child.typ.global_bindings() + return self.child.typ.global_env(default_value) elif i == 2: - return self.child.typ.row_bindings() + return self.child.typ.row_env(default_value) else: - return [] + return {} - def agg_bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def agg_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 or i == 2 @@ -341,11 +341,11 @@ def _compute_type(self): child_typ.key_type._concat(self.expr.typ), child_typ.row_key) - def bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} - def agg_bindings(self, i): - return self.child.typ.row_bindings() if i == 1 else [] + def agg_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} def binds(self, i): return i == 1 diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala index 38509dfdaf0..2d25fc8fa05 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala @@ -66,6 +66,11 @@ object LowerMatrixIR { BindingEnv(e, agg = Some(e), scan = Some(e)) } + def matrixGlobalSubstEnv(child: MatrixIR): BindingEnv[IRProxy] = { + val e = Env[IRProxy]("global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*)) + BindingEnv(e, agg = Some(e), scan = Some(e)) + } + def matrixSubstEnvIR(child: MatrixIR, lowered: TableIR): BindingEnv[IR] = { val e = Env[IR]("global" -> SelectFields(Ref("global", lowered.typ.globalType), child.typ.globalType.fieldNames), "va" -> SelectFields(Ref("row", lowered.typ.rowType), child.typ.rowType.fieldNames)) @@ -128,7 +133,7 @@ object LowerMatrixIR { irRange(0, 'global (colsField).len) .filter('i ~> (let(sa = 'global (colsField)('i)) - in subst(pred, matrixSubstEnv(child)))))) + in subst(pred, matrixGlobalSubstEnv(child)))))) .mapRows('row.insertFields(entriesField -> 'global ('newColIdx).map('i ~> 'row (entriesField)('i)))) .mapGlobals('global .insertFields(colsField -> From 427d8c78c9bc6fb067682dac6058ee753a5bc984 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 14 Aug 2019 12:50:30 -0400 Subject: [PATCH 25/47] wip --- hail/python/hail/backend/backend.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 8a993e76b05..232d6d93ea4 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -100,8 +100,15 @@ def fs(self): def _to_java_ir(self, ir): if not hasattr(ir, '_jir'): r = CSERenderer(stop_at_jir=True) + r_old = Renderer(stop_at_jir=True) # FIXME parse should be static - ir._jir = ir.parse(r(ir), ir_map=r.jirs) + no_cse = r_old(ir) + print('without cse:') + print(no_cse) + cse = r(ir) + print('with cse:') + print(cse) + ir._jir = ir.parse(cse, ir_map=r.jirs) return ir._jir def execute(self, ir, timed=False): From 6aa0ba14c3a7c35794f10d900e62d176992a63a6 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 22 Aug 2019 10:10:54 -0400 Subject: [PATCH 26/47] refactoring --- hail/python/hail/ir/renderer.py | 203 +++++++++++++++----------------- 1 file changed, 98 insertions(+), 105 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 83e87f3416e..3677c5dbfa8 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -136,26 +136,11 @@ def __call__(self, x: 'Renderable'): return ''.join(builder) -class BindingSite: - def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'): - self.depth = depth - self.lifted_lets = lifted_lets - self.node = node - - -class PrintStackFrame: - def __init__(self, binding_site: BindingSite): - self.depth = binding_site.depth - self.lifted_lets = binding_site.lifted_lets - self.visited = {} - self.let_bodies = [] - - Vars = Dict[str, int] Context = (Vars, Vars, Vars) -class StackFrame: +class AnalysisStackFrame: def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR', new_bindings=(None, None, None)): # immutable @@ -182,7 +167,7 @@ def bind_depth(self) -> int: bind_depth = self.min_binding_depth return bind_depth - def make_child_frame(self, depth: int) -> 'StackFrame': + def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': x = self.node i = self.child_idx - 1 child = x.children[i] @@ -198,8 +183,8 @@ def make_child_frame(self, depth: int) -> 'StackFrame': new_bindings = (eval_bindings, agg_bindings, scan_bindings) child_context = x.child_context(i, self.context, depth) - return StackFrame(child_outermost_scope, child_context, child, - new_bindings) + return AnalysisStackFrame(child_outermost_scope, child_context, child, + new_bindings) def update_parent_free_vars(self, parent_free_vars: Set[str]): # subtract vars bound by parent from free_vars @@ -214,6 +199,45 @@ def update_parent_free_vars(self, parent_free_vars: Set[str]): parent_free_vars.update(self.free_vars) +class BindingSite: + def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'): + self.depth = depth + self.lifted_lets = lifted_lets + self.node = node + + +class BindingsStackFrame: + def __init__(self, binding_site: BindingSite): + self.depth = binding_site.depth + self.lifted_lets = binding_site.lifted_lets + self.visited = {} + self.let_bodies = [] + + +class PrintStackFrame: + def __init__(self, x, children, builder, outermost_scope, depth): + self.x = x + self.children = children + self.local_builder = builder + self.outermost_scope = outermost_scope + self.ir_child_num = 0 + self.depth = depth + self.i = 0 + +class PostChildrenLifted(PrintStackFrame): + def __init__(self, x, children, local_builder, outermost_scope, depth, builder, lift_to, name): + super().__init__(x, children, local_builder, outermost_scope, depth) + self.builder = builder + self.lift_to = lift_to + self.name = name + +class PostChildren(PrintStackFrame): + def __init__(self, x, children, local_builder, outermost_scope, depth, builder, insert_lets): + super().__init__(x, children, local_builder, outermost_scope, depth) + self.builder = builder + self.insert_lets = insert_lets + + class CSERenderer(Renderer): def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir @@ -234,14 +258,14 @@ def add_jir(self, jir): return jir_id @staticmethod - def find_in_scope(x: 'ir.BaseIR', context: List[StackFrame], outermost_scope: int) -> int: + def find_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame], outermost_scope: int) -> int: for i in reversed(range(len(context))): if id(x) in context[i].visited: return i return -1 @staticmethod - def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: + def lifted_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame]) -> int: for i in range(len(context)): if id(x) in context[i].lifted_lets: return i @@ -262,110 +286,71 @@ def lifted_in_scope(x: 'ir.BaseIR', context: List[StackFrame]) -> int: # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing # any lets to be inserted above y. - class State: - def __init__(self, x, children, builder, outermost_scope, depth): - self.x = x - self.children = children - self.builder = builder - self.outermost_scope = outermost_scope - self.ir_child_num = 0 - self.depth = depth - self.i = 0 - - def pre_add_lets(self, state, binding_sites, context): - x = state.x - insert_lets = id(x) in binding_sites and len(binding_sites[id(x)].lifted_lets) > 0 - if insert_lets: - state.builder = [] - context.append(PrintStackFrame(binding_sites[id(x)])) - head = x.render_head(self) - if head != '': - state.builder.append(head) - return insert_lets - - class Kont: - pass - - class PostChildrenLifted(Kont): - def __init__(self, state, builder, lift_to, name): - self.state = state - self.builder = builder - self.local_builder = state.builder - self.lift_to = lift_to - self.name = name - - class PostChildren(Kont): - def __init__(self, state, builder, insert_lets): - self.state = state - self.builder = builder - self.local_builder = state.builder - self.insert_lets = insert_lets - - def loop(self, root): + def build_string(self, root): root_builder = [] context = [] - root_state = self.State(root, root.render_children(self), root_builder, 0, 1) + root_state = PrintStackFrame(root, root.render_children(self), root_builder, 0, 1) if id(root) in self.memo: - root_state.builder.append(self.memo[id(root)]) + root_state.local_builder.append(self.memo[id(root)]) return ''.join(root_builder) insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 if insert_lets: - root_state.builder = [] - context.append(PrintStackFrame(self.binding_sites[id(root)])) + root_state.local_builder = [] + context.append(BindingsStackFrame(self.binding_sites[id(root)])) head = root.render_head(self) if head != '': - root_state.builder.append(head) + root_state.local_builder.append(head) - root_frame = lambda x: None - root_frame.state = root_state - stack = [root_frame] + stack = [root_state] while True: - state = stack[-1].state + state = stack[-1] if state.i >= len(state.children): if len(stack) <= 1: - root_state.builder.append(root.render_tail(self)) + root_state.local_builder.append(root.render_tail(self)) insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 if insert_lets: - self.add_lets(context, root_state.builder, root_builder) + self.add_lets(context, root_state.local_builder, root_builder) return ''.join(root_builder) - k = stack[-1] - if isinstance(k, self.PostChildrenLifted): - if id(k.state.x) in self.memo: - k.local_builder.extend(self.memo[id(k.state.x)]) + if isinstance(state, PostChildrenLifted): + if id(state.x) in self.memo: + state.local_builder.extend(self.memo[id(state.x)]) else: - k.local_builder.append(k.state.x.render_tail(self)) - k.local_builder.append(' ') + state.local_builder.append(state.x.render_tail(self)) + state.local_builder.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets - context[k.lift_to].let_bodies.append(k.local_builder) - k.builder.append(f'(Ref {k.name})') + context[state.lift_to].let_bodies.append(state.local_builder) + state.builder.append(f'(Ref {state.name})') stack.pop() - stack[-1].state.i += 1 + stack[-1].i += 1 continue else: - assert isinstance(k, self.PostChildren) - k.local_builder.append(k.state.x.render_tail(self)) - if k.insert_lets: - self.add_lets(context, k.local_builder, k.builder) + assert isinstance(state, PostChildren) + state.local_builder.append(state.x.render_tail(self)) + if state.insert_lets: + self.add_lets(context, state.local_builder, state.builder) stack.pop() - state = stack[-1].state + state = stack[-1] if not isinstance(state.x, ir.BaseIR): state.ir_child_num += 1 state.i += 1 continue - state.builder.append(' ') + state.local_builder.append(' ') child = state.children[state.i] - new_state = self.State(child, child.render_children(self), state.builder, state.outermost_scope, state.depth) + child_children = child.render_children(self) + child_local_builder = state.local_builder + child_outermost_scope = state.outermost_scope + child_depth = state.depth if isinstance(state.x, ir.BaseIR): if state.x.new_block(state.ir_child_num): - new_state.outermost_scope = state.depth + child_outermost_scope = state.depth if isinstance(child, ir.BaseIR): - new_state.depth += 1 + child_depth += 1 lift_to = self.lifted_in_scope(child, context) else: lift_to = -1 @@ -377,44 +362,52 @@ def loop(self, root): (name, _) = context[lift_to].lifted_lets[id(child)] if id(child) in context[lift_to].visited: - state.builder.append(f'(Ref {name})') + state.local_builder.append(f'(Ref {name})') state.i += 1 continue context[lift_to].visited[id(child)] = child - new_state.builder = [f'(Let {name} '] + child_local_builder = [f'(Let {name} '] if id(child) in self.memo: - new_state.children = [] - stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) + child_children = [] + stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name)) continue - insert_lets = self.pre_add_lets(new_state, self.binding_sites, context) - assert(not insert_lets) + assert(not (id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0)) + head = child.render_head(self) + if head != '': + child_local_builder.append(head) - stack.append(self.PostChildrenLifted(new_state, state.builder, lift_to, name)) + stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name)) continue if isinstance(child, ir.BaseIR): if id(child) in self.memo: - new_state.builder.extend(self.memo[id(child)]) + child_local_builder.extend(self.memo[id(child)]) state.i += 1 continue - insert_lets = self.pre_add_lets(new_state, self.binding_sites, context) + insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0 + if insert_lets: + child_local_builder = [] + context.append(BindingsStackFrame(self.binding_sites[id(child)])) + head = child.render_head(self) + if head != '': + child_local_builder.append(head) - stack.append(self.PostChildren(new_state, state.builder, insert_lets)) + stack.append(PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets)) continue else: head = child.render_head(self) if head != '': - state.builder.append(head) + child_local_builder.append(head) + new_state = PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, False) new_state.ir_child_num = state.ir_child_num - insert_lets = False - stack.append(self.PostChildren(new_state, state.builder, insert_lets)) + stack.append(new_state) continue @staticmethod @@ -429,7 +422,7 @@ def add_lets(context, local_builder, builder): builder.append(')') def compute_new_bindings(self, root: 'ir.BaseIR'): - root_frame = StackFrame(0, ({}, {}, {}), root) + root_frame = AnalysisStackFrame(0, ({}, {}, {}), root) stack = [root_frame] binding_sites = {} @@ -522,4 +515,4 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): def __call__(self, root: 'ir.BaseIR') -> str: self.binding_sites = self.compute_new_bindings(root) - return self.loop(root) + return self.build_string(root) From 3d3671f81eca5dd114dacc56df1fb3e4e9b15934 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 22 Aug 2019 15:35:42 -0400 Subject: [PATCH 27/47] refactoring --- hail/python/hail/ir/renderer.py | 345 +++++++++++++++----------------- 1 file changed, 157 insertions(+), 188 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 3677c5dbfa8..59e8e970931 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -158,10 +158,7 @@ def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR', # compute depth at which we might bind this node def bind_depth(self) -> int: if len(self.free_vars) > 0: - try: - bind_depth = max(self.free_vars.values()) - except: - raise + bind_depth = max(self.free_vars.values()) bind_depth = max(bind_depth, self.min_binding_depth) else: bind_depth = self.min_binding_depth @@ -215,27 +212,25 @@ def __init__(self, binding_site: BindingSite): class PrintStackFrame: - def __init__(self, x, children, builder, outermost_scope, depth): + def __init__(self, x, children, local_builder, outermost_scope, depth, builder, insert_lets, lift_to_frame=None): self.x = x self.children = children - self.local_builder = builder + self.builder = builder + self.local_builder = local_builder self.outermost_scope = outermost_scope self.ir_child_num = 0 self.depth = depth + self.insert_lets = insert_lets + self.lift_to_frame = lift_to_frame self.i = 0 -class PostChildrenLifted(PrintStackFrame): - def __init__(self, x, children, local_builder, outermost_scope, depth, builder, lift_to, name): - super().__init__(x, children, local_builder, outermost_scope, depth) - self.builder = builder - self.lift_to = lift_to - self.name = name - -class PostChildren(PrintStackFrame): - def __init__(self, x, children, local_builder, outermost_scope, depth, builder, insert_lets): - super().__init__(x, children, local_builder, outermost_scope, depth) - self.builder = builder - self.insert_lets = insert_lets + def add_lets(self, let_bodies): + for let_body in let_bodies: + self.builder.extend(let_body) + self.builder.extend(self.local_builder) + num_lets = len(let_bodies) + for _ in range(num_lets): + self.builder.append(')') class CSERenderer(Renderer): @@ -257,170 +252,35 @@ def add_jir(self, jir): self.jirs[jir_id] = jir return jir_id - @staticmethod - def find_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame], outermost_scope: int) -> int: - for i in reversed(range(len(context))): - if id(x) in context[i].visited: - return i - return -1 - - @staticmethod - def lifted_in_scope(x: 'ir.BaseIR', context: List[AnalysisStackFrame]) -> int: - for i in range(len(context)): - if id(x) in context[i].lifted_lets: - return i - return -1 - - # Pre: - # * 'context' is a list of 'Scope's, one for each potential let-insertion - # site. - # * 'ref_to_scope' maps each bound variable to the index in 'context' of the - # scope of its binding site. - # Post: - # * Returns set of free variables in 'x'. - # * Each subtree of 'x' is flagged as visited in the outermost scope - # containing all of its free variables. - # * Each subtree previously visited (either an earlier subtree of 'x', or - # marked visited in 'context') is added to set of lets in its (previously - # computed) outermost scope. - # * 'self.scopes' is updated to map subtrees y of 'x' to scopes containing - # any lets to be inserted above y. - - def build_string(self, root): - root_builder = [] - context = [] - root_state = PrintStackFrame(root, root.render_children(self), root_builder, 0, 1) - if id(root) in self.memo: - root_state.local_builder.append(self.memo[id(root)]) - return ''.join(root_builder) - insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 - if insert_lets: - root_state.local_builder = [] - context.append(BindingsStackFrame(self.binding_sites[id(root)])) - head = root.render_head(self) - if head != '': - root_state.local_builder.append(head) - - stack = [root_state] - - while True: - state = stack[-1] - - if state.i >= len(state.children): - if len(stack) <= 1: - root_state.local_builder.append(root.render_tail(self)) - insert_lets = id(root) in self.binding_sites and len(self.binding_sites[id(root)].lifted_lets) > 0 - if insert_lets: - self.add_lets(context, root_state.local_builder, root_builder) - return ''.join(root_builder) - if isinstance(state, PostChildrenLifted): - if id(state.x) in self.memo: - state.local_builder.extend(self.memo[id(state.x)]) - else: - state.local_builder.append(state.x.render_tail(self)) - state.local_builder.append(' ') - # let_bodies is built post-order, which guarantees earlier - # lets can't refer to later lets - context[state.lift_to].let_bodies.append(state.local_builder) - state.builder.append(f'(Ref {state.name})') - stack.pop() - stack[-1].i += 1 - continue - else: - assert isinstance(state, PostChildren) - state.local_builder.append(state.x.render_tail(self)) - if state.insert_lets: - self.add_lets(context, state.local_builder, state.builder) - stack.pop() - state = stack[-1] - if not isinstance(state.x, ir.BaseIR): - state.ir_child_num += 1 - state.i += 1 - continue - - state.local_builder.append(' ') - child = state.children[state.i] - - child_children = child.render_children(self) - child_local_builder = state.local_builder - child_outermost_scope = state.outermost_scope - child_depth = state.depth - - if isinstance(state.x, ir.BaseIR): - if state.x.new_block(state.ir_child_num): - child_outermost_scope = state.depth - if isinstance(child, ir.BaseIR): - child_depth += 1 - lift_to = self.lifted_in_scope(child, context) - else: - lift_to = -1 - - if (lift_to >= 0 and - context[lift_to] and - context[lift_to].depth >= state.outermost_scope): - state.ir_child_num += 1 - (name, _) = context[lift_to].lifted_lets[id(child)] - - if id(child) in context[lift_to].visited: - state.local_builder.append(f'(Ref {name})') - state.i += 1 - continue - - context[lift_to].visited[id(child)] = child - - child_local_builder = [f'(Let {name} '] - - if id(child) in self.memo: - child_children = [] - stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name)) - continue - - assert(not (id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0)) - head = child.render_head(self) - if head != '': - child_local_builder.append(head) - - stack.append(PostChildrenLifted(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, lift_to, name)) - continue - - if isinstance(child, ir.BaseIR): - if id(child) in self.memo: - child_local_builder.extend(self.memo[id(child)]) - state.i += 1 - continue - - insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0 - if insert_lets: - child_local_builder = [] - context.append(BindingsStackFrame(self.binding_sites[id(child)])) - head = child.render_head(self) - if head != '': - child_local_builder.append(head) - - stack.append(PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets)) - continue - else: - head = child.render_head(self) - if head != '': - child_local_builder.append(head) - - new_state = PostChildren(child, child_children, child_local_builder, child_outermost_scope, child_depth, state.local_builder, False) - new_state.ir_child_num = state.ir_child_num - - stack.append(new_state) - continue - - @staticmethod - def add_lets(context, local_builder, builder): - sf = context[-1] - context.pop() - for let_body in sf.let_bodies: - builder.extend(let_body) - builder.extend(local_builder) - num_lets = len(sf.let_bodies) - for _ in range(num_lets): - builder.append(')') - + # At top of main loop, we are considering the node 'node' and its + # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are + # about to do post-processing on 'node' before moving back up to its parent. + # + # 'stack' is a stack of `AnalysisStackFrame`s, one for each node on the path + # from 'root' to 'node. Each stack frame tracks the following immutable + # information: + # * 'node': The node corresponding to this stack frame. + # * 'min_binding_depth': If 'node' is to be bound in a let, the let may not + # rise higher than this (its depth must be >= 'min_binding_depth') + # * 'context': The binding context of 'node'. Maps variables bound above + # to the depth at which they were bound (more precisely, if + # 'context[var] == depth', then 'stack[depth-1].node' binds 'var' in the + # subtree rooted at 'stack[depth].node'). + # * 'new_bindings': The variables which were bound by 'node's parent. These + # must be subtracted out of 'node's free variables when updating the + # parent's free variables. + # + # Each stack frame also holds the following mutable state: + # * 'free_vars': The running union of free variables in the subtree rooted + # at 'node'. + # * 'visited': A set of visited descendants. For each descendant 'x' of + # 'node', 'id(x)' is added to 'visited'. This allows us to recognize when + # we see a node for a second time. + # * 'lifted_lets': For each descendant 'x' of 'node', if 'x' is to be bound + # in a let immediately above 'node', then 'lifted_lets' contains 'id(x)', + # along with the unique id to bind 'x' to. + # * 'child_idx': the child currently being visited (satisfies the invariant + # 'stack[i].node.children[stack[i].child_idx] is stack[i+1].node`). def compute_new_bindings(self, root: 'ir.BaseIR'): root_frame = AnalysisStackFrame(0, ({}, {}, {}), root) stack = [root_frame] @@ -470,8 +330,12 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): self.memo[id(child)] = jref continue - seen_in_scope = self.find_in_scope(child, stack, - frame.min_binding_depth) + for i in reversed(range(len(stack))): + if id(child) in stack[i].visited: + seen_in_scope = i + break + else: + seen_in_scope = -1 if seen_in_scope >= 0 and isinstance(child, ir.IR): # we've seen 'child' before, should not traverse (or we will @@ -495,10 +359,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): if child.name not in node.bindings(child_idx, default_value=0).keys(): (eval_c, _, _) = node.child_context_without_bindings( child_idx, frame.context) - try: - frame.free_vars.update({child.name: eval_c[child.name]}) - except: - raise + frame.free_vars.update({child.name: eval_c[child.name]}) continue stack.append(frame.make_child_frame(len(stack))) @@ -512,6 +373,114 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0, root) return binding_sites + def make_post_children(self, node, builder, context, outermost_scope, depth, local_builder=None): + if local_builder is None: + local_builder = builder + insert_lets = id(node) in self.binding_sites and len(self.binding_sites[id(node)].lifted_lets) > 0 + state = PrintStackFrame(node, node.render_children(self), local_builder, outermost_scope, depth, builder, insert_lets) + if not isinstance(node, ir.BaseIR): + state.ir_child_num = state.ir_child_num + if insert_lets: + state.local_builder = [] + context.append(BindingsStackFrame(self.binding_sites[id(node)])) + head = node.render_head(self) + if head != '': + state.local_builder.append(head) + return state + + def build_string(self, root): + root_builder = [] + context = [] + + if id(root) in self.memo: + return ''.join(self.memo[id(root)]) + + root_state = self.make_post_children(root, root_builder, context, 0, 1) + + stack = [root_state] + + while True: + state = stack[-1] + + if state.i >= len(state.children): + if state.lift_to_frame is not None: + if id(state.x) in self.memo: + state.local_builder.extend(self.memo[id(state.x)]) + else: + state.local_builder.append(state.x.render_tail(self)) + state.local_builder.append(' ') + # let_bodies is built post-order, which guarantees earlier + # lets can't refer to later lets + state.lift_to_frame.let_bodies.append(state.local_builder) + (name, _) = state.lift_to_frame.lifted_lets[id(state.x)] + # FIXME: can this be added on first visit? + state.builder.append(f'(Ref {name})') + stack.pop() + stack[-1].i += 1 + continue + else: + state.local_builder.append(state.x.render_tail(self)) + if state.insert_lets: + state.add_lets(context[-1].let_bodies) + context.pop() + stack.pop() + if not stack: + return ''.join(state.builder) + state = stack[-1] + if not isinstance(state.x, ir.BaseIR): + state.ir_child_num += 1 + state.i += 1 + continue + + state.local_builder.append(' ') + child = state.children[state.i] + + child_outermost_scope = state.outermost_scope + child_depth = state.depth + + if isinstance(state.x, ir.BaseIR) and state.x.new_block(state.ir_child_num): + child_outermost_scope = state.depth + lift_to_frame = None + if isinstance(child, ir.BaseIR): + child_depth += 1 + lift_to_frame = next((frame for frame in context if id(child) in frame.lifted_lets), None) + + if lift_to_frame and lift_to_frame.depth >= state.outermost_scope: + insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0 + assert not insert_lets + state.ir_child_num += 1 + (name, _) = lift_to_frame.lifted_lets[id(child)] + + if id(child) in lift_to_frame.visited: + state.local_builder.append(f'(Ref {name})') + state.i += 1 + continue + + lift_to_frame.visited[id(child)] = child + + child_builder = [f'(Let {name} '] + if id(child) in self.memo: + new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets, lift_to_frame) + else: + # children = child.render_children(self) + # head = child.render_head(self) + # if head != '': + # child_builder.append(head) + # new_state = PostChildrenLifted(child, children, child_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets, lift_to_frame) + new_state = self.make_post_children(child, state.local_builder, context, child_outermost_scope, child_depth, child_builder) + new_state.lift_to_frame = lift_to_frame + stack.append(new_state) + continue + + if id(child) in self.memo: + state.local_builder.extend(self.memo[id(child)]) + state.i += 1 + continue + + new_state = self.make_post_children(child, state.local_builder, context, child_outermost_scope, child_depth) + stack.append(new_state) + continue + def __call__(self, root: 'ir.BaseIR') -> str: self.binding_sites = self.compute_new_bindings(root) From 636aa5e647f7a639652bea2af9c40a3e639ef9da Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 26 Aug 2019 10:51:52 -0400 Subject: [PATCH 28/47] refactoring, documenting --- hail/python/hail/ir/renderer.py | 214 ++++++++++++++++++++------------ 1 file changed, 132 insertions(+), 82 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 59e8e970931..4cd1996e471 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,6 +1,6 @@ from hail import ir import abc -from typing import Sequence, List, Set, Dict, Callable, Tuple +from typing import Sequence, List, Set, Dict, Callable, Tuple, Optional class Renderable(object): @@ -141,24 +141,22 @@ def __call__(self, x: 'Renderable'): class AnalysisStackFrame: - def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR', - new_bindings=(None, None, None)): + def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR'): # immutable self.min_binding_depth = min_binding_depth self.context = context self.node = x - (self._parent_eval_bindings, self._parent_agg_bindings, - self._parent_scan_bindings) = new_bindings # mutable - self.free_vars = {} + self._free_vars = {} self.visited: Dict[int, 'ir.BaseIR'] = {} self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} self.child_idx = 0 + self._child_bindings = None # compute depth at which we might bind this node def bind_depth(self) -> int: - if len(self.free_vars) > 0: - bind_depth = max(self.free_vars.values()) + if len(self._free_vars) > 0: + bind_depth = max(self._free_vars.values()) bind_depth = max(bind_depth, self.min_binding_depth) else: bind_depth = self.min_binding_depth @@ -178,22 +176,36 @@ def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': agg_bindings = x.agg_bindings(i, 0).keys() scan_bindings = x.scan_bindings(i, 0).keys() new_bindings = (eval_bindings, agg_bindings, scan_bindings) + self._child_bindings = new_bindings child_context = x.child_context(i, self.context, depth) - return AnalysisStackFrame(child_outermost_scope, child_context, child, - new_bindings) + return AnalysisStackFrame(child_outermost_scope, child_context, child,) - def update_parent_free_vars(self, parent_free_vars: Set[str]): + def free_vars(self): + # subtract vars that will be bound by inserted lets + for (var, _) in self.lifted_lets.values(): + self._free_vars.pop(var, 0) + return self._free_vars + + def update_free_vars(self, child_frame: 'AnalysisStackFrame'): + child_free_vars = child_frame.free_vars() # subtract vars bound by parent from free_vars - for var in [*self._parent_eval_bindings, - *self._parent_agg_bindings, - *self._parent_scan_bindings]: - self.free_vars.pop(var, 0) + (eval_bindings, agg_bindings, scan_bindings) = self._child_bindings + for var in [*eval_bindings, *agg_bindings, *scan_bindings]: + child_free_vars.pop(var, 0) + # update parent's free variables + self._free_vars.update(child_free_vars) + + def update_parent_free_vars(self, parent_frame: 'AnalysisStackFrame'): + # subtract vars bound by parent from free_vars + (eval_bindings, agg_bindings, scan_bindings) = parent_frame._child_bindings + for var in [*eval_bindings, *agg_bindings, *scan_bindings]: + self._free_vars.pop(var, 0) # subtract vars that will be bound by inserted lets for (var, _) in self.lifted_lets.values(): - self.free_vars.pop(var, 0) + self._free_vars.pop(var, 0) # update parent's free variables - parent_free_vars.update(self.free_vars) + parent_frame._free_vars.update(self._free_vars) class BindingSite: @@ -212,17 +224,21 @@ def __init__(self, binding_site: BindingSite): class PrintStackFrame: - def __init__(self, x, children, local_builder, outermost_scope, depth, builder, insert_lets, lift_to_frame=None): - self.x = x - self.children = children + def __init__(self, node, children, local_builder, outermost_scope, depth, builder, insert_lets, lift_to_frame=None): + # immutable + self.node: Renderable = node + self.children: Sequence[Renderable] = children + self.min_binding_depth: int = outermost_scope + self.depth: int = depth + self.lift_to_frame: Optional[int] = lift_to_frame + self.insert_lets: bool = insert_lets + # mutable + # FIXME: Builder is always parent's local_builder. Shouldn't need to + # store builder. Maybe also make these plain strings. self.builder = builder self.local_builder = local_builder - self.outermost_scope = outermost_scope + self.child_idx = 0 self.ir_child_num = 0 - self.depth = depth - self.insert_lets = insert_lets - self.lift_to_frame = lift_to_frame - self.i = 0 def add_lets(self, let_bodies): for let_body in let_bodies: @@ -232,6 +248,22 @@ def add_lets(self, let_bodies): for _ in range(num_lets): self.builder.append(')') + def make_child_frame(self, renderer, builder, context, outermost_scope, depth, local_builder=None): + child = self.children[self.i] + if local_builder is None: + local_builder = builder + insert_lets = id(child) in renderer.binding_sites and len(renderer.binding_sites[id(child)].lifted_lets) > 0 + state = PrintStackFrame(child, child.render_children(renderer), local_builder, outermost_scope, depth, builder, insert_lets) + if not isinstance(child, ir.BaseIR): + state.ir_child_num = state.ir_child_num + if insert_lets: + state.local_builder = [] + context.append(BindingsStackFrame(renderer.binding_sites[id(child)])) + head = child.render_head(renderer) + if head != '': + state.local_builder.append(head) + return state + class CSERenderer(Renderer): def __init__(self, stop_at_jir=False): @@ -308,7 +340,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): binding_sites[id(node)] = \ BindingSite(frame.lifted_lets, len(stack), node) - frame.update_parent_free_vars(parent_frame.free_vars) + parent_frame.update_free_vars(frame) stack.pop() continue @@ -330,14 +362,9 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): self.memo[id(child)] = jref continue - for i in reversed(range(len(stack))): - if id(child) in stack[i].visited: - seen_in_scope = i - break - else: - seen_in_scope = -1 + seen_in_scope = next((i for i in reversed(range(len(stack))) if id(child) in stack[i].visited), None) - if seen_in_scope >= 0 and isinstance(child, ir.IR): + if seen_in_scope is not None and isinstance(child, ir.IR): # we've seen 'child' before, should not traverse (or we will # find too many lifts) if id(child) in stack[seen_in_scope].lifted_lets: @@ -350,7 +377,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): # variables. To prevent a parent from being lifted too high, # we must register 'child' as having the free variable 'uid', # which will be true when 'child' is replaced by "Ref uid". - frame.free_vars[uid] = seen_in_scope + frame._free_vars[uid] = seen_in_scope continue # first time visiting 'child' @@ -359,16 +386,14 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): if child.name not in node.bindings(child_idx, default_value=0).keys(): (eval_c, _, _) = node.child_context_without_bindings( child_idx, frame.context) - frame.free_vars.update({child.name: eval_c[child.name]}) + frame._free_vars[child.name] = eval_c[child.name] continue stack.append(frame.make_child_frame(len(stack))) continue - for (var, _) in root_frame.lifted_lets.values(): - var_depth = root_frame.free_vars.pop(var, 0) - assert var_depth == 0 - assert(len(root_frame.free_vars) == 0) + root_free_vars = root_frame.free_vars() + assert(len(root_free_vars) == 0) binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0, root) return binding_sites @@ -388,6 +413,34 @@ def make_post_children(self, node, builder, context, outermost_scope, depth, loc state.local_builder.append(head) return state + # At top of main loop, we are considering the 'Renderable' 'node' and its + # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are + # about to do post-processing on 'node' before moving back up to its parent. + # + # 'stack' is a stack of `PrintStackFrame`s, one for each node on the path + # from 'root' to 'node. Each stack frame tracks the following immutable + # information: + # * 'node': The 'Renderable' node corresponding to this stack frame. + # * 'children': The list of 'Renderable' children. + # * 'min_binding_depth': If 'node' is to be bound in a let, the let may not + # rise higher than this (its depth must be >= 'min_binding_depth') + # * 'depth': The depth of 'node' in the original tree, i.e. the number of + # BaseIR in 'stack', not counting other 'Renderable's. + # * 'lift_to_frame': The outermost frame in which 'node' was marked to be + # lifted in the analysis pass, if any, otherwise None. + # * 'insert_lets': True if any lets need to be inserted above 'node'. No + # node has both 'lift_to_frame' not None and 'insert_lets' True. + # + # Each stack frame also holds the following mutable state: + # * 'builder' and 'local_builder': If 'insert_lets', then 'builder' is the + # buffer building the parent's rendered IR, and 'local_builder' is the + # buffer building 'node's rendered IR. After traversing the subtree rooted + # at 'node', all lets will be added to 'builder' before copying + # 'local_builder' to 'builder'. + # If + # * 'child_idx': + # * 'ir_child_num': + def build_string(self, root): root_builder = [] context = [] @@ -395,89 +448,86 @@ def build_string(self, root): if id(root) in self.memo: return ''.join(self.memo[id(root)]) - root_state = self.make_post_children(root, root_builder, context, 0, 1) - - stack = [root_state] + stack = [self.make_post_children(root, root_builder, context, 0, 1)] while True: - state = stack[-1] - - if state.i >= len(state.children): - if state.lift_to_frame is not None: - if id(state.x) in self.memo: - state.local_builder.extend(self.memo[id(state.x)]) + frame = stack[-1] + node = frame.node + # child_idx = frame.child_idx + # frame.child_idx += 1 + + if frame.child_idx >= len(frame.children): + if frame.lift_to_frame is not None: + assert(not frame.insert_lets) + if id(node) in self.memo: + frame.local_builder.extend(self.memo[id(node)]) else: - state.local_builder.append(state.x.render_tail(self)) - state.local_builder.append(' ') + frame.local_builder.append(node.render_tail(self)) + frame.local_builder.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets - state.lift_to_frame.let_bodies.append(state.local_builder) - (name, _) = state.lift_to_frame.lifted_lets[id(state.x)] + frame.lift_to_frame.let_bodies.append(frame.local_builder) + (name, _) = frame.lift_to_frame.lifted_lets[id(node)] # FIXME: can this be added on first visit? - state.builder.append(f'(Ref {name})') + frame.builder.append(f'(Ref {name})') stack.pop() - stack[-1].i += 1 + stack[-1].child_idx += 1 continue else: - state.local_builder.append(state.x.render_tail(self)) - if state.insert_lets: - state.add_lets(context[-1].let_bodies) + frame.local_builder.append(node.render_tail(self)) + if frame.insert_lets: + frame.add_lets(context[-1].let_bodies) context.pop() stack.pop() if not stack: - return ''.join(state.builder) - state = stack[-1] - if not isinstance(state.x, ir.BaseIR): - state.ir_child_num += 1 - state.i += 1 + return ''.join(frame.builder) + frame = stack[-1] + if not isinstance(node, ir.BaseIR): + frame.ir_child_num += 1 + frame.child_idx += 1 continue - state.local_builder.append(' ') - child = state.children[state.i] + frame.local_builder.append(' ') + child = frame.children[frame.child_idx] - child_outermost_scope = state.outermost_scope - child_depth = state.depth + child_outermost_scope = frame.min_binding_depth + child_depth = frame.depth - if isinstance(state.x, ir.BaseIR) and state.x.new_block(state.ir_child_num): - child_outermost_scope = state.depth + if isinstance(frame.node, ir.BaseIR) and frame.node.new_block(frame.ir_child_num): + child_outermost_scope = frame.depth lift_to_frame = None if isinstance(child, ir.BaseIR): child_depth += 1 lift_to_frame = next((frame for frame in context if id(child) in frame.lifted_lets), None) - if lift_to_frame and lift_to_frame.depth >= state.outermost_scope: + if lift_to_frame and lift_to_frame.depth >= frame.min_binding_depth: insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0 assert not insert_lets - state.ir_child_num += 1 + frame.ir_child_num += 1 (name, _) = lift_to_frame.lifted_lets[id(child)] if id(child) in lift_to_frame.visited: - state.local_builder.append(f'(Ref {name})') - state.i += 1 + frame.local_builder.append(f'(Ref {name})') + frame.child_idx += 1 continue lift_to_frame.visited[id(child)] = child child_builder = [f'(Let {name} '] if id(child) in self.memo: - new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets, lift_to_frame) + new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, frame.local_builder, insert_lets, lift_to_frame) else: - # children = child.render_children(self) - # head = child.render_head(self) - # if head != '': - # child_builder.append(head) - # new_state = PostChildrenLifted(child, children, child_builder, child_outermost_scope, child_depth, state.local_builder, insert_lets, lift_to_frame) - new_state = self.make_post_children(child, state.local_builder, context, child_outermost_scope, child_depth, child_builder) + new_state = self.make_post_children(child, frame.local_builder, context, child_outermost_scope, child_depth, child_builder) new_state.lift_to_frame = lift_to_frame stack.append(new_state) continue if id(child) in self.memo: - state.local_builder.extend(self.memo[id(child)]) - state.i += 1 + frame.local_builder.extend(self.memo[id(child)]) + frame.child_idx += 1 continue - new_state = self.make_post_children(child, state.local_builder, context, child_outermost_scope, child_depth) + new_state = self.make_post_children(child, frame.local_builder, context, child_outermost_scope, child_depth) stack.append(new_state) continue From f64ac3ef9c2d3be77496bd2b4f96b9d4d54235ca Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Mon, 26 Aug 2019 14:08:21 -0400 Subject: [PATCH 29/47] refactoring --- hail/python/hail/ir/base_ir.py | 12 +++++++ hail/python/hail/ir/ir.py | 6 ++++ hail/python/hail/ir/renderer.py | 57 +++++++++++++++------------------ 3 files changed, 43 insertions(+), 32 deletions(-) diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 1e852174975..e50d75f9f91 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -89,6 +89,9 @@ def __hash__(self): def new_block(self, i: int) -> bool: ... + def renderable_new_block(self, i: int) -> bool: + return self.new_block(i) + @staticmethod def is_effectful() -> bool: return False @@ -116,9 +119,15 @@ def scan_bindings(self, i: int, default_value=None): def uses_agg_context(self, i: int) -> bool: return False + def renderable_uses_agg_context(self, i: int) -> bool: + return self.uses_agg_context(i) + def uses_scan_context(self, i: int) -> bool: return False + def renderable_uses_scan_context(self, i: int) -> bool: + return self.uses_scan_context(i) + def child_context_without_bindings(self, i: int, parent_context): (eval_c, agg_c, scan_c) = parent_context if self.uses_agg_context(i): @@ -189,6 +198,9 @@ def typ(self): def new_block(self, i: int) -> bool: return self.uses_agg_context(i) or self.uses_scan_context(i) + def renderable_new_block(self, i: int) -> bool: + return self.renderable_uses_agg_context(i) or self.renderable_uses_scan_context(i) + @abc.abstractmethod def _compute_type(self, env, agg_env): raise NotImplementedError(self) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 434b877a209..eadc0e44c70 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1401,6 +1401,9 @@ def __init__(self, agg_op, constructor_args, init_op_args, seq_op_args): def uses_agg_context(self, i: int): return i >= len(self.constructor_args) + (len(self.init_op_args) if self.init_op_args else 0) + def renderable_uses_agg_context(self, i: int): + return i == 2 + class ApplyScanOp(BaseApplyAggOp): @typecheck_method(agg_op=str, @@ -1413,6 +1416,9 @@ def __init__(self, agg_op, constructor_args, init_op_args, seq_op_args): def uses_scan_context(self, i: int): return i >= len(self.constructor_args) + (len(self.init_op_args) if self.init_op_args else 0) + def renderable_uses_agg_context(self, i: int): + return i == 2 + class Begin(IR): @typecheck_method(xs=sequenceof(IR)) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 4cd1996e471..729efcb97ca 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -238,7 +238,6 @@ def __init__(self, node, children, local_builder, outermost_scope, depth, builde self.builder = builder self.local_builder = local_builder self.child_idx = 0 - self.ir_child_num = 0 def add_lets(self, let_bodies): for let_body in let_bodies: @@ -249,17 +248,19 @@ def add_lets(self, let_bodies): self.builder.append(')') def make_child_frame(self, renderer, builder, context, outermost_scope, depth, local_builder=None): - child = self.children[self.i] + child = self.children[self.child_idx] + return self.make(child, renderer, builder, context, outermost_scope, depth, local_builder) + + @staticmethod + def make(node, renderer, builder, context, outermost_scope, depth, local_builder=None): if local_builder is None: local_builder = builder - insert_lets = id(child) in renderer.binding_sites and len(renderer.binding_sites[id(child)].lifted_lets) > 0 - state = PrintStackFrame(child, child.render_children(renderer), local_builder, outermost_scope, depth, builder, insert_lets) - if not isinstance(child, ir.BaseIR): - state.ir_child_num = state.ir_child_num + insert_lets = id(node) in renderer.binding_sites and len(renderer.binding_sites[id(node)].lifted_lets) > 0 + state = PrintStackFrame(node, node.render_children(renderer), local_builder, outermost_scope, depth, builder, insert_lets) if insert_lets: state.local_builder = [] - context.append(BindingsStackFrame(renderer.binding_sites[id(child)])) - head = child.render_head(renderer) + context.append(BindingsStackFrame(renderer.binding_sites[id(node)])) + head = node.render_head(renderer) if head != '': state.local_builder.append(head) return state @@ -325,11 +326,6 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): frame.child_idx += 1 if child_idx >= len(node.children): - if len(stack) <= 1: - break - - parent_frame = stack[-2] - # mark node as visited at potential let insertion site if not node.is_effectful(): stack[frame.bind_depth()].visited[id(node)] = node @@ -340,9 +336,10 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): binding_sites[id(node)] = \ BindingSite(frame.lifted_lets, len(stack), node) - parent_frame.update_free_vars(frame) - stack.pop() + if not stack: + break + stack[-1].update_free_vars(frame) continue child = node.children[child_idx] @@ -403,8 +400,6 @@ def make_post_children(self, node, builder, context, outermost_scope, depth, loc local_builder = builder insert_lets = id(node) in self.binding_sites and len(self.binding_sites[id(node)].lifted_lets) > 0 state = PrintStackFrame(node, node.render_children(self), local_builder, outermost_scope, depth, builder, insert_lets) - if not isinstance(node, ir.BaseIR): - state.ir_child_num = state.ir_child_num if insert_lets: state.local_builder = [] context.append(BindingsStackFrame(self.binding_sites[id(node)])) @@ -432,14 +427,15 @@ def make_post_children(self, node, builder, context, outermost_scope, depth, loc # node has both 'lift_to_frame' not None and 'insert_lets' True. # # Each stack frame also holds the following mutable state: - # * 'builder' and 'local_builder': If 'insert_lets', then 'builder' is the - # buffer building the parent's rendered IR, and 'local_builder' is the - # buffer building 'node's rendered IR. After traversing the subtree rooted - # at 'node', all lets will be added to 'builder' before copying - # 'local_builder' to 'builder'. - # If - # * 'child_idx': - # * 'ir_child_num': + # * 'child_idx': The index of the 'Renderable' child currently being + # visited. + # * 'builder': The buffer building the parent's rendered IR. + # * 'local_builder': The builder building 'node's IR. + # If 'insert_lets', all lets will be added to 'builder' before copying + # 'local_builder' to 'builder. If 'lift_to_frame', 'local_builder' will be + # added to 'lift_to_frame's list of lifted lets, while only "(Ref ...)" will + # be added to 'builder'. If neither, then it is safe for 'local_builder' to + # be 'builder', to save copying. def build_string(self, root): root_builder = [] @@ -448,7 +444,7 @@ def build_string(self, root): if id(root) in self.memo: return ''.join(self.memo[id(root)]) - stack = [self.make_post_children(root, root_builder, context, 0, 1)] + stack = [PrintStackFrame.make(root, self, root_builder, context, 0, 1)] while True: frame = stack[-1] @@ -482,8 +478,6 @@ def build_string(self, root): if not stack: return ''.join(frame.builder) frame = stack[-1] - if not isinstance(node, ir.BaseIR): - frame.ir_child_num += 1 frame.child_idx += 1 continue @@ -493,7 +487,7 @@ def build_string(self, root): child_outermost_scope = frame.min_binding_depth child_depth = frame.depth - if isinstance(frame.node, ir.BaseIR) and frame.node.new_block(frame.ir_child_num): + if isinstance(frame.node, ir.BaseIR) and frame.node.renderable_new_block(frame.child_idx): child_outermost_scope = frame.depth lift_to_frame = None if isinstance(child, ir.BaseIR): @@ -503,7 +497,6 @@ def build_string(self, root): if lift_to_frame and lift_to_frame.depth >= frame.min_binding_depth: insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0 assert not insert_lets - frame.ir_child_num += 1 (name, _) = lift_to_frame.lifted_lets[id(child)] if id(child) in lift_to_frame.visited: @@ -517,7 +510,7 @@ def build_string(self, root): if id(child) in self.memo: new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, frame.local_builder, insert_lets, lift_to_frame) else: - new_state = self.make_post_children(child, frame.local_builder, context, child_outermost_scope, child_depth, child_builder) + new_state = frame.make_child_frame(self, frame.local_builder, context, child_outermost_scope, child_depth, child_builder) new_state.lift_to_frame = lift_to_frame stack.append(new_state) continue @@ -527,7 +520,7 @@ def build_string(self, root): frame.child_idx += 1 continue - new_state = self.make_post_children(child, frame.local_builder, context, child_outermost_scope, child_depth) + new_state = frame.make_child_frame(self, frame.local_builder, context, child_outermost_scope, child_depth) stack.append(new_state) continue From 3842f38effa89f5386587f5b11dc836eefb098e2 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 28 Aug 2019 11:53:02 -0400 Subject: [PATCH 30/47] refactoring --- hail/python/hail/ir/renderer.py | 81 ++++++++++++++------------------- 1 file changed, 34 insertions(+), 47 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 729efcb97ca..95a6ecc333c 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -140,6 +140,7 @@ def __call__(self, x: 'Renderable'): Context = (Vars, Vars, Vars) +# FIXME: replace these by NamedTuples, or use slots class AnalysisStackFrame: def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR'): # immutable @@ -209,10 +210,9 @@ def update_parent_free_vars(self, parent_frame: 'AnalysisStackFrame'): class BindingSite: - def __init__(self, lifted_lets: Dict[int, Tuple[str, 'ir.BaseIR']], depth: int, node: 'ir.BaseIR'): + def __init__(self, lifted_lets: Dict[int, str], depth: int): self.depth = depth self.lifted_lets = lifted_lets - self.node = node class BindingsStackFrame: @@ -237,7 +237,7 @@ def __init__(self, node, children, local_builder, outermost_scope, depth, builde # store builder. Maybe also make these plain strings. self.builder = builder self.local_builder = local_builder - self.child_idx = 0 + self.child_idx = -1 def add_lets(self, let_bodies): for let_body in let_bodies: @@ -247,19 +247,19 @@ def add_lets(self, let_bodies): for _ in range(num_lets): self.builder.append(')') - def make_child_frame(self, renderer, builder, context, outermost_scope, depth, local_builder=None): + def make_child_frame(self, renderer, binding_sites, builder, context, outermost_scope, depth, local_builder=None): child = self.children[self.child_idx] - return self.make(child, renderer, builder, context, outermost_scope, depth, local_builder) + return self.make(child, renderer, binding_sites, builder, context, outermost_scope, depth, local_builder) @staticmethod - def make(node, renderer, builder, context, outermost_scope, depth, local_builder=None): + def make(node, renderer, binding_sites, builder, context, outermost_scope, depth, local_builder=None): if local_builder is None: local_builder = builder - insert_lets = id(node) in renderer.binding_sites and len(renderer.binding_sites[id(node)].lifted_lets) > 0 + insert_lets = id(node) in binding_sites and len(binding_sites[id(node)].lifted_lets) > 0 state = PrintStackFrame(node, node.render_children(renderer), local_builder, outermost_scope, depth, builder, insert_lets) if insert_lets: state.local_builder = [] - context.append(BindingsStackFrame(renderer.binding_sites[id(node)])) + context.append(BindingsStackFrame(binding_sites[id(node)])) head = node.render_head(renderer) if head != '': state.local_builder.append(head) @@ -273,7 +273,6 @@ def __init__(self, stop_at_jir=False): self.jirs = {} self.memo: Dict[int, Sequence[str]] = {} self.uid_count = 0 - self.binding_sites: Dict[int, BindingSite] = {} def uid(self) -> str: self.uid_count += 1 @@ -289,7 +288,7 @@ def add_jir(self, jir): # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are # about to do post-processing on 'node' before moving back up to its parent. # - # 'stack' is a stack of `AnalysisStackFrame`s, one for each node on the path + # 'stack' is a stack of 'AnalysisStackFrame's, one for each node on the path # from 'root' to 'node. Each stack frame tracks the following immutable # information: # * 'node': The node corresponding to this stack frame. @@ -313,8 +312,14 @@ def add_jir(self, jir): # in a let immediately above 'node', then 'lifted_lets' contains 'id(x)', # along with the unique id to bind 'x' to. # * 'child_idx': the child currently being visited (satisfies the invariant - # 'stack[i].node.children[stack[i].child_idx] is stack[i+1].node`). - def compute_new_bindings(self, root: 'ir.BaseIR'): + # 'stack[i].node.children[stack[i].child_idx] is stack[i+1].node'). + # + # Returns a Dict summarizing of all lets to be inserted. For each 'node' + # which will have lets inserted immediately above it, maps 'id(node)' to a + # 'BindingSite' recording the depth of 'node' and a Dict 'lifted_lets', + # where for each descendant 'x' which will be bound above 'node', + # 'lifted_lets' maps 'id(x)' to the unique id 'x' will be bound to. + def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: root_frame = AnalysisStackFrame(0, ({}, {}, {}), root) stack = [root_frame] binding_sites = {} @@ -334,7 +339,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): # binding sites if frame.lifted_lets: binding_sites[id(node)] = \ - BindingSite(frame.lifted_lets, len(stack), node) + BindingSite(frame.lifted_lets, len(stack)) stack.pop() if not stack: @@ -392,27 +397,14 @@ def compute_new_bindings(self, root: 'ir.BaseIR'): root_free_vars = root_frame.free_vars() assert(len(root_free_vars) == 0) - binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0, root) + binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0) return binding_sites - def make_post_children(self, node, builder, context, outermost_scope, depth, local_builder=None): - if local_builder is None: - local_builder = builder - insert_lets = id(node) in self.binding_sites and len(self.binding_sites[id(node)].lifted_lets) > 0 - state = PrintStackFrame(node, node.render_children(self), local_builder, outermost_scope, depth, builder, insert_lets) - if insert_lets: - state.local_builder = [] - context.append(BindingsStackFrame(self.binding_sites[id(node)])) - head = node.render_head(self) - if head != '': - state.local_builder.append(head) - return state - # At top of main loop, we are considering the 'Renderable' 'node' and its # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are # about to do post-processing on 'node' before moving back up to its parent. # - # 'stack' is a stack of `PrintStackFrame`s, one for each node on the path + # 'stack' is a stack of 'PrintStackFrame's, one for each node on the path # from 'root' to 'node. Each stack frame tracks the following immutable # information: # * 'node': The 'Renderable' node corresponding to this stack frame. @@ -437,22 +429,22 @@ def make_post_children(self, node, builder, context, outermost_scope, depth, loc # be added to 'builder'. If neither, then it is safe for 'local_builder' to # be 'builder', to save copying. - def build_string(self, root): + def build_string(self, root, binding_sites): root_builder = [] - context = [] + context: List[BindingsStackFrame] = [] if id(root) in self.memo: return ''.join(self.memo[id(root)]) - stack = [PrintStackFrame.make(root, self, root_builder, context, 0, 1)] + stack = [PrintStackFrame.make(root, self, binding_sites, root_builder, context, 0, 1)] while True: frame = stack[-1] node = frame.node - # child_idx = frame.child_idx - # frame.child_idx += 1 + frame.child_idx += 1 + child_idx = frame.child_idx - if frame.child_idx >= len(frame.children): + if child_idx >= len(frame.children): if frame.lift_to_frame is not None: assert(not frame.insert_lets) if id(node) in self.memo: @@ -467,7 +459,6 @@ def build_string(self, root): # FIXME: can this be added on first visit? frame.builder.append(f'(Ref {name})') stack.pop() - stack[-1].child_idx += 1 continue else: frame.local_builder.append(node.render_tail(self)) @@ -477,31 +468,29 @@ def build_string(self, root): stack.pop() if not stack: return ''.join(frame.builder) - frame = stack[-1] - frame.child_idx += 1 continue frame.local_builder.append(' ') - child = frame.children[frame.child_idx] + child = frame.children[child_idx] child_outermost_scope = frame.min_binding_depth child_depth = frame.depth if isinstance(frame.node, ir.BaseIR) and frame.node.renderable_new_block(frame.child_idx): child_outermost_scope = frame.depth - lift_to_frame = None if isinstance(child, ir.BaseIR): child_depth += 1 lift_to_frame = next((frame for frame in context if id(child) in frame.lifted_lets), None) + else: + lift_to_frame = None if lift_to_frame and lift_to_frame.depth >= frame.min_binding_depth: - insert_lets = id(child) in self.binding_sites and len(self.binding_sites[id(child)].lifted_lets) > 0 + insert_lets = id(child) in binding_sites and len(binding_sites[id(child)].lifted_lets) > 0 assert not insert_lets (name, _) = lift_to_frame.lifted_lets[id(child)] if id(child) in lift_to_frame.visited: frame.local_builder.append(f'(Ref {name})') - frame.child_idx += 1 continue lift_to_frame.visited[id(child)] = child @@ -510,21 +499,19 @@ def build_string(self, root): if id(child) in self.memo: new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, frame.local_builder, insert_lets, lift_to_frame) else: - new_state = frame.make_child_frame(self, frame.local_builder, context, child_outermost_scope, child_depth, child_builder) + new_state = frame.make_child_frame(self, binding_sites, frame.local_builder, context, child_outermost_scope, child_depth, child_builder) new_state.lift_to_frame = lift_to_frame stack.append(new_state) continue if id(child) in self.memo: frame.local_builder.extend(self.memo[id(child)]) - frame.child_idx += 1 continue - new_state = frame.make_child_frame(self, frame.local_builder, context, child_outermost_scope, child_depth) + new_state = frame.make_child_frame(self, binding_sites, frame.local_builder, context, child_outermost_scope, child_depth) stack.append(new_state) continue def __call__(self, root: 'ir.BaseIR') -> str: - self.binding_sites = self.compute_new_bindings(root) - - return self.build_string(root) + binding_sites = self.compute_new_bindings(root) + return self.build_string(root, binding_sites) From 92c1f6b9343a9a8034ef0b5e504383de1827b532 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 28 Aug 2019 13:34:20 -0400 Subject: [PATCH 31/47] refactoring --- hail/python/hail/ir/renderer.py | 75 +++++++++++++++++---------------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 95a6ecc333c..3d570088f71 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -224,7 +224,7 @@ def __init__(self, binding_site: BindingSite): class PrintStackFrame: - def __init__(self, node, children, local_builder, outermost_scope, depth, builder, insert_lets, lift_to_frame=None): + def __init__(self, node, children, builder, outermost_scope, depth, insert_lets, lift_to_frame=None): # immutable self.node: Renderable = node self.children: Sequence[Renderable] = children @@ -233,36 +233,31 @@ def __init__(self, node, children, local_builder, outermost_scope, depth, builde self.lift_to_frame: Optional[int] = lift_to_frame self.insert_lets: bool = insert_lets # mutable - # FIXME: Builder is always parent's local_builder. Shouldn't need to - # store builder. Maybe also make these plain strings. self.builder = builder - self.local_builder = local_builder self.child_idx = -1 - def add_lets(self, let_bodies): + def add_lets(self, let_bodies, out_builder): for let_body in let_bodies: - self.builder.extend(let_body) - self.builder.extend(self.local_builder) + out_builder.extend(let_body) + out_builder.extend(self.builder) num_lets = len(let_bodies) for _ in range(num_lets): - self.builder.append(')') + out_builder.append(')') - def make_child_frame(self, renderer, binding_sites, builder, context, outermost_scope, depth, local_builder=None): + def make_child_frame(self, renderer, binding_sites, builder, context, outermost_scope, depth): child = self.children[self.child_idx] - return self.make(child, renderer, binding_sites, builder, context, outermost_scope, depth, local_builder) + return self.make(child, renderer, binding_sites, builder, context, outermost_scope, depth) @staticmethod - def make(node, renderer, binding_sites, builder, context, outermost_scope, depth, local_builder=None): - if local_builder is None: - local_builder = builder + def make(node, renderer, binding_sites, builder, context, outermost_scope, depth): insert_lets = id(node) in binding_sites and len(binding_sites[id(node)].lifted_lets) > 0 - state = PrintStackFrame(node, node.render_children(renderer), local_builder, outermost_scope, depth, builder, insert_lets) + state = PrintStackFrame(node, node.render_children(renderer), builder, outermost_scope, depth, insert_lets) if insert_lets: - state.local_builder = [] + state.builder = [] context.append(BindingsStackFrame(binding_sites[id(node)])) head = node.render_head(renderer) if head != '': - state.local_builder.append(head) + state.builder.append(head) return state @@ -378,7 +373,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: # Since we are not traversing 'child', we don't know its free # variables. To prevent a parent from being lifted too high, # we must register 'child' as having the free variable 'uid', - # which will be true when 'child' is replaced by "Ref uid". + # which will be true when 'child' is replaced by "(Ref uid)". frame._free_vars[uid] = seen_in_scope continue @@ -448,29 +443,31 @@ def build_string(self, root, binding_sites): if frame.lift_to_frame is not None: assert(not frame.insert_lets) if id(node) in self.memo: - frame.local_builder.extend(self.memo[id(node)]) + frame.builder.extend(self.memo[id(node)]) else: - frame.local_builder.append(node.render_tail(self)) - frame.local_builder.append(' ') + frame.builder.append(node.render_tail(self)) + frame.builder.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets - frame.lift_to_frame.let_bodies.append(frame.local_builder) + frame.lift_to_frame.let_bodies.append(frame.builder) (name, _) = frame.lift_to_frame.lifted_lets[id(node)] - # FIXME: can this be added on first visit? - frame.builder.append(f'(Ref {name})') stack.pop() continue else: - frame.local_builder.append(node.render_tail(self)) + frame.builder.append(node.render_tail(self)) + stack.pop() if frame.insert_lets: - frame.add_lets(context[-1].let_bodies) + if not stack: + out_builder = root_builder + else: + out_builder = stack[-1].builder + frame.add_lets(context[-1].let_bodies, out_builder) context.pop() - stack.pop() if not stack: - return ''.join(frame.builder) + return ''.join(root_builder) continue - frame.local_builder.append(' ') + frame.builder.append(' ') child = frame.children[child_idx] child_outermost_scope = frame.min_binding_depth @@ -489,26 +486,30 @@ def build_string(self, root, binding_sites): assert not insert_lets (name, _) = lift_to_frame.lifted_lets[id(child)] + child_builder = [f'(Let {name} '] + + if id(child) in self.memo: + new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, insert_lets, lift_to_frame) + stack.append(new_state) + continue + + frame.builder.append(f'(Ref {name})') + if id(child) in lift_to_frame.visited: - frame.local_builder.append(f'(Ref {name})') continue lift_to_frame.visited[id(child)] = child - child_builder = [f'(Let {name} '] - if id(child) in self.memo: - new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, frame.local_builder, insert_lets, lift_to_frame) - else: - new_state = frame.make_child_frame(self, binding_sites, frame.local_builder, context, child_outermost_scope, child_depth, child_builder) - new_state.lift_to_frame = lift_to_frame + new_state = frame.make_child_frame(self, binding_sites, child_builder, context, child_outermost_scope, child_depth) + new_state.lift_to_frame = lift_to_frame stack.append(new_state) continue if id(child) in self.memo: - frame.local_builder.extend(self.memo[id(child)]) + frame.builder.extend(self.memo[id(child)]) continue - new_state = frame.make_child_frame(self, binding_sites, frame.local_builder, context, child_outermost_scope, child_depth) + new_state = frame.make_child_frame(self, binding_sites, frame.builder, context, child_outermost_scope, child_depth) stack.append(new_state) continue From 2786aec2420eaad7c8abb5c4e52c003b511f8c1f Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 28 Aug 2019 14:10:39 -0400 Subject: [PATCH 32/47] use namedtuple and __slots__ --- hail/python/hail/backend/backend.py | 13 ++--- hail/python/hail/ir/renderer.py | 78 ++++++----------------------- hail/python/test/hail/test_ir.py | 8 +-- 3 files changed, 22 insertions(+), 77 deletions(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 232d6d93ea4..24c260e638e 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -6,7 +6,7 @@ from hail.expr.table_type import * from hail.expr.matrix_type import * from hail.expr.blockmatrix_type import * -from hail.ir.renderer import Renderer, CSERenderer +from hail.ir.renderer import Renderer from hail.table import Table from hail.matrixtable import MatrixTable @@ -99,16 +99,9 @@ def fs(self): def _to_java_ir(self, ir): if not hasattr(ir, '_jir'): - r = CSERenderer(stop_at_jir=True) - r_old = Renderer(stop_at_jir=True) + r = Renderer(stop_at_jir=True) # FIXME parse should be static - no_cse = r_old(ir) - print('without cse:') - print(no_cse) - cse = r(ir) - print('with cse:') - print(cse) - ir._jir = ir.parse(cse, ir_map=r.jirs) + ir._jir = ir.parse(r(ir), ir_map=r.jirs) return ir._jir def execute(self, ir, timed=False): diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 3d570088f71..c58de336377 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,6 +1,7 @@ from hail import ir import abc from typing import Sequence, List, Set, Dict, Callable, Tuple, Optional +from collections import namedtuple class Renderable(object): @@ -88,60 +89,14 @@ def is_empty(self) -> bool: return self._idx < 0 -class Renderer(object): - def __init__(self, stop_at_jir=False): - self.stop_at_jir = stop_at_jir - self.count = 0 - self.jirs = {} - - def add_jir(self, jir): - jir_id = f'm{self.count}' - self.count += 1 - self.jirs[jir_id] = jir - return jir_id - - def __call__(self, x: 'Renderable'): - stack = RQStack() - builder = [] - - while x is not None or stack.non_empty(): - if x is not None: - # TODO: it would be nice to put the JavaIR logic in BaseIR somewhere but this isn't trivial - if self.stop_at_jir and hasattr(x, '_jir'): - jir_id = self.add_jir(x._jir) - if isinstance(x, ir.MatrixIR): - builder.append(f'(JavaMatrix {jir_id})') - elif isinstance(x, ir.TableIR): - builder.append(f'(JavaTable {jir_id})') - elif isinstance(x, ir.BlockMatrixIR): - builder.append(f'(JavaBlockMatrix {jir_id})') - else: - assert isinstance(x, ir.IR) - builder.append(f'(JavaIR {jir_id})') - else: - head = x.render_head(self) - if head != '': - builder.append(x.render_head(self)) - stack.push(RenderableQueue(x.render_children(self), x.render_tail(self))) - x = None - else: - top = stack.peek() - if top.exhausted(): - stack.pop() - builder.append(top.tail) - else: - builder.append(' ') - x = top.pop() - - return ''.join(builder) - - Vars = Dict[str, int] Context = (Vars, Vars, Vars) # FIXME: replace these by NamedTuples, or use slots class AnalysisStackFrame: + __slots__ = ['min_binding_depth', 'context', 'node', '_free_vars', 'visited', + 'lifted_lets', 'child_idx', '_child_bindings'] def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR'): # immutable self.min_binding_depth = min_binding_depth @@ -209,21 +164,18 @@ def update_parent_free_vars(self, parent_frame: 'AnalysisStackFrame'): parent_frame._free_vars.update(self._free_vars) -class BindingSite: - def __init__(self, lifted_lets: Dict[int, str], depth: int): - self.depth = depth - self.lifted_lets = lifted_lets - +BindingSite = namedtuple('BindingSite', 'depth lifted_lets') +BindingsStackFrame = namedtuple('BindingsStackFrame', 'depth lifted_lets visited let_bodies') -class BindingsStackFrame: - def __init__(self, binding_site: BindingSite): - self.depth = binding_site.depth - self.lifted_lets = binding_site.lifted_lets - self.visited = {} - self.let_bodies = [] +def make_bindings_stack_frame(site: BindingSite): + return BindingsStackFrame(depth=site.depth, + lifted_lets=site.lifted_lets, + visited={}, + let_bodies=[]) class PrintStackFrame: + __slots__ = ['node', 'children', 'min_binding_depth', 'depth', 'lift_to_frame', 'insert_lets', 'builder', 'child_idx'] def __init__(self, node, children, builder, outermost_scope, depth, insert_lets, lift_to_frame=None): # immutable self.node: Renderable = node @@ -254,14 +206,14 @@ def make(node, renderer, binding_sites, builder, context, outermost_scope, depth state = PrintStackFrame(node, node.render_children(renderer), builder, outermost_scope, depth, insert_lets) if insert_lets: state.builder = [] - context.append(BindingsStackFrame(binding_sites[id(node)])) + context.append(make_bindings_stack_frame(binding_sites[id(node)])) head = node.render_head(renderer) if head != '': state.builder.append(head) return state -class CSERenderer(Renderer): +class Renderer: def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir self.jir_count = 0 @@ -334,7 +286,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: # binding sites if frame.lifted_lets: binding_sites[id(node)] = \ - BindingSite(frame.lifted_lets, len(stack)) + BindingSite(lifted_lets=frame.lifted_lets, depth=len(stack)) stack.pop() if not stack: @@ -392,7 +344,7 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: root_free_vars = root_frame.free_vars() assert(len(root_free_vars) == 0) - binding_sites[id(root)] = BindingSite(frame.lifted_lets, 0) + binding_sites[id(root)] = BindingSite(lifted_lets=frame.lifted_lets, depth=0) return binding_sites # At top of main loop, we are considering the 'Renderable' 'node' and its diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 34792cf08ca..71fcd796937 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -1,7 +1,7 @@ import unittest import hail as hl import hail.ir as ir -from hail.ir.renderer import CSERenderer +from hail.ir.renderer import Renderer from hail.expr import construct_expr from hail.expr.types import tint32 from hail.utils.java import Env @@ -364,7 +364,7 @@ def test_cse(self): ' (ApplyBinaryPrimOp `+`' ' (Ref __cse_1)' ' (Ref __cse_1)))') - self.assertEqual(expected, CSERenderer()(x)) + self.assertEqual(expected, Renderer()(x)) def test_cse2(self): x = ir.I32(5) @@ -380,7 +380,7 @@ def test_cse2(self): ' (Ref __cse_2)' ' (I32 4))' ' (Ref __cse_2))))') - self.assertEqual(expected, CSERenderer()(div)) + self.assertEqual(expected, Renderer()(div)) def test_cse_ifs(self): outer_repeated = ir.I32(5) @@ -396,7 +396,7 @@ def test_cse_ifs(self): ' (I32 5)))' ' (I32 5))' ) - self.assertEqual(expected, CSERenderer()(cond)) + self.assertEqual(expected, Renderer()(cond)) # # def test_foo(self): # array = ir.MakeArray([ir.I32(5)], tint32) From 0e54c39718cbf617b15df34e99d3af8b1f0e8f10 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Thu, 29 Aug 2019 10:24:15 -0400 Subject: [PATCH 33/47] [hail] Fix performance problems in aggregators by binding arguments --- .../hail/expr/aggregators/aggregators.py | 100 ++++++++++-------- 1 file changed, 57 insertions(+), 43 deletions(-) diff --git a/hail/python/hail/expr/aggregators/aggregators.py b/hail/python/hail/expr/aggregators/aggregators.py index 8d35dfef0f1..e41a5a600e1 100644 --- a/hail/python/hail/expr/aggregators/aggregators.py +++ b/hail/python/hail/expr/aggregators/aggregators.py @@ -757,7 +757,7 @@ def mean(expr) -> Float64Expression: :class:`.Expression` of type :py:data:`.tfloat64` Mean value of records of `expr`. """ - return sum(expr)/count_where(hl.is_defined(expr)) + return hl.bind(lambda expr: sum(expr) / count_where(hl.is_defined(expr)), expr, _ctx=_agg_func.context) @typecheck(expr=expr_float64) @@ -794,20 +794,24 @@ def stats(expr) -> StructExpression: `n`, and `sum`. """ - return hl.bind(lambda aggs: - hl.bind(lambda mean: hl.struct( - mean = mean, - stdev = hl.sqrt(hl.float64(aggs.sumsq - (2 * mean * aggs.sum) + (aggs.n_def * mean ** 2)) / aggs.n_def), - min = hl.float64(aggs.min), - max = hl.float64(aggs.max), - n = aggs.n_def, - sum = hl.float64(aggs.sum) - ), hl.float64(aggs.sum)/aggs.n_def), - hl.struct(n_def = count_where(hl.is_defined(expr)), - sum = sum(expr), - sumsq = sum(expr ** 2), - min = min(expr), - max = max(expr))) + return hl.bind( + lambda expr: hl.bind( + lambda aggs: hl.bind( + lambda mean: hl.struct( + mean=mean, + stdev=hl.sqrt(hl.float64( + aggs.sumsq - (2 * mean * aggs.sum) + (aggs.n_def * mean ** 2)) / aggs.n_def), + min=hl.float64(aggs.min), + max=hl.float64(aggs.max), + n=aggs.n_def, + sum=hl.float64(aggs.sum) + ), hl.float64(aggs.sum) / aggs.n_def), + hl.struct(n_def=count_where(hl.is_defined(expr)), + sum=sum(expr), + sumsq=sum(expr ** 2), + min=min(expr), + max=max(expr))), + expr, _ctx=_agg_func.context) @typecheck(expr=expr_oneof(expr_int64, expr_float64)) @@ -1150,6 +1154,8 @@ def call_stats(call, alleles) -> StructExpression: return _agg_func('CallStats', [call], t, [], init_op_args=[n_alleles]) +_bin_idx_f = None +_result_from_agg_f = None @typecheck(expr=expr_float64, start=expr_float64, end=expr_float64, bins=expr_int32) def hist(expr, start, end, bins) -> StructExpression: @@ -1195,18 +1201,23 @@ def hist(expr, start, end, bins) -> StructExpression: :class:`.StructExpression` Struct expression with fields `bin_edges`, `bin_freq`, `n_smaller`, and `n_larger`. """ - - bin_idx_f = hl.experimental.define_function( - lambda s, e, nbins, binsize, v: - (hl.case() - .when(v < s, -1) - .when(v > e, nbins) - .when(v == e, nbins - 1) - .default(hl.int32(hl.floor((v - s) / binsize)))), - hl.tfloat64, hl.tfloat64, hl.tint32, hl.tfloat64, hl.tfloat64) - - bin_idx = bin_idx_f(start, end, bins, hl.float64(end - start) / bins, expr) - freq_dict = hl.agg.filter(hl.is_defined(expr), hl.agg.group_by(bin_idx, hl.agg.count())) + global _bin_idx_f, _result_from_agg_f + if _bin_idx_f is None: + _bin_idx_f = hl.experimental.define_function( + lambda s, e, nbins, binsize, v: + (hl.case() + .when(v < s, -1) + .when(v > e, nbins) + .when(v == e, nbins - 1) + .default(hl.int32(hl.floor((v - s) / binsize)))), + hl.tfloat64, hl.tfloat64, hl.tint32, hl.tfloat64, hl.tfloat64) + + freq_dict = hl.bind( + lambda expr: hl.agg.filter(hl.is_defined(expr), + hl.agg.group_by( + _bin_idx_f(start, end, bins, hl.float64(end - start) / bins, expr), + hl.agg.count())), + expr, _ctx=_agg_func.context) def result(s, nbins, bs, freq_dict): return hl.struct( @@ -1227,11 +1238,12 @@ def wrap_errors(s, e, nbins, freq_dict): hl.float64(e - s) / nbins)) .or_error(hl.literal("'hist' requires positive 'bins', but bins=") + hl.str(nbins))) - result_from_agg_f = hl.experimental.define_function( - wrap_errors, - hl.tfloat64, hl.tfloat64, hl.tint32, hl.tdict(hl.tint32, hl.tint64)) + if _result_from_agg_f is None: + _result_from_agg_f = hl.experimental.define_function( + wrap_errors, + hl.tfloat64, hl.tfloat64, hl.tint32, hl.tdict(hl.tint32, hl.tint64)) - return result_from_agg_f(start, end, bins, freq_dict) + return _result_from_agg_f(start, end, bins, freq_dict) @typecheck(x=expr_float64, y=expr_float64, label=nullable(oneof(expr_str, expr_array(expr_str))), n_divisions=int) @@ -1529,18 +1541,20 @@ def corr(x, y) -> Float64Expression: ------- :class:`.Float64Expression` """ - - return hl.bind(lambda a: - (a.n * a.xy - a.x * a.y) / - hl.sqrt((a.n * a.xsq - a.x ** 2) * - (a.n * a.ysq - a.y ** 2)), - hl.agg.filter(hl.is_defined(x) & hl.is_defined(y), - hl.struct(x=hl.agg.sum(x), - y=hl.agg.sum(y), - xsq=hl.agg.sum(x ** 2), - ysq=hl.agg.sum(y ** 2), - xy=hl.agg.sum(x * y), - n=hl.agg.count()))) + return hl.bind( + lambda x, y: hl.bind( + lambda a: + (a.n * a.xy - a.x * a.y) / + hl.sqrt((a.n * a.xsq - a.x ** 2) * + (a.n * a.ysq - a.y ** 2)), + hl.agg.filter(hl.is_defined(x) & hl.is_defined(y), + hl.struct(x=hl.agg.sum(x), + y=hl.agg.sum(y), + xsq=hl.agg.sum(x ** 2), + ysq=hl.agg.sum(y ** 2), + xy=hl.agg.sum(x * y), + n=hl.agg.count()))), + x, y, _ctx=_agg_func.context) @typecheck(group=expr_any, agg_expr=agg_expr(expr_any)) From e3be6636550f82fd78b6198ac0ffca9a2076d18b Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Fri, 30 Aug 2019 06:34:45 -0400 Subject: [PATCH 34/47] fix let lifting bug --- hail/python/hail/ir/ir.py | 2 +- .../is/hail/expr/ir/ExtractAggregators.scala | 8 +++---- .../scala/is/hail/expr/ir/agg/Extract.scala | 22 +++++++++++++++++-- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 55a75985e0e..15e6b605209 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -2012,7 +2012,7 @@ def delete(env, name): _subst(ir.body, delete(env, ir.name))) elif isinstance(ir, AggLet): return AggLet(ir.name, - _subst(ir.value), + _subst(ir.value, agg_env, {}), _subst(ir.body, delete(env, ir.name)), ir.is_scan) elif isinstance(ir, ArrayMap): diff --git a/hail/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala b/hail/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala index 699bb5015c1..0416a557752 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ExtractAggregators.scala @@ -133,7 +133,7 @@ object ExtractAggregators { val transformed = this.extract(aggBody, ab, newBuilder, aggLetAB, result, newAggregator, keyedAggregator, arrayAggregator) // collect lets that depend on `name`, push the rest up - val (dependent, independent) = aggLetAB.result().partition(l => Mentions(l.value, name)) + val (dependent, independent) = agg.Extract.partitionDependentLets(aggLetAB.result(), name) ab3 ++= independent val (initOp, seqOp) = newBuilder.result().map { case AggOps(x, y) => (x, y) }.unzip @@ -172,11 +172,11 @@ object ExtractAggregators { val transformed = this.extract[RVAgg](aggBody, newRVAggBuilder, newBuilder, aggLetAB, newRef, newAggregator, keyedAggregator, arrayAggregator) // collect lets that depend on `elementName`, push the rest up - val (dependent, independent) = aggLetAB.result().partition(l => Mentions(l.value, elementName)) + val (dependent, independent) = agg.Extract.partitionDependentLets(aggLetAB.result(), elementName) ab3 ++= independent val nestedAggs = newRVAggBuilder.result() - val agg = arrayAggregator(nestedAggs.map(_.rvAgg).toArray) + val arrAgg = arrayAggregator(nestedAggs.map(_.rvAgg).toArray) val rt = TArray(TTuple(nestedAggs.map(_.rt): _*)) newRef._typ = -rt.elementType @@ -193,7 +193,7 @@ object ExtractAggregators { val (initOp, seqOp) = newBuilder.result().map { case AggOps(x, y) => (x, y) }.unzip val i = ab.length - ab += IRAgg[RVAgg](i, agg, rt) + ab += IRAgg[RVAgg](i, arrAgg, rt) ab2 += AggOps( Some(InitOp(i, FastIndexedSeq[IR](Begin(initOp.flatten.toFastIndexedSeq)) ++ knownLengthIRSeq, aggSigCheck)), Let( diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala index ad44cc35871..d9b8a844c35 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala @@ -10,6 +10,7 @@ import is.hail.io.CodecSpec import is.hail.rvd.{RVDContext, RVDType} import is.hail.utils._ +import scala.collection.mutable import scala.language.{existentials, postfixOps} object TableMapIRNew { @@ -227,6 +228,23 @@ case class Aggs(postAggIR: IR, init: IR, seqPerElt: IR, aggs: Array[AggSignature } object Extract { + def partitionDependentLets(lets: Array[AggLet], name: String): (Array[AggLet], Array[AggLet]) = { + val depBindings = mutable.HashSet.empty[String] + depBindings += name + + val dep = new ArrayBuilder[AggLet] + val indep = new ArrayBuilder[AggLet] + + lets.foreach { l => + val fv = FreeVariables(l.value, supportsAgg = false, supportsScan = false) + if (fv.eval.m.keysIterator.exists(k => depBindings.contains(k))) { + dep += l + depBindings += l.name + } else + indep += l + } + (dep.result(), indep.result()) + } def addLets(ir: IR, lets: Array[AggLet]): IR = { assert(lets.areDistinct()) @@ -318,7 +336,7 @@ object Extract { val newLet = new ArrayBuilder[AggLet]() val transformed = this.extract(aggBody, ab, newSeq, newLet, result) - val (dependent, independent) = newLet.result().partition(l => Mentions(l.value, name)) + val (dependent, independent) = partitionDependentLets(newLet.result(), name) letBuilder ++= independent seqBuilder += ArrayFor(array, name, addLets(Begin(newSeq.result()), dependent)) transformed @@ -350,7 +368,7 @@ object Extract { val newRef = Ref(genUID(), null) val transformed = this.extract(aggBody, newAggs, newSeq, newLet, newRef) - val (dependent, independent) = newLet.result().partition(l => Mentions(l.value, elementName)) + val (dependent, independent) = partitionDependentLets(newLet.result(), elementName) letBuilder ++= independent val i = ab.length From baa5251a91eaa696590659b26bf31b7baf27f68a Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Fri, 30 Aug 2019 07:20:42 -0400 Subject: [PATCH 35/47] bump! From 367819d89455fa5f0a129a92daa3a11716de984a Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 4 Sep 2019 11:00:55 -0400 Subject: [PATCH 36/47] enable AggLets, other refactoring --- hail/python/hail/ir/base_ir.py | 23 +-- hail/python/hail/ir/ir.py | 2 +- hail/python/hail/ir/renderer.py | 278 +++++++++++++++++++++++-------- hail/python/test/hail/test_ir.py | 22 ++- 4 files changed, 236 insertions(+), 89 deletions(-) diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index e50d75f9f91..4e01fcc28fa 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -33,7 +33,10 @@ def render_head(self, r): head_str = self.head_str() if head_str != '': head_str = f' {head_str}' - return f'({self._ir_name()}{head_str}' + trailing_space = '' + if len(self.children) > 0: + trailing_space = ' ' + return f'({self._ir_name()}{head_str}{trailing_space}' def render_tail(self, r): return ')' @@ -182,8 +185,6 @@ def map_ir(self, f): return self.copy(*new_children) - # FIXME: This is only used to compute the free variables in a subtree, - # in an incorrect way. Should just define a free variables method directly. @property def bound_variables(self): return {v for child in self.children for v in child.bound_variables} @@ -196,23 +197,17 @@ def typ(self): return self._type def new_block(self, i: int) -> bool: - return self.uses_agg_context(i) or self.uses_scan_context(i) - - def renderable_new_block(self, i: int) -> bool: - return self.renderable_uses_agg_context(i) or self.renderable_uses_scan_context(i) + return False @abc.abstractmethod def _compute_type(self, env, agg_env): raise NotImplementedError(self) def parse(self, code, ref_map={}, ir_map={}): - try: - return Env.hail().expr.ir.IRParser.parse_value_ir( - code, - {k: t._parsable_string() for k, t in ref_map.items()}, - ir_map) - except: - raise + return Env.hail().expr.ir.IRParser.parse_value_ir( + code, + {k: t._parsable_string() for k, t in ref_map.items()}, + ir_map) class TableIR(BaseIR): diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index eadc0e44c70..b19135cdf9b 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1347,7 +1347,7 @@ def copy(self, *args): return new_instance(self.agg_op, constr_args, init_op_args if len(init_op_args) != 0 else None, seq_op_args) def head_str(self): - return f' {self.agg_op} ' + return f'{self.agg_op}' def render_children(self, r): return [ diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index c58de336377..a4a46466814 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -93,19 +93,28 @@ def is_empty(self) -> bool: Context = (Vars, Vars, Vars) -# FIXME: replace these by NamedTuples, or use slots class AnalysisStackFrame: - __slots__ = ['min_binding_depth', 'context', 'node', '_free_vars', 'visited', - 'lifted_lets', 'child_idx', '_child_bindings'] - def __init__(self, min_binding_depth: int, context: Context, x: 'ir.BaseIR'): + __slots__ = ['min_binding_depth', 'min_value_binding_depth', 'scan_scope', + 'context', 'node', '_free_vars', 'visited', 'agg_visited', + 'scan_visited', 'lifted_lets', 'agg_lifted_lets', + 'scan_lifted_lets', 'child_idx', '_child_bindings'] + + def __init__(self, min_binding_depth: int, min_value_binding_depth: int, + scan_scope: bool, context: Context, x: 'ir.BaseIR'): # immutable self.min_binding_depth = min_binding_depth + self.min_value_binding_depth = min_value_binding_depth + self.scan_scope = scan_scope self.context = context self.node = x # mutable self._free_vars = {} - self.visited: Dict[int, 'ir.BaseIR'] = {} - self.lifted_lets: Dict[int, (str, 'ir.BaseIR')] = {} + self.visited: Set[int] = set() + self.agg_visited: Set[int] = set() + self.scan_visited: Set[int] = set() + self.lifted_lets: Dict[int, str] = {} + self.agg_lifted_lets: Dict[int, str] = {} + self.scan_lifted_lets: Dict[int, str] = {} self.child_idx = 0 self._child_bindings = None @@ -122,10 +131,18 @@ def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': x = self.node i = self.child_idx - 1 child = x.children[i] + child_min_binding_depth = self.min_binding_depth + child_min_value_binding_depth = self.min_value_binding_depth + child_scan_scope = False if x.new_block(i): - child_outermost_scope = depth - else: - child_outermost_scope = self.min_binding_depth + child_min_binding_depth = depth + child_min_value_binding_depth = depth + elif x.uses_agg_context(i): + child_min_value_binding_depth = depth + child_scan_scope = False + elif x.uses_scan_context(i): + child_min_value_binding_depth = depth + child_scan_scope = True # compute vars bound in 'child' by 'node' eval_bindings = x.bindings(i, 0).keys() @@ -135,11 +152,15 @@ def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': self._child_bindings = new_bindings child_context = x.child_context(i, self.context, depth) - return AnalysisStackFrame(child_outermost_scope, child_context, child,) + return AnalysisStackFrame(child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, child_context, child,) def free_vars(self): # subtract vars that will be bound by inserted lets - for (var, _) in self.lifted_lets.values(): + def bound_vars(): + yield from self.lifted_lets.values() + yield from self.agg_lifted_lets.values() + yield from self.scan_lifted_lets.values() + for var in bound_vars(): self._free_vars.pop(var, 0) return self._free_vars @@ -164,25 +185,41 @@ def update_parent_free_vars(self, parent_frame: 'AnalysisStackFrame'): parent_frame._free_vars.update(self._free_vars) -BindingSite = namedtuple('BindingSite', 'depth lifted_lets') -BindingsStackFrame = namedtuple('BindingsStackFrame', 'depth lifted_lets visited let_bodies') +BindingSite = namedtuple( + 'BindingSite', + 'depth lifted_lets agg_lifted_lets scan_lifted_lets') +BindingsStackFrame = namedtuple( + 'BindingsStackFrame', + 'depth lifted_lets agg_lifted_lets scan_lifted_lets visited agg_visited' + ' scan_visited let_bodies') def make_bindings_stack_frame(site: BindingSite): return BindingsStackFrame(depth=site.depth, lifted_lets=site.lifted_lets, + agg_lifted_lets=site.agg_lifted_lets, + scan_lifted_lets=site.scan_lifted_lets, visited={}, + agg_visited={}, + scan_visited={}, let_bodies=[]) class PrintStackFrame: - __slots__ = ['node', 'children', 'min_binding_depth', 'depth', 'lift_to_frame', 'insert_lets', 'builder', 'child_idx'] - def __init__(self, node, children, builder, outermost_scope, depth, insert_lets, lift_to_frame=None): + __slots__ = ['node', 'children', 'min_binding_depth', + 'min_value_binding_depth', 'scan_scope', 'depth', + 'lift_to_frame', 'insert_lets', 'builder', 'child_idx'] + + def __init__(self, node, children, builder, min_binding_depth, + min_value_binding_depth, scan_scope, depth, insert_lets, + lift_to_frame=None): # immutable self.node: Renderable = node self.children: Sequence[Renderable] = children - self.min_binding_depth: int = outermost_scope + self.min_binding_depth: int = min_binding_depth + self.min_value_binding_depth: int = min_value_binding_depth + self.scan_scope: bool = scan_scope self.depth: int = depth - self.lift_to_frame: Optional[int] = lift_to_frame + self.lift_to_frame: Optional[BindingsStackFrame] = lift_to_frame self.insert_lets: bool = insert_lets # mutable self.builder = builder @@ -196,14 +233,23 @@ def add_lets(self, let_bodies, out_builder): for _ in range(num_lets): out_builder.append(')') - def make_child_frame(self, renderer, binding_sites, builder, context, outermost_scope, depth): + def make_child_frame(self, renderer, binding_sites, builder, context, + min_binding_depth, min_value_binding_depth, scan_scope, + depth): child = self.children[self.child_idx] - return self.make(child, renderer, binding_sites, builder, context, outermost_scope, depth) + return self.make(child, renderer, binding_sites, builder, context, + min_binding_depth, min_value_binding_depth, scan_scope, + depth) @staticmethod - def make(node, renderer, binding_sites, builder, context, outermost_scope, depth): - insert_lets = id(node) in binding_sites and len(binding_sites[id(node)].lifted_lets) > 0 - state = PrintStackFrame(node, node.render_children(renderer), builder, outermost_scope, depth, insert_lets) + def make(node, renderer, binding_sites, builder, context, min_binding_depth, + min_value_binding_depth, scan_scope, depth): + insert_lets = (id(node) in binding_sites + and (len(binding_sites[id(node)].lifted_lets) > 0 + or len(binding_sites[id(node)].agg_lifted_lets) > 0)) + state = PrintStackFrame(node, node.render_children(renderer), builder, + min_binding_depth, min_value_binding_depth, + scan_scope, depth, insert_lets) if insert_lets: state.builder = [] context.append(make_bindings_stack_frame(binding_sites[id(node)])) @@ -267,7 +313,7 @@ def add_jir(self, jir): # where for each descendant 'x' which will be bound above 'node', # 'lifted_lets' maps 'id(x)' to the unique id 'x' will be bound to. def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: - root_frame = AnalysisStackFrame(0, ({}, {}, {}), root) + root_frame = AnalysisStackFrame(0, 0, False, ({}, {}, {}), root) stack = [root_frame] binding_sites = {} @@ -279,14 +325,25 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: if child_idx >= len(node.children): # mark node as visited at potential let insertion site + # FIXME: factor out to AnalysisStackFrame method if not node.is_effectful(): - stack[frame.bind_depth()].visited[id(node)] = node + bind_depth = frame.bind_depth() + if bind_depth < frame.min_value_binding_depth: + if frame.scan_scope: + stack[bind_depth].scan_visited.add(id(node)) + else: + stack[bind_depth].agg_visited.add(id(node)) + else: + stack[bind_depth].visited.add(id(node)) # if any lets being inserted here, add node to registry of # binding sites - if frame.lifted_lets: - binding_sites[id(node)] = \ - BindingSite(lifted_lets=frame.lifted_lets, depth=len(stack)) + if frame.lifted_lets or frame.agg_lifted_lets or frame.scan_lifted_lets: + binding_sites[id(node)] = BindingSite( + lifted_lets=frame.lifted_lets, + agg_lifted_lets=frame.agg_lifted_lets, + scan_lifted_lets=frame.scan_lifted_lets, + depth=len(stack)) stack.pop() if not stack: @@ -311,23 +368,47 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: self.memo[id(child)] = jref continue - seen_in_scope = next((i for i in reversed(range(len(stack))) if id(child) in stack[i].visited), None) - - if seen_in_scope is not None and isinstance(child, ir.IR): - # we've seen 'child' before, should not traverse (or we will - # find too many lifts) - if id(child) in stack[seen_in_scope].lifted_lets: - (uid, _) = stack[seen_in_scope].lifted_lets[id(child)] + if isinstance(child, ir.IR): + seen_in_scope = next((i for i in reversed(range(frame.min_value_binding_depth, + len(stack))) + if id(child) in stack[i].visited), + None) + lets = None + if seen_in_scope is not None: + lets = stack[seen_in_scope].lifted_lets else: - # second time we've seen 'child', lift to a let - uid = self.uid() - stack[seen_in_scope].lifted_lets[id(child)] = (uid, child) - # Since we are not traversing 'child', we don't know its free - # variables. To prevent a parent from being lifted too high, - # we must register 'child' as having the free variable 'uid', - # which will be true when 'child' is replaced by "(Ref uid)". - frame._free_vars[uid] = seen_in_scope - continue + if frame.scan_scope: + seen_in_scope = next( + (i for i in reversed(range(frame.min_binding_depth, + frame.min_value_binding_depth)) + if id(child) in stack[i].scan_visited), + None) + if seen_in_scope is not None: + lets = stack[seen_in_scope].scan_lifted_lets + else: + seen_in_scope = next( + (i for i in reversed(range(frame.min_binding_depth, + frame.min_value_binding_depth)) + if id(child) in stack[i].agg_visited), + None) + if seen_in_scope is not None: + lets = stack[seen_in_scope].agg_lifted_lets + + if lets is not None: + # we've seen 'child' before, should not traverse (or we will + # find too many lifts) + if id(child) in lets: + uid = lets[id(child)] + else: + # second time we've seen 'child', lift to a let + uid = self.uid() + lets[id(child)] = uid + # Since we are not traversing 'child', we don't know its free + # variables. To prevent a parent from being lifted too high, + # we must register 'child' as having the free variable 'uid', + # which will be true when 'child' is replaced by "(Ref uid)". + frame._free_vars[uid] = seen_in_scope + continue # first time visiting 'child' @@ -335,7 +416,10 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: if child.name not in node.bindings(child_idx, default_value=0).keys(): (eval_c, _, _) = node.child_context_without_bindings( child_idx, frame.context) - frame._free_vars[child.name] = eval_c[child.name] + try: + frame._free_vars[child.name] = eval_c[child.name] + except: + raise continue stack.append(frame.make_child_frame(len(stack))) @@ -344,7 +428,11 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: root_free_vars = root_frame.free_vars() assert(len(root_free_vars) == 0) - binding_sites[id(root)] = BindingSite(lifted_lets=frame.lifted_lets, depth=0) + binding_sites[id(root)] = BindingSite( + lifted_lets=frame.lifted_lets, + agg_lifted_lets=frame.agg_lifted_lets, + scan_lifted_lets=frame.scan_lifted_lets, + depth=0) return binding_sites # At top of main loop, we are considering the 'Renderable' 'node' and its @@ -383,7 +471,7 @@ def build_string(self, root, binding_sites): if id(root) in self.memo: return ''.join(self.memo[id(root)]) - stack = [PrintStackFrame.make(root, self, binding_sites, root_builder, context, 0, 1)] + stack = [PrintStackFrame.make(root, self, binding_sites, root_builder, context, 0, 0, False, 1)] while True: frame = stack[-1] @@ -402,7 +490,6 @@ def build_string(self, root, binding_sites): # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets frame.lift_to_frame.let_bodies.append(frame.builder) - (name, _) = frame.lift_to_frame.lifted_lets[id(node)] stack.pop() continue else: @@ -419,40 +506,94 @@ def build_string(self, root, binding_sites): return ''.join(root_builder) continue - frame.builder.append(' ') + if child_idx > 0: + frame.builder.append(' ') child = frame.children[child_idx] - child_outermost_scope = frame.min_binding_depth + child_min_binding_depth = frame.min_binding_depth + child_min_value_binding_depth = frame.min_value_binding_depth + child_scan_scope = False child_depth = frame.depth - if isinstance(frame.node, ir.BaseIR) and frame.node.renderable_new_block(frame.child_idx): - child_outermost_scope = frame.depth + if isinstance(node, ir.BaseIR): + if node.renderable_new_block(child_idx): + child_min_binding_depth = frame.depth + child_min_value_binding_depth = frame.depth + if node.renderable_uses_agg_context(child_idx): + child_min_value_binding_depth = frame.depth + child_scan_scope = False + if node.renderable_uses_scan_context(child_idx): + child_min_value_binding_depth = frame.depth + child_scan_scope = True + if isinstance(child, ir.BaseIR): child_depth += 1 - lift_to_frame = next((frame for frame in context if id(child) in frame.lifted_lets), None) + lift_to_frame = next( + (c for c in context + if c.depth >= child_min_value_binding_depth + and id(child) in c.lifted_lets), + None) + if lift_to_frame: + lift_type = 'value' + else: + lift_to_frame = next( + (c for c in context + if child_min_binding_depth <= c.depth < child_min_value_binding_depth + and id(child) in c.agg_lifted_lets), + None) + if lift_to_frame: + lift_type = 'agg' + else: + lift_to_frame = next( + (c for c in context + if child_min_binding_depth <= c.depth < child_min_value_binding_depth + and id(child) in c.scan_lifted_lets), + None) + if lift_to_frame: + lift_type = 'scan' else: lift_to_frame = None - if lift_to_frame and lift_to_frame.depth >= frame.min_binding_depth: - insert_lets = id(child) in binding_sites and len(binding_sites[id(child)].lifted_lets) > 0 + if lift_to_frame: + insert_lets = (id(child) in binding_sites + and (len(binding_sites[id(child)].lifted_lets) > 0 + or len(binding_sites[id(child)].agg_lifted_lets) > 0 + or len(binding_sites[id(child)].scan_lifted_lets > 0))) assert not insert_lets - (name, _) = lift_to_frame.lifted_lets[id(child)] + if lift_type == 'value': + name = lift_to_frame.lifted_lets[id(child)] + child_builder = [f'(Let {name} '] + visited = lift_to_frame.visited + elif lift_type == 'agg': + name = lift_to_frame.agg_lifted_lets[id(child)] + child_builder = [f'(AggLet {name} False '] + visited = lift_to_frame.agg_visited + else: + name = lift_to_frame.scan_lifted_lets[id(child)] + child_builder = [f'(AggLet {name} True '] + visited = lift_to_frame.scan_visited - child_builder = [f'(Let {name} '] + frame.builder.append(f'(Ref {name})') if id(child) in self.memo: - new_state = PrintStackFrame(child, [], child_builder, child_outermost_scope, child_depth, insert_lets, lift_to_frame) - stack.append(new_state) + child_builder.extend(self.memo[id(child)]) + child_builder.append(' ') + lift_to_frame.let_bodies.append(child_builder) continue + # new_state = PrintStackFrame(child, [], child_builder, child_min_binding_depth, child_depth, insert_lets, lift_to_frame) + # stack.append(new_state) + # continue - frame.builder.append(f'(Ref {name})') - - if id(child) in lift_to_frame.visited: + if id(child) in visited: continue - - lift_to_frame.visited[id(child)] = child - - new_state = frame.make_child_frame(self, binding_sites, child_builder, context, child_outermost_scope, child_depth) + visited[id(child)] = child + + new_state = frame.make_child_frame(self, binding_sites, + child_builder, context, + child_min_binding_depth, + child_min_value_binding_depth, + child_scan_scope, + child_depth) new_state.lift_to_frame = lift_to_frame stack.append(new_state) continue @@ -461,7 +602,12 @@ def build_string(self, root, binding_sites): frame.builder.extend(self.memo[id(child)]) continue - new_state = frame.make_child_frame(self, binding_sites, frame.builder, context, child_outermost_scope, child_depth) + new_state = frame.make_child_frame(self, binding_sites, + frame.builder, context, + child_min_binding_depth, + child_min_value_binding_depth, + child_scan_scope, + child_depth) stack.append(new_state) continue diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 71fcd796937..d09ac8ffb19 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -397,11 +397,17 @@ def test_cse_ifs(self): ' (I32 5))' ) self.assertEqual(expected, Renderer()(cond)) - # - # def test_foo(self): - # array = ir.MakeArray([ir.I32(5)], tint32) - # ref = ir.Ref('x') - # sum = ir.ApplyBinaryPrimOp('+', ref, ref) - # map = ir.ArrayMap(array, 'x', sum) - # print(NewCSE()(map)) - # self.assertTrue(False) + + def test_agg_cse(self): + x = ir.GetField(ir.Ref('row'), 'idx') + inner_sum = ir.ApplyBinaryPrimOp('+', x, x) + agg = ir.ApplyAggOp('AggOp', [], [], [inner_sum]) + outer_sum = ir.ApplyBinaryPrimOp('+', agg, agg) + table_agg = ir.TableAggregate(ir.TableRange(5, 1), outer_sum) + expected = ( + '(TableAggregate (TableRange 5 1)' + ' (AggLet __cse_1 False (GetField idx (Ref row))' + ' (Let __cse_2 (ApplyAggOp AggOp () None' + ' ((ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))' + ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2)))))') + self.assertEqual(expected, Renderer()(table_agg)) From 6320d30481de68fbd26f3454c844f36d95b57bb2 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 4 Sep 2019 15:24:51 -0400 Subject: [PATCH 37/47] put back old renderer --- hail/python/hail/backend/backend.py | 2 +- hail/python/hail/ir/ir.py | 4 ++ hail/python/hail/ir/renderer.py | 60 ++++++++++++++++++++++++++--- hail/python/test/hail/test_ir.py | 10 ++--- 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 24c260e638e..77a87152ec8 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -6,7 +6,7 @@ from hail.expr.table_type import * from hail.expr.matrix_type import * from hail.expr.blockmatrix_type import * -from hail.ir.renderer import Renderer +from hail.ir.renderer import CSERenderer, Renderer from hail.table import Table from hail.matrixtable import MatrixTable diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index b19135cdf9b..f8be1f265ba 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1349,6 +1349,10 @@ def copy(self, *args): def head_str(self): return f'{self.agg_op}' + # Overloaded to add space after 'agg_op' even if there are no children. + def render_head(self, r): + return f'({self._ir_name()} {self.agg_op} ' + def render_children(self, r): return [ ParensRenderer(self.constructor_args), diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index a4a46466814..4ef7bdd6804 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -89,6 +89,54 @@ def is_empty(self) -> bool: return self._idx < 0 +class Renderer(object): + def __init__(self, stop_at_jir=False): + self.stop_at_jir = stop_at_jir + self.count = 0 + self.jirs = {} + + def add_jir(self, jir): + jir_id = f'm{self.count}' + self.count += 1 + self.jirs[jir_id] = jir + return jir_id + + def __call__(self, x: 'Renderable'): + stack = RQStack() + builder = [] + + while x is not None or stack.non_empty(): + if x is not None: + # TODO: it would be nice to put the JavaIR logic in BaseIR somewhere but this isn't trivial + if self.stop_at_jir and hasattr(x, '_jir'): + jir_id = self.add_jir(x._jir) + if isinstance(x, ir.MatrixIR): + builder.append(f'(JavaMatrix {jir_id})') + elif isinstance(x, ir.TableIR): + builder.append(f'(JavaTable {jir_id})') + elif isinstance(x, ir.BlockMatrixIR): + builder.append(f'(JavaBlockMatrix {jir_id})') + else: + assert isinstance(x, ir.IR) + builder.append(f'(JavaIR {jir_id})') + else: + head = x.render_head(self) + if head != '': + builder.append(x.render_head(self)) + stack.push(RenderableQueue(x.render_children(self), x.render_tail(self))) + x = None + else: + top = stack.peek() + if top.exhausted(): + stack.pop() + builder.append(top.tail) + else: + builder.append(' ') + x = top.pop() + + return ''.join(builder) + + Vars = Dict[str, int] Context = (Vars, Vars, Vars) @@ -133,7 +181,7 @@ def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': child = x.children[i] child_min_binding_depth = self.min_binding_depth child_min_value_binding_depth = self.min_value_binding_depth - child_scan_scope = False + child_scan_scope = self.scan_scope if x.new_block(i): child_min_binding_depth = depth child_min_value_binding_depth = depth @@ -259,7 +307,7 @@ def make(node, renderer, binding_sites, builder, context, min_binding_depth, return state -class Renderer: +class CSERenderer: def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir self.jir_count = 0 @@ -478,12 +526,14 @@ def build_string(self, root, binding_sites): node = frame.node frame.child_idx += 1 child_idx = frame.child_idx + if isinstance(node, ir.ApplyAggOp) and node.agg_op == 'Count': + print('...') if child_idx >= len(frame.children): if frame.lift_to_frame is not None: assert(not frame.insert_lets) if id(node) in self.memo: - frame.builder.extend(self.memo[id(node)]) + frame.builder.append(self.memo[id(node)]) else: frame.builder.append(node.render_tail(self)) frame.builder.append(' ') @@ -576,7 +626,7 @@ def build_string(self, root, binding_sites): frame.builder.append(f'(Ref {name})') if id(child) in self.memo: - child_builder.extend(self.memo[id(child)]) + child_builder.append(self.memo[id(child)]) child_builder.append(' ') lift_to_frame.let_bodies.append(child_builder) continue @@ -599,7 +649,7 @@ def build_string(self, root, binding_sites): continue if id(child) in self.memo: - frame.builder.extend(self.memo[id(child)]) + frame.builder.append(self.memo[id(child)]) continue new_state = frame.make_child_frame(self, binding_sites, diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index d09ac8ffb19..d6e6f098e0b 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -1,7 +1,7 @@ import unittest import hail as hl import hail.ir as ir -from hail.ir.renderer import Renderer +from hail.ir.renderer import CSERenderer from hail.expr import construct_expr from hail.expr.types import tint32 from hail.utils.java import Env @@ -364,7 +364,7 @@ def test_cse(self): ' (ApplyBinaryPrimOp `+`' ' (Ref __cse_1)' ' (Ref __cse_1)))') - self.assertEqual(expected, Renderer()(x)) + self.assertEqual(expected, CSERenderer()(x)) def test_cse2(self): x = ir.I32(5) @@ -380,7 +380,7 @@ def test_cse2(self): ' (Ref __cse_2)' ' (I32 4))' ' (Ref __cse_2))))') - self.assertEqual(expected, Renderer()(div)) + self.assertEqual(expected, CSERenderer()(div)) def test_cse_ifs(self): outer_repeated = ir.I32(5) @@ -396,7 +396,7 @@ def test_cse_ifs(self): ' (I32 5)))' ' (I32 5))' ) - self.assertEqual(expected, Renderer()(cond)) + self.assertEqual(expected, CSERenderer()(cond)) def test_agg_cse(self): x = ir.GetField(ir.Ref('row'), 'idx') @@ -410,4 +410,4 @@ def test_agg_cse(self): ' (Let __cse_2 (ApplyAggOp AggOp () None' ' ((ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))' ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2)))))') - self.assertEqual(expected, Renderer()(table_agg)) + self.assertEqual(expected, CSERenderer()(table_agg)) From 07e37a30a7d17a893bb314eb865f1506fd8f945c Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 5 Sep 2019 15:10:30 -0400 Subject: [PATCH 38/47] fixes --- hail/python/hail/ir/ir.py | 38 +++++++++++++++- hail/python/hail/ir/renderer.py | 78 +++++++++++++------------------- hail/python/test/hail/test_ir.py | 18 ++++++++ 3 files changed, 86 insertions(+), 48 deletions(-) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 20fc5d69e03..068cf9e1b21 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -321,9 +321,38 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.value._compute_type(agg_env, None) - self.body._compute_type(env, _env_bind(agg_env, self.name, self.value._type)) + self.body._compute_type(env, _env_bind(agg_env, [(self.name, self.value._type)])) self._type = self.body._type + def agg_bindings(self, i, default_value=None): + if not self.is_scan and i == 1: + if default_value is None: + value = self.value._type + else: + value = default_value + return {self.name: value} + else: + return {} + + def scan_bindings(self, i, default_value=None): + if self.is_scan and i == 1: + if default_value is None: + value = self.value._type + else: + value = default_value + return {self.name: value} + else: + return {} + + def binds(self, i): + return i == 1 + + def uses_agg_context(self, i: int) -> bool: + return not self.is_scan and i == 0 + + def uses_scan_context(self, i: int) -> bool: + return self.is_scan and i == 0 + class Ref(IR): @typecheck_method(name=str) @@ -1422,6 +1451,13 @@ def _compute_type(self, env, agg_env): [a.typ for a in self.init_op_args] if self.init_op_args else None, [a.typ for a in self.seq_op_args]) + def new_block(self, i: int) -> bool: + n = len(self.constructor_args) + return self.init_op_args and i < n + len(self.init_op_args) + + def renderable_new_block(self, i: int) -> bool: + return i <= 1 + class ApplyAggOp(BaseApplyAggOp): @typecheck_method(agg_op=str, diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 4ef7bdd6804..bb66a0715e5 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -200,7 +200,7 @@ def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': self._child_bindings = new_bindings child_context = x.child_context(i, self.context, depth) - return AnalysisStackFrame(child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, child_context, child,) + return AnalysisStackFrame(child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, child_context, child) def free_vars(self): # subtract vars that will be bound by inserted lets @@ -293,6 +293,7 @@ def make_child_frame(self, renderer, binding_sites, builder, context, def make(node, renderer, binding_sites, builder, context, min_binding_depth, min_value_binding_depth, scan_scope, depth): insert_lets = (id(node) in binding_sites + and depth == binding_sites[id(node)].depth and (len(binding_sites[id(node)].lifted_lets) > 0 or len(binding_sites[id(node)].agg_lifted_lets) > 0)) state = PrintStackFrame(node, node.render_children(renderer), builder, @@ -384,6 +385,8 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: else: stack[bind_depth].visited.add(id(node)) + stack.pop() + # if any lets being inserted here, add node to registry of # binding sites if frame.lifted_lets or frame.agg_lifted_lets or frame.scan_lifted_lets: @@ -393,7 +396,6 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: scan_lifted_lets=frame.scan_lifted_lets, depth=len(stack)) - stack.pop() if not stack: break stack[-1].update_free_vars(frame) @@ -416,8 +418,10 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: self.memo[id(child)] = jref continue + child_frame = frame.make_child_frame(len(stack)) + if isinstance(child, ir.IR): - seen_in_scope = next((i for i in reversed(range(frame.min_value_binding_depth, + seen_in_scope = next((i for i in reversed(range(child_frame.min_value_binding_depth, len(stack))) if id(child) in stack[i].visited), None) @@ -427,16 +431,16 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: else: if frame.scan_scope: seen_in_scope = next( - (i for i in reversed(range(frame.min_binding_depth, - frame.min_value_binding_depth)) + (i for i in reversed(range(child_frame.min_binding_depth, + child_frame.min_value_binding_depth)) if id(child) in stack[i].scan_visited), None) if seen_in_scope is not None: lets = stack[seen_in_scope].scan_lifted_lets else: seen_in_scope = next( - (i for i in reversed(range(frame.min_binding_depth, - frame.min_value_binding_depth)) + (i for i in reversed(range(child_frame.min_binding_depth, + child_frame.min_value_binding_depth)) if id(child) in stack[i].agg_visited), None) if seen_in_scope is not None: @@ -464,23 +468,12 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: if child.name not in node.bindings(child_idx, default_value=0).keys(): (eval_c, _, _) = node.child_context_without_bindings( child_idx, frame.context) - try: - frame._free_vars[child.name] = eval_c[child.name] - except: - raise + frame._free_vars[child.name] = eval_c[child.name] continue - stack.append(frame.make_child_frame(len(stack))) + stack.append(child_frame) continue - root_free_vars = root_frame.free_vars() - assert(len(root_free_vars) == 0) - - binding_sites[id(root)] = BindingSite( - lifted_lets=frame.lifted_lets, - agg_lifted_lets=frame.agg_lifted_lets, - scan_lifted_lets=frame.scan_lifted_lets, - depth=0) return binding_sites # At top of main loop, we are considering the 'Renderable' 'node' and its @@ -519,15 +512,13 @@ def build_string(self, root, binding_sites): if id(root) in self.memo: return ''.join(self.memo[id(root)]) - stack = [PrintStackFrame.make(root, self, binding_sites, root_builder, context, 0, 0, False, 1)] + stack = [PrintStackFrame.make(root, self, binding_sites, root_builder, context, 0, 0, False, 0)] while True: frame = stack[-1] node = frame.node frame.child_idx += 1 child_idx = frame.child_idx - if isinstance(node, ir.ApplyAggOp) and node.agg_op == 'Count': - print('...') if child_idx >= len(frame.children): if frame.lift_to_frame is not None: @@ -569,38 +560,31 @@ def build_string(self, root, binding_sites): if node.renderable_new_block(child_idx): child_min_binding_depth = frame.depth child_min_value_binding_depth = frame.depth - if node.renderable_uses_agg_context(child_idx): + elif node.renderable_uses_agg_context(child_idx): child_min_value_binding_depth = frame.depth child_scan_scope = False - if node.renderable_uses_scan_context(child_idx): + elif node.renderable_uses_scan_context(child_idx): child_min_value_binding_depth = frame.depth child_scan_scope = True if isinstance(child, ir.BaseIR): child_depth += 1 - lift_to_frame = next( - (c for c in context - if c.depth >= child_min_value_binding_depth - and id(child) in c.lifted_lets), - None) - if lift_to_frame: - lift_type = 'value' - else: - lift_to_frame = next( - (c for c in context - if child_min_binding_depth <= c.depth < child_min_value_binding_depth - and id(child) in c.agg_lifted_lets), - None) - if lift_to_frame: - lift_type = 'agg' - else: - lift_to_frame = next( - (c for c in context - if child_min_binding_depth <= c.depth < child_min_value_binding_depth - and id(child) in c.scan_lifted_lets), - None) - if lift_to_frame: + for c in context: + if c.depth >= child_min_value_binding_depth and id(child) in c.lifted_lets: + lift_to_frame = c + lift_type = 'value' + break + if child_min_binding_depth <= c.depth < child_min_value_binding_depth: + if id(child) in c.scan_lifted_lets: + lift_to_frame = c lift_type = 'scan' + break + if id(child) in c.agg_lifted_lets: + lift_to_frame = c + lift_type = 'agg' + break + else: + lift_to_frame = None else: lift_to_frame = None diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 68514f63c83..383fafab1f6 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -415,3 +415,21 @@ def test_agg_cse(self): ' ((ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))' ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2)))))') self.assertEqual(expected, CSERenderer()(table_agg)) + + def test_init_op(self): + x = ir.I32(5) + sum = ir.ApplyBinaryPrimOp('+', x, x) + agg = ir.ApplyAggOp('CallStats', [sum], [sum], [sum]) + top = ir.ApplyBinaryPrimOp('+', sum, agg) + expected = ( + '(Let __cse_1 (I32 5)' + ' (AggLet __cse_4 False (I32 5)' + ' (ApplyBinaryPrimOp `+`' + ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (ApplyAggOp CallStats' + ' ((Let __cse_3 (I32 5)' + ' (ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3))))' + ' ((Let __cse_3 (I32 5)' + ' (ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3))))' + ' ((ApplyBinaryPrimOp `+` (Ref __cse_4) (Ref __cse_4)))))))') + self.assertEqual(expected, CSERenderer()(top)) From 8c9a83a1cedd44ec32aec665770f9a8e8ba61be9 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 5 Sep 2019 15:18:14 -0400 Subject: [PATCH 39/47] Squashed commit of the following: MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit commit 07e37a30a7d17a893bb314eb865f1506fd8f945c Author: patrick-schultz Date: Thu Sep 5 15:10:30 2019 -0400 fixes commit f6b74c5348443242479cc0b82db02ca9b1186e7c Merge: 6320d3048 baa5251a9 Author: patrick-schultz Date: Wed Sep 4 15:28:04 2019 -0400 Merge branch 'tim-agglet-fix' into python-cse commit 6320d30481de68fbd26f3454c844f36d95b57bb2 Author: patrick-schultz Date: Wed Sep 4 15:24:51 2019 -0400 put back old renderer commit 367819d89455fa5f0a129a92daa3a11716de984a Author: patrick-schultz Date: Wed Sep 4 11:00:55 2019 -0400 enable AggLets, other refactoring commit baa5251a91eaa696590659b26bf31b7baf27f68a Author: Tim Poterba Date: Fri Aug 30 07:20:42 2019 -0400 bump! commit e3be6636550f82fd78b6198ac0ffca9a2076d18b Author: Tim Poterba Date: Fri Aug 30 06:34:45 2019 -0400 fix let lifting bug commit 0e54c39718cbf617b15df34e99d3af8b1f0e8f10 Author: Tim Poterba Date: Thu Aug 29 10:24:15 2019 -0400 [hail] Fix performance problems in aggregators by binding arguments commit 2786aec2420eaad7c8abb5c4e52c003b511f8c1f Author: patrick-schultz Date: Wed Aug 28 14:10:39 2019 -0400 use namedtuple and __slots__ commit 92c1f6b9343a9a8034ef0b5e504383de1827b532 Author: patrick-schultz Date: Wed Aug 28 13:34:20 2019 -0400 refactoring commit 3842f38effa89f5386587f5b11dc836eefb098e2 Author: patrick-schultz Date: Wed Aug 28 11:53:02 2019 -0400 refactoring commit f64ac3ef9c2d3be77496bd2b4f96b9d4d54235ca Author: patrick-schultz Date: Mon Aug 26 14:08:21 2019 -0400 refactoring commit 636aa5e647f7a639652bea2af9c40a3e639ef9da Author: patrick-schultz Date: Mon Aug 26 10:51:52 2019 -0400 refactoring, documenting commit 3d3671f81eca5dd114dacc56df1fb3e4e9b15934 Author: patrick-schultz Date: Thu Aug 22 15:35:42 2019 -0400 refactoring commit 6aa0ba14c3a7c35794f10d900e62d176992a63a6 Author: patrick-schultz Date: Thu Aug 22 10:10:54 2019 -0400 refactoring commit 427d8c78c9bc6fb067682dac6058ee753a5bc984 Author: patrick-schultz Date: Wed Aug 14 12:50:30 2019 -0400 wip commit fb6e986dd49de59107d776798a7d2e8c7b8f0324 Author: patrick-schultz Date: Wed Aug 14 10:43:19 2019 -0400 bugfixes commit 1cccb0516018652e9896d9308cc0aae4a397ffc2 Author: patrick-schultz Date: Tue Aug 13 10:05:40 2019 -0400 refactoring, bugfixes commit d2df59240d3e85521b9049b35b8d19aa61085fce Author: patrick-schultz Date: Mon Aug 12 11:46:21 2019 -0400 refactoring commit 6ca6b238d99fb3c59671161b003ed97734c27b75 Author: patrick-schultz Date: Mon Aug 12 10:48:31 2019 -0400 linkedlist -> stack commit 50ff007079592bbb23faffd56ffb6904cde1ade9 Author: patrick-schultz Date: Fri Aug 9 16:18:55 2019 -0400 recursion -> loop commit 55c32baaf379890d1c57d55ba865852293044515 Author: patrick-schultz Date: Fri Aug 9 15:34:17 2019 -0400 defunctionalize commit 3db71dbbcb2afd30a1565082a130240d80b069d2 Author: patrick-schultz Date: Thu Aug 8 15:05:37 2019 -0400 wip commit d196b2e5bf2892c8e6ab28a9a0223eb283b0630f Author: patrick-schultz Date: Thu Aug 8 14:34:15 2019 -0400 inline print commit 16eb524226f9aa3a6ac7b3011a2341d2ec75a882 Author: patrick-schultz Date: Thu Aug 8 11:36:28 2019 -0400 inline print_renderable commit 8460050d5fe6925e2c579e31e68abc4e2f25992d Author: patrick-schultz Date: Thu Aug 8 11:23:40 2019 -0400 factor out args into state object commit f10306ea0ac6441e8364dcde45edeaea5791ba1e Author: patrick-schultz Date: Thu Aug 8 10:50:38 2019 -0400 CPS commit 654f219200452e41487904c2a11dad9c17b2fb85 Author: patrick-schultz Date: Wed Aug 7 13:04:47 2019 -0400 second pass: loop -> recursion commit ed6ce8104248f168cb22557c313b338796a491f8 Author: patrick-schultz Date: Wed Aug 7 09:56:43 2019 -0400 more refactoring commit e12b7e552e07afc6a9cf2877f1bf04cd3ac86422 Author: patrick-schultz Date: Wed Aug 7 08:50:49 2019 -0400 some refactoring commit e0c76adf18cd93b4dddbdaeaa1309abf109c75fa Author: patrick-schultz Date: Fri Aug 2 15:11:21 2019 -0400 cleaning up commit 144a1bd5976090cf40b39459946c63ee9a34b3f2 Author: patrick-schultz Date: Thu Aug 1 16:40:53 2019 -0400 recursion -> loop commit 28f23e36e14014314dd0b28dc198ee30200acb15 Author: patrick-schultz Date: Thu Aug 1 16:28:47 2019 -0400 defunctionalization commit 6619b17264f2aedd3ed840ef33f595c39b5cdbe5 Author: patrick-schultz Date: Thu Aug 1 15:12:36 2019 -0400 inline recur commit c2e9f9268cfa1333d2fa65e1db3b9c409b9a73e9 Author: patrick-schultz Date: Thu Aug 1 15:02:48 2019 -0400 CPS transform commit 9e788d3c0a2b5c901969aabc3ffd0b3b6ccf0bfe Author: patrick-schultz Date: Thu Aug 1 14:29:51 2019 -0400 loop -> recursion commit 3c392e6c655652556e178c203b81d0af04459b73 Author: patrick-schultz Date: Thu Aug 1 14:04:06 2019 -0400 finally fully working? commit ed133da523882a25a01940019d7cdfab81809c46 Author: patrick-schultz Date: Wed Jul 31 14:40:18 2019 -0400 some refactoring commit 689ca407c3880c4e321129eeb193a0716171fcec Author: patrick-schultz Date: Mon Jul 29 14:57:16 2019 -0400 works with agg bindings commit 62e56806caf2803e1ce48e49934ff665369e3213 Author: patrick-schultz Date: Wed Jul 24 11:51:06 2019 -0400 New CSE IR renrerer in Python Doesn’t handle agg environments correctly. --- hail/python/hail/expr/matrix_type.py | 49 ++- hail/python/hail/expr/table_type.py | 15 +- hail/python/hail/ir/__init__.py | 1 + hail/python/hail/ir/base_ir.py | 91 ++++- hail/python/hail/ir/blockmatrix_ir.py | 24 ++ hail/python/hail/ir/ir.py | 344 ++++++++++++++++-- hail/python/hail/ir/matrix_ir.py | 128 +++++++ hail/python/hail/ir/table_ir.py | 68 ++++ .../scala/is/hail/expr/ir/LowerMatrixIR.scala | 7 +- 9 files changed, 680 insertions(+), 47 deletions(-) diff --git a/hail/python/hail/expr/matrix_type.py b/hail/python/hail/expr/matrix_type.py index d63010489d5..bb812010d07 100644 --- a/hail/python/hail/expr/matrix_type.py +++ b/hail/python/hail/expr/matrix_type.py @@ -122,22 +122,39 @@ def _rename(self, global_map, col_map, row_map, entry_map): [row_map.get(k, k) for k in self.row_key], self.entry_type._rename(entry_map)) - def global_env(self): - return {'global': self.global_type} - - def row_env(self): - return {'global': self.global_type, - 'va': self.row_type} - - def col_env(self): - return {'global': self.global_type, - 'sa': self.col_type} - - def entry_env(self): - return {'global': self.global_type, - 'va': self.row_type, - 'sa': self.col_type, - 'g': self.entry_type} + def global_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type} + else: + return {'global': default_value} + + def row_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, + 'va': self.row_type} + else: + return {'global': default_value, + 'va': default_value} + + def col_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, + 'sa': self.col_type} + else: + return {'global': default_value, + 'sa': default_value} + + def entry_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, + 'va': self.row_type, + 'sa': self.col_type, + 'g': self.entry_type} + else: + return {'global': default_value, + 'va': default_value, + 'sa': default_value, + 'g': default_value} import pprint diff --git a/hail/python/hail/expr/table_type.py b/hail/python/hail/expr/table_type.py index 705dea5120b..06c6fe1a041 100644 --- a/hail/python/hail/expr/table_type.py +++ b/hail/python/hail/expr/table_type.py @@ -80,12 +80,17 @@ def _rename(self, global_map, row_map): self.row_type._rename(row_map), [row_map.get(k, k) for k in self.row_key]) - def row_env(self): - return {'global': self.global_type, - 'row': self.row_type} + def row_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type, 'row': self.row_type} + else: + return {'global': default_value, 'row': default_value} - def global_env(self): - return {'global': self.global_type} + def global_env(self, default_value=None): + if default_value is None: + return {'global': self.global_type} + else: + return {'global': default_value} import pprint diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index 614d25abe21..38f22c98770 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -1,3 +1,4 @@ +from .base_ir import * from .ir import * from .register_functions import * from .register_aggregators import * diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index e116ffe40e0..2fb42c3c324 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,11 +1,24 @@ import abc -from typing import List +from typing import List, Tuple from hail.utils.java import Env +from hail.expr.types import HailType from .renderer import Renderer, Renderable, RenderableStr +def _env_bind(env, bindings): + if bindings: + if env: + res = env.copy() + res.update(bindings) + return res + else: + return dict(bindings) + else: + return env + + class BaseIR(Renderable): def __init__(self, *children): super().__init__() @@ -44,7 +57,8 @@ def head_str(self): def parse(self, code, ref_map, ir_map): return - @abc.abstractproperty + @property + @abc.abstractmethod def typ(self): return @@ -71,6 +85,59 @@ def _eq(self, other): def __hash__(self): return 31 + hash(str(self)) + @abc.abstractmethod + def new_block(self, i: int) -> bool: + ... + + @staticmethod + def is_effectful() -> bool: + return False + + def binds(self, i: int) -> bool: + return False + + def bindings(self, i: int, default_value=None): + """Compute variables bound in child 'i'. + + Returns + ------- + dict + mapping from bound variables to 'default_value', if provided, + otherwise to their types + """ + return {} + + def agg_bindings(self, i: int, default_value=None): + return {} + + def scan_bindings(self, i: int, default_value=None): + return {} + + def uses_agg_context(self, i: int) -> bool: + return False + + def uses_scan_context(self, i: int) -> bool: + return False + + def child_context_without_bindings(self, i: int, parent_context): + (eval_c, agg_c, scan_c) = parent_context + if self.uses_agg_context(i): + return (agg_c, None, None) + elif self.uses_scan_context(i): + return (scan_c, None, None) + else: + return parent_context + + def child_context(self, i: int, parent_context, default_value=None): + base = self.child_context_without_bindings(i, parent_context) + if not self.binds(i): + return base + eval_b = self.bindings(i, default_value) + agg_b = self.agg_bindings(i, default_value) + scan_b = self.scan_bindings(i, default_value) + (eval_c, agg_c, scan_c) = base + return _env_bind(eval_c, eval_b), _env_bind(agg_c, agg_b), _env_bind(scan_c, scan_b) + class IR(BaseIR): def __init__(self, *children): @@ -117,6 +184,9 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return False + @abc.abstractmethod def _compute_type(self, env, agg_env): raise NotImplementedError(self) @@ -143,9 +213,15 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return True + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_table_ir(code, ref_map, ir_map) + global_env = {'global'} + row_env = {'global', 'row'} + class MatrixIR(BaseIR): def __init__(self, *children): @@ -162,9 +238,17 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return True + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map) + global_env = {'global'} + row_env = {'global', 'va'} + col_env = {'global', 'sa'} + entry_env = {'global', 'sa', 'va', 'g'} + class BlockMatrixIR(BaseIR): def __init__(self, *children): @@ -181,6 +265,9 @@ def typ(self): assert self._type is not None, self return self._type + def new_block(self, i: int) -> bool: + return True + def parse(self, code, ref_map={}, ir_map={}): return Env.hail().expr.ir.IRParser.parse_blockmatrix_ir(code, ref_map, ir_map) diff --git a/hail/python/hail/ir/blockmatrix_ir.py b/hail/python/hail/ir/blockmatrix_ir.py index b6db226171f..704d29e5b9e 100644 --- a/hail/python/hail/ir/blockmatrix_ir.py +++ b/hail/python/hail/ir/blockmatrix_ir.py @@ -37,6 +37,16 @@ def __init__(self, child, f): def _compute_type(self): self._type = self.child.typ + def bindings(self, i: int, default_value=None): + if i == 1: + value = self.child.typ.element_type if default_value is None else default_value + return {'element': value} + else: + return {} + + def binds(self, i): + return {'element'} if i == 1 else {} + class BlockMatrixMap2(BlockMatrixIR): @typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR, f=IR) @@ -50,6 +60,20 @@ def _compute_type(self): self.right.typ # Force self._type = self.left.typ + def bindings(self, i: int, default_value=None): + if i == 2: + if default_value is None: + l_value = self.left.typ.element_type + r_value = self.right.typ.element_type + else: + (l_value, r_value) = (default_value, default_value) + return {'l': l_value, 'r': r_value} + else: + return {} + + def binds(self, i): + return {'l', 'r'} if i == 2 else {} + class BlockMatrixDot(BlockMatrixIR): @typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 1a70ce33bf2..f866b80cd3d 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -9,15 +9,11 @@ from hail.typecheck import * from hail.utils.misc import escape_str, dump_json, parsable_strings, escape_id from .base_ir import * +from .base_ir import _env_bind from .matrix_writer import MatrixWriter, MatrixNativeMultiWriter from .renderer import Renderer, Renderable, RenderableStr, ParensRenderer from .table_writer import TableWriter -def _env_bind(env, k, v): - env = env.copy() - env[k] = v - return env - class I32(IR): @typecheck_method(x=int) @@ -236,6 +232,9 @@ def _compute_type(self, env, agg_env): assert (self.cnsq.typ == self.altr.typ) self._type = self.cnsq.typ + def new_block(self, i): + return i == 1 or i == 2 + class Coalesce(IR): @typecheck_method(values=IR) @@ -280,9 +279,22 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.value._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.value._type), agg_env) + self.body._compute_type(_env_bind(env, self.bindings(1)), agg_env) self._type = self.body._type + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.value._type + else: + value = default_value + return {self.name: value} + else: + return {} + + def binds(self, i): + return i == 1 + class AggLet(IR): @typecheck_method(name=str, value=IR, body=IR, is_scan=bool) @@ -309,9 +321,38 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.value._compute_type(agg_env, None) - self.body._compute_type(env, _env_bind(agg_env, self.name, self.value._type)) + self.body._compute_type(env, _env_bind(agg_env, [(self.name, self.value._type)])) self._type = self.body._type + def agg_bindings(self, i, default_value=None): + if not self.is_scan and i == 1: + if default_value is None: + value = self.value._type + else: + value = default_value + return {self.name: value} + else: + return {} + + def scan_bindings(self, i, default_value=None): + if self.is_scan and i == 1: + if default_value is None: + value = self.value._type + else: + value = default_value + return {self.name: value} + else: + return {} + + def binds(self, i): + return i == 1 + + def uses_agg_context(self, i: int) -> bool: + return not self.is_scan and i == 0 + + def uses_scan_context(self, i: int) -> bool: + return self.is_scan and i == 0 + class Ref(IR): @typecheck_method(name=str) @@ -576,9 +617,22 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.nd._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.nd.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.nd.typ.element_type)]), agg_env) self._type = tndarray(self.body.typ, self.nd.typ.ndim) + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.nd.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} + + def binds(self, i): + return i == 1 + class NDArrayRef(IR): @typecheck_method(nd=IR, idxs=sequenceof(IR)) @@ -697,6 +751,10 @@ def _compute_type(self, env, agg_env): self.path._compute_type(env, agg_env) self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class ArraySort(IR): @typecheck_method(a=IR, l_name=str, r_name=str, compare=IR) @@ -725,6 +783,19 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self._type = self.a.typ + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.l_name: value, self.r_name: value} + else: + return {} + + def binds(self, i): + return i == 1 + class ToSet(IR): @typecheck_method(a=IR) @@ -832,9 +903,22 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) self._type = tarray(self.body.typ) + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} + + def binds(self, i): + return i == 1 + class ArrayFilter(IR): @typecheck_method(a=IR, name=str, body=IR) @@ -860,9 +944,22 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) self._type = self.a.typ + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} + + def binds(self, i): + return i == 1 + class ArrayFlatMap(IR): @typecheck_method(a=IR, name=str, body=IR) @@ -888,9 +985,21 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) self._type = tarray(self.body.typ.element_type) + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.name: value} + return {} + + def binds(self, i): + return i == 1 + class ArrayFold(IR): @typecheck_method(a=IR, zero=IR, accum_name=str, value_name=str, body=IR) @@ -921,12 +1030,23 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) self.body._compute_type( - _env_bind( - _env_bind(env, self.value_name, self.a.typ.element_type), - self.accum_name, self.zero.typ), + _env_bind(env, [(self.value_name, self.a.typ.element_type), + (self.accum_name, self.zero.typ)]), agg_env) self._type = self.zero.typ + def bindings(self, i, default_value=None): + if i == 2: + if default_value is None: + return {self.accum_name: self.zero.typ, self.value_name: self.a.typ.element_type} + else: + return {self.accum_name: default_value, self.value_name: default_value} + else: + return {} + + def binds(self, i): + return i == 2 + class ArrayScan(IR): @typecheck_method(a=IR, zero=IR, accum_name=str, value_name=str, body=IR) @@ -957,12 +1077,23 @@ def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) self.body._compute_type( - _env_bind( - _env_bind(env, self.value_name, self.a.typ.element_type), - self.accum_name, self.zero.typ), + _env_bind(env, [(self.value_name, self.a.typ.element_type), + (self.accum_name, self.zero.typ)]), agg_env) self._type = tarray(self.body.typ) + def bindings(self, i, default_value=None): + if i == 2: + if default_value is None: + return {self.accum_name: self.zero.typ, self.value_name: self.a.typ.element_type} + else: + return {self.accum_name: default_value, self.value_name: default_value} + else: + return {} + + def binds(self, i): + return i == 2 + class ArrayLeftJoinDistinct(IR): @typecheck_method(left=IR, right=IR, l_name=str, r_name=str, compare=IR, join=IR) @@ -990,6 +1121,20 @@ def _eq(self, other): def bound_variables(self): return {self.l_name, self.r_name} | super().bound_variables + def bindings(self, i, default_value=None): + if i == 2 or i == 3: + if default_value is None: + return {self.l_name: self.left.typ.element_type, + self.r_name: self.right.typ.element_type} + else: + return {self.l_name: default_value, + self.r_name: default_value} + else: + return {} + + def binds(self, i): + return i == 2 or i == 3 + class ArrayFor(IR): @typecheck_method(a=IR, value_name=str, body=IR) @@ -1015,9 +1160,22 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, self.value_name, self.a.typ.element_type), agg_env) + self.body._compute_type(_env_bind(env, [(self.value_name, self.a.typ.element_type)]), agg_env) self._type = tvoid + def bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.a.typ.element_type + else: + value = default_value + return {self.value_name: value} + else: + return {} + + def binds(self, i): + return i == 1 + class AggFilter(IR): @typecheck_method(cond=IR, agg_ir=IR, is_scan=bool) @@ -1042,6 +1200,12 @@ def _compute_type(self, env, agg_env): self.agg_ir._compute_type(env, agg_env) self._type = self.agg_ir.typ + def uses_agg_context(self, i: int): + return i == 0 and not self.is_scan + + def uses_scan_context(self, i: int): + return i == 0 and self.is_scan + class AggExplode(IR): @typecheck_method(array=IR, name=str, agg_body=IR, is_scan=bool) @@ -1068,9 +1232,32 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - self.agg_body._compute_type(env, _env_bind(agg_env, self.name, self.array.typ.element_type)) + new_agg = _env_bind(agg_env, [(self.name, self.array.typ.element_type)]) + self.agg_body._compute_type(env, new_agg) self._type = self.agg_body.typ + def agg_bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.array.typ.element_type + else: + value = default_value + return {self.name: value} + else: + return {} + + def scan_bindings(self, i, default_value=None): + return self.agg_bindings(i, default_value) + + def binds(self, i): + return i == 1 + + def uses_agg_context(self, i: int): + return i == 0 and not self.is_scan + + def uses_scan_context(self, i: int): + return i == 0 and self.is_scan + class AggGroupBy(IR): @typecheck_method(key=IR, agg_ir=IR, is_scan=bool) @@ -1095,6 +1282,12 @@ def _compute_type(self, env, agg_env): self.agg_ir._compute_type(env, agg_env) self._type = tdict(self.key.typ, self.agg_ir.typ) + def uses_agg_context(self, i: int): + return i == 0 and not self.is_scan + + def uses_scan_context(self, i: int): + return i == 0 and self.is_scan + class AggArrayPerElement(IR): @typecheck_method(array=IR, element_name=str, index_name=str, agg_ir=IR, is_scan=bool) @@ -1118,14 +1311,43 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - self.agg_ir._compute_type(_env_bind(env, self.index_name, tint32), - _env_bind(agg_env, self.element_name, self.array.typ.element_type)) + self.agg_ir._compute_type(_env_bind(env, [(self.index_name, tint32)]), + _env_bind(agg_env, [(self.element_name, self.array.typ.element_type)])) self._type = tarray(self.agg_ir.typ) @property def bound_variables(self): return {self.element_name, self.index_name} | super().bound_variables + def uses_agg_context(self, i: int): + return i == 0 and not self.is_scan + + def uses_scan_context(self, i: int): + return i == 0 and self.is_scan + + def bindings(self, i, default_value=None): + if i == 1: + value = tint32 if default_value is None else default_value + return {self.index_name: value} + else: + return {} + + def agg_bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + value = self.array.typ.element_type + else: + value = default_value + return {self.element_name: value} + else: + return {} + + def scan_bindings(self, i, default_value=None): + return self.agg_bindings(i, default_value) + + def binds(self, i): + return i == 1 + def _register(registry, name, f): registry[name].append(f) @@ -1183,7 +1405,11 @@ def copy(self, *args): return new_instance(self.agg_op, constr_args, init_op_args if len(init_op_args) != 0 else None, seq_op_args) def head_str(self): - return f' {self.agg_op} ' + return f'{self.agg_op}' + + # Overloaded to add space after 'agg_op' even if there are no children. + def render_head(self, r): + return f'({self._ir_name()} {self.agg_op} ' def render_children(self, r): return [ @@ -1225,6 +1451,10 @@ def _compute_type(self, env, agg_env): [a.typ for a in self.init_op_args] if self.init_op_args else None, [a.typ for a in self.seq_op_args]) + def new_block(self, i: int) -> bool: + n = len(self.constructor_args) + return self.init_op_args and i < n + len(self.init_op_args) + class ApplyAggOp(BaseApplyAggOp): @typecheck_method(agg_op=str, @@ -1234,6 +1464,9 @@ class ApplyAggOp(BaseApplyAggOp): def __init__(self, agg_op, constructor_args, init_op_args, seq_op_args): super().__init__(agg_op, constructor_args, init_op_args, seq_op_args) + def uses_agg_context(self, i: int): + return i >= len(self.constructor_args) + (len(self.init_op_args) if self.init_op_args else 0) + class ApplyScanOp(BaseApplyAggOp): @typecheck_method(agg_op=str, @@ -1243,6 +1476,9 @@ class ApplyScanOp(BaseApplyAggOp): def __init__(self, agg_op, constructor_args, init_op_args, seq_op_args): super().__init__(agg_op, constructor_args, init_op_args, seq_op_args) + def uses_scan_context(self, i: int): + return i >= len(self.constructor_args) + (len(self.init_op_args) if self.init_op_args else 0) + class Begin(IR): @typecheck_method(xs=sequenceof(IR)) @@ -1500,6 +1736,10 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self._type = self._typ + @staticmethod + def is_effectful() -> bool: + return True + _function_registry = defaultdict(list) _seeded_function_registry = defaultdict(list) @@ -1610,6 +1850,10 @@ def _compute_type(self, env, agg_env): self._type = self.return_type + @staticmethod + def is_effectful() -> bool: + return True + class Uniroot(IR): @typecheck_method(argname=str, function=IR, min=IR, max=IR) @@ -1635,11 +1879,21 @@ def _eq(self, other): return other.argname == self.argname def _compute_type(self, env, agg_env): - self.function._compute_type(_env_bind(env, self.argname, tfloat64), agg_env) + self.function._compute_type(_env_bind(env, [(self.argname, tfloat64)]), agg_env) self.min._compute_type(env, agg_env) self.max._compute_type(env, agg_env) self._type = tfloat64 + def bindings(self, i, default_value=None): + if i == 0: + value = tfloat64 if default_value is None else default_value + return {self.argname: value} + else: + return {} + + def binds(self, i): + return i == 0 + class TableCount(IR): @typecheck_method(child=TableIR) @@ -1702,6 +1956,18 @@ def _compute_type(self, env, agg_env): self.query._compute_type(self.child.typ.global_env(), self.child.typ.row_env()) self._type = self.query.typ + def new_block(self, i: int): + return i == 1 + + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} + + def agg_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixAggregate(IR): @typecheck_method(child=MatrixIR, query=IR) @@ -1718,6 +1984,18 @@ def _compute_type(self, env, agg_env): self.query._compute_type(self.child.typ.global_env(), self.child.typ.entry_env()) self._type = self.query.typ + def new_block(self, i: int): + return i == 1 + + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} + + def agg_bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class TableWrite(IR): @typecheck_method(child=TableIR, writer=TableWriter) @@ -1740,6 +2018,10 @@ def _compute_type(self, env, agg_env): self.child._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class MatrixWrite(IR): @typecheck_method(child=MatrixIR, matrix_writer=MatrixWriter) def __init__(self, child, matrix_writer): @@ -1761,6 +2043,10 @@ def _compute_type(self, env, agg_env): self.child._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class MatrixMultiWrite(IR): @typecheck_method(children=sequenceof(MatrixIR), writer=MatrixNativeMultiWriter) @@ -1782,6 +2068,10 @@ def _compute_type(self, env, agg_env): x._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class BlockMatrixWrite(IR): @typecheck_method(child=BlockMatrixIR, writer=BlockMatrixWriter) @@ -1803,6 +2093,10 @@ def _compute_type(self, env, agg_env): self.child._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class BlockMatrixMultiWrite(IR): @typecheck_method(block_matrices=sequenceof(BlockMatrixIR), writer=BlockMatrixMultiWriter) @@ -1825,6 +2119,10 @@ def _compute_type(self, env, agg_env): x._compute_type() self._type = tvoid + @staticmethod + def is_effectful() -> bool: + return True + class TableToValueApply(IR): def __init__(self, child, config): diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index a1fa0e47d7d..7a791550d00 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -10,6 +10,10 @@ def __init__(self, child, entry_expr, row_expr): self.entry_expr = entry_expr self.row_expr = row_expr + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.entry_expr._compute_type(child_typ.col_env(), child_typ.entry_env()) @@ -22,6 +26,25 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) + def bindings(self, i, default_value=None): + if i == 1: + return self.child.typ.col_env(default_value) + elif i == 2: + return self.child.typ.global_env(default_value) + else: + return {} + + def agg_bindings(self, i, default_value=None): + if i == 1: + return self.child.typ.entry_env(default_value) + elif i == 2: + return self.child.typ.row_env(default_value) + else: + return {} + + def binds(self, i): + return i == 1 or i == 2 + class MatrixRead(MatrixIR): def __init__(self, reader, drop_cols=False, drop_rows=False): @@ -46,10 +69,20 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixChooseCols(MatrixIR): def __init__(self, child, old_indices): @@ -74,6 +107,10 @@ def __init__(self, child, new_col, new_key): self.new_col = new_col self.new_key = new_key + @staticmethod + def new_block(i): + return i > 0 + def head_str(self): return '(' + ' '.join(f'"{escape_str(f)}"' for f in self.new_key) + ')' if self.new_key is not None else 'None' @@ -91,6 +128,18 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def bindings(self, i, default_value=None): + return self.child.typ.col_env(default_value) if i == 1 else {} + + def agg_bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} + + def scan_bindings(self, i, default_value=None): + return self.child.typ.col_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixUnionCols(MatrixIR): def __init__(self, left, right): @@ -109,6 +158,10 @@ def __init__(self, child, new_entry): self.child = child self.new_entry = new_entry + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.new_entry._compute_type(child_typ.entry_env(), None) @@ -120,6 +173,12 @@ def _compute_type(self): child_typ.row_key, self.new_entry.typ) + def bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixFilterEntries(MatrixIR): def __init__(self, child, pred): @@ -127,10 +186,20 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.entry_env(), None) self._type = self.child.typ + def bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixKeyRowsBy(MatrixIR): def __init__(self, child, keys, is_sorted=False): @@ -164,6 +233,10 @@ def __init__(self, child, new_row): self.child = child self.new_row = new_row + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.new_row._compute_type(child_typ.row_env(), child_typ.entry_env()) @@ -175,6 +248,18 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def agg_bindings(self, i, default_value=None): + return self.child.typ.entry_env(default_value) if i == 1 else {} + + def scan_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixMapGlobals(MatrixIR): def __init__(self, child, new_global): @@ -182,6 +267,10 @@ def __init__(self, child, new_global): self.child = child self.new_global = new_global + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.new_global._compute_type(child_typ.global_env(), None) @@ -193,6 +282,12 @@ def _compute_type(self): child_typ.row_key, child_typ.entry_type) + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixFilterCols(MatrixIR): def __init__(self, child, pred): @@ -200,10 +295,20 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.col_env(), None) self._type = self.child.typ + def bindings(self, i, default_value=None): + return self.child.typ.col_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixCollectColsByKey(MatrixIR): def __init__(self, child): @@ -229,6 +334,10 @@ def __init__(self, child, entry_expr, col_expr): self.entry_expr = entry_expr self.col_expr = col_expr + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.entry_expr._compute_type(child_typ.row_env(), child_typ.entry_env()) @@ -241,6 +350,25 @@ def _compute_type(self): child_typ.row_key, self.entry_expr.typ) + def bindings(self, i, default_value=None): + if i == 1: + return self.child.typ.row_env(default_value) + elif i == 2: + return self.child.typ.global_env(default_value) + else: + return {} + + def agg_bindings(self, i, default_value=None): + if i == 1: + return self.child.typ.entry_env(default_value) + elif i == 2: + return self.child.typ.col_env(default_value) + else: + return {} + + def binds(self, i): + return i == 1 or i == 2 + class MatrixExplodeRows(MatrixIR): def __init__(self, child, path): diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 089d226c869..f9879cf9b46 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -125,12 +125,26 @@ def __init__(self, child, new_globals): self.child = child self.new_globals = new_globals + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.new_globals._compute_type(self.child.typ.global_env(), None) self._type = hl.ttable(self.new_globals.typ, self.child.typ.row_type, self.child.typ.row_key) + @staticmethod + def new_block(i): + return i == 1 + + def bindings(self, i, default_value=None): + return self.child.typ.global_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class TableExplode(TableIR): def __init__(self, child, path): @@ -176,6 +190,10 @@ def __init__(self, child, new_row): self.child = child self.new_row = new_row + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): # agg_env for scans self.new_row._compute_type(self.child.typ.row_env(), self.child.typ.row_env()) @@ -184,6 +202,15 @@ def _compute_type(self): self.new_row.typ, self.child.typ.row_key) + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def scan_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class TableRead(TableIR): def __init__(self, reader, drop_rows=False): @@ -241,10 +268,20 @@ def __init__(self, child, pred): self.child = child self.pred = pred + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): self.pred._compute_type(self.child.typ.row_env(), None) self._type = self.child.typ + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class TableKeyByAndAggregate(TableIR): def __init__(self, child, expr, new_key, n_partitions, buffer_size): @@ -255,6 +292,10 @@ def __init__(self, child, expr, new_key, n_partitions, buffer_size): self.n_partitions = n_partitions self.buffer_size = buffer_size + @staticmethod + def new_block(i): + return i > 0 + def head_str(self): return f'{self.n_partitions} {self.buffer_size}' @@ -268,6 +309,20 @@ def _compute_type(self): self.new_key.typ._concat(self.expr.typ), list(self.new_key.typ)) + def bindings(self, i, default_value=None): + if i == 1: + return self.child.typ.global_env(default_value) + elif i == 2: + return self.child.typ.row_env(default_value) + else: + return {} + + def agg_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 or i == 2 + class TableAggregateByKey(TableIR): def __init__(self, child, expr): @@ -275,6 +330,10 @@ def __init__(self, child, expr): self.child = child self.expr = expr + @staticmethod + def new_block(i): + return i > 0 + def _compute_type(self): child_typ = self.child.typ self.expr._compute_type(child_typ.global_env(), child_typ.row_env()) @@ -282,6 +341,15 @@ def _compute_type(self): child_typ.key_type._concat(self.expr.typ), child_typ.row_key) + def bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def agg_bindings(self, i, default_value=None): + return self.child.typ.row_env(default_value) if i == 1 else {} + + def binds(self, i): + return i == 1 + class MatrixColsTable(TableIR): def __init__(self, child): diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala index 2b67c8883d6..79b39c83837 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala @@ -66,6 +66,11 @@ object LowerMatrixIR { BindingEnv(e, agg = Some(e), scan = Some(e)) } + def matrixGlobalSubstEnv(child: MatrixIR): BindingEnv[IRProxy] = { + val e = Env[IRProxy]("global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*)) + BindingEnv(e, agg = Some(e), scan = Some(e)) + } + def matrixSubstEnvIR(child: MatrixIR, lowered: TableIR): BindingEnv[IR] = { val e = Env[IR]("global" -> SelectFields(Ref("global", lowered.typ.globalType), child.typ.globalType.fieldNames), "va" -> SelectFields(Ref("row", lowered.typ.rowType), child.typ.rowType.fieldNames)) @@ -128,7 +133,7 @@ object LowerMatrixIR { irRange(0, 'global (colsField).len) .filter('i ~> (let(sa = 'global (colsField)('i)) - in subst(pred, matrixSubstEnv(child)))))) + in subst(pred, matrixGlobalSubstEnv(child)))))) .mapRows('row.insertFields(entriesField -> 'global ('newColIdx).map('i ~> 'row (entriesField)('i)))) .mapGlobals('global .insertFields(colsField -> From eb1225f86b91b0ab33fd480631b3e4263e854182 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 5 Sep 2019 15:35:33 -0400 Subject: [PATCH 40/47] make _compute_type use bindings --- hail/python/hail/ir/ir.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index f866b80cd3d..736d7ad80ab 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -321,7 +321,7 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.value._compute_type(agg_env, None) - self.body._compute_type(env, _env_bind(agg_env, [(self.name, self.value._type)])) + self.body._compute_type(env, _env_bind(agg_env, self.agg_bindings(1))) self._type = self.body._type def agg_bindings(self, i, default_value=None): @@ -617,7 +617,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.nd._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, [(self.name, self.nd.typ.element_type)]), agg_env) + self.body._compute_type(_env_bind(env, self.bindings(1)), agg_env) self._type = tndarray(self.body.typ, self.nd.typ.ndim) def bindings(self, i, default_value=None): @@ -903,7 +903,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) + self.body._compute_type(_env_bind(env, self.bindings(1)), agg_env) self._type = tarray(self.body.typ) def bindings(self, i, default_value=None): @@ -944,7 +944,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) + self.body._compute_type(_env_bind(env, self.bindings(1)), agg_env) self._type = self.a.typ def bindings(self, i, default_value=None): @@ -985,7 +985,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, [(self.name, self.a.typ.element_type)]), agg_env) + self.body._compute_type(_env_bind(env, self.bindings(1)), agg_env) self._type = tarray(self.body.typ.element_type) def bindings(self, i, default_value=None): @@ -1029,10 +1029,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) - self.body._compute_type( - _env_bind(env, [(self.value_name, self.a.typ.element_type), - (self.accum_name, self.zero.typ)]), - agg_env) + self.body._compute_type(_env_bind(env, self.bindings(2)), agg_env) self._type = self.zero.typ def bindings(self, i, default_value=None): @@ -1076,10 +1073,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) self.zero._compute_type(env, agg_env) - self.body._compute_type( - _env_bind(env, [(self.value_name, self.a.typ.element_type), - (self.accum_name, self.zero.typ)]), - agg_env) + self.body._compute_type(_env_bind(env, self.bindings(2)), agg_env) self._type = tarray(self.body.typ) def bindings(self, i, default_value=None): @@ -1160,7 +1154,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.a._compute_type(env, agg_env) - self.body._compute_type(_env_bind(env, [(self.value_name, self.a.typ.element_type)]), agg_env) + self.body._compute_type(_env_bind(env, self.bindings(1)), agg_env) self._type = tvoid def bindings(self, i, default_value=None): @@ -1232,8 +1226,7 @@ def bound_variables(self): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - new_agg = _env_bind(agg_env, [(self.name, self.array.typ.element_type)]) - self.agg_body._compute_type(env, new_agg) + self.agg_body._compute_type(env, _env_bind(agg_env, self.agg_bindings(1))) self._type = self.agg_body.typ def agg_bindings(self, i, default_value=None): @@ -1311,8 +1304,8 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.array._compute_type(agg_env, None) - self.agg_ir._compute_type(_env_bind(env, [(self.index_name, tint32)]), - _env_bind(agg_env, [(self.element_name, self.array.typ.element_type)])) + self.agg_ir._compute_type(_env_bind(env, self.bindings(1)), + _env_bind(agg_env, self.agg_bindings(1))) self._type = tarray(self.agg_ir.typ) @property @@ -1879,7 +1872,7 @@ def _eq(self, other): return other.argname == self.argname def _compute_type(self, env, agg_env): - self.function._compute_type(_env_bind(env, [(self.argname, tfloat64)]), agg_env) + self.function._compute_type(_env_bind(env, self.bindings(0)), agg_env) self.min._compute_type(env, agg_env) self.max._compute_type(env, agg_env) self._type = tfloat64 From 705f3c3a3dde387f2b4dd952b9a59fbcc31ddd4e Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 5 Sep 2019 16:39:59 -0400 Subject: [PATCH 41/47] wip refactoring --- hail/python/hail/ir/ir.py | 2 +- hail/python/hail/ir/renderer.py | 484 +++++++++++++++++--------------- 2 files changed, 253 insertions(+), 233 deletions(-) diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 787678c10ce..1375d2b9e22 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -321,7 +321,7 @@ def _eq(self, other): def _compute_type(self, env, agg_env): self.value._compute_type(agg_env, None) - self.body._compute_type(env, _env_bind(agg_env, self.agg_bindings(1))) + self.body._compute_type(env, _env_bind(agg_env, {self.name: self.value._type})) self._type = self.body._type def agg_bindings(self, i, default_value=None): diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index bb66a0715e5..529a1916c61 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -141,106 +141,16 @@ def __call__(self, x: 'Renderable'): Context = (Vars, Vars, Vars) -class AnalysisStackFrame: - __slots__ = ['min_binding_depth', 'min_value_binding_depth', 'scan_scope', - 'context', 'node', '_free_vars', 'visited', 'agg_visited', - 'scan_visited', 'lifted_lets', 'agg_lifted_lets', - 'scan_lifted_lets', 'child_idx', '_child_bindings'] - - def __init__(self, min_binding_depth: int, min_value_binding_depth: int, - scan_scope: bool, context: Context, x: 'ir.BaseIR'): - # immutable - self.min_binding_depth = min_binding_depth - self.min_value_binding_depth = min_value_binding_depth - self.scan_scope = scan_scope - self.context = context - self.node = x - # mutable - self._free_vars = {} - self.visited: Set[int] = set() - self.agg_visited: Set[int] = set() - self.scan_visited: Set[int] = set() - self.lifted_lets: Dict[int, str] = {} - self.agg_lifted_lets: Dict[int, str] = {} - self.scan_lifted_lets: Dict[int, str] = {} - self.child_idx = 0 - self._child_bindings = None - - # compute depth at which we might bind this node - def bind_depth(self) -> int: - if len(self._free_vars) > 0: - bind_depth = max(self._free_vars.values()) - bind_depth = max(bind_depth, self.min_binding_depth) - else: - bind_depth = self.min_binding_depth - return bind_depth - - def make_child_frame(self, depth: int) -> 'AnalysisStackFrame': - x = self.node - i = self.child_idx - 1 - child = x.children[i] - child_min_binding_depth = self.min_binding_depth - child_min_value_binding_depth = self.min_value_binding_depth - child_scan_scope = self.scan_scope - if x.new_block(i): - child_min_binding_depth = depth - child_min_value_binding_depth = depth - elif x.uses_agg_context(i): - child_min_value_binding_depth = depth - child_scan_scope = False - elif x.uses_scan_context(i): - child_min_value_binding_depth = depth - child_scan_scope = True - - # compute vars bound in 'child' by 'node' - eval_bindings = x.bindings(i, 0).keys() - agg_bindings = x.agg_bindings(i, 0).keys() - scan_bindings = x.scan_bindings(i, 0).keys() - new_bindings = (eval_bindings, agg_bindings, scan_bindings) - self._child_bindings = new_bindings - child_context = x.child_context(i, self.context, depth) - - return AnalysisStackFrame(child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, child_context, child) - - def free_vars(self): - # subtract vars that will be bound by inserted lets - def bound_vars(): - yield from self.lifted_lets.values() - yield from self.agg_lifted_lets.values() - yield from self.scan_lifted_lets.values() - for var in bound_vars(): - self._free_vars.pop(var, 0) - return self._free_vars - - def update_free_vars(self, child_frame: 'AnalysisStackFrame'): - child_free_vars = child_frame.free_vars() - # subtract vars bound by parent from free_vars - (eval_bindings, agg_bindings, scan_bindings) = self._child_bindings - for var in [*eval_bindings, *agg_bindings, *scan_bindings]: - child_free_vars.pop(var, 0) - # update parent's free variables - self._free_vars.update(child_free_vars) - - def update_parent_free_vars(self, parent_frame: 'AnalysisStackFrame'): - # subtract vars bound by parent from free_vars - (eval_bindings, agg_bindings, scan_bindings) = parent_frame._child_bindings - for var in [*eval_bindings, *agg_bindings, *scan_bindings]: - self._free_vars.pop(var, 0) - # subtract vars that will be bound by inserted lets - for (var, _) in self.lifted_lets.values(): - self._free_vars.pop(var, 0) - # update parent's free variables - parent_frame._free_vars.update(self._free_vars) - - BindingSite = namedtuple( 'BindingSite', 'depth lifted_lets agg_lifted_lets scan_lifted_lets') + BindingsStackFrame = namedtuple( 'BindingsStackFrame', 'depth lifted_lets agg_lifted_lets scan_lifted_lets visited agg_visited' ' scan_visited let_bodies') + def make_bindings_stack_frame(site: BindingSite): return BindingsStackFrame(depth=site.depth, lifted_lets=site.lifted_lets, @@ -252,80 +162,15 @@ def make_bindings_stack_frame(site: BindingSite): let_bodies=[]) -class PrintStackFrame: - __slots__ = ['node', 'children', 'min_binding_depth', - 'min_value_binding_depth', 'scan_scope', 'depth', - 'lift_to_frame', 'insert_lets', 'builder', 'child_idx'] - - def __init__(self, node, children, builder, min_binding_depth, - min_value_binding_depth, scan_scope, depth, insert_lets, - lift_to_frame=None): - # immutable - self.node: Renderable = node - self.children: Sequence[Renderable] = children - self.min_binding_depth: int = min_binding_depth - self.min_value_binding_depth: int = min_value_binding_depth - self.scan_scope: bool = scan_scope - self.depth: int = depth - self.lift_to_frame: Optional[BindingsStackFrame] = lift_to_frame - self.insert_lets: bool = insert_lets - # mutable - self.builder = builder - self.child_idx = -1 - - def add_lets(self, let_bodies, out_builder): - for let_body in let_bodies: - out_builder.extend(let_body) - out_builder.extend(self.builder) - num_lets = len(let_bodies) - for _ in range(num_lets): - out_builder.append(')') - - def make_child_frame(self, renderer, binding_sites, builder, context, - min_binding_depth, min_value_binding_depth, scan_scope, - depth): - child = self.children[self.child_idx] - return self.make(child, renderer, binding_sites, builder, context, - min_binding_depth, min_value_binding_depth, scan_scope, - depth) - - @staticmethod - def make(node, renderer, binding_sites, builder, context, min_binding_depth, - min_value_binding_depth, scan_scope, depth): - insert_lets = (id(node) in binding_sites - and depth == binding_sites[id(node)].depth - and (len(binding_sites[id(node)].lifted_lets) > 0 - or len(binding_sites[id(node)].agg_lifted_lets) > 0)) - state = PrintStackFrame(node, node.render_children(renderer), builder, - min_binding_depth, min_value_binding_depth, - scan_scope, depth, insert_lets) - if insert_lets: - state.builder = [] - context.append(make_bindings_stack_frame(binding_sites[id(node)])) - head = node.render_head(renderer) - if head != '': - state.builder.append(head) - return state - - -class CSERenderer: - def __init__(self, stop_at_jir=False): - self.stop_at_jir = stop_at_jir - self.jir_count = 0 - self.jirs = {} - self.memo: Dict[int, Sequence[str]] = {} +class CSEAnalysisPass: + def __init__(self, renderer): + self.renderer = renderer self.uid_count = 0 def uid(self) -> str: self.uid_count += 1 return f'__cse_{self.uid_count}' - def add_jir(self, jir): - jir_id = f'm{self.jir_count}' - self.jir_count += 1 - self.jirs[jir_id] = jir - return jir_id - # At top of main loop, we are considering the node 'node' and its # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are # about to do post-processing on 'node' before moving back up to its parent. @@ -361,8 +206,8 @@ def add_jir(self, jir): # 'BindingSite' recording the depth of 'node' and a Dict 'lifted_lets', # where for each descendant 'x' which will be bound above 'node', # 'lifted_lets' maps 'id(x)' to the unique id 'x' will be bound to. - def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: - root_frame = AnalysisStackFrame(0, 0, False, ({}, {}, {}), root) + def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: + root_frame = self.StackFrame(0, 0, False, ({}, {}, {}), root) stack = [root_frame] binding_sites = {} @@ -374,7 +219,6 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: if child_idx >= len(node.children): # mark node as visited at potential let insertion site - # FIXME: factor out to AnalysisStackFrame method if not node.is_effectful(): bind_depth = frame.bind_depth() if bind_depth < frame.min_value_binding_depth: @@ -389,12 +233,8 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: # if any lets being inserted here, add node to registry of # binding sites - if frame.lifted_lets or frame.agg_lifted_lets or frame.scan_lifted_lets: - binding_sites[id(node)] = BindingSite( - lifted_lets=frame.lifted_lets, - agg_lifted_lets=frame.agg_lifted_lets, - scan_lifted_lets=frame.scan_lifted_lets, - depth=len(stack)) + if frame.has_lifted_lets(): + binding_sites[id(node)] = frame.make_binding_site(len(stack)) if not stack: break @@ -403,48 +243,31 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: child = node.children[child_idx] - if self.stop_at_jir and hasattr(child, '_jir'): - jir_id = self.add_jir(child._jir) - if isinstance(child, ir.MatrixIR): - jref = f'(JavaMatrix {jir_id})' - elif isinstance(child, ir.TableIR): - jref = f'(JavaTable {jir_id})' - elif isinstance(child, ir.BlockMatrixIR): - jref = f'(JavaBlockMatrix {jir_id})' - else: - assert isinstance(child, ir.IR) - jref = f'(JavaIR {jir_id})' - - self.memo[id(child)] = jref + if self.renderer.stop_at_jir and hasattr(child, '_jir'): + self.renderer._add_jir(child) continue child_frame = frame.make_child_frame(len(stack)) if isinstance(child, ir.IR): - seen_in_scope = next((i for i in reversed(range(child_frame.min_value_binding_depth, - len(stack))) - if id(child) in stack[i].visited), - None) - lets = None - if seen_in_scope is not None: - lets = stack[seen_in_scope].lifted_lets - else: - if frame.scan_scope: - seen_in_scope = next( - (i for i in reversed(range(child_frame.min_binding_depth, - child_frame.min_value_binding_depth)) - if id(child) in stack[i].scan_visited), - None) - if seen_in_scope is not None: - lets = stack[seen_in_scope].scan_lifted_lets + for i in reversed(range(child_frame.min_binding_depth, len(stack))): + frame = stack[i] + if i >= child_frame.min_value_binding_depth: + if id(child) in frame.visited: + seen_in_scope = i + lets = frame.lifted_lets + break else: - seen_in_scope = next( - (i for i in reversed(range(child_frame.min_binding_depth, - child_frame.min_value_binding_depth)) - if id(child) in stack[i].agg_visited), - None) - if seen_in_scope is not None: - lets = stack[seen_in_scope].agg_lifted_lets + if id(child) in frame.agg_visited: + seen_in_scope = i + lets = frame.agg_lifted_lets + break + if id(child) in frame.scan_visited: + seen_in_scope = i + lets = frame.scan_lifted_lets + break + else: + seen_in_scope = None if lets is not None: # we've seen 'child' before, should not traverse (or we will @@ -476,6 +299,112 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: return binding_sites + class StackFrame: + __slots__ = ['min_binding_depth', 'min_value_binding_depth', 'scan_scope', + 'context', 'node', '_free_vars', 'visited', 'agg_visited', + 'scan_visited', 'lifted_lets', 'agg_lifted_lets', + 'scan_lifted_lets', 'child_idx', '_child_bindings'] + + def __init__(self, min_binding_depth: int, min_value_binding_depth: int, + scan_scope: bool, context: Context, x: 'ir.BaseIR'): + # immutable + self.min_binding_depth = min_binding_depth + self.min_value_binding_depth = min_value_binding_depth + self.scan_scope = scan_scope + self.context = context + self.node = x + # mutable + self._free_vars = {} + self.visited: Set[int] = set() + self.agg_visited: Set[int] = set() + self.scan_visited: Set[int] = set() + self.lifted_lets: Dict[int, str] = {} + self.agg_lifted_lets: Dict[int, str] = {} + self.scan_lifted_lets: Dict[int, str] = {} + self.child_idx = 0 + self._child_bindings = None + + def has_lifted_lets(self): + return self.lifted_lets or self.agg_lifted_lets or self.scan_lifted_lets + + def make_binding_site(self, depth): + return BindingSite( + lifted_lets=self.lifted_lets, + agg_lifted_lets=self.agg_lifted_lets, + scan_lifted_lets=self.scan_lifted_lets, + depth=depth) + + # compute depth at which we might bind this node + def bind_depth(self) -> int: + if len(self._free_vars) > 0: + bind_depth = max(self._free_vars.values()) + bind_depth = max(bind_depth, self.min_binding_depth) + else: + bind_depth = self.min_binding_depth + return bind_depth + + def make_child_frame(self, depth: int): + x = self.node + i = self.child_idx - 1 + child = x.children[i] + child_min_binding_depth = self.min_binding_depth + child_min_value_binding_depth = self.min_value_binding_depth + child_scan_scope = self.scan_scope + if x.new_block(i): + child_min_binding_depth = depth + child_min_value_binding_depth = depth + elif x.uses_agg_context(i): + child_min_value_binding_depth = depth + child_scan_scope = False + elif x.uses_scan_context(i): + child_min_value_binding_depth = depth + child_scan_scope = True + + # compute vars bound in 'child' by 'node' + eval_bindings = x.bindings(i, 0).keys() + agg_bindings = x.agg_bindings(i, 0).keys() + scan_bindings = x.scan_bindings(i, 0).keys() + new_bindings = (eval_bindings, agg_bindings, scan_bindings) + self._child_bindings = new_bindings + child_context = x.child_context(i, self.context, depth) + + return CSEAnalysisPass.StackFrame(child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, child_context, child) + + def free_vars(self): + # subtract vars that will be bound by inserted lets + def bound_vars(): + yield from self.lifted_lets.values() + yield from self.agg_lifted_lets.values() + yield from self.scan_lifted_lets.values() + for var in bound_vars(): + self._free_vars.pop(var, 0) + return self._free_vars + + def update_free_vars(self, child_frame): + child_free_vars = child_frame.free_vars() + # subtract vars bound by parent from free_vars + (eval_bindings, agg_bindings, scan_bindings) = self._child_bindings + for var in [*eval_bindings, *agg_bindings, *scan_bindings]: + child_free_vars.pop(var, 0) + # update parent's free variables + self._free_vars.update(child_free_vars) + + def update_parent_free_vars(self, parent_frame): + # subtract vars bound by parent from free_vars + (eval_bindings, agg_bindings, scan_bindings) = parent_frame._child_bindings + for var in [*eval_bindings, *agg_bindings, *scan_bindings]: + self._free_vars.pop(var, 0) + # subtract vars that will be bound by inserted lets + for (var, _) in self.lifted_lets.values(): + self._free_vars.pop(var, 0) + # update parent's free variables + parent_frame._free_vars.update(self._free_vars) + + +class CSEPrintPass: + def __init__(self, renderer): + self.renderer = renderer + # At top of main loop, we are considering the 'Renderable' 'node' and its # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are # about to do post-processing on 'node' before moving back up to its parent. @@ -505,14 +434,15 @@ def compute_new_bindings(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: # be added to 'builder'. If neither, then it is safe for 'local_builder' to # be 'builder', to save copying. - def build_string(self, root, binding_sites): + def __call__(self, root, binding_sites): root_builder = [] context: List[BindingsStackFrame] = [] + memo = self.renderer.memo - if id(root) in self.memo: - return ''.join(self.memo[id(root)]) + if id(root) in memo: + return ''.join(memo[id(root)]) - stack = [PrintStackFrame.make(root, self, binding_sites, root_builder, context, 0, 0, False, 0)] + stack = [self.StackFrame.make(root, self, binding_sites, root_builder, context, 0, 0, False, 0)] while True: frame = stack[-1] @@ -523,10 +453,10 @@ def build_string(self, root, binding_sites): if child_idx >= len(frame.children): if frame.lift_to_frame is not None: assert(not frame.insert_lets) - if id(node) in self.memo: - frame.builder.append(self.memo[id(node)]) + if id(node) in memo: + frame.builder.append(memo[id(node)]) else: - frame.builder.append(node.render_tail(self)) + frame.builder.append(node.render_tail(self.renderer)) frame.builder.append(' ') # let_bodies is built post-order, which guarantees earlier # lets can't refer to later lets @@ -534,7 +464,7 @@ def build_string(self, root, binding_sites): stack.pop() continue else: - frame.builder.append(node.render_tail(self)) + frame.builder.append(node.render_tail(self.renderer)) stack.pop() if frame.insert_lets: if not stack: @@ -594,35 +524,38 @@ def build_string(self, root, binding_sites): or len(binding_sites[id(child)].agg_lifted_lets) > 0 or len(binding_sites[id(child)].scan_lifted_lets > 0))) assert not insert_lets + if lift_type == 'value': - name = lift_to_frame.lifted_lets[id(child)] - child_builder = [f'(Let {name} '] visited = lift_to_frame.visited + name = lift_to_frame.lifted_lets[id(child)] elif lift_type == 'agg': - name = lift_to_frame.agg_lifted_lets[id(child)] - child_builder = [f'(AggLet {name} False '] visited = lift_to_frame.agg_visited + name = lift_to_frame.agg_lifted_lets[id(child)] else: - name = lift_to_frame.scan_lifted_lets[id(child)] - child_builder = [f'(AggLet {name} True '] visited = lift_to_frame.scan_visited + name = lift_to_frame.scan_lifted_lets[id(child)] frame.builder.append(f'(Ref {name})') - if id(child) in self.memo: - child_builder.append(self.memo[id(child)]) + if id(child) in visited: + continue + + if lift_type == 'value': + child_builder = [f'(Let {name} '] + elif lift_type == 'agg': + child_builder = [f'(AggLet {name} False '] + else: + child_builder = [f'(AggLet {name} True '] + + if id(child) in memo: + child_builder.append(memo[id(child)]) child_builder.append(' ') lift_to_frame.let_bodies.append(child_builder) continue - # new_state = PrintStackFrame(child, [], child_builder, child_min_binding_depth, child_depth, insert_lets, lift_to_frame) - # stack.append(new_state) - # continue - if id(child) in visited: - continue visited[id(child)] = child - new_state = frame.make_child_frame(self, binding_sites, + new_state = frame.make_child_frame(self.renderer, binding_sites, child_builder, context, child_min_binding_depth, child_min_value_binding_depth, @@ -632,11 +565,11 @@ def build_string(self, root, binding_sites): stack.append(new_state) continue - if id(child) in self.memo: - frame.builder.append(self.memo[id(child)]) + if id(child) in memo: + frame.builder.append(memo[id(child)]) continue - new_state = frame.make_child_frame(self, binding_sites, + new_state = frame.make_child_frame(self.renderer, binding_sites, frame.builder, context, child_min_binding_depth, child_min_value_binding_depth, @@ -645,6 +578,93 @@ def build_string(self, root, binding_sites): stack.append(new_state) continue + class StackFrame: + __slots__ = ['node', 'children', 'min_binding_depth', + 'min_value_binding_depth', 'scan_scope', 'depth', + 'lift_to_frame', 'insert_lets', 'builder', 'child_idx'] + + def __init__(self, node, children, builder, min_binding_depth, + min_value_binding_depth, scan_scope, depth, insert_lets, + lift_to_frame=None): + # immutable + self.node: Renderable = node + self.children: Sequence[Renderable] = children + self.min_binding_depth: int = min_binding_depth + self.min_value_binding_depth: int = min_value_binding_depth + self.scan_scope: bool = scan_scope + self.depth: int = depth + self.lift_to_frame: Optional[BindingsStackFrame] = lift_to_frame + self.insert_lets: bool = insert_lets + # mutable + self.builder = builder + self.child_idx = -1 + + def add_lets(self, let_bodies, out_builder): + for let_body in let_bodies: + out_builder.extend(let_body) + out_builder.extend(self.builder) + num_lets = len(let_bodies) + for _ in range(num_lets): + out_builder.append(')') + + def make_child_frame(self, renderer, binding_sites, builder, context, + min_binding_depth, min_value_binding_depth, + scan_scope, depth): + child = self.children[self.child_idx] + return self.make(child, renderer, binding_sites, builder, context, + min_binding_depth, min_value_binding_depth, + scan_scope, depth) + + @staticmethod + def make(node, renderer, binding_sites, builder, context, + min_binding_depth, min_value_binding_depth, scan_scope, depth): + insert_lets = (id(node) in binding_sites + and depth == binding_sites[id(node)].depth + and (len(binding_sites[id(node)].lifted_lets) > 0 + or len(binding_sites[id(node)].agg_lifted_lets) > 0)) + state = CSEPrintPass.StackFrame(node, + node.render_children(renderer), + builder, + min_binding_depth, + min_value_binding_depth, + scan_scope, depth, insert_lets) + if insert_lets: + state.builder = [] + context.append(make_bindings_stack_frame(binding_sites[id(node)])) + head = node.render_head(renderer) + if head != '': + state.builder.append(head) + return state + + +class CSERenderer: + def __init__(self, stop_at_jir=False): + self.stop_at_jir = stop_at_jir + self.jir_count = 0 + self.jirs = {} + self.memo: Dict[int, Sequence[str]] = {} + + def add_jir(self, jir): + jir_id = f'm{self.jir_count}' + self.jir_count += 1 + self.jirs[jir_id] = jir + return jir_id + + def _add_jir(self, node): + jir_id = self.add_jir(node._jir) + if isinstance(node, ir.MatrixIR): + jref = f'(JavaMatrix {jir_id})' + elif isinstance(node, ir.TableIR): + jref = f'(JavaTable {jir_id})' + elif isinstance(node, ir.BlockMatrixIR): + jref = f'(JavaBlockMatrix {jir_id})' + else: + assert isinstance(node, ir.IR) + jref = f'(JavaIR {jir_id})' + + self.memo[id(node)] = jref + + def __call__(self, root: 'ir.BaseIR') -> str: - binding_sites = self.compute_new_bindings(root) - return self.build_string(root, binding_sites) + binding_sites = CSEAnalysisPass(self)(root) + return CSEPrintPass(self)(root, binding_sites) From 538c1be4b41498135eff2966c74e1fe388f030a5 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 6 Sep 2019 16:03:13 -0400 Subject: [PATCH 42/47] refactoring/documenting --- hail/python/hail/backend/backend.py | 2 +- hail/python/hail/ir/base_ir.py | 12 +- hail/python/hail/ir/renderer.py | 261 +++++++++++++++------------- 3 files changed, 150 insertions(+), 125 deletions(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 77a87152ec8..d0e8b2f01ea 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -99,7 +99,7 @@ def fs(self): def _to_java_ir(self, ir): if not hasattr(ir, '_jir'): - r = Renderer(stop_at_jir=True) + r = CSERenderer(stop_at_jir=True) # FIXME parse should be static ir._jir = ir.parse(r(ir), ir_map=r.jirs) return ir._jir diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 4e01fcc28fa..8e6c199bfad 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -99,9 +99,6 @@ def renderable_new_block(self, i: int) -> bool: def is_effectful() -> bool: return False - def binds(self, i: int) -> bool: - return False - def bindings(self, i: int, default_value=None): """Compute variables bound in child 'i'. @@ -142,13 +139,14 @@ def child_context_without_bindings(self, i: int, parent_context): def child_context(self, i: int, parent_context, default_value=None): base = self.child_context_without_bindings(i, parent_context) - if not self.binds(i): - return base eval_b = self.bindings(i, default_value) agg_b = self.agg_bindings(i, default_value) scan_b = self.scan_bindings(i, default_value) - (eval_c, agg_c, scan_c) = base - return _env_bind(eval_c, eval_b), _env_bind(agg_c, agg_b), _env_bind(scan_c, scan_b) + if eval_b or agg_b or scan_b: + (eval_c, agg_c, scan_c) = base + return _env_bind(eval_c, eval_b), _env_bind(agg_c, agg_b), _env_bind(scan_c, scan_b) + else: + return base class IR(BaseIR): diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 529a1916c61..d617b811801 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -145,22 +145,6 @@ def __call__(self, x: 'Renderable'): 'BindingSite', 'depth lifted_lets agg_lifted_lets scan_lifted_lets') -BindingsStackFrame = namedtuple( - 'BindingsStackFrame', - 'depth lifted_lets agg_lifted_lets scan_lifted_lets visited agg_visited' - ' scan_visited let_bodies') - - -def make_bindings_stack_frame(site: BindingSite): - return BindingsStackFrame(depth=site.depth, - lifted_lets=site.lifted_lets, - agg_lifted_lets=site.agg_lifted_lets, - scan_lifted_lets=site.scan_lifted_lets, - visited={}, - agg_visited={}, - scan_visited={}, - let_bodies=[]) - class CSEAnalysisPass: def __init__(self, renderer): @@ -175,31 +159,8 @@ def uid(self) -> str: # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are # about to do post-processing on 'node' before moving back up to its parent. # - # 'stack' is a stack of 'AnalysisStackFrame's, one for each node on the path - # from 'root' to 'node. Each stack frame tracks the following immutable - # information: - # * 'node': The node corresponding to this stack frame. - # * 'min_binding_depth': If 'node' is to be bound in a let, the let may not - # rise higher than this (its depth must be >= 'min_binding_depth') - # * 'context': The binding context of 'node'. Maps variables bound above - # to the depth at which they were bound (more precisely, if - # 'context[var] == depth', then 'stack[depth-1].node' binds 'var' in the - # subtree rooted at 'stack[depth].node'). - # * 'new_bindings': The variables which were bound by 'node's parent. These - # must be subtracted out of 'node's free variables when updating the - # parent's free variables. - # - # Each stack frame also holds the following mutable state: - # * 'free_vars': The running union of free variables in the subtree rooted - # at 'node'. - # * 'visited': A set of visited descendants. For each descendant 'x' of - # 'node', 'id(x)' is added to 'visited'. This allows us to recognize when - # we see a node for a second time. - # * 'lifted_lets': For each descendant 'x' of 'node', if 'x' is to be bound - # in a let immediately above 'node', then 'lifted_lets' contains 'id(x)', - # along with the unique id to bind 'x' to. - # * 'child_idx': the child currently being visited (satisfies the invariant - # 'stack[i].node.children[stack[i].child_idx] is stack[i+1].node'). + # 'stack' is a stack of 'StackFrame's, one for each node on the path from + # 'root' to 'node. See 'StackFrame' for descriptions of the maintained state. # # Returns a Dict summarizing of all lets to be inserted. For each 'node' # which will have lets inserted immediately above it, maps 'id(node)' to a @@ -214,8 +175,8 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: while True: frame = stack[-1] node = frame.node - child_idx = frame.child_idx frame.child_idx += 1 + child_idx = frame.child_idx if child_idx >= len(node.children): # mark node as visited at potential let insertion site @@ -251,23 +212,23 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: if isinstance(child, ir.IR): for i in reversed(range(child_frame.min_binding_depth, len(stack))): - frame = stack[i] + cur = stack[i] if i >= child_frame.min_value_binding_depth: - if id(child) in frame.visited: + if id(child) in cur.visited: seen_in_scope = i - lets = frame.lifted_lets + lets = cur.lifted_lets break else: - if id(child) in frame.agg_visited: + if id(child) in cur.agg_visited: seen_in_scope = i - lets = frame.agg_lifted_lets + lets = cur.agg_lifted_lets break - if id(child) in frame.scan_visited: + if id(child) in cur.scan_visited: seen_in_scope = i - lets = frame.scan_lifted_lets + lets = cur.scan_lifted_lets break else: - seen_in_scope = None + lets = None if lets is not None: # we've seen 'child' before, should not traverse (or we will @@ -307,24 +268,55 @@ class StackFrame: def __init__(self, min_binding_depth: int, min_value_binding_depth: int, scan_scope: bool, context: Context, x: 'ir.BaseIR'): - # immutable + # Immutable: + + # The node corresponding to this stack frame. + self.node = x + # If 'node' is to be bound in a let, the let may not rise higher + # than this (its depth must be >= 'min_binding_depth'). self.min_binding_depth = min_binding_depth + # The binding context of 'node'. Maps variables bound above to the + # depth at which they were bound (more precisely, if + # 'context[var] == depth', then 'stack[depth-1].node' binds 'var' in + # the subtree rooted at 'stack[depth].node'). + self.context = context + # If 'node' is to be bound in a let above this (at a smaller depth) + # then it must be in an AggLet, with 'scan_scope' determining + # whether it is bound in the agg scope or scan scope. self.min_value_binding_depth = min_value_binding_depth self.scan_scope = scan_scope - self.context = context - self.node = x - # mutable - self._free_vars = {} + + # Mutable: + + # The running union of free variables in the subtree rooted at + # 'node'. For each free variable we store the depth at which it is + # bound (above 'node'). + self._free_vars: Dict[str, int] = {} + # Sets of visited descendants. For each descendant 'x' of 'node', + # 'id(x)' is added to 'visited'. This allows us to recognize when + # we see a node for a second time. + # A single descendant may need to be bound here in three separate + # scopes, so we must track each scope separately. self.visited: Set[int] = set() self.agg_visited: Set[int] = set() self.scan_visited: Set[int] = set() + # For each descendant 'x' of 'node', if 'x' is to be bound in a Let + # (resp. an AggLet) immediately above 'node', then 'lifted_lets' + # (resp. (agg/scan)_lifted_lets) contains 'id(x)', along with the + # unique id to bind 'x' to. self.lifted_lets: Dict[int, str] = {} self.agg_lifted_lets: Dict[int, str] = {} self.scan_lifted_lets: Dict[int, str] = {} - self.child_idx = 0 + # The child currently being visited (satisfies the invariant + # 'stack[i].node.children[stack[i].child_idx] is stack[i+1].node'). + # Starts at -1 because it is incremented at the top of the main loop. + self.child_idx = -1 + # The variables bound by this node in the currently active child. + # These must be subtracted out of the child's free variables when + # updating '_free_vars'. self._child_bindings = None - def has_lifted_lets(self): + def has_lifted_lets(self) -> bool: return self.lifted_lets or self.agg_lifted_lets or self.scan_lifted_lets def make_binding_site(self, depth): @@ -345,7 +337,7 @@ def bind_depth(self) -> int: def make_child_frame(self, depth: int): x = self.node - i = self.child_idx - 1 + i = self.child_idx child = x.children[i] child_min_binding_depth = self.min_binding_depth child_min_value_binding_depth = self.min_value_binding_depth @@ -389,17 +381,6 @@ def update_free_vars(self, child_frame): # update parent's free variables self._free_vars.update(child_free_vars) - def update_parent_free_vars(self, parent_frame): - # subtract vars bound by parent from free_vars - (eval_bindings, agg_bindings, scan_bindings) = parent_frame._child_bindings - for var in [*eval_bindings, *agg_bindings, *scan_bindings]: - self._free_vars.pop(var, 0) - # subtract vars that will be bound by inserted lets - for (var, _) in self.lifted_lets.values(): - self._free_vars.pop(var, 0) - # update parent's free variables - parent_frame._free_vars.update(self._free_vars) - class CSEPrintPass: def __init__(self, renderer): @@ -409,40 +390,22 @@ def __init__(self, renderer): # 'child_idx'th child, or if 'child_idx' = 'len(node.children)', we are # about to do post-processing on 'node' before moving back up to its parent. # - # 'stack' is a stack of 'PrintStackFrame's, one for each node on the path - # from 'root' to 'node. Each stack frame tracks the following immutable - # information: - # * 'node': The 'Renderable' node corresponding to this stack frame. - # * 'children': The list of 'Renderable' children. - # * 'min_binding_depth': If 'node' is to be bound in a let, the let may not - # rise higher than this (its depth must be >= 'min_binding_depth') - # * 'depth': The depth of 'node' in the original tree, i.e. the number of - # BaseIR in 'stack', not counting other 'Renderable's. - # * 'lift_to_frame': The outermost frame in which 'node' was marked to be - # lifted in the analysis pass, if any, otherwise None. - # * 'insert_lets': True if any lets need to be inserted above 'node'. No - # node has both 'lift_to_frame' not None and 'insert_lets' True. + # 'stack' is a stack of 'StackFrame's, one for each node on the path from + # 'root' to 'node'. See 'StackFrame' for descriptions of the maintained state. # - # Each stack frame also holds the following mutable state: - # * 'child_idx': The index of the 'Renderable' child currently being - # visited. - # * 'builder': The buffer building the parent's rendered IR. - # * 'local_builder': The builder building 'node's IR. - # If 'insert_lets', all lets will be added to 'builder' before copying - # 'local_builder' to 'builder. If 'lift_to_frame', 'local_builder' will be - # added to 'lift_to_frame's list of lifted lets, while only "(Ref ...)" will - # be added to 'builder'. If neither, then it is safe for 'local_builder' to - # be 'builder', to save copying. - - def __call__(self, root, binding_sites): + # 'bindings_stack' is a stack of 'BindingsStackFrame's, one for each + # potential binding site on the path from 'root' to 'node'. + + def __call__(self, root: ir.BaseIR, binding_sites: Dict[int, BindingSite]): root_builder = [] - context: List[BindingsStackFrame] = [] + bindings_stack: List[CSEPrintPass.BindingsStackFrame] = [] memo = self.renderer.memo if id(root) in memo: return ''.join(memo[id(root)]) - stack = [self.StackFrame.make(root, self, binding_sites, root_builder, context, 0, 0, False, 0)] + stack = [self.StackFrame.make(root, self.renderer, binding_sites, + root_builder, bindings_stack, 0, 0, False, 0)] while True: frame = stack[-1] @@ -471,8 +434,8 @@ def __call__(self, root, binding_sites): out_builder = root_builder else: out_builder = stack[-1].builder - frame.add_lets(context[-1].let_bodies, out_builder) - context.pop() + frame.add_lets(bindings_stack[-1].let_bodies, out_builder) + bindings_stack.pop() if not stack: return ''.join(root_builder) continue @@ -499,7 +462,7 @@ def __call__(self, root, binding_sites): if isinstance(child, ir.BaseIR): child_depth += 1 - for c in context: + for c in bindings_stack: if c.depth >= child_min_value_binding_depth and id(child) in c.lifted_lets: lift_to_frame = c lift_type = 'value' @@ -556,7 +519,7 @@ def __call__(self, root, binding_sites): visited[id(child)] = child new_state = frame.make_child_frame(self.renderer, binding_sites, - child_builder, context, + child_builder, bindings_stack, child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, @@ -570,7 +533,7 @@ def __call__(self, root, binding_sites): continue new_state = frame.make_child_frame(self.renderer, binding_sites, - frame.builder, context, + frame.builder, bindings_stack, child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, @@ -578,28 +541,66 @@ def __call__(self, root, binding_sites): stack.append(new_state) continue + BindingsStackFrame = namedtuple( + 'BindingsStackFrame', + 'depth lifted_lets agg_lifted_lets scan_lifted_lets visited agg_visited' + ' scan_visited let_bodies') + class StackFrame: __slots__ = ['node', 'children', 'min_binding_depth', 'min_value_binding_depth', 'scan_scope', 'depth', 'lift_to_frame', 'insert_lets', 'builder', 'child_idx'] - def __init__(self, node, children, builder, min_binding_depth, - min_value_binding_depth, scan_scope, depth, insert_lets, - lift_to_frame=None): - # immutable + def __init__(self, + node: Renderable, + children: Sequence[Renderable], + builder: Sequence[str], + min_binding_depth: int, + min_value_binding_depth: int, + scan_scope: bool, + depth: int, + insert_lets: bool, + lift_to_frame: 'Optional[CSEPrintPass.BindingsStackFrame]' = None): + # Immutable + + # The 'Renderable' node corresponding to this stack frame. self.node: Renderable = node + # The list of 'Renderable' children. self.children: Sequence[Renderable] = children + # If 'node' is to be bound in a let, the let may not rise higher + # than this (its depth must be >= 'min_binding_depth'). self.min_binding_depth: int = min_binding_depth + # If 'node' is to be bound higher than this, it must be bound by an + # AggLet, with 'scan_scope' determining whether it is bound in + # the agg scope or scan scope. self.min_value_binding_depth: int = min_value_binding_depth self.scan_scope: bool = scan_scope + # The depth of 'node' in the original tree, i.e. the number of + # BaseIR above this in the stack, not counting other 'Renderable's. self.depth: int = depth - self.lift_to_frame: Optional[BindingsStackFrame] = lift_to_frame + # The outermost frame above this in which 'node' was marked to be + # lifted in the analysis pass, if any, otherwise None. + self.lift_to_frame: Optional[CSEPrintPass.BindingsStackFrame] = lift_to_frame + # True if any lets need to be inserted above 'node'. No node has + # both 'lift_to_frame' not None and 'insert_lets' True. self.insert_lets: bool = insert_lets - # mutable - self.builder = builder + + # Mutable + + # The index of the 'Renderable' child currently being visited. + # Starts at -1 because it is incremented at the top of the main loop. self.child_idx = -1 + # The array of strings building 'node's IR. + # * If 'insert_lets', all lets will be added to the parent's + # 'builder' before appending this 'builder'. + # * If 'lift_to_frame', 'builder' will be added to 'lift_to_frame's + # list of lifted lets, while only "(Ref ...)" will be added to + # the parent's 'builder'. + # * If neither, then it is safe for 'builder' to be an alias of the + # parent's 'builder', to save copying. + self.builder = builder - def add_lets(self, let_bodies, out_builder): + def add_lets(self, let_bodies: Sequence[str], out_builder: Sequence[str]): for let_body in let_bodies: out_builder.extend(let_body) out_builder.extend(self.builder) @@ -607,17 +608,30 @@ def add_lets(self, let_bodies, out_builder): for _ in range(num_lets): out_builder.append(')') - def make_child_frame(self, renderer, binding_sites, builder, context, - min_binding_depth, min_value_binding_depth, - scan_scope, depth): + def make_child_frame(self, + renderer: 'CSERenderer', + binding_sites: Dict[int, BindingSite], + builder: Sequence[str], + bindings_stack: 'Sequence[CSEPrintPass.BindingsStackFrame]', + min_binding_depth: int, + min_value_binding_depth: int, + scan_scope: bool, + depth: int): child = self.children[self.child_idx] - return self.make(child, renderer, binding_sites, builder, context, + return self.make(child, renderer, binding_sites, builder, bindings_stack, min_binding_depth, min_value_binding_depth, scan_scope, depth) @staticmethod - def make(node, renderer, binding_sites, builder, context, - min_binding_depth, min_value_binding_depth, scan_scope, depth): + def make(node: Renderable, + renderer: 'CSERenderer', + binding_sites: Dict[int, BindingSite], + builder: Sequence[str], + bindings_stack: 'Sequence[CSEPrintPass.BindingsStackFrame]', + min_binding_depth: int, + min_value_binding_depth: int, + scan_scope: bool, + depth: int): insert_lets = (id(node) in binding_sites and depth == binding_sites[id(node)].depth and (len(binding_sites[id(node)].lifted_lets) > 0 @@ -630,12 +644,26 @@ def make(node, renderer, binding_sites, builder, context, scan_scope, depth, insert_lets) if insert_lets: state.builder = [] - context.append(make_bindings_stack_frame(binding_sites[id(node)])) + bindings_stack.append( + CSEPrintPass.StackFrame.make_bindings_stack_frame( + binding_sites[id(node)])) head = node.render_head(renderer) if head != '': state.builder.append(head) return state + @staticmethod + def make_bindings_stack_frame(site: BindingSite): + return CSEPrintPass.BindingsStackFrame( + depth=site.depth, + lifted_lets=site.lifted_lets, + agg_lifted_lets=site.agg_lifted_lets, + scan_lifted_lets=site.scan_lifted_lets, + visited={}, + agg_visited={}, + scan_visited={}, + let_bodies=[]) + class CSERenderer: def __init__(self, stop_at_jir=False): @@ -664,7 +692,6 @@ def _add_jir(self, node): self.memo[id(node)] = jref - def __call__(self, root: 'ir.BaseIR') -> str: binding_sites = CSEAnalysisPass(self)(root) return CSEPrintPass(self)(root, binding_sites) From bb2ce5165d0ce5bc5e1f847a01b629f3e7b85eeb Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Tue, 10 Sep 2019 09:32:56 -0400 Subject: [PATCH 43/47] fix --- hail/python/hail/ir/renderer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index d617b811801..2ef32b40a1a 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -396,7 +396,7 @@ def __init__(self, renderer): # 'bindings_stack' is a stack of 'BindingsStackFrame's, one for each # potential binding site on the path from 'root' to 'node'. - def __call__(self, root: ir.BaseIR, binding_sites: Dict[int, BindingSite]): + def __call__(self, root: 'ir.BaseIR', binding_sites: Dict[int, BindingSite]): root_builder = [] bindings_stack: List[CSEPrintPass.BindingsStackFrame] = [] memo = self.renderer.memo From 9d61e847e8fba5f24e7e79f6f56ee2de9fd2fed2 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 11 Sep 2019 09:15:09 -0400 Subject: [PATCH 44/47] use pytest assertions --- hail/python/test/hail/test_ir.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index 2b609359ac9..4807a65717d 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -370,7 +370,7 @@ def test_cse(self): ' (ApplyBinaryPrimOp `+`' ' (Ref __cse_1)' ' (Ref __cse_1)))') - self.assertEqual(expected, CSERenderer()(x)) + assert expected == CSERenderer()(x) def test_cse2(self): x = ir.I32(5) @@ -386,7 +386,7 @@ def test_cse2(self): ' (Ref __cse_2)' ' (I32 4))' ' (Ref __cse_2))))') - self.assertEqual(expected, CSERenderer()(div)) + assert expected == CSERenderer()(div) def test_cse_ifs(self): outer_repeated = ir.I32(5) @@ -402,7 +402,7 @@ def test_cse_ifs(self): ' (I32 5)))' ' (I32 5))' ) - self.assertEqual(expected, CSERenderer()(cond)) + assert expected == CSERenderer()(cond) def test_agg_cse(self): x = ir.GetField(ir.Ref('row'), 'idx') @@ -416,7 +416,7 @@ def test_agg_cse(self): ' (Let __cse_2 (ApplyAggOp AggOp () None' ' ((ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))' ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2)))))') - self.assertEqual(expected, CSERenderer()(table_agg)) + assert expected == CSERenderer()(table_agg) def test_init_op(self): x = ir.I32(5) @@ -434,4 +434,4 @@ def test_init_op(self): ' ((Let __cse_3 (I32 5)' ' (ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3))))' ' ((ApplyBinaryPrimOp `+` (Ref __cse_4) (Ref __cse_4)))))))') - self.assertEqual(expected, CSERenderer()(top)) + assert expected == CSERenderer()(top) From eadff26459ded3fc010e5baab46120c5c54abf4e Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 11 Sep 2019 11:30:38 -0400 Subject: [PATCH 45/47] address comments --- hail/python/hail/ir/base_ir.py | 13 ++--- hail/python/hail/ir/ir.py | 6 +-- hail/python/hail/ir/renderer.py | 93 ++++++++++++++++++--------------- 3 files changed, 58 insertions(+), 54 deletions(-) diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 8e6c199bfad..d33e835e6de 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,10 +1,7 @@ import abc -from typing import List, Tuple - from hail.utils.java import Env -from hail.expr.types import HailType -from .renderer import Renderer, Renderable, RenderableStr +from .renderer import Renderer, PlainRenderer, Renderable def _env_bind(env, bindings): @@ -26,10 +23,10 @@ def __init__(self, *children): self.children = children def __str__(self): - r = Renderer(stop_at_jir=False) + r = PlainRenderer(stop_at_jir=False) return r(self) - def render_head(self, r): + def render_head(self, r: Renderer): head_str = self.head_str() if head_str != '': head_str = f' {head_str}' @@ -38,13 +35,13 @@ def render_head(self, r): trailing_space = ' ' return f'({self._ir_name()}{head_str}{trailing_space}' - def render_tail(self, r): + def render_tail(self, r: Renderer): return ')' def _ir_name(self): return self.__class__.__name__ - def render_children(self, r): + def render_children(self, r: Renderer): return self.children def head_str(self): diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 7be6ef3bbe5..f0723a44491 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1512,13 +1512,13 @@ def __init__(self, field, child): self.field = field self.child = child - def render_head(self, r: 'Renderer'): + def render_head(self, r: Renderer): return f'({self.field} ' - def render_tail(self, r: 'Renderer'): + def render_tail(self, r: Renderer): return ')' - def render_children(self, r: 'Renderer'): + def render_children(self, r: Renderer): return [self.child] @staticmethod diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 2ef32b40a1a..fc1afa57575 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,6 +1,6 @@ from hail import ir import abc -from typing import Sequence, List, Set, Dict, Callable, Tuple, Optional +from typing import Sequence, MutableSequence, List, Set, Dict, Optional from collections import namedtuple @@ -89,7 +89,13 @@ def is_empty(self) -> bool: return self._idx < 0 -class Renderer(object): +class Renderer: + @abc.abstractmethod + def add_jir(self, jir): + pass + + +class PlainRenderer(Renderer): def __init__(self, stop_at_jir=False): self.stop_at_jir = stop_at_jir self.count = 0 @@ -146,8 +152,40 @@ def __call__(self, x: 'Renderable'): 'depth lifted_lets agg_lifted_lets scan_lifted_lets') +class CSERenderer(Renderer): + def __init__(self, stop_at_jir=False): + self.stop_at_jir = stop_at_jir + self.jir_count = 0 + self.jirs = {} + self.memo: Dict[int, Sequence[str]] = {} + + def add_jir(self, jir): + jir_id = f'm{self.jir_count}' + self.jir_count += 1 + self.jirs[jir_id] = jir + return jir_id + + def _add_jir(self, node): + jir_id = self.add_jir(node._jir) + if isinstance(node, ir.MatrixIR): + jref = f'(JavaMatrix {jir_id})' + elif isinstance(node, ir.TableIR): + jref = f'(JavaTable {jir_id})' + elif isinstance(node, ir.BlockMatrixIR): + jref = f'(JavaBlockMatrix {jir_id})' + else: + assert isinstance(node, ir.IR) + jref = f'(JavaIR {jir_id})' + + self.memo[id(node)] = jref + + def __call__(self, root: 'ir.BaseIR') -> str: + binding_sites = CSEAnalysisPass(self)(root) + return CSEPrintPass(self)(root, binding_sites) + + class CSEAnalysisPass: - def __init__(self, renderer): + def __init__(self, renderer: CSERenderer): self.renderer = renderer self.uid_count = 0 @@ -230,6 +268,7 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: else: lets = None + # 'lets' is either assigned before one of the 'br/has if lets is not None: # we've seen 'child' before, should not traverse (or we will # find too many lifts) @@ -317,7 +356,7 @@ def __init__(self, min_binding_depth: int, min_value_binding_depth: int, self._child_bindings = None def has_lifted_lets(self) -> bool: - return self.lifted_lets or self.agg_lifted_lets or self.scan_lifted_lets + return bool(self.lifted_lets or self.agg_lifted_lets or self.scan_lifted_lets) def make_binding_site(self, depth): return BindingSite( @@ -383,7 +422,7 @@ def update_free_vars(self, child_frame): class CSEPrintPass: - def __init__(self, renderer): + def __init__(self, renderer: CSERenderer): self.renderer = renderer # At top of main loop, we are considering the 'Renderable' 'node' and its @@ -554,7 +593,7 @@ class StackFrame: def __init__(self, node: Renderable, children: Sequence[Renderable], - builder: Sequence[str], + builder: MutableSequence[str], min_binding_depth: int, min_value_binding_depth: int, scan_scope: bool, @@ -600,7 +639,7 @@ def __init__(self, # parent's 'builder', to save copying. self.builder = builder - def add_lets(self, let_bodies: Sequence[str], out_builder: Sequence[str]): + def add_lets(self, let_bodies: Sequence[str], out_builder: MutableSequence[str]): for let_body in let_bodies: out_builder.extend(let_body) out_builder.extend(self.builder) @@ -611,8 +650,8 @@ def add_lets(self, let_bodies: Sequence[str], out_builder: Sequence[str]): def make_child_frame(self, renderer: 'CSERenderer', binding_sites: Dict[int, BindingSite], - builder: Sequence[str], - bindings_stack: 'Sequence[CSEPrintPass.BindingsStackFrame]', + builder: MutableSequence[str], + bindings_stack: 'MutableSequence[CSEPrintPass.BindingsStackFrame]', min_binding_depth: int, min_value_binding_depth: int, scan_scope: bool, @@ -626,8 +665,8 @@ def make_child_frame(self, def make(node: Renderable, renderer: 'CSERenderer', binding_sites: Dict[int, BindingSite], - builder: Sequence[str], - bindings_stack: 'Sequence[CSEPrintPass.BindingsStackFrame]', + builder: MutableSequence[str], + bindings_stack: 'MutableSequence[CSEPrintPass.BindingsStackFrame]', min_binding_depth: int, min_value_binding_depth: int, scan_scope: bool, @@ -663,35 +702,3 @@ def make_bindings_stack_frame(site: BindingSite): agg_visited={}, scan_visited={}, let_bodies=[]) - - -class CSERenderer: - def __init__(self, stop_at_jir=False): - self.stop_at_jir = stop_at_jir - self.jir_count = 0 - self.jirs = {} - self.memo: Dict[int, Sequence[str]] = {} - - def add_jir(self, jir): - jir_id = f'm{self.jir_count}' - self.jir_count += 1 - self.jirs[jir_id] = jir - return jir_id - - def _add_jir(self, node): - jir_id = self.add_jir(node._jir) - if isinstance(node, ir.MatrixIR): - jref = f'(JavaMatrix {jir_id})' - elif isinstance(node, ir.TableIR): - jref = f'(JavaTable {jir_id})' - elif isinstance(node, ir.BlockMatrixIR): - jref = f'(JavaBlockMatrix {jir_id})' - else: - assert isinstance(node, ir.IR) - jref = f'(JavaIR {jir_id})' - - self.memo[id(node)] = jref - - def __call__(self, root: 'ir.BaseIR') -> str: - binding_sites = CSEAnalysisPass(self)(root) - return CSEPrintPass(self)(root, binding_sites) From cb8fd48deccf523c660ea1371418fbb200064b12 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 13 Sep 2019 12:32:38 -0400 Subject: [PATCH 46/47] fix --- hail/python/hail/backend/backend.py | 4 ++-- hail/python/hail/experimental/function.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index 9c2fc5ab629..6a5ea369d9d 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -193,7 +193,7 @@ def __init__(self): def _to_java_ir(self, ir): if not hasattr(ir, '_jir'): - r = Renderer(stop_at_jir=True) + r = CSERenderer(stop_at_jir=True) # FIXME parse should be static ir._jir = ir.parse(r(ir), ir_map=r.jirs) return ir._jir @@ -221,7 +221,7 @@ def fs(self): return self._fs def _render(self, ir): - r = Renderer() + r = CSERenderer() assert len(r.jirs) == 0 return r(ir) diff --git a/hail/python/hail/experimental/function.py b/hail/python/hail/experimental/function.py index d3153ed9389..422543ee774 100644 --- a/hail/python/hail/experimental/function.py +++ b/hail/python/hail/experimental/function.py @@ -22,7 +22,7 @@ def define_function(f, *param_types, _name=None): body = f(*(construct_expr(Ref(pn), pt) for pn, pt in zip(param_names, param_types))) ret_type = body.dtype - r = Renderer(stop_at_jir=True) + r = CSERenderer(stop_at_jir=True) code = r(body._ir) jbody = body._ir.parse(code, ref_map=dict(zip(param_names, param_types)), ir_map=r.jirs) From 9f4c2983e1777227e8cabb4e85bc28ffeb143edd Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 13 Sep 2019 13:48:07 -0400 Subject: [PATCH 47/47] allow cse on ir with free variables --- hail/python/hail/experimental/function.py | 5 +++-- hail/python/hail/ir/renderer.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/hail/python/hail/experimental/function.py b/hail/python/hail/experimental/function.py index 422543ee774..5f484436447 100644 --- a/hail/python/hail/experimental/function.py +++ b/hail/python/hail/experimental/function.py @@ -1,5 +1,6 @@ from hail.utils.java import Env -from hail.ir import Apply, Ref, Renderer, register_session_function +from hail.ir import Apply, Ref, register_session_function +from hail.ir.renderer import CSERenderer from hail.expr.types import hail_type from hail.expr.expressions import construct_expr, anytype, expr_any, unify_all from hail.typecheck import typecheck, nullable @@ -23,7 +24,7 @@ def define_function(f, *param_types, _name=None): ret_type = body.dtype r = CSERenderer(stop_at_jir=True) - code = r(body._ir) + code = r(body._ir, free_vars=param_names) jbody = body._ir.parse(code, ref_map=dict(zip(param_names, param_types)), ir_map=r.jirs) Env.hail().expr.ir.functions.IRFunctionRegistry.pyRegisterIR( diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index fc1afa57575..bc079ad6829 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -179,8 +179,10 @@ def _add_jir(self, node): self.memo[id(node)] = jref - def __call__(self, root: 'ir.BaseIR') -> str: - binding_sites = CSEAnalysisPass(self)(root) + def __call__(self, root: 'ir.BaseIR', free_vars=None) -> str: + if not free_vars: + free_vars = {} + binding_sites = CSEAnalysisPass(self)(root, free_vars) return CSEPrintPass(self)(root, binding_sites) @@ -205,8 +207,10 @@ def uid(self) -> str: # 'BindingSite' recording the depth of 'node' and a Dict 'lifted_lets', # where for each descendant 'x' which will be bound above 'node', # 'lifted_lets' maps 'id(x)' to the unique id 'x' will be bound to. - def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: - root_frame = self.StackFrame(0, 0, False, ({}, {}, {}), root) + def __call__(self, root: 'ir.BaseIR', free_vars) -> Dict[int, BindingSite]: + root_frame = self.StackFrame(0, 0, False, + ({var: 0 for var in free_vars}, {}, {}), + root) stack = [root_frame] binding_sites = {}