Skip to content

Commit

Permalink
Smarter take/limit implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
rxin committed Sep 21, 2013
1 parent 119de80 commit 42571d3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 10 deletions.
38 changes: 28 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Expand Up @@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest](
}

/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
* it will be slow if a lot of partitions are required. In that case, use collect() to get the
* whole RDD instead.
* Take the first num elements of the RDD. It works by first scanning one partition, and use the
* results from that partition to estimate the number of additional partitions needed to satisfy
* the limit.
*/
def take(num: Int): Array[T] = {
if (num == 0) {
return new Array[T](0)
}

val buf = new ArrayBuffer[T]
var p = 0
while (buf.size < num && p < partitions.size) {
val totalParts = this.partitions.length
var partsScanned = 0
while (buf.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (buf.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = num - buf.size
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
buf ++= res(0)
if (buf.size == num)
return buf.toArray
p += 1
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)

res.foreach(buf ++= _.take(num - buf.size))
partsScanned += numPartsToTry
}

return buf.toArray
}

Expand Down
38 changes: 38 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Expand Up @@ -321,6 +321,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}

test("take") {
var nums = sc.makeRDD(Range(1, 1000), 1)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)

nums = sc.makeRDD(Range(1, 1000), 2)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)

nums = sc.makeRDD(Range(1, 1000), 100)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)

nums = sc.makeRDD(Range(1, 1000), 1000)
assert(nums.take(0).size === 0)
assert(nums.take(1) === Array(1))
assert(nums.take(3) === Array(1, 2, 3))
assert(nums.take(500) === (1 to 500).toArray)
assert(nums.take(501) === (1 to 501).toArray)
assert(nums.take(999) === (1 to 999).toArray)
assert(nums.take(1000) === (1 to 999).toArray)
}

test("top with predefined ordering") {
val nums = Array.range(1, 100000)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
Expand Down

0 comments on commit 42571d3

Please sign in to comment.