Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python IR binding and environment infrastructure #7000

Merged
merged 5 commits into from Sep 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 33 additions & 16 deletions hail/python/hail/expr/matrix_type.py
Expand Up @@ -122,22 +122,39 @@ def _rename(self, global_map, col_map, row_map, entry_map):
[row_map.get(k, k) for k in self.row_key],
self.entry_type._rename(entry_map))

def global_env(self):
return {'global': self.global_type}

def row_env(self):
return {'global': self.global_type,
'va': self.row_type}

def col_env(self):
return {'global': self.global_type,
'sa': self.col_type}

def entry_env(self):
return {'global': self.global_type,
'va': self.row_type,
'sa': self.col_type,
'g': self.entry_type}
def global_env(self, default_value=None):
if default_value is None:
return {'global': self.global_type}
else:
return {'global': default_value}

def row_env(self, default_value=None):
if default_value is None:
return {'global': self.global_type,
'va': self.row_type}
else:
return {'global': default_value,
'va': default_value}

def col_env(self, default_value=None):
if default_value is None:
return {'global': self.global_type,
'sa': self.col_type}
else:
return {'global': default_value,
'sa': default_value}

def entry_env(self, default_value=None):
if default_value is None:
return {'global': self.global_type,
'va': self.row_type,
'sa': self.col_type,
'g': self.entry_type}
else:
return {'global': default_value,
'va': default_value,
'sa': default_value,
'g': default_value}


import pprint
Expand Down
15 changes: 10 additions & 5 deletions hail/python/hail/expr/table_type.py
Expand Up @@ -80,12 +80,17 @@ def _rename(self, global_map, row_map):
self.row_type._rename(row_map),
[row_map.get(k, k) for k in self.row_key])

def row_env(self):
return {'global': self.global_type,
'row': self.row_type}
def row_env(self, default_value=None):
if default_value is None:
return {'global': self.global_type, 'row': self.row_type}
else:
return {'global': default_value, 'row': default_value}

def global_env(self):
return {'global': self.global_type}
def global_env(self, default_value=None):
if default_value is None:
return {'global': self.global_type}
else:
return {'global': default_value}


import pprint
Expand Down
1 change: 1 addition & 0 deletions hail/python/hail/ir/__init__.py
@@ -1,3 +1,4 @@
from .base_ir import *
from .ir import *
from .register_functions import *
from .register_aggregators import *
Expand Down
89 changes: 87 additions & 2 deletions hail/python/hail/ir/base_ir.py
@@ -1,11 +1,24 @@
import abc

from typing import List
from typing import List, Tuple

from hail.utils.java import Env
from hail.expr.types import HailType
from .renderer import Renderer, Renderable, RenderableStr


def _env_bind(env, bindings):
if bindings:
if env:
res = env.copy()
res.update(bindings)
return res
else:
return dict(bindings)
else:
return env


class BaseIR(Renderable):
def __init__(self, *children):
super().__init__()
Expand Down Expand Up @@ -44,7 +57,8 @@ def head_str(self):
def parse(self, code, ref_map, ir_map):
return

@abc.abstractproperty
@property
@abc.abstractmethod
def typ(self):
return

Expand All @@ -71,6 +85,57 @@ def _eq(self, other):
def __hash__(self):
return 31 + hash(str(self))

@abc.abstractmethod
def new_block(self, i: int) -> bool:
...

@staticmethod
def is_effectful() -> bool:
return False

def bindings(self, i: int, default_value=None):
"""Compute variables bound in child 'i'.

Returns
-------
dict
mapping from bound variables to 'default_value', if provided,
otherwise to their types
"""
return {}

def agg_bindings(self, i: int, default_value=None):
return {}

def scan_bindings(self, i: int, default_value=None):
return {}

def uses_agg_context(self, i: int) -> bool:
return False

def uses_scan_context(self, i: int) -> bool:
return False

def child_context_without_bindings(self, i: int, parent_context):
(eval_c, agg_c, scan_c) = parent_context
if self.uses_agg_context(i):
return (agg_c, None, None)
elif self.uses_scan_context(i):
return (scan_c, None, None)
else:
return parent_context

def child_context(self, i: int, parent_context, default_value=None):
base = self.child_context_without_bindings(i, parent_context)
eval_b = self.bindings(i, default_value)
agg_b = self.agg_bindings(i, default_value)
scan_b = self.scan_bindings(i, default_value)
if eval_b or agg_b or scan_b:
(eval_c, agg_c, scan_c) = base
return _env_bind(eval_c, eval_b), _env_bind(agg_c, agg_b), _env_bind(scan_c, scan_b)
else:
return base


class IR(BaseIR):
def __init__(self, *children):
Expand Down Expand Up @@ -117,6 +182,9 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
return False

@abc.abstractmethod
def _compute_type(self, env, agg_env):
raise NotImplementedError(self)
Expand All @@ -143,9 +211,15 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
return True

def parse(self, code, ref_map={}, ir_map={}):
return Env.hail().expr.ir.IRParser.parse_table_ir(code, ref_map, ir_map)

global_env = {'global'}
row_env = {'global', 'row'}


class MatrixIR(BaseIR):
def __init__(self, *children):
Expand All @@ -162,9 +236,17 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
return True

def parse(self, code, ref_map={}, ir_map={}):
return Env.hail().expr.ir.IRParser.parse_matrix_ir(code, ref_map, ir_map)

global_env = {'global'}
row_env = {'global', 'va'}
col_env = {'global', 'sa'}
entry_env = {'global', 'sa', 'va', 'g'}


class BlockMatrixIR(BaseIR):
def __init__(self, *children):
Expand All @@ -181,6 +263,9 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
return True

def parse(self, code, ref_map={}, ir_map={}):
return Env.hail().expr.ir.IRParser.parse_blockmatrix_ir(code, ref_map, ir_map)

Expand Down
24 changes: 24 additions & 0 deletions hail/python/hail/ir/blockmatrix_ir.py
Expand Up @@ -37,6 +37,16 @@ def __init__(self, child, f):
def _compute_type(self):
self._type = self.child.typ

def bindings(self, i: int, default_value=None):
if i == 1:
value = self.child.typ.element_type if default_value is None else default_value
return {'element': value}
else:
return {}

def binds(self, i):
return {'element'} if i == 1 else {}


class BlockMatrixMap2(BlockMatrixIR):
@typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR, f=IR)
Expand All @@ -50,6 +60,20 @@ def _compute_type(self):
self.right.typ # Force
self._type = self.left.typ

def bindings(self, i: int, default_value=None):
if i == 2:
if default_value is None:
l_value = self.left.typ.element_type
r_value = self.right.typ.element_type
else:
(l_value, r_value) = (default_value, default_value)
return {'l': l_value, 'r': r_value}
else:
return {}

def binds(self, i):
return {'l', 'r'} if i == 2 else {}


class BlockMatrixDot(BlockMatrixIR):
@typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR)
Expand Down