Skip to content

Commit

Permalink
[query] Add expression source checking to annotate methods (#9207)
Browse files Browse the repository at this point in the history
* [query] Add expression source checking to annotate methods

Fixes #9121

* fix annotate errors

* broadcast
  • Loading branch information
tpoterba committed Aug 5, 2020
1 parent 6bd1cee commit b8421e6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
16 changes: 8 additions & 8 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ def annotate_globals(self, **named_exprs) -> 'MatrixTable':
"""

caller = "MatrixTable.annotate_globals"
check_annotate_exprs(caller, named_exprs, self._global_indices)
check_annotate_exprs(caller, named_exprs, self._global_indices, set())
return self._select_globals(caller, self.globals.annotate(**named_exprs))

@typecheck_method(named_exprs=expr_any)
Expand Down Expand Up @@ -953,7 +953,7 @@ def annotate_rows(self, **named_exprs) -> 'MatrixTable':
"""

caller = "MatrixTable.annotate_rows"
check_annotate_exprs(caller, named_exprs, self._row_indices)
check_annotate_exprs(caller, named_exprs, self._row_indices, {self._col_axis})
return self._select_rows(caller, self._rvrow.annotate(**named_exprs))

@typecheck_method(named_exprs=expr_any)
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def annotate_cols(self, **named_exprs) -> 'MatrixTable':
Matrix table with new column-indexed field(s).
"""
caller = "MatrixTable.annotate_cols"
check_annotate_exprs(caller, named_exprs, self._col_indices)
check_annotate_exprs(caller, named_exprs, self._col_indices, {self._row_axis})
return self._select_cols(caller, self.col.annotate(**named_exprs))

@typecheck_method(named_exprs=expr_any)
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def annotate_entries(self, **named_exprs) -> 'MatrixTable':
Matrix table with new row-and-column-indexed field(s).
"""
caller = "MatrixTable.annotate_entries"
check_annotate_exprs(caller, named_exprs, self._entry_indices)
check_annotate_exprs(caller, named_exprs, self._entry_indices, set())
return self._select_entries(caller, s=self.entry.annotate(**named_exprs))

def select_globals(self, *exprs, **named_exprs) -> 'MatrixTable':
Expand Down Expand Up @@ -1822,7 +1822,7 @@ def transmute_globals(self, **named_exprs) -> 'MatrixTable':
:class:`.MatrixTable`
"""
caller = 'MatrixTable.transmute_globals'
check_annotate_exprs(caller, named_exprs, self._global_indices)
check_annotate_exprs(caller, named_exprs, self._global_indices, set())
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._global_indices) - set(named_exprs.keys())
return self._select_globals(caller,
self.globals.annotate(**named_exprs).drop(*fields_referenced))
Expand Down Expand Up @@ -1861,7 +1861,7 @@ def transmute_rows(self, **named_exprs) -> 'MatrixTable':
:class:`.MatrixTable`
"""
caller = 'MatrixTable.transmute_rows'
check_annotate_exprs(caller, named_exprs, self._row_indices)
check_annotate_exprs(caller, named_exprs, self._row_indices, {self._col_axis})
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._row_indices) - set(named_exprs.keys())
fields_referenced -= set(self.row_key)

Expand Down Expand Up @@ -1901,7 +1901,7 @@ def transmute_cols(self, **named_exprs) -> 'MatrixTable':
:class:`.MatrixTable`
"""
caller = 'MatrixTable.transmute_cols'
check_annotate_exprs(caller, named_exprs, self._col_indices)
check_annotate_exprs(caller, named_exprs, self._col_indices, {self._row_axis})
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._col_indices) - set(named_exprs.keys())
fields_referenced -= set(self.col_key)

Expand Down Expand Up @@ -1934,7 +1934,7 @@ def transmute_entries(self, **named_exprs) -> 'MatrixTable':
:class:`.MatrixTable`
"""
caller = 'MatrixTable.transmute_entries'
check_annotate_exprs(caller, named_exprs, self._entry_indices)
check_annotate_exprs(caller, named_exprs, self._entry_indices, set())
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._entry_indices) - set(named_exprs.keys())

return self._select_entries(caller,
Expand Down
8 changes: 4 additions & 4 deletions hail/python/hail/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def annotate_globals(self, **named_exprs) -> 'Table':
Table with new global field(s).
"""
caller = 'Table.annotate_globals'
check_annotate_exprs(caller, named_exprs, self._global_indices)
check_annotate_exprs(caller, named_exprs, self._global_indices, set())
return self._select_globals('Table.annotate_globals', self.globals.annotate(**named_exprs))

def select_globals(self, *exprs, **named_exprs) -> 'Table':
Expand Down Expand Up @@ -675,7 +675,7 @@ def transmute_globals(self, **named_exprs) -> 'Table':
:class:`.Table`
"""
caller = 'Table.transmute_globals'
check_annotate_exprs(caller, named_exprs, self._global_indices)
check_annotate_exprs(caller, named_exprs, self._global_indices, set())
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._global_indices) - set(named_exprs.keys())

return self._select_globals(caller,
Expand Down Expand Up @@ -742,7 +742,7 @@ def transmute(self, **named_exprs) -> 'Table':
Table with transmuted fields.
"""
caller = "Table.transmute"
check_annotate_exprs(caller, named_exprs, self._row_indices)
check_annotate_exprs(caller, named_exprs, self._row_indices, set())
fields_referenced = extract_refs_by_indices(named_exprs.values(), self._row_indices) - set(named_exprs.keys())
fields_referenced -= set(self.key)

Expand Down Expand Up @@ -775,7 +775,7 @@ def annotate(self, **named_exprs) -> 'Table':
Table with new fields.
"""
caller = "Table.annotate"
check_annotate_exprs(caller, named_exprs, self._row_indices)
check_annotate_exprs(caller, named_exprs, self._row_indices, set())
return self._select(caller, self.row.annotate(**named_exprs))

@typecheck_method(expr=expr_bool,
Expand Down
6 changes: 4 additions & 2 deletions hail/python/hail/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,11 @@ def is_top_level_field(e):
return s


def check_annotate_exprs(caller, named_exprs, indices):
def check_annotate_exprs(caller, named_exprs, indices, agg_axes):
from hail.expr.expressions import analyze
protected_key = set(indices.protected_key)
for k in named_exprs:
for k, v in named_exprs.items():
analyze(f'{caller}: field {k!r}', v, indices, agg_axes, broadcast=True)
check_keys(caller, k, protected_key)
check_collisions(caller, list(named_exprs), indices)
return named_exprs
Expand Down
6 changes: 6 additions & 0 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,12 @@ def test_invalid_field_ref_error(self):
with pytest.raises(hl.expr.ExpressionException, match='Found fields from 2 objects:'):
mt.annotate_entries(x = mt.GT.n_alt_alleles() * mt2.af)

def test_invalid_field_ref_annotate(self):
mt = hl.balding_nichols_model(2, 5, 5)
mt2 = hl.balding_nichols_model(2, 5, 5)
with pytest.raises(hl.expr.ExpressionException, match='source mismatch'):
mt.annotate_entries(x = mt2.af)


def test_read_write_all_types():
mt = create_all_values_matrix_table()
Expand Down

0 comments on commit b8421e6

Please sign in to comment.