Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mapWith, flatMapWith and filterWith #510

Merged
merged 12 commits into from Mar 23, 2013
Merged
66 changes: 65 additions & 1 deletion core/src/main/scala/spark/RDD.scala
Expand Up @@ -364,6 +364,62 @@ abstract class RDD[T: ClassManifest](
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)

/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => U): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.map(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}

/**
* FlatMaps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => Seq[U]): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.flatMap(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}

/**
* Applies f to each element of this RDD, where f takes an additional parameter of type A.
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def foreachWith[A: ClassManifest](constructA: Int => A)
(f:(T, A) => Unit) {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.map(t => {f(t, a); t})
}
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
}

/**
* Filters this RDD with p, where p takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def filterWith[A: ClassManifest](constructA: Int => A)
(p:(T, A) => Boolean): RDD[T] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.filter(t => p(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
}

/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
Expand All @@ -382,6 +438,14 @@ abstract class RDD[T: ClassManifest](
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => f(iter))
}

/**
* Return an array that contains all of the elements in this RDD.
*/
Expand All @@ -404,7 +468,7 @@ abstract class RDD[T: ClassManifest](

/**
* Return an RDD with the elements from `this` that are not in `other`.
*
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
Expand Down
60 changes: 60 additions & 0 deletions core/src/test/scala/spark/RDDSuite.scala
Expand Up @@ -178,4 +178,64 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(prunedData.size === 1)
assert(prunedData(0) === 10)
}

test("mapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(index: Int) => new Random(index + 42))
{(t: Int, prng: Random) => prng.nextDouble * t}.collect()
val prn42_3 = {
val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(2) === prn42_3)
assert(randoms(5) === prn43_3)
}

test("flatMapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(index: Int) => new Random(index + 42))
{(t: Int, prng: Random) =>
val random = prng.nextDouble()
Seq(random * t, random * t * 10)}.
collect()
val prn42_3 = {
val prng42 = new Random(42)
prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble()
}
val prn43_3 = {
val prng43 = new Random(43)
prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble()
}
assert(randoms(5) === prn42_3 * 10)
assert(randoms(11) === prn43_3 * 10)
}

test("filterWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(index: Int) => new Random(index + 42))
{(t: Int, prng: Random) => prng.nextInt(3) == 0}.
collect()
val checkSample = {
val prng42 = new Random(42)
val prng43 = new Random(43)
Array(1, 2, 3, 4, 5, 6).filter{i =>
if (i < 4) 0 == prng42.nextInt(3)
else 0 == prng43.nextInt(3)}
}
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
}