Permalink
Browse files

Added reduceByKey operation for RDDs containing pairs

  • Loading branch information...
mateiz committed Oct 4, 2010
1 parent 34ecced commit 9f20b6b4338d70262be3a2256856080aa22ecbce
Showing with 38 additions and 2 deletions.
  1. +33 −2 src/scala/spark/RDD.scala
  2. +5 −0 src/scala/spark/SparkContext.scala
View
@@ -7,6 +7,7 @@ import java.util.Random
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map
+import scala.collection.mutable.HashMap
import mesos._
@@ -27,7 +28,12 @@ abstract class RDD[T: ClassManifest](
def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f))
def aggregateSplit() = new SplitRDD(this)
def cache() = new CachedRDD(this)
- def sample(withReplacement: Boolean, frac: Double, seed: Int) = new SampledRDD(this, withReplacement, frac, seed)
+
+ def sample(withReplacement: Boolean, frac: Double, seed: Int) =
+ new SampledRDD(this, withReplacement, frac, seed)
+
+ def flatMap[U: ClassManifest](f: T => Traversable[U]) =
+ new FlatMappedRDD(this, sc.clean(f))
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
@@ -140,6 +146,16 @@ extends RDD[T](prev.sparkContext) {
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
}
+class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+ prev: RDD[T], f: T => Traversable[U])
+extends RDD[U](prev.sparkContext) {
+ override def splits = prev.splits
+ override def preferredLocations(split: Split) = prev.preferredLocations(split)
+ override def iterator(split: Split) =
+ prev.iterator(split).toStream.flatMap(f).iterator
+ override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
+}
+
class SplitRDD[T: ClassManifest](prev: RDD[T])
extends RDD[Array[T]](prev.sparkContext) {
override def splits = prev.splits
@@ -281,7 +297,7 @@ extends RDD[T](sc) {
s.asInstanceOf[UnionSplit[T]].preferredLocations()
}
-@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split {}
+@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split
@serializable
class CartesianRDD[T: ClassManifest, U:ClassManifest](
@@ -310,3 +326,18 @@ extends RDD[Pair[T, U]](sc) {
rdd2.taskStarted(currSplit.s2, slot)
}
}
+
+@serializable class PairRDDExtras[K, V](rdd: RDD[(K, V)]) {
+ def reduceByKey(func: (V, V) => V): Map[K, V] = {
+ def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
+ for ((k, v) <- m2) {
+ m1.get(k) match {
+ case None => m1(k) = v
+ case Some(w) => m1(k) = func(w, v)
+ }
+ }
+ return m1
+ }
+ rdd.map(pair => HashMap(pair)).reduce(mergeMaps)
+ }
+}
@@ -85,9 +85,14 @@ object SparkContext {
def add(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
}
+
implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
def add(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0
}
+
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
+
+ implicit def rddToPairRDDExtras[K, V](rdd: RDD[(K, V)]) =
+ new PairRDDExtras(rdd)
}

0 comments on commit 9f20b6b

Please sign in to comment.