Skip to content

Commit

Permalink
[feature] Add unify option to Table.union (#5858)
Browse files Browse the repository at this point in the history
* [feature] Add `unify` option to `Table.union`

* address comments
  • Loading branch information
tpoterba authored and danking committed Apr 11, 2019
1 parent fb0a894 commit 49f8a6c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 15 deletions.
58 changes: 43 additions & 15 deletions hail/python/hail/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,8 +1883,8 @@ def add_index(self, name='idx') -> 'Table':

return self.annotate(**{name: hl.scan.count()})

@typecheck_method(tables=table_type)
def union(self, *tables) -> 'Table':
@typecheck_method(tables=table_type, unify=bool)
def union(self, *tables, unify: bool = False) -> 'Table':
"""Union the rows of multiple tables.
Examples
Expand All @@ -1896,32 +1896,60 @@ def union(self, *tables) -> 'Table':
Notes
-----
If a row appears in both tables identically, it is duplicated in the
result. The left and right tables must have the same schema and key.
If a row appears in more than one table identically, it is duplicated
in the result. All tables must have the same key names and types. They
must also have the same row types, unless the `unify` parameter is
``True``, in which case a field appearing in any table will be included
in the result, with missing values for tables that do not contain the
field. If a field appears in multiple tables with incompatible types,
like arrays and strings, then an error will be raised.
Parameters
----------
tables : varargs of :class:`.Table`
Tables to union.
unify : :obj:`bool`
Attempt to unify table field.
Returns
-------
:class:`.Table`
Table with all rows from each component table.
"""
left_key = list(self.key)
left_key = self.key.dtype
for i, ht, in enumerate(tables):
right_key = list(ht.key)
if not ht.row.dtype == self.row.dtype:
raise ValueError(f"'union': table {i} has a different row type.\n"
f" Expected: {self.row.dtype}\n"
f" Table {i}: {ht.row.dtype}")
elif left_key != right_key:
if left_key != ht.key.dtype:
raise ValueError(f"'union': table {i} has a different key."
f" Expected: {left_key}\n"
f" Table {i}: {right_key}")
return Table(TableUnion([self._tir] + [table._tir for table in tables]))
f" Expected: {left_key}\n"
f" Table {i}: {ht.key.dtype}")

if not (unify or ht.row.dtype == self.row.dtype):
raise ValueError(f"'union': table {i} has a different row type.\n"
f" Expected: {self.row.dtype}\n"
f" Table {i}: {ht.row.dtype}")
all_tables = [self]
all_tables.extend(tables)

if unify and not len(set(ht.row_value.dtype for ht in all_tables)) == 1:
discovered = defaultdict(dict)
for i, ht in enumerate(all_tables):
for field_name in ht.row_value:
discovered[field_name][i] = ht[field_name]
all_fields = [{} for _ in all_tables]
for field_name, expr_dict in discovered.items():
*unified, can_unify = hl.expr.expressions.unify_exprs(*expr_dict.values())
if not can_unify:
raise ValueError(f"cannot unify field {field_name!r}: found fields of types "
f"{[str(t) for t in {e.dtype for e in expr_dict.values()}]}")
unified_map = dict(zip(expr_dict.keys(), unified))
default = hl.null(unified[0].dtype)
for i in range(len(all_tables)):
all_fields[i][field_name] = unified_map.get(i, default)

for i, t in enumerate(all_tables):
all_tables[i] = t.select(**all_fields[i])

return Table(TableUnion([table._tir for table in all_tables]))

@typecheck_method(n=int, _localize=bool)
def take(self, n, _localize=True):
Expand Down
22 changes: 22 additions & 0 deletions hail/python/test/hail/table/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,28 @@ def test_union(self):
self.assertTrue(t1.key_by().union(t2.key_by(), t3.key_by())
._same(hl.utils.range_table(15).key_by()))

def test_union_unify(self):
t1 = hl.utils.range_table(2)
t2 = t1.annotate(x=hl.int32(1), y='A')
t3 = t1.annotate(z=(1, 2, 3), x=hl.float64(1.5))
t4 = t1.key_by(idx=t1.idx + 10)

u = t1.union(t2, t3, t4, unify=True)

assert u.x.dtype == hl.tfloat64
assert list(u.row) == ['idx', 'x', 'y', 'z']

assert u.collect() == [
hl.utils.Struct(idx=0, x=None, y=None, z=None),
hl.utils.Struct(idx=0, x=1.0, y='A', z=None),
hl.utils.Struct(idx=0, x=1.5, y=None, z=(1, 2, 3)),
hl.utils.Struct(idx=1, x=None, y=None, z=None),
hl.utils.Struct(idx=1, x=1.0, y='A', z=None),
hl.utils.Struct(idx=1, x=1.5, y=None, z=(1, 2, 3)),
hl.utils.Struct(idx=10, x=None, y=None, z=None),
hl.utils.Struct(idx=11, x=None, y=None, z=None),
]

def test_table_head_returns_right_number(self):
rt = hl.utils.range_table(10, 11)
par = hl.Table.parallelize([hl.Struct(x=x) for x in range(10)], schema='struct{x: int32}', n_partitions=11)
Expand Down

0 comments on commit 49f8a6c

Please sign in to comment.