Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hail][feature] add outer option to union_cols #7475

Merged
merged 6 commits into from Nov 11, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion hail/python/hail/ir/matrix_ir.py
Expand Up @@ -121,10 +121,17 @@ def scan_bindings(self, i, default_value=None):


class MatrixUnionCols(MatrixIR):
def __init__(self, left, right):
def __init__(self, left, right, join_type):
super().__init__(left, right)
self.left = left
self.right = right
self.join_type = join_type

def head_str(self):
return f'{escape_id(self.join_type)}'

def _eq(self, other):
return self.join_type == other.join_type

def _compute_type(self):
self.right.typ # force
Expand Down
30 changes: 23 additions & 7 deletions hail/python/hail/matrixtable.py
Expand Up @@ -3570,8 +3570,9 @@ def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable':
f"Datasets 0 and {wrong_keys+1} have different columns (or possibly different order).")
return MatrixTable(MatrixUnionRows(*[d._mir for d in datasets]))

@typecheck_method(other=matrix_table_type)
def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
@typecheck_method(other=matrix_table_type,
row_join_type=enumeration('inner', 'outer', 'left', 'right'))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't support left/right

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eek, lazy copy/pasting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's what review is for!

def union_cols(self, other: 'MatrixTable', row_join_type='inner') -> 'MatrixTable':
"""Take the union of dataset columns.

Examples
Expand All @@ -3593,10 +3594,22 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
The row fields in the resulting dataset are the row fields from the
first dataset; the row schemas do not need to match.

This method performs an inner join on rows and concatenates entries
from the two datasets for each row. Only distinct keys from each
dataset are included (equivalent to calling :meth:`.distinct_by_row`
on each dataset first).
This method creates a :class:`.MatrixTable` which contains all columns
from both input datasets. The set of rows included in the result is
determined by the `row_join_type` parameter.

- With the default ``row_join_type=inner``, an inner join is performed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally don't format the arg bit as code: should be something like:

with the default value of ``'inner'``

on rows, so that only rows whose row key exists in both input datasets
are included. In this case, the entries for each row are the
concatenation of all entries of the corresponding rows in the input
datasets.
- With ``row_join_type=outer``, an outer join is perfomed on rows, so
that row keys which exist in only one input dataset are also included.
For those rows, the entrie fields for the columns coming from the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: entrie

other dataset will be missing.

Only distinct row keys from each dataset are included (equivalent to
calling :meth:`.distinct_by_row` on each dataset first).

This method does not deduplicate; if a column key exists identically in
two datasets, then it will be duplicated in the result.
Expand All @@ -3605,6 +3618,9 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
----------
other : :class:`.MatrixTable`
Dataset to concatenate.
outer : bool
If `True`, perform an outer join on rows, otherwise perform an
inner join. Default `False`.

Returns
-------
Expand All @@ -3628,7 +3644,7 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
f' left: {", ".join(self.row_key.dtype.values())}\n'
f' right: {", ".join(other.row_key.dtype.values())}')

return MatrixTable(MatrixUnionCols(self._mir, other._mir))
return MatrixTable(MatrixUnionCols(self._mir, other._mir, row_join_type))

@typecheck_method(n=nullable(int), n_cols=nullable(int))
def head(self, n: Optional[int], n_cols: Optional[int] = None) -> 'MatrixTable':
Expand Down
16 changes: 16 additions & 0 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Expand Up @@ -530,6 +530,22 @@ def test_union_cols_distinct(self):
mt = mt.key_rows_by(x = mt.row_idx // 2)
assert mt.union_cols(mt).count_rows() == 5

def test_union_cols_outer(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test for correct entry joining?

r, c = 10, 10
mt = hl.utils.range_matrix_table(2*r, c)
mt = mt.annotate_entries(entry=hl.tuple([mt.row_idx, mt.col_idx]))
mt2 = hl.utils.range_matrix_table(2*r, c)
mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r)
mt2 = mt2.key_cols_by(col_idx=mt2.col_idx + c)
mt2 = mt2.annotate_entries(entry=hl.tuple([mt2.row_idx, mt2.col_idx]))
expected = hl.utils.range_matrix_table(3*r, 2*c)
missing = hl.null(hl.ttuple(hl.tint, hl.tint))
expected = expected.annotate_entries(entry=hl.cond(
expected.col_idx < c,
hl.cond(expected.row_idx < 2*r, hl.tuple([expected.row_idx, expected.col_idx]), missing),
hl.cond(expected.row_idx >= r, hl.tuple([expected.row_idx, expected.col_idx]), missing)))
assert mt.union_cols(mt2, row_join_type='outer')._same(expected)

def test_union_rows_different_col_schema(self):
mt = hl.utils.range_matrix_table(10, 10)
mt2 = hl.utils.range_matrix_table(10, 10)
Expand Down
25 changes: 19 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala
Expand Up @@ -449,28 +449,41 @@ object LowerMatrixIR {
}
}))

case MatrixUnionCols(left, right) =>
case MatrixUnionCols(left, right, joinType) =>
val rightEntries = genUID()
val rightCols = genUID()
val ll = lower(left, ab).distinct()
def handleMissingEntriesArray(entries: Symbol, cols: Symbol): IRProxy =
if (joinType == "inner")
'row(entries)
else
irIf('row(entries).isNA) {
irRange(0, 'global(cols).len)
.map('a ~> irToProxy(MakeStruct(right.typ.entryType.fieldNames.map(f => (f, NA(right.typ.entryType.fieldType(f)))))))
} {
'row(entries)
}
TableJoin(
ll,
lower(right, ab).distinct()
.mapRows('row
.insertFields(Symbol(rightEntries) -> 'row (entriesField))
.insertFields(Symbol(rightEntries) -> 'row(entriesField))
.selectFields(right.typ.rowKey :+ rightEntries: _*))
.mapGlobals('global
.insertFields(Symbol(rightCols) -> 'global (colsField))
.insertFields(Symbol(rightCols) -> 'global(colsField))
.selectFields(rightCols)),
"inner")
joinType)
.mapRows('row
.insertFields(entriesField ->
makeArray('row (entriesField), 'row (Symbol(rightEntries))).flatMap('a ~> 'a))
makeArray(
handleMissingEntriesArray(entriesField, colsField),
handleMissingEntriesArray(Symbol(rightEntries), Symbol(rightCols)))
.flatMap('a ~> 'a))
// TableJoin puts keys first; drop rightEntries, but also restore left row field order
.selectFields(ll.typ.rowType.fieldNames: _*))
.mapGlobals('global
.insertFields(colsField ->
makeArray('global (colsField), 'global (Symbol(rightCols))).flatMap('a ~> 'a))
makeArray('global(colsField), 'global(Symbol(rightCols))).flatMap('a ~> 'a))
.dropFields(Symbol(rightCols)))

case MatrixMapEntries(child, newEntries) =>
Expand Down
18 changes: 12 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala
Expand Up @@ -426,23 +426,29 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR)
lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound
}

case class MatrixUnionCols(left: MatrixIR, right: MatrixIR) extends MatrixIR {
case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) extends MatrixIR {
require(joinType == "inner" || joinType == "outer")
lazy val children: IndexedSeq[BaseIR] = Array(left, right)

def copy(newChildren: IndexedSeq[BaseIR]): MatrixUnionCols = {
assert(newChildren.length == 2)
MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR])
MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR], joinType)
}

val typ: MatrixType = left.typ
val typ: MatrixType = if (joinType == "inner")
left.typ
else
left.typ.copy(
colType = TStruct(left.typ.colType.fields.map(f => f.copy(typ = -f.typ))),
entryType = TStruct(left.typ.entryType.fields.map(f => f.copy(typ = -f.typ))))

override def columnCount: Option[Int] =
left.columnCount.flatMap(leftCount => right.columnCount.map(rightCount => leftCount + rightCount))

lazy val rowCountUpperBound: Option[Long] = (left.rowCountUpperBound, right.rowCountUpperBound) match {
case (Some(l), Some(r)) => Some(l.min(r))
case (Some(l), None) => Some(l)
case (None, Some(r)) => Some(r)
case (Some(l), Some(r)) => if (joinType == "inner") Some(l.min(r)) else Some(l + r)
case (Some(l), None) => if (joinType == "inner") Some(l) else None
case (None, Some(r)) => if (joinType == "inner") Some(r) else None
case (None, None) => None
}
}
Expand Down
3 changes: 2 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/Parser.scala
Expand Up @@ -1248,9 +1248,10 @@ object IRParser {
val newEntry = ir_value_expr(env.withRefMap(child.typ.refMap))(it)
MatrixMapEntries(child, newEntry)
case "MatrixUnionCols" =>
val joinType = identifier(it)
val left = matrix_ir(env)(it)
val right = matrix_ir(env)(it)
MatrixUnionCols(left, right)
MatrixUnionCols(left, right, joinType)
case "MatrixMapGlobals" =>
val child = matrix_ir(env)(it)
val newGlobals = ir_value_expr(env.withRefMap(child.typ.refMap))(it)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala
Expand Up @@ -536,7 +536,7 @@ object PruneDeadFields {
case MatrixFilterEntries(child, pred) =>
val irDep = memoizeAndGetDep(pred, pred.typ, child.typ, memo)
memoizeMatrixIR(child, unify(child.typ, requestedType, irDep), memo)
case MatrixUnionCols(left, right) =>
case MatrixUnionCols(left, right, joinType) =>
val leftRequestedType = requestedType.copy(
rowKey = left.typ.rowKey,
rowType = unify(left.typ.rowType, requestedType.rowType, selectKey(left.typ.rowType, left.typ.rowKey))
Expand Down