Skip to content

Commit

Permalink
Added reduceByKey operation for RDDs containing pairs
Browse files Browse the repository at this point in the history
  • Loading branch information
mateiz committed Oct 4, 2010
1 parent 34ecced commit 9f20b6b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
35 changes: 33 additions & 2 deletions src/scala/spark/RDD.scala
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import java.util.Random


import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.Map import scala.collection.mutable.Map
import scala.collection.mutable.HashMap


import mesos._ import mesos._


Expand All @@ -27,7 +28,12 @@ abstract class RDD[T: ClassManifest](
def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f)) def filter(f: T => Boolean) = new FilteredRDD(this, sc.clean(f))
def aggregateSplit() = new SplitRDD(this) def aggregateSplit() = new SplitRDD(this)
def cache() = new CachedRDD(this) def cache() = new CachedRDD(this)
def sample(withReplacement: Boolean, frac: Double, seed: Int) = new SampledRDD(this, withReplacement, frac, seed)
def sample(withReplacement: Boolean, frac: Double, seed: Int) =
new SampledRDD(this, withReplacement, frac, seed)

def flatMap[U: ClassManifest](f: T => Traversable[U]) =
new FlatMappedRDD(this, sc.clean(f))


def foreach(f: T => Unit) { def foreach(f: T => Unit) {
val cleanF = sc.clean(f) val cleanF = sc.clean(f)
Expand Down Expand Up @@ -140,6 +146,16 @@ extends RDD[T](prev.sparkContext) {
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot) override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
} }


class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T], f: T => Traversable[U])
extends RDD[U](prev.sparkContext) {
override def splits = prev.splits
override def preferredLocations(split: Split) = prev.preferredLocations(split)
override def iterator(split: Split) =
prev.iterator(split).toStream.flatMap(f).iterator
override def taskStarted(split: Split, slot: SlaveOffer) = prev.taskStarted(split, slot)
}

class SplitRDD[T: ClassManifest](prev: RDD[T]) class SplitRDD[T: ClassManifest](prev: RDD[T])
extends RDD[Array[T]](prev.sparkContext) { extends RDD[Array[T]](prev.sparkContext) {
override def splits = prev.splits override def splits = prev.splits
Expand Down Expand Up @@ -281,7 +297,7 @@ extends RDD[T](sc) {
s.asInstanceOf[UnionSplit[T]].preferredLocations() s.asInstanceOf[UnionSplit[T]].preferredLocations()
} }


@serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split {} @serializable class CartesianSplit(val s1: Split, val s2: Split) extends Split


@serializable @serializable
class CartesianRDD[T: ClassManifest, U:ClassManifest]( class CartesianRDD[T: ClassManifest, U:ClassManifest](
Expand Down Expand Up @@ -310,3 +326,18 @@ extends RDD[Pair[T, U]](sc) {
rdd2.taskStarted(currSplit.s2, slot) rdd2.taskStarted(currSplit.s2, slot)
} }
} }

@serializable class PairRDDExtras[K, V](rdd: RDD[(K, V)]) {
def reduceByKey(func: (V, V) => V): Map[K, V] = {
def mergeMaps(m1: HashMap[K, V], m2: HashMap[K, V]): HashMap[K, V] = {
for ((k, v) <- m2) {
m1.get(k) match {
case None => m1(k) = v
case Some(w) => m1(k) = func(w, v)
}
}
return m1
}
rdd.map(pair => HashMap(pair)).reduce(mergeMaps)
}
}
5 changes: 5 additions & 0 deletions src/scala/spark/SparkContext.scala
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -85,9 +85,14 @@ object SparkContext {
def add(t1: Double, t2: Double): Double = t1 + t2 def add(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0 def zero(initialValue: Double) = 0.0
} }

implicit object IntAccumulatorParam extends AccumulatorParam[Int] { implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
def add(t1: Int, t2: Int): Int = t1 + t2 def add(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0 def zero(initialValue: Int) = 0
} }

// TODO: Add AccumulatorParams for other types, e.g. lists and strings // TODO: Add AccumulatorParams for other types, e.g. lists and strings

implicit def rddToPairRDDExtras[K, V](rdd: RDD[(K, V)]) =
new PairRDDExtras(rdd)
} }

0 comments on commit 9f20b6b

Please sign in to comment.