From 12772ac016be43fb4bae8bce01ecc598329b93b0 Mon Sep 17 00:00:00 2001 From: Tim Poterba Date: Mon, 26 Aug 2019 15:33:46 -0400 Subject: [PATCH] [hail] Better scaling on RVD.union Do a tree reduce instead of a linear reduce. This means that the java stack depth is log2(N) instead of N, and prevents stack overflow errors when unioning hundreds of tables together. --- hail/python/test/hail/table/test_table.py | 7 +++++++ hail/src/main/scala/is/hail/rvd/RVD.scala | 2 +- .../is/hail/utils/richUtils/RichIndexedSeq.scala | 15 +++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/hail/python/test/hail/table/test_table.py b/hail/python/test/hail/table/test_table.py index 01d6aea39ce..b1e6c6e3e5a 100644 --- a/hail/python/test/hail/table/test_table.py +++ b/hail/python/test/hail/table/test_table.py @@ -813,6 +813,13 @@ def test_union(self): self.assertTrue(t1.key_by().union(t2.key_by(), t3.key_by()) ._same(hl.utils.range_table(15).key_by())) + def test_nested_union(self): + N = 100 + M = 200 + t = hl.utils.range_table(N, n_partitions=16) + + assert hl.Table.union(*[t for _ in range(M)])._force_count() == N * M + def test_union_unify(self): t1 = hl.utils.range_table(2) t2 = t1.annotate(x=hl.int32(1), y='A') diff --git a/hail/src/main/scala/is/hail/rvd/RVD.scala b/hail/src/main/scala/is/hail/rvd/RVD.scala index ec97dac2761..f178360ddef 100644 --- a/hail/src/main/scala/is/hail/rvd/RVD.scala +++ b/hail/src/main/scala/is/hail/rvd/RVD.scala @@ -1449,7 +1449,7 @@ object RVD { val sc = first.sparkContext RVD.unkeyed(first.rowPType, ContextRDD.union(sc, rvds.map(_.crdd))) } else - rvds.reduce(_.orderedMerge(_, joinKey)) + rvds.toArray.treeReduce(_.orderedMerge(_, joinKey)) } def union(rvds: Seq[RVD]): RVD = diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala index 09aa8d50c32..4e0d9309214 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala @@ -1,5 +1,9 @@ package is.hail.utils.richUtils +import is.hail.utils._ + +import scala.reflect.ClassTag + /** Rich wrapper for an indexed sequence. * * Houses the generic binary search methods. All methods taking @@ -153,4 +157,15 @@ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { } notFound(left) } + + def treeReduce(f: (T, T) => T)(implicit tct: ClassTag[T]): T = { + var is: IndexedSeq[T] = a + while (is.length > 1) { + is = is.iterator.grouped(2).map { + case Seq(x1, x2) => f(x1, x2) + case Seq(x1) => x1 + }.toFastIndexedSeq + } + is.head + } }