Skip to content

Commit

Permalink
Fix Python generating Terrible, Horrible, No Good, Very Bad IR (#5107)
Browse files Browse the repository at this point in the history
* Fixes #5100

Fix Python generating Terrible, Horrible, No Good, Very Bad IR

This PR changes the Python select and key_by operators to generate
the IR we'd expect them to be generating (e.g. `ht.select('x')` emits
a `SelectFields` instead of a `MakeStruct`).

In the process, I found and fixed a bug in group expressions for
`GroupedMatrixTable`. This is tested for both tables and matrix tables
in the new tests in `test_table` and `test_matrix_table`.

Some timings:

    >>> mt = hl.read_matrix_table('/Users/tpoterba/data/profile.mt')
    >>> %timeit mt.select_entries('GT')._force_count_rows()

master:

    1.64 s ± 106 ms per loop

PR:

    967 ms ± 61.1 ms per loop

* Fix groupby

* address comments

* Fix

* add unfilter to fix pca

* address check_annotate_exprs issue
  • Loading branch information
tpoterba authored and danking committed Jan 22, 2019
1 parent dc87d80 commit 8c29706
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 362 deletions.
29 changes: 20 additions & 9 deletions hail/python/hail/expr/expressions/indices.py
@@ -1,12 +1,18 @@
from hail.typecheck import *
import hail as hl

from typing import List


class Indices(object):
@typecheck_method(source=anytype, axes=setof(str))
def __init__(self, source=None, axes=set()):
self.source = source
self.axes = axes
self._cached_key = None

def __hash__(self):
return 37 + hash((self.source, *self.axes))

def __eq__(self, other):
return isinstance(other, Indices) and self.source is other.source and self.axes == other.axes
Expand All @@ -31,24 +37,29 @@ def unify(*indices):
return Indices(src, axes)

@property
def key(self):
def protected_key(self) -> List[str]:
if self._cached_key is None:
self._cached_key = self._get_key()
return self._cached_key
else:
return self._cached_key

def _get_key(self):
if self.source is None:
return None
return []
elif isinstance(self.source, hl.Table):
if self == self.source._row_indices:
return self.source.key
return list(self.source.key)
else:
return None
return []
else:
assert isinstance(self.source, hl.MatrixTable)
if self == self.source._row_indices:
return self.source.row_key
return list(self.source.row_key)
elif self == self.source._col_indices:
return self.source.col_key
elif self == self.source._entry_indices:
return hl.struct(**self.source.row_key, **self.source.col_key)
return list(self.source.col_key)
else:
return None
return []

def __str__(self):
return 'Indices(axes={}, source={})'.format(self.axes, self.source)
Expand Down
14 changes: 14 additions & 0 deletions hail/python/hail/expr/expressions/typed_expressions.py
Expand Up @@ -1334,6 +1334,20 @@ def __ne__(self, other):
def __nonzero__(self):
return Expression.__nonzero__(self)

def _annotate_ordered(self, insertions_dict, field_order):
def get_type(field):
e = insertions_dict.get(field)
if e is None:
e = self._fields[field]
return e.dtype

new_type = hl.tstruct(**{f: get_type(f) for f in field_order})
indices, aggregations = unify_all(self, *insertions_dict.values())
return construct_expr(InsertFields(self._ir, [(field, expr._ir) for field, expr in insertions_dict.items()], field_order),
new_type,
indices,
aggregations)

@typecheck_method(named_exprs=expr_any)
def annotate(self, **named_exprs):
"""Add new fields or recompute existing fields.
Expand Down
458 changes: 214 additions & 244 deletions hail/python/hail/matrixtable.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions hail/python/hail/methods/statgen.py
Expand Up @@ -1384,6 +1384,7 @@ def hwe_normalized_pca(call_expr, k=10, compute_loadings=False) -> Tuple[List[fl
mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called)
mt = mt.annotate_rows(
__hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) * n_variants / 2))
mt = mt._unfilter_entries()

normalized_gt = hl.or_else((mt.__gt - mt.__mean_gt) / mt.__hwe_scaled_std_dev, 0.0)

Expand Down
126 changes: 54 additions & 72 deletions hail/python/hail/table.py
Expand Up @@ -112,9 +112,9 @@ class GroupedTable(ExprContainer):
and :meth:`.GroupedTable.aggregate`.
"""

