From e005d8bfc22da9100a737af0c481cecb6d336958 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Wed, 6 Nov 2019 14:29:22 -0500 Subject: [PATCH 1/6] add `outer` option to `union_cols` --- hail/python/hail/ir/matrix_ir.py | 9 ++++++- hail/python/hail/matrixtable.py | 6 ++--- .../hail/matrixtable/test_matrix_table.py | 5 ++++ .../scala/is/hail/expr/ir/LowerMatrixIR.scala | 25 ++++++++++++++----- .../main/scala/is/hail/expr/ir/MatrixIR.scala | 18 ++++++++----- .../main/scala/is/hail/expr/ir/Parser.scala | 3 ++- .../is/hail/expr/ir/PruneDeadFields.scala | 2 +- 7 files changed, 50 insertions(+), 18 deletions(-) diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index c52e840414a..5f606deeea7 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -121,10 +121,17 @@ def scan_bindings(self, i, default_value=None): class MatrixUnionCols(MatrixIR): - def __init__(self, left, right): + def __init__(self, left, right, join_type): super().__init__(left, right) self.left = left self.right = right + self.join_type = join_type + + def head_str(self): + return f'{escape_id(self.join_type)}' + + def _eq(self, other): + return self.join_type == other.join_type def _compute_type(self): self.right.typ # force diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index b49174b14b5..5d62e22bcaf 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -3570,8 +3570,8 @@ def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable': f"Datasets 0 and {wrong_keys+1} have different columns (or possibly different order).") return MatrixTable(MatrixUnionRows(*[d._mir for d in datasets])) - @typecheck_method(other=matrix_table_type) - def union_cols(self, other: 'MatrixTable') -> 'MatrixTable': + @typecheck_method(other=matrix_table_type, _outer=bool) + def union_cols(self, other: 'MatrixTable', _outer=False) -> 'MatrixTable': """Take the union of dataset columns. Examples @@ -3628,7 +3628,7 @@ def union_cols(self, other: 'MatrixTable') -> 'MatrixTable': f' left: {", ".join(self.row_key.dtype.values())}\n' f' right: {", ".join(other.row_key.dtype.values())}') - return MatrixTable(MatrixUnionCols(self._mir, other._mir)) + return MatrixTable(MatrixUnionCols(self._mir, other._mir, "outer" if _outer else "inner")) @typecheck_method(n=nullable(int), n_cols=nullable(int)) def head(self, n: Optional[int], n_cols: Optional[int] = None) -> 'MatrixTable': diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index dd703809716..c7f5dd35c4c 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -530,6 +530,11 @@ def test_union_cols_distinct(self): mt = mt.key_rows_by(x = mt.row_idx // 2) assert mt.union_cols(mt).count_rows() == 5 + def test_union_cols_outer(self): + mt = hl.utils.range_matrix_table(10, 1) + mt2 = mt.key_rows_by(row_idx=mt.row_idx + 5) + assert mt.union_cols(mt2, _outer=True).count_rows() == 15 + def test_union_rows_different_col_schema(self): mt = hl.utils.range_matrix_table(10, 10) mt2 = hl.utils.range_matrix_table(10, 10) diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala index 62b058bc5d3..16404cad043 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala @@ -449,28 +449,41 @@ object LowerMatrixIR { } })) - case MatrixUnionCols(left, right) => + case MatrixUnionCols(left, right, joinType) => val rightEntries = genUID() val rightCols = genUID() val ll = lower(left, ab).distinct() + def handleMissingEntriesArray(entries: Symbol, cols: Symbol): IRProxy = + if (joinType == "inner") + 'row(entries) + else + irIf('row(entries).isNA) { + irRange(0, 'global(cols).len) + .map('a ~> irToProxy(MakeStruct(right.typ.entryType.fieldNames.map(f => (f, NA(right.typ.entryType.fieldType(f))))))) + } { + 'row(entries) + } TableJoin( ll, lower(right, ab).distinct() .mapRows('row - .insertFields(Symbol(rightEntries) -> 'row (entriesField)) + .insertFields(Symbol(rightEntries) -> 'row(entriesField)) .selectFields(right.typ.rowKey :+ rightEntries: _*)) .mapGlobals('global - .insertFields(Symbol(rightCols) -> 'global (colsField)) + .insertFields(Symbol(rightCols) -> 'global(colsField)) .selectFields(rightCols)), - "inner") + joinType) .mapRows('row .insertFields(entriesField -> - makeArray('row (entriesField), 'row (Symbol(rightEntries))).flatMap('a ~> 'a)) + makeArray( + handleMissingEntriesArray(entriesField, colsField), + handleMissingEntriesArray(Symbol(rightEntries), Symbol(rightCols))) + .flatMap('a ~> 'a)) // TableJoin puts keys first; drop rightEntries, but also restore left row field order .selectFields(ll.typ.rowType.fieldNames: _*)) .mapGlobals('global .insertFields(colsField -> - makeArray('global (colsField), 'global (Symbol(rightCols))).flatMap('a ~> 'a)) + makeArray('global(colsField), 'global(Symbol(rightCols))).flatMap('a ~> 'a)) .dropFields(Symbol(rightCols))) case MatrixMapEntries(child, newEntries) => diff --git a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala index 38b62c52be2..57760dc9f8b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala @@ -426,23 +426,29 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR) lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound } -case class MatrixUnionCols(left: MatrixIR, right: MatrixIR) extends MatrixIR { +case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) extends MatrixIR { + require(joinType == "inner" || joinType == "outer") lazy val children: IndexedSeq[BaseIR] = Array(left, right) def copy(newChildren: IndexedSeq[BaseIR]): MatrixUnionCols = { assert(newChildren.length == 2) - MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR]) + MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR], joinType) } - val typ: MatrixType = left.typ + val typ: MatrixType = if (joinType == "inner") + left.typ + else + left.typ.copy( + colType = TStruct(left.typ.colType.fields.map(f => f.copy(typ = -f.typ))), + entryType = TStruct(left.typ.entryType.fields.map(f => f.copy(typ = -f.typ)))) override def columnCount: Option[Int] = left.columnCount.flatMap(leftCount => right.columnCount.map(rightCount => leftCount + rightCount)) lazy val rowCountUpperBound: Option[Long] = (left.rowCountUpperBound, right.rowCountUpperBound) match { - case (Some(l), Some(r)) => Some(l.min(r)) - case (Some(l), None) => Some(l) - case (None, Some(r)) => Some(r) + case (Some(l), Some(r)) => if (joinType == "inner") Some(l.min(r)) else Some(l + r) + case (Some(l), None) => if (joinType == "inner") Some(l) else None + case (None, Some(r)) => if (joinType == "inner") Some(r) else None case (None, None) => None } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index dbdea076f2d..d4ea63a1506 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -1248,9 +1248,10 @@ object IRParser { val newEntry = ir_value_expr(env.withRefMap(child.typ.refMap))(it) MatrixMapEntries(child, newEntry) case "MatrixUnionCols" => + val joinType = identifier(it) val left = matrix_ir(env)(it) val right = matrix_ir(env)(it) - MatrixUnionCols(left, right) + MatrixUnionCols(left, right, joinType) case "MatrixMapGlobals" => val child = matrix_ir(env)(it) val newGlobals = ir_value_expr(env.withRefMap(child.typ.refMap))(it) diff --git a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala index f7fff6954a6..61d44273402 100644 --- a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala +++ b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala @@ -536,7 +536,7 @@ object PruneDeadFields { case MatrixFilterEntries(child, pred) => val irDep = memoizeAndGetDep(pred, pred.typ, child.typ, memo) memoizeMatrixIR(child, unify(child.typ, requestedType, irDep), memo) - case MatrixUnionCols(left, right) => + case MatrixUnionCols(left, right, joinType) => val leftRequestedType = requestedType.copy( rowKey = left.typ.rowKey, rowType = unify(left.typ.rowType, requestedType.rowType, selectKey(left.typ.rowType, left.typ.rowKey)) From 9e2ade1424c6eacea4d6f7e289aac7eb17c27683 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 7 Nov 2019 09:31:57 -0500 Subject: [PATCH 2/6] better test --- .../test/hail/matrixtable/test_matrix_table.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index c7f5dd35c4c..8e25893fb97 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -531,9 +531,20 @@ def test_union_cols_distinct(self): assert mt.union_cols(mt).count_rows() == 5 def test_union_cols_outer(self): - mt = hl.utils.range_matrix_table(10, 1) - mt2 = mt.key_rows_by(row_idx=mt.row_idx + 5) - assert mt.union_cols(mt2, _outer=True).count_rows() == 15 + r, c = 10, 10 + mt = hl.utils.range_matrix_table(2*r, c) + mt = mt.annotate_entries(entry=hl.tuple([mt.row_idx, mt.col_idx])) + mt2 = hl.utils.range_matrix_table(2*r, c) + mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r) + mt2 = mt2.key_cols_by(col_idx=mt2.col_idx + c) + mt2 = mt2.annotate_entries(entry=hl.tuple([mt2.row_idx, mt2.col_idx])) + expected = hl.utils.range_matrix_table(3*r, 2*c) + missing = hl.null(hl.ttuple(hl.tint, hl.tint)) + expected = expected.annotate_entries(entry=hl.cond( + expected.col_idx < c, + hl.cond(expected.row_idx < 2*r, hl.tuple([expected.row_idx, expected.col_idx]), missing), + hl.cond(expected.row_idx >= r, hl.tuple([expected.row_idx, expected.col_idx]), missing))) + assert mt.union_cols(mt2, _outer=True)._same(expected) def test_union_rows_different_col_schema(self): mt = hl.utils.range_matrix_table(10, 10) From 2cb653ff74fa0ae06984daf0ddab5ca349153dc4 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 7 Nov 2019 09:47:41 -0500 Subject: [PATCH 3/6] document `outer` option --- hail/python/hail/matrixtable.py | 10 +++++++--- hail/python/test/hail/matrixtable/test_matrix_table.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 5d62e22bcaf..45bfe5896c0 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -3570,8 +3570,8 @@ def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable': f"Datasets 0 and {wrong_keys+1} have different columns (or possibly different order).") return MatrixTable(MatrixUnionRows(*[d._mir for d in datasets])) - @typecheck_method(other=matrix_table_type, _outer=bool) - def union_cols(self, other: 'MatrixTable', _outer=False) -> 'MatrixTable': + @typecheck_method(other=matrix_table_type, outer=bool) + def union_cols(self, other: 'MatrixTable', outer=False) -> 'MatrixTable': """Take the union of dataset columns. Examples @@ -3596,7 +3596,8 @@ def union_cols(self, other: 'MatrixTable', _outer=False) -> 'MatrixTable': This method performs an inner join on rows and concatenates entries from the two datasets for each row. Only distinct keys from each dataset are included (equivalent to calling :meth:`.distinct_by_row` - on each dataset first). + on each dataset first). If ``outer=True``, then this will instead + perform an outer join on rows. This method does not deduplicate; if a column key exists identically in two datasets, then it will be duplicated in the result. @@ -3605,6 +3606,9 @@ def union_cols(self, other: 'MatrixTable', _outer=False) -> 'MatrixTable': ---------- other : :class:`.MatrixTable` Dataset to concatenate. + outer : bool + If `True`, perform an outer join on rows, otherwise perform an + inner join. Default `False`. Returns ------- diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index 8e25893fb97..dfe48c2e553 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -544,7 +544,7 @@ def test_union_cols_outer(self): expected.col_idx < c, hl.cond(expected.row_idx < 2*r, hl.tuple([expected.row_idx, expected.col_idx]), missing), hl.cond(expected.row_idx >= r, hl.tuple([expected.row_idx, expected.col_idx]), missing))) - assert mt.union_cols(mt2, _outer=True)._same(expected) + assert mt.union_cols(mt2, outer=True)._same(expected) def test_union_rows_different_col_schema(self): mt = hl.utils.range_matrix_table(10, 10) From b0b9431c481a75f59339783c39aa2229c0e8cfe3 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Thu, 7 Nov 2019 11:56:01 -0500 Subject: [PATCH 4/6] missed an `_outer` --- hail/python/hail/matrixtable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 45bfe5896c0..25527b8becf 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -3632,7 +3632,7 @@ def union_cols(self, other: 'MatrixTable', outer=False) -> 'MatrixTable': f' left: {", ".join(self.row_key.dtype.values())}\n' f' right: {", ".join(other.row_key.dtype.values())}') - return MatrixTable(MatrixUnionCols(self._mir, other._mir, "outer" if _outer else "inner")) + return MatrixTable(MatrixUnionCols(self._mir, other._mir, "outer" if outer else "inner")) @typecheck_method(n=nullable(int), n_cols=nullable(int)) def head(self, n: Optional[int], n_cols: Optional[int] = None) -> 'MatrixTable': From 7ab56df341b6faacc2291bbc45bc50905b294eb0 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 8 Nov 2019 13:53:07 -0500 Subject: [PATCH 5/6] rename parameter, expand documentation --- hail/python/hail/matrixtable.py | 28 +++++++++++++------ .../hail/matrixtable/test_matrix_table.py | 2 +- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 25527b8becf..d4e9a6ccb3c 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -3570,8 +3570,9 @@ def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable': f"Datasets 0 and {wrong_keys+1} have different columns (or possibly different order).") return MatrixTable(MatrixUnionRows(*[d._mir for d in datasets])) - @typecheck_method(other=matrix_table_type, outer=bool) - def union_cols(self, other: 'MatrixTable', outer=False) -> 'MatrixTable': + @typecheck_method(other=matrix_table_type, + row_join_type=enumeration('inner', 'outer', 'left', 'right')) + def union_cols(self, other: 'MatrixTable', row_join_type='inner') -> 'MatrixTable': """Take the union of dataset columns. Examples @@ -3593,11 +3594,22 @@ def union_cols(self, other: 'MatrixTable', outer=False) -> 'MatrixTable': The row fields in the resulting dataset are the row fields from the first dataset; the row schemas do not need to match. - This method performs an inner join on rows and concatenates entries - from the two datasets for each row. Only distinct keys from each - dataset are included (equivalent to calling :meth:`.distinct_by_row` - on each dataset first). If ``outer=True``, then this will instead - perform an outer join on rows. + This method creates a :class:`.MatrixTable` which contains all columns + from both input datasets. The set of rows included in the result is + determined by the `row_join_type` parameter. + + - With the default ``row_join_type=inner``, an inner join is performed + on rows, so that only rows whose row key exists in both input datasets + are included. In this case, the entries for each row are the + concatenation of all entries of the corresponding rows in the input + datasets. + - With ``row_join_type=outer``, an outer join is perfomed on rows, so + that row keys which exist in only one input dataset are also included. + For those rows, the entrie fields for the columns coming from the + other dataset will be missing. + + Only distinct row keys from each dataset are included (equivalent to + calling :meth:`.distinct_by_row` on each dataset first). This method does not deduplicate; if a column key exists identically in two datasets, then it will be duplicated in the result. @@ -3632,7 +3644,7 @@ def union_cols(self, other: 'MatrixTable', outer=False) -> 'MatrixTable': f' left: {", ".join(self.row_key.dtype.values())}\n' f' right: {", ".join(other.row_key.dtype.values())}') - return MatrixTable(MatrixUnionCols(self._mir, other._mir, "outer" if outer else "inner")) + return MatrixTable(MatrixUnionCols(self._mir, other._mir, row_join_type)) @typecheck_method(n=nullable(int), n_cols=nullable(int)) def head(self, n: Optional[int], n_cols: Optional[int] = None) -> 'MatrixTable': diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index dfe48c2e553..264abf17cd6 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -544,7 +544,7 @@ def test_union_cols_outer(self): expected.col_idx < c, hl.cond(expected.row_idx < 2*r, hl.tuple([expected.row_idx, expected.col_idx]), missing), hl.cond(expected.row_idx >= r, hl.tuple([expected.row_idx, expected.col_idx]), missing))) - assert mt.union_cols(mt2, outer=True)._same(expected) + assert mt.union_cols(mt2, row_join_type='outer')._same(expected) def test_union_rows_different_col_schema(self): mt = hl.utils.range_matrix_table(10, 10) From a2c3e11212f828317954882fee87cfba864f2303 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Fri, 8 Nov 2019 15:16:04 -0500 Subject: [PATCH 6/6] fix typos --- hail/python/hail/matrixtable.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index d4e9a6ccb3c..8aab5fc4f83 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -3571,7 +3571,7 @@ def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable': return MatrixTable(MatrixUnionRows(*[d._mir for d in datasets])) @typecheck_method(other=matrix_table_type, - row_join_type=enumeration('inner', 'outer', 'left', 'right')) + row_join_type=enumeration('inner', 'outer')) def union_cols(self, other: 'MatrixTable', row_join_type='inner') -> 'MatrixTable': """Take the union of dataset columns. @@ -3598,15 +3598,15 @@ def union_cols(self, other: 'MatrixTable', row_join_type='inner') -> 'MatrixTabl from both input datasets. The set of rows included in the result is determined by the `row_join_type` parameter. - - With the default ``row_join_type=inner``, an inner join is performed + - With the default value of ``'inner'``, an inner join is performed on rows, so that only rows whose row key exists in both input datasets are included. In this case, the entries for each row are the concatenation of all entries of the corresponding rows in the input datasets. - - With ``row_join_type=outer``, an outer join is perfomed on rows, so - that row keys which exist in only one input dataset are also included. - For those rows, the entrie fields for the columns coming from the - other dataset will be missing. + - With `row_join_type` set to ``'outer'``, an outer join is perfomed on + rows, so that row keys which exist in only one input dataset are also + included. For those rows, the entry fields for the columns coming + from the other dataset will be missing. Only distinct row keys from each dataset are included (equivalent to calling :meth:`.distinct_by_row` on each dataset first).