Skip to content

Commit

Permalink
Merge pull request #533 from mridulm/yarn
Browse files Browse the repository at this point in the history
Fix performance issues and bugs in YARN branch.
  • Loading branch information
mateiz committed Mar 30, 2013
2 parents a113b88 + c951237 commit 268690b
Show file tree
Hide file tree
Showing 22 changed files with 1,136 additions and 452 deletions.
7 changes: 3 additions & 4 deletions core/src/main/scala/spark/CacheTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
Expand Down Expand Up @@ -100,7 +98,7 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "CacheTracker"

val timeout = 10.seconds
val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")

var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props[CacheTrackerActor], name = actorName)
Expand Down Expand Up @@ -211,8 +209,9 @@ private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, b
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
val elements = new ArrayBuffer[Any]
elements ++= rdd.compute(split)
try {
// compute can throw exceptions - ensure that lock is released if that happens.
elements ++= rdd.compute(split)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
//future.apply() // Wait for the reply from the cache tracker
Expand Down
25 changes: 16 additions & 9 deletions core/src/main/scala/spark/FetchFailedException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,25 @@ package spark
import spark.storage.BlockManagerId

private[spark] class FetchFailedException(
val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
val reduceId: Int,
taskEndReason: TaskEndReason,
message: String,
cause: Throwable)
extends Exception {

override def getMessage(): String =
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)

def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
cause)

def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(null, shuffleId, -1, reduceId),
"Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)

override def getMessage(): String = message


override def getCause(): Throwable = cause

def toTaskEndReason: TaskEndReason =
FetchFailed(bmAddress, shuffleId, mapId, reduceId)
def toTaskEndReason: TaskEndReason = taskEndReason

}
4 changes: 4 additions & 0 deletions core/src/main/scala/spark/Logging.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ trait Logging {
if (log.isErrorEnabled) log.error(msg, throwable)
}

protected def isTraceEnabled(): Boolean = {
log.isTraceEnabled
}

// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
protected def initLogging() { log }
Expand Down
110 changes: 75 additions & 35 deletions core/src/main/scala/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._

import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
Expand Down Expand Up @@ -37,13 +35,13 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
}

private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
val host: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "MapOutputTracker"
private val masterHost: String = System.getProperty("spark.master.host", "localhost")
private val masterPort: Int = System.getProperty("spark.master.port", "7077").toInt
private val actorName: String = "MapOutputTracker"

val timeout = 10.seconds
val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")

var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
private var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]

// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
Expand All @@ -52,14 +50,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea

// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
private val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]

var trackerActor: ActorRef = if (isMaster) {
private var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
actor
} else {
val url = "akka://spark@%s:%s/user/%s".format(host, port, actorName)
val url = "akka://spark@%s:%s/user/%s".format(masterHost, masterPort, actorName)
actorSystem.actorFor(url)
}

Expand All @@ -83,12 +81,11 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
}

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.get(shuffleId) != null) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}

def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses.get(shuffleId)
array.synchronized {
Expand Down Expand Up @@ -128,6 +125,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
val statuses = mapStatuses.get(shuffleId)
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
Expand All @@ -138,29 +136,65 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
case e: InterruptedException =>
}
}
return mapStatuses.get(shuffleId).map(status =>
(status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
} else {
}

// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId)
if (null == fetchedStatuses) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val hostPort = Utils.localHostPort()
val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
val fetchedStatuses = deserializeStatuses(fetchedBytes)

logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()

if (null != fetchedStatuses) {
// not registered with fetching, simply return it
// sync'ing on fetchedStatuses for consistency of api sake
// Since this is a remote fetch, it should not be modified locally; but the local
// statuses ARE modified - modifying code to uniformly lock to remove
// implicit knowledge of whether it is remote or local (since we do not maintain that
// explicitly in the map)
fetchedStatuses.synchronized {
return fetchedStatuses.map(status =>
(status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
}
}


try {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val hostPort = Utils.localHostPort()
val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]

fetchedStatuses = deserializeStatuses(fetchedBytes)

logInfo("Got the output locations")
val prevStatus = mapStatuses.put(shuffleId, fetchedStatuses)
// enable this assertion ?
// assert (null == prevStatus)
} finally {
// release lock in try/finally
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
if (null != fetchedStatuses) {
fetchedStatuses.synchronized {
return fetchedStatuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
}
}
else{
// What now ? Throw exception ?
throw new FetchFailedException(shuffleId, reduceId, null)
}
return fetchedStatuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
} else {
return statuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
statuses.synchronized {
return statuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
}
}
}

Expand Down Expand Up @@ -192,7 +226,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
// mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
generation = newGen
}
}
Expand Down Expand Up @@ -230,18 +265,23 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
objOut.writeObject(statuses)
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
objOut.writeObject(statuses)
}
objOut.close()
out.toByteArray
}

// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().asInstanceOf[Array[MapStatus]]
objIn.readObject().asInstanceOf[Array[MapStatus]].
// drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
filter( _ != null )
}
}

Expand Down
44 changes: 43 additions & 1 deletion core/src/main/scala/spark/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,44 @@ private object Utils extends Logging {
return buf
}


private val shutdownDeletePaths = new collection.mutable.HashSet[String]()

// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths += absolutePath
}
}

// Is the path already registered to be deleted via a shutdown hook ?
def hasShutdownDeleteDir(file: File): Boolean = {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths.contains(absolutePath)
}
}

// Note: if file is child of some registered path, while not equal to it, then return true; else false
// This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException
// and incomplete cleanup
def hasRootAsShutdownDeleteDir(file: File): Boolean = {

val absolutePath = file.getAbsolutePath()

var shutdownDeletePathsStr: String = ""
val retval = shutdownDeletePaths.synchronized {
shutdownDeletePathsStr = shutdownDeletePaths.mkString("[ ", ", ", " ]")

shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined
}

logInfo("file = " + file + ", present as root ? " + retval + ", shutdownDeletePaths = " + shutdownDeletePathsStr)

retval
}

/** Create a temporary directory inside the given parent directory */
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
var attempts = 0
Expand All @@ -83,10 +121,14 @@ private object Utils extends Logging {
}
} catch { case e: IOException => ; }
}

registerShutdownDeleteDir(dir)

// Add a shutdown hook to delete the temp dir when the JVM exits
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
override def run() {
Utils.deleteRecursively(dir)
// Attempt to delete if some patch which is parent of this is not already registered.
if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
}
})
return dir
Expand Down
31 changes: 24 additions & 7 deletions core/src/main/scala/spark/deploy/yarn/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class Client(conf: Configuration, args: ClientArguments) extends Logging {
logInfo("Max mem capabililty of resources in this cluster " + maxMem)

// If the cluster does not have enough memory resources, exit.
val requestedMem = args.amMemory + args.numWorkers * args.workerMemory
val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
if (requestedMem > maxMem) {
logError("Cluster cannot satisfy memory resource request of " + requestedMem)
System.exit(1)
Expand All @@ -132,8 +132,9 @@ class Client(conf: Configuration, args: ClientArguments) extends Logging {
// Add them as local resources to the AM
val fs = FileSystem.get(conf)
Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
.foreach { case(destName, localPath) =>
if (null != localPath) {
.foreach { case(destName, _localPath) =>
val localPath: String = if (null != _localPath) _localPath.trim() else ""
if (! localPath.isEmpty()) {
val src = new Path(localPath)
val pathSuffix = appName + "/" + appId.getId() + destName
val dst = new Path(fs.getHomeDirectory(), pathSuffix)
Expand Down Expand Up @@ -210,21 +211,37 @@ class Client(conf: Configuration, args: ClientArguments) extends Logging {
amContainer.setLocalResources(localResources)
amContainer.setEnvironment(env)

var amMemory = java.lang.Math.max(args.amMemory,
newApp.getMinimumResourceCapability().getMemory() - YarnAllocationHandler.MEMORY_OVERHEAD)

val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()

var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
(if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD

// Extra options for the JVM
var JAVA_OPTS = ""

// Add Xmx for am memory
JAVA_OPTS += "-Xmx" + amMemory + "m "

// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
// node, spark gc effects all other containers performance (which can also be other spark containers)
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
// limited to subset of cores on a node.
if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
JAVA_OPTS += " -XX:+CMSIncrementalMode "
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
}
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
}

// Command for the ApplicationMaster
val commands = List[String]("java " +
" -server " +
JAVA_OPTS +
" spark.deploy.yarn.ApplicationMaster" +
" --class " + args.userClass +
Expand All @@ -240,7 +257,7 @@ class Client(conf: Configuration, args: ClientArguments) extends Logging {

val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
// Memory for the ApplicationMaster
capability.setMemory(args.amMemory)
capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
amContainer.setResource(capability)

return amContainer
Expand Down
Loading

0 comments on commit 268690b

Please sign in to comment.