Skip to content

Commit

Permalink
Merge pull request alteryx#5 from markhamstra/streamingIterable
Browse files Browse the repository at this point in the history
Streaming iterable
  • Loading branch information
jhartlaub committed Jan 30, 2014
2 parents 13555f1 + 178ef28 commit 7f2d770
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
17 changes: 16 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{Utils, BoundedPriorityQueue}
import org.apache.spark.util.{RDDiterable, Utils, BoundedPriorityQueue}

import org.apache.spark.SparkContext._
import org.apache.spark._
import scala.concurrent.duration.Duration
import java.util.concurrent.TimeUnit

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -574,6 +576,8 @@ abstract class RDD[T: ClassManifest](
sc.runJob(this, (iter: Iterator[T]) => f(iter))
}



/**
* Return an array that contains all of the elements in this RDD.
*/
Expand All @@ -594,6 +598,17 @@ abstract class RDD[T: ClassManifest](
filter(f.isDefinedAt).map(f)
}

/**
* Return iterable that lazily fetches partitions
* @param prefetchPartitions How many partitions to prefetch. Larger value increases parallelism but also increases
* driver memory requirement
* @param timeOut how long to wait for each partition fetch
* @return Iterable of every element in this RDD
*/
def toIterable(prefetchPartitions: Int = 1, timeOut: Duration = Duration(30, TimeUnit.SECONDS)) = {
new RDDiterable[T](this, prefetchPartitions, timeOut)
}

/**
* Return an RDD with the elements from `this` that are not in `other`.
*
Expand Down
59 changes: 59 additions & 0 deletions core/src/main/scala/org/apache/spark/util/RDDiterable.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package org.apache.spark.util

import scala.collection.immutable.Queue
import scala.concurrent.{Await, Future}
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration.Duration
import scala.annotation.tailrec
import org.apache.spark.rdd.RDD

/**Iterable whose iterator iterates over all elements of an RDD without fetching all partitions to the driver process
*
* @param rdd RDD to iterate
* @param prefetchPartitions The number of partitions to prefetch
* @param timeOut How long to wait for each partition before failing.
* @tparam T
*/
class RDDiterable[T: ClassManifest](rdd: RDD[T], prefetchPartitions: Int, timeOut: Duration) extends Serializable with Iterable[T] {

def iterator = new Iterator[T] {
var partitions = Range(0, rdd.partitions.size)
var pendingFetches = Queue.empty.enqueue(partitions.take(prefetchPartitions).map(par => fetchData(par)))
partitions = partitions.drop(prefetchPartitions)
var currentIterator: Iterator[T] = Iterator.empty
@tailrec
def hasNext() = {
if (currentIterator.hasNext) {
true
} else {
pendingFetches = partitions.headOption.map {
partitionNo =>
pendingFetches.enqueue(fetchData(partitionNo))
}.getOrElse(pendingFetches)
partitions = partitions.drop(1)

if (pendingFetches.isEmpty) {
currentIterator = Iterator.empty
false
} else {
val (future, pendingFetchesN) = pendingFetches.dequeue
pendingFetches = pendingFetchesN
currentIterator = Await.result(future, timeOut).iterator
this.hasNext()
}
}
}
def next() = {
hasNext()
currentIterator.next()
}
}
private def fetchData(partitionIndex: Int): Future[Seq[T]] = {
val results = new ArrayBuffer[T]()
rdd.context.submitJob[T, Array[T], Seq[T]](rdd,
x => x.toArray,
List(partitionIndex),
(inx: Int, res: Array[T]) => results.appendAll(res),
results.toSeq)
}
}
20 changes: 20 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,26 @@ class RDDSuite extends FunSuite with SharedSparkContext {
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}

test("toIterable") {
var nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
assert(nums.toIterable().toArray === (1 to 999).toArray)

nums = sc.makeRDD(Range(1000, 1, -1), 100)
assert(nums.toIterable(prefetchPartitions = 10).size === 999)
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1000, 1, -1).toArray)

nums = sc.makeRDD(Range(1, 100), 1000)
assert(nums.toIterable(prefetchPartitions = 10).size === 99)
assert(nums.toIterable(prefetchPartitions = 10).toArray === Range(1, 100).toArray)

nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.toIterable(prefetchPartitions = -1).size === 999)
assert(nums.toIterable().toArray === (1 to 999).toArray)
}



test("take") {
var nums = sc.makeRDD(Range(1, 1000), 1)
assert(nums.take(0).size === 0)
Expand Down

0 comments on commit 7f2d770

Please sign in to comment.