Skip to content

Commit

Permalink
[Hail][feature] add outer option to union_cols (#7475)
Browse files Browse the repository at this point in the history
* add `outer` option to `union_cols`

* better test

* document `outer` option

* missed an `_outer`

* rename parameter, expand documentation

* fix typos
  • Loading branch information
patrick-schultz authored and danking committed Nov 11, 2019
1 parent 80a3513 commit 48ec6bd
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 22 deletions.
9 changes: 8 additions & 1 deletion hail/python/hail/ir/matrix_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,17 @@ def scan_bindings(self, i, default_value=None):


class MatrixUnionCols(MatrixIR):
def __init__(self, left, right):
def __init__(self, left, right, join_type):
super().__init__(left, right)
self.left = left
self.right = right
self.join_type = join_type

def head_str(self):
return f'{escape_id(self.join_type)}'

def _eq(self, other):
return self.join_type == other.join_type

def _compute_type(self):
self.right.typ # force
Expand Down
30 changes: 23 additions & 7 deletions hail/python/hail/matrixtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3570,8 +3570,9 @@ def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable':
f"Datasets 0 and {wrong_keys+1} have different columns (or possibly different order).")
return MatrixTable(MatrixUnionRows(*[d._mir for d in datasets]))

@typecheck_method(other=matrix_table_type)
def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
@typecheck_method(other=matrix_table_type,
row_join_type=enumeration('inner', 'outer'))
def union_cols(self, other: 'MatrixTable', row_join_type='inner') -> 'MatrixTable':
"""Take the union of dataset columns.
Examples
Expand All @@ -3593,10 +3594,22 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
The row fields in the resulting dataset are the row fields from the
first dataset; the row schemas do not need to match.
This method performs an inner join on rows and concatenates entries
from the two datasets for each row. Only distinct keys from each
dataset are included (equivalent to calling :meth:`.distinct_by_row`
on each dataset first).
This method creates a :class:`.MatrixTable` which contains all columns
from both input datasets. The set of rows included in the result is
determined by the `row_join_type` parameter.
- With the default value of ``'inner'``, an inner join is performed
on rows, so that only rows whose row key exists in both input datasets
are included. In this case, the entries for each row are the
concatenation of all entries of the corresponding rows in the input
datasets.
- With `row_join_type` set to ``'outer'``, an outer join is perfomed on
rows, so that row keys which exist in only one input dataset are also
included. For those rows, the entry fields for the columns coming
from the other dataset will be missing.
Only distinct row keys from each dataset are included (equivalent to
calling :meth:`.distinct_by_row` on each dataset first).
This method does not deduplicate; if a column key exists identically in
two datasets, then it will be duplicated in the result.
Expand All @@ -3605,6 +3618,9 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
----------
other : :class:`.MatrixTable`
Dataset to concatenate.
outer : bool
If `True`, perform an outer join on rows, otherwise perform an
inner join. Default `False`.
Returns
-------
Expand All @@ -3628,7 +3644,7 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable':
f' left: {", ".join(self.row_key.dtype.values())}\n'
f' right: {", ".join(other.row_key.dtype.values())}')

return MatrixTable(MatrixUnionCols(self._mir, other._mir))
return MatrixTable(MatrixUnionCols(self._mir, other._mir, row_join_type))

@typecheck_method(n=nullable(int), n_cols=nullable(int))
def head(self, n: Optional[int], n_cols: Optional[int] = None) -> 'MatrixTable':
Expand Down
16 changes: 16 additions & 0 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,22 @@ def test_union_cols_distinct(self):
mt = mt.key_rows_by(x = mt.row_idx // 2)
assert mt.union_cols(mt).count_rows() == 5

def test_union_cols_outer(self):
r, c = 10, 10
mt = hl.utils.range_matrix_table(2*r, c)
mt = mt.annotate_entries(entry=hl.tuple([mt.row_idx, mt.col_idx]))
mt2 = hl.utils.range_matrix_table(2*r, c)
mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r)
mt2 = mt2.key_cols_by(col_idx=mt2.col_idx + c)
mt2 = mt2.annotate_entries(entry=hl.tuple([mt2.row_idx, mt2.col_idx]))
expected = hl.utils.range_matrix_table(3*r, 2*c)
missing = hl.null(hl.ttuple(hl.tint, hl.tint))
expected = expected.annotate_entries(entry=hl.cond(
expected.col_idx < c,
hl.cond(expected.row_idx < 2*r, hl.tuple([expected.row_idx, expected.col_idx]), missing),
hl.cond(expected.row_idx >= r, hl.tuple([expected.row_idx, expected.col_idx]), missing)))
assert mt.union_cols(mt2, row_join_type='outer')._same(expected)

def test_union_rows_different_col_schema(self):
mt = hl.utils.range_matrix_table(10, 10)
mt2 = hl.utils.range_matrix_table(10, 10)
Expand Down
25 changes: 19 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -449,28 +449,41 @@ object LowerMatrixIR {
}
}))

case MatrixUnionCols(left, right) =>
case MatrixUnionCols(left, right, joinType) =>
val rightEntries = genUID()
val rightCols = genUID()
val ll = lower(left, ab).distinct()
def handleMissingEntriesArray(entries: Symbol, cols: Symbol): IRProxy =
if (joinType == "inner")
'row(entries)
else
irIf('row(entries).isNA) {
irRange(0, 'global(cols).len)
.map('a ~> irToProxy(MakeStruct(right.typ.entryType.fieldNames.map(f => (f, NA(right.typ.entryType.fieldType(f)))))))
} {
'row(entries)
}
TableJoin(
ll,
lower(right, ab).distinct()
.mapRows('row
.insertFields(Symbol(rightEntries) -> 'row (entriesField))
.insertFields(Symbol(rightEntries) -> 'row(entriesField))
.selectFields(right.typ.rowKey :+ rightEntries: _*))
.mapGlobals('global
.insertFields(Symbol(rightCols) -> 'global (colsField))
.insertFields(Symbol(rightCols) -> 'global(colsField))
.selectFields(rightCols)),
"inner")
joinType)
.mapRows('row
.insertFields(entriesField ->
makeArray('row (entriesField), 'row (Symbol(rightEntries))).flatMap('a ~> 'a))
makeArray(
handleMissingEntriesArray(entriesField, colsField),
handleMissingEntriesArray(Symbol(rightEntries), Symbol(rightCols)))
.flatMap('a ~> 'a))
// TableJoin puts keys first; drop rightEntries, but also restore left row field order
.selectFields(ll.typ.rowType.fieldNames: _*))
.mapGlobals('global
.insertFields(colsField ->
makeArray('global (colsField), 'global (Symbol(rightCols))).flatMap('a ~> 'a))
makeArray('global(colsField), 'global(Symbol(rightCols))).flatMap('a ~> 'a))
.dropFields(Symbol(rightCols)))

case MatrixMapEntries(child, newEntries) =>
Expand Down
18 changes: 12 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -426,23 +426,29 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR)
lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound
}

case class MatrixUnionCols(left: MatrixIR, right: MatrixIR) extends MatrixIR {
case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) extends MatrixIR {
require(joinType == "inner" || joinType == "outer")
lazy val children: IndexedSeq[BaseIR] = Array(left, right)

def copy(newChildren: IndexedSeq[BaseIR]): MatrixUnionCols = {
assert(newChildren.length == 2)
MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR])
MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR], joinType)
}

val typ: MatrixType = left.typ
val typ: MatrixType = if (joinType == "inner")
left.typ
else
left.typ.copy(
colType = TStruct(left.typ.colType.fields.map(f => f.copy(typ = -f.typ))),
entryType = TStruct(left.typ.entryType.fields.map(f => f.copy(typ = -f.typ))))

override def columnCount: Option[Int] =
left.columnCount.flatMap(leftCount => right.columnCount.map(rightCount => leftCount + rightCount))

lazy val rowCountUpperBound: Option[Long] = (left.rowCountUpperBound, right.rowCountUpperBound) match {
case (Some(l), Some(r)) => Some(l.min(r))
case (Some(l), None) => Some(l)
case (None, Some(r)) => Some(r)
case (Some(l), Some(r)) => if (joinType == "inner") Some(l.min(r)) else Some(l + r)
case (Some(l), None) => if (joinType == "inner") Some(l) else None
case (None, Some(r)) => if (joinType == "inner") Some(r) else None
case (None, None) => None
}
}
Expand Down
3 changes: 2 additions & 1 deletion hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1248,9 +1248,10 @@ object IRParser {
val newEntry = ir_value_expr(env.withRefMap(child.typ.refMap))(it)
MatrixMapEntries(child, newEntry)
case "MatrixUnionCols" =>
val joinType = identifier(it)
val left = matrix_ir(env)(it)
val right = matrix_ir(env)(it)
MatrixUnionCols(left, right)
MatrixUnionCols(left, right, joinType)
case "MatrixMapGlobals" =>
val child = matrix_ir(env)(it)
val newGlobals = ir_value_expr(env.withRefMap(child.typ.refMap))(it)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ object PruneDeadFields {
case MatrixFilterEntries(child, pred) =>
val irDep = memoizeAndGetDep(pred, pred.typ, child.typ, memo)
memoizeMatrixIR(child, unify(child.typ, requestedType, irDep), memo)
case MatrixUnionCols(left, right) =>
case MatrixUnionCols(left, right, joinType) =>
val leftRequestedType = requestedType.copy(
rowKey = left.typ.rowKey,
rowType = unify(left.typ.rowType, requestedType.rowType, selectKey(left.typ.rowType, left.typ.rowKey))
Expand Down

0 comments on commit 48ec6bd

Please sign in to comment.