Bagel: Large-scale graph processing on Spark

Bagel is an implementation of the Pregel graph processing framework on Spark.

Bagel currently supports basic graph computation, combiners, and aggregators. Future work includes support for mutating the graph topology. Tests exist but currently don't run due to a Spark bug.

I would recommend you refactor your code before merging, it is always harder / less tempting to do after.


This looks great, Ankur, except for two naming things: can you change the package name from bagel to spark.bagel, and can you rename the Pregel class to Bagel?


Sure, I've done so.

This page is out of date. Refresh to see the latest.
159 bagel/src/main/scala/spark/bagel/Bagel.scala
@@ -0,0 +1,159 @@
+package spark.bagel
+import spark._
+import spark.SparkContext._
+import scala.collection.mutable.ArrayBuffer
+object Bagel extends Logging {
+ def run[V <: Vertex : Manifest, M <: Message : Manifest, C : Manifest, A : Manifest](
+ sc: SparkContext,
+ verts: RDD[(String, V)],
+ msgs: RDD[(String, M)]
+ )(
+ combiner: Combiner[M, C] = new DefaultCombiner[M],
+ aggregator: Aggregator[V, A] = new NullAggregator[V],
+ superstep: Int = 0,
+ numSplits: Int = sc.numCores
+ )(
+ compute: (V, Option[C], A, Int) => (V, Iterable[M])
+ ): RDD[V] = {
+ logInfo("Starting superstep "+superstep+".")
+ val startTime = System.currentTimeMillis
+ val aggregated = agg(verts, aggregator)
+ val combinedMsgs = msgs.combineByKey(combiner.createCombiner, combiner.mergeMsg, combiner.mergeCombiners, numSplits)
+ val grouped = verts.groupWith(combinedMsgs)
+ val (processed, numMsgs, numActiveVerts) = comp[V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
+ val timeTaken = System.currentTimeMillis - startTime
+ logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
+ // Check stopping condition and iterate
+ val noActivity = numMsgs == 0 && numActiveVerts == 0
+ if (noActivity) {
+ { case (id, (vert, msgs)) => vert }
+ } else {
+ val newVerts = processed.mapValues { case (vert, msgs) => vert }
+ val newMsgs = processed.flatMap {
+ case (id, (vert, msgs)) => => (m.targetId, m))
+ }
+ run(sc, newVerts, newMsgs)(combiner, aggregator, superstep + 1, numSplits)(compute)
+ }
+ }
+ /**
+ * Aggregates the given vertices using the given aggregator, or does
+ * nothing if it is a NullAggregator.
+ */
+ def agg[V <: Vertex, A : Manifest](verts: RDD[(String, V)], aggregator: Aggregator[V, A]): A = aggregator match {
+ case _: NullAggregator[_] =>
+ None
+ case _ =>
+ {
+ case (id, vert) => aggregator.createAggregator(vert)
+ }.reduce(aggregator.mergeAggregators(_, _))
+ }
+ /**
+ * Processes the given vertex-message RDD using the compute
+ * function. Returns the processed RDD, the number of messages
+ * created, and the number of active vertices.
+ */
+ def comp[V <: Vertex, M <: Message, C](sc: SparkContext, grouped: RDD[(String, (Seq[V], Seq[C]))], compute: (V, Option[C]) => (V, Iterable[M])): (RDD[(String, (V, Iterable[M]))], Int, Int) = {
+ var numMsgs = sc.accumulator(0)
+ var numActiveVerts = sc.accumulator(0)
+ val processed = grouped.flatMapValues {
+ case (Seq(), _) => None
+ case (Seq(v), c) =>
+ val (newVert, newMsgs) =
+ compute(v, c match {
+ case Seq(comb) => Some(comb)
+ case Seq() => None
+ })
+ numMsgs += newMsgs.size
+ if (
+ numActiveVerts += 1
+ Some((newVert, newMsgs))
+ }.cache
+ // Force evaluation of processed RDD for accurate performance measurements
+ processed.foreach(x => {})
+ (processed, numMsgs.value, numActiveVerts.value)
+ }
+ /**
+ * Converts a compute function that doesn't take an aggregator to
+ * one that does, so it can be passed to
+ */
+ implicit def addAggregatorArg[
+ V <: Vertex : Manifest, M <: Message : Manifest, C
+ ](
+ compute: (V, Option[C], Int) => (V, Iterable[M])
+ ): (V, Option[C], Option[Nothing], Int) => (V, Iterable[M]) = {
+ (vert: V, messages: Option[C], aggregator: Option[Nothing], superstep: Int) => compute(vert, messages, superstep)
+ }
+// TODO: Simplify Combiner interface and make it more OO.
+trait Combiner[M, C] {
+ def createCombiner(msg: M): C
+ def mergeMsg(combiner: C, msg: M): C
+ def mergeCombiners(a: C, b: C): C
+trait Aggregator[V, A] {
+ def createAggregator(vert: V): A
+ def mergeAggregators(a: A, b: A): A
+class DefaultCombiner[M] extends Combiner[M, ArrayBuffer[M]] {
+ def createCombiner(msg: M): ArrayBuffer[M] =
+ ArrayBuffer(msg)
+ def mergeMsg(combiner: ArrayBuffer[M], msg: M): ArrayBuffer[M] =
+ combiner += msg
+ def mergeCombiners(a: ArrayBuffer[M], b: ArrayBuffer[M]): ArrayBuffer[M] =
+ a ++= b
+class NullAggregator[V] extends Aggregator[V, Option[Nothing]] {
+ def createAggregator(vert: V): Option[Nothing] = None
+ def mergeAggregators(a: Option[Nothing], b: Option[Nothing]): Option[Nothing] = None
+ * Represents a Bagel vertex.
+ *
+ * Subclasses may store state along with each vertex and must be
+ * annotated with @serializable.
+ */
+trait Vertex {
+ def id: String
+ def active: Boolean
+ * Represents a Bagel message to a target vertex.
+ *
+ * Subclasses may contain a payload to deliver to the target vertex
+ * and must be annotated with @serializable.
+ */
+trait Message {
+ def targetId: String
+ * Represents a directed edge between two vertices.
+ *
+ * Subclasses may store state along each edge and must be annotated
+ * with @serializable.
+ */
+trait Edge {
+ def targetId: String
96 bagel/src/main/scala/spark/bagel/examples/ShortestPath.scala
@@ -0,0 +1,96 @@
+package spark.bagel.examples
+import spark._
+import spark.SparkContext._
+import scala.math.min
+import spark.bagel._
+import spark.bagel.Bagel._
+object ShortestPath {
+ def main(args: Array[String]) {
+ if (args.length < 4) {
+ System.err.println("Usage: ShortestPath <graphFile> <startVertex> " +
+ "<numSplits> <host>")
+ System.exit(-1)
+ }
+ val graphFile = args(0)
+ val startVertex = args(1)
+ val numSplits = args(2).toInt
+ val host = args(3)
+ val sc = new SparkContext(host, "ShortestPath")
+ // Parse the graph data from a file into two RDDs, vertices and messages
+ val lines =
+ (sc.textFile(graphFile)
+ .filter(!_.matches("^\\s*#.*"))
+ .map(line => line.split("\t")))
+ val vertices: RDD[(String, SPVertex)] =
+ (lines.groupBy(line => line(0))
+ .map {
+ case (vertexId, lines) => {
+ val outEdges = lines.collect {
+ case Array(_, targetId, edgeValue) =>
+ new SPEdge(targetId, edgeValue.toInt)
+ }
+ (vertexId, new SPVertex(vertexId, Int.MaxValue, outEdges, true))
+ }
+ })
+ val messages: RDD[(String, SPMessage)] =
+ (lines.filter(_.length == 2)
+ .map {
+ case Array(vertexId, messageValue) =>
+ (vertexId, new SPMessage(vertexId, messageValue.toInt))
+ })
+ System.err.println("Read "+vertices.count()+" vertices and "+
+ messages.count()+" messages.")
+ // Do the computation
+ val compute = addAggregatorArg {
+ (self: SPVertex, messageMinValue: Option[Int], superstep: Int) =>
+ val newValue = messageMinValue match {
+ case Some(minVal) => min(self.value, minVal)
+ case None => self.value
+ }
+ val outbox =
+ if (newValue != self.value)
+ =>
+ new SPMessage(edge.targetId, newValue + edge.value))
+ else
+ List()
+ (new SPVertex(, newValue, self.outEdges, false), outbox)
+ }
+ val result =, vertices, messages)(combiner = MinCombiner, numSplits = numSplits)(compute)
+ // Print the result
+ System.err.println("Shortest path from "+startVertex+" to all vertices:")
+ val shortest = =>
+ "%s\t%s\n".format(, vertex.value match {
+ case x if x == Int.MaxValue => "inf"
+ case x => x
+ })).collect.mkString
+ println(shortest)
+ }
+object MinCombiner extends Combiner[SPMessage, Int] {
+ def createCombiner(msg: SPMessage): Int =
+ msg.value
+ def mergeMsg(combiner: Int, msg: SPMessage): Int =
+ min(combiner, msg.value)
+ def mergeCombiners(a: Int, b: Int): Int =
+ min(a, b)
+@serializable class SPVertex(val id: String, val value: Int, val outEdges: Seq[SPEdge], val active: Boolean) extends Vertex
+@serializable class SPEdge(val targetId: String, val value: Int) extends Edge
+@serializable class SPMessage(val targetId: String, val value: Int) extends Message
158 bagel/src/main/scala/spark/bagel/examples/WikipediaPageRank.scala
@@ -0,0 +1,158 @@
+package spark.bagel.examples
+import spark._
+import spark.SparkContext._
+import spark.bagel._
+import spark.bagel.Bagel._
+import scala.collection.mutable.ArrayBuffer
+import scala.xml.{XML,NodeSeq}
+import com.esotericsoftware.kryo._
+object WikipediaPageRank {
+ def main(args: Array[String]) {
+ if (args.length < 4) {
+ System.err.println("Usage: WikipediaPageRank <inputFile> <threshold> <numSplits> <host> [<noCombiner>]")
+ System.exit(-1)
+ }
+ System.setProperty("spark.serialization", "spark.KryoSerialization")
+ System.setProperty("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
+ val inputFile = args(0)
+ val threshold = args(1).toDouble
+ val numSplits = args(2).toInt
+ val host = args(3)
+ val noCombiner = args.length > 4 && args(4).nonEmpty
+ val sc = new SparkContext(host, "WikipediaPageRank")
+ // Parse the Wikipedia page data into a graph
+ val input = sc.textFile(inputFile)
+ println("Counting vertices...")
+ val numVertices = input.count()
+ println("Done counting vertices.")
+ println("Parsing input file...")
+ val vertices: RDD[(String, PRVertex)] = => {
+ val fields = line.split("\t")
+ val (title, body) = (fields(1), fields(3).replace("\\n", "\n"))
+ val links =
+ if (body == "\\N")
+ NodeSeq.Empty
+ else
+ try {
+ XML.loadString(body) \\ "link" \ "target"
+ } catch {
+ case e: org.xml.sax.SAXParseException =>
+ System.err.println("Article \""+title+"\" has malformed XML in body:\n"+body)
+ NodeSeq.Empty
+ }
+ val outEdges = ArrayBuffer( => new PREdge(new String(link.text))): _*)
+ val id = new String(title)
+ (id, new PRVertex(id, 1.0 / numVertices, outEdges, true))
+ }).cache
+ println("Done parsing input file.")
+ // Do the computation
+ val epsilon = 0.01 / numVertices
+ val messages = sc.parallelize(List[(String, PRMessage)]())
+ val result =
+ if (noCombiner) {
+, vertices, messages)(numSplits = numSplits)(PRNoCombiner.compute(numVertices, epsilon))
+ } else {
+, vertices, messages)(combiner = PRCombiner, numSplits = numSplits)(PRCombiner.compute(numVertices, epsilon))
+ }
+ // Print the result
+ System.err.println("Articles with PageRank >= "+threshold+":")
+ val top = result.filter(_.value >= threshold).map(vertex =>
+ "%s\t%s\n".format(, vertex.value)).collect.mkString
+ println(top)
+ }
+object PRCombiner extends Combiner[PRMessage, Double] {
+ def createCombiner(msg: PRMessage): Double =
+ msg.value
+ def mergeMsg(combiner: Double, msg: PRMessage): Double =
+ combiner + msg.value
+ def mergeCombiners(a: Double, b: Double): Double =
+ a + b
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messageSum: Option[Double], superstep: Int): (PRVertex, Iterable[PRMessage]) = {
+ val newValue = messageSum match {
+ case Some(msgSum) if msgSum != 0 =>
+ 0.15 / numVertices + 0.85 * msgSum
+ case _ => self.value
+ }
+ val terminate = (superstep >= 10 && (newValue - self.value).abs < epsilon) || superstep >= 30
+ val outbox =
+ if (!terminate)
+ =>
+ new PRMessage(edge.targetId, newValue / self.outEdges.size))
+ else
+ ArrayBuffer[PRMessage]()
+ (new PRVertex(, newValue, self.outEdges, !terminate), outbox)
+ }
+object PRNoCombiner extends DefaultCombiner[PRMessage] {
+ def compute(numVertices: Long, epsilon: Double)(self: PRVertex, messages: Option[ArrayBuffer[PRMessage]], superstep: Int): (PRVertex, Iterable[PRMessage]) =
+ PRCombiner.compute(numVertices, epsilon)(self, messages match {
+ case Some(msgs) => Some(
+ case None => None
+ }, superstep)
+@serializable class PRVertex() extends Vertex {
+ var id: String = _
+ var value: Double = _
+ var outEdges: ArrayBuffer[PREdge] = _
+ var active: Boolean = true
+ def this(id: String, value: Double, outEdges: ArrayBuffer[PREdge], active: Boolean) {
+ this()
+ = id
+ this.value = value
+ this.outEdges = outEdges
+ = active
+ }
+@serializable class PRMessage() extends Message {
+ var targetId: String = _
+ var value: Double = _
+ def this(targetId: String, value: Double) {
+ this()
+ this.targetId = targetId
+ this.value = value
+ }
+@serializable class PREdge() extends Edge {
+ var targetId: String = _
+ def this(targetId: String) {
+ this()
+ this.targetId = targetId
+ }
+class PRKryoRegistrator extends KryoRegistrator {
+ def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[PRVertex])
+ kryo.register(classOf[PRMessage])
+ kryo.register(classOf[PREdge])
+ }
53 bagel/src/test/scala/bagel/BagelSuite.scala
@@ -0,0 +1,53 @@
+package spark.bagel
+import org.scalatest.{FunSuite, Assertions}
+import org.scalatest.prop.Checkers
+import org.scalacheck.Arbitrary._
+import org.scalacheck.Gen
+import org.scalacheck.Prop._
+import scala.collection.mutable.ArrayBuffer
+import spark._
+import spark.bagel.Bagel._
+@serializable class TestVertex(val id: String, val active: Boolean, val age: Int) extends Vertex
+@serializable class TestMessage(val targetId: String) extends Message
+class BagelSuite extends FunSuite with Assertions {
+ test("halting by voting") {
+ val sc = new SparkContext("local", "test")
+ val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, true, 0))))
+ val msgs = sc.parallelize(Array[(String, TestMessage)]())
+ val numSupersteps = 5
+ val result =
+, verts, msgs)()(addAggregatorArg {
+ (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
+ (new TestVertex(, superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
+ })
+ for (vert <- result.collect)
+ assert(vert.age === numSupersteps)
+ }
+ test("halting by message silence") {
+ val sc = new SparkContext("local", "test")
+ val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(id, false, 0))))
+ val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
+ val numSupersteps = 5
+ val result =
+, verts, msgs)()(addAggregatorArg {
+ (self: TestVertex, msgs: Option[ArrayBuffer[TestMessage]], superstep: Int) =>
+ val msgsOut =
+ msgs match {
+ case Some(ms) if (superstep < numSupersteps - 1) =>
+ ms
+ case _ =>
+ new ArrayBuffer[TestMessage]()
+ }
+ (new TestVertex(,, self.age + 1), msgsOut)
+ })
+ for (vert <- result.collect)
+ assert(vert.age === numSupersteps)
+ }
6 project/build/SparkProject.scala
@@ -14,6 +14,8 @@ extends ParentProject(info) with IdeaProject
lazy val examples =
project("examples", "Spark Examples", new ExamplesProject(_), core)
+ lazy val bagel = project("bagel", "Bagel", new BagelProject(_), core)
class CoreProject(info: ProjectInfo)
extends DefaultProject(info) with Eclipsify with IdeaProject with DepJar with XmlTestReport
@@ -21,6 +23,10 @@ extends ParentProject(info) with IdeaProject
class ExamplesProject(info: ProjectInfo)
extends DefaultProject(info) with Eclipsify with IdeaProject
+ class BagelProject(info: ProjectInfo)
+ extends DefaultProject(info) with DepJar with XmlTestReport
+ {}
2  run
@@ -35,6 +35,7 @@ export JAVA_OPTS
# Build up classpath
@@ -60,6 +61,7 @@ CLASSPATH+=:$EXAMPLES_DIR/target/scala_2.8.1/classes
for jar in $CORE_DIR/lib/hadoop-0.20.2/lib/*.jar; do
export CLASSPATH # Needed for spark-shell
if [ -n "$SCALA_HOME" ]; then
Something went wrong with that request. Please try again.