Skip to content

Commit

Permalink
[hail] Parameterize check on union_rows. Don't check in split_multi (
Browse files Browse the repository at this point in the history
…#6669)

* [hail] Parameterize check on union_rows. Don't check in `split_multi`

When I timed the specific code block we hid behind the flag, it took
about 1.5 seconds on the benchmark dataset.

* bump!

* bump!
  • Loading branch information
tpoterba authored and danking committed Jul 18, 2019
1 parent 7260fa8 commit 205ba0f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
17 changes: 9 additions & 8 deletions hail/python/hail/matrixtable.py
Expand Up @@ -3458,8 +3458,8 @@ def _select_globals(self, caller, s) -> 'MatrixTable':
analyze(caller, s, self._global_indices)
return cleanup(MatrixTable(MatrixMapGlobals(base._mir, s._ir)))

@typecheck(datasets=matrix_table_type)
def union_rows(*datasets: 'MatrixTable') -> 'MatrixTable':
@typecheck(datasets=matrix_table_type, _check_cols=bool)
def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable':
"""Take the union of dataset rows.
Examples
Expand Down Expand Up @@ -3538,12 +3538,13 @@ def union_rows(*datasets: 'MatrixTable') -> 'MatrixTable':
raise ValueError(error_msg.format(
"col key types", 0, first.col_key.dtype, i+1, next.col_key.dtype
))
wrong_keys = hl.eval(hl.rbind(first.col_key.collect(_localize=False), lambda first_keys: (
hl.zip_with_index([mt.col_key.collect(_localize=False) for mt in datasets[1:]])
.find(lambda x: ~(x[1] == first_keys))[0])))
if wrong_keys is not None:
raise ValueError("'MatrixTable.union_rows' expects all datasets to have the same columns. " +
"Datasets 0 and {} have different columns (or possibly different order).".format(wrong_keys+1))
if _check_cols:
wrong_keys = hl.eval(hl.rbind(first.col_key.collect(_localize=False), lambda first_keys: (
hl.zip_with_index([mt.col_key.collect(_localize=False) for mt in datasets[1:]])
.find(lambda x: ~(x[1] == first_keys))[0])))
if wrong_keys is not None:
raise ValueError(f"'MatrixTable.union_rows' expects all datasets to have the same columns. " +
f"Datasets 0 and {wrong_keys+1} have different columns (or possibly different order).")
return MatrixTable(MatrixUnionRows(*[d._mir for d in datasets]))

@typecheck_method(other=matrix_table_type)
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hail/methods/statgen.py
Expand Up @@ -2029,7 +2029,7 @@ def make_array(cond):

left = split_rows(make_array(lambda locus: locus == ds['locus']), False)
moved = split_rows(make_array(lambda locus: locus != ds['locus']), True)
return left.union(moved) if is_table else left.union_rows(moved)
return left.union(moved) if is_table else left.union_rows(moved, _check_cols=False)


@typecheck(ds=oneof(Table, MatrixTable),
Expand Down Expand Up @@ -2983,7 +2983,7 @@ def filter_alleles(mt: MatrixTable,

right = mt.filter_rows((mt.locus != mt.__new_locus) | (mt.alleles != mt.__new_alleles))
right = right.key_rows_by(locus=right.__new_locus, alleles=right.__new_alleles)
return left.union_rows(right).drop('__allele_inclusion', '__new_locus', '__new_alleles')
return left.union_rows(right, _check_cols=False).drop('__allele_inclusion', '__new_locus', '__new_alleles')


@typecheck(mt=MatrixTable, f=anytype, subset=bool)
Expand Down

0 comments on commit 205ba0f

Please sign in to comment.