Skip to content

Commit

Permalink
Fix Java API for mapPartitionsWithIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Mar 8, 2014
1 parent 0b7b7fd commit 8d849a1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
Expand Up @@ -73,11 +73,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition. * of the original partition.
*/ */
def mapPartitionsWithIndex[R: ClassTag]( def mapPartitionsWithIndex[R](f: MapPartitionsWithIndexFunction[T, R],
f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = {
preservesPartitioning: Boolean = false): JavaRDD[R] = import scala.collection.JavaConverters._
new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), def fn = (a: Int, b: Iterator[T]) => f.apply(a, asJavaIterator(b)).asScala
preservesPartitioning)) val newRdd = rdd.mapPartitionsWithIndex(fn, preservesPartitioning)(f.elementType())
new JavaRDD(newRdd)(f.elementType())
}


/** /**
* Return a new RDD by applying a function to all elements of this RDD. * Return a new RDD by applying a function to all elements of this RDD.
Expand Down
20 changes: 20 additions & 0 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Expand Up @@ -414,6 +414,26 @@ public void javaDoubleRDDHistoGram() {
Assert.assertArrayEquals(expected_counts, histogram); Assert.assertArrayEquals(expected_counts, histogram);
} }


@Test
public void mapPartitionsWithIndex() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaRDD<Integer> rddByIndex =
rdd.mapPartitionsWithIndex(new MapPartitionsWithIndexFunction<Integer, Integer>() {
@Override
public Iterator<Integer> call(Integer start, java.util.Iterator<Integer> iter) {
List<Integer> list = new ArrayList<Integer>();
int pos = start;
while (iter.hasNext()) {
list.add(iter.next() * pos);
pos += 1;
}
return list.iterator();
}
}, false);
Assert.assertEquals(0, rddByIndex.first().intValue());
}


@Test @Test
public void map() { public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
Expand Down

0 comments on commit 8d849a1

Please sign in to comment.