Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Simplify and genericize type parameters in Bagel

  • Loading branch information...
commit 0028caf3a4727623f70e23cd2f611f9797d0a3d3 1 parent 2d7057b
@ankurdave ankurdave authored
Showing with 129 additions and 85 deletions.
  1. +129 −85 bagel/src/main/scala/spark/bagel/Bagel.scala
View
214 bagel/src/main/scala/spark/bagel/Bagel.scala
@@ -6,54 +6,110 @@ import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
object Bagel extends Logging {
- def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest](
+ def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
+ C : Manifest, A : Manifest](
sc: SparkContext,
- verts: RDD[(String, V)],
- msgs: RDD[(String, M)]
+ vertices: RDD[(K, V)],
+ messages: RDD[(K, M)],
+ combiner: Combiner[M, C],
+ aggregator: Option[Aggregator[V, A]],
+ partitioner: Partitioner,
+ numSplits: Int
)(
- combiner: Combiner[M, C] = new DefaultCombiner[M],
- aggregator: Aggregator[V, A] = new NullAggregator[V],
- superstep: Int = 0,
- numSplits: Int = sc.defaultParallelism
- )(
- compute: (V, Option[C], A, Int) => (V, Iterable[M])
- ): RDD[V] = {
-
- logInfo("Starting superstep "+superstep+".")
- val startTime = System.currentTimeMillis
-
- val aggregated = agg(verts, aggregator)
- val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
- val grouped = verts.groupWith(combinedMsgs)
- val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
-
- val timeTaken = System.currentTimeMillis - startTime
- logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
-
- // Check stopping condition and iterate
- val noActivity = numMsgs == 0 && numActiveVerts == 0
- if (noActivity) {
- processed.map { case (id, (vert, msgs)) => vert }
- } else {
- val newVerts = processed.mapValues { case (vert, msgs) => vert }
- val newMsgs = processed.flatMap {
+ compute: (V, Option[C], Option[A], Int) => (V, Array[M])
+ ): RDD[(K, V)] = {
+ val splits = if (numSplits != 0) numSplits else sc.defaultParallelism
+
+ var superstep = 0
+ var verts = vertices
+ var msgs = messages
+ var noActivity = false
+ do {
+ logInfo("Starting superstep "+superstep+".")
+ val startTime = System.currentTimeMillis
+
+ val aggregated = agg(verts, aggregator)
+ val combinedMsgs = msgs.combineByKey(
+ combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners,
+ splits, partitioner)
+ val grouped = combinedMsgs.groupWith(verts)
+ val (processed, numMsgs, numActiveVerts) =
+ comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
+
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+
+ verts = processed.mapValues { case (vert, msgs) => vert }
+ msgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
}
- run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute)
- }
+ superstep += 1
+
+ noActivity = numMsgs == 0 && numActiveVerts == 0
+ } while (!noActivity)
+
+ verts
+ }
+
+ def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
+ C : Manifest](
+ sc: SparkContext,
+ vertices: RDD[(K, V)],
+ messages: RDD[(K, M)],
+ combiner: Combiner[M, C],
+ partitioner: Partitioner,
+ numSplits: Int
+ )(
+ compute: (V, Option[C], Int) => (V, Array[M])
+ ): RDD[(K, V)] = {
+ run[K, V, M, C, Nothing](
+ sc, vertices, messages, combiner, None, partitioner, numSplits)(
+ addAggregatorArg[K, V, M, C](compute))
+ }
+
+ def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
+ C : Manifest](
+ sc: SparkContext,
+ vertices: RDD[(K, V)],
+ messages: RDD[(K, M)],
+ combiner: Combiner[M, C],
+ numSplits: Int
+ )(
+ compute: (V, Option[C], Int) => (V, Array[M])
+ ): RDD[(K, V)] = {
+ val part = new HashPartitioner(numSplits)
+ run[K, V, M, C, Nothing](
+ sc, vertices, messages, combiner, None, part, numSplits)(
+ addAggregatorArg[K, V, M, C](compute))
+ }
+
+ def run[K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
+ sc: SparkContext,
+ vertices: RDD[(K, V)],
+ messages: RDD[(K, M)],
+ numSplits: Int
+ )(
+ compute: (V, Option[Array[M]], Int) => (V, Array[M])
+ ): RDD[(K, V)] = {
+ val part = new HashPartitioner(numSplits)
+ run[K, V, M, Array[M], Nothing](
+ sc, vertices, messages, new DefaultCombiner(), None, part, numSplits)(
+ addAggregatorArg[K, V, M, Array[M]](compute))
}
/**
- * Aggregates the given vertices using the given aggregator, or does
- * nothing if it is a NullAggregator.
+ * Aggregates the given vertices using the given aggregator, if it
+ * is specified.
*/
- def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match {
- case _: NullAggregator[_] =>
- None
- case _ =>
- verts.map {
- case (id, vert) => aggregator.createAggregator(vert)
- }.reduce(aggregator.mergeAggregators(_, _))
+ private def agg[K, V <: Vertex, A : Manifest](
+ verts: RDD[(K, V)],
+ aggregator: Option[Aggregator[V, A]]
+ ): Option[A] = aggregator match {
+ case Some(a) =>
+ Some(verts.map {
+ case (id, vert) => a.createAggregator(vert)
+ }.reduce(a.mergeAggregators(_, _)))
+ case None => None
}
/**
@@ -61,23 +117,27 @@ object Bagel extends Logging {
* function. Returns the processed RDD, the number of messages
* created, and the number of active vertices.
*/
- def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = {
+ private def comp[K : Manifest, V <: Vertex, M <: Message[K], C](
+ sc: SparkContext,
+ grouped: RDD[(K, (Seq[C], Seq[V]))],
+ compute: (V, Option[C]) => (V, Array[M])
+ ): (RDD[(K, (V, Array[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues {
- case (Seq(), _) => None
- case (Seq(v), c) =>
- val (newVert, newMsgs) =
- compute(v, c match {
- case Seq(comb) => Some(comb)
- case Seq() => None
- })
-
- numMsgs += newMsgs.size
- if (newVert.active)
- numActiveVerts += 1
-
- Some((newVert, newMsgs))
+ case (_, vs) if vs.size == 0 => None
+ case (c, vs) =>
+ val (newVert, newMsgs) =
+ compute(vs(0), c match {
+ case Seq(comb) => Some(comb)
+ case Seq() => None
+ })
+
+ numMsgs += newMsgs.size
+ if (newVert.active)
+ numActiveVerts += 1
+
+ Some((newVert, newMsgs))
}.cache
// Force evaluation of processed RDD for accurate performance measurements
@@ -90,16 +150,16 @@ object Bagel extends Logging {
* Converts a compute function that doesn't take an aggregator to
* one that does, so it can be passed to Bagel.run.
*/
- implicit def addAggregatorArg[
- V <: Vertex : Manifest, M <: Message : Manifest, C
+ private def addAggregatorArg[
+ K : Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C
](
- compute: (V, Option[C], Int) => (V, Iterable[M])
- ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = {
- (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep)
+ compute: (V, Option[C], Int) => (V, Array[M])
+ ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = {
+ (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) =>
+ compute(vert, msgs, superstep)
}
}
-// TODO: Simplify Combiner interface and make it more OO.
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
@@ -111,18 +171,13 @@ trait Aggregator[V, A] {
def mergeAggregators(a: A, b: A): A
}
-class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] with Serializable {
- def createCombiner(msg: M): ArrayBuffer[M] =
- ArrayBuffer(msg)
- def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
- combiner += msg
- def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
- a ++= b
-}
-
-class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable {
- def createAggregator(vert: V): Option[Nothing] = None
- def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
+class DefaultCombiner[M : Manifest] extends Combiner[M, Array[M]] with Serializable {
+ def createCombiner(msg: M): Array[M] =
+ Array(msg)
+ def mergeMsg(combiner: Array[M], msg: M): Array[M] =
+ combiner :+ msg
+ def mergeCombiners(a: Array[M], b: Array[M]): Array[M] =
+ a ++ b
}
/**
@@ -132,7 +187,6 @@ class NullAggregator[V] extends Aggregator[V, Option[Nothing]] with Serializable
* inherit from java.io.Serializable or scala.Serializable.
*/
trait Vertex {
- def id: String
def active: Boolean
}
@@ -142,16 +196,6 @@ trait Vertex {
* Subclasses may contain a payload to deliver to the target vertex
* and must inherit from java.io.Serializable or scala.Serializable.
*/
-trait Message {
- def targetId: String
-}
-
-/**
- * Represents a directed edge between two vertices.
- *
- * Subclasses may store state along each edge and must inherit from
- * java.io.Serializable or scala.Serializable.
- */
-trait Edge {
- def targetId: String
+trait Message[K] {
+ def targetId: K
}
Please sign in to comment.
Something went wrong with that request. Please try again.