Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Spark 1165 rdd.intersection in python and java

Author: Prashant Sharma <prashant.s@imaginea.com>
Author: Prashant Sharma <scrapcodes@gmail.com>

Closes #80 from ScrapCodes/SPARK-1165/RDD.intersection and squashes the following commits:

9b015e9 [Prashant Sharma] Added a note, shuffle is required for intersection.
1fea813 [Prashant Sharma] correct the lines wrapping
d0c71f3 [Prashant Sharma] SPARK-1165 RDD.intersection in java
d6effee [Prashant Sharma] SPARK-1165 Implemented RDD.intersection in python.
  • Loading branch information...
commit 6e730edcde7ca6cbb5727dff7a42f7284b368528 1 parent b7cd9e9
@ScrapCodes ScrapCodes authored pwendell committed
View
8 core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -140,6 +140,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
*/
def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd))
+
// Double RDD functions
/** Add up the elements in this RDD. */
View
10 core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -126,6 +126,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.union(other.rdd))
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.intersection(other.rdd))
+
+
// first() has to be overridden here so that the generated method has the signature
// 'public scala.Tuple2 first()'; if the trait's definition is used,
// then the method has the signature 'public java.lang.Object first()',
View
9 core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -106,6 +106,15 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
*/
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
+
+ /**
+ * Return the intersection of this RDD and another one. The output will not contain any duplicate
+ * elements, even if the input RDDs did.
+ *
+ * Note that this method performs a shuffle internally.
+ */
+ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd))
+
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
View
31 core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -110,6 +110,37 @@ public void sparkContextUnion() {
Assert.assertEquals(4, pUnion.count());
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void intersection() {
+ List<Integer> ints1 = Arrays.asList(1, 10, 2, 3, 4, 5);
+ List<Integer> ints2 = Arrays.asList(1, 6, 2, 3, 7, 8);
+ JavaRDD<Integer> s1 = sc.parallelize(ints1);
+ JavaRDD<Integer> s2 = sc.parallelize(ints2);
+
+ JavaRDD<Integer> intersections = s1.intersection(s2);
+ Assert.assertEquals(3, intersections.count());
+
+ ArrayList<Integer> list = new ArrayList<Integer>();
+ JavaRDD<Integer> empty = sc.parallelize(list);
+ JavaRDD<Integer> emptyIntersection = empty.intersection(s2);
+ Assert.assertEquals(0, emptyIntersection.count());
+
+ List<Double> doubles = Arrays.asList(1.0, 2.0);
+ JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles);
+ JavaDoubleRDD dIntersection = d1.intersection(d2);
+ Assert.assertEquals(2, dIntersection.count());
+
+ List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
+ pairs.add(new Tuple2<Integer, Integer>(1, 2));
+ pairs.add(new Tuple2<Integer, Integer>(3, 4));
+ JavaPairRDD<Integer, Integer> p1 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> p2 = sc.parallelizePairs(pairs);
+ JavaPairRDD<Integer, Integer> pIntersection = p1.intersection(p2);
+ Assert.assertEquals(2, pIntersection.count());
+ }
+
@Test
public void sortByKey() {
List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>();
View
17 python/pyspark/rdd.py
@@ -326,6 +326,23 @@ def union(self, other):
return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
self.ctx.serializer)
+ def intersection(self, other):
+ """
+ Return the intersection of this RDD and another one. The output will not
+ contain any duplicate elements, even if the input RDDs did.
+
+ Note that this method performs a shuffle internally.
+
+ >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5])
+ >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8])
+ >>> rdd1.intersection(rdd2).collect()
+ [1, 2, 3]
+ """
+ return self.map(lambda v: (v, None)) \
+ .cogroup(other.map(lambda v: (v, None))) \
+ .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \
+ .keys()
+
def _reserialize(self):
if self._jrdd_deserializer == self.ctx.serializer:
return self
Please sign in to comment.
Something went wrong with that request. Please try again.