Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hail] Publicize and document Table.multi_way_zip_join #6488

Merged
merged 1 commit into from Jul 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion hail/python/hail/experimental/vcf_combiner.py
Expand Up @@ -153,7 +153,7 @@ def renumber_entry(entry, old_to_new) -> StructExpression:

def combine_gvcfs(mts):
"""merges vcfs using multi way join"""
ts = hl.Table._multi_way_zip_join([localize(mt) for mt in mts], 'data', 'g')
ts = hl.Table.multi_way_zip_join([localize(mt) for mt in mts], 'data', 'g')
combined = combine(ts)
return unlocalize(combined)

Expand Down
45 changes: 39 additions & 6 deletions hail/python/hail/table.py
Expand Up @@ -328,7 +328,7 @@ class Table(ExprContainer):
@staticmethod
def _from_java(jtir):
return Table(JavaTable(jtir))

def __init__(self, tir):
super(Table, self).__init__()

Expand Down Expand Up @@ -2963,7 +2963,7 @@ def from_spark(df, key=[]) -> 'Table':
----------
df : :class:`.pyspark.sql.DataFrame`
PySpark DataFrame.

key : :obj:`str` or :obj:`list` of :obj:`str`
Key fields.

Expand Down Expand Up @@ -3056,22 +3056,22 @@ def _same(self, other, tolerance=1e-6, absolute=False):
r = other
r = r.select_globals(**{right_global_value: r.globals})
r = r.select(**{right_value: r._value})

t = l._zip_join(r)

if not hl.eval(_values_similar(t[left_global_value], t[right_global_value], tolerance, absolute)):
g = hl.eval(t.globals)
print(f'Table._same: globals differ: {g[left_global_value]}, {g[right_global_value]}')
return False

if not t.all(_values_similar(t[left_value], t[right_value], tolerance, absolute)):
print('Table._same: rows differ:')
t = t.filter(~ _values_similar(t[left_value], t[right_value], tolerance, absolute))
bad_rows = t.take(10)
for r in bad_rows:
print(f' {r[left_value]}, {r[right_value]}')
return False

return True


Expand Down Expand Up @@ -3212,7 +3212,40 @@ def _unlocalize_entries(self, entries_field_name, cols_field_name, col_key) -> '

@staticmethod
@typecheck(tables=sequenceof(table_type), data_field_name=str, global_field_name=str)
def _multi_way_zip_join(tables, data_field_name, global_field_name) -> 'Table':
def multi_way_zip_join(tables, data_field_name, global_field_name) -> 'Table':
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe the behaviour of this will change in a backwards incompatible way. The only stabilization question that I think is unresolved is do we want these parameters to have default values?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a default value is backwards compatible, so we can punt on that for now. But I also don't see any problem with just picking something reasonable like "row_fields" and "global_fields". I'm fine with either option, and ready to approve once you decide.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave them off for now then.

"""Combine many tables in a zip join

Notes
-----
The row type of the returned table is a struct with the key fields, and
one extra field, `data_field_name`, which is an array of structs with
the non key fields, one per input. The array elements are missing if
their corresponding input had no row with that key or possibly if there
is another input with more rows with that key than the corresponding
input.

The global type of the returned table is an array of structs of the
global type of all of the inputs.

The types for every input must be identical, not merely compatible,
including the keys.

A zip join is similar to an outer join however rows are not duplicated
to create the full Cartesian product of duplicate keys. Instead, there
is exactly one entry in some `data_field_name` array for every row in
the inputs.

Parameters
----------
tables : :obj:`List[Table]`
A list of tables to combine
data_field_name : :obj:`str`
The name of the resulting data field
global_field_name : :obj:`str`
The name of the resulting global field

.. include:: _templates/experimental.rst
"""
if not tables:
raise ValueError('multi_way_zip_join must have at least one table as an argument')
head = tables[0]
Expand Down
10 changes: 5 additions & 5 deletions hail/python/test/hail/table/test_table.py
Expand Up @@ -453,7 +453,7 @@ def test_multi_way_zip_join(self):
{"id": 3, "name": "z", "data": 0.01}]
s = hl.tstruct(id=hl.tint32, name=hl.tstr, data=hl.tfloat64)
ts = [hl.Table.parallelize(r, schema=s, key='id') for r in [d1, d2, d3]]
joined = hl.Table._multi_way_zip_join(ts, '__data', '__globals').drop('__globals')
joined = hl.Table.multi_way_zip_join(ts, '__data', '__globals').drop('__globals')
dexpected = [{"id": 0, "__data": [{"name": "a", "data": 0.0},
{"name": "d", "data": 1.1},
None]},
Expand All @@ -476,10 +476,10 @@ def test_multi_way_zip_join(self):
self.assertTrue(expected._same(joined))

expected2 = expected.transmute(data=expected['__data'])
joined_same_name = hl.Table._multi_way_zip_join(ts, 'data', 'globals').drop('globals')
joined_same_name = hl.Table.multi_way_zip_join(ts, 'data', 'globals').drop('globals')
self.assertTrue(expected2._same(joined_same_name))

joined_nothing = hl.Table._multi_way_zip_join(ts, 'data', 'globals').drop('data', 'globals')
joined_nothing = hl.Table.multi_way_zip_join(ts, 'data', 'globals').drop('data', 'globals')
self.assertEqual(joined_nothing._force_count(), 5)

def test_multi_way_zip_join_globals(self):
Expand All @@ -490,14 +490,14 @@ def test_multi_way_zip_join_globals(self):
hl.struct(x=hl.null(hl.tint32)),
hl.struct(x=5),
hl.struct(x=0)]))
joined = hl.Table._multi_way_zip_join([t1, t2, t3], '__data', '__globals')
joined = hl.Table.multi_way_zip_join([t1, t2, t3], '__data', '__globals')
self.assertEqual(hl.eval(joined.globals), hl.eval(expected))

def test_multi_way_zip_join_key_downcast(self):
mt = hl.import_vcf(resource('sample.vcf.bgz'))
mt = mt.key_rows_by('locus')
ht = mt.rows()
j = hl.Table._multi_way_zip_join([ht, ht], 'd', 'g')
j = hl.Table.multi_way_zip_join([ht, ht], 'd', 'g')
j._force_count()

def test_index_maintains_count(self):
Expand Down