def __init__(self, parent: 'Table', groups):
def __init__(self, parent: 'Table', key_expr):
super(GroupedTable, self).__init__()
self._groups = groups
self._key_expr = key_expr
self._parent = parent
self._npartitions = None
self._buffer_size = 50
Expand Down Expand Up @@ -208,23 +208,19 @@ def aggregate(self, **named_exprs) -> 'Table':
:class:`.Table`
Aggregated table.
"""
if self._groups is None:
raise ValueError('GroupedTable cannot be aggregated if no groupings are specified.')

group_exprs = dict(self._groups)

for name, expr in named_exprs.items():
analyze(f'GroupedTable.aggregate: ({repr(name)})', expr, self._parent._global_indices, {self._parent._row_axis})
if not named_exprs.keys().isdisjoint(group_exprs.keys()):
intersection = set(named_exprs.keys()) & set(group_exprs.keys())
if not named_exprs.keys().isdisjoint(set(self._key_expr)):
intersection = set(named_exprs.keys()) & set(self._key_expr)
raise ValueError(
f'GroupedTable.aggregate: Group names and aggregration expression names overlap: {intersection}')

base, _ = self._parent._process_joins(*group_exprs.values(), *named_exprs.values())
base, _ = self._parent._process_joins(self._key_expr, *named_exprs.values())

key_struct = self._key_expr
return Table(TableKeyByAndAggregate(base._tir,
hl.struct(**named_exprs)._ir,
hl.struct(**group_exprs)._ir,
key_struct._ir,
self._npartitions,
self._buffer_size))

Expand Down Expand Up @@ -535,20 +531,22 @@ def key_by(self, *keys, **named_keys) -> 'Table':
:class:`.Table`
Table with a new key.
"""
key_fields = get_select_exprs("Table.key_by",
keys, named_keys, self._row_indices,
protect_keys=False)
key_fields, computed_keys = get_key_by_exprs("Table.key_by", keys, named_keys, self._row_indices)

new_row = self.row.annotate(**key_fields)
base, cleanup = self._process_joins(new_row)
if not computed_keys:
return Table(TableKeyBy(self._tir, key_fields))
else:
new_row = self.row.annotate(**computed_keys)
base, cleanup = self._process_joins(new_row)

return cleanup(Table(
TableKeyBy(
TableMapRows(
TableKeyBy(base._tir, []),
new_row._ir),
list(key_fields))))
return cleanup(Table(
TableKeyBy(
TableMapRows(
TableKeyBy(base._tir, []),
new_row._ir),
list(key_fields))))

@typecheck_method(named_exprs=expr_any)
def annotate_globals(self, **named_exprs) -> 'Table':
"""Add new global fields.
Expand All @@ -573,9 +571,8 @@ def annotate_globals(self, **named_exprs) -> 'Table':
:class:`.Table`
Table with new global field(s).
"""
named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
for k, v in named_exprs.items():
check_collisions(self._fields, k, self._global_indices)
caller = 'Table.annotate_globals'
check_annotate_exprs(caller, named_exprs, self._global_indices)
return self._select_globals('Table.annotate_globals', self.globals.annotate(**named_exprs))

def select_globals(self, *exprs, **named_exprs) -> 'Table':
Expand Down Expand Up @@ -614,23 +611,16 @@ def select_globals(self, *exprs, **named_exprs) -> 'Table':
:class:`.Table`
Table with specified global fields.
"""
exprs = [self[e] if not isinstance(e, Expression) else e for e in exprs]
named_exprs = {k: to_expr(v) for k, v in named_exprs.items()}
assignments = OrderedDict()

for e in exprs:
if not e._ir.is_nested_field:
raise ExpressionException("method 'select_globals' expects keyword arguments for complex expressions")
assert isinstance(e._ir, GetField)
assignments[e._ir.name] = e

for k, e in named_exprs.items():
check_collisions(self._fields, k, self._global_indices)
assignments[k] = e
caller = 'Table.select_globals'
new_globals = get_select_exprs(caller,
exprs,
named_exprs,
self._global_indices,
self._globals)

check_field_uniqueness(assignments.keys())
return self._select_globals('Table.select_globals', hl.struct(**assignments))
return self._select_globals(caller, new_globals)

@typecheck_method(named_exprs=expr_any)
def transmute_globals(self, **named_exprs) -> 'Table':
"""Similar to :meth:`.Table.annotate_globals`, but drops referenced fields.
Expand All @@ -656,13 +646,14 @@ def transmute_globals(self, **named_exprs) -> 'Table':
:class:`.Table`
"""
caller = 'Table.transmute_globals'
e = get_annotate_exprs(caller, named_exprs, self._global_indices)
fields_referenced = extract_refs_by_indices(e.values(), self._global_indices) - set(e.keys())
check_annotate_exprs(caller, named_exprs, self._global_indices)
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._global_indices) - set(named_exprs.keys())

return self._select_globals(caller,
self.globals.annotate(**named_exprs).drop(*fields_referenced))


@typecheck_method(named_exprs=expr_any)
def transmute(self, **named_exprs) -> 'Table':
"""Add new fields and drop fields referenced.
Expand Down Expand Up @@ -723,12 +714,13 @@ def transmute(self, **named_exprs) -> 'Table':
Table with transmuted fields.
"""
caller = "Table.transmute"
e = get_annotate_exprs(caller, named_exprs, self._row_indices)
fields_referenced = extract_refs_by_indices(e.values(), self._row_indices) - set(e.keys())
check_annotate_exprs(caller, named_exprs, self._row_indices)
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._row_indices) - set(named_exprs.keys())
fields_referenced -= set(self.key)

return self._select(caller, self.row.annotate(**e).drop(*fields_referenced))
return self._select(caller, self.row.annotate(**named_exprs).drop(*fields_referenced))

@typecheck_method(named_exprs=expr_any)
def annotate(self, **named_exprs) -> 'Table':
"""Add new fields.
Expand All @@ -755,8 +747,8 @@ def annotate(self, **named_exprs) -> 'Table':
Table with new fields.
"""
caller = "Table.annotate"
e = get_annotate_exprs(caller, named_exprs, self._row_indices)
return self._select(caller, self.row.annotate(**e))
check_annotate_exprs(caller, named_exprs, self._row_indices)
return self._select(caller, self.row.annotate(**named_exprs))

@typecheck_method(expr=expr_bool,
keep=bool)
Expand Down Expand Up @@ -894,10 +886,11 @@ def select(self, *exprs, **named_exprs) -> 'Table':
:class:`.Table`
Table with specified fields.
"""
row_exprs = get_select_exprs('Table.select',
exprs, named_exprs, self._row_indices,
protect_keys=True)
row = self.key.annotate(**row_exprs)
row = get_select_exprs('Table.select',
exprs,
named_exprs,
self._row_indices,
self._row)

return self._select('Table.select', row)

Expand Down Expand Up @@ -956,14 +949,14 @@ def drop(self, *exprs) -> 'Table':
table = self
if any(self._fields[field]._indices == self._global_indices for field in fields_to_drop):
# need to drop globals
new_global_fields = [f for f in table.globals if
f not in fields_to_drop]
table = table.select_globals(*new_global_fields)
table = table._select_globals('drop',
self._globals.drop(*[f for f in table.globals if f in fields_to_drop]))

if any(self._fields[field]._indices == self._row_indices for field in fields_to_drop):
# need to drop row fields
protected_key = set(self._row_indices.protected_key)
for f in fields_to_drop:
check_keys(f, self._row_indices)
check_keys('drop', f, protected_key)
row_fields = set(table.row)
to_drop = [f for f in fields_to_drop if f in row_fields]
table = table._select('drop', table.row.drop(*to_drop))
Expand Down Expand Up @@ -1099,23 +1092,12 @@ def group_by(self, *exprs, **named_exprs) -> 'GroupedTable':
:class:`.GroupedTable`
Grouped table; use :meth:`.GroupedTable.aggregate` to complete the aggregation.
"""
groups = []
for e in exprs:
if isinstance(e, str):
e = self[e]
else:
e = to_expr(e)
analyze('Table.group_by', e, self._row_indices)
if not e._ir.is_nested_field:
raise ExpressionException("method 'group_by' expects keyword arguments for complex expressions")
key = e._ir.name
groups.append((key, e))
for k, e in named_exprs.items():
e = to_expr(e)
analyze('Table.group_by', e, self._row_indices)
groups.append((k, e))

return GroupedTable(self, groups)
key, computed_key = get_key_by_exprs('Table.group_by',
exprs,
named_exprs,
self._row_indices,
override_protected_indices={self._global_indices})
return GroupedTable(self, self.row.annotate(**computed_key).select(*key))

@typecheck_method(expr=expr_any, _localize=bool)
def aggregate(self, expr, _localize=True):
Expand Down

0 comments on commit 8c29706

Please sign in to comment.