diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index af0114bee3f49..c07d4b85b213b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -73,11 +73,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithIndex[R: ClassTag]( - f: JFunction2[java.lang.Integer, java.util.Iterator[T], java.util.Iterator[R]], - preservesPartitioning: Boolean = false): JavaRDD[R] = - new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), - preservesPartitioning)) + def mapPartitionsWithIndex[R](f: MapPartitionsWithIndexFunction[T, R], + preservesPartitioning: Boolean = false): JavaRDD[R] = { + import scala.collection.JavaConverters._ + def fn = (a: Int, b: Iterator[T]) => f.apply(a, asJavaIterator(b)).asScala + val newRdd = rdd.mapPartitionsWithIndex(fn, preservesPartitioning)(f.elementType()) + new JavaRDD(newRdd)(f.elementType()) + } /** * Return a new RDD by applying a function to all elements of this RDD. diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 40e853c39ca99..2dfa52a86a355 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -414,6 +414,26 @@ public void javaDoubleRDDHistoGram() { Assert.assertArrayEquals(expected_counts, histogram); } + @Test + public void mapPartitionsWithIndex() { + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); + JavaRDD rddByIndex = + rdd.mapPartitionsWithIndex(new MapPartitionsWithIndexFunction() { + @Override + public Iterator call(Integer start, java.util.Iterator iter) { + List list = new ArrayList(); + int pos = start; + while (iter.hasNext()) { + list.add(iter.next() * pos); + pos += 1; + } + return list.iterator(); + } + }, false); + Assert.assertEquals(0, rddByIndex.first().intValue()); + } + + @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));