Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Fix Java API for mapPartitionsWithIndex

  • Loading branch information...
commit 215a9bf5bc36bb53c68112aa5cd8a52152d8cd69 1 parent 84f7ca1
@holdenk authored
View
12 core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -73,11 +73,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithIndex[R: ClassTag](
- f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]],
- preservesPartitioning: Boolean = false): JavaRDD[R] =
- new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
- preservesPartitioning))
+ def mapPartitionsWithIndex[R](f: MapPartitionsWithIndexFunction[T, R],
+ preservesPartitioning: Boolean = false): JavaRDD[R] = {
+ import scala.collection.JavaConverters._
+ def fn = (a: Int, b: Iterator[T]) => f.apply(a, asJavaIterator(b)).asScala
+ val newRdd = rdd.mapPartitionsWithIndex(fn, preservesPartitioning)(f.elementType())
+ new JavaRDD(newRdd)(f.elementType())
+ }
/**
* Return a new RDD by applying a function to all elements of this RDD.
View
20 core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -378,6 +378,26 @@ public void javaDoubleRDDHistoGram() {
}
@Test
+ public void mapPartitionsWithIndex() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ JavaRDD<Integer> rddByIndex =
+ rdd.mapPartitionsWithIndex(new MapPartitionsWithIndexFunction<Integer, Integer>() {
+ @Override
+ public Iterator<Integer> call(Integer start, java.util.Iterator<Integer> iter) {
+ List<Integer> list = new ArrayList<Integer>();
+ int pos = start;
+ while (iter.hasNext()) {
+ list.add(iter.next() * pos);
+ pos += 1;
+ }
+ return list.iterator();
+ }
+ }, false);
+ Assert.assertEquals(0, rddByIndex.first().intValue());
+ }
+
+
+ @Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
Please sign in to comment.
Something went wrong with that request. Please try again.