diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 717247ba2fc4d..9f8cbe9e8da96 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -38,10 +38,12 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{Utils, BoundedPriorityQueue} +import org.apache.spark.util.{RDDiterable, Utils, BoundedPriorityQueue} import org.apache.spark.SparkContext._ import org.apache.spark._ +import scala.concurrent.duration.Duration +import java.util.concurrent.TimeUnit /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -574,6 +576,8 @@ abstract class RDD[T: ClassManifest]( sc.runJob(this, (iter: Iterator[T]) => f(iter)) } + + /** * Return an array that contains all of the elements in this RDD. */ @@ -594,6 +598,17 @@ abstract class RDD[T: ClassManifest]( filter(f.isDefinedAt).map(f) } + /** + * Return iterable that lazily fetches partitions + * @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism but also increases + * driver memory requirement + * @param timeOut how long to wait for each partition fetch + * @return Iterable of every element in this RDD + */ + def toIterable(prefetchPartitions: Int = 1, timeOut: Duration = Duration(30, TimeUnit.SECONDS)) = { + new RDDiterable[T](this, prefetchPartitions, timeOut) + } + /** * Return an RDD with the elements from `this` that are not in `other`. * diff --git a/core/src/main/scala/org/apache/spark/util/RDDiterable.scala b/core/src/main/scala/org/apache/spark/util/RDDiterable.scala new file mode 100644 index 0000000000000..320c8d69d133a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RDDiterable.scala @@ -0,0 +1,59 @@ +package org.apache.spark.util + +import scala.collection.immutable.Queue +import scala.concurrent.{Await, Future} +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.Duration +import scala.annotation.tailrec +import org.apache.spark.rdd.RDD + +/**Iterable whose iterator iterates over all elements of an RDD without fetching all partitions to the driver process + * + * @param rdd RDD to iterate + * @param prefetchPartitions The number of partitions to prefetch + * @param timeOut How long to wait for each partition before failing. + * @tparam T + */ +class RDDiterable[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, timeOut: Duration) extends Serializable with Iterable[T] { + + def iterator = new Iterator[T] { + var partitions = Range(0, rdd.partitions.size) + var pendingFetches = Queue.empty.enqueue(partitions.take(prefetchPartitions).map(par => fetchData(par))) + partitions = partitions.drop(prefetchPartitions) + var currentIterator: Iterator[T] = Iterator.empty + @tailrec + def hasNext() = { + if (currentIterator.hasNext) { + true + } else { + pendingFetches = partitions.headOption.map { + partitionNo => + pendingFetches.enqueue(fetchData(partitionNo)) + }.getOrElse(pendingFetches) + partitions = partitions.drop(1) + + if (pendingFetches.isEmpty) { + currentIterator = Iterator.empty + false + } else { + val (future, pendingFetchesN) = pendingFetches.dequeue + pendingFetches = pendingFetchesN + currentIterator = Await.result(future, timeOut).iterator + this.hasNext() + } + } + } + def next() = { + hasNext() + currentIterator.next() + } + } + private def fetchData(partitionIndex: Int): Future[Seq[T]] = { + val results = new ArrayBuffer[T]() + rdd.context.submitJob[T, Array[T], Seq[T]](rdd, + x => x.toArray, + List(partitionIndex), + (inx: Int, res: Array[T]) => results.appendAll(res), + results.toSeq) + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 354ab8ae5d7d5..7525b37ec6cc5 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -340,6 +340,26 @@ class RDDSuite extends FunSuite with SharedSparkContext { for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) } + test("toIterable") { + var nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterable(prefetchPartitions = 10).size === 999) + assert(nums.toIterable().toArray === (1 to 999).toArray) + + nums = sc.makeRDD(Range(1000, 1, -1), 100) + assert(nums.toIterable(prefetchPartitions = 10).size === 999) + assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray) + + nums = sc.makeRDD(Range(1, 100), 1000) + assert(nums.toIterable(prefetchPartitions = 10).size === 99) + assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1, 100).toArray) + + nums = sc.makeRDD(Range(1, 1000), 100) + assert(nums.toIterable(prefetchPartitions = -1).size === 999) + assert(nums.toIterable().toArray === (1 to 999).toArray) + } + + + test("take") { var nums = sc.makeRDD(Range(1, 1000), 1) assert(nums.take(0).size === 0